99
1010import torchhd
1111from torchhd import embeddings
12+ from torchhd .models import Centroid
1213from torchhd .datasets import EMGHandGestures
1314
1415device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
@@ -27,31 +28,23 @@ def transform(x):
2728 return x [SUBSAMPLES ]
2829
2930
30- class Model (nn .Module ):
31- def __init__ (self , num_classes , timestamps , channels ):
32- super (Model , self ).__init__ ()
31+ class Encoder (nn .Module ):
32+ def __init__ (self , out_features , timestamps , channels ):
33+ super (Encoder , self ).__init__ ()
3334
34- self .channels = embeddings .Random (channels , DIMENSIONS )
35- self .timestamps = embeddings .Random (timestamps , DIMENSIONS )
36- self .signals = embeddings .Level (NUM_LEVELS , DIMENSIONS , high = 20 )
35+ self .channels = embeddings .Random (channels , out_features )
36+ self .timestamps = embeddings .Random (timestamps , out_features )
37+ self .signals = embeddings .Level (NUM_LEVELS , out_features , high = 20 )
3738
38- self .classify = nn .Linear (DIMENSIONS , num_classes , bias = False )
39- self .classify .weight .data .fill_ (0.0 )
40-
41- def encode (self , x : torch .Tensor ) -> torch .Tensor :
42- signal = self .signals (x )
39+ def forward (self , input : torch .Tensor ) -> torch .Tensor :
40+ signal = self .signals (input )
4341 samples = torchhd .bind (signal , self .channels .weight .unsqueeze (0 ))
4442 samples = torchhd .bind (signal , self .timestamps .weight .unsqueeze (1 ))
4543
4644 samples = torchhd .multiset (samples )
4745 sample_hv = torchhd .ngrams (samples , n = N_GRAM_SIZE )
4846 return torchhd .hard_quantize (sample_hv )
4947
50- def forward (self , x : torch .Tensor ) -> torch .Tensor :
51- enc = self .encode (x )
52- logit = self .classify (enc )
53- return logit
54-
5548
5649def experiment (subjects = [0 ]):
5750 print ("List of subjects " + str (subjects ))
@@ -66,29 +59,32 @@ def experiment(subjects=[0]):
6659 train_ld = data .DataLoader (train_ds , batch_size = BATCH_SIZE , shuffle = True )
6760 test_ld = data .DataLoader (test_ds , batch_size = BATCH_SIZE , shuffle = False )
6861
62+ encode = Encoder (DIMENSIONS , ds [0 ][0 ].size (- 2 ), ds [0 ][0 ].size (- 1 ))
63+ encode = encode .to (device )
64+
6965 num_classes = len (ds .classes )
70- model = Model ( num_classes , ds [ 0 ][ 0 ]. size ( - 2 ), ds [ 0 ][ 0 ]. size ( - 1 ) )
66+ model = Centroid ( DIMENSIONS , num_classes )
7167 model = model .to (device )
7268
7369 with torch .no_grad ():
74- for samples , labels in tqdm (train_ld , desc = "Training" ):
70+ for samples , targets in tqdm (train_ld , desc = "Training" ):
7571 samples = samples .to (device )
76- labels = labels .to (device )
72+ targets = targets .to (device )
7773
78- samples_hv = model .encode (samples )
79- model .classify .weight [labels ] += samples_hv
80-
81- model .classify .weight [:] = F .normalize (model .classify .weight )
74+ sample_hv = encode (samples )
75+ model .add (sample_hv , targets )
8276
8377 accuracy = torchmetrics .Accuracy ("multiclass" , num_classes = num_classes )
8478
8579 with torch .no_grad ():
86- for samples , labels in tqdm (test_ld , desc = "Testing" ):
80+ model .normalize ()
81+
82+ for samples , targets in tqdm (test_ld , desc = "Testing" ):
8783 samples = samples .to (device )
8884
89- outputs = model (samples )
90- predictions = torch . argmax ( outputs , dim = - 1 )
91- accuracy .update (predictions .cpu (), labels )
85+ sample_hv = encode (samples )
86+ output = model ( sample_hv , dot = True )
87+ accuracy .update (output .cpu (), targets )
9288
9389 print (f"Testing accuracy of { (accuracy .compute ().item () * 100 ):.3f} %" )
9490
0 commit comments