Skip to content

Implement val_temporal_unit for deciding how datetimes and dates timestamps get validated. #1751

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
40 changes: 33 additions & 7 deletions src/input/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use pyo3::pyclass::CompareOp;
use pyo3::types::PyTuple;
use pyo3::types::{PyDate, PyDateTime, PyDelta, PyDeltaAccess, PyDict, PyTime, PyTzInfo};
use pyo3::IntoPyObjectExt;
use speedate::DateConfig;
use speedate::{
Date, DateTime, DateTimeConfig, Duration, MicrosecondsPrecisionOverflowBehavior, ParseError, Time, TimeConfig,
};
Expand All @@ -21,6 +22,7 @@ use super::Input;
use crate::errors::ToErrorValue;
use crate::errors::{ErrorType, ValError, ValResult};
use crate::tools::py_err;
use crate::validators::TemporalUnitMode;

#[cfg_attr(debug_assertions, derive(Debug))]
pub enum EitherDate<'py> {
Expand Down Expand Up @@ -324,8 +326,12 @@ impl<'py> EitherDateTime<'py> {
}
}

pub fn bytes_as_date<'py>(input: &(impl Input<'py> + ?Sized), bytes: &[u8]) -> ValResult<EitherDate<'py>> {
match Date::parse_bytes(bytes) {
pub fn bytes_as_date<'py>(
input: &(impl Input<'py> + ?Sized),
bytes: &[u8],
mode: TemporalUnitMode,
) -> ValResult<EitherDate<'py>> {
match Date::parse_bytes_with_config(bytes, &DateConfig::builder().timestamp_unit(mode.into()).build()) {
Ok(date) => Ok(date.into()),
Err(err) => Err(ValError::new(
ErrorType::DateParsing {
Expand Down Expand Up @@ -364,6 +370,7 @@ pub fn bytes_as_datetime<'py>(
input: &(impl Input<'py> + ?Sized),
bytes: &[u8],
microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior,
mode: TemporalUnitMode,
) -> ValResult<EitherDateTime<'py>> {
match DateTime::parse_bytes_with_config(
bytes,
Expand All @@ -372,7 +379,7 @@ pub fn bytes_as_datetime<'py>(
microseconds_precision_overflow_behavior: microseconds_overflow_behavior,
unix_timestamp_offset: Some(0),
},
..Default::default()
timestamp_unit: mode.into(),
},
) {
Ok(dt) => Ok(dt.into()),
Expand All @@ -390,6 +397,7 @@ pub fn int_as_datetime<'py>(
input: &(impl Input<'py> + ?Sized),
timestamp: i64,
timestamp_microseconds: u32,
mode: TemporalUnitMode,
) -> ValResult<EitherDateTime<'py>> {
match DateTime::from_timestamp_with_config(
timestamp,
Expand All @@ -399,7 +407,7 @@ pub fn int_as_datetime<'py>(
unix_timestamp_offset: Some(0),
..Default::default()
},
..Default::default()
timestamp_unit: mode.into(),
},
) {
Ok(dt) => Ok(dt.into()),
Expand Down Expand Up @@ -427,12 +435,30 @@ macro_rules! nan_check {
};
}

pub fn float_as_datetime<'py>(input: &(impl Input<'py> + ?Sized), timestamp: f64) -> ValResult<EitherDateTime<'py>> {
pub fn float_as_datetime<'py>(
input: &(impl Input<'py> + ?Sized),
timestamp: f64,
mode: TemporalUnitMode,
) -> ValResult<EitherDateTime<'py>> {
nan_check!(input, timestamp, DatetimeParsing);
let microseconds = timestamp.fract().abs() * 1_000_000.0;
let microseconds = match mode {
TemporalUnitMode::Seconds => timestamp.fract().abs() * 1_000_000.0,
TemporalUnitMode::Milliseconds => timestamp.fract().abs() * 1_000.0,
TemporalUnitMode::Infer => {
// Use the same watershed from speedate to determine if we treat the float as seconds or milliseconds.
// TODO: should we expose this from speedate?
if timestamp.abs() <= 20_000_000_000.0 {
// treat as seconds
timestamp.fract().abs() * 1_000_000.0
} else {
// treat as milliseconds
timestamp.fract().abs() * 1_000.0
}
}
};
// checking for extra digits in microseconds is unreliable with large floats,
// so we just round to the nearest microsecond
int_as_datetime(input, timestamp.floor() as i64, microseconds.round() as u32)
int_as_datetime(input, timestamp.floor() as i64, microseconds.round() as u32, mode)
}

