Skip to content

Commit 7a9c098

Browse files
Merge pull request #73 from lambda-feedback/tr102-rtol-slow
Fixed similarly to compareExpressions
2 parents 44c324c + 0671d40 commit 7a9c098

File tree

3 files changed

+51
-24
lines changed

3 files changed

+51
-24
lines changed

app/evaluation.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -145,31 +145,48 @@ def check_equality(response, answer, params, eval_response) -> dict:
145145
eval_response.is_correct = ((res.args[0]-res.args[1])/(ans.args[0]-ans.args[1])).simplify().is_constant()
146146
return eval_response
147147

148-
error_below_atol = False
149-
error_below_rtol = False
150-
151-
if params.get("numerical", False) or params.get("rtol", False) or params.get("atol", False):
152-
# REMARK: 'pi' should be a reserve symbols but is sometimes not treated as one, possibly because of input symbols
153-
# The two lines below this comments fixes the issue but a more robust solution should be found for cases where there
154-
# are other reserved symbols.
155-
ans = ans.subs(Symbol('pi'), float(pi))
156-
res = res.subs(Symbol('pi'), float(pi))
157-
if res.is_constant() and ans.is_constant():
148+
is_correct = bool((res - ans).simplify() == 0)
149+
eval_response.is_correct = is_correct
150+
151+
error_below_atol = None
152+
error_below_rtol = None
153+
154+
if eval_response.is_correct is False:
155+
if params.get("numerical", False) or params.get("rtol", False) or params.get("atol", False):
156+
# REMARK: 'pi' should be a reserved symbol but it is sometimes not treated as one, possibly because of input symbols.
157+
# The two lines below this comments fixes the issue but a more robust solution should be found for cases where there
158+
# are other reserved symbols.
159+
def replace_pi(expr):
160+
pi_symbol = pi
161+
for s in expr.free_symbols:
162+
if str(s) == 'pi':
163+
pi_symbol = s
164+
return expr.subs(pi_symbol, float(pi))
165+
ans = replace_pi(ans)
166+
res = replace_pi(res)
158167
if "atol" in params.keys():
159-
error_below_atol = bool(abs(float(ans-res)) < float(params["atol"]))
168+
try:
169+
absolute_error = abs(float(ans-res))
170+
error_below_atol = bool(absolute_error < float(params["atol"]))
171+
except TypeError:
172+
error_below_atol = None
160173
else:
161174
error_below_atol = True
162175
if "rtol" in params.keys():
163-
rtol = float(params["rtol"])
164-
error_below_rtol = bool(float(abs(((ans-res)/ans).simplify())) < rtol)
176+
try:
177+
relative_error = abs(float((ans-res)/ans))
178+
error_below_rtol = bool(relative_error < float(params["rtol"]))
179+
except TypeError:
180+
error_below_rtol = None
165181
else:
166182
error_below_rtol = True
167-
if error_below_atol and error_below_rtol:
168-
eval_response.is_correct = True
169-
tag = "WITHIN_TOLERANCE"
170-
eval_response.add_feedback((tag, symbolic_equal_internal_messages[tag]))
171-
return eval_response
183+
if error_below_atol is None or error_below_rtol is None:
184+
eval_response.is_correct = False
185+
tag = "NOT_NUMERICAL"
186+
eval_response.add_feedback((tag, symbolic_equal_internal_messages[tag]))
187+
elif error_below_atol is True and error_below_rtol is True:
188+
eval_response.is_correct = True
189+
tag = "WITHIN_TOLERANCE"
190+
eval_response.add_feedback((tag, symbolic_equal_internal_messages[tag]))
172191

173-
is_correct = bool((res - ans).simplify() == 0)
174-
eval_response.is_correct = is_correct
175192
return eval_response

app/evaluation_tests.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -548,71 +548,81 @@ def test_empty_input_symbols_codes_and_alternatives(self):
548548
assert result["is_correct"] is True
549549

550550
@pytest.mark.parametrize(
551-
"description,response,answer,tolerance,outcome",
551+
"description,response,answer,tolerance,tags,outcome",
552552
[
553553
(
554554
"Correct response, tolerance specified with atol",
555555
"6.73",
556556
"sqrt(3)+5",
557557
{"atol": 0.005},
558+
["WITHIN_TOLERANCE"],
558559
True
559560
),
560561
(
561562
"Incorrect response, tolerance specified with atol",
562563
"6.7",
563564
"sqrt(3)+5",
564565
{"atol": 0.005},
566+
[],
565567
False
566568
),
567569
(
568570
"Correct response, tolerance specified with rtol",
569571
"6.73",
570572
"sqrt(3)+5",
571573
{"rtol": 0.0005},
574+
["WITHIN_TOLERANCE"],
572575
True
573576
),
574577
(
575578
"Incorrect response, tolerance specified with rtol",
576579
"6.7",
577580
"sqrt(3)+5",
578581
{"rtol": 0.0005},
582+
[],
579583
False
580584
),
581585
(
582586
"Response is not constant, tolerance specified with atol",
583587
"6.7+x",
584588
"sqrt(3)+5",
585589
{"atol": 0.005},
590+
["NOT_NUMERICAL"],
586591
False
587592
),
588593
(
589594
"Answer is not constant, tolerance specified with atol",
590595
"6.73",
591596
"sqrt(3)+x",
592597
{"atol": 0.005},
598+
["NOT_NUMERICAL"],
593599
False
594600
),
595601
(
596602
"Response is not constant, tolerance specified with rtol",
597603
"6.7+x",
598604
"sqrt(3)+5",
599605
{"rtol": 0.0005},
606+
["NOT_NUMERICAL"],
600607
False
601608
),
602609
(
603610
"Answer is not constant, tolerance specified with rtol",
604611
"6.73",
605612
"sqrt(3)+x",
606613
{"rtol": 0.0005},
614+
["NOT_NUMERICAL"],
607615
False
608616
),
609617
]
610618
)
611-
def test_numerical_comparison(self, description, response, answer, tolerance, outcome):
619+
def test_numerical_comparison_problem(self, description, response, answer, tolerance, tags, outcome):
612620
params = {"numerical": True}
613621
params.update(tolerance)
614-
result = evaluation_function(response, answer, params)
622+
result = evaluation_function(response, answer, params, include_test_data=True)
615623
assert result["is_correct"] is outcome
624+
for tag in tags:
625+
tag in result["tags"]
616626

617627
def test_warning_inappropriate_symbol(self):
618628
answer = 'factorial(2**4)'

app/feedback/symbolic_equal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@
1111
"EXPRESSION_NOT_EQUALITY": "The response was an expression but was expected to be an equality.",
1212
"EQUALITY_NOT_EXPRESSION": "The response was an equality but was expected to be an expression.",
1313
"WITHIN_TOLERANCE": "", # "The difference between the response the answer is within specified error tolerance.",
14-
"SYMBOLICALLY_EQUAL": "The response and answer are symbolically equal.",
14+
"NOT_NUMERICAL": "", #"The expression cannot be evaluated numerically.",
1515
}

0 commit comments

Comments
 (0)