from sympy.core.symbol import symbols, Dummy from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction from sympy.core.function import Lambda from sympy.functions.elementary.exponential import exp from sympy.functions.elementary.trigonometric import sin from sympy.matrices.dense import Matrix from sympy.matrices.expressions.matexpr import MatrixSymbol from sympy.matrices.expressions.matmul import MatMul from sympy.simplify.simplify import simplify from sympy.testing.pytest import raises from sympy.matrices.common import ShapeError X = MatrixSymbol("X", 3, 3) Y = MatrixSymbol("Y", 3, 3) k = symbols("k") Xk = MatrixSymbol("X", k, k) Xd = X.as_explicit() x, y, z, t = symbols("x y z t") def test_applyfunc_matrix(): x = Dummy('x') double = Lambda(x, x**2) expr = ElementwiseApplyFunction(double, Xd) assert isinstance(expr, ElementwiseApplyFunction) assert expr.doit() == Xd.applyfunc(lambda x: x**2) assert expr.shape == (3, 3) assert expr.func(*expr.args) == expr assert simplify(expr) == expr assert expr[0, 0] == double(Xd[0, 0]) expr = ElementwiseApplyFunction(double, X) assert isinstance(expr, ElementwiseApplyFunction) assert isinstance(expr.doit(), ElementwiseApplyFunction) assert expr == X.applyfunc(double) assert expr.func(*expr.args) == expr expr = ElementwiseApplyFunction(exp, X*Y) assert expr.expr == X*Y assert expr.function.dummy_eq(Lambda(x, exp(x))) assert expr.dummy_eq((X*Y).applyfunc(exp)) assert expr.func(*expr.args) == expr assert isinstance(X*expr, MatMul) assert (X*expr).shape == (3, 3) Z = MatrixSymbol("Z", 2, 3) assert (Z*expr).shape == (2, 3) expr = ElementwiseApplyFunction(exp, Z.T)*ElementwiseApplyFunction(exp, Z) assert expr.shape == (3, 3) expr = ElementwiseApplyFunction(exp, Z)*ElementwiseApplyFunction(exp, Z.T) assert expr.shape == (2, 2) raises(ShapeError, lambda: ElementwiseApplyFunction(exp, Z)*ElementwiseApplyFunction(exp, Z)) M = Matrix([[x, y], [z, t]]) expr = ElementwiseApplyFunction(sin, M) assert isinstance(expr, ElementwiseApplyFunction) assert expr.function.dummy_eq(Lambda(x, sin(x))) assert expr.expr == M assert expr.doit() == M.applyfunc(sin) assert expr.doit() == Matrix([[sin(x), sin(y)], [sin(z), sin(t)]]) assert expr.func(*expr.args) == expr expr = ElementwiseApplyFunction(double, Xk) assert expr.doit() == expr assert expr.subs(k, 2).shape == (2, 2) assert (expr*expr).shape == (k, k) M = MatrixSymbol("M", k, t) expr2 = M.T*expr*M assert isinstance(expr2, MatMul) assert expr2.args[1] == expr assert expr2.shape == (t, t) expr3 = expr*M assert expr3.shape == (k, t) raises(ShapeError, lambda: M*expr) expr1 = ElementwiseApplyFunction(lambda x: x+1, Xk) expr2 = ElementwiseApplyFunction(lambda x: x, Xk) assert expr1 != expr2 def test_applyfunc_entry(): af = X.applyfunc(sin) assert af[0, 0] == sin(X[0, 0]) af = Xd.applyfunc(sin) assert af[0, 0] == sin(X[0, 0]) def test_applyfunc_as_explicit(): af = X.applyfunc(sin) assert af.as_explicit() == Matrix([ [sin(X[0, 0]), sin(X[0, 1]), sin(X[0, 2])], [sin(X[1, 0]), sin(X[1, 1]), sin(X[1, 2])], [sin(X[2, 0]), sin(X[2, 1]), sin(X[2, 2])], ]) def test_applyfunc_transpose(): af = Xk.applyfunc(sin) assert af.T.dummy_eq(Xk.T.applyfunc(sin)) def test_applyfunc_shape_11_matrices(): M = MatrixSymbol("M", 1, 1) double = Lambda(x, x*2) expr = M.applyfunc(sin) assert isinstance(expr, ElementwiseApplyFunction) expr = M.applyfunc(double) assert isinstance(expr, MatMul) assert expr == 2*M