|
1 | | -import os |
2 | | - |
3 | | -import pandas as pd |
4 | 1 | from flask import Flask, jsonify, request |
| 2 | +import tensorflow as tf |
5 | 3 | from tensorflow.keras.preprocessing.sequence import pad_sequences |
6 | 4 | from tensorflow.keras.preprocessing.text import Tokenizer |
| 5 | +import pandas as pd |
| 6 | +import os |
7 | 7 |
|
| 8 | +app = Flask(__name__) |
| 9 | + |
| 10 | +# Constants and configurations |
8 | 11 | MAX_WORDS = 10000 |
9 | 12 | MAX_LEN = 100 |
10 | 13 | DATASET_PATH = os.getenv("DATASET_PATH", "dataset/sqli_dataset1.csv") |
| 14 | +MODEL_PATH = os.getenv("MODEL_PATH", "/app/sqli_model/3/") |
11 | 15 | DATASET = pd.read_csv(DATASET_PATH) |
| 16 | + |
| 17 | +# Tokenizer setup |
12 | 18 | TOKENIZER = Tokenizer(num_words=MAX_WORDS, filters="") |
13 | 19 | TOKENIZER.fit_on_texts(DATASET["Query"]) |
14 | | -CONFIG = {"DEBUG": False} |
15 | 20 |
|
| 21 | +# Load the model using tf.saved_model.load and get the serving signature |
| 22 | +loaded_model = tf.saved_model.load(MODEL_PATH) |
| 23 | +model_predict = loaded_model.signatures["serving_default"] |
16 | 24 |
|
17 | | -app = Flask(__name__) |
18 | | -app.config.from_mapping(CONFIG) |
19 | 25 |
|
| 26 | +@app.route("/predict", methods=["POST"]) |
| 27 | +def predict(): |
| 28 | + if not request.json or "query" not in request.json: |
| 29 | + return jsonify({"error": "No query provided"}), 400 |
| 30 | + |
| 31 | + try: |
| 32 | + # Tokenize and pad the input query |
| 33 | + query = request.json["query"] |
| 34 | + query_seq = TOKENIZER.texts_to_sequences([query]) |
| 35 | + query_vec = pad_sequences(query_seq, maxlen=MAX_LEN) |
20 | 36 |
|
21 | | -@app.route("/tokenize_and_sequence", methods=["POST"]) |
22 | | -def tokenize_and_sequence(): |
23 | | - """Tokenize and sequence the input query from the request |
24 | | - and return the vectorized output. |
25 | | - """ |
| 37 | + # Convert input to tensor |
| 38 | + input_tensor = tf.convert_to_tensor(query_vec, dtype=tf.float32) |
26 | 39 |
|
27 | | - body = request.get_json() |
28 | | - if not body: |
29 | | - return jsonify({"error": "No JSON body provided"}), 400 |
| 40 | + # Use the loaded model's serving signature to make the prediction |
| 41 | + prediction = model_predict(input_tensor) |
30 | 42 |
|
31 | | - # Vectorize the sample |
32 | | - query_seq = TOKENIZER.texts_to_sequences([body["query"]]) |
33 | | - query_vec = pad_sequences(query_seq, maxlen=MAX_LEN) |
| 43 | + if "output_0" not in prediction or prediction["output_0"].get_shape() != [1, 1]: |
| 44 | + return jsonify({"error": "Invalid model output"}), 500 |
34 | 45 |
|
35 | | - tokens = query_vec.tolist() |
36 | | - return jsonify({"tokens": tokens[0]}) |
| 46 | + return jsonify( |
| 47 | + { |
| 48 | + "confidence": float("%.4f" % prediction["output_0"].numpy()[0][0]), |
| 49 | + } |
| 50 | + ) |
| 51 | + except Exception as e: |
| 52 | + # TODO: Log the error and return a proper error message |
| 53 | + return jsonify({"error": str(e)}), 500 |
37 | 54 |
|
38 | 55 |
|
39 | 56 | if __name__ == "__main__": |
40 | | - # Run the app in debug mode |
41 | | - app.run(host="localhost", port=8000, debug=True) |
| 57 | + app.run(host="0.0.0.0", port=8000, debug=True) |
0 commit comments