""" Typing declarations for np.timedelta64. """ from itertools import product import operator from numba.core import types from numba.core.typing.templates import (AttributeTemplate, ConcreteTemplate, AbstractTemplate, infer_global, infer, infer_getattr, signature) from numba.np import npdatetime_helpers # timedelta64-only operations class TimedeltaUnaryOp(AbstractTemplate): def generic(self, args, kws): if len(args) == 2: # Guard against binary + and - return op, = args if not isinstance(op, types.NPTimedelta): return return signature(op, op) class TimedeltaBinOp(AbstractTemplate): def generic(self, args, kws): if len(args) == 1: # Guard against unary + and - return left, right = args if not all(isinstance(tp, types.NPTimedelta) for tp in args): return if npdatetime_helpers.can_cast_timedelta_units(left.unit, right.unit): return signature(right, left, right) elif npdatetime_helpers.can_cast_timedelta_units(right.unit, left.unit): return signature(left, left, right) class TimedeltaCmpOp(AbstractTemplate): def generic(self, args, kws): # For equality comparisons, all units are inter-comparable left, right = args if not all(isinstance(tp, types.NPTimedelta) for tp in args): return return signature(types.boolean, left, right) class TimedeltaOrderedCmpOp(AbstractTemplate): def generic(self, args, kws): # For ordered comparisons, units must be compatible left, right = args if not all(isinstance(tp, types.NPTimedelta) for tp in args): return if (npdatetime_helpers.can_cast_timedelta_units(left.unit, right.unit) or npdatetime_helpers.can_cast_timedelta_units(right.unit, left.unit)): return signature(types.boolean, left, right) class TimedeltaMixOp(AbstractTemplate): def generic(self, args, kws): """ (timedelta64, {int, float}) -> timedelta64 ({int, float}, timedelta64) -> timedelta64 """ left, right = args if isinstance(right, types.NPTimedelta): td, other = right, left sig_factory = lambda other: signature(td, other, td) elif isinstance(left, types.NPTimedelta): td, other = left, right sig_factory = lambda other: signature(td, td, other) else: return if not isinstance(other, (types.Float, types.Integer)): return # Force integer types to convert to signed because it matches # timedelta64 semantics better. if isinstance(other, types.Integer): other = types.int64 return sig_factory(other) class TimedeltaDivOp(AbstractTemplate): def generic(self, args, kws): """ (timedelta64, {int, float}) -> timedelta64 (timedelta64, timedelta64) -> float """ left, right = args if not isinstance(left, types.NPTimedelta): return if isinstance(right, types.NPTimedelta): if (npdatetime_helpers.can_cast_timedelta_units(left.unit, right.unit) or npdatetime_helpers.can_cast_timedelta_units(right.unit, left.unit)): return signature(types.float64, left, right) elif isinstance(right, (types.Float)): return signature(left, left, right) elif isinstance(right, (types.Integer)): # Force integer types to convert to signed because it matches # timedelta64 semantics better. return signature(left, left, types.int64) @infer_global(operator.pos) class TimedeltaUnaryPos(TimedeltaUnaryOp): key = operator.pos @infer_global(operator.neg) class TimedeltaUnaryNeg(TimedeltaUnaryOp): key = operator.neg @infer_global(operator.add) @infer_global(operator.iadd) class TimedeltaBinAdd(TimedeltaBinOp): key = operator.add @infer_global(operator.sub) @infer_global(operator.isub) class TimedeltaBinSub(TimedeltaBinOp): key = operator.sub @infer_global(operator.mul) @infer_global(operator.imul) class TimedeltaBinMult(TimedeltaMixOp): key = operator.mul @infer_global(operator.truediv) @infer_global(operator.itruediv) class TimedeltaTrueDiv(TimedeltaDivOp): key = operator.truediv @infer_global(operator.floordiv) @infer_global(operator.ifloordiv) class TimedeltaFloorDiv(TimedeltaDivOp): key = operator.floordiv @infer_global(operator.eq) class TimedeltaCmpEq(TimedeltaCmpOp): key = operator.eq @infer_global(operator.ne) class TimedeltaCmpNe(TimedeltaCmpOp): key = operator.ne @infer_global(operator.lt) class TimedeltaCmpLt(TimedeltaOrderedCmpOp): key = operator.lt @infer_global(operator.le) class TimedeltaCmpLE(TimedeltaOrderedCmpOp): key = operator.le @infer_global(operator.gt) class TimedeltaCmpGt(TimedeltaOrderedCmpOp): key = operator.gt @infer_global(operator.ge) class TimedeltaCmpGE(TimedeltaOrderedCmpOp): key = operator.ge @infer_global(abs) class TimedeltaAbs(TimedeltaUnaryOp): pass # datetime64 operations @infer_global(operator.add) @infer_global(operator.iadd) class DatetimePlusTimedelta(AbstractTemplate): key = operator.add def generic(self, args, kws): if len(args) == 1: # Guard against unary + return left, right = args if isinstance(right, types.NPTimedelta): dt = left td = right elif isinstance(left, types.NPTimedelta): dt = right td = left else: return if isinstance(dt, types.NPDatetime): unit = npdatetime_helpers.combine_datetime_timedelta_units(dt.unit, td.unit) if unit is not None: return signature(types.NPDatetime(unit), left, right) @infer_global(operator.sub) @infer_global(operator.isub) class DatetimeMinusTimedelta(AbstractTemplate): key = operator.sub def generic(self, args, kws): if len(args) == 1: # Guard against unary - return dt, td = args if isinstance(dt, types.NPDatetime) and isinstance(td, types.NPTimedelta): unit = npdatetime_helpers.combine_datetime_timedelta_units(dt.unit, td.unit) if unit is not None: return signature(types.NPDatetime(unit), dt, td) @infer_global(operator.sub) class DatetimeMinusDatetime(AbstractTemplate): key = operator.sub def generic(self, args, kws): if len(args) == 1: # Guard against unary - return left, right = args if isinstance(left, types.NPDatetime) and isinstance(right, types.NPDatetime): # All units compatible... unit = npdatetime_helpers.get_best_unit(left.unit, right.unit) return signature(types.NPTimedelta(unit), left, right) class DatetimeCmpOp(AbstractTemplate): def generic(self, args, kws): # For datetime64 comparisons, all units are inter-comparable left, right = args if not all(isinstance(tp, types.NPDatetime) for tp in args): return return signature(types.boolean, left, right) @infer_global(operator.eq) class DatetimeCmpEq(DatetimeCmpOp): key = operator.eq @infer_global(operator.ne) class DatetimeCmpNe(DatetimeCmpOp): key = operator.ne @infer_global(operator.lt) class DatetimeCmpLt(DatetimeCmpOp): key = operator.lt @infer_global(operator.le) class DatetimeCmpLE(DatetimeCmpOp): key = operator.le @infer_global(operator.gt) class DatetimeCmpGt(DatetimeCmpOp): key = operator.gt @infer_global(operator.ge) class DatetimeCmpGE(DatetimeCmpOp): key = operator.ge