from abc import ABCMeta from ..py2help import with_metaclass from ..coretypes import ( DataShape, DateTime, Function, Option, Record, String, Time, TimeDelta, Tuple, Units, ) from ..dispatch import dispatch def _fmt_path(path): """Format the path for final display. Parameters ---------- path : iterable of str The path to the values that are not equal. Returns ------- fmtd : str The formatted path to put into the error message. """ if not path: return '' return 'path: _' + ''.join(path) @dispatch(DataShape, DataShape) def assert_dshape_equal(a, b, check_dim=True, path=None, **kwargs): """Assert that two dshapes are equal, providing an informative error message when they are not equal. Parameters ---------- a, b : dshape The dshapes to check for equality. check_dim : bool, optional Check shapes for equality with respect to their dimensions. default: True check_tz : bool, optional Checks times and datetimes for equality with respect to timezones. default: True check_timedelta_unit : bool, optional Checks timedeltas for equality with respect to their unit (us, ns, ...). default: True check_str_encoding : bool, optional Checks strings for equality with respect to their encoding. default: True check_str_fixlen : bool, optional Checks string for equality with respect to their fixlen. default: True check_record_order : bool, optional Checks records for equality with respect to the order of the fields. default: True Raises ------ AssertionError Raised when the two dshapes are not equal. """ ashape = a.shape bshape = b.shape if path is None: path = () if check_dim: for n, (adim, bdim) in enumerate(zip(ashape, bshape)): if adim != bdim: path += '.shape[%d]' % n, raise AssertionError( 'dimensions do not match: %s != %s%s\n%s' % ( adim, bdim, ('\n%s != %s' % ( ' * '.join(map(str, ashape)), ' * '.join(map(str, bshape)), )) if len(a.shape) > 1 else '', _fmt_path(path), ), ) path += '.measure', assert_dshape_equal( a.measure, b.measure, check_dim=check_dim, path=path, **kwargs ) class Slotted(with_metaclass(ABCMeta)): @classmethod def __subclasshook__(cls, subcls): return hasattr(subcls, '__slots__') @assert_dshape_equal.register(Slotted, Slotted) def _check_slots(a, b, path=None, **kwargs): if type(a) != type(b): return _base_case(a, b, path=path, **kwargs) assert a.__slots__ == b.__slots__, 'slots mismatch: %r != %r\n%s' % ( a.__slots__, b.__slots__, _fmt_path(path), ) if path is None: path = () for slot in a.__slots__: assert getattr(a, slot) == getattr(b, slot), \ "%s %ss do not match: %r != %r\n%s" % ( type(a).__name__.lower(), slot, getattr(a, slot), getattr(b, slot), _fmt_path(path + ('.' + slot,)), ) @assert_dshape_equal.register(object, object) def _base_case(a, b, path=None, **kwargs): assert a == b, '%s != %s\n%s' % (a, b, _fmt_path(path)) @dispatch((DateTime, Time), (DateTime, Time)) def assert_dshape_equal(a, b, path=None, check_tz=True, **kwargs): if type(a) != type(b): return _base_case(a, b) if check_tz: _check_slots(a, b, path) @dispatch(TimeDelta, TimeDelta) def assert_dshape_equal(a, b, path=None, check_timedelta_unit=True, **kwargs): if check_timedelta_unit: _check_slots(a, b, path) @dispatch(Units, Units) def assert_dshape_equal(a, b, path=None, **kwargs): if path is None: path = () assert a.unit == b.unit, '%s units do not match: %r != %s\n%s' % ( type(a).__name__.lower(), a.unit, b.unit, _fmt_path(path + ('.unit',)), ) path.append('.tp') assert_dshape_equal(a.tp, b.tp, **kwargs) @dispatch(String, String) def assert_dshape_equal(a, b, path=None, check_str_encoding=True, check_str_fixlen=True, **kwargs): if path is None: path = () if check_str_encoding: assert a.encoding == b.encoding, \ 'string encodings do not match: %r != %r\n%s' % ( a.encoding, b.encoding, _fmt_path(path + ('.encoding',)), ) if check_str_fixlen: assert a.fixlen == b.fixlen, \ 'string fixlens do not match: %d != %d\n%s' % ( a.fixlen, b.fixlen, _fmt_path(path + ('.fixlen',)), ) @dispatch(Option, Option) def assert_dshape_equal(a, b, path=None, **kwargs): if path is None: path = () path += '.ty', return assert_dshape_equal(a.ty, b.ty, path=path, **kwargs) @dispatch(Record, Record) def assert_dshape_equal(a, b, check_record_order=True, path=None, **kwargs): afields = a.fields bfields = b.fields assert len(afields) == len(bfields), \ 'records have mismatched field counts: %d != %d\n%r != %r\n%s' % ( len(afields), len(bfields), a.names, b.names, _fmt_path(path), ) if not check_record_order: afields = sorted(afields) bfields = sorted(bfields) if path is None: path = () for n, ((aname, afield), (bname, bfield)) in enumerate( zip(afields, bfields)): assert aname == bname, \ 'record field name at position %d does not match: %r != %r\n%s' % ( n, aname, bname, _fmt_path(path), ) assert_dshape_equal( afield, bfield, path=path + ('[%s]' % repr(aname),), check_record_order=check_record_order, **kwargs ) @dispatch(Tuple, Tuple) def assert_dshape_equal(a, b, path=None, **kwargs): assert len(a.dshapes) == len(b.dshapes), \ 'tuples have mismatched field counts: %d != %d\n%r != %r\n%s' % ( len(a.dshapes), len(b.dshapes), a, b, _fmt_path(path), ) if path is None: path = () path += '.dshapes', for n, (ashape, bshape) in enumerate(zip(a.dshapes, b.dshapes)): assert_dshape_equal( ashape, bshape, path=path + ('[%d]' % n,), **kwargs ) @dispatch(Function, Function) def assert_dshape_equal(a, b, path=None, **kwargs): assert len(a.argtypes) == len(b.argtypes),\ 'functions have different arities: %d != %d\n%r != %r\n%s' % ( len(a.argtypes), len(b.argtypes), a, b, _fmt_path(path), ) if path is None: path = () for n, (aarg, barg) in enumerate(zip(a.argtypes, b.argtypes)): assert_dshape_equal( aarg, barg, path=path + ('.argtypes[%d]' % n,), **kwargs ) assert_dshape_equal( a.restype, b.restype, path=path + ('.restype',), **kwargs )