summaryrefslogtreecommitdiff
path: root/Qwen2.5-Eval/evaluation/latex2sympy/asciimath_printer.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@timan108.cs.illinois.edu>2025-09-04 22:16:22 -0500
committerYuren Hao <yurenh2@timan108.cs.illinois.edu>2025-09-04 22:16:22 -0500
commitfc6d57ffb8d5ddb5820fcc00b5491a585c259ebc (patch)
treee9841f93a353e2107225cfc721d1ce57c0e594dc /Qwen2.5-Eval/evaluation/latex2sympy/asciimath_printer.py
Initial commit
Diffstat (limited to 'Qwen2.5-Eval/evaluation/latex2sympy/asciimath_printer.py')
-rwxr-xr-xQwen2.5-Eval/evaluation/latex2sympy/asciimath_printer.py50
1 files changed, 50 insertions, 0 deletions
diff --git a/Qwen2.5-Eval/evaluation/latex2sympy/asciimath_printer.py b/Qwen2.5-Eval/evaluation/latex2sympy/asciimath_printer.py
new file mode 100755
index 0000000..dd1b676
--- /dev/null
+++ b/Qwen2.5-Eval/evaluation/latex2sympy/asciimath_printer.py
@@ -0,0 +1,50 @@
+from sympy.printing.str import StrPrinter
+from sympy.core import S
+
+class AsciiMathPrinter(StrPrinter):
+
+ def _print_Limit(self, expr):
+ e, z = expr.args
+
+ return "lim_(%s -> %s) %s" % (self._print(z), self._print(z), self._print(e))
+
+ def _print_Integral(self, expr):
+ e, lims = expr.args
+ if len(lims) > 1:
+ return "int_(%s)^(%s) %s d%s" % (self._print(lims[1]), self._print(lims[2]), self._print(e), self._print(lims[0]))
+ else:
+ return "int %s d%s" % (self._print(e), self._print(lims))
+
+ def _print_Sum(self, expr):
+ e, lims = expr.args
+ return "sum_(%s = %s)^(%s) %s" % (self._print(lims[0]), self._print(lims[1]), self._print(lims[2]), self._print(e))
+
+ def _print_Product(self, expr):
+ e, lims = expr.args
+ return "prod_(%s = %s)^(%s) %s" % (self._print(lims[0]), self._print(lims[1]), self._print(lims[2]), self._print(e))
+
+ def _print_factorial(self, expr):
+ return "%s!" % self._print(expr.args[0])
+
+ def _print_Derivative(self, expr):
+ e = expr.args[0]
+ wrt = expr.args[1]
+ return "d/d%s %s" % (self._print(wrt), self._print(e))
+
+ def _print_Abs(self, expr):
+ return "|%s|" % self._print(expr.args[0])
+
+ def _print_Equality(self, expr):
+ return "%s = %s" % (self._print(expr.args[0]), self._print(expr.args[1]))
+
+ def _print_Pow(self, expr):
+ b = self._print(expr.base)
+ if expr.exp is S.Half:
+ return "sqrt(%s)" % b
+
+ if -expr.exp is S.Half:
+ return "1/sqrt(%s)" % b
+ if expr.exp is -S.One:
+ return "1/%s" % b
+
+ return "%s^(%s)" % (b, self._print(expr.exp))