Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ cfg_constructor/out
**/.venv
**/logdat

!**/demo_file_id.txt
!**/demo_file_id.txt
317 changes: 317 additions & 0 deletions GraphIsomorphismNetwork/jax_gat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,317 @@
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state, checkpoints
from new_loader import get_paths, load_dataset_jax_new
import optax
from pathlib import Path
import pandas as pd
import numpy as np
import time
import os
import logging
import tqdm

log_file = Path(__file__).resolve().parent.parent / "logs" / "gat_training.log"
log_file.parent.mkdir(parents=True, exist_ok=True)

root_logger = logging.getLogger()
if root_logger.handlers:
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)

logging.basicConfig(
filename=log_file,
level=logging.INFO,
filemode="a",
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt='%Y-%m-%d %H:%M:%S'
)

logging.info("Starting GAT training...")

class GraphAttentionLayer(nn.Module):
out_dim: int
dropout_rate: float = 0.5

@nn.compact
def __call__(self, x, senders, receivers, training: bool):
W = nn.Dense(self.out_dim, use_bias=False)
x = W(x)

a1 = nn.Dense(1, use_bias=False, name='attention_left')
a2 = nn.Dense(1, use_bias=False, name='attention_right')

f_1 = a1(x) # [N, 1]
f_2 = a2(x)

# Create attention logits for edges
edge_logits = f_1[senders] + f_2[receivers]
edge_logits = nn.leaky_relu(edge_logits.squeeze(-1), negative_slope=0.2)

# Normalize attention weights using softmax
max_logits = jax.ops.segment_max(
edge_logits, receivers, num_segments=x.shape[0], indices_are_sorted=False
)
edge_logits_normalized = edge_logits - max_logits[receivers]
edge_weights = jnp.exp(edge_logits_normalized)


weight_sum = jax.ops.segment_sum(
edge_weights, receivers, num_segments=x.shape[0], indices_are_sorted=False
)
edge_weights = edge_weights / (weight_sum[receivers] + 1e-16)

# Apply dropout to attention weights
edge_weights = nn.Dropout(rate=self.dropout_rate)(
edge_weights, deterministic=not training
)


x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not training)


weighted_features = edge_weights[:, None] * x[senders]
aggregated = jax.ops.segment_sum(
weighted_features, receivers, num_segments=x.shape[0], indices_are_sorted=False
)

return nn.elu(aggregated)

class MultiHeadAttention(nn.Module):
num_heads: int
hidden_dim: int
dropout_rate: float = 0.5
last_layer: bool = False

@nn.compact
def __call__(self, x, senders, receivers, training: bool):
#multiple attention heads
heads = []
for i in range(self.num_heads):
head = GraphAttentionLayer(
out_dim=self.hidden_dim,
dropout_rate=self.dropout_rate,
name=f'attention_head_{i}'
)
head_out = head(x, senders, receivers, training)
heads.append(head_out)

if self.last_layer:
# Average the heads for the last layer
return jnp.mean(jnp.stack(heads), axis=0)
else:
return jnp.concatenate(heads, axis=-1)

class GAT(nn.Module):
in_channels: int
hidden_channels: int
out_channels: int
num_heads: list
dropout_rate: float = 0.5

@nn.compact
def __call__(self, x, edge_index, batch, training: bool):
senders, receivers = edge_index

x = MultiHeadAttention(
num_heads=self.num_heads[0],
hidden_dim=self.hidden_channels,
dropout_rate=self.dropout_rate,
last_layer=False,
name='layer_0'
)(x, senders, receivers, training)


x = MultiHeadAttention(
num_heads=self.num_heads[1],
hidden_dim=self.out_channels,
dropout_rate=self.dropout_rate,
last_layer=True,
name='layer_1'
)(x, senders, receivers, training)

# Global add pooling
x = jax.ops.segment_sum(x, batch, num_segments=1)

return jax.nn.log_softmax(x, axis=-1)

class TrainState(train_state.TrainState):
batch_stats: dict

def create_train_state(rng, model, learning_rate, sample_input):
variables = model.init(
rng,
sample_input["x"],
sample_input["edge_index"],
sample_input["batch"],
training=True,
)
tx = optax.adam(learning_rate)
return TrainState.create(
apply_fn=model.apply,
params=variables["params"],
tx=tx,
batch_stats=variables.get("batch_stats", {})
)

@jax.jit
def train_step(state, batch, dropout_rng):
def loss_fn(params):
variables = {"params": params, "batch_stats": state.batch_stats}
logits, new_model_state = state.apply_fn(
variables,
batch["x"],
batch["edge_index"],
batch["batch"],
training=True,
mutable=["batch_stats"],
rngs={"dropout": dropout_rng},
)
labels = batch["y"]
nll = -jnp.mean(jnp.take_along_axis(logits, labels[:, None], axis=-1).squeeze())

