Skip to content

Commit 99f4d1c

Browse files
committed
Don't scale the values
Update prediction scores
1 parent b9faf5c commit 99f4d1c

File tree

1 file changed

+19
-21
lines changed

1 file changed

+19
-21
lines changed

training/test_train.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from tensorflow.keras.preprocessing.text import Tokenizer
44
from tensorflow.keras.preprocessing.sequence import pad_sequences
55
from tensorflow.keras.layers import TFSMLayer
6-
import numpy as np
76

87

98
MAX_WORDS = 10000
@@ -34,21 +33,20 @@
3433
)
3534
def model(request):
3635
# Load dataset
37-
data = None
3836
prefix = ""
37+
data = None
3938
try:
4039
data = pd.read_csv(request.param["dataset"])
4140
except FileNotFoundError:
4241
# Check if the dataset is in the parent directory
4342
prefix = "../"
4443
data = pd.read_csv(prefix + request.param["dataset"])
4544

46-
# Load TF model from SavedModel
47-
sqli_model = TFSMLayer(
48-
prefix + request.param["model_path"], call_endpoint="serving_default"
49-
)
45+
# Load TF model using TFSMLayer with the serving_default endpoint
46+
model_path = prefix + request.param["model_path"]
47+
sqli_model = TFSMLayer(model_path, call_endpoint="serving_default")
5048

51-
# Tokenize the sample
49+
# Tokenizer setup
5250
tokenizer = Tokenizer(num_words=MAX_WORDS, filters="")
5351
tokenizer.fit_on_texts(data["Query"])
5452

@@ -62,18 +60,18 @@ def model(request):
6260
@pytest.mark.parametrize(
6361
"sample",
6462
[
65-
("select * from users where id=1 or 1=1;", [99.99, 97.40, 11.96]),
66-
("select * from users where id='1' or 1=1--", [92.02, 97.40, 11.96]),
67-
("select * from users", [0.077, 0.015, 0.002]),
68-
("select * from users where id=10000", [14.83, 88.93, 0.229]),
69-
("select '1' union select 'a'; -- -'", [99.99, 97.32, 99.97]),
63+
("select * from users where id=1 or 1=1;", [0.9202, 0.974, 0.0022]),
64+
("select * from users where id='1' or 1=1--", [0.9202, 0.974, 0.0022]),
65+
("select * from users", [0.00077, 0.0015, 0.0231]),
66+
("select * from users where id=10000", [0.1483, 0.8893, 0.0008]),
67+
("select '1' union select 'a'; -- -'", [0.9999, 0.9732, 0.0139]),
7068
(
7169
"select '' union select 'malicious php code' \\g /var/www/test.php; -- -';",
72-
[99.99, 80.65, 99.98],
70+
[0.9999, 0.8065, 0.0424],
7371
),
7472
(
7573
"select '' || pg_sleep((ascii((select 'a' limit 1)) - 32) / 2); -- -';",
76-
[99.99, 99.99, 99.93],
74+
[0.9999, 0.9999, 0.01543],
7775
),
7876
],
7977
)
@@ -85,12 +83,12 @@ def test_sqli_model(model, sample):
8583
# Predict sample
8684
predictions = model["sqli_model"](sample_vec)
8785

88-
# Scale up to 100
89-
output = "dense"
90-
if "output_0" in predictions:
91-
output = "output_0" # Model v2 and v3 use output_0 instead of dense
86+
# Extract the prediction result
87+
output_key = "output_0" if "output_0" in predictions else "dense"
88+
predicted_value = predictions[output_key].numpy()[0][0]
9289

93-
print(predictions[output].numpy() * 100) # Debugging purposes (prints on error)
94-
assert predictions[output].numpy() * 100 == pytest.approx(
95-
np.array([[sample[1][model["index"]]]]), 0.1
90+
print(
91+
f"Predicted: {predicted_value:.4f}, Expected: {sample[1][model['index']]:.4f}"
9692
)
93+
94+
assert predicted_value == pytest.approx(sample[1][model["index"]], abs=0.05)

0 commit comments

Comments
 (0)