diff --git a/configs/crop_configs/alm_side.json b/configs/crop_configs/alm_side.json deleted file mode 100644 index d0419dd..0000000 --- a/configs/crop_configs/alm_side.json +++ /dev/null @@ -1,5 +0,0 @@ -{ -"proportional_h_coord_top": 0.21666666666666667, -"target_h": 120, -"target_w": 112 -} diff --git a/configs/data_configs/alm_side.json b/configs/data_configs/alm_side.json new file mode 100644 index 0000000..c0f291d --- /dev/null +++ b/configs/data_configs/alm_side.json @@ -0,0 +1,8 @@ +{ +"proportional_h_coord_top": 0.21666666666666667, +"target_h": 120, +"target_w": 112, +"extension": ".png", +"trial_pattern": "^CW44_20240522153954_side_trial_" + +} diff --git a/configs/model_configs/alm_default.json b/configs/model_configs/alm_default.json new file mode 100644 index 0000000..36d2db6 --- /dev/null +++ b/configs/model_configs/alm_default.json @@ -0,0 +1,24 @@ +{ +"num_blocks": 2, +"in_channels_0": 1, +"out_channels_0": 16, +"in_channels_1": 16, +"out_channels_1": 32, +"kernel_preconv": 3, +"stride_preconv": 1, +"pool_size_preconv_0": 2, +"pool_size_preconv_1": 4, +"use_batch_norm_preconv": false, +"kernel_residual": 3, +"stride_residual": 1, +"use_batch_norm_residual": false, +"pool_size_residual_0": null, +"pool_size_residual_1": 4, +"n_layers_residual": 3, +"out_conv": 288, +"out_linear": 128, +"embed_size": 16, +"use_batch_norm_linear": false, +"image_height": 120, +"image_width": 112 +} diff --git a/configs/train_configs/alm_default.json b/configs/train_configs/alm_default.json new file mode 100644 index 0000000..5d4e0fd --- /dev/null +++ b/configs/train_configs/alm_default.json @@ -0,0 +1,17 @@ +{ +"scheduler":"linear", +"start_factor":0.1, +"end_factor":1, +"warmup_steps":10, +"max_epochs":500, +"weight_decay":0, +"l2_weight":1, +"learning_rate":0, +"batch_size":10, +"num_workers":2, +"subsample_rate":10, +"subsample_offset":0, +"max_epochs":500, +"accelerator":"gpu" +"fast_dev_run":false +} diff --git a/configs/train_configs/alm_default_dev.json b/configs/train_configs/alm_default_dev.json new file mode 100644 index 0000000..cd48064 --- /dev/null +++ b/configs/train_configs/alm_default_dev.json @@ -0,0 +1,17 @@ +{ +"scheduler":"linear", +"start_factor":0.1, +"end_factor":1, +"warmup_steps":10, +"max_epochs":500, +"weight_decay":0, +"l2_weight":1, +"learning_rate":0, +"batch_size":10, +"num_workers":2, +"subsample_rate":10, +"subsample_offset":0, +"max_epochs":100, +"accelerator":"cpu", +"fast_dev_run":true +} diff --git a/figures/CW44_naive_default_crop_img.png b/figures/CW44_naive_default_crop_img.png new file mode 100644 index 0000000..ce119f9 Binary files /dev/null and b/figures/CW44_naive_default_crop_img.png differ diff --git a/figures/CW44_naive_pre_crop_img.png b/figures/CW44_naive_pre_crop_img.png new file mode 100644 index 0000000..59b9214 Binary files /dev/null and b/figures/CW44_naive_pre_crop_img.png differ diff --git a/figures/preprocessing_difference_alias.png b/figures/preprocessing_difference_alias.png new file mode 100644 index 0000000..aa70a14 Binary files /dev/null and b/figures/preprocessing_difference_alias.png differ diff --git a/notes/preprocessing.md b/notes/preprocessing.md new file mode 100644 index 0000000..c42d086 --- /dev/null +++ b/notes/preprocessing.md @@ -0,0 +1,13 @@ +# Preprocessing description + +## Default cropping + +[precrop](../figures/CW44_naive_pre_crop_img.png) + +[crop](../figures/CW44_naive_default_crop_img.png) + +## Antialiasing + +Differences greater than 1e-3. Due to aliasing I'm pretty sure. + +[alias](../figures/preprocessing_difference_alias.png) diff --git a/requirements_cpu.txt b/requirements_cpu.txt index 7fd84b2..5b3fc31 100644 --- a/requirements_cpu.txt +++ b/requirements_cpu.txt @@ -1,4 +1,10 @@ -pip install torch # should specify a version later -pip install numpy -pip install matplotlib -pip install pytest +torch # should specify a version later +torchvision +joblib +pytorch_lightning +numpy +matplotlib +pytest +opencv-python +fire +scikit-image diff --git a/scripts/eval_single_session.py b/scripts/eval_single_session.py new file mode 100644 index 0000000..da4f7b0 --- /dev/null +++ b/scripts/eval_single_session.py @@ -0,0 +1,98 @@ +""" +Evaluate a single session autoencoder. + +Assumes that we are given a path to a directory `preds/{modeltype}/date/time/`. Will use the configuration parameters stored there +""" +import os +import datetime +from tqdm import tqdm +import json +import numpy as np +import fire +import torch +import pytorch_lightning as pl +from behavioral_autoencoder.module import SingleSessionModule +from behavioral_autoencoder.dataloading import SessionFramesDataModule +from behavioral_autoencoder.dataset import CropResizeProportion + +here = os.path.join(os.path.abspath(os.path.dirname(__file__))) + +def eval_trialwise(model,test_dataset,mean_image,path): + """ + """ + for i, image_sequence in tqdm(enumerate(test_dataset)): + reconstructs,latents = model(image_sequence[None,:].cuda()) + reconstructs_centered = reconstructs + mean_image.cuda() + folder = test_dataset.trial_folders[i] + savepath = os.path.join(path,folder) + try: + os.mkdir(savepath) + except FileExistsError: + pass + np.save(os.path.join(savepath,"reconstruct.npy"),reconstructs_centered.cpu().detach().numpy()) + np.save(os.path.join(savepath,"latents.npy"),latents.cpu().detach().numpy()) + +def main(data_path,data_config_path,eval_config_path): + """ + get a model path, and use it to load in a given model. + """ + saved_checkpoint_path = os.path.join(".","models","single_session_autoencoder","03-07-25","18_05_06","epoch=99-step=33100.ckpt") + data_dir = os.path.join("home","ubuntu","Data","CW35","2023_12_15","Frames") + metadata_dir = os.path.join(".","preds","single_session_autoencoder","03-07-25","18_05_06") + video_fps = 400 + delay_start_time = 2.5+1.3 ## pre-sample and sample time intervals. + delay_end_time = 2.5+1.3+3 ## delay is 3 seconds. + subsample = 10 + eval_batch_size=1 + + ## Load in data related stuff + + with open(data_config_path,"r") as f: + data_process_config = json.load(f) + + with open(eval_config_path,"r") as f: + eval_config = json.load(f) + + alm_cropping = CropResizeProportion(data_config_path) + data_config = { + "data_path":data_path, + "transform":alm_cropping, + "extension":data_process_config["extension"], + "trial_pattern":data_process_config["trial_pattern"], + "frame_subset":[f"frame_{i:06d}.png" for i in np.arange(int(delay_start_time*video_fps),int(delay_end_time*video_fps),subsample)] + } + + date = datetime.datetime.now().strftime("%m-%d-%y") + time = datetime.datetime.now().strftime("%H_%M_%S") + datestamp_eval = os.path.join(here,"eval",date) + timestamp_eval = os.path.join(here,"eval",date,time) + for path in [datestamp_eval,timestamp_eval]: + try: + os.mkdir(path) + except FileExistsError: + pass + + sfdm = SessionFramesDataModule( + data_config, + eval_config["batch_size"], + eval_config["num_workers"], + eval_config["subsample_rate"], + eval_config["subsample_offset"], + eval_config["val_subsample_rate"], + eval_config["val_subsample_offset"] + ) + + model = SingleSessionModule.load_from_checkpoint( + checkpoint_path=saved_checkpoint_path + ) + + sfdm.setup("test") + + eval_trialwise(model,sfdm.dataset,sfdm.mean_image,path) + + + import pdb; pdb.set_trace() + +if __name__ == "__main__": + fire.Fire(main) + diff --git a/scripts/train_single_session.py b/scripts/train_single_session.py new file mode 100644 index 0000000..b316945 --- /dev/null +++ b/scripts/train_single_session.py @@ -0,0 +1,121 @@ +"""Train a single session autoencoder on provided data. + +Saves checkpoints for the corresponding model, outputs to tensorboard, and finally dumps all predictions and latents into a save directory. + +""" +import os +import fire +import json +import joblib +import datetime +import pytorch_lightning as pl +from behavioral_autoencoder.module import SingleSessionModule +from behavioral_autoencoder.dataset import CropResizeProportion +from behavioral_autoencoder.dataloading import SessionFramesDataModule +from behavioral_autoencoder.eval import get_all_predicts_latents,get_dl_predicts_latents +from pytorch_lightning.callbacks import ModelCheckpoint,LearningRateMonitor +from pytorch_lightning.loggers import TensorBoardLogger + +here = os.path.join(os.path.abspath(os.path.dirname(__file__))) + +def main(model_config_path, train_config_path, data_path, data_config_path): + """This main function takes as input four paths. These paths indicate the model configuration parameters, training configuration parameters, path to the data directory, and cropping configuration, respectively. By default we assume that we are training a single session autoencoder. + """ + print("\n=== Starting Single Session Autoencoder Training ===") + + ## Model setup + print("\nLoading configurations...") + with open(model_config_path,"r") as f: + model_config = json.load(f) + with open(train_config_path,"r") as f: + train_config = json.load(f) + model_name = "single_session_autoencoder" + + print(f"Model config: {model_config}") + print(f"Training config: {train_config}") + + hparams = { + "model":model_name, + "model_config":model_config, + "train_config":train_config + } + + print("\nInitializing model...") + ssm = SingleSessionModule(hparams) + + ## Data setup + with open(data_config_path,"r") as f: + data_process_config = json.load(f) + print("\nSetting up data...") + alm_cropping = CropResizeProportion(data_config_path) + data_config = { + "data_path":data_path, + "transform":alm_cropping, + "data_config_path":data_config_path, + "extension":data_process_config["extension"], + "trial_pattern":data_process_config["trial_pattern"] + } + print(f"Data config: {data_config}") + + print("Initializing data module...") + sfdm = SessionFramesDataModule( + data_config, + train_config["batch_size"], + train_config["num_workers"], + train_config["subsample_rate"], + train_config["subsample_offset"], + train_config["val_subsample_rate"], + train_config["val_subsample_offset"] + ) + + import pdb; pdb.set_trace() + ## Set up logging and trainer + print("\nSetting up logging and checkpoints...") + date=datetime.datetime.now().strftime("%m-%d-%y") + time=datetime.datetime.now().strftime("%H_%M_%S") + timestamp_model = os.path.join(here,"models",model_name,date,time) + timestamp_pred = os.path.join(here,"preds",model_name,date,time) + print(f"Model will be saved to: {timestamp_model}") + print(f"Predictions will be saved to: {timestamp_pred}") + + logger = TensorBoardLogger("tb_logs",name="test_single_session_auto",log_graph=True) + checkpoint = ModelCheckpoint(monitor="mse/val", mode="min", save_last=True, dirpath=timestamp_model) + lr_monitor = LearningRateMonitor(logging_interval='epoch') + + print("\nInitializing trainer...") + trainer = pl.Trainer( + fast_dev_run=train_config["fast_dev_run"], + max_epochs=train_config["max_epochs"], + accelerator=train_config["accelerator"], + enable_checkpointing=True, + callbacks=[checkpoint,lr_monitor], + log_every_n_steps=1, + logger=logger, + enable_progress_bar=True, + ) + + ## Fit the model + print(f"\nStarting training for {train_config['max_epochs']} epochs...") + trainer.fit(ssm,sfdm) + print("Training completed!") + + ## Get out predictions + print("\nGenerating predictions and latents...") + preds,latents = get_dl_predicts_latents(ssm,sfdm.val_dataloader(),sfdm.mean_image,train_config["batch_size"],train_config["num_workers"]) + + ## Save out all relevant metadata + print("\nSaving results...") + os.makedirs(timestamp_pred, exist_ok=True) + joblib.dump(preds,os.path.join(timestamp_pred,"preds")) + joblib.dump(latents,os.path.join(timestamp_pred,"latents")) + with open(os.path.join(timestamp_pred,"model_config"),"w") as f: + json.dump(hparams, f) + with open(os.path.join(timestamp_pred,"data_config"),"w") as f: + data_config["data_path"] = data_path + json.dump(data_config, f) + + print("\n=== Training Complete ===") + print(f"Results saved to: {timestamp_pred}") + +if __name__ == "__main__": + fire.Fire(main) diff --git a/src/behavioral_autoencoder/README.md b/src/behavioral_autoencoder/README.md new file mode 100644 index 0000000..6363eb8 --- /dev/null +++ b/src/behavioral_autoencoder/README.md @@ -0,0 +1,7 @@ +# README +This package is organized as follows: +- `networks.py` contains actual descriptions of network architectures which can be used as components of an autoencoder model. +- `module.py` contains the logic which specifies which combination of network architectures correspond to what kind of models. +- `metrics.py` contains custom build evaluation metrics. +- `dataloading.py` contains code for the dataloaders with preprocessing. +- `dataset.py` contains code to specify how datasets should be structured. diff --git a/src/behavioral_autoencoder/data_utils.py b/src/behavioral_autoencoder/data_utils.py index 2b4823b..f0f6a1c 100644 --- a/src/behavioral_autoencoder/data_utils.py +++ b/src/behavioral_autoencoder/data_utils.py @@ -64,7 +64,7 @@ def check_exists(frame_dir: Path) -> bool: return response == "y" return True -def main(video_dir,frame_dir,video_suffix=".avi"): +def main(video_dir,frame_dir,video_suffix=".avi",match_str=None): """Given a directory of videos, write to a different directory with the following structure: 1. one subdirectory per video file, as well as a metadata file `annotations.txt`. 2. within each subdirectory, pngs per individual frames. @@ -76,12 +76,15 @@ def main(video_dir,frame_dir,video_suffix=".avi"): video_dir: directory containing video files. frame_dir: directory to write frames to. video_suffix (default=".avi"): suffix of video files to consider. + match_str: string to find within the video names to write only a subset. """ video_dir = Path(video_dir) frame_dir = Path(frame_dir) # 1. Get a directory which contains video files. Store video file names. video_files = os.listdir(video_dir) video_files = [f for f in video_files if f.endswith(video_suffix)] + if match_str is not None: + video_files = [f for f in video_files if match_str in f] # 2. Check that the directory we care about exists. If it doesn't create. If it does, ask user. video_files_write = [] for video_file in video_files: @@ -99,7 +102,6 @@ def main(video_dir,frame_dir,video_suffix=".avi"): with open(frame_dir / "annotations.txt", "a") as f: f.write(f"{video_dir_name} {first} {last} 0\n") -def npy_to_frames() if __name__ == "__main__": fire.Fire(main) diff --git a/src/behavioral_autoencoder/dataloading.py b/src/behavioral_autoencoder/dataloading.py new file mode 100644 index 0000000..c7fea34 --- /dev/null +++ b/src/behavioral_autoencoder/dataloading.py @@ -0,0 +1,185 @@ +import torchvision +from tqdm import tqdm +import numpy as np +import torch +from torch.utils.data import Subset,DataLoader +from behavioral_autoencoder.dataset import SessionFramesTorchvision,SessionSequenceTorchvision,CropResizeProportion +import pytorch_lightning as pl +from joblib import Memory +import os +import tempfile +from pathlib import Path + +# Set up cache location with multiple options for flexibility +def get_cache_dir(): + """Get cache directory with the following priority: + 1. BEHAVIORAL_AUTOENCODER_CACHE env variable if set + 2. Project's .cache directory if in development mode + 3. User's home directory under ~/.cache/behavioral_autoencoder + 4. System temp directory as fallback + """ + # Option 1: Environment variable (highest priority) + if "BEHAVIORAL_AUTOENCODER_CACHE" in os.environ: + cache_dir = Path(os.environ["BEHAVIORAL_AUTOENCODER_CACHE"]) + + # Option 2: Project directory if it exists (development mode) + elif (Path(__file__).parent.parent.parent / ".cache").exists(): + cache_dir = Path(__file__).parent.parent.parent / ".cache" + + # Option 3: User's home directory + else: + home_dir = Path.home() + cache_dir = home_dir / ".cache" / "behavioral_autoencoder" + + # Create directory if it doesn't exist + os.makedirs(cache_dir, exist_ok=True) + return cache_dir + +# Initialize memory cache +memory = Memory(location=get_cache_dir(), verbose=1) + +@memory.cache +def calculate_mean_image(data_path, dataset_config, subsample_rate, subsample_offset, batch_size, num_workers): + """ + Calculate mean image from given training set parameters. + This function is cached to avoid redundant calculations. + """ + # 1. First construct the right training set indices: + sub_dataset = {} + for field in ["transform","extension","trial_pattern"]: + sub_dataset[field] = dataset_config[field] + dataset = SessionFramesTorchvision(data_path,**sub_dataset) + all_indices = np.arange(len(dataset)) + ## Subsample indices + train_inds = all_indices[subsample_offset::subsample_rate] + trainset = Subset(dataset, train_inds) + + ## use dataloader to compute sum image + trainloader = DataLoader(trainset, batch_size=batch_size, num_workers=num_workers) + + sum_im = torch.zeros(dataset[0].shape[1:]) ## should be + for data in tqdm(trainloader): + sum_im += data.sum(axis=0).sum(axis=0) ## sum across the batch and sequence dimensions. + mean = sum_im / len(trainset) + return mean + +class SessionFramesDataModule(pl.LightningDataModule): + """Lightning data module collects together data loading logic frame subsampling, and then calculates the framewise mean of the training dataset. + This datamodule does the following: + 1. Subsamples at a given subsetting rate to generate the training data. + + 2. Checks if we have already calculated the mean image for the given training parameters + 3. Generates the training dataset by subsampling at the given subsampling rate and offset. Calculates mean image if does not exist yet. + 4. Defines datasets which compose given transformations with a subtraction of the training set mean. + + """ + def __init__( + self, + dataset_config, + batch_size, + num_workers, + train_subsample_rate, + train_subsample_offset, + val_subsample_rate, + val_subsample_offset, + trainset_subtract_mean=True, + ): + """ + By default, checks if we have already calculated the image mean on a particular dataset for a particular training set, and if so gets that cached mean. + + Parameters + ---------- + dataset_config : dict + dictionary containing configuration parameters for dataset. Must include: + data_path: str + path of the top level folder of per-trial frames. + transform : any + the transform function that we will apply to all the data. Can be followed by different training, testing transforms for different datasets, and can be None. + Can optionally also include: + extension : str + extension of files which + trial_pattern : str + regex which filters for certain folder prefixes in the trial. + batch_size : int + frames per batch + trainset_subsample_rate : int + take one frame for every `trainset_subsample_rate` frames. + trainset_subsample_offset : int + offset for subsampling. + trainset_subtract_mean : bool + whether we should calculate and subtract the mean of the training set. + """ + super().__init__() + self.data_path = dataset_config.pop("data_path") ## each dataset does not take this argument, so we remove. + self.dataset_config = dataset_config + self.batch_size = batch_size + self.num_workers = num_workers + self.train_subsample_rate = train_subsample_rate + self.train_subsample_offset = train_subsample_offset + self.val_subsample_rate = val_subsample_rate + self.val_subsample_offset = val_subsample_offset + self.subtract_mean = trainset_subtract_mean + + if self.subtract_mean: + print("Calculating mean image") + self.mean_image = calculate_mean_image(self.data_path,dataset_config,self.train_subsample_rate,self.train_subsample_offset,self.batch_size,self.num_workers) + subtract_mean = torchvision.transforms.Lambda(self.subtract_mean_image) + # augment the transformation for our datasets: + if self.dataset_config["transform"] is not None: + mean_normed_transform = torchvision.transforms.Compose([ + self.dataset_config["transform"], + subtract_mean + ]) + else: + mean_normed_transform = torchvision.transforms.Compose([ + subtract_mean + ]) + self.transform = mean_normed_transform + else: + self.transform = self.dataset_config["transform"] + print("Done with setup") + + def subtract_mean_image(self,x): + return x-self.mean_image + + def setup(self,stage): + if stage == "fit": + _ = self.dataset_config.pop("transform") + self.dataset = SessionFramesTorchvision(self.data_path,transform = self.transform,**self.dataset_config) + all_indices = np.arange(len(self.dataset)) + ## Subsample indices + train_inds = all_indices[self.train_subsample_offset::self.train_subsample_rate] + val_inds = all_indices[self.val_subsample_offset::self.val_subsample_rate] + #test_inds = [i for i in all_indices if not (i in train_inds)] + self.trainset = Subset(self.dataset,train_inds) + self.valset = Subset(self.dataset,val_inds) + if stage == "test": + ### the only way to access this right now is to pass setup explicitly right now. + _ = self.dataset_config.pop("transform") + self.dataset = SessionSequenceTorchvision(self.data_path,transform = self.transform,**self.dataset_config) + + def train_dataloader(self,shuffle=True): + dataloader = DataLoader( + self.trainset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=shuffle, + drop_last =False, + pin_memory=True + ) + return dataloader + + def val_dataloader(self): + dataloader = DataLoader( + self.valset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + drop_last =False, + pin_memory=True + ) + return dataloader + + def test_dataloader(self): + return self.val_dataloader() + diff --git a/src/behavioral_autoencoder/dataset.py b/src/behavioral_autoencoder/dataset.py index 12f7ac1..27a3a99 100644 --- a/src/behavioral_autoencoder/dataset.py +++ b/src/behavioral_autoencoder/dataset.py @@ -9,6 +9,7 @@ import cv2 from PIL import Image from torchvision.io import read_image +import torch import torchvision.transforms.functional as F import os import json @@ -95,6 +96,7 @@ class SessionFramesDataset(Dataset): cumsum_n_trials : arraylike cumulative index for frame index across all trials. + """ def __init__(self, base_folder, extension=".png", crop_info={'h_coord': 26}, trial_pattern=None): @@ -277,6 +279,114 @@ def crop_img_proportional(self,img): img = F.crop(img, y_top, x_left, y_bottom - y_top, x_right-x_left) return img +class SessionSequenceTorchvision(Dataset): + """Quite similar to SessionFramesTorchvision, but returns a whole sequence of frames instead of the single frame. + + Attributes + ---------- + base_folder : str + given by parameter at initialization + trial_folders : arraylike + sorted names of per-trial directories + extension : str + extension for frame files. + frame_dict : dict + dictionary with keys given by trial_folder names, and entries arraylikes of frames within each folder. + trial_lengths : list + number of frames within each trial dictionary + cumsum_n_trials : arraylike + cumulative index for frame index across all trials. + transform : any + None or transform function + frame_subset : any + The indices of frames which we will extract from each trial. If not given, will extract all frames for each trial in each batch. + """ + def __init__(self, base_folder, extension=".png", trial_pattern=None, transform = None, frame_subset=None): + """ + Parameters + ---------- + base_folder : string + path to the base folder which contains folders for each individual trial. + extension : string + file extension for frame files (default: ".png") + crop_info : dict + cropping information to be passed to `transform_image`, with one expected key, `h_coord`. + Crops out the image in the original space so that ~h_coord pixels to the right of the + image would be cropped following appropriate image transformation. + trial_pattern : string, optional + Regular expression pattern to match trial folders. If None, all directories are considered + trial folders. Example: r"^\d+_trial$" would match folders like "0_trial", "1_trial", etc. + frame_subset : any + Array like of filenames for frames to extract from each trial. If not given, extracts all frames fr each trial in a single batch. + + """ + self.base_folder = base_folder + self.extension = extension + self.frame_subset = frame_subset + + # Get all items in the base folder + all_items = os.listdir(base_folder) + + # Filter for directories only + self.trial_folders = [ + item for item in all_items + if os.path.isdir(os.path.join(base_folder, item)) + ] + + # Apply regex pattern if provided + if trial_pattern is not None: + pattern = re.compile(trial_pattern) + self.trial_folders = [ + folder for folder in self.trial_folders + if pattern.match(folder) + ] + + self.trial_folders = np.sort(self.trial_folders) + + self.frame_dict = {folder: np.sort(self.filter_frames(base_folder,folder)) for folder in self.trial_folders} + self.trial_lengths = [len(self.frame_dict[folder]) for folder in self.trial_folders] + self.cumsum_n_trials = np.cumsum(self.trial_lengths) + self.transform = transform + + def filter_frames(self,base_folder,folder): + """ + Given a trial folder, filters out extra files to return only those with a given extension, and matching a given frame subset names. . + """ + candidates = os.listdir(os.path.join(base_folder,folder)) + if self.frame_subset is None: + return [f for f in candidates if f.endswith(self.extension)] + else: + return [f for f in candidates if f.endswith(self.extension) and (f in self.frame_subset)] + + def __len__(self): + """ + Required method for pytorch datasets. + """ + return np.sum(self.trial_lengths) + + def __getitem__(self, trial_idx, method = "searchsorted"): + """ + Get all frames which match a given + + Parameters + ---------- + idx: int + integer index into the data. + """ + ## get trial number + + batch = [] + for frame in self.frame_dict[self.trial_folders[trial_idx]]: + image_path = os.path.join( + self.base_folder, + self.trial_folders[trial_idx], + frame) + img = Image.open(image_path) + if self.transform: + img = self.transform(img) + batch.append(img[None,:]) + return torch.cat(batch,dim=0) + class SessionFramesTorchvision(Dataset): """Essentially the same as SessionFramesDataset above, but factors out image transformations into a separate class. Assumes we have a dataset which is organized as a directory of directories, with one directory per trial. @@ -299,9 +409,11 @@ class SessionFramesTorchvision(Dataset): cumulative index for frame index across all trials. transform : any None or transform function + seq_length : int + we end up outputting data which has a sequence dimension for consistency with other implementations as a singleton dimension preceding the others. """ - def __init__(self, base_folder, extension=".png", trial_pattern=None, transform = None): + def __init__(self, base_folder, extension=".png", trial_pattern=None, transform = None, seq_length = 1): """ Parameters ---------- @@ -319,6 +431,8 @@ def __init__(self, base_folder, extension=".png", trial_pattern=None, transform """ self.base_folder = base_folder self.extension = extension + self.seq_length = seq_length + assert seq_length == 1, "can't do more than this. " # Get all items in the base folder all_items = os.listdir(base_folder) @@ -386,6 +500,6 @@ def __getitem__(self, idx, method = "searchsorted"): img = Image.open(image_path) if self.transform: img = self.transform(img) - return img + return img[None,:] diff --git a/src/behavioral_autoencoder/eval.py b/src/behavioral_autoencoder/eval.py new file mode 100644 index 0000000..6a3ff60 --- /dev/null +++ b/src/behavioral_autoencoder/eval.py @@ -0,0 +1,35 @@ +""" +Functions for two kinds of evaluation: + 1. Bulk extraction of latents and predictions from the last timepoint. + 2. Evaluation of the latent space. +""" +from torch.utils.data import DataLoader +import torch +from tqdm import tqdm + +def get_all_predicts_latents(trained_model,datamodule,batch_size,num_workers): + """ + Get latents for the entire video, having trained on a subsampled set. + + """ + full_dataloader = DataLoader(datamodule.dataset,batch_size=batch_size,num_workers=num_workers) + predictions = [] + latents = [] + for batch in tqdm(full_dataloader): + prediction,latent = trained_model(batch) + predictions.append(prediction+full_dataloader.mean_image[None,None,:]) + latents.append(latent) + return torch.concatenate(predictions), torch.concatenate(latents) + +def get_dl_predicts_latents(trained_model,dataloader,mean,batch_size,num_workers): + """ + Get latents for the entire video, having trained on a subsampled set. + + """ + predictions = [] + latents = [] + for batch in tqdm(dataloader): + prediction,latent = trained_model(batch) + predictions.append(prediction+mean[None,None,:]) + latents.append(latent) + return torch.concatenate(predictions), torch.concatenate(latents) diff --git a/src/behavioral_autoencoder/metrics.py b/src/behavioral_autoencoder/metrics.py new file mode 100644 index 0000000..bbb91b2 --- /dev/null +++ b/src/behavioral_autoencoder/metrics.py @@ -0,0 +1,13 @@ +""" +Metrics and losses for network training +""" +import torch + +def L2_loss(batch): + """Given a batch of latents (shaped batch,sequence,activations), calculates the squared norm of the activations across the entire batch and returns shape batch. + + """ + squared_acts = batch**2 + l2_norm = torch.sum(squared_acts,axis=-1) + mean_l2 = torch.mean(l2_norm) + return mean_l2 diff --git a/src/behavioral_autoencoder/module.py b/src/behavioral_autoencoder/module.py new file mode 100644 index 0000000..10d4e96 --- /dev/null +++ b/src/behavioral_autoencoder/module.py @@ -0,0 +1,99 @@ +import pytorch_lightning as pl +import torch +import torch.nn as nn +from behavioral_autoencoder.metrics import L2_loss +from behavioral_autoencoder.networks import SingleSessionAutoEncoder +from torch.optim.lr_scheduler import LinearLR,ChainedScheduler + +models = { + "single_session_autoencoder":SingleSessionAutoEncoder + } + +class Autoencoder_Models(pl.LightningModule): + """ + Abstract base class for Autoencoder models. Handles things like loss definition, model choice, + """ + def __init__(self,hparams): + super().__init__() + self.save_hyperparameters(hparams) + self.reconstruct_criterion = nn.MSELoss() + self.shrink_criterion = L2_loss + def forward(self,batch): + images = batch + latents,predictions = self.model(images) + return latents,predictions + def training_step(self, batch, batch_nb): + predictions,latents = self.forward(batch) + rec_loss = self.reconstruct_criterion(predictions,batch) + shrink_loss = self.shrink_criterion(latents) + loss = rec_loss+self.hparams["train_config"]["l2_weight"]*shrink_loss + self.log("loss/train", loss) + self.log("mse/train", rec_loss) + return loss + def validation_step(self, batch, batch_nb): + predictions,latents = self.forward(batch) + rec_loss = self.reconstruct_criterion(predictions,batch) + shrink_loss = self.shrink_criterion(latents) + loss = rec_loss+self.hparams["train_config"]["l2_weight"]*shrink_loss + self.log("loss/val", loss) + self.log("mse/val", rec_loss) + def test_step(self, batch, batch_nb): + predictions,latents = self.forward(batch) + rec_loss = self.reconstruct_criterion(predictions,batch) + shrink_loss = self.shrink_criterion(latents) + loss = rec_loss+self.hparams["train_config"]["l2_weight"]*shrink_loss + self.log("loss/val", loss) + self.log("mse/val", rec_loss) + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.model.parameters(), + lr=self.hparams["train_config"]["learning_rate"], + weight_decay=self.hparams["train_config"]["weight_decay"], + ) + scheduler = self.setup_scheduler(optimizer) + return [optimizer], [scheduler] + def setup_scheduler(self,optimizer): + """Chooses between the cosine learning rate scheduler, linear scheduler, or step scheduler. + + """ + if self.hparams["train_config"]["scheduler"] == "linear": + scheduler = { + "scheduler": LinearLR( + optimizer,start_factor=self.hparams["train_config"]["start_factor"], + end_factor=self.hparams["train_config"]["end_factor"], + total_iters=self.hparams["train_config"]["warmup_steps"] + ), + "interval": "step", + "name": "learning_rate" + } + elif self.hparams["train_config"]["scheduler"] == "step": + scheduler = { + "scheduler": torch.optim.lr_scheduler.MultiStepLR( + optimizer, milestones = [10,20,30], gamma = 0.1, last_epoch=-1 + ), + "interval": "epoch", + "frequency":1, + "name": "learning_rate", + } + elif self.hparams["train_config"]["scheduler"] == "linear_step": + linear_scheduler = LinearLR( + optimizer,start_factor=self.hparams["train_config"]["start_factor"], + end_factor=self.hparams["train_config"]["end_factor"], + total_iters=self.hparams["train_config"]["warmup_steps"] + ) + step_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, milestones = [7500,15000,22500], gamma = 0.1, last_epoch=-1) + scheduler = { + "scheduler":ChainedScheduler(optimizer=optimizer,schedulers=[linear_scheduler,step_scheduler]), + "interval": "step", + "frequency": 1, + "name": "learning_rate" + } + return scheduler + +class SingleSessionModule(Autoencoder_Models): + """Class for single session autoencoder. + """ + def __init__(self,hparams): + super().__init__(hparams) + self.model = models[hparams["model"]](hparams["model_config"]) diff --git a/src/behavioral_autoencoder/networks.py b/src/behavioral_autoencoder/networks.py new file mode 100644 index 0000000..41461dc --- /dev/null +++ b/src/behavioral_autoencoder/networks.py @@ -0,0 +1,166 @@ +""" +Network architectures. +""" +import torch.nn as nn + +class PreConvBlock(nn.Module): + def __init__(self, + in_channels = 1, + out_channels = 16, + kernel = 3, + stride=1, + pool_size = 2, + use_batch_norm = False): + """ + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + stride (int): Stride for the convolutional layers. + """ + super(PreConvBlock, self).__init__() + + self.layers = nn.ModuleList([ + nn.Conv2d(in_channels, out_channels, + kernel_size=kernel, stride=stride, + padding=(kernel - stride) // 2, + bias= not use_batch_norm), + ]) + if use_batch_norm: + self.layers.append(nn.BatchNorm2d(out_channels)) + + self.layers.append(nn.ReLU(inplace=True)) + self.layers.append(nn.MaxPool2d(kernel_size=pool_size)) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + +class ResidualBlock(nn.Module): + def __init__(self, + n_channels = 16, + kernel = 3, + stride = 1, + use_batch_norm = False, + downsample=None, + pool_size = 4, + n_layers = 3): + """ + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + stride (int): Stride for the convolutional layers. + downsample (nn.Module, optional): Downsampling layer if input and output dimensions differ. + """ + super(ResidualBlock, self).__init__() + + self.layers = nn.ModuleList() + for i_layer in range(n_layers): + self.layers.append(nn.Conv2d(n_channels, n_channels, + kernel_size=kernel, stride=stride, + padding=(kernel - stride) // 2, + bias= not use_batch_norm)) + if use_batch_norm: + self.layers.append(nn.BatchNorm2d(n_channels)) + if i_layer!=(n_layers - 1): + self.layers.append(nn.ReLU(inplace=True)) + + self.downsample = downsample # Optional downsampling layer + self.post_residual_layers = nn.ModuleList([nn.ReLU(inplace=True)]) + if pool_size is not None: + self.post_residual_layers.append(nn.MaxPool2d(kernel_size=pool_size)) + + def forward(self, x): + identity = x + + # Pass through the layers in the ModuleList + for layer in self.layers: + x = layer(x) + + # Apply downsampling to the identity if necessary + if self.downsample: + identity = self.downsample(identity) + + # Add the residual connection + x += identity + for layer in self.post_residual_layers: + x = layer(x) + + return x + +class Encoder(nn.Module): + + def __init__(self, + configs): + """ + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + stride (int): Stride for the convolutional layers. + """ + super(Encoder, self).__init__() + + self.configs = configs + + self.residual_layers = nn.ModuleList() + + for i_blocks in range(configs['num_blocks']): + self.residual_layers.append(PreConvBlock(configs['in_channels_%d'%i_blocks], + configs['out_channels_%d'%i_blocks], + configs['kernel_preconv'], + configs['stride_preconv'], + configs['pool_size_preconv_%d'%i_blocks], + configs['use_batch_norm_preconv'])) + self.residual_layers.append(ResidualBlock(configs['out_channels_%d'%i_blocks], + configs['kernel_residual'], + configs['stride_residual'], + configs['use_batch_norm_residual'], + pool_size=configs['pool_size_residual_%d'%i_blocks], + n_layers=configs['n_layers_residual'])) + + self.linear_layers = nn.ModuleList([nn.Linear(configs['out_conv'], configs['out_linear'], bias = not configs['use_batch_norm_linear']),]) + if configs['use_batch_norm_linear']: + self.linear_layers.append(nn.BatchNorm1d(configs['out_linear'])) + self.linear_layers.append(nn.ReLU(inplace=True)) + self.linear_layers.append(nn.Linear(configs['out_linear'], configs['embed_size'])) + + def forward(self, x): + bs, seq_length, c, h ,w = x.size() + + x = x.view(bs*seq_length, c, h, w) + for layer in self.residual_layers: + x = layer(x) + + x = x.view(bs*seq_length, -1) + for layer in self.linear_layers: + x = layer(x) + + x = x.view(bs, seq_length, -1) + return x + +class SingleSessionDecoder(nn.Module): + def __init__(self, configs): + super(SingleSessionDecoder, self).__init__() + self.configs = configs + self.linear_layer = nn.Linear(configs['embed_size'], configs['image_height']*configs['image_width']) + + def forward(self, x): + bs, seq_length, _ = x.size() + x = self.linear_layer(x) + x = x.view(bs, seq_length, 1, self.configs['image_height'], self.configs['image_width']) + return x + +class SingleSessionAutoEncoder(nn.Module): + """ + This autoencoder takes in data of shape (batch,sequence (for a sequence of images), channels (1 for grayscale), height, width). Thus, it is set up to work with multidimensional, video structured data, although in practice we work with grayscale and can usually flatten across a sequence of images for the examples that we often care about. The latents are of shape batch,sequence,embedding_size. + """ + def __init__(self, configs): + super(SingleSessionAutoEncoder, self).__init__() + self.configs = configs + self.encoder = Encoder(configs) + self.decoder = SingleSessionDecoder(configs) + + def forward(self, x): + z = self.encoder(x) + x = self.decoder(z) + return x, z diff --git a/tests/test_dataloading.py b/tests/test_dataloading.py new file mode 100644 index 0000000..ae07b80 --- /dev/null +++ b/tests/test_dataloading.py @@ -0,0 +1,127 @@ +""" +Test what the dataloader does. +""" +from behavioral_autoencoder.dataset import CropResizeProportion +import json +import os +import cv2 +import numpy as np +from behavioral_autoencoder.dataloading import SessionFramesDataModule + + +here = os.path.abspath(os.path.dirname(__file__)) + +def temp_hierarchical_folder_generator(tmp_path, n_trials=3, n_ims_per_trial=10, extra_files=None): + """Creates a temporary hierarchical folder structure for testing. + + This fixture creates a folder structure that mimics a session with multiple trials, + where each trial contains randomly sampled images. The structure is: + temp_session/ + ├── 0_trial/ + │ ├── frame_000000.png + │ ├── frame_000001.png + │ ├── ... + │ └── extra_file.txt + ├── 1_trial/ + │ ├── frame_000000.png + │ ├── ... + │ └── extra_file.txt + └── ... + + Parameters + ---------- + tmp_path : Path + Pytest fixture providing temporary directory path + n_trials : int, optional + Number of trial folders to create, by default 3 + n_ims_per_trial : int, optional + Number of images to create per trial, by default 10 + extra_files : list of str, optional + List of extra file names to create in each trial directory + + Returns + ------- + Path + Path to the created temporary session directory + """ + # Set random seed for reproducibility + np.random.seed(42) + + # Load example images + example_images = np.load('./test_data/example_images.npy')*255 + + # Create session directory + session_dir = tmp_path / "temp_session" + session_dir.mkdir() + + # Create trial folders and populate with images + for trial in range(n_trials): + trial_dir = session_dir / f"{trial}_trial" + trial_dir.mkdir() + + # Randomly sample images + selected_images = example_images[ + np.random.choice(len(example_images), n_ims_per_trial, replace=True) + ] + + # Save images as PNGs + for i, img in enumerate(selected_images): + img_path = trial_dir / f"frame_{i:06d}.png" + cv2.imwrite(str(img_path), img) + + # Create extra files within each trial folder if specified + if extra_files: + for filename in extra_files: + (trial_dir / filename).touch() + + return session_dir + +class Test_SessionFramesDataModule(): + config_path = os.path.join(here,"..","configs","data_configs","alm_side.json") + def test_init(self,tmp_path): + with open(self.config_path,"r") as f: + crop_config = json.load(f) + # Create a larger dataset + n_trials = 10 # Increased number of trials + n_ims_per_trial = 50 # Increased images per trial + session_dir = temp_hierarchical_folder_generator( + tmp_path, + n_trials=n_trials, + n_ims_per_trial=n_ims_per_trial + ) + + alm_cropping = CropResizeProportion(self.config_path) + data_config = { + "data_path":session_dir, + "transform":alm_cropping, + "extension":".png", + "trial_pattern":None + } + sfdm = SessionFramesDataModule(data_config,10,2,10,1,10,1) + assert sfdm.mean_image.shape == (1,crop_config["target_h"],crop_config["target_w"]) + + def test_train_dataloader(self,tmp_path): + with open(self.config_path,"r") as f: + crop_config = json.load(f) + # Create a larger dataset + n_trials = 10 # Increased number of trials + n_ims_per_trial = 50 # Increased images per trial + session_dir = temp_hierarchical_folder_generator( + tmp_path, + n_trials=n_trials, + n_ims_per_trial=n_ims_per_trial + ) + + alm_cropping = CropResizeProportion(self.config_path) + data_config = { + "data_path":session_dir, + "transform":alm_cropping, + "extension":".png", + "trial_pattern":None + } + sfdm = SessionFramesDataModule(data_config,10,2,10,1,10,1) + sfdm.setup("fit") + dataloader = sfdm.train_dataloader() + + for batch in dataloader: + assert batch.shape[1:] == (1,1,crop_config["target_h"],crop_config["target_w"]) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 897aa60..e67bc32 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -2,7 +2,7 @@ Test functions within the dataset class. Uses test data inside `test_data` folder. """ -from behavioral_autoencoder.dataset import SessionFramesDataset,SessionFramesTorchvision,CropResizeProportion +from behavioral_autoencoder.dataset import SessionFramesDataset,SessionFramesTorchvision,SessionSequenceTorchvision,CropResizeProportion import pytest import numpy as np import time @@ -263,8 +263,43 @@ def test_getitem_performance(self, tmp_path): assert np.array_equal(result_argmax, result_searchsorted), \ f"Methods returned different results for index {idx}" +class Test_SessionSequenceTorchvision: + config_path = os.path.join(here,"..","configs","data_configs","alm_side.json") + def test_init(self, temp_hierarchical_folder): + """Test the SessionFramesTorchvision initialization and basic properties, including with default cropping given by CropResizeProportion with default. + + Tests: + 1. Dataset can be initialized + 2. Number of trials matches fixture + 3. Image dimensions and format are correct + 4. Dataset length matches expected total frames + """ + alm_cropping = CropResizeProportion(self.config_path) + len_sequence = 3 + dataset = SessionSequenceTorchvision(temp_hierarchical_folder,transform = alm_cropping,frame_subset = [f"frame_{i:06d}.png" for i in range(len_sequence)]) + + # Test number of trials + assert len(dataset.trial_folders) == 3, "Should have 3 trials by default" + + # Test image properties + first_image = dataset[0] + assert isinstance(first_image, torch.Tensor), "Dataset should return numpy arrays" + assert len(first_image.shape) == 4 + assert first_image.shape[0] == len_sequence, "Sequence dimension should be 0" + assert first_image.shape[1] == 1, "Images should be 2D grayscale (H,W)" + assert first_image.dtype == torch.float32, "Images should be float32" + + # Test dataset length + expected_length = 3 * len_sequence # n_trials * n_ims_per_trial + assert len(dataset) == expected_length, f"Dataset should have {expected_length} total frames" + + # Test all images are readable + for i in range(3): + img = dataset[i] + assert img is not None, f"Failed to load sequence at index {i}" + class Test_SessionFramesTorchvision: - config_path = os.path.join(here,"..","configs","crop_configs","alm_side.json") + config_path = os.path.join(here,"..","configs","data_configs","alm_side.json") def test_init(self, temp_hierarchical_folder): """Test the SessionFramesTorchvision initialization and basic properties, including with default cropping given by CropResizeProportion with default. @@ -285,7 +320,9 @@ def test_init(self, temp_hierarchical_folder): # Test image properties first_image = dataset[0] assert isinstance(first_image, torch.Tensor), "Dataset should return numpy arrays" - assert first_image.shape[0] == 1, "Images should be 2D grayscale (H,W)" + assert len(first_image.shape) == 4 + assert first_image.shape[0] == 1, "Sequence dimension should be 0" + assert first_image.shape[1] == 1, "Images should be 2D grayscale (H,W)" assert first_image.dtype == torch.float32, "Images should be float32" # Test dataset length diff --git a/tests/test_module.py b/tests/test_module.py new file mode 100644 index 0000000..4e689a8 --- /dev/null +++ b/tests/test_module.py @@ -0,0 +1,184 @@ +""" +""" +from behavioral_autoencoder.module import SingleSessionModule +from behavioral_autoencoder.dataset import CropResizeProportion +from behavioral_autoencoder.dataloading import SessionFramesDataModule +from pytorch_lightning.loggers import TensorBoardLogger +import numpy as np +from torch.utils.data import DataLoader +import pytorch_lightning as pl +import cv2 +import os +import json + +here = os.path.join(os.path.abspath(os.path.dirname(__file__))) + +def temp_hierarchical_folder_generator(tmp_path, n_trials=3, n_ims_per_trial=10, extra_files=None): + """Creates a temporary hierarchical folder structure for testing. + + This fixture creates a folder structure that mimics a session with multiple trials, + where each trial contains randomly sampled images. The structure is: + temp_session/ + ├── 0_trial/ + │ ├── frame_000000.png + │ ├── frame_000001.png + │ ├── ... + │ └── extra_file.txt + ├── 1_trial/ + │ ├── frame_000000.png + │ ├── ... + │ └── extra_file.txt + └── ... + + Parameters + ---------- + tmp_path : Path + Pytest fixture providing temporary directory path + n_trials : int, optional + Number of trial folders to create, by default 3 + n_ims_per_trial : int, optional + Number of images to create per trial, by default 10 + extra_files : list of str, optional + List of extra file names to create in each trial directory + + Returns + ------- + Path + Path to the created temporary session directory + """ + # Set random seed for reproducibility + np.random.seed(42) + + # Load example images + example_images = np.load('./test_data/example_images.npy')*255 + + # Create session directory + session_dir = tmp_path / "temp_session" + session_dir.mkdir() + + # Create trial folders and populate with images + for trial in range(n_trials): + trial_dir = session_dir / f"{trial}_trial" + trial_dir.mkdir() + + # Randomly sample images + selected_images = example_images[ + np.random.choice(len(example_images), n_ims_per_trial, replace=True) + ] + + # Save images as PNGs + for i, img in enumerate(selected_images): + img_path = trial_dir / f"frame_{i:06d}.png" + cv2.imwrite(str(img_path), img) + + # Create extra files within each trial folder if specified + if extra_files: + for filename in extra_files: + (trial_dir / filename).touch() + + return session_dir + +class Test_SingleSessionModule(): + model_config_path = os.path.join(here,"..","configs","model_configs","alm_default.json") + train_config_path = os.path.join(here,"..","configs","train_configs","alm_default.json") + crop_config_path = os.path.join(here,"..","configs","data_configs","alm_side.json") + def test_init(self): + with open(self.model_config_path,"r") as f: + model_config = json.load(f) + with open(self.train_config_path,"r") as f: + train_config = json.load(f) + hparams = { + "model":"single_session_autoencoder", + "model_config":model_config, + "train_config":train_config + } + ssm = SingleSessionModule(hparams) + def test_baby_train_loop(self,tmp_path): + model_config_path = os.path.join(here,"..","configs","model_configs","alm_default.json") + train_config_path = os.path.join(here,"..","configs","train_configs","alm_default.json") + + with open(self.model_config_path,"r") as f: + model_config = json.load(f) + with open(self.train_config_path,"r") as f: + train_config = json.load(f) + with open(self.crop_config_path,"r") as f: + crop_config = json.load(f) + + hparams = { + "model":"single_session_autoencoder", + "model_config":model_config, + "train_config":train_config + } + ssm = SingleSessionModule(hparams) + + session_dir = temp_hierarchical_folder_generator( + tmp_path, + ) + + alm_cropping = CropResizeProportion(self.crop_config_path) + data_config = { + "data_path":session_dir, + "transform":alm_cropping, + "extension":".png", + "trial_pattern":None + } + sfdm = SessionFramesDataModule(data_config,10,2,10,1,10,1) + logger = TensorBoardLogger("tb_logs",name="test_single_session_auto",log_graph=True) + trainer = pl.Trainer( + max_epochs=1, + accelerator='cpu', # Explicitly use CPU for testing + enable_checkpointing=False, # Disable for testing + logger=logger, # Disable logging for testing + enable_progress_bar=True, # See training progress + ) + trainer.fit(ssm,sfdm) + assert trainer.current_epoch == 1, "Training should complete 2 epochs" + assert trainer.global_step > 0, "Should have completed some training steps" + def test_eval_save(self,tmp_path): + model_config_path = os.path.join(here,"..","configs","model_configs","alm_default.json") + train_config_path = os.path.join(here,"..","configs","train_configs","alm_default.json") + + with open(self.model_config_path,"r") as f: + model_config = json.load(f) + with open(self.train_config_path,"r") as f: + train_config = json.load(f) + with open(self.crop_config_path,"r") as f: + crop_config = json.load(f) + + hparams = { + "model":"single_session_autoencoder", + "model_config":model_config, + "train_config":train_config + } + ssm = SingleSessionModule(hparams) + + session_dir = temp_hierarchical_folder_generator( + tmp_path, + ) + + alm_cropping = CropResizeProportion(self.crop_config_path) + data_config = { + "data_path":session_dir, + "transform":alm_cropping, + "extension":".png", + "trial_pattern":None + } + sfdm = SessionFramesDataModule(data_config,10,2,10,1,10,1) + logger = TensorBoardLogger("tb_logs",name="test_single_session_auto",log_graph=True) + trainer = pl.Trainer( + max_epochs=1, + accelerator='cpu', # Explicitly use CPU for testing + enable_checkpointing=False, # Disable for testing + logger=logger, # Disable logging for testing + enable_progress_bar=True, # See training progress + ) + trainer.fit(ssm,sfdm) + # Extract out some latents: + full_dataloader = DataLoader(sfdm.dataset,batch_size=10) + predictions = [] + latents = [] + for batch in full_dataloader: + prediction,latent = ssm(batch) + predictions.append() + import pdb; pdb.set_trace() + diff --git a/tests/test_networks.py b/tests/test_networks.py new file mode 100644 index 0000000..ed23580 --- /dev/null +++ b/tests/test_networks.py @@ -0,0 +1,28 @@ +""" +Test that networks work as intended. +""" +import os +import torch +import json +from behavioral_autoencoder.networks import SingleSessionAutoEncoder + +here = os.path.abspath(os.path.dirname(__file__)) + +class Test_SingleSessionAutoEncoder(): + config_path = os.path.join(here,"..","configs","model_configs","alm_default.json") + def test_init(self): + with open(self.config_path,"r") as f: + config = json.load(f) + ssae = SingleSessionAutoEncoder(config) + def test_forward(self): + with open(self.config_path,"r") as f: + config = json.load(f) + batch_size=10 + dummy_input = torch.zeros(10,1,1,config["image_height"],config["image_width"]) + ssae = SingleSessionAutoEncoder(config) + output = ssae(dummy_input) + assert output[0].shape == (10,1,1,config["image_height"],config["image_width"]) + assert output[1].shape == (10,1,config["embed_size"]) + + +