Skip to content

Commit be43ad5

Browse files
Added input sanitisation that removes \mathrm and \text from handwritten input.
Note that this means that \mathrm and \text cannot be used for input symbols
1 parent 3c14937 commit be43ad5

File tree

2 files changed

+36
-21
lines changed

2 files changed

+36
-21
lines changed

app/preview.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,40 @@ class Preview(TypedDict):
3636
class Result(TypedDict):
3737
preview: Preview
3838

39+
def find_matching_parenthesis(string, index, delimiters=None):
40+
depth = 0
41+
if delimiters == None:
42+
delimiters = ('(', ')')
43+
for k in range(index, len(string)):
44+
if string[k] == delimiters[0]:
45+
depth += 1
46+
continue
47+
if string[k] == delimiters[1]:
48+
depth += -1
49+
if depth == 0:
50+
return k
51+
return -1
52+
53+
def sanitise_latex(response):
54+
response = "".join(response.split())
55+
response = response.replace('~',' ')
56+
wrappers = [r"\mathrm",r"\text"]
57+
for wrapper in wrappers:
58+
processed_response = []
59+
index = 0
60+
while index < len(response):
61+
wrapper_start = response.find(wrapper+"{", index)
62+
if wrapper_start > -1:
63+
processed_response.append(response[index:wrapper_start])
64+
wrapper_end = find_matching_parenthesis(response, wrapper_start+1, delimiters=('{','}'))
65+
inside_wrapper = response[(wrapper_start+len(wrapper+"{")):wrapper_end]
66+
processed_response.append(inside_wrapper)
67+
index = wrapper_end+1
68+
else:
69+
processed_response.append(response[index:])
70+
index = len(response)
71+
response = "".join(processed_response)
72+
return response
3973

4074
def parse_latex(response: str, symbols: SymbolDict) -> str:
4175
"""Parse a LaTeX string to a sympy string while preserving custom symbols.
@@ -53,6 +87,8 @@ def parse_latex(response: str, symbols: SymbolDict) -> str:
5387
"""
5488
substitutions = {}
5589

90+
response = sanitise_latex(response)
91+
5692
for sympy_symbol_str in symbols:
5793
symbol_str = symbols[sympy_symbol_str]["latex"]
5894
latex_symbol_str = extract_latex(symbol_str)
@@ -78,7 +114,6 @@ def parse_latex(response: str, symbols: SymbolDict) -> str:
78114
except Exception as e:
79115
raise ValueError(str(e))
80116

81-
82117
def parse_symbolic(response: str, params):
83118
response_list_in = create_expression_set(response, params)
84119
response_list_out = []

app/preview_tests.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -127,26 +127,6 @@ def test_sympy_conversion_preserves_default_symbols(self):
127127
preview = result["preview"]
128128
assert preview.get("latex") == "\\mu + x + 1"
129129

130-
def test_latex_conversion_preserves_optional_symbols(self):
131-
response = "m_{ \\text{table} } + \\text{hello}_\\text{world} - x + 1"
132-
params = Params(
133-
is_latex=True,
134-
simplify=False,
135-
symbols={
136-
"m_table": {
137-
"latex": r"hello \( m_{\text{table}} \) world",
138-
"aliases": [],
139-
},
140-
"test": {
141-
"latex": r"hello $ \text{hello}_\text{world} $ world.",
142-
"aliases": [],
143-
},
144-
},
145-
)
146-
result = preview_function(response, params)
147-
preview = result["preview"]
148-
assert preview.get("sympy") == "m_table + test - x + 1"
149-
150130
def test_sympy_conversion_preserves_optional_symbols(self):
151131
response = "m_table + test + x + 1"
152132
params = Params(

0 commit comments

Comments
 (0)