From fc6d57ffb8d5ddb5820fcc00b5491a585c259ebc Mon Sep 17 00:00:00 2001 From: Yuren Hao Date: Thu, 4 Sep 2025 22:16:22 -0500 Subject: Initial commit --- .../evaluation/latex2sympy/latex2sympy2.py | 1162 ++++++++++++++++++++ 1 file changed, 1162 insertions(+) create mode 100755 Qwen2.5-Eval/evaluation/latex2sympy/latex2sympy2.py (limited to 'Qwen2.5-Eval/evaluation/latex2sympy/latex2sympy2.py') diff --git a/Qwen2.5-Eval/evaluation/latex2sympy/latex2sympy2.py b/Qwen2.5-Eval/evaluation/latex2sympy/latex2sympy2.py new file mode 100755 index 0000000..859a8c7 --- /dev/null +++ b/Qwen2.5-Eval/evaluation/latex2sympy/latex2sympy2.py @@ -0,0 +1,1162 @@ +import sympy +import re +from sympy import matrix_symbols, simplify, factor, expand, apart, expand_trig +from antlr4 import InputStream, CommonTokenStream +from antlr4.error.ErrorListener import ErrorListener + +try: + from gen.PSParser import PSParser + from gen.PSLexer import PSLexer + from gen.PSListener import PSListener +except Exception: + from .gen.PSParser import PSParser + from .gen.PSLexer import PSLexer + from .gen.PSListener import PSListener + +from sympy.printing.str import StrPrinter + +from sympy.parsing.sympy_parser import parse_expr + +import hashlib + +is_real = None + +frac_type = r'\frac' + +variances = {} +var = {} + +VARIABLE_VALUES = {} + + +def set_real(value): + global is_real + is_real = value + + +def set_variances(vars): + global variances + variances = vars + global var + var = {} + for variance in vars: + var[str(variance)] = vars[variance] + + +def latex2sympy(sympy: str, variable_values={}): + # record frac + global frac_type + if sympy.find(r'\frac') != -1: + frac_type = r'\frac' + if sympy.find(r'\dfrac') != -1: + frac_type = r'\dfrac' + if sympy.find(r'\tfrac') != -1: + frac_type = r'\tfrac' + sympy = sympy.replace(r'\dfrac', r'\frac') + sympy = sympy.replace(r'\tfrac', r'\frac') + # Translate Transpose + sympy = sympy.replace(r'\mathrm{T}', 'T', -1) + # Translate Derivative + sympy = sympy.replace(r'\mathrm{d}', 'd', -1).replace(r'{\rm d}', 'd', -1) + # Translate Matrix + sympy = sympy.replace(r'\left[\begin{matrix}', r'\begin{bmatrix}', -1).replace(r'\end{matrix}\right]', r'\end{bmatrix}', -1) + # Translate Permutation + sympy = re.sub(r"\(([a-zA-Z0-9+\-*/\\ ]+?)\)_{([a-zA-Z0-9+\-*/\\ ]+?)}", r"\\frac{(\1)!}{((\1)-(\2))!}", sympy) + # Remove \displaystyle + sympy = sympy.replace(r'\displaystyle', ' ', -1) + # Remove \quad + sympy = sympy.replace(r'\quad', ' ', -1).replace(r'\qquad', ' ', -1).replace(r'~', ' ', -1).replace(r'\,', ' ', -1) + # Remove $ + sympy = sympy.replace(r'$', ' ', -1) + + # variable values + global VARIABLE_VALUES + if len(variable_values) > 0: + VARIABLE_VALUES = variable_values + else: + VARIABLE_VALUES = {} + + # setup listener + matherror = MathErrorListener(sympy) + + # stream input + stream = InputStream(sympy) + lex = PSLexer(stream) + lex.removeErrorListeners() + lex.addErrorListener(matherror) + + tokens = CommonTokenStream(lex) + parser = PSParser(tokens) + + # remove default console error listener + parser.removeErrorListeners() + parser.addErrorListener(matherror) + + # process the input + return_data = None + math = parser.math() + + # if a list + if math.relation_list(): + return_data = [] + + # go over list items + relation_list = math.relation_list().relation_list_content() + for list_item in relation_list.relation(): + expr = convert_relation(list_item) + return_data.append(expr) + + # if not, do default + else: + relation = math.relation() + return_data = convert_relation(relation) + + return return_data + + +class MathErrorListener(ErrorListener): + def __init__(self, src): + super(ErrorListener, self).__init__() + self.src = src + + def syntaxError(self, recog, symbol, line, col, msg, e): + fmt = "%s\n%s\n%s" + marker = "~" * col + "^" + + if msg.startswith("missing"): + err = fmt % (msg, self.src, marker) + elif msg.startswith("no viable"): + err = fmt % ("I expected something else here", self.src, marker) + elif msg.startswith("mismatched"): + names = PSParser.literalNames + expected = [names[i] for i in e.getExpectedTokens() if i < len(names)] + if len(expected) < 10: + expected = " ".join(expected) + err = (fmt % ("I expected one of these: " + expected, + self.src, marker)) + else: + err = (fmt % ("I expected something else here", self.src, marker)) + else: + err = fmt % ("I don't understand this", self.src, marker) + raise Exception(err) + + +def convert_relation(rel): + if rel.expr(): + return convert_expr(rel.expr()) + + lh = convert_relation(rel.relation(0)) + rh = convert_relation(rel.relation(1)) + if rel.LT(): + return sympy.StrictLessThan(lh, rh, evaluate=False) + elif rel.LTE(): + return sympy.LessThan(lh, rh, evaluate=False) + elif rel.GT(): + return sympy.StrictGreaterThan(lh, rh, evaluate=False) + elif rel.GTE(): + return sympy.GreaterThan(lh, rh, evaluate=False) + elif rel.EQUAL(): + return sympy.Eq(lh, rh, evaluate=False) + elif rel.ASSIGNMENT(): + # !Use Global variances + if lh.is_Symbol: + # set value + variances[lh] = rh + var[str(lh)] = rh + return rh + else: + # find the symbols in lh - rh + equation = lh - rh + syms = equation.atoms(sympy.Symbol) + if len(syms) > 0: + # Solve equation + result = [] + for sym in syms: + values = sympy.solve(equation, sym) + for value in values: + result.append(sympy.Eq(sym, value, evaluate=False)) + return result + else: + return sympy.Eq(lh, rh, evaluate=False) + elif rel.IN(): + # !Use Global variances + if hasattr(rh, 'is_Pow') and rh.is_Pow and hasattr(rh.exp, 'is_Mul'): + n = rh.exp.args[0] + m = rh.exp.args[1] + if n in variances: + n = variances[n] + if m in variances: + m = variances[m] + rh = sympy.MatrixSymbol(lh, n, m) + variances[lh] = rh + var[str(lh)] = rh + else: + raise Exception("Don't support this form of definition of matrix symbol.") + return lh + elif rel.UNEQUAL(): + return sympy.Ne(lh, rh, evaluate=False) + + +def convert_expr(expr): + if expr.additive(): + return convert_add(expr.additive()) + + +def convert_elementary_transform(matrix, transform): + if transform.transform_scale(): + transform_scale = transform.transform_scale() + transform_atom = transform_scale.transform_atom() + k = None + num = int(transform_atom.NUMBER().getText()) - 1 + if transform_scale.expr(): + k = convert_expr(transform_scale.expr()) + elif transform_scale.group(): + k = convert_expr(transform_scale.group().expr()) + elif transform_scale.SUB(): + k = -1 + else: + k = 1 + if transform_atom.LETTER_NO_E().getText() == 'r': + matrix = matrix.elementary_row_op(op='n->kn', row=num, k=k) + elif transform_atom.LETTER_NO_E().getText() == 'c': + matrix = matrix.elementary_col_op(op='n->kn', col=num, k=k) + else: + raise Exception('Row and col don\'s match') + + elif transform.transform_swap(): + first_atom = transform.transform_swap().transform_atom()[0] + second_atom = transform.transform_swap().transform_atom()[1] + first_num = int(first_atom.NUMBER().getText()) - 1 + second_num = int(second_atom.NUMBER().getText()) - 1 + if first_atom.LETTER_NO_E().getText() != second_atom.LETTER_NO_E().getText(): + raise Exception('Row and col don\'s match') + elif first_atom.LETTER_NO_E().getText() == 'r': + matrix = matrix.elementary_row_op(op='n<->m', row1=first_num, row2=second_num) + elif first_atom.LETTER_NO_E().getText() == 'c': + matrix = matrix.elementary_col_op(op='n<->m', col1=first_num, col2=second_num) + else: + raise Exception('Row and col don\'s match') + + elif transform.transform_assignment(): + first_atom = transform.transform_assignment().transform_atom() + second_atom = transform.transform_assignment().transform_scale().transform_atom() + transform_scale = transform.transform_assignment().transform_scale() + k = None + if transform_scale.expr(): + k = convert_expr(transform_scale.expr()) + elif transform_scale.group(): + k = convert_expr(transform_scale.group().expr()) + elif transform_scale.SUB(): + k = -1 + else: + k = 1 + first_num = int(first_atom.NUMBER().getText()) - 1 + second_num = int(second_atom.NUMBER().getText()) - 1 + if first_atom.LETTER_NO_E().getText() != second_atom.LETTER_NO_E().getText(): + raise Exception('Row and col don\'s match') + elif first_atom.LETTER_NO_E().getText() == 'r': + matrix = matrix.elementary_row_op(op='n->n+km', k=k, row1=first_num, row2=second_num) + elif first_atom.LETTER_NO_E().getText() == 'c': + matrix = matrix.elementary_col_op(op='n->n+km', k=k, col1=first_num, col2=second_num) + else: + raise Exception('Row and col don\'s match') + + return matrix + + +def convert_matrix(matrix): + # build matrix + row = matrix.matrix_row() + tmp = [] + rows = 0 + mat = None + + for r in row: + tmp.append([]) + for expr in r.expr(): + tmp[rows].append(convert_expr(expr)) + rows = rows + 1 + + mat = sympy.Matrix(tmp) + + if hasattr(matrix, 'MATRIX_XRIGHTARROW') and matrix.MATRIX_XRIGHTARROW(): + transforms_list = matrix.elementary_transforms() + if len(transforms_list) == 1: + for transform in transforms_list[0].elementary_transform(): + mat = convert_elementary_transform(mat, transform) + elif len(transforms_list) == 2: + # firstly transform top of xrightarrow + for transform in transforms_list[1].elementary_transform(): + mat = convert_elementary_transform(mat, transform) + # firstly transform bottom of xrightarrow + for transform in transforms_list[0].elementary_transform(): + mat = convert_elementary_transform(mat, transform) + + return mat + + +def add_flat(lh, rh): + if hasattr(lh, 'is_Add') and lh.is_Add or hasattr(rh, 'is_Add') and rh.is_Add: + args = [] + if hasattr(lh, 'is_Add') and lh.is_Add: + args += list(lh.args) + else: + args += [lh] + if hasattr(rh, 'is_Add') and rh.is_Add: + args = args + list(rh.args) + else: + args += [rh] + return sympy.Add(*args, evaluate=False) + else: + return sympy.Add(lh, rh, evaluate=False) + + +def mat_add_flat(lh, rh): + if hasattr(lh, 'is_MatAdd') and lh.is_MatAdd or hasattr(rh, 'is_MatAdd') and rh.is_MatAdd: + args = [] + if hasattr(lh, 'is_MatAdd') and lh.is_MatAdd: + args += list(lh.args) + else: + args += [lh] + if hasattr(rh, 'is_MatAdd') and rh.is_MatAdd: + args = args + list(rh.args) + else: + args += [rh] + return sympy.MatAdd(*[arg.doit() for arg in args], evaluate=False) + else: + return sympy.MatAdd(lh.doit(), rh.doit(), evaluate=False) + + +def mul_flat(lh, rh): + if hasattr(lh, 'is_Mul') and lh.is_Mul or hasattr(rh, 'is_Mul') and rh.is_Mul: + args = [] + if hasattr(lh, 'is_Mul') and lh.is_Mul: + args += list(lh.args) + else: + args += [lh] + if hasattr(rh, 'is_Mul') and rh.is_Mul: + args = args + list(rh.args) + else: + args += [rh] + return sympy.Mul(*args, evaluate=False) + else: + return sympy.Mul(lh, rh, evaluate=False) + + +def mat_mul_flat(lh, rh): + if hasattr(lh, 'is_MatMul') and lh.is_MatMul or hasattr(rh, 'is_MatMul') and rh.is_MatMul: + args = [] + if hasattr(lh, 'is_MatMul') and lh.is_MatMul: + args += list(lh.args) + else: + args += [lh] + if hasattr(rh, 'is_MatMul') and rh.is_MatMul: + args = args + list(rh.args) + else: + args += [rh] + return sympy.MatMul(*[arg.doit() for arg in args], evaluate=False) + else: + if hasattr(lh, 'doit') and hasattr(rh, 'doit'): + return sympy.MatMul(lh.doit(), rh.doit(), evaluate=False) + elif hasattr(lh, 'doit') and not hasattr(rh, 'doit'): + return sympy.MatMul(lh.doit(), rh, evaluate=False) + elif not hasattr(lh, 'doit') and hasattr(rh, 'doit'): + return sympy.MatMul(lh, rh.doit(), evaluate=False) + else: + return sympy.MatMul(lh, rh, evaluate=False) + + +def convert_add(add): + if add.ADD(): + lh = convert_add(add.additive(0)) + rh = convert_add(add.additive(1)) + + if lh.is_Matrix or rh.is_Matrix: + return mat_add_flat(lh, rh) + else: + return add_flat(lh, rh) + elif add.SUB(): + lh = convert_add(add.additive(0)) + rh = convert_add(add.additive(1)) + + if lh.is_Matrix or rh.is_Matrix: + return mat_add_flat(lh, mat_mul_flat(-1, rh)) + else: + # If we want to force ordering for variables this should be: + # return Sub(lh, rh, evaluate=False) + if not rh.is_Matrix and rh.func.is_Number: + rh = -rh + else: + rh = mul_flat(-1, rh) + return add_flat(lh, rh) + else: + return convert_mp(add.mp()) + + +def convert_mp(mp): + if hasattr(mp, 'mp'): + mp_left = mp.mp(0) + mp_right = mp.mp(1) + else: + mp_left = mp.mp_nofunc(0) + mp_right = mp.mp_nofunc(1) + + if mp.MUL() or mp.CMD_TIMES() or mp.CMD_CDOT(): + lh = convert_mp(mp_left) + rh = convert_mp(mp_right) + + if lh.is_Matrix or rh.is_Matrix: + return mat_mul_flat(lh, rh) + else: + return mul_flat(lh, rh) + elif mp.DIV() or mp.CMD_DIV() or mp.COLON(): + lh = convert_mp(mp_left) + rh = convert_mp(mp_right) + if lh.is_Matrix or rh.is_Matrix: + return sympy.MatMul(lh, sympy.Pow(rh, -1, evaluate=False), evaluate=False) + else: + return sympy.Mul(lh, sympy.Pow(rh, -1, evaluate=False), evaluate=False) + elif mp.CMD_MOD(): + lh = convert_mp(mp_left) + rh = convert_mp(mp_right) + if rh.is_Matrix: + raise Exception("Cannot perform modulo operation with a matrix as an operand") + else: + return sympy.Mod(lh, rh, evaluate=False) + else: + if hasattr(mp, 'unary'): + return convert_unary(mp.unary()) + else: + return convert_unary(mp.unary_nofunc()) + + +def convert_unary(unary): + if hasattr(unary, 'unary'): + nested_unary = unary.unary() + else: + nested_unary = unary.unary_nofunc() + if hasattr(unary, 'postfix_nofunc'): + first = unary.postfix() + tail = unary.postfix_nofunc() + postfix = [first] + tail + else: + postfix = unary.postfix() + + if unary.ADD(): + return convert_unary(nested_unary) + elif unary.SUB(): + tmp_convert_nested_unary = convert_unary(nested_unary) + if tmp_convert_nested_unary.is_Matrix: + return mat_mul_flat(-1, tmp_convert_nested_unary, evaluate=False) + else: + if tmp_convert_nested_unary.func.is_Number: + return -tmp_convert_nested_unary + else: + return mul_flat(-1, tmp_convert_nested_unary) + elif postfix: + return convert_postfix_list(postfix) + + +def convert_postfix_list(arr, i=0): + if i >= len(arr): + raise Exception("Index out of bounds") + + res = convert_postfix(arr[i]) + + if isinstance(res, sympy.Expr) or isinstance(res, sympy.Matrix) or res is sympy.S.EmptySet: + if i == len(arr) - 1: + return res # nothing to multiply by + else: + # multiply by next + rh = convert_postfix_list(arr, i + 1) + + if res.is_Matrix or rh.is_Matrix: + return mat_mul_flat(res, rh) + else: + return mul_flat(res, rh) + elif isinstance(res, tuple) or isinstance(res, list) or isinstance(res, dict): + return res + else: # must be derivative + wrt = res[0] + if i == len(arr) - 1: + raise Exception("Expected expression for derivative") + else: + expr = convert_postfix_list(arr, i + 1) + return sympy.Derivative(expr, wrt) + + +def do_subs(expr, at): + if at.expr(): + at_expr = convert_expr(at.expr()) + syms = at_expr.atoms(sympy.Symbol) + if len(syms) == 0: + return expr + elif len(syms) > 0: + sym = next(iter(syms)) + return expr.subs(sym, at_expr) + elif at.equality(): + lh = convert_expr(at.equality().expr(0)) + rh = convert_expr(at.equality().expr(1)) + return expr.subs(lh, rh) + + +def convert_postfix(postfix): + if hasattr(postfix, 'exp'): + exp_nested = postfix.exp() + else: + exp_nested = postfix.exp_nofunc() + + exp = convert_exp(exp_nested) + for op in postfix.postfix_op(): + if op.BANG(): + if isinstance(exp, list): + raise Exception("Cannot apply postfix to derivative") + exp = sympy.factorial(exp, evaluate=False) + elif op.eval_at(): + ev = op.eval_at() + at_b = None + at_a = None + if ev.eval_at_sup(): + at_b = do_subs(exp, ev.eval_at_sup()) + if ev.eval_at_sub(): + at_a = do_subs(exp, ev.eval_at_sub()) + if at_b is not None and at_a is not None: + exp = add_flat(at_b, mul_flat(at_a, -1)) + elif at_b is not None: + exp = at_b + elif at_a is not None: + exp = at_a + elif op.transpose(): + try: + exp = exp.T + except: + try: + exp = sympy.transpose(exp) + except: + pass + pass + + return exp + + +def convert_exp(exp): + if hasattr(exp, 'exp'): + exp_nested = exp.exp() + else: + exp_nested = exp.exp_nofunc() + + if exp_nested: + base = convert_exp(exp_nested) + if isinstance(base, list): + raise Exception("Cannot raise derivative to power") + if exp.atom(): + exponent = convert_atom(exp.atom()) + elif exp.expr(): + exponent = convert_expr(exp.expr()) + return sympy.Pow(base, exponent, evaluate=False) + else: + if hasattr(exp, 'comp'): + return convert_comp(exp.comp()) + else: + return convert_comp(exp.comp_nofunc()) + + +def convert_comp(comp): + if comp.group(): + return convert_expr(comp.group().expr()) + elif comp.norm_group(): + return convert_expr(comp.norm_group().expr()).norm() + elif comp.abs_group(): + return sympy.Abs(convert_expr(comp.abs_group().expr()), evaluate=False) + elif comp.floor_group(): + return handle_floor(convert_expr(comp.floor_group().expr())) + elif comp.ceil_group(): + return handle_ceil(convert_expr(comp.ceil_group().expr())) + elif comp.atom(): + return convert_atom(comp.atom()) + elif comp.frac(): + return convert_frac(comp.frac()) + elif comp.binom(): + return convert_binom(comp.binom()) + elif comp.matrix(): + return convert_matrix(comp.matrix()) + elif comp.det(): + # !Use Global variances + return convert_matrix(comp.det()).subs(variances).det() + elif comp.func(): + return convert_func(comp.func()) + + +def convert_atom(atom): + if atom.atom_expr(): + atom_expr = atom.atom_expr() + + # find the atom's text + atom_text = '' + if atom_expr.LETTER_NO_E(): + atom_text = atom_expr.LETTER_NO_E().getText() + if atom_text == "I": + return sympy.I + elif atom_expr.GREEK_CMD(): + atom_text = atom_expr.GREEK_CMD().getText()[1:].strip() + elif atom_expr.OTHER_SYMBOL_CMD(): + atom_text = atom_expr.OTHER_SYMBOL_CMD().getText().strip() + elif atom_expr.accent(): + atom_accent = atom_expr.accent() + # get name for accent + name = atom_accent.start.text + # name = atom_accent.start.text[1:] + # exception: check if bar or overline which are treated both as bar + # if name in ["bar", "overline"]: + # name = "bar" + # if name in ["vec", "overrightarrow"]: + # name = "vec" + # if name in ["tilde", "widetilde"]: + # name = "tilde" + # get the base (variable) + base = atom_accent.base.getText() + # set string to base+name + atom_text = name + '{' + base + '}' + + # find atom's subscript, if any + subscript_text = '' + if atom_expr.subexpr(): + subexpr = atom_expr.subexpr() + subscript = None + if subexpr.expr(): # subscript is expr + subscript = subexpr.expr().getText().strip() + elif subexpr.atom(): # subscript is atom + subscript = subexpr.atom().getText().strip() + elif subexpr.args(): # subscript is args + subscript = subexpr.args().getText().strip() + subscript_inner_text = StrPrinter().doprint(subscript) + if len(subscript_inner_text) > 1: + subscript_text = '_{' + subscript_inner_text + '}' + else: + subscript_text = '_' + subscript_inner_text + + # construct the symbol using the text and optional subscript + atom_symbol = sympy.Symbol(atom_text + subscript_text, real=is_real) + # for matrix symbol + matrix_symbol = None + global var + if atom_text + subscript_text in var: + try: + rh = var[atom_text + subscript_text] + shape = sympy.shape(rh) + matrix_symbol = sympy.MatrixSymbol(atom_text + subscript_text, shape[0], shape[1]) + variances[matrix_symbol] = variances[atom_symbol] + except: + pass + + # find the atom's superscript, and return as a Pow if found + if atom_expr.supexpr(): + supexpr = atom_expr.supexpr() + func_pow = None + if supexpr.expr(): + func_pow = convert_expr(supexpr.expr()) + else: + func_pow = convert_atom(supexpr.atom()) + return sympy.Pow(atom_symbol, func_pow, evaluate=False) + + return atom_symbol if not matrix_symbol else matrix_symbol + elif atom.SYMBOL(): + s = atom.SYMBOL().getText().replace("\\$", "").replace("\\%", "") + if s == "\\infty": + return sympy.oo + elif s == '\\pi': + return sympy.pi + elif s == '\\emptyset': + return sympy.S.EmptySet + else: + raise Exception("Unrecognized symbol") + elif atom.NUMBER(): + s = atom.NUMBER().getText().replace(",", "") + try: + sr = sympy.Rational(s) + return sr + except (TypeError, ValueError): + return sympy.Number(s) + elif atom.E_NOTATION(): + s = atom.E_NOTATION().getText().replace(",", "") + try: + sr = sympy.Rational(s) + return sr + except (TypeError, ValueError): + return sympy.Number(s) + elif atom.DIFFERENTIAL(): + var = get_differential_var(atom.DIFFERENTIAL()) + return sympy.Symbol('d' + var.name, real=is_real) + elif atom.mathit(): + text = rule2text(atom.mathit().mathit_text()) + return sympy.Symbol(text, real=is_real) + elif atom.VARIABLE(): + text = atom.VARIABLE().getText() + is_percent = text.endswith("\\%") + trim_amount = 3 if is_percent else 1 + name = text[10:] + name = name[0:len(name) - trim_amount] + + # add hash to distinguish from regular symbols + hash = hashlib.md5(name.encode()).hexdigest() + symbol_name = name + hash + + # replace the variable for already known variable values + if name in VARIABLE_VALUES: + # if a sympy class + if isinstance(VARIABLE_VALUES[name], tuple(sympy.core.all_classes)): + symbol = VARIABLE_VALUES[name] + + # if NOT a sympy class + else: + symbol = parse_expr(str(VARIABLE_VALUES[name])) + else: + symbol = sympy.Symbol(symbol_name, real=is_real) + + if is_percent: + return sympy.Mul(symbol, sympy.Pow(100, -1, evaluate=False), evaluate=False) + + # return the symbol + return symbol + + elif atom.PERCENT_NUMBER(): + text = atom.PERCENT_NUMBER().getText().replace("\\%", "").replace(",", "") + try: + number = sympy.Rational(text) + except (TypeError, ValueError): + number = sympy.Number(text) + percent = sympy.Rational(number, 100) + return percent + + +def rule2text(ctx): + stream = ctx.start.getInputStream() + # starting index of starting token + startIdx = ctx.start.start + # stopping index of stopping token + stopIdx = ctx.stop.stop + + return stream.getText(startIdx, stopIdx) + + +def convert_frac(frac): + diff_op = False + partial_op = False + lower_itv = frac.lower.getSourceInterval() + lower_itv_len = lower_itv[1] - lower_itv[0] + 1 + if (frac.lower.start == frac.lower.stop and + frac.lower.start.type == PSLexer.DIFFERENTIAL): + wrt = get_differential_var_str(frac.lower.start.text) + diff_op = True + elif (lower_itv_len == 2 and + frac.lower.start.type == PSLexer.SYMBOL and + frac.lower.start.text == '\\partial' and + (frac.lower.stop.type == PSLexer.LETTER_NO_E or frac.lower.stop.type == PSLexer.SYMBOL)): + partial_op = True + wrt = frac.lower.stop.text + if frac.lower.stop.type == PSLexer.SYMBOL: + wrt = wrt[1:] + + if diff_op or partial_op: + wrt = sympy.Symbol(wrt, real=is_real) + if (diff_op and frac.upper.start == frac.upper.stop and + frac.upper.start.type == PSLexer.LETTER_NO_E and + frac.upper.start.text == 'd'): + return [wrt] + elif (partial_op and frac.upper.start == frac.upper.stop and + frac.upper.start.type == PSLexer.SYMBOL and + frac.upper.start.text == '\\partial'): + return [wrt] + upper_text = rule2text(frac.upper) + + expr_top = None + if diff_op and upper_text.startswith('d'): + expr_top = latex2sympy(upper_text[1:]) + elif partial_op and frac.upper.start.text == '\\partial': + expr_top = latex2sympy(upper_text[len('\\partial'):]) + if expr_top: + return sympy.Derivative(expr_top, wrt) + + expr_top = convert_expr(frac.upper) + expr_bot = convert_expr(frac.lower) + if expr_top.is_Matrix or expr_bot.is_Matrix: + return sympy.MatMul(expr_top, sympy.Pow(expr_bot, -1, evaluate=False), evaluate=False) + else: + return sympy.Mul(expr_top, sympy.Pow(expr_bot, -1, evaluate=False), evaluate=False) + + +def convert_binom(binom): + expr_top = convert_expr(binom.upper) + expr_bot = convert_expr(binom.lower) + return sympy.binomial(expr_top, expr_bot) + + +def convert_func(func): + if func.func_normal_single_arg(): + if func.L_PAREN(): # function called with parenthesis + arg = convert_func_arg(func.func_single_arg()) + else: + arg = convert_func_arg(func.func_single_arg_noparens()) + + name = func.func_normal_single_arg().start.text[1:] + + # change arc -> a + if name in ["arcsin", "arccos", "arctan", "arccsc", "arcsec", + "arccot"]: + name = "a" + name[3:] + expr = getattr(sympy.functions, name)(arg, evaluate=False) + elif name in ["arsinh", "arcosh", "artanh"]: + name = "a" + name[2:] + expr = getattr(sympy.functions, name)(arg, evaluate=False) + elif name in ["arcsinh", "arccosh", "arctanh"]: + name = "a" + name[3:] + expr = getattr(sympy.functions, name)(arg, evaluate=False) + elif name == "operatorname": + operatorname = func.func_normal_single_arg().func_operator_name.getText() + + if operatorname in ["arsinh", "arcosh", "artanh"]: + operatorname = "a" + operatorname[2:] + expr = getattr(sympy.functions, operatorname)(arg, evaluate=False) + elif operatorname in ["arcsinh", "arccosh", "arctanh"]: + operatorname = "a" + operatorname[3:] + expr = getattr(sympy.functions, operatorname)(arg, evaluate=False) + elif operatorname == "floor": + expr = handle_floor(arg) + elif operatorname == "ceil": + expr = handle_ceil(arg) + elif operatorname == 'eye': + expr = sympy.eye(arg) + elif operatorname == 'rank': + expr = sympy.Integer(arg.rank()) + elif operatorname in ['trace', 'tr']: + expr = arg.trace() + elif operatorname == 'rref': + expr = arg.rref()[0] + elif operatorname == 'nullspace': + expr = arg.nullspace() + elif operatorname == 'norm': + expr = arg.norm() + elif operatorname == 'cols': + expr = [arg.col(i) for i in range(arg.cols)] + elif operatorname == 'rows': + expr = [arg.row(i) for i in range(arg.rows)] + elif operatorname in ['eig', 'eigen', 'diagonalize']: + expr = arg.diagonalize() + elif operatorname in ['eigenvals', 'eigenvalues']: + expr = arg.eigenvals() + elif operatorname in ['eigenvects', 'eigenvectors']: + expr = arg.eigenvects() + elif operatorname in ['svd', 'SVD']: + expr = arg.singular_value_decomposition() + elif name in ["log", "ln"]: + if func.subexpr(): + if func.subexpr().atom(): + base = convert_atom(func.subexpr().atom()) + else: + base = convert_expr(func.subexpr().expr()) + elif name == "log": + base = 10 + elif name == "ln": + base = sympy.E + expr = sympy.log(arg, base, evaluate=False) + elif name in ["exp", "exponentialE"]: + expr = sympy.exp(arg) + elif name == "floor": + expr = handle_floor(arg) + elif name == "ceil": + expr = handle_ceil(arg) + elif name == 'det': + expr = arg.det() + + func_pow = None + should_pow = True + if func.supexpr(): + if func.supexpr().expr(): + func_pow = convert_expr(func.supexpr().expr()) + else: + func_pow = convert_atom(func.supexpr().atom()) + + if name in ["sin", "cos", "tan", "csc", "sec", "cot", "sinh", "cosh", "tanh"]: + if func_pow == -1: + name = "a" + name + should_pow = False + expr = getattr(sympy.functions, name)(arg, evaluate=False) + + if func_pow and should_pow: + expr = sympy.Pow(expr, func_pow, evaluate=False) + + return expr + + elif func.func_normal_multi_arg(): + if func.L_PAREN(): # function called with parenthesis + args = func.func_multi_arg().getText().split(",") + else: + args = func.func_multi_arg_noparens().split(",") + + args = list(map(lambda arg: latex2sympy(arg, VARIABLE_VALUES), args)) + name = func.func_normal_multi_arg().start.text[1:] + + if name == "operatorname": + operatorname = func.func_normal_multi_arg().func_operator_name.getText() + if operatorname in ["gcd", "lcm"]: + expr = handle_gcd_lcm(operatorname, args) + elif operatorname == 'zeros': + expr = sympy.zeros(*args) + elif operatorname == 'ones': + expr = sympy.ones(*args) + elif operatorname == 'diag': + expr = sympy.diag(*args) + elif operatorname == 'hstack': + expr = sympy.Matrix.hstack(*args) + elif operatorname == 'vstack': + expr = sympy.Matrix.vstack(*args) + elif operatorname in ['orth', 'ortho', 'orthogonal', 'orthogonalize']: + if len(args) == 1: + arg = args[0] + expr = sympy.matrices.GramSchmidt([arg.col(i) for i in range(arg.cols)], True) + else: + expr = sympy.matrices.GramSchmidt(args, True) + elif name in ["gcd", "lcm"]: + expr = handle_gcd_lcm(name, args) + elif name in ["max", "min"]: + name = name[0].upper() + name[1:] + expr = getattr(sympy.functions, name)(*args, evaluate=False) + + func_pow = None + should_pow = True + if func.supexpr(): + if func.supexpr().expr(): + func_pow = convert_expr(func.supexpr().expr()) + else: + func_pow = convert_atom(func.supexpr().atom()) + + if func_pow and should_pow: + expr = sympy.Pow(expr, func_pow, evaluate=False) + + return expr + elif func.atom_expr_no_supexpr(): + # define a function + f = sympy.Function(func.atom_expr_no_supexpr().getText()) + # args + args = func.func_common_args().getText().split(",") + if args[-1] == '': + args = args[:-1] + args = [latex2sympy(arg, VARIABLE_VALUES) for arg in args] + # supexpr + if func.supexpr(): + if func.supexpr().expr(): + expr = convert_expr(func.supexpr().expr()) + else: + expr = convert_atom(func.supexpr().atom()) + return sympy.Pow(f(*args), expr, evaluate=False) + else: + return f(*args) + elif func.FUNC_INT(): + return handle_integral(func) + elif func.FUNC_SQRT(): + expr = convert_expr(func.base) + if func.root: + r = convert_expr(func.root) + return sympy.Pow(expr, 1 / r, evaluate=False) + else: + return sympy.Pow(expr, sympy.S.Half, evaluate=False) + elif func.FUNC_SUM(): + return handle_sum_or_prod(func, "summation") + elif func.FUNC_PROD(): + return handle_sum_or_prod(func, "product") + elif func.FUNC_LIM(): + return handle_limit(func) + elif func.EXP_E(): + return handle_exp(func) + + +def convert_func_arg(arg): + if hasattr(arg, 'expr'): + return convert_expr(arg.expr()) + else: + return convert_mp(arg.mp_nofunc()) + + +def handle_integral(func): + if func.additive(): + integrand = convert_add(func.additive()) + elif func.frac(): + integrand = convert_frac(func.frac()) + else: + integrand = 1 + + int_var = None + if func.DIFFERENTIAL(): + int_var = get_differential_var(func.DIFFERENTIAL()) + else: + for sym in integrand.atoms(sympy.Symbol): + s = str(sym) + if len(s) > 1 and s[0] == 'd': + if s[1] == '\\': + int_var = sympy.Symbol(s[2:], real=is_real) + else: + int_var = sympy.Symbol(s[1:], real=is_real) + int_sym = sym + if int_var: + integrand = integrand.subs(int_sym, 1) + else: + # Assume dx by default + int_var = sympy.Symbol('x', real=is_real) + + if func.subexpr(): + if func.subexpr().atom(): + lower = convert_atom(func.subexpr().atom()) + else: + lower = convert_expr(func.subexpr().expr()) + if func.supexpr().atom(): + upper = convert_atom(func.supexpr().atom()) + else: + upper = convert_expr(func.supexpr().expr()) + return sympy.Integral(integrand, (int_var, lower, upper)) + else: + return sympy.Integral(integrand, int_var) + + +def handle_sum_or_prod(func, name): + val = convert_mp(func.mp()) + iter_var = convert_expr(func.subeq().equality().expr(0)) + start = convert_expr(func.subeq().equality().expr(1)) + if func.supexpr().expr(): # ^{expr} + end = convert_expr(func.supexpr().expr()) + else: # ^atom + end = convert_atom(func.supexpr().atom()) + + if name == "summation": + return sympy.Sum(val, (iter_var, start, end)) + elif name == "product": + return sympy.Product(val, (iter_var, start, end)) + + +def handle_limit(func): + sub = func.limit_sub() + if sub.LETTER_NO_E(): + var = sympy.Symbol(sub.LETTER_NO_E().getText(), real=is_real) + elif sub.GREEK_CMD(): + var = sympy.Symbol(sub.GREEK_CMD().getText()[1:].strip(), real=is_real) + elif sub.OTHER_SYMBOL_CMD(): + var = sympy.Symbol(sub.OTHER_SYMBOL_CMD().getText().strip(), real=is_real) + else: + var = sympy.Symbol('x', real=is_real) + if sub.SUB(): + direction = "-" + else: + direction = "+" + approaching = convert_expr(sub.expr()) + content = convert_mp(func.mp()) + + return sympy.Limit(content, var, approaching, direction) + + +def handle_exp(func): + if func.supexpr(): + if func.supexpr().expr(): # ^{expr} + exp_arg = convert_expr(func.supexpr().expr()) + else: # ^atom + exp_arg = convert_atom(func.supexpr().atom()) + else: + exp_arg = 1 + return sympy.exp(exp_arg) + + +def handle_gcd_lcm(f, args): + """ + Return the result of gcd() or lcm(), as UnevaluatedExpr + + f: str - name of function ("gcd" or "lcm") + args: List[Expr] - list of function arguments + """ + + args = tuple(map(sympy.nsimplify, args)) + + # gcd() and lcm() don't support evaluate=False + return sympy.UnevaluatedExpr(getattr(sympy, f)(args)) + + +def handle_floor(expr): + """ + Apply floor() then return the floored expression. + + expr: Expr - sympy expression as an argument to floor() + """ + return sympy.functions.floor(expr, evaluate=False) + + +def handle_ceil(expr): + """ + Apply ceil() then return the ceil-ed expression. + + expr: Expr - sympy expression as an argument to ceil() + """ + return sympy.functions.ceiling(expr, evaluate=False) + + +def get_differential_var(d): + text = get_differential_var_str(d.getText()) + return sympy.Symbol(text, real=is_real) + + +def get_differential_var_str(text): + for i in range(1, len(text)): + c = text[i] + if not (c == " " or c == "\r" or c == "\n" or c == "\t"): + idx = i + break + text = text[idx:] + if text[0] == "\\": + text = text[1:] + return text + + +def latex(tex): + global frac_type + result = sympy.latex(tex) + result = result.replace(r'\frac', frac_type, -1).replace(r'\dfrac', frac_type, -1).replace(r'\tfrac', frac_type, -1) + result = result.replace(r'\left[\begin{matrix}', r'\begin{bmatrix}', -1).replace(r'\end{matrix}\right]', r'\end{bmatrix}', -1) + result = result.replace(r'\left', r'', -1).replace(r'\right', r'', -1) + result = result.replace(r' )', r')', -1) + result = result.replace(r'\log', r'\ln', -1) + return result + + +def latex2latex(tex): + result = latex2sympy(tex) + # if result is a list or tuple or dict + if isinstance(result, list) or isinstance(result, tuple) or isinstance(result, dict): + return latex(result) + else: + return latex(simplify(result.subs(variances).doit().doit())) + + +# Set image value +latex2latex('i=I') +latex2latex('j=I') +# set Identity(i) +for i in range(1, 10): + lh = sympy.Symbol(r'\bm{I}_' + str(i), real=False) + lh_m = sympy.MatrixSymbol(r'\bm{I}_' + str(i), i, i) + rh = sympy.Identity(i).as_mutable() + variances[lh] = rh + variances[lh_m] = rh + var[str(lh)] = rh + +if __name__ == '__main__': + # latex2latex(r'A_1=\begin{bmatrix}1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8\end{bmatrix}') + # latex2latex(r'b_1=\begin{bmatrix}1 \\ 2 \\ 3 \\ 4\end{bmatrix}') + # tex = r"(x+2)|_{x=y+1}" + # tex = r"\operatorname{zeros}(3)" + tex = r"\operatorname{rows}(\begin{bmatrix}1 & 2 \\ 3 & 4\end{bmatrix})" + # print("latex2latex:", latex2latex(tex)) + math = latex2sympy(tex) + # math = math.subs(variances) + print("latex:", tex) + # print("var:", variances) + print("raw_math:", math) + # print("math:", latex(math.doit())) + # print("math_type:", type(math.doit())) + # print("shape:", (math.doit()).shape) + print("cal:", latex2latex(tex)) -- cgit v1.2.3