This repository contains code for classifying events using Graph Neural Networks (GNNs). (The data in this repository is for HZZ vs ZZ.) The classification pipeline consists of three main scripts:
grapher.py- Converts input data into graph representations.trainer2.py- Trains a GNN model on the generated graphs.tester2.py- Evaluates the trained model and generates performance plots.
.
├── Data_train
│ ├── raw
│ │ ├── data_norm.txt
│ ├── processed
│
├── Data_test
│ ├── raw
│ │ ├── data_norm.txt
│ ├── processed
│
├── grapher.py
├── trainer2.py
├── tester2.py
├── README.md
- Reads
data_norm.txtanddata_norm.txtfromData_train/raw/andData_test/raw/respectively. - Converts the data into graph structures. (Here 4 nodes with node 0,1 connected and 2,3 connected. each with 4 features)
- Saves the processed graph data in
Data_train/processed/andData_test/processed/.
Run:
python grapher.py- Loads the processed graph data from
Data_train/processed/. - Trains a GNN model for graph classification.
- Saves the trained model.
Run:
python trainer2.py- Loads the trained model and evaluates it on the test dataset from
Data_test/processed/. - Generates and saves performance plots:
- Loss vs Epochs
- Accuracy vs Epochs
- Neural Network Output Distribution
- ROC Curve
- Confusion Matrix
Run:
python tester2.pyEnsure you have the following Python packages installed:
pip install torch==2.3.0 torch_geometric matplotlib numpy scikit-learn pandas seaborn tqdm torchinfo h5py
- Modify the dataset paths in the scripts if needed.
- Adjust model parameters in
trainer2.pyfor better performance. - Results are stored in respective output directories and visualized via
tester2.py.
Author: Shreyas Bakare