11import sys
2+ import os
23import pandas as pd
34import tensorflow as tf
45from tensorflow .keras .preprocessing .text import Tokenizer
1718from tensorflow .keras .callbacks import EarlyStopping
1819from sklearn .model_selection import KFold
1920from sklearn .metrics import accuracy_score , precision_score , recall_score
21+ from sklearn .utils .class_weight import compute_class_weight
2022import numpy as np
2123import matplotlib .pyplot as plt
2224
@@ -54,11 +56,7 @@ def build_model(input_dim, output_dim=128):
5456 model .compile (
5557 loss = "binary_crossentropy" ,
5658 optimizer = "adam" ,
57- metrics = [
58- "accuracy" ,
59- tf .keras .metrics .Precision (name = "precision" ),
60- tf .keras .metrics .Recall (name = "recall" ),
61- ],
59+ metrics = ["accuracy" , tf .keras .metrics .Precision (), tf .keras .metrics .Recall ()],
6260 )
6361 return model
6462
@@ -75,31 +73,43 @@ def calculate_f1_f2(precision, recall, beta=1):
7573
7674def plot_history (history ):
7775 """Plot the training and validation loss, accuracy, precision, and recall."""
76+ available_metrics = history .history .keys () # Check which metrics are available
7877 plt .figure (figsize = (12 , 8 ))
79- for i , metric in enumerate (["loss" , "accuracy" , "precision" , "recall" ], start = 1 ):
80- plt .subplot (2 , 2 , i )
81- plt .plot (history .history [metric ], label = f"Training { metric .capitalize ()} " )
82- plt .plot (
83- history .history [f"val_{ metric } " ], label = f"Validation { metric .capitalize ()} "
84- )
85- plt .title (metric .capitalize ())
86- plt .xlabel ("Epochs" )
87- plt .ylabel (metric .capitalize ())
88- plt .legend ()
78+
79+ # Define metrics to plot
80+ metrics_to_plot = ["loss" , "accuracy" , "precision" , "recall" ]
81+ for i , metric in enumerate (metrics_to_plot , start = 1 ):
82+ if metric in available_metrics :
83+ plt .subplot (2 , 2 , i )
84+ plt .plot (history .history [metric ], label = f"Training { metric .capitalize ()} " )
85+ plt .plot (
86+ history .history [f"val_{ metric } " ],
87+ label = f"Validation { metric .capitalize ()} " ,
88+ )
89+ plt .title (metric .capitalize ())
90+ plt .xlabel ("Epochs" )
91+ plt .ylabel (metric .capitalize ())
92+ plt .legend ()
93+
8994 plt .tight_layout ()
9095 plt .savefig ("training_history.png" )
9196
9297
93- # Main function
9498if __name__ == "__main__" :
9599 if len (sys .argv ) != 3 :
96100 print ("Usage: python train.py <input_file> <output_dir>" )
97101 sys .exit (1 )
98102
103+ # Constants
104+ MAX_WORDS = 10000
105+ MAX_LEN = 100
106+ EPOCHS = 50
107+ BATCH_SIZE = 32
108+
99109 # Load and preprocess data
100110 data = load_data (sys .argv [1 ])
101111 X , tokenizer = preprocess_text (data )
102- y = data ["Label" ]
112+ y = data ["Label" ]. values # Convert to NumPy array to avoid KeyError in KFold
103113
104114 # Initialize cross-validation
105115 k_folds = 5
@@ -111,7 +121,13 @@ def plot_history(history):
111121
112122 # Split the data
113123 X_train , X_val = X [train_idx ], X [val_idx ]
114- y_train , y_val = y .iloc [train_idx ], y .iloc [val_idx ]
124+ y_train , y_val = y [train_idx ], y [val_idx ]
125+
126+ # Compute class weights to handle imbalance
127+ class_weights = compute_class_weight (
128+ "balanced" , classes = np .unique (y_train ), y = y_train
129+ )
130+ class_weight_dict = {i : class_weights [i ] for i in range (len (class_weights ))}
115131
116132 # Build and train the model
117133 model = build_model (input_dim = len (tokenizer .word_index ) + 1 )
@@ -121,15 +137,16 @@ def plot_history(history):
121137 history = model .fit (
122138 X_train ,
123139 y_train ,
124- epochs = 50 ,
125- batch_size = 32 ,
140+ epochs = EPOCHS ,
141+ batch_size = BATCH_SIZE ,
126142 validation_data = (X_val , y_val ),
143+ class_weight = class_weight_dict ,
127144 callbacks = [early_stopping ],
128145 verbose = 1 ,
129146 )
130147
131- # Make predictions to manually calculate metrics
132- y_val_pred = (model .predict (X_val ) > 0.5 ).astype (int )
148+ # Make predictions to calculate metrics
149+ y_val_pred = (model .predict (X_val ) > 0.8 ).astype (int )
133150 accuracy = accuracy_score (y_val , y_val_pred )
134151 precision = precision_score (y_val , y_val_pred )
135152 recall = recall_score (y_val , y_val_pred )
@@ -143,12 +160,17 @@ def plot_history(history):
143160 fold_metrics ["f1" ].append (f1_score )
144161 fold_metrics ["f2" ].append (f2_score )
145162
146- # Calculate average metrics across folds
163+ # Calculate and display average metrics across folds
147164 avg_metrics = {metric : np .mean (scores ) for metric , scores in fold_metrics .items ()}
148165 print ("\n Cross-validation results:" )
149166 for metric , value in avg_metrics .items ():
150167 print (f"{ metric .capitalize ()} : { value :.2f} " )
151168
152169 # Save the final model trained on the last fold
153- model .export (sys .argv [2 ])
170+ output_dir = sys .argv [2 ]
171+ if not os .path .exists (output_dir ):
172+ os .makedirs (output_dir )
173+ model .export (output_dir )
174+
175+ # Plot training history of the last fold
154176 plot_history (history )
0 commit comments