import numpy as np def prepare_shape_strides_dtype(shape, strides, dtype, order): dtype = np.dtype(dtype) if isinstance(shape, int): shape = (shape,) if isinstance(strides, int): strides = (strides,) else: strides = strides or _fill_stride_by_order(shape, dtype, order) return shape, strides, dtype def _fill_stride_by_order(shape, dtype, order): nd = len(shape) if nd == 0: return () strides = [0] * nd if order == 'C': strides[-1] = dtype.itemsize for d in reversed(range(nd - 1)): strides[d] = strides[d + 1] * shape[d + 1] elif order == 'F': strides[0] = dtype.itemsize for d in range(1, nd): strides[d] = strides[d - 1] * shape[d - 1] else: raise ValueError('must be either C/F order') return tuple(strides)