-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathPrepare_Data.py
More file actions
61 lines (46 loc) · 2.25 KB
/
Prepare_Data.py
File metadata and controls
61 lines (46 loc) · 2.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# -*- coding: utf-8 -*-
"""
Load datasets for models
@author: jpeeples
"""
## Python standard libraries
from __future__ import print_function
from __future__ import division
## PyTorch dependencies
import torch
## Local external libraries
from Utils.Get_min_max import get_min_max_minibatch
from Utils.Get_min_max_zero import get_min_max_minibatch_zero
from Utils.Get_standarize import get_standardization_minibatch
from Datasets.DeepShipSegments import DeepShipSegments
from Datasets.Get_preprocessed_data import process_data
def Prepare_DataLoaders(Network_parameters):
Dataset = Network_parameters['Dataset']
data_dir = Network_parameters['data_dir']
sample_rate=Network_parameters['sample_rate'][Dataset]
segment_length=Network_parameters['segment_length'][Dataset]
process_data(sample_rate=sample_rate, segment_length=segment_length)
#Change input to network based on models
#If TDNN or HLTDNN, number of input features is 1
#Else (CNN), replicate input to be 3 channels
#If number of input channels is 3 for TDNN, RGB will be set to False
if (Network_parameters['Model_name'] == 'TDNN' and Network_parameters['TDNN_feats'][Dataset]):
RGB = False
else:
RGB = True
if Dataset == 'DeepShip':
train_dataset = DeepShipSegments(data_dir, partition='train')
val_dataset = DeepShipSegments(data_dir, partition='val')
test_dataset = DeepShipSegments(data_dir, partition='test')
else:
raise RuntimeError('Dataset not implemented')
#Create dictionary of datasets
image_datasets = {'train': train_dataset, 'val': val_dataset, 'test': test_dataset}
# Create training and test dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=Network_parameters['batch_size'][x],
shuffle=True,
num_workers=Network_parameters['num_workers'],
pin_memory=Network_parameters['pin_memory'])
for x in ['train', 'val','test']}
return dataloaders_dict