33from tensorflow .keras .preprocessing .text import Tokenizer
44from tensorflow .keras .preprocessing .sequence import pad_sequences
55from tensorflow .keras .layers import TFSMLayer
6- import numpy as np
76
87
98MAX_WORDS = 10000
3433)
3534def model (request ):
3635 # Load dataset
37- data = None
3836 prefix = ""
37+ data = None
3938 try :
4039 data = pd .read_csv (request .param ["dataset" ])
4140 except FileNotFoundError :
4241 # Check if the dataset is in the parent directory
4342 prefix = "../"
4443 data = pd .read_csv (prefix + request .param ["dataset" ])
4544
46- # Load TF model from SavedModel
47- sqli_model = TFSMLayer (
48- prefix + request .param ["model_path" ], call_endpoint = "serving_default"
49- )
45+ # Load TF model using TFSMLayer with the serving_default endpoint
46+ model_path = prefix + request .param ["model_path" ]
47+ sqli_model = TFSMLayer (model_path , call_endpoint = "serving_default" )
5048
51- # Tokenize the sample
49+ # Tokenizer setup
5250 tokenizer = Tokenizer (num_words = MAX_WORDS , filters = "" )
5351 tokenizer .fit_on_texts (data ["Query" ])
5452
@@ -62,18 +60,18 @@ def model(request):
6260@pytest .mark .parametrize (
6361 "sample" ,
6462 [
65- ("select * from users where id=1 or 1=1;" , [99.99 , 97.40 , 11.96 ]),
66- ("select * from users where id='1' or 1=1--" , [92.02 , 97.40 , 11.96 ]),
67- ("select * from users" , [0.077 , 0.015 , 0.002 ]),
68- ("select * from users where id=10000" , [14.83 , 88.93 , 0.229 ]),
69- ("select '1' union select 'a'; -- -'" , [99.99 , 97.32 , 99.97 ]),
63+ ("select * from users where id=1 or 1=1;" , [0.9202 , 0.974 , 0.0022 ]),
64+ ("select * from users where id='1' or 1=1--" , [0.9202 , 0.974 , 0.0022 ]),
65+ ("select * from users" , [0.00077 , 0.0015 , 0.0231 ]),
66+ ("select * from users where id=10000" , [0.1483 , 0.8893 , 0.0008 ]),
67+ ("select '1' union select 'a'; -- -'" , [0.9999 , 0.9732 , 0.0139 ]),
7068 (
7169 "select '' union select 'malicious php code' \\ g /var/www/test.php; -- -';" ,
72- [99.99 , 80.65 , 99.98 ],
70+ [0.9999 , 0.8065 , 0.0424 ],
7371 ),
7472 (
7573 "select '' || pg_sleep((ascii((select 'a' limit 1)) - 32) / 2); -- -';" ,
76- [99.99 , 99.99 , 99.93 ],
74+ [0.9999 , 0.9999 , 0.01543 ],
7775 ),
7876 ],
7977)
@@ -85,12 +83,12 @@ def test_sqli_model(model, sample):
8583 # Predict sample
8684 predictions = model ["sqli_model" ](sample_vec )
8785
88- # Scale up to 100
89- output = "dense"
90- if "output_0" in predictions :
91- output = "output_0" # Model v2 and v3 use output_0 instead of dense
86+ # Extract the prediction result
87+ output_key = "output_0" if "output_0" in predictions else "dense"
88+ predicted_value = predictions [output_key ].numpy ()[0 ][0 ]
9289
93- print (predictions [output ].numpy () * 100 ) # Debugging purposes (prints on error)
94- assert predictions [output ].numpy () * 100 == pytest .approx (
95- np .array ([[sample [1 ][model ["index" ]]]]), 0.1
90+ print (
91+ f"Predicted: { predicted_value :.4f} , Expected: { sample [1 ][model ['index' ]]:.4f} "
9692 )
93+
94+ assert predicted_value == pytest .approx (sample [1 ][model ["index" ]], abs = 0.05 )
0 commit comments