@@ -145,31 +145,48 @@ def check_equality(response, answer, params, eval_response) -> dict:
145145 eval_response .is_correct = ((res .args [0 ]- res .args [1 ])/ (ans .args [0 ]- ans .args [1 ])).simplify ().is_constant ()
146146 return eval_response
147147
148- error_below_atol = False
149- error_below_rtol = False
150-
151- if params .get ("numerical" , False ) or params .get ("rtol" , False ) or params .get ("atol" , False ):
152- # REMARK: 'pi' should be a reserve symbols but is sometimes not treated as one, possibly because of input symbols
153- # The two lines below this comments fixes the issue but a more robust solution should be found for cases where there
154- # are other reserved symbols.
155- ans = ans .subs (Symbol ('pi' ), float (pi ))
156- res = res .subs (Symbol ('pi' ), float (pi ))
157- if res .is_constant () and ans .is_constant ():
148+ is_correct = bool ((res - ans ).simplify () == 0 )
149+ eval_response .is_correct = is_correct
150+
151+ error_below_atol = None
152+ error_below_rtol = None
153+
154+ if eval_response .is_correct is False :
155+ if params .get ("numerical" , False ) or params .get ("rtol" , False ) or params .get ("atol" , False ):
156+ # REMARK: 'pi' should be a reserved symbol but it is sometimes not treated as one, possibly because of input symbols.
157+ # The two lines below this comments fixes the issue but a more robust solution should be found for cases where there
158+ # are other reserved symbols.
159+ def replace_pi (expr ):
160+ pi_symbol = pi
161+ for s in expr .free_symbols :
162+ if str (s ) == 'pi' :
163+ pi_symbol = s
164+ return expr .subs (pi_symbol , float (pi ))
165+ ans = replace_pi (ans )
166+ res = replace_pi (res )
158167 if "atol" in params .keys ():
159- error_below_atol = bool (abs (float (ans - res )) < float (params ["atol" ]))
168+ try :
169+ absolute_error = abs (float (ans - res ))
170+ error_below_atol = bool (absolute_error < float (params ["atol" ]))
171+ except TypeError :
172+ error_below_atol = None
160173 else :
161174 error_below_atol = True
162175 if "rtol" in params .keys ():
163- rtol = float (params ["rtol" ])
164- error_below_rtol = bool (float (abs (((ans - res )/ ans ).simplify ())) < rtol )
176+ try :
177+ relative_error = abs (float ((ans - res )/ ans ))
178+ error_below_rtol = bool (relative_error < float (params ["rtol" ]))
179+ except TypeError :
180+ error_below_rtol = None
165181 else :
166182 error_below_rtol = True
167- if error_below_atol and error_below_rtol :
168- eval_response .is_correct = True
169- tag = "WITHIN_TOLERANCE"
170- eval_response .add_feedback ((tag , symbolic_equal_internal_messages [tag ]))
171- return eval_response
183+ if error_below_atol is None or error_below_rtol is None :
184+ eval_response .is_correct = False
185+ tag = "NOT_NUMERICAL"
186+ eval_response .add_feedback ((tag , symbolic_equal_internal_messages [tag ]))
187+ elif error_below_atol is True and error_below_rtol is True :
188+ eval_response .is_correct = True
189+ tag = "WITHIN_TOLERANCE"
190+ eval_response .add_feedback ((tag , symbolic_equal_internal_messages [tag ]))
172191
173- is_correct = bool ((res - ans ).simplify () == 0 )
174- eval_response .is_correct = is_correct
175192 return eval_response
0 commit comments