|
16 | 16 | import re |
17 | 17 | from typing import Dict, List, TypedDict |
18 | 18 |
|
| 19 | +class ModifiedLatexPrinter(LatexPrinter): |
| 20 | + """Modified LatexPrinter class that prints logarithms other than the natural logarithm correctly. |
| 21 | + """ |
| 22 | + def _print_log(self, expr, exp=None): |
| 23 | + if self._settings["ln_notation"] and len(expr.args) < 2: |
| 24 | + log_not = r"\ln" |
| 25 | + else: |
| 26 | + log_not = r"\log" |
| 27 | + if len(expr.args) > 1: |
| 28 | + base = self._print(expr.args[1]) |
| 29 | + log_not = r"\log_{%s}" % base |
| 30 | + tex = r"%s{\left(%s \right)}" % (log_not, self._print(expr.args[0])) |
| 31 | + |
| 32 | + if exp is not None: |
| 33 | + return r"%s^{%s}" % (tex, exp) |
| 34 | + else: |
| 35 | + return tex |
| 36 | + |
19 | 37 | elementary_functions_names = [ |
20 | 38 | ('sin', []), ('sinc', []), ('csc', ['cosec']), ('cos', []), ('sec', []), ('tan', []), ('cot', ['cotan']), |
21 | 39 | ('asin', ['arcsin']), ('acsc', ['arccsc', 'arccosec', 'acosec']), ('acos', ['arccos']), ('asec', ['arcsec']), |
22 | 40 | ('atan', ['arctan']), ('acot', ['arccot', 'arccotan', 'acotan']), ('atan2', ['arctan2']), |
23 | 41 | ('sinh', []), ('cosh', []), ('tanh', []), ('csch', ['cosech']), ('sech', []), |
24 | 42 | ('asinh', ['arcsinh']), ('acosh', ['arccosh']), ('atanh', ['arctanh']), |
25 | 43 | ('acsch', ['arccsch', 'arccosech']), ('asech', ['arcsech']), |
26 | | - ('exp', ['Exp']), ('E', ['e']), ('log', []), |
| 44 | + ('exp', ['Exp']), ('E', ['e']), ('log', ['ln']), |
27 | 45 | ('sqrt', []), ('sign', []), ('Abs', ['abs']), ('Max', ['max']), ('Min', ['min']), ('arg', []), ('ceiling', ['ceil']), ('floor', []), |
28 | 46 | # Below this line should probably not be collected with elementary functions. Some like 'common operations' would be a better name |
29 | 47 | ('summation', ['sum','Sum']), ('Derivative', ['diff']), |
@@ -453,10 +471,18 @@ def latex_symbols(symbols): |
453 | 471 | return symbol_dict |
454 | 472 |
|
455 | 473 |
|
456 | | -def sympy_to_latex(equation, symbols): |
457 | | - latex_out = LatexPrinter( |
458 | | - {"symbol_names": latex_symbols(symbols)} |
459 | | - ).doprint(equation) |
| 474 | +def sympy_to_latex(equation, symbols, settings=None): |
| 475 | + default_settings = { |
| 476 | + "symbol_names": latex_symbols(symbols), |
| 477 | + "ln_notation": True, |
| 478 | + } |
| 479 | + if settings is None: |
| 480 | + settings = default_settings |
| 481 | + else: |
| 482 | + for key in default_settings.keys(): |
| 483 | + if key not in settings.keys(): |
| 484 | + settings[key] = default_settings[key] |
| 485 | + latex_out = ModifiedLatexPrinter(settings).doprint(equation) |
460 | 486 | return latex_out |
461 | 487 |
|
462 | 488 |
|
|
0 commit comments