Skip to content

Commit d78f3b6

Browse files
committed
Revamp training v3 script
Add cross validation with KFold Add F1 and F2 score metrics Fix training metrics
1 parent d766243 commit d78f3b6

File tree

1 file changed

+127
-130
lines changed

1 file changed

+127
-130
lines changed

training/train_v3.py

Lines changed: 127 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -1,157 +1,154 @@
11
import sys
22
import pandas as pd
3+
import tensorflow as tf
34
from tensorflow.keras.preprocessing.text import Tokenizer
45
from tensorflow.keras.preprocessing.sequence import pad_sequences
56
from tensorflow.keras.models import Sequential
67
from tensorflow.keras.layers import (
78
Bidirectional,
89
Conv1D,
910
Dense,
11+
Dropout,
1012
Embedding,
1113
Flatten,
1214
LSTM,
1315
MaxPooling1D,
1416
)
15-
from tensorflow.keras.metrics import Accuracy, Recall, Precision
1617
from tensorflow.keras.callbacks import EarlyStopping
17-
from sklearn.model_selection import train_test_split
18-
from sklearn.metrics import (
19-
accuracy_score,
20-
recall_score,
21-
precision_score,
22-
f1_score,
23-
confusion_matrix,
24-
)
18+
from sklearn.model_selection import KFold
19+
from sklearn.metrics import accuracy_score, precision_score, recall_score
2520
import numpy as np
2621
import matplotlib.pyplot as plt
2722

2823

29-
# Check if the input file and output directory are provided
30-
if len(sys.argv) != 3:
31-
print("Usage: python train.py <input_file> <output_dir>")
32-
sys.exit(1)
33-
34-
# Load dataset
35-
data = pd.read_csv(sys.argv[1])
36-
37-
# Define parameters
38-
MAX_WORDS = 10000
39-
MAX_LEN = 100
40-
41-
# Use Tokenizer to encode text
42-
tokenizer = Tokenizer(num_words=MAX_WORDS, filters="")
43-
tokenizer.fit_on_texts(data["Query"])
44-
sequences = tokenizer.texts_to_sequences(data["Query"])
45-
46-
# Pad the text sequence
47-
X = pad_sequences(sequences, maxlen=MAX_LEN)
24+
def load_data(file_path):
25+
"""Load data from a CSV file."""
26+
try:
27+
return pd.read_csv(file_path)
28+
except Exception as e:
29+
print(f"Error loading data: {e}")
30+
sys.exit(1)
31+
32+
33+
def preprocess_text(data, max_words=10000, max_len=100):
34+
"""Tokenize and pad text data."""
35+
tokenizer = Tokenizer(num_words=max_words, oov_token="<OOV>")
36+
tokenizer.fit_on_texts(data["Query"])
37+
sequences = tokenizer.texts_to_sequences(data["Query"])
38+
return pad_sequences(sequences, maxlen=max_len), tokenizer
39+
40+
41+
def build_model(input_dim, output_dim=128):
42+
"""Define and compile the CNN-BiLSTM model."""
43+
model = Sequential(
44+
[
45+
Embedding(input_dim=input_dim, output_dim=output_dim),
46+
Dropout(0.2),
47+
Conv1D(filters=64, kernel_size=3, padding="same", activation="relu"),
48+
MaxPooling1D(pool_size=2),
49+
Bidirectional(LSTM(64, dropout=0.2, recurrent_dropout=0.2)),
50+
Flatten(),
51+
Dense(1, activation="sigmoid"),
52+
]
53+
)
54+
model.compile(
55+
loss="binary_crossentropy",
56+
optimizer="adam",
57+
metrics=[
58+
"accuracy",
59+
tf.keras.metrics.Precision(name="precision"),
60+
tf.keras.metrics.Recall(name="recall"),
61+
],
62+
)
63+
return model
64+
65+
66+
def calculate_f1_f2(precision, recall, beta=1):
67+
"""Calculate F1 or F2 score based on precision and recall with given beta."""
68+
beta_squared = beta**2
69+
return (
70+
(1 + beta_squared)
71+
* (precision * recall)
72+
/ (beta_squared * precision + recall + tf.keras.backend.epsilon())
73+
)
4874

