|
1 | 1 | import sys |
2 | 2 | import pandas as pd |
| 3 | +import tensorflow as tf |
3 | 4 | from tensorflow.keras.preprocessing.text import Tokenizer |
4 | 5 | from tensorflow.keras.preprocessing.sequence import pad_sequences |
5 | 6 | from tensorflow.keras.models import Sequential |
6 | 7 | from tensorflow.keras.layers import ( |
7 | 8 | Bidirectional, |
8 | 9 | Conv1D, |
9 | 10 | Dense, |
| 11 | + Dropout, |
10 | 12 | Embedding, |
11 | 13 | Flatten, |
12 | 14 | LSTM, |
13 | 15 | MaxPooling1D, |
14 | 16 | ) |
15 | | -from tensorflow.keras.metrics import Accuracy, Recall, Precision |
16 | 17 | from tensorflow.keras.callbacks import EarlyStopping |
17 | | -from sklearn.model_selection import train_test_split |
18 | | -from sklearn.metrics import ( |
19 | | - accuracy_score, |
20 | | - recall_score, |
21 | | - precision_score, |
22 | | - f1_score, |
23 | | - confusion_matrix, |
24 | | -) |
| 18 | +from sklearn.model_selection import KFold |
| 19 | +from sklearn.metrics import accuracy_score, precision_score, recall_score |
25 | 20 | import numpy as np |
26 | 21 | import matplotlib.pyplot as plt |
27 | 22 |
|
28 | 23 |
|
29 | | -# Check if the input file and output directory are provided |
30 | | -if len(sys.argv) != 3: |
31 | | - print("Usage: python train.py <input_file> <output_dir>") |
32 | | - sys.exit(1) |
33 | | - |
34 | | -# Load dataset |
35 | | -data = pd.read_csv(sys.argv[1]) |
36 | | - |
37 | | -# Define parameters |
38 | | -MAX_WORDS = 10000 |
39 | | -MAX_LEN = 100 |
40 | | - |
41 | | -# Use Tokenizer to encode text |
42 | | -tokenizer = Tokenizer(num_words=MAX_WORDS, filters="") |
43 | | -tokenizer.fit_on_texts(data["Query"]) |
44 | | -sequences = tokenizer.texts_to_sequences(data["Query"]) |
45 | | - |
46 | | -# Pad the text sequence |
47 | | -X = pad_sequences(sequences, maxlen=MAX_LEN) |
| 24 | +def load_data(file_path): |
| 25 | + """Load data from a CSV file.""" |
| 26 | + try: |
| 27 | + return pd.read_csv(file_path) |
| 28 | + except Exception as e: |
| 29 | + print(f"Error loading data: {e}") |
| 30 | + sys.exit(1) |
| 31 | + |
| 32 | + |
| 33 | +def preprocess_text(data, max_words=10000, max_len=100): |
| 34 | + """Tokenize and pad text data.""" |
| 35 | + tokenizer = Tokenizer(num_words=max_words, oov_token="<OOV>") |
| 36 | + tokenizer.fit_on_texts(data["Query"]) |
| 37 | + sequences = tokenizer.texts_to_sequences(data["Query"]) |
| 38 | + return pad_sequences(sequences, maxlen=max_len), tokenizer |
| 39 | + |
| 40 | + |
| 41 | +def build_model(input_dim, output_dim=128): |
| 42 | + """Define and compile the CNN-BiLSTM model.""" |
| 43 | + model = Sequential( |
| 44 | + [ |
| 45 | + Embedding(input_dim=input_dim, output_dim=output_dim), |
| 46 | + Dropout(0.2), |
| 47 | + Conv1D(filters=64, kernel_size=3, padding="same", activation="relu"), |
| 48 | + MaxPooling1D(pool_size=2), |
| 49 | + Bidirectional(LSTM(64, dropout=0.2, recurrent_dropout=0.2)), |
| 50 | + Flatten(), |
| 51 | + Dense(1, activation="sigmoid"), |
| 52 | + ] |
| 53 | + ) |
| 54 | + model.compile( |
| 55 | + loss="binary_crossentropy", |
| 56 | + optimizer="adam", |
| 57 | + metrics=[ |
| 58 | + "accuracy", |
| 59 | + tf.keras.metrics.Precision(name="precision"), |
| 60 | + tf.keras.metrics.Recall(name="recall"), |
| 61 | + ], |
| 62 | + ) |
| 63 | + return model |
| 64 | + |
| 65 | + |
| 66 | +def calculate_f1_f2(precision, recall, beta=1): |
| 67 | + """Calculate F1 or F2 score based on precision and recall with given beta.""" |
| 68 | + beta_squared = beta**2 |
| 69 | + return ( |
| 70 | + (1 + beta_squared) |
| 71 | + * (precision * recall) |
| 72 | + / (beta_squared * precision + recall + tf.keras.backend.epsilon()) |
| 73 | + ) |
48 | 74 |
|
49 | | -# Split the training set and test set |
50 | | -y = data["Label"] |
51 | | -X_train, X_test, y_train, y_test = train_test_split( |
52 | | - X, y, test_size=0.2, random_state=42 |
53 | | -) |
54 | | - |
55 | | -# Create CNN-BiLSTM model |
56 | | -model = Sequential() |
57 | | -model.add(Embedding(MAX_WORDS, 128)) |
58 | | -model.add(Conv1D(filters=64, kernel_size=3, padding="same", activation="relu")) |
59 | | -model.add(MaxPooling1D(pool_size=2)) |
60 | | -model.add(Bidirectional(LSTM(64, dropout=0.2, recurrent_dropout=0.2))) |
61 | | -model.add(Flatten()) |
62 | | -model.add(Dense(1, activation="sigmoid")) |
63 | | - |
64 | | -model.compile( |
65 | | - loss="binary_crossentropy", |
66 | | - optimizer="adam", |
67 | | - metrics=[ |
68 | | - Accuracy(), |
69 | | - Recall(), |
70 | | - Precision(), |
71 | | - ], |
72 | | -) |
73 | | - |
74 | | -# Define early stopping callback with a rollback of 5 |
75 | | -early_stopping = EarlyStopping( |
76 | | - monitor="val_loss", patience=5, restore_best_weights=True |
77 | | -) |
78 | 75 |
|
79 | | -# Train model with early stopping |
80 | | -history = model.fit( |
81 | | - X_train, |
82 | | - y_train, |
83 | | - epochs=50, # Maximum number of epochs |
84 | | - batch_size=32, |
85 | | - validation_data=(X_test, y_test), |
86 | | - callbacks=[early_stopping], |
87 | | - verbose=1, |
88 | | -) |
89 | | - |
90 | | -# Predict test set |
91 | | -y_pred = model.predict(X_test, verbose=1) |
92 | | -y_pred_classes = np.argmax(y_pred, axis=1) |
93 | | - |
94 | | -# Calculate model performance indicators |
95 | | -accuracy = accuracy_score(y_test, y_pred_classes) |
96 | | -recall = recall_score(y_test, y_pred_classes, zero_division=1) |
97 | | -precision = precision_score(y_test, y_pred_classes, zero_division=1) |
98 | | -f1 = f1_score(y_test, y_pred_classes, zero_division=1) |
99 | | -tn, fp, fn, tp = confusion_matrix(y_test, y_pred_classes).ravel() |
100 | | - |
101 | | -# Output performance indicators |
102 | | -print("Accuracy: {:.2f}%".format(accuracy * 100)) |
103 | | -print("Recall: {:.2f}%".format(recall * 100)) |
104 | | -print("Precision: {:.2f}%".format(precision * 100)) |
105 | | -print("F1-score: {:.2f}%".format(f1 * 100)) |
106 | | -print("Specificity: {:.2f}%".format(tn / (tn + fp) * 100)) |
107 | | -print("ROC: {:.2f}%".format(tp / (tp + fn) * 100)) |
108 | | - |
109 | | -# Save model as SavedModel format |
110 | | -model.export(sys.argv[2]) |
111 | | - |
112 | | - |
113 | | -# Plot the training history |
114 | 76 | def plot_history(history): |
| 77 | + """Plot the training and validation loss, accuracy, precision, and recall.""" |
115 | 78 | plt.figure(figsize=(12, 8)) |
116 | | - |
117 | | - # Plot loss |
118 | | - plt.subplot(2, 2, 1) |
119 | | - plt.plot(history.history["loss"], label="Training Loss") |
120 | | - plt.plot(history.history["val_loss"], label="Validation Loss") |
121 | | - plt.title("Loss") |
122 | | - plt.xlabel("Epochs") |
123 | | - plt.ylabel("Loss") |
124 | | - plt.legend() |
125 | | - |
126 | | - # Plot accuracy |
127 | | - plt.subplot(2, 2, 2) |
128 | | - plt.plot(history.history["accuracy"], label="Training Accuracy") |
129 | | - plt.plot(history.history["val_accuracy"], label="Validation Accuracy") |
130 | | - plt.title("Accuracy") |
131 | | - plt.xlabel("Epochs") |
132 | | - plt.ylabel("Accuracy") |
133 | | - plt.legend() |
134 | | - |
135 | | - # Plot precision |
136 | | - plt.subplot(2, 2, 3) |
137 | | - plt.plot(history.history["precision"], label="Training Precision") |
138 | | - plt.plot(history.history["val_precision"], label="Validation Precision") |
139 | | - plt.title("Precision") |
140 | | - plt.xlabel("Epochs") |
141 | | - plt.ylabel("Precision") |
142 | | - plt.legend() |
143 | | - |
144 | | - # Plot recall |
145 | | - plt.subplot(2, 2, 4) |
146 | | - plt.plot(history.history["recall"], label="Training Recall") |
147 | | - plt.plot(history.history["val_recall"], label="Validation Recall") |
148 | | - plt.title("Recall") |
149 | | - plt.xlabel("Epochs") |
150 | | - plt.ylabel("Recall") |
151 | | - plt.legend() |
152 | | - |
| 79 | + for i, metric in enumerate(["loss", "accuracy", "precision", "recall"], start=1): |
| 80 | + plt.subplot(2, 2, i) |
| 81 | + plt.plot(history.history[metric], label=f"Training {metric.capitalize()}") |
| 82 | + plt.plot( |
| 83 | + history.history[f"val_{metric}"], label=f"Validation {metric.capitalize()}" |
| 84 | + ) |
| 85 | + plt.title(metric.capitalize()) |
| 86 | + plt.xlabel("Epochs") |
| 87 | + plt.ylabel(metric.capitalize()) |
| 88 | + plt.legend() |
153 | 89 | plt.tight_layout() |
154 | 90 | plt.savefig("training_history.png") |
155 | 91 |
|
156 | 92 |
|
157 | | -plot_history(history) |
| 93 | +# Main function |
| 94 | +if __name__ == "__main__": |
| 95 | + if len(sys.argv) != 3: |
| 96 | + print("Usage: python train.py <input_file> <output_dir>") |
| 97 | + sys.exit(1) |
| 98 | + |
| 99 | + # Load and preprocess data |
| 100 | + data = load_data(sys.argv[1]) |
| 101 | + X, tokenizer = preprocess_text(data) |
| 102 | + y = data["Label"] |
| 103 | + |
| 104 | + # Initialize cross-validation |
| 105 | + k_folds = 5 |
| 106 | + kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42) |
| 107 | + fold_metrics = {"accuracy": [], "precision": [], "recall": [], "f1": [], "f2": []} |
| 108 | + |
| 109 | + for fold, (train_idx, val_idx) in enumerate(kfold.split(X, y), 1): |
| 110 | + print(f"Training fold {fold}/{k_folds}") |
| 111 | + |
| 112 | + # Split the data |
| 113 | + X_train, X_val = X[train_idx], X[val_idx] |
| 114 | + y_train, y_val = y.iloc[train_idx], y.iloc[val_idx] |
| 115 | + |
| 116 | + # Build and train the model |
| 117 | + model = build_model(input_dim=len(tokenizer.word_index) + 1) |
| 118 | + early_stopping = EarlyStopping( |
| 119 | + monitor="val_loss", patience=5, restore_best_weights=True |
| 120 | + ) |
| 121 | + history = model.fit( |
| 122 | + X_train, |
| 123 | + y_train, |
| 124 | + epochs=50, |
| 125 | + batch_size=32, |
| 126 | + validation_data=(X_val, y_val), |
| 127 | + callbacks=[early_stopping], |
| 128 | + verbose=1, |
| 129 | + ) |
| 130 | + |
| 131 | + # Make predictions to manually calculate metrics |
| 132 | + y_val_pred = (model.predict(X_val) > 0.5).astype(int) |
| 133 | + accuracy = accuracy_score(y_val, y_val_pred) |
| 134 | + precision = precision_score(y_val, y_val_pred) |
| 135 | + recall = recall_score(y_val, y_val_pred) |
| 136 | + f1_score = calculate_f1_f2(precision, recall, beta=1) |
| 137 | + f2_score = calculate_f1_f2(precision, recall, beta=2) |
| 138 | + |
| 139 | + # Collect fold metrics |
| 140 | + fold_metrics["accuracy"].append(accuracy) |
| 141 | + fold_metrics["precision"].append(precision) |
| 142 | + fold_metrics["recall"].append(recall) |
| 143 | + fold_metrics["f1"].append(f1_score) |
| 144 | + fold_metrics["f2"].append(f2_score) |
| 145 | + |
| 146 | + # Calculate average metrics across folds |
| 147 | + avg_metrics = {metric: np.mean(scores) for metric, scores in fold_metrics.items()} |
| 148 | + print("\nCross-validation results:") |
| 149 | + for metric, value in avg_metrics.items(): |
| 150 | + print(f"{metric.capitalize()}: {value:.2f}") |
| 151 | + |
| 152 | + # Save the final model trained on the last fold |
| 153 | + model.export(sys.argv[2]) |
| 154 | + plot_history(history) |
0 commit comments