Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bindsnet/datasets/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def crop_sample(sample):
opts = {}
image, bb = sample["image"], sample["bb"]
orig_bbox = BoundingBox(bb[0], bb[1], bb[2], bb[3])
(output_image, pad_image_location, edge_spacing_x, edge_spacing_y) = cropPadImage(
output_image, pad_image_location, edge_spacing_x, edge_spacing_y = cropPadImage(
orig_bbox, image
)
new_bbox = BoundingBox(0, 0, 0, 0)
Expand Down
9 changes: 2 additions & 7 deletions bindsnet/datasets/torchvision_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,13 @@ def create_torchvision_dataset_wrapper(ds_type):
ds_type = getattr(torchDB, ds_type)

class TorchvisionDatasetWrapper(ds_type):
__doc__ = (
"""BindsNET torchvision dataset wrapper for:
__doc__ = """BindsNET torchvision dataset wrapper for:

The core difference is the output of __getitem__ is no longer
(image, label) rather a dictionary containing the image, label,
and their encoded versions if encoders were provided.

\n\n"""
+ str(ds_type)
if ds_type.__doc__ is None
else ds_type.__doc__
)
\n\n""" + str(ds_type) if ds_type.__doc__ is None else ds_type.__doc__

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion bindsnet/learning/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ def _conv1d_connection_update(self, **kwargs) -> None:
``AbstractConnection`` class.
"""
# Get convolutional layer parameters.
(out_channels, in_channels, kernel_size) = self.connection.w.size()
out_channels, in_channels, kernel_size = self.connection.w.size()
padding, stride = self.connection.padding, self.connection.stride
batch_size = self.source.batch_size

Expand Down
194 changes: 193 additions & 1 deletion bindsnet/network/topology_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def link(self, parent_feature) -> None:
Allow two features to share tensor values
"""

valid_features = (Probability, Weight, Bias, Intensity)
valid_features = (Probability, Delay, Weight, Bias, Intensity)

assert isinstance(self, valid_features), f"A {self} cannot use feature linking"
assert isinstance(
Expand Down Expand Up @@ -464,6 +464,198 @@ def assert_valid_range(self):
), f"Invalid range for feature {self.name}: the min value must be of type torch.Tensor, float, or int"


class Delay(AbstractFeature):
def __init__(
self,
name: str,
value: Union[torch.Tensor, float, int] = None,
range: Optional[Sequence[float]] = None,
norm: Optional[Union[torch.Tensor, float, int]] = None,
learning_rule: Optional[bindsnet.learning.LearningRule] = None,
nu: Optional[Union[list, tuple]] = None,
reduction: Optional[callable] = None,
decay: float = 0.0,
max_delay: Optional[int] = 32,
delay_decay: Optional[float] = 0, # TODO: Make this global + lambda
drop_late_spikes: Optional[bool] = False,
refractory: Optional[bool] = False, # TODO: Change this name
normalize_delays: Optional[
bool
] = False, # force normalize delays instead of clipping
) -> None:
# language=rst
"""
Delays outgoing signals based on the values of :code:`value` and :code:`max_delay`. Delays are calculated as
being :code:`value` * :code:`max_delay`, where :code: `value` is in range [0, 1]
:param name: Name of the feature
:param value: Unscaled delays. Unscaled implies that these values are in [0, 1], and will be multiplied by :code:`max_delay` to determine delay time
:param range: Range of acceptable values for the :code:`value` parameter
:param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each sample
and after the value has been updated by the learning rule (if there is one)
:param learning_rule: Rule which will modify the :code:`value` after each sample
:param nu: Learning rate for the learning rule
:param reduction: Method for reducing parameter updates along the minibatch
dimension
:param decay: Constant multiple to decay weights by on each iteration
:param max_delay: Maximum possible delay
:param delay_decay: Decay :code:`value` by this amount every time step
:param drop_late_spikes: Surpress spikes when delay is at maximum
:param refractory: Block spikes in synapse until earlier ones pass
:param normalize: Force normalize delay every run instead of clipping values
"""

### Assertions ###
super().__init__(
name=name,
value=value,
range=[0, 1], # note: Value isn't used, not 'None' to avoid errors
norm=norm,
learning_rule=learning_rule,
nu=nu,
reduction=reduction,
decay=decay,
)
self.max_delay = max_delay
self.delay_decay = delay_decay
self.drop_late_spikes = drop_late_spikes
self.refractory = refractory
self.normalize_delays = normalize_delays

def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
value = self.value.clone().detach().flatten()
if self.normalize_delays:
# force normalize delay values
tmp_min, tmp_max = torch.min(value), torch.max(value)
if tmp_max > 1 or tmp_min < 0:
value = (value - tmp_min) / (tmp_max - tmp_min)
else:
# force clip delay values
value = torch.clamp(value, 0, 1)

# Generate new delays for insertion into buffers
delays = ((1 - value) * (self.max_delay - 1)).long()

if self.refractory:
# TODO: Is there a reason why this is in here?
if self.drop_late_spikes:
conn_spikes[delays == self.max_delay] = 0

# Prevent additional spikes if one is already on the synapse
conn_spikes &= self.refrac_count <= 0
self.refrac_count -= 1
bool_spikes = conn_spikes.bool()
self.refrac_count[bool_spikes] = delays[bool_spikes]

# add circular time index to delays
# TODO: Dead spikes of delay = self.dmax don't properly die if self.time_idx > 0
delays = (delays + self.time_idx) % self.max_delay

# Fill the delay buffer, according to connection delays
# |delay_buffer| = [source.n * target.n, max_delay]
# TODO: Can we remove .float() for performance? (Change delay buffer type?)
flattened_conn_spikes = conn_spikes.flatten().float()
self.delay_buffer[self.delays_idx, delays] = flattened_conn_spikes # .bool()

# Outgoing signal is spikes scheduled to fire at time_idx
# TODO: Detach + Clone likely not efficient as passing reference to buffer at current time index; efficiency
out_signal = (
self.delay_buffer[:, self.time_idx]
.view(self.source_n, self.target_n)
.detach()
.clone()
)

# Clear transmitted spikes
self.delay_buffer[:, self.time_idx] = 0.0

# Suppress max delays
if self.drop_late_spikes and not self.refractory:
late_spikes_time = (self.time_idx - 1) % self.max_delay
self.delay_buffer[:, late_spikes_time] = 0.0

# Increment circular time pointer
self.time_idx = (self.time_idx + 1) % self.max_delay

# TODO: Remember to move this to global
# Decay
if self.delay_decay:
self.delay_buffer = self.delay_buffer - self.delay_decay.to("cuda")
self.delay_buffer[self.delay_decay < 0] = (
0 # TODO: Determine if this is faster than clamp(min=0)
)

return out_signal

def reset_state_variables(self) -> None:
super().reset_state_variables()

# Reset time index and empty buffer
self.time_idx = 0
self.delay_buffer.zero_()

def prime_feature(self, connection, device, **kwargs) -> None:
#### Initialize value ####
if self.value is None:
self.initialize_value = lambda: torch.clamp(
torch.rand(
(connection.source.n, connection.target.n),
dtype=torch.float32,
device=device,
),
self.range[0],
self.range[1],
)
else:
self.value = self.value.to(torch.float32).to(device)

super().prime_feature(connection, device, **kwargs)

#### Initialize additional class variables ####
self.delays_idx = torch.arange(
0, connection.source.n * connection.target.n, dtype=torch.int32
).to(device)
self.delay_buffer = torch.zeros(
connection.source.n * connection.target.n,
self.max_delay,
dtype=torch.float32,
).to(device)
self.time_idx = 0
self.source_n = connection.source.n
self.target_n = connection.target.n

# Tensor necessary for interaction with delay buffer
if self.delay_decay:
self.delay_decay = torch.tensor([self.delay_decay])

if self.refractory:
self.refrac_count = torch.zeros(
connection.source.n * connection.target.n,
dtype=torch.long,
device=device,
)

def assert_valid_range(self):
super().assert_valid_range()

r = self.range

## Check min greater than 0 ##
if isinstance(r[0], torch.Tensor):
assert (
r[0] >= 0
).all(), (
f"Invalid range for feature {self.name}: a min value is less than 0"
)
elif isinstance(r[0], (float, int)):
assert (
r[0] >= 0
), f"Invalid range for feature {self.name}: the min value is less than 0"
else:
assert (
False
), f"Invalid range for feature {self.name}: the min value must be of type torch.Tensor, float, or int"


class Mask(AbstractFeature):
def __init__(
self,
Expand Down
1 change: 0 additions & 1 deletion examples/benchmark/sparse_vs_dense_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from bindsnet.evaluation import all_activity, assign_labels, proportion_weighting


parser = argparse.ArgumentParser()
parser.add_argument("--benchmark_type", choices=["memory", "runtime"], default="memory")
args = parser.parse_args()
Expand Down
1 change: 0 additions & 1 deletion examples/mnist/MCC_reservoir.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from bindsnet.network.topology import MulticompartmentConnection
from bindsnet.utils import get_square_weights


parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--n_neurons", type=int, default=500)
Expand Down
Loading