Skip to content

Commit 9e527ac

Browse files
authored
Created create_random_permute function (#115)
* Update functional.py * Update __init__.py
1 parent 6eee8f7 commit 9e527ac

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed

torchhd/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
inverse,
2626
negative,
2727
cleanup,
28+
create_random_permute,
2829
randsel,
2930
multirandsel,
3031
soft_quantize,
@@ -73,6 +74,7 @@
7374
"inverse",
7475
"negative",
7576
"cleanup",
77+
"create_random_permute",
7678
"randsel",
7779
"multirandsel",
7880
"soft_quantize",

torchhd/functional.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import Type, Union
2+
from typing import Type, Union, Callable
33
import torch
44
from torch import LongTensor, FloatTensor, Tensor
55
from collections import deque
@@ -25,6 +25,7 @@
2525
"inverse",
2626
"negative",
2727
"cleanup",
28+
"create_random_permute",
2829
"hard_quantize",
2930
"soft_quantize",
3031
"hamming_similarity",
@@ -661,6 +662,50 @@ def permute(input: VSA_Model, *, shifts=1) -> VSA_Model:
661662
return input.permute(shifts)
662663

663664

665+
666+
def create_random_permute(dim: int) -> Callable[[VSA_Model, int], VSA_Model]:
667+
r"""Creates random permutation functions.
668+
669+
Args:
670+
dim (int): dimension of the hypervectors
671+
672+
Examples::
673+
674+
>>> a = torchhd.random_hv(3, 10)
675+
>>> a
676+
tensor([[-1., 1., 1., 1., -1., -1., -1., -1., 1., -1.],
677+
[-1., -1., -1., 1., -1., 1., -1., -1., 1., -1.],
678+
[ 1., 1., 1., -1., -1., 1., -1., 1., 1., 1.]])
679+
>>> p = torchhd.create_random_permute(10)
680+
>>> p(a, 2)
681+
tensor([[ 1., 1., -1., -1., -1., 1., -1., -1., 1., -1.],
682+
[ 1., -1., -1., -1., 1., 1., -1., -1., -1., -1.],
683+
[ 1., 1., 1., -1., 1., -1., -1., 1., 1., 1.]])
684+
>>> p(a, -2)
685+
tensor([[-1., 1., 1., 1., -1., -1., -1., -1., 1., -1.],
686+
[-1., -1., -1., 1., -1., 1., -1., -1., 1., -1.],
687+
[ 1., 1., 1., -1., -1., 1., -1., 1., 1., 1.]])
688+
689+
"""
690+
691+
forward = torch.randperm(dim)
692+
backward = torch.empty_like(forward)
693+
backward[forward] = torch.arange(dim)
694+
695+
def permute(input: VSA_Model, shifts: int = 1) -> VSA_Model:
696+
y = input
697+
if shifts > 0:
698+
for _ in range(shifts):
699+
y = y[..., forward]
700+
elif shifts < 0:
701+
for _ in range(shifts):
702+
y = y[..., backward]
703+
return y
704+
705+
return permute
706+
707+
708+
664709
def inverse(input: VSA_Model) -> VSA_Model:
665710
r"""Inverse for the binding operation.
666711

0 commit comments

Comments
 (0)