import itertools import numpy as np import operator from numba.core import types, errors from numba import prange from numba.parfors.parfor import internal_prange from numba.core.typing.templates import (AttributeTemplate, ConcreteTemplate, AbstractTemplate, infer_global, infer, infer_getattr, signature, bound_function, make_callable_template) from numba.cpython.builtins import get_type_min_value, get_type_max_value from numba.core.extending import ( typeof_impl, type_callable, models, register_model, make_attribute_wrapper, ) @infer_global(print) class Print(AbstractTemplate): def generic(self, args, kws): for a in args: sig = self.context.resolve_function_type("print_item", (a,), {}) if sig is None: raise TypeError("Type %s is not printable." % a) assert sig.return_type is types.none return signature(types.none, *args) @infer class PrintItem(AbstractTemplate): key = "print_item" def generic(self, args, kws): arg, = args return signature(types.none, *args) @infer_global(abs) class Abs(ConcreteTemplate): int_cases = [signature(ty, ty) for ty in sorted(types.signed_domain)] uint_cases = [signature(ty, ty) for ty in sorted(types.unsigned_domain)] real_cases = [signature(ty, ty) for ty in sorted(types.real_domain)] complex_cases = [signature(ty.underlying_float, ty) for ty in sorted(types.complex_domain)] cases = int_cases + uint_cases + real_cases + complex_cases @infer_global(slice) class Slice(ConcreteTemplate): cases = [ signature(types.slice2_type, types.intp), signature(types.slice2_type, types.none), signature(types.slice2_type, types.none, types.none), signature(types.slice2_type, types.none, types.intp), signature(types.slice2_type, types.intp, types.none), signature(types.slice2_type, types.intp, types.intp), signature(types.slice3_type, types.intp, types.intp, types.intp), signature(types.slice3_type, types.none, types.intp, types.intp), signature(types.slice3_type, types.intp, types.none, types.intp), signature(types.slice3_type, types.intp, types.intp, types.none), signature(types.slice3_type, types.intp, types.none, types.none), signature(types.slice3_type, types.none, types.intp, types.none), signature(types.slice3_type, types.none, types.none, types.intp), signature(types.slice3_type, types.none, types.none, types.none), ] @infer_global(range, typing_key=range) @infer_global(prange, typing_key=prange) @infer_global(internal_prange, typing_key=internal_prange) class Range(ConcreteTemplate): cases = [ signature(types.range_state32_type, types.int32), signature(types.range_state32_type, types.int32, types.int32), signature(types.range_state32_type, types.int32, types.int32, types.int32), signature(types.range_state64_type, types.int64), signature(types.range_state64_type, types.int64, types.int64), signature(types.range_state64_type, types.int64, types.int64, types.int64), signature(types.unsigned_range_state64_type, types.uint64), signature(types.unsigned_range_state64_type, types.uint64, types.uint64), signature(types.unsigned_range_state64_type, types.uint64, types.uint64, types.uint64), ] @infer class GetIter(AbstractTemplate): key = "getiter" def generic(self, args, kws): assert not kws [obj] = args if isinstance(obj, types.IterableType): return signature(obj.iterator_type, obj) @infer class IterNext(AbstractTemplate): key = "iternext" def generic(self, args, kws): assert not kws [it] = args if isinstance(it, types.IteratorType): return signature(types.Pair(it.yield_type, types.boolean), it) @infer class PairFirst(AbstractTemplate): """ Given a heterogeneous pair, return the first element. """ key = "pair_first" def generic(self, args, kws): assert not kws [pair] = args if isinstance(pair, types.Pair): return signature(pair.first_type, pair) @infer class PairSecond(AbstractTemplate): """ Given a heterogeneous pair, return the second element. """ key = "pair_second" def generic(self, args, kws): assert not kws [pair] = args if isinstance(pair, types.Pair): return signature(pair.second_type, pair) def choose_result_bitwidth(*inputs): return max(types.intp.bitwidth, *(tp.bitwidth for tp in inputs)) def choose_result_int(*inputs): """ Choose the integer result type for an operation on integer inputs, according to the integer typing NBEP. """ bitwidth = choose_result_bitwidth(*inputs) signed = any(tp.signed for tp in inputs) return types.Integer.from_bitwidth(bitwidth, signed) # The "machine" integer types to take into consideration for operator typing # (according to the integer typing NBEP) machine_ints = ( sorted(set((types.intp, types.int64))) + sorted(set((types.uintp, types.uint64))) ) # Explicit integer rules for binary operators; smaller ints will be # automatically upcast. integer_binop_cases = tuple( signature(choose_result_int(op1, op2), op1, op2) for op1, op2 in itertools.product(machine_ints, machine_ints) ) class BinOp(ConcreteTemplate): cases = list(integer_binop_cases) cases += [signature(op, op, op) for op in sorted(types.real_domain)] cases += [signature(op, op, op) for op in sorted(types.complex_domain)] @infer_global(operator.add) class BinOpAdd(BinOp): pass @infer_global(operator.iadd) class BinOpAdd(BinOp): pass @infer_global(operator.sub) class BinOpSub(BinOp): pass @infer_global(operator.isub) class BinOpSub(BinOp): pass @infer_global(operator.mul) class BinOpMul(BinOp): pass @infer_global(operator.imul) class BinOpMul(BinOp): pass @infer_global(operator.mod) class BinOpMod(ConcreteTemplate): cases = list(integer_binop_cases) cases += [signature(op, op, op) for op in sorted(types.real_domain)] @infer_global(operator.imod) class BinOpMod(ConcreteTemplate): cases = list(integer_binop_cases) cases += [signature(op, op, op) for op in sorted(types.real_domain)] @infer_global(operator.truediv) class BinOpTrueDiv(ConcreteTemplate): cases = [signature(types.float64, op1, op2) for op1, op2 in itertools.product(machine_ints, machine_ints)] cases += [signature(op, op, op) for op in sorted(types.real_domain)] cases += [signature(op, op, op) for op in sorted(types.complex_domain)] @infer_global(operator.itruediv) class BinOpTrueDiv(ConcreteTemplate): cases = [signature(types.float64, op1, op2) for op1, op2 in itertools.product(machine_ints, machine_ints)] cases += [signature(op, op, op) for op in sorted(types.real_domain)] cases += [signature(op, op, op) for op in sorted(types.complex_domain)] @infer_global(operator.floordiv) class BinOpFloorDiv(ConcreteTemplate): cases = list(integer_binop_cases) cases += [signature(op, op, op) for op in sorted(types.real_domain)] @infer_global(operator.ifloordiv) class BinOpFloorDiv(ConcreteTemplate): cases = list(integer_binop_cases) cases += [signature(op, op, op) for op in sorted(types.real_domain)] @infer_global(divmod) class DivMod(ConcreteTemplate): _tys = machine_ints + sorted(types.real_domain) cases = [signature(types.UniTuple(ty, 2), ty, ty) for ty in _tys] @infer_global(operator.pow) class BinOpPower(ConcreteTemplate): cases = list(integer_binop_cases) # Ensure that float32 ** int doesn't go through DP computations cases += [signature(types.float32, types.float32, op) for op in (types.int32, types.int64, types.uint64)] cases += [signature(types.float64, types.float64, op) for op in (types.int32, types.int64, types.uint64)] cases += [signature(op, op, op) for op in sorted(types.real_domain)] cases += [signature(op, op, op) for op in sorted(types.complex_domain)] @infer_global(operator.ipow) class BinOpPower(ConcreteTemplate): cases = list(integer_binop_cases) # Ensure that float32 ** int doesn't go through DP computations cases += [signature(types.float32, types.float32, op) for op in (types.int32, types.int64, types.uint64)] cases += [signature(types.float64, types.float64, op) for op in (types.int32, types.int64, types.uint64)] cases += [signature(op, op, op) for op in sorted(types.real_domain)] cases += [signature(op, op, op) for op in sorted(types.complex_domain)] @infer_global(pow) class PowerBuiltin(BinOpPower): # TODO add 3 operand version pass class BitwiseShiftOperation(ConcreteTemplate): # For bitshifts, only the first operand's signedness matters # to choose the operation's signedness (the second operand # should always be positive but will generally be considered # signed anyway, since it's often a constant integer). # (also, see issue #1995 for right-shifts) # The RHS type is fixed to 64-bit signed/unsigned ints. # The implementation will always cast the operands to the width of the # result type, which is the widest between the LHS type and (u)intp. cases = [signature(max(op, types.intp), op, op2) for op in sorted(types.signed_domain) for op2 in [types.uint64, types.int64]] cases += [signature(max(op, types.uintp), op, op2) for op in sorted(types.unsigned_domain) for op2 in [types.uint64, types.int64]] unsafe_casting = False @infer_global(operator.lshift) class BitwiseLeftShift(BitwiseShiftOperation): pass @infer_global(operator.ilshift) class BitwiseLeftShift(BitwiseShiftOperation): pass @infer_global(operator.rshift) class BitwiseRightShift(BitwiseShiftOperation): pass @infer_global(operator.irshift) class BitwiseRightShift(BitwiseShiftOperation): pass class BitwiseLogicOperation(BinOp): cases = [signature(types.boolean, types.boolean, types.boolean)] cases += list(integer_binop_cases) unsafe_casting = False @infer_global(operator.and_) class BitwiseAnd(BitwiseLogicOperation): pass @infer_global(operator.iand) class BitwiseAnd(BitwiseLogicOperation): pass @infer_global(operator.or_) class BitwiseOr(BitwiseLogicOperation): pass @infer_global(operator.ior) class BitwiseOr(BitwiseLogicOperation): pass @infer_global(operator.xor) class BitwiseXor(BitwiseLogicOperation): pass @infer_global(operator.ixor) class BitwiseXor(BitwiseLogicOperation): pass # Bitwise invert and negate are special: we must not upcast the operand # for unsigned numbers, as that would change the result. # (i.e. ~np.int8(0) == 255 but ~np.int32(0) == 4294967295). @infer_global(operator.invert) class BitwiseInvert(ConcreteTemplate): # Note Numba follows the Numpy semantics of returning a bool, # while Python returns an int. This makes it consistent with # np.invert() and makes array expressions correct. cases = [signature(types.boolean, types.boolean)] cases += [signature(choose_result_int(op), op) for op in sorted(types.unsigned_domain)] cases += [signature(choose_result_int(op), op) for op in sorted(types.signed_domain)] unsafe_casting = False class UnaryOp(ConcreteTemplate): cases = [signature(choose_result_int(op), op) for op in sorted(types.unsigned_domain)] cases += [signature(choose_result_int(op), op) for op in sorted(types.signed_domain)] cases += [signature(op, op) for op in sorted(types.real_domain)] cases += [signature(op, op) for op in sorted(types.complex_domain)] cases += [signature(types.intp, types.boolean)] @infer_global(operator.neg) class UnaryNegate(UnaryOp): pass @infer_global(operator.pos) class UnaryPositive(UnaryOp): pass @infer_global(operator.not_) class UnaryNot(ConcreteTemplate): cases = [signature(types.boolean, types.boolean)] cases += [signature(types.boolean, op) for op in sorted(types.signed_domain)] cases += [signature(types.boolean, op) for op in sorted(types.unsigned_domain)] cases += [signature(types.boolean, op) for op in sorted(types.real_domain)] cases += [signature(types.boolean, op) for op in sorted(types.complex_domain)] class OrderedCmpOp(ConcreteTemplate): cases = [signature(types.boolean, types.boolean, types.boolean)] cases += [signature(types.boolean, op, op) for op in sorted(types.signed_domain)] cases += [signature(types.boolean, op, op) for op in sorted(types.unsigned_domain)] cases += [signature(types.boolean, op, op) for op in sorted(types.real_domain)] class UnorderedCmpOp(ConcreteTemplate): cases = OrderedCmpOp.cases + [ signature(types.boolean, op, op) for op in sorted(types.complex_domain)] @infer_global(operator.lt) class CmpOpLt(OrderedCmpOp): pass @infer_global(operator.le) class CmpOpLe(OrderedCmpOp): pass @infer_global(operator.gt) class CmpOpGt(OrderedCmpOp): pass @infer_global(operator.ge) class CmpOpGe(OrderedCmpOp): pass # more specific overloads should be registered first @infer_global(operator.eq) class ConstOpEq(AbstractTemplate): def generic(self, args, kws): assert not kws (arg1, arg2) = args if isinstance(arg1, types.Literal) and isinstance(arg2, types.Literal): return signature(types.boolean, arg1, arg2) @infer_global(operator.ne) class ConstOpNotEq(ConstOpEq): pass @infer_global(operator.eq) class CmpOpEq(UnorderedCmpOp): pass @infer_global(operator.ne) class CmpOpNe(UnorderedCmpOp): pass class TupleCompare(AbstractTemplate): def generic(self, args, kws): [lhs, rhs] = args if isinstance(lhs, types.BaseTuple) and isinstance(rhs, types.BaseTuple): for u, v in zip(lhs, rhs): # Check element-wise comparability res = self.context.resolve_function_type(self.key, (u, v), {}) if res is None: break else: return signature(types.boolean, lhs, rhs) @infer_global(operator.eq) class TupleEq(TupleCompare): pass @infer_global(operator.ne) class TupleNe(TupleCompare): pass @infer_global(operator.ge) class TupleGe(TupleCompare): pass @infer_global(operator.gt) class TupleGt(TupleCompare): pass @infer_global(operator.le) class TupleLe(TupleCompare): pass @infer_global(operator.lt) class TupleLt(TupleCompare): pass @infer_global(operator.add) class TupleAdd(AbstractTemplate): def generic(self, args, kws): if len(args) == 2: a, b = args if (isinstance(a, types.BaseTuple) and isinstance(b, types.BaseTuple) and not isinstance(a, types.BaseNamedTuple) and not isinstance(b, types.BaseNamedTuple)): res = types.BaseTuple.from_types(tuple(a) + tuple(b)) return signature(res, a, b) class CmpOpIdentity(AbstractTemplate): def generic(self, args, kws): [lhs, rhs] = args return signature(types.boolean, lhs, rhs) @infer_global(operator.is_) class CmpOpIs(CmpOpIdentity): pass @infer_global(operator.is_not) class CmpOpIsNot(CmpOpIdentity): pass def normalize_1d_index(index): """ Normalize the *index* type (an integer or slice) for indexing a 1D sequence. """ if isinstance(index, types.SliceType): return index elif isinstance(index, types.Integer): return types.intp if index.signed else types.uintp @infer_global(operator.getitem) class GetItemCPointer(AbstractTemplate): def generic(self, args, kws): assert not kws ptr, idx = args if isinstance(ptr, types.CPointer) and isinstance(idx, types.Integer): return signature(ptr.dtype, ptr, normalize_1d_index(idx)) @infer_global(operator.setitem) class SetItemCPointer(AbstractTemplate): def generic(self, args, kws): assert not kws ptr, idx, val = args if isinstance(ptr, types.CPointer) and isinstance(idx, types.Integer): return signature(types.none, ptr, normalize_1d_index(idx), ptr.dtype) @infer_global(len) class Len(AbstractTemplate): def generic(self, args, kws): assert not kws (val,) = args if isinstance(val, (types.Buffer, types.BaseTuple)): return signature(types.intp, val) elif isinstance(val, (types.RangeType)): return signature(val.dtype, val) @infer_global(tuple) class TupleConstructor(AbstractTemplate): def generic(self, args, kws): assert not kws # empty tuple case if len(args) == 0: return signature(types.Tuple(())) (val,) = args # tuple as input if isinstance(val, types.BaseTuple): return signature(val, val) @infer_global(operator.contains) class Contains(AbstractTemplate): def generic(self, args, kws): assert not kws (seq, val) = args if isinstance(seq, (types.Sequence)): return signature(types.boolean, seq, val) @infer_global(operator.truth) class TupleBool(AbstractTemplate): def generic(self, args, kws): assert not kws (val,) = args if isinstance(val, (types.BaseTuple)): return signature(types.boolean, val) @infer class StaticGetItemTuple(AbstractTemplate): key = "static_getitem" def generic(self, args, kws): tup, idx = args ret = None if not isinstance(tup, types.BaseTuple): return if isinstance(idx, int): try: ret = tup.types[idx] except IndexError: raise errors.NumbaIndexError("tuple index out of range") elif isinstance(idx, slice): ret = types.BaseTuple.from_types(tup.types[idx]) if ret is not None: sig = signature(ret, *args) return sig @infer class StaticGetItemLiteralList(AbstractTemplate): key = "static_getitem" def generic(self, args, kws): tup, idx = args ret = None if not isinstance(tup, types.LiteralList): return if isinstance(idx, int): ret = tup.types[idx] if ret is not None: sig = signature(ret, *args) return sig @infer class StaticGetItemLiteralStrKeyDict(AbstractTemplate): key = "static_getitem" def generic(self, args, kws): tup, idx = args ret = None if not isinstance(tup, types.LiteralStrKeyDict): return if isinstance(idx, str): if idx in tup.fields: lookup = tup.fields.index(idx) else: raise errors.NumbaKeyError(f"Key '{idx}' is not in dict.") ret = tup.types[lookup] if ret is not None: sig = signature(ret, *args) return sig # Generic implementation for "not in" @infer class GenericNotIn(AbstractTemplate): key = "not in" def generic(self, args, kws): args = args[::-1] sig = self.context.resolve_function_type(operator.contains, args, kws) return signature(sig.return_type, *sig.args[::-1]) #------------------------------------------------------------------------------- @infer_getattr class MemoryViewAttribute(AttributeTemplate): key = types.MemoryView def resolve_contiguous(self, buf): return types.boolean def resolve_c_contiguous(self, buf): return types.boolean def resolve_f_contiguous(self, buf): return types.boolean def resolve_itemsize(self, buf): return types.intp def resolve_nbytes(self, buf): return types.intp def resolve_readonly(self, buf): return types.boolean def resolve_shape(self, buf): return types.UniTuple(types.intp, buf.ndim) def resolve_strides(self, buf): return types.UniTuple(types.intp, buf.ndim) def resolve_ndim(self, buf): return types.intp #------------------------------------------------------------------------------- @infer_getattr class BooleanAttribute(AttributeTemplate): key = types.Boolean def resolve___class__(self, ty): return types.NumberClass(ty) @bound_function("number.item") def resolve_item(self, ty, args, kws): assert not kws if not args: return signature(ty) @infer_getattr class NumberAttribute(AttributeTemplate): key = types.Number def resolve___class__(self, ty): return types.NumberClass(ty) def resolve_real(self, ty): return getattr(ty, "underlying_float", ty) def resolve_imag(self, ty): return getattr(ty, "underlying_float", ty) @bound_function("complex.conjugate") def resolve_conjugate(self, ty, args, kws): assert not args assert not kws return signature(ty) @bound_function("number.item") def resolve_item(self, ty, args, kws): assert not kws if not args: return signature(ty) @infer_getattr class NPTimedeltaAttribute(AttributeTemplate): key = types.NPTimedelta def resolve___class__(self, ty): return types.NumberClass(ty) @infer_getattr class NPDatetimeAttribute(AttributeTemplate): key = types.NPDatetime def resolve___class__(self, ty): return types.NumberClass(ty) @infer_getattr class SliceAttribute(AttributeTemplate): key = types.SliceType def resolve_start(self, ty): return types.intp def resolve_stop(self, ty): return types.intp def resolve_step(self, ty): return types.intp @bound_function("slice.indices") def resolve_indices(self, ty, args, kws): assert not kws if len(args) != 1: raise errors.NumbaTypeError( "indices() takes exactly one argument (%d given)" % len(args) ) typ, = args if not isinstance(typ, types.Integer): raise errors.NumbaTypeError( "'%s' object cannot be interpreted as an integer" % typ ) return signature(types.UniTuple(types.intp, 3), types.intp) #------------------------------------------------------------------------------- @infer_getattr class NumberClassAttribute(AttributeTemplate): key = types.NumberClass def resolve___call__(self, classty): """ Resolve a number class's constructor (e.g. calling int(...)) """ ty = classty.instance_type def typer(val): if isinstance(val, (types.BaseTuple, types.Sequence)): # Array constructor, e.g. np.int32([1, 2]) fnty = self.context.resolve_value_type(np.array) sig = fnty.get_call_type(self.context, (val, types.DType(ty)), {}) return sig.return_type elif isinstance(val, (types.Number, types.Boolean, types.IntEnumMember)): # Scalar constructor, e.g. np.int32(42) return ty elif isinstance(val, (types.NPDatetime, types.NPTimedelta)): # Constructor cast from datetime-like, e.g. # > np.int64(np.datetime64("2000-01-01")) if ty.bitwidth == 64: return ty else: msg = (f"Cannot cast {val} to {ty} as {ty} is not 64 bits " "wide.") raise errors.TypingError(msg) else: if (isinstance(val, types.Array) and val.ndim == 0 and val.dtype == ty): # This is 0d array -> scalar degrading return ty else: # unsupported msg = f"Casting {val} to {ty} directly is unsupported." if isinstance(val, types.Array): # array casts are supported a different way. msg += f" Try doing '.astype(np.{ty})' instead" raise errors.TypingError(msg) return types.Function(make_callable_template(key=ty, typer=typer)) @infer_getattr class TypeRefAttribute(AttributeTemplate): key = types.TypeRef def resolve___call__(self, classty): """ Resolve a number class's constructor (e.g. calling int(...)) Note: This is needed because of the limitation of the current type-system implementation. Specifically, the lack of a higher-order type (i.e. passing the ``DictType`` vs ``DictType(key_type, value_type)``) """ ty = classty.instance_type if isinstance(ty, type) and issubclass(ty, types.Type): # Redirect the typing to a: # @type_callable(ty) # def typeddict_call(context): # ... # For example, see numba/typed/typeddict.py # @type_callable(DictType) # def typeddict_call(context): class Redirect(object): def __init__(self, context): self.context = context def __call__(self, *args, **kwargs): result = self.context.resolve_function_type(ty, args, kwargs) if hasattr(result, "pysig"): self.pysig = result.pysig return result return types.Function(make_callable_template(key=ty, typer=Redirect(self.context))) #------------------------------------------------------------------------------ class MinMaxBase(AbstractTemplate): def _unify_minmax(self, tys): for ty in tys: if not isinstance(ty, types.Number): return return self.context.unify_types(*tys) def generic(self, args, kws): """ Resolve a min() or max() call. """ assert not kws if not args: return if len(args) == 1: # max(arg) only supported if arg is an iterable if isinstance(args[0], types.BaseTuple): tys = list(args[0]) if not tys: raise TypeError("%s() argument is an empty tuple" % (self.key.__name__,)) else: return else: # max(*args) tys = args retty = self._unify_minmax(tys) if retty is not None: return signature(retty, *args) @infer_global(max) class Max(MinMaxBase): pass @infer_global(min) class Min(MinMaxBase): pass @infer_global(round) class Round(ConcreteTemplate): cases = [ signature(types.intp, types.float32), signature(types.int64, types.float64), signature(types.float32, types.float32, types.intp), signature(types.float64, types.float64, types.intp), ] #------------------------------------------------------------------------------ @infer_global(bool) class Bool(AbstractTemplate): def generic(self, args, kws): assert not kws [arg] = args if isinstance(arg, (types.Boolean, types.Number)): return signature(types.boolean, arg) # XXX typing for bool cannot be polymorphic because of the # types.Function thing, so we redirect to the operator.truth # intrinsic. return self.context.resolve_function_type(operator.truth, args, kws) @infer_global(int) class Int(AbstractTemplate): def generic(self, args, kws): if kws: raise errors.NumbaAssertionError('kws not supported') [arg] = args if isinstance(arg, types.Integer): return signature(arg, arg) if isinstance(arg, (types.Float, types.Boolean)): return signature(types.intp, arg) @infer_global(float) class Float(AbstractTemplate): def generic(self, args, kws): assert not kws [arg] = args if arg not in types.number_domain: raise errors.NumbaTypeError("float() only support for numbers") if arg in types.complex_domain: raise errors.NumbaTypeError("float() does not support complex") if arg in types.integer_domain: return signature(types.float64, arg) elif arg in types.real_domain: return signature(arg, arg) @infer_global(complex) class Complex(AbstractTemplate): def generic(self, args, kws): assert not kws if len(args) == 1: [arg] = args if arg not in types.number_domain: raise TypeError("complex() only support for numbers") if arg == types.float32: return signature(types.complex64, arg) else: return signature(types.complex128, arg) elif len(args) == 2: [real, imag] = args if (real not in types.number_domain or imag not in types.number_domain): raise TypeError("complex() only support for numbers") if real == imag == types.float32: return signature(types.complex64, real, imag) else: return signature(types.complex128, real, imag) #------------------------------------------------------------------------------ @infer_global(enumerate) class Enumerate(AbstractTemplate): def generic(self, args, kws): assert not kws it = args[0] if len(args) > 1 and not isinstance(args[1], types.Integer): raise errors.NumbaTypeError("Only integers supported as start " "value in enumerate") elif len(args) > 2: #let python raise its own error enumerate(*args) if isinstance(it, types.IterableType): enumerate_type = types.EnumerateType(it) return signature(enumerate_type, *args) @infer_global(zip) class Zip(AbstractTemplate): def generic(self, args, kws): assert not kws if all(isinstance(it, types.IterableType) for it in args): zip_type = types.ZipType(args) return signature(zip_type, *args) @infer_global(iter) class Iter(AbstractTemplate): def generic(self, args, kws): assert not kws if len(args) == 1: it = args[0] if isinstance(it, types.IterableType): return signature(it.iterator_type, *args) @infer_global(next) class Next(AbstractTemplate): def generic(self, args, kws): assert not kws if len(args) == 1: it = args[0] if isinstance(it, types.IteratorType): return signature(it.yield_type, *args) #------------------------------------------------------------------------------ @infer_global(type) class TypeBuiltin(AbstractTemplate): def generic(self, args, kws): assert not kws if len(args) == 1: # One-argument type() -> return the __class__ # Avoid literal types arg = types.unliteral(args[0]) classty = self.context.resolve_getattr(arg, "__class__") if classty is not None: return signature(classty, *args) #------------------------------------------------------------------------------ @infer_getattr class OptionalAttribute(AttributeTemplate): key = types.Optional def generic_resolve(self, optional, attr): return self.context.resolve_getattr(optional.type, attr) #------------------------------------------------------------------------------ @infer_getattr class DeferredAttribute(AttributeTemplate): key = types.DeferredType def generic_resolve(self, deferred, attr): return self.context.resolve_getattr(deferred.get(), attr) #------------------------------------------------------------------------------ @infer_global(get_type_min_value) @infer_global(get_type_max_value) class MinValInfer(AbstractTemplate): def generic(self, args, kws): assert not kws assert len(args) == 1 assert isinstance(args[0], (types.DType, types.NumberClass)) return signature(args[0].dtype, *args) #------------------------------------------------------------------------------ class IndexValue(object): """ Index and value """ def __init__(self, ind, val): self.index = ind self.value = val def __repr__(self): return 'IndexValue(%f, %f)' % (self.index, self.value) class IndexValueType(types.Type): def __init__(self, val_typ): self.val_typ = val_typ super(IndexValueType, self).__init__( name='IndexValueType({})'.format(val_typ)) @typeof_impl.register(IndexValue) def typeof_index(val, c): val_typ = typeof_impl(val.value, c) return IndexValueType(val_typ) @type_callable(IndexValue) def type_index_value(context): def typer(ind, mval): if ind == types.intp or ind == types.uintp: return IndexValueType(mval) return typer @register_model(IndexValueType) class IndexValueModel(models.StructModel): def __init__(self, dmm, fe_type): members = [ ('index', types.intp), ('value', fe_type.val_typ), ] models.StructModel.__init__(self, dmm, fe_type, members) make_attribute_wrapper(IndexValueType, 'index', 'index') make_attribute_wrapper(IndexValueType, 'value', 'value')