Skip to content

Commit ee21f63

Browse files
Merge pull request #75 from lambda-feedback/tr111-improve-input-sanitisation-handwritten-preview
Added input sanitisation that removes \mathrm and \text from handwrit…
2 parents 3c14937 + be43ad5 commit ee21f63

File tree

2 files changed

+36
-21
lines changed

2 files changed

+36
-21
lines changed

app/preview.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,40 @@ class Preview(TypedDict):
3636
class Result(TypedDict):
3737
preview: Preview
3838

39+
def find_matching_parenthesis(string, index, delimiters=None):
40+
depth = 0
41+
if delimiters == None:
42+
delimiters = ('(', ')')
43+
for k in range(index, len(string)):
44+
if string[k] == delimiters[0]:
45+
depth += 1
46+
continue
47+
if string[k] == delimiters[1]:
48+
depth += -1
49+
if depth == 0:
50+
return k
51+
return -1
52+
53+
def sanitise_latex(response):
54+
response = "".join(response.split())
55+
response = response.replace('~',' ')
56+
wrappers = [r"\mathrm",r"\text"]
57+
for wrapper in wrappers:
58+
processed_response = []
59+
index = 0
60+
while index < len(response):
61+
wrapper_start = response.find(wrapper+"{", index)
62+
if wrapper_start > -1:
63+
processed_response.append(response[index:wrapper_start])
64+
wrapper_end = find_matching_parenthesis(response, wrapper_start+1, delimiters=('{','}'))
65+
inside_wrapper = response[(wrapper_start+len(wrapper+"{")):wrapper_end]
66+
processed_response.append(inside_wrapper)
67+
index = wrapper_end+1
68+
else:
69+
processed_response.append(response[index:])
70+
index = len(response)
71+
response = "".join(processed_response)
72+
return response
3973

4074
def parse_latex(response: str, symbols: SymbolDict) -> str:
4175
"""Parse a LaTeX string to a sympy string while preserving custom symbols.
@@ -53,6 +87,8 @@ def parse_latex(response: str, symbols: SymbolDict) -> str:
5387
"""
5488
substitutions = {}
5589

90+
response = sanitise_latex(response)
91+
5692
for sympy_symbol_str in symbols:
5793
symbol_str = symbols[sympy_symbol_str]["latex"]
5894
latex_symbol_str = extract_latex(symbol_str)
@@ -78,7 +114,6 @@ def parse_latex(response: str, symbols: SymbolDict) -> str:
78114
except Exception as e:
79115
raise ValueError(str(e))
80116

81-
82117
def parse_symbolic(response: str, params):
83118
response_list_in = create_expression_set(response, params)
84119
response_list_out = []

app/preview_tests.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -127,26 +127,6 @@ def test_sympy_conversion_preserves_default_symbols(self):
127127
preview = result["preview"]
128128
assert preview.get("latex") == "\\mu + x + 1"
129129

130-
def test_latex_conversion_preserves_optional_symbols(self):
131-
response = "m_{ \\text{table} } + \\text{hello}_\\text{world} - x + 1"
132-
params = Params(
133-
is_latex=True,
134-
simplify=False,
135-
symbols={
136-
"m_table": {
137-
"latex": r"hello \( m_{\text{table}} \) world",
138-
"aliases": [],
139-
},
140-
"test": {
141-
"latex": r"hello $ \text{hello}_\text{world} $ world.",
142-
"aliases": [],
143-
},
144-
},
145-
)
146-
result = preview_function(response, params)
147-
preview = result["preview"]
148-
assert preview.get("sympy") == "m_table + test - x + 1"
149-
150130
def test_sympy_conversion_preserves_optional_symbols(self):
151131
response = "m_table + test + x + 1"
152132
params = Params(

0 commit comments

Comments
 (0)