Skip to content

Commit 8852041

Browse files
committed
Tokenize and predict in a single step (deprecates Serving API)
1 parent 9ccca91 commit 8852041

File tree

1 file changed

+37
-21
lines changed

1 file changed

+37
-21
lines changed

api/api.py

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,57 @@
1-
import os
2-
3-
import pandas as pd
41
from flask import Flask, jsonify, request
2+
import tensorflow as tf
53
from tensorflow.keras.preprocessing.sequence import pad_sequences
64
from tensorflow.keras.preprocessing.text import Tokenizer
5+
import pandas as pd
6+
import os
77

8+
app = Flask(__name__)
9+
10+
# Constants and configurations
811
MAX_WORDS = 10000
912
MAX_LEN = 100
1013
DATASET_PATH = os.getenv("DATASET_PATH", "dataset/sqli_dataset1.csv")
14+
MODEL_PATH = os.getenv("MODEL_PATH", "/app/sqli_model/3/")
1115
DATASET = pd.read_csv(DATASET_PATH)
16+
17+
# Tokenizer setup
1218
TOKENIZER = Tokenizer(num_words=MAX_WORDS, filters="")
1319
TOKENIZER.fit_on_texts(DATASET["Query"])
14-
CONFIG = {"DEBUG": False}
1520

21+
# Load the model using tf.saved_model.load and get the serving signature
22+
loaded_model = tf.saved_model.load(MODEL_PATH)
23+
model_predict = loaded_model.signatures["serving_default"]
1624

17-
app = Flask(__name__)
18-
app.config.from_mapping(CONFIG)
1925

26+
@app.route("/predict", methods=["POST"])
27+
def predict():
28+
if not request.json or "query" not in request.json:
29+
return jsonify({"error": "No query provided"}), 400
30+
31+
try:
32+
# Tokenize and pad the input query
33+
query = request.json["query"]
34+
query_seq = TOKENIZER.texts_to_sequences([query])
35+
query_vec = pad_sequences(query_seq, maxlen=MAX_LEN)
2036

21-
@app.route("/tokenize_and_sequence", methods=["POST"])
22-
def tokenize_and_sequence():
23-
"""Tokenize and sequence the input query from the request
24-
and return the vectorized output.
25-
"""
37+
# Convert input to tensor
38+
input_tensor = tf.convert_to_tensor(query_vec, dtype=tf.float32)
2639

27-
body = request.get_json()
28-
if not body:
29-
return jsonify({"error": "No JSON body provided"}), 400
40+
# Use the loaded model's serving signature to make the prediction
41+
prediction = model_predict(input_tensor)
3042

31-
# Vectorize the sample
32-
query_seq = TOKENIZER.texts_to_sequences([body["query"]])
33-
query_vec = pad_sequences(query_seq, maxlen=MAX_LEN)
43+
if "output_0" not in prediction or prediction["output_0"].get_shape() != [1, 1]:
44+
return jsonify({"error": "Invalid model output"}), 500
3445

35-
tokens = query_vec.tolist()
36-
return jsonify({"tokens": tokens[0]})
46+
return jsonify(
47+
{
48+
"confidence": float("%.4f" % prediction["output_0"].numpy()[0][0]),
49+
}
50+
)
51+
except Exception as e:
52+
# TODO: Log the error and return a proper error message
53+
return jsonify({"error": str(e)}), 500
3754

3855

3956
if __name__ == "__main__":
40-
# Run the app in debug mode
41-
app.run(host="localhost", port=8000, debug=True)
57+
app.run(host="0.0.0.0", port=8000, debug=True)

0 commit comments

Comments
 (0)