49-
# Split the training set and test set
50-
y = data["Label"]
51-
X_train, X_test, y_train, y_test = train_test_split(
52-
X, y, test_size=0.2, random_state=42
53-
)
54-
55-
# Create CNN-BiLSTM model
56-
model = Sequential()
57-
model.add(Embedding(MAX_WORDS, 128))
58-
model.add(Conv1D(filters=64, kernel_size=3, padding="same", activation="relu"))
59-
model.add(MaxPooling1D(pool_size=2))
60-
model.add(Bidirectional(LSTM(64, dropout=0.2, recurrent_dropout=0.2)))
61-
model.add(Flatten())
62-
model.add(Dense(1, activation="sigmoid"))
63-
64-
model.compile(
65-
loss="binary_crossentropy",
66-
optimizer="adam",
67-
metrics=[
68-
Accuracy(),
69-
Recall(),
70-
Precision(),
71-
],
72-
)
73-
74-
# Define early stopping callback with a rollback of 5
75-
early_stopping = EarlyStopping(
76-
monitor="val_loss", patience=5, restore_best_weights=True
77-
)
7875

79-
# Train model with early stopping
80-
history = model.fit(
81-
X_train,
82-
y_train,
83-
epochs=50, # Maximum number of epochs
84-
batch_size=32,
85-
validation_data=(X_test, y_test),
86-
callbacks=[early_stopping],
87-
verbose=1,
88-
)
89-
90-
# Predict test set
91-
y_pred = model.predict(X_test, verbose=1)
92-
y_pred_classes = np.argmax(y_pred, axis=1)
93-
94-
# Calculate model performance indicators
95-
accuracy = accuracy_score(y_test, y_pred_classes)
96-
recall = recall_score(y_test, y_pred_classes, zero_division=1)
97-
precision = precision_score(y_test, y_pred_classes, zero_division=1)
98-
f1 = f1_score(y_test, y_pred_classes, zero_division=1)
99-
tn, fp, fn, tp = confusion_matrix(y_test, y_pred_classes).ravel()
100-
101-
# Output performance indicators
102-
print("Accuracy: {:.2f}%".format(accuracy * 100))
103-
print("Recall: {:.2f}%".format(recall * 100))
104-
print("Precision: {:.2f}%".format(precision * 100))
105-
print("F1-score: {:.2f}%".format(f1 * 100))
106-
print("Specificity: {:.2f}%".format(tn / (tn + fp) * 100))
107-
print("ROC: {:.2f}%".format(tp / (tp + fn) * 100))
108-
109-
# Save model as SavedModel format
110-
model.export(sys.argv[2])
111-
112-
113-
# Plot the training history
11476
def plot_history(history):
77+
"""Plot the training and validation loss, accuracy, precision, and recall."""
11578
plt.figure(figsize=(12, 8))
116-
117-
# Plot loss
118-
plt.subplot(2, 2, 1)
119-
plt.plot(history.history["loss"], label="Training Loss")
120-
plt.plot(history.history["val_loss"], label="Validation Loss")
121-
plt.title("Loss")
122-
plt.xlabel("Epochs")
123-
plt.ylabel("Loss")
124-
plt.legend()
125-
126-
# Plot accuracy
127-
plt.subplot(2, 2, 2)
128-
plt.plot(history.history["accuracy"], label="Training Accuracy")
129-
plt.plot(history.history["val_accuracy"], label="Validation Accuracy")
130-
plt.title("Accuracy")
131-
plt.xlabel("Epochs")
132-
plt.ylabel("Accuracy")
133-
plt.legend()
134-
135-
# Plot precision
136-
plt.subplot(2, 2, 3)
137-
plt.plot(history.history["precision"], label="Training Precision")
138-
plt.plot(history.history["val_precision"], label="Validation Precision")
139-
plt.title("Precision")
140-
plt.xlabel("Epochs")
141-
plt.ylabel("Precision")
142-
plt.legend()
143-
144-
# Plot recall
145-
plt.subplot(2, 2, 4)
146-
plt.plot(history.history["recall"], label="Training Recall")
147-
plt.plot(history.history["val_recall"], label="Validation Recall")
148-
plt.title("Recall")
149-
plt.xlabel("Epochs")
150-
plt.ylabel("Recall")
151-
plt.legend()
152-
79+
for i, metric in enumerate(["loss", "accuracy", "precision", "recall"], start=1):
80+
plt.subplot(2, 2, i)
81+
plt.plot(history.history[metric], label=f"Training {metric.capitalize()}")
82+
plt.plot(
83+
history.history[f"val_{metric}"], label=f"Validation {metric.capitalize()}"
84+
)
85+
plt.title(metric.capitalize())
86+
plt.xlabel("Epochs")
87+
plt.ylabel(metric.capitalize())
88+
plt.legend()
15389
plt.tight_layout()
15490
plt.savefig("training_history.png")
15591