# L2 regularization
l2_loss = 5e-4 * sum(jnp.sum(p**2) for p in jax.tree_leaves(params))
total_loss = nll + l2_loss

return total_loss, new_model_state

grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, new_model_state), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads, batch_stats=new_model_state["batch_stats"])
return state, loss

@jax.jit
def test_step(state, batch):
variables = {"params": state.params, "batch_stats": state.batch_stats}
logits = state.apply_fn(
variables,
batch["x"],
batch["edge_index"],
batch["batch"],
training=False,
mutable=False
)
pred = jnp.argmax(logits, axis=-1)
correct = jnp.sum(pred == batch["y"])
return correct

def save_model(state, checkpoint_dir, step):
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoints.save_checkpoint(
ckpt_dir=checkpoint_dir,
target=state,
step=step,
overwrite=True
)
msg = f"Model saved at step {step} to {checkpoint_dir}"
print(msg)
logging.info(msg)

def load_model_if_exists(state, checkpoint_dir):
latest_ckpt = checkpoints.latest_checkpoint(checkpoint_dir)
if latest_ckpt:
state = checkpoints.restore_checkpoint(
ckpt_dir=checkpoint_dir,
target=state
)
msg = f"Model restored from checkpoint: {latest_ckpt}"
print(msg)
logging.info(msg)
else:
msg = "No checkpoint found. Training from scratch."
print(msg)
logging.info(msg)
return state

def main():
# print(jax.devices())
dev = jax.devices()[0] if jax.devices() else None

paths = get_paths(samples_2000=True)
data_loader, num_features, num_classes = load_dataset_jax_new(paths, max_files=200)

split_ratio = 0.8
n_total = len(data_loader)
split_idx = int(n_total * split_ratio)
train_loader = data_loader[:split_idx]
test_loader = data_loader[split_idx:]
print(f"Total batches: {n_total}; Training batches: {len(train_loader)}; Test batches: {len(test_loader)}")

print(f"Number of Features: {num_features}, model hidden layer dim: {max(1, int(num_features * 1e-4))}")

# GAT config
num_heads = [8, 1]
hidden_dim = max(1, int(num_features * 1e-4))

model = GAT(
in_channels=num_features,
hidden_channels=hidden_dim,
out_channels=num_classes,
num_heads=num_heads,
dropout_rate=0.5
)

rng = jax.random.PRNGKey(0)
dropout_rng, init_rng = jax.random.split(rng)

print(f"Total samples: {n_total}")
sample_input = train_loader[0]
state = create_train_state(init_rng, model, learning_rate=0.005, sample_input=sample_input)

checkpoint_dir = Path(__file__).resolve().parent.parent / "weights" / "gat"
#state = load_model_if_exists(state, checkpoint_dir)

start_train = time.perf_counter()
num_epochs = 400

state = jax.device_put(state, device=dev)

for epoch in range(1, num_epochs + 1):
epoch_start = time.perf_counter()

# Train Loop
epoch_loss = 0.0
total_graphs = 0
for batch in tqdm.tqdm(train_loader, desc=f"training epoch {epoch}"):
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
state, loss = train_step(state, batch, dropout_rng)
state = jax.device_put(state, device=dev)
epoch_loss += loss * batch["num_graphs"]
total_graphs += batch["num_graphs"]
avg_loss = epoch_loss / total_graphs

# Test Loop
total_correct = 0
total_test_graphs = 0
for batch in test_loader:
correct = test_step(state, batch)
total_correct += correct
total_test_graphs += batch["num_graphs"]
test_acc = total_correct / total_test_graphs

epoch_end = time.perf_counter()
msg1 = f"Time for epoch {epoch} was {(epoch_end - epoch_start):.6f} seconds"
msg2 = f"Epoch: {epoch:03d}, Loss: {avg_loss:.4f}, Test Acc: {test_acc:.4f}"
print(msg1)
print(msg2)
logging.info(msg1)
logging.info(msg2)

if epoch % 50 == 0:
save_model(state, checkpoint_dir, epoch)


save_model(state, checkpoint_dir, num_epochs)

end_train = time.perf_counter()
msg3 = f"Total training time : {(end_train - start_train):.6f} seconds"
msg4 = f"Average time per epoch: {((end_train - start_train) / num_epochs):.6f} seconds"
print(msg3)
print(msg4)
logging.info(msg3)
logging.info(msg4)

if __name__ == "__main__":
main()
Loading