pub fn date_as_datetime<'py>(date: &Bound<'py, PyDate>) -> PyResult<EitherDateTime<'py>> {
Expand Down
5 changes: 3 additions & 2 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use pyo3::{intern, prelude::*, IntoPyObjectExt};
use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
use crate::lookup_key::{LookupKey, LookupPath};
use crate::tools::py_err;
use crate::validators::ValBytesMode;
use crate::validators::{TemporalUnitMode, ValBytesMode};

use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
use super::return_enums::{EitherBytes, EitherComplex, EitherInt, EitherString};
Expand Down Expand Up @@ -158,7 +158,7 @@ pub trait Input<'py>: fmt::Debug {

fn validate_iter(&self) -> ValResult<GenericIterator<'static>>;

fn validate_date(&self, strict: bool) -> ValMatch<EitherDate<'py>>;
fn validate_date(&self, strict: bool, mode: TemporalUnitMode) -> ValMatch<EitherDate<'py>>;

fn validate_time(
&self,
Expand All @@ -170,6 +170,7 @@ pub trait Input<'py>: fmt::Debug {
&self,
strict: bool,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
mode: TemporalUnitMode,
) -> ValMatch<EitherDateTime<'py>>;

fn validate_timedelta(
Expand Down
20 changes: 11 additions & 9 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::input::return_enums::EitherComplex;
use crate::lookup_key::{LookupKey, LookupPath};
use crate::validators::complex::string_to_complex;
use crate::validators::decimal::create_decimal;
use crate::validators::ValBytesMode;
use crate::validators::{TemporalUnitMode, ValBytesMode};

use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, float_as_datetime, float_as_duration,
Expand Down Expand Up @@ -277,9 +277,9 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> {
}
}

fn validate_date(&self, _strict: bool) -> ValResult<ValidationMatch<EitherDate<'py>>> {
fn validate_date(&self, _strict: bool, mode: TemporalUnitMode) -> ValResult<ValidationMatch<EitherDate<'py>>> {
match self {
JsonValue::Str(v) => bytes_as_date(self, v.as_bytes()).map(ValidationMatch::strict),
JsonValue::Str(v) => bytes_as_date(self, v.as_bytes(), mode).map(ValidationMatch::strict),
_ => Err(ValError::new(ErrorTypeDefaults::DateType, self)),
}
}
Expand Down Expand Up @@ -313,13 +313,14 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> {
&self,
strict: bool,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
mode: TemporalUnitMode,
) -> ValResult<ValidationMatch<EitherDateTime<'py>>> {
match self {
JsonValue::Str(v) => {
bytes_as_datetime(self, v.as_bytes(), microseconds_overflow_behavior).map(ValidationMatch::strict)
bytes_as_datetime(self, v.as_bytes(), microseconds_overflow_behavior, mode).map(ValidationMatch::strict)
}
JsonValue::Int(v) if !strict => int_as_datetime(self, *v, 0).map(ValidationMatch::lax),
JsonValue::Float(v) if !strict => float_as_datetime(self, *v).map(ValidationMatch::lax),
JsonValue::Int(v) if !strict => int_as_datetime(self, *v, 0, mode).map(ValidationMatch::lax),
JsonValue::Float(v) if !strict => float_as_datetime(self, *v, mode).map(ValidationMatch::lax),
_ => Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)),
}
}
Expand Down Expand Up @@ -485,8 +486,8 @@ impl<'py> Input<'py> for str {
Ok(string_to_vec(self).into())
}

