# -*- coding: utf-8 -*- from __future__ import print_function, division, absolute_import from itertools import chain import operator from .. import py2help from .. import parser from .. import type_symbol_table from ..validation import validate from .. import coretypes __all__ = 'dshape', 'dshapes', 'has_var_dim', 'has_ellipsis', 'cat_dshapes' subclasses = operator.methodcaller('__subclasses__') #------------------------------------------------------------------------ # Utility Functions for DataShapes #------------------------------------------------------------------------ def dshapes(*args): """ Parse a bunch of datashapes all at once. >>> a, b = dshapes('3 * int32', '2 * var * float64') """ return [dshape(arg) for arg in args] def dshape(o): """ Parse a datashape. For a thorough description see http://blaze.pydata.org/docs/datashape.html >>> ds = dshape('2 * int32') >>> ds[1] ctype("int32") """ if isinstance(o, coretypes.DataShape): return o if isinstance(o, py2help._strtypes): ds = parser.parse(o, type_symbol_table.sym) elif isinstance(o, (coretypes.CType, coretypes.String, coretypes.Record, coretypes.JSON, coretypes.Date, coretypes.Time, coretypes.DateTime, coretypes.Unit)): ds = coretypes.DataShape(o) elif isinstance(o, coretypes.Mono): ds = o elif isinstance(o, (list, tuple)): ds = coretypes.DataShape(*o) else: raise TypeError('Cannot create dshape from object of type %s' % type(o)) validate(ds) return ds def cat_dshapes(dslist): """ Concatenates a list of dshapes together along the first axis. Raises an error if there is a mismatch along another axis or the measures are different. Requires that the leading dimension be a known size for all data shapes. TODO: Relax this restriction to support streaming dimensions. >>> cat_dshapes(dshapes('10 * int32', '5 * int32')) dshape("15 * int32") """ if len(dslist) == 0: raise ValueError('Cannot concatenate an empty list of dshapes') elif len(dslist) == 1: return dslist[0] outer_dim_size = operator.index(dslist[0][0]) inner_ds = dslist[0][1:] for ds in dslist[1:]: outer_dim_size += operator.index(ds[0]) if ds[1:] != inner_ds: raise ValueError(('The datashapes to concatenate much' ' all match after' ' the first dimension (%s vs %s)') % (inner_ds, ds[1:])) return coretypes.DataShape(*[coretypes.Fixed(outer_dim_size)] + list(inner_ds)) def collect(pred, expr): """ Collect terms in expression that match predicate >>> from datashape import Unit, dshape >>> predicate = lambda term: isinstance(term, Unit) >>> dshape = dshape('var * {value: int64, loc: 2 * int32}') >>> sorted(set(collect(predicate, dshape)), key=str) [Fixed(val=2), ctype("int32"), ctype("int64"), Var()] >>> from datashape import var, int64 >>> sorted(set(collect(predicate, [var, int64])), key=str) [ctype("int64"), Var()] """ if pred(expr): return [expr] if isinstance(expr, coretypes.Record): return chain.from_iterable(collect(pred, typ) for typ in expr.types) if isinstance(expr, coretypes.Mono): return chain.from_iterable(collect(pred, typ) for typ in expr.parameters) if isinstance(expr, (list, tuple)): return chain.from_iterable(collect(pred, item) for item in expr) def has_var_dim(ds): """Returns True if datashape has a variable dimension Note currently treats variable length string as scalars. >>> has_var_dim(dshape('2 * int32')) False >>> has_var_dim(dshape('var * 2 * int32')) True """ return has((coretypes.Ellipsis, coretypes.Var), ds) def has(typ, ds): if isinstance(ds, typ): return True if isinstance(ds, coretypes.Record): return any(has(typ, t) for t in ds.types) if isinstance(ds, coretypes.Mono): return any(has(typ, p) for p in ds.parameters) if isinstance(ds, (list, tuple)): return any(has(typ, item) for item in ds) return False def has_ellipsis(ds): """Returns True if the datashape has an ellipsis >>> has_ellipsis(dshape('2 * int')) False >>> has_ellipsis(dshape('... * int')) True """ return has(coretypes.Ellipsis, ds)