15692

157-
plot_history(history)
93+
# Main function
94+
if __name__ == "__main__":
95+
if len(sys.argv) != 3:
96+
print("Usage: python train.py <input_file> <output_dir>")
97+
sys.exit(1)
98+
99+
# Load and preprocess data
100+
data = load_data(sys.argv[1])
101+
X, tokenizer = preprocess_text(data)
102+
y = data["Label"]
103+
104+
# Initialize cross-validation
105+
k_folds = 5
106+
kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)
107+
fold_metrics = {"accuracy": [], "precision": [], "recall": [], "f1": [], "f2": []}
108+
109+
for fold, (train_idx, val_idx) in enumerate(kfold.split(X, y), 1):
110+
print(f"Training fold {fold}/{k_folds}")
111+
112+
# Split the data
113+
X_train, X_val = X[train_idx], X[val_idx]
114+
y_train, y_val = y.iloc[train_idx], y.iloc[val_idx]
115+
116+
# Build and train the model
117+
model = build_model(input_dim=len(tokenizer.word_index) + 1)
118+
early_stopping = EarlyStopping(
119+
monitor="val_loss", patience=5, restore_best_weights=True
120+
)
121+
history = model.fit(
122+
X_train,
123+
y_train,
124+
epochs=50,
125+
batch_size=32,
126+
validation_data=(X_val, y_val),
127+
callbacks=[early_stopping],
128+
verbose=1,
129+
)
130+
131+
# Make predictions to manually calculate metrics
132+
y_val_pred = (model.predict(X_val) > 0.5).astype(int)
133+
accuracy = accuracy_score(y_val, y_val_pred)
134+
precision = precision_score(y_val, y_val_pred)
135+
recall = recall_score(y_val, y_val_pred)
136+
f1_score = calculate_f1_f2(precision, recall, beta=1)
137+
f2_score = calculate_f1_f2(precision, recall, beta=2)
138+
139+
# Collect fold metrics
140+
fold_metrics["accuracy"].append(accuracy)
141+
fold_metrics["precision"].append(precision)
142+
fold_metrics["recall"].append(recall)
143+
fold_metrics["f1"].append(f1_score)
144+
fold_metrics["f2"].append(f2_score)
145+
146+
# Calculate average metrics across folds
147+
avg_metrics = {metric: np.mean(scores) for metric, scores in fold_metrics.items()}
148+
print("\nCross-validation results:")
149+
for metric, value in avg_metrics.items():
150+
print(f"{metric.capitalize()}: {value:.2f}")
151+
152+
# Save the final model trained on the last fold
153+
model.export(sys.argv[2])
154+
plot_history(history)

0 commit comments

Comments
 (0)