fn validate_date(&self, _strict: bool) -> ValResult<ValidationMatch<EitherDate<'py>>> {
bytes_as_date(self, self.as_bytes()).map(ValidationMatch::lax)
fn validate_date(&self, _strict: bool, mode: TemporalUnitMode) -> ValResult<ValidationMatch<EitherDate<'py>>> {
bytes_as_date(self, self.as_bytes(), mode).map(ValidationMatch::lax)
}

fn validate_time(
Expand All @@ -501,8 +502,9 @@ impl<'py> Input<'py> for str {
&self,
_strict: bool,
microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior,
mode: TemporalUnitMode,
) -> ValResult<ValidationMatch<EitherDateTime<'py>>> {
bytes_as_datetime(self, self.as_bytes(), microseconds_overflow_behavior).map(ValidationMatch::lax)
bytes_as_datetime(self, self.as_bytes(), microseconds_overflow_behavior, mode).map(ValidationMatch::lax)
}

fn validate_timedelta(
Expand Down
14 changes: 8 additions & 6 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::tools::{extract_i64, safe_repr};
use crate::validators::complex::string_to_complex;
use crate::validators::decimal::{create_decimal, get_decimal_type};
use crate::validators::Exactness;
use crate::validators::TemporalUnitMode;
use crate::validators::ValBytesMode;
use crate::ArgsKwargs;

Expand Down Expand Up @@ -494,7 +495,7 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
}
}

fn validate_date(&self, strict: bool) -> ValResult<ValidationMatch<EitherDate<'py>>> {
fn validate_date(&self, strict: bool, mode: TemporalUnitMode) -> ValResult<ValidationMatch<EitherDate<'py>>> {
if let Ok(date) = self.downcast_exact::<PyDate>() {
Ok(ValidationMatch::exact(date.clone().into()))
} else if self.is_instance_of::<PyDateTime>() {
Expand All @@ -515,7 +516,7 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
None
}
} {
bytes_as_date(self, bytes).map(ValidationMatch::lax)
bytes_as_date(self, bytes, mode).map(ValidationMatch::lax)
} else {
Err(ValError::new(ErrorTypeDefaults::DateType, self))
}
Expand Down Expand Up @@ -559,6 +560,7 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
&self,
strict: bool,
microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior,
mode: TemporalUnitMode,
) -> ValResult<ValidationMatch<EitherDateTime<'py>>> {
if let Ok(dt) = self.downcast_exact::<PyDateTime>() {
return Ok(ValidationMatch::exact(dt.clone().into()));
Expand All @@ -570,15 +572,15 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
if !strict {
return if let Ok(py_str) = self.downcast::<PyString>() {
let str = py_string_str(py_str)?;
bytes_as_datetime(self, str.as_bytes(), microseconds_overflow_behavior)
bytes_as_datetime(self, str.as_bytes(), microseconds_overflow_behavior, mode)
} else if let Ok(py_bytes) = self.downcast::<PyBytes>() {
bytes_as_datetime(self, py_bytes.as_bytes(), microseconds_overflow_behavior)
bytes_as_datetime(self, py_bytes.as_bytes(), microseconds_overflow_behavior, mode)
} else if self.is_exact_instance_of::<PyBool>() {
Err(ValError::new(ErrorTypeDefaults::DatetimeType, self))
} else if let Some(int) = extract_i64(self) {
int_as_datetime(self, int, 0)
int_as_datetime(self, int, 0, mode)
} else if let Ok(float) = self.extract::<f64>() {
float_as_datetime(self, float)
float_as_datetime(self, float, mode)
} else if let Ok(date) = self.downcast::<PyDate>() {
Ok(date_as_datetime(date)?)
} else {
Expand Down
13 changes: 8 additions & 5 deletions src/input/input_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::lookup_key::{LookupKey, LookupPath};
use crate::tools::safe_repr;
use crate::validators::complex::string_to_complex;
use crate::validators::decimal::create_decimal;
use crate::validators::ValBytesMode;
use crate::validators::{TemporalUnitMode, ValBytesMode};

use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, EitherDate, EitherDateTime, EitherTime,
Expand Down Expand Up @@ -201,9 +201,9 @@ impl<'py> Input<'py> for StringMapping<'py> {
Err(ValError::new(ErrorTypeDefaults::IterableType, self))
}

fn validate_date(&self, _strict: bool) -> ValResult<ValidationMatch<EitherDate<'py>>> {
fn validate_date(&self, _strict: bool, mode: TemporalUnitMode) -> ValResult<ValidationMatch<EitherDate<'py>>> {
match self {
Self::String(s) => bytes_as_date(self, py_string_str(s)?.as_bytes()).map(ValidationMatch::strict),
Self::String(s) => bytes_as_date(self, py_string_str(s)?.as_bytes(), mode).map(ValidationMatch::strict),
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::DateType, self)),
}
}
Expand All @@ -224,10 +224,13 @@ impl<'py> Input<'py> for StringMapping<'py> {
&self,
_strict: bool,
microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior,
mode: TemporalUnitMode,
) -> ValResult<ValidationMatch<EitherDateTime<'py>>> {
match self {
Self::String(s) => bytes_as_datetime(self, py_string_str(s)?.as_bytes(), microseconds_overflow_behavior)
.map(ValidationMatch::strict),
Self::String(s) => {
bytes_as_datetime(self, py_string_str(s)?.as_bytes(), microseconds_overflow_behavior, mode)
.map(ValidationMatch::strict)
}
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)),
}
}
Expand Down
60 changes: 55 additions & 5 deletions src/validators/config.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
use std::borrow::Cow;
use std::str::FromStr;

