from sympy.core.relational import (Eq, Ne) from sympy.core.singleton import S from sympy.core.symbol import symbols from sympy.functions.elementary.miscellaneous import sqrt from sympy.functions.elementary.trigonometric import (cos, sin) from sympy.external import import_module from sympy.testing.pytest import skip from sympy.utilities.matchpy_connector import WildDot, WildPlus, WildStar, Replacer matchpy = import_module("matchpy") x, y, z = symbols("x y z") def _get_first_match(expr, pattern): from matchpy import ManyToOneMatcher, Pattern matcher = ManyToOneMatcher() matcher.add(Pattern(pattern)) return next(iter(matcher.match(expr))) def test_matchpy_connector(): if matchpy is None: skip("matchpy not installed") from multiset import Multiset from matchpy import Pattern, Substitution w_ = WildDot("w_") w__ = WildPlus("w__") w___ = WildStar("w___") expr = x + y pattern = x + w_ p, subst = _get_first_match(expr, pattern) assert p == Pattern(pattern) assert subst == Substitution({'w_': y}) expr = x + y + z pattern = x + w__ p, subst = _get_first_match(expr, pattern) assert p == Pattern(pattern) assert subst == Substitution({'w__': Multiset([y, z])}) expr = x + y + z pattern = x + y + z + w___ p, subst = _get_first_match(expr, pattern) assert p == Pattern(pattern) assert subst == Substitution({'w___': Multiset()}) def test_matchpy_optional(): if matchpy is None: skip("matchpy not installed") from matchpy import Pattern, Substitution from matchpy import ManyToOneReplacer, ReplacementRule p = WildDot("p", optional=1) q = WildDot("q", optional=0) pattern = p*x + q expr1 = 2*x pa, subst = _get_first_match(expr1, pattern) assert pa == Pattern(pattern) assert subst == Substitution({'p': 2, 'q': 0}) expr2 = x + 3 pa, subst = _get_first_match(expr2, pattern) assert pa == Pattern(pattern) assert subst == Substitution({'p': 1, 'q': 3}) expr3 = x pa, subst = _get_first_match(expr3, pattern) assert pa == Pattern(pattern) assert subst == Substitution({'p': 1, 'q': 0}) expr4 = x*y + z pa, subst = _get_first_match(expr4, pattern) assert pa == Pattern(pattern) assert subst == Substitution({'p': y, 'q': z}) replacer = ManyToOneReplacer() replacer.add(ReplacementRule(Pattern(pattern), lambda p, q: sin(p)*cos(q))) assert replacer.replace(expr1) == sin(2)*cos(0) assert replacer.replace(expr2) == sin(1)*cos(3) assert replacer.replace(expr3) == sin(1)*cos(0) assert replacer.replace(expr4) == sin(y)*cos(z) def test_replacer(): if matchpy is None: skip("matchpy not installed") x1_ = WildDot("x1_") x2_ = WildDot("x2_") a_ = WildDot("a_", optional=S.One) b_ = WildDot("b_", optional=S.One) c_ = WildDot("c_", optional=S.Zero) replacer = Replacer(common_constraints=[ matchpy.CustomConstraint(lambda a_: not a_.has(x)), matchpy.CustomConstraint(lambda b_: not b_.has(x)), matchpy.CustomConstraint(lambda c_: not c_.has(x)), ]) # Rewrite the equation into implicit form, unless it's already solved: replacer.add(Eq(x1_, x2_), Eq(x1_ - x2_, 0), conditions_nonfalse=[Ne(x2_, 0), Ne(x1_, 0), Ne(x1_, x), Ne(x2_, x)]) # Simple equation solver for real numbers: replacer.add(Eq(a_*x + b_, 0), Eq(x, -b_/a_)) disc = b_**2 - 4*a_*c_ replacer.add( Eq(a_*x**2 + b_*x + c_, 0), Eq(x, (-b_ - sqrt(disc))/(2*a_)) | Eq(x, (-b_ + sqrt(disc))/(2*a_)), conditions_nonfalse=[disc >= 0] ) replacer.add( Eq(a_*x**2 + c_, 0), Eq(x, sqrt(-c_/a_)) | Eq(x, -sqrt(-c_/a_)), conditions_nonfalse=[-c_*a_ > 0] ) assert replacer.replace(Eq(3*x, y)) == Eq(x, y/3) assert replacer.replace(Eq(x**2 + 1, 0)) == Eq(x**2 + 1, 0) assert replacer.replace(Eq(x**2, 4)) == (Eq(x, 2) | Eq(x, -2)) assert replacer.replace(Eq(x**2 + 4*y*x + 4*y**2, 0)) == Eq(x, -2*y)