Skip to content

Commit d9d532f

Browse files
committed
Changed list to sequence
1 parent c0a4926 commit d9d532f

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/progpy/data_models/lstm_model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright © 2021 United States Government as represented by the Administrator of the
22
# National Aeronautics and Space Administration. All Rights Reserved.
33

4+
from collections import abc
45
from itertools import chain
56
import matplotlib.pyplot as plt
67
from numbers import Number
@@ -476,8 +477,8 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs):
476477
raise ValueError(f"layers must be greater than 0, got {params['layers']}")
477478
if np.isscalar(params['units']):
478479
params['units'] = [params['units'] for _ in range(params['layers'])]
479-
if not isinstance(params['units'], (list, np.ndarray, tuple)):
480-
raise TypeError(f"units must be a list of integers, not {type(params['units'])}")
480+
if not isinstance(params['units'], (abc.Sequence, np.ndarray)):
481+
raise TypeError(f"units must be a Sequence (e.g., list or tuple) of integers, not {type(params['units'])}")
481482
if len(params['units']) != params['layers']:
482483
raise ValueError(f"units must be a list of integers of length {params['layers']}, got {params['units']}")
483484
for i in range(params['layers']):
@@ -487,7 +488,7 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs):
487488
raise TypeError(f"dropout must be an float greater than or equal to 0, not {type(params['dropout'])}")
488489
if params['dropout'] < 0:
489490
raise ValueError(f"dropout must be greater than or equal to 0, got {params['dropout']}")
490-
if not isinstance(params['activation'], (list, np.ndarray)):
491+
if not isinstance(params['activation'], (abc.Sequence, np.ndarray)):
491492
params['activation'] = [params['activation'] for _ in range(params['layers'])]
492493
if not np.isscalar(params['validation_split']):
493494
raise TypeError(f"validation_split must be an float between 0 and 1, not {type(params['validation_split'])}")

0 commit comments

Comments
 (0)