use crate::build_tools::py_schema_err;
use crate::errors::ErrorType;
use crate::input::EitherBytes;
use crate::serializers::BytesMode;
use crate::tools::SchemaDict;
use base64::engine::general_purpose::GeneralPurpose;
use base64::engine::{DecodePaddingMode, GeneralPurposeConfig};
use base64::{alphabet, DecodeError, Engine};
use pyo3::types::{PyDict, PyString};
use pyo3::{intern, prelude::*};

use crate::errors::ErrorType;
use crate::input::EitherBytes;
use crate::serializers::BytesMode;
use crate::tools::SchemaDict;
use speedate::TimestampUnit;

const URL_SAFE_OPTIONAL_PADDING: GeneralPurpose = GeneralPurpose::new(
&alphabet::URL_SAFE,
Expand All @@ -21,6 +22,55 @@ const STANDARD_OPTIONAL_PADDING: GeneralPurpose = GeneralPurpose::new(
GeneralPurposeConfig::new().with_decode_padding_mode(DecodePaddingMode::Indifferent),
);

#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
pub enum TemporalUnitMode {
Seconds,
Milliseconds,
#[default]
Infer,
}

impl FromStr for TemporalUnitMode {
type Err = PyErr;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"seconds" => Ok(Self::Seconds),
"milliseconds" => Ok(Self::Milliseconds),
"infer" => Ok(Self::Infer),

s => py_schema_err!(
"Invalid temporal_unit_mode serialization mode: `{}`, expected seconds, milliseconds or infer",
s
),
}
}
}

impl TemporalUnitMode {
pub fn from_config(config: Option<&Bound<'_, PyDict>>) -> PyResult<Self> {
let Some(config_dict) = config else {
return Ok(Self::default());
};
let raw_mode = config_dict.get_as::<Bound<'_, PyString>>(intern!(config_dict.py(), "val_temporal_unit"))?;
let temporal_unit = raw_mode.map_or_else(
|| Ok(TemporalUnitMode::default()),
|raw| TemporalUnitMode::from_str(&raw.to_cow()?),
)?;
Ok(temporal_unit)
}
}

impl From<TemporalUnitMode> for TimestampUnit {
fn from(value: TemporalUnitMode) -> Self {
match value {
TemporalUnitMode::Seconds => TimestampUnit::Second,
TemporalUnitMode::Milliseconds => TimestampUnit::Millisecond,
TemporalUnitMode::Infer => TimestampUnit::Infer,
}
}
}

#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
pub struct ValBytesMode {
pub ser: BytesMode,
Expand Down
Loading
Loading