-
Notifications
You must be signed in to change notification settings - Fork 20
Adding discrete control functionality to libemg #129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request adds discrete control functionality to libemg, enabling discrete gesture recognition as an alternative to continuous classification. The discrete mode allows recording EMG data through spacebar presses instead of timers, and includes new classifier models and algorithms designed for discrete gesture detection.
Changes:
- Added discrete parameter to GUI, data collection, and data processing pipeline to support spacebar-based recording
- Implemented OnlineDiscreteClassifier for real-time discrete gesture detection with optional rejection thresholds and prediction buffering
- Added three discrete classification models: MVLDA (majority vote LDA), DTWClassifier (DTW-based k-NN), and MyoCrossUserPretrained (pretrained PyTorch model)
Reviewed changes
Copilot reviewed 10 out of 11 changed files in this pull request and generated 25 comments.
Show a summary per file
| File | Description |
|---|---|
| libemg/gui.py | Added discrete parameter to default args, implemented UI scaling and font loading for better cross-platform display |
| libemg/feature_extractor.py | Added discrete parameter to extract_features to handle list of templates separately, refactored into _extract_features_single helper |
| libemg/emg_predictor.py | Added OnlineDiscreteClassifier class for real-time discrete gesture classification with buffering and rejection threshold support |
| libemg/data_handler.py | Added discrete parameter to parse_windows to keep windows from each rep separate instead of concatenating them |
| libemg/_gui/_data_collection_panel.py | Added discrete mode UI controls and play_collection_visual_discrete method for spacebar-controlled recording |
| libemg/_discrete_models/init.py | Created new module for discrete models with imports for MVLDA, DTW, and MyoCrossUser |
| libemg/_discrete_models/MVLDA.py | Implemented majority vote LDA classifier with soft voting predict_proba |
| libemg/_discrete_models/DTW.py | Implemented DTW-based k-NN classifier with weighted probability estimation |
| libemg/_discrete_models/MyoCrossUser.py | Implemented pretrained PyTorch model loader with automatic model download |
| libemg/init.py | Added import for _discrete_models module |
| .gitignore | Added *.model to ignore downloaded model files |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| import torch | ||
| import torch.nn as nn | ||
| import numpy as np |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The MyoCrossUser module imports PyTorch (torch and torch.nn) but torch is not listed in requirements.txt. This will cause an ImportError when users try to use this pretrained model. Add torch to requirements.txt or document it as an optional dependency for this specific model.
| def __init__(self, | ||
| online_data_handler, | ||
| args={'media_folder': 'images/', 'data_folder':'data/', 'num_reps': 3, 'rep_time': 5, 'rest_time': 3, 'auto_advance': True}, | ||
| args={'media_folder': 'images/', 'data_folder':'data/', 'num_reps': 3, 'rep_time': 5, 'rest_time': 3, 'auto_advance': True, 'discrete': False}, |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a mutable default argument (dictionary) is a Python anti-pattern. If the same default object is modified, it will persist across function calls. Change to 'args=None' and handle the default dictionary creation in the function body.
| return | ||
|
|
||
| model_dir = os.path.dirname(self.model_path) | ||
| os.makedirs(model_dir, exist_ok=True) |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When model_dir is an empty string (which can happen if model_path is just a filename like 'Discrete.model'), os.makedirs(model_dir, exist_ok=True) will try to create an empty directory which could cause unexpected behavior. Add a check to only create the directory if model_dir is not empty.
| os.makedirs(model_dir, exist_ok=True) | |
| if model_dir: | |
| os.makedirs(model_dir, exist_ok=True) |
libemg/emg_predictor.py
Outdated
| import time | ||
| import numpy as np | ||
| from libemg.feature_extractor import FeatureExtractor | ||
| from libemg.utils import get_windows | ||
|
|
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These imports are redundant as they are already imported at the top of the file. The imports for time, numpy, FeatureExtractor, and get_windows are already present at lines 22, 14, 11, and 31 respectively. Remove these duplicate import statements.
| import time | |
| import numpy as np | |
| from libemg.feature_extractor import FeatureExtractor | |
| from libemg.utils import get_windows | |
libemg/_discrete_models/__init__.py
Outdated
| from libemg._discrete_models import MVLDA | ||
| from libemg._discrete_models import DTW | ||
| from libemg._discrete_models import MyoCrossUser No newline at end of file |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The import syntax is incorrect. These should be importing the classes from the modules, not importing the module itself. Change to: 'from libemg._discrete_models.MVLDA import MVLDA', 'from libemg._discrete_models.DTW import DTWClassifier', and 'from libemg._discrete_models.MyoCrossUser import MyoCrossUserPretrained, DiscreteClassifier'.
| from libemg._discrete_models import MVLDA | |
| from libemg._discrete_models import DTW | |
| from libemg._discrete_models import MyoCrossUser | |
| from libemg._discrete_models.MVLDA import MVLDA | |
| from libemg._discrete_models.DTW import DTWClassifier | |
| from libemg._discrete_models.MyoCrossUser import MyoCrossUserPretrained, DiscreteClassifier |
| class OnlineDiscreteClassifier: | ||
| """OnlineDiscreteClassifier. | ||
| Real-time discrete gesture classifier that detects individual gestures from EMG data. | ||
| Unlike continuous classifiers, this classifier is designed for detecting discrete, | ||
| transient gestures and outputs a prediction only when a gesture is detected. | ||
| Parameters | ||
| ---------- | ||
| odh: OnlineDataHandler | ||
| An online data handler object for streaming EMG data. | ||
| model: object | ||
| A trained model with a predict_proba method (e.g., from libemg discrete models). | ||
| window_size: int | ||
| The number of samples in a window. | ||
| window_increment: int | ||
| The number of samples that advances before the next window. | ||
| null_label: int | ||
| The label corresponding to the null/no gesture class. | ||
| feature_list: list or None | ||
| A list of features that will be extracted during real-time classification. | ||
| Pass in None if the model expects raw windowed data. | ||
| template_size: int | ||
| The maximum number of samples to use for gesture template matching. | ||
| min_template_size: int, default=None | ||
| The minimum number of samples required before attempting classification. | ||
| If None, defaults to template_size. | ||
| key_mapping: dict, default=None | ||
| A dictionary mapping gesture names to keyboard keys for automated key presses. | ||
| Requires pyautogui to be installed. | ||
| feature_dic: dict, default=None | ||
| A dictionary containing feature extraction parameters. | ||
| gesture_mapping: dict, default=None | ||
| A dictionary mapping class indices to gesture names for debug output. | ||
| rejection_threshold: float, default=0.0 | ||
| The confidence threshold (0-1). Predictions with confidence below this | ||
| threshold will be rejected and treated as null gestures. | ||
| debug: bool, default=True | ||
| If True, prints accepted gestures with timestamps and confidence values. | ||
| buffer_size: int, default=1 | ||
| Number of successive predictions to buffer before accepting a gesture. | ||
| When buffer_size > 1, the mode (most frequent prediction) across the buffer | ||
| is used to determine the final prediction. This helps filter noisy predictions. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| odh, | ||
| model, | ||
| window_size, | ||
| window_increment, | ||
| null_label, | ||
| feature_list, | ||
| template_size, | ||
| min_template_size=None, | ||
| key_mapping=None, | ||
| feature_dic={}, | ||
| gesture_mapping=None, | ||
| rejection_threshold=0.0, | ||
| debug=True, | ||
| buffer_size=1 | ||
| ): | ||
| self.odh = odh | ||
| self.window_size = window_size | ||
| self.window_increment = window_increment | ||
| self.feature_list = feature_list | ||
| self.model = model | ||
| self.null_label = null_label | ||
| self.template_size = template_size | ||
| self.min_template_size = min_template_size if min_template_size is not None else template_size | ||
| self.key_mapping = key_mapping | ||
| self.feature_dic = feature_dic | ||
| self.gesture_mapping = gesture_mapping | ||
| self.rejection_threshold = rejection_threshold | ||
| self.debug = debug | ||
| self.buffer_size = buffer_size | ||
| self.prediction_buffer = deque(maxlen=buffer_size) | ||
| self.fe = FeatureExtractor() | ||
|
|
||
| def run(self): | ||
| """ | ||
| Main loop for gesture detection. | ||
| Uses predict_proba to apply an optional rejection threshold. | ||
| When buffer_size > 1, takes the mode across multiple successive predictions. | ||
| """ | ||
| expected_count = self.min_template_size | ||
|
|
||
| while True: | ||
| # Get and process EMG data | ||
| _, counts = self.odh.get_data(self.window_size) | ||
| if counts['emg'][0][0] >= expected_count: | ||
| data, _ = self.odh.get_data(self.template_size) | ||
| emg = data['emg'][::-1] | ||
| feats = get_windows(emg, window_size=self.window_size, window_increment=self.window_increment) | ||
| if self.feature_list is not None: | ||
| feats = self.fe.extract_features(self.feature_list, feats, array=True, feature_dic=self.feature_dic) | ||
|
|
||
| probas = self.model.predict_proba(np.array([feats]))[0] | ||
|
|
||
| # Get the class with the highest probability | ||
| pred = np.argmax(probas) | ||
| confidence = probas[pred] | ||
|
|
||
| # Check rejection threshold | ||
| if confidence < self.rejection_threshold: | ||
| pred = self.null_label | ||
|
|
||
| # Add prediction to buffer | ||
| self.prediction_buffer.append(pred) | ||
|
|
||
| # Check if buffer is full and compute mode | ||
| if len(self.prediction_buffer) >= self.buffer_size: | ||
| # Get mode of buffer predictions | ||
| buffer_list = list(self.prediction_buffer) | ||
| mode_result = stats.mode(buffer_list, keepdims=False) | ||
| buffered_pred = mode_result[0] | ||
|
|
||
| if buffered_pred != self.null_label: | ||
| if self.debug: | ||
| label = self.gesture_mapping[buffered_pred] if self.gesture_mapping else buffered_pred | ||
| print(f"{time.time()} ACCEPTED: {label} (Conf: {confidence:.2f})") | ||
|
|
||
| if self.key_mapping is not None: | ||
| self._key_press(buffered_pred) | ||
|
|
||
| self.odh.reset() | ||
| self.prediction_buffer.clear() | ||
| expected_count = self.min_template_size | ||
| else: | ||
| expected_count += self.window_increment | ||
| else: | ||
| expected_count += self.window_increment | ||
|
|
||
| def _key_press(self, pred): | ||
| import pyautogui | ||
| gesture_name = self.gesture_mapping[pred] | ||
| if gesture_name in self.key_mapping: | ||
| pyautogui.press(self.key_mapping[gesture_name]) |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new discrete control functionality including OnlineDiscreteClassifier, discrete models (MVLDA, DTWClassifier, MyoCrossUserPretrained), and the discrete parameter in parse_windows and extract_features lacks test coverage. Given that the repository has existing test coverage for other features, consider adding tests for the discrete functionality to ensure correctness and prevent regressions.
|
|
||
| MODEL_URL = "https://github.com/eeddy/DiscreteMCI/raw/main/Other/Discrete.model" | ||
| DEFAULT_MODEL_PATH = os.path.join("./Discrete.model") |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The DEFAULT_MODEL_PATH downloads the model to './Discrete.model' in the current working directory. This could pollute the user's working directory or cause permission issues in read-only directories. Consider using a more appropriate location such as a cache directory (e.g., using platformdirs or tempfile) or documenting that users should specify a custom model_path.
| MODEL_URL = "https://github.com/eeddy/DiscreteMCI/raw/main/Other/Discrete.model" | |
| DEFAULT_MODEL_PATH = os.path.join("./Discrete.model") | |
| import tempfile | |
| MODEL_URL = "https://github.com/eeddy/DiscreteMCI/raw/main/Other/Discrete.model" | |
| DEFAULT_MODEL_PATH = os.path.join(tempfile.gettempdir(), "Discrete.model") |
| Returns | ||
| ---------- | ||
| dictionary or list | ||
| A dictionary where each key is a specific feature and its value is a list of the computed | ||
| dictionary or list | ||
| A dictionary where each key is a specific feature and its value is a list of the computed | ||
| features for each window. | ||
| StandardScaler | ||
| If normalize is true it will return the normalizer object. This should be passed into the feature extractor for test data. | ||
| """ |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Returns section of the docstring does not document the different return behavior when discrete=True. When discrete=True and normalize=False, it returns a list of dictionaries/arrays. When discrete=True and normalize=True, it returns a tuple of (list of arrays, scaler). Update the docstring to clearly document both cases.
| break | ||
| if not font_loaded: | ||
| # Fallback: scale the default bitmap font (lower quality) | ||
| default_font = None |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable default_font is not used.
libemg/_discrete_models/DTW.py
Outdated
| @@ -0,0 +1,54 @@ | |||
| from tslearn.metrics import dtw_path | |||
| import numpy as np | |||
| from collections import Counter | |||
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 'Counter' is not used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 11 out of 13 changed files in this pull request and generated 16 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| sys.modules['DiscreteClassifier'] = sys.modules[__name__] | ||
|
|
||
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
| self.model = torch.load(self.model_path, map_location=device, weights_only=False) |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using weights_only=False in torch.load() can pose a security risk, as it allows arbitrary code execution during deserialization. This is particularly concerning since the model is downloaded from an external URL. Consider adding validation of the downloaded file (e.g., checking a hash) or documenting this security consideration for users.
libemg/feature_extractor.py
Outdated
| Returns | ||
| ---------- When discrete=False: | ||
| dictionary or np.ndarray | ||
| A dictionary where each key is a specific feature and its value is a list of the computed | ||
| features for each window. If array=True, returns a np.ndarray instead. | ||
| tuple (np.ndarray, StandardScaler) | ||
| If normalize=True, returns a tuple of (features array, scaler). The scaler should be passed | ||
| into the feature extractor for test data. | ||
| When discrete=True: | ||
| list | ||
| A list of dictionaries/arrays (one per template). If array=True, each element is a np.ndarray. | ||
| tuple (list, StandardScaler) | ||
| If normalize=True, returns a tuple of (list of np.ndarrays, scaler). The scaler should be | ||
| passed into the feature extractor for test data. |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Returns section in the docstring is malformed with "When discrete=False:" and "When discrete=True:" appearing on the same line as "----------". The formatting should have the type/description headers on their own lines for proper rendering in documentation tools.
| self.args = { | ||
| 'window_size': 10, 'window_increment': 5, 'null_label': 0, 'feature_list': None, 'template_size': 250, 'min_template_size': 150, 'gesture_mapping': ['Nothing', 'Close', 'Flexion', 'Extension', 'Open', 'Pinch'], 'buffer_size': 5, | ||
| } |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hardcoded gesture mapping list ['Nothing', 'Close', 'Flexion', 'Extension', 'Open', 'Pinch'] in the args dictionary is not documented in the class docstring. Users need to know what gestures the pretrained model was trained on and in what order the class indices map to these gesture names.
| """Record while spacebar is held, stop when released.""" | ||
| # Display gesture name and grayscale image (waiting state) | ||
| dpg.set_value("__dc_prompt", value=media[1]) | ||
| dpg.set_item_width("__dc_prompt_spacer", width=self.video_player_width/2+30 - (7*len(media[1]))/2) |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The centering calculation width=self.video_player_width/2+30 - (7*len(media[1]))/2 uses a magic number (7) that appears to be an estimate of character width in pixels. This calculation may not work correctly for all fonts or font sizes, especially with the new UI scaling feature. Consider using actual text measurement methods if available, or document this assumption.
| dpg.set_item_width("__dc_prompt_spacer", width=self.video_player_width/2+30 - (7*len(media[1]))/2) | |
| # Center the prompt based on its actual rendered width instead of a per-character estimate | |
| text_width, _ = dpg.get_text_size(media[1]) | |
| dpg.set_item_width("__dc_prompt_spacer", width=self.video_player_width/2 + 30 - text_width/2) |
| print(f"Downloading model to {self.model_path}...") | ||
| urllib.request.urlretrieve(MODEL_URL, self.model_path) | ||
| print("Download complete.") |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model download from an external URL (GitHub) lacks error handling for network failures or invalid responses. If the download fails or is interrupted, it could leave a corrupted file that would cause subsequent loads to fail. Consider adding error handling, retry logic, or at minimum validating the downloaded file before saving it.
| for i, s in enumerate(X): | ||
| # DTW distances to templates | ||
| dists = np.array([dtw_path(t, s)[1] for t in self.templates], dtype=float) |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The DTW distance calculation in the inner loop (line 30) computes distances to all templates for every prediction, which could be slow for large template sets. The implementation uses a list comprehension with dtw_path which may not be optimized. Consider whether there are opportunities for caching or optimization, especially if the same samples are processed multiple times.
| def _key_press(self, pred): | ||
| import pyautogui | ||
| gesture_name = self.gesture_mapping[pred] | ||
| if gesture_name in self.key_mapping: | ||
| pyautogui.press(self.key_mapping[gesture_name]) |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pyautogui import is done inside the method rather than at the module level, but there's no error handling if the library is not installed. Since key_mapping is an optional feature, consider adding a try-except block to provide a clear error message if pyautogui is not available, or document it as a required dependency for this feature.
| from libemg._discrete_models.MVLDA import MVLDA | ||
| from libemg._discrete_models.DTW import DTWClassifier | ||
| from libemg._discrete_models.MyoCrossUser import MyoCrossUserPretrained No newline at end of file |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no test coverage for the new discrete models (MVLDA, DTWClassifier, MyoCrossUserPretrained) or the OnlineDiscreteClassifier. Given that the repository has comprehensive test coverage for other components (test_feature_extractor.py, test_online_classifier.py, etc.), tests should be added for these new classes to maintain consistency and ensure reliability.
| sys.modules['DiscreteClassifier'] = sys.modules[__name__] | ||
|
|
||
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
| self.model = torch.load(self.model_path, map_location=device, weights_only=False) | ||
| self.model.eval() | ||
|
|
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method modifies sys.modules globally by adding a DiscreteClassifier entry (line 130). This could cause issues if multiple instances of MyoCrossUserPretrained are created or if there are naming conflicts with other modules. Consider using a more specific module name or documenting this side effect clearly.
| sys.modules['DiscreteClassifier'] = sys.modules[__name__] | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| self.model = torch.load(self.model_path, map_location=device, weights_only=False) | |
| self.model.eval() | |
| original_module = sys.modules.get('DiscreteClassifier') | |
| sys.modules['DiscreteClassifier'] = sys.modules[__name__] | |
| try: | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| self.model = torch.load(self.model_path, map_location=device, weights_only=False) | |
| self.model.eval() | |
| finally: | |
| # Restore previous sys.modules state to avoid global side effects | |
| if original_module is None: | |
| sys.modules.pop('DiscreteClassifier', None) | |
| else: | |
| sys.modules['DiscreteClassifier'] = original_module |
| def extract_features(self, feature_list, windows, feature_dic={}, array=False, normalize=False, normalizer=None, fix_feature_errors=False, discrete=False): | ||
| """Extracts a list of features. | ||
| Parameters | ||
| ---------- | ||
| feature_list: list | ||
| The group of features to extract. Run get_feature_list() or checkout the API documentation | ||
| to find an up-to-date feature list. | ||
| windows: list | ||
| The group of features to extract. Run get_feature_list() or checkout the API documentation | ||
| to find an up-to-date feature list. | ||
| windows: list | ||
| A list of windows - should be computed directly from the OfflineDataHandler or the utils.get_windows() method. | ||
| feature_dic: dict | ||
| A dictionary containing the parameters you'd like passed to each feature. ex. {"MDF_sf":1000} | ||
| array: bool (optional), default=False | ||
| array: bool (optional), default=False | ||
| If True, the dictionary will get converted to a list. | ||
| normalize: bool (optional), default=False | ||
| If True, the features will be normalized between using sklearn StandardScaler. The returned object will be a list. | ||
| normalizer: StandardScaler, default=None | ||
| This should be set to the output from feature extraction on the training data. Do not normalize testing features without this as this could be considered information leakage. | ||
| This should be set to the output from feature extraction on the training data. Do not normalize testing features without this as this could be considered information leakage. | ||
| fix_feature_errors: bool (optional), default=False | ||
| If true, fixes all feature errors (NaN=0, INF=0, -INF=0). | ||
| discrete: bool (optional), default=False | ||
| If True, windows is expected to be a list of templates (from parse_windows with discrete=True). | ||
| Features will be extracted for each template separately and returned as a list. | ||
| Returns | ||
| ---------- When discrete=False: | ||
| dictionary or np.ndarray | ||
| A dictionary where each key is a specific feature and its value is a list of the computed | ||
| features for each window. If array=True, returns a np.ndarray instead. | ||
| tuple (np.ndarray, StandardScaler) | ||
| If normalize=True, returns a tuple of (features array, scaler). The scaler should be passed | ||
| into the feature extractor for test data. | ||
| When discrete=True: | ||
| list | ||
| A list of dictionaries/arrays (one per template). If array=True, each element is a np.ndarray. | ||
| tuple (list, StandardScaler) | ||
| If normalize=True, returns a tuple of (list of np.ndarrays, scaler). The scaler should be | ||
| passed into the feature extractor for test data. | ||
| """ | ||
| if discrete: | ||
| # Handle discrete mode: windows is a list of templates | ||
| all_features = [] | ||
| for template in windows: | ||
| template_features = self._extract_features_single(feature_list, template, feature_dic, array, fix_feature_errors) | ||
| all_features.append(template_features) | ||
|
|
||
| if normalize: | ||
| # For normalization in discrete mode, we need to flatten, normalize, then restructure | ||
| if not array: | ||
| all_features = [self._format_data(f) for f in all_features] | ||
| combined = np.vstack(all_features) | ||
| if not normalizer: | ||
| scaler = StandardScaler() | ||
| combined = scaler.fit_transform(combined) | ||
| else: | ||
| scaler = normalizer | ||
| combined = normalizer.transform(combined) | ||
| # Split back into list based on original sizes | ||
| result = [] | ||
| idx = 0 | ||
| for template in windows: | ||
| n_windows = template.shape[0] | ||
| result.append(combined[idx:idx+n_windows]) | ||
| idx += n_windows | ||
| return result, scaler | ||
| return all_features | ||
|
|
||
| return self._extract_features_single(feature_list, windows, feature_dic, array, fix_feature_errors, normalize, normalizer) | ||
|
|
||
| def _extract_features_single(self, feature_list, windows, feature_dic={}, array=False, fix_feature_errors=False, normalize=False, normalizer=None): |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a mutable default argument (dictionary) for the feature_dic parameter can lead to unexpected behavior if the default is modified. This pattern appears in both extract_features and _extract_features_single. Consider using None as the default and creating a new dictionary inside the method when needed.
eeddy
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving as a beta release
* Update feature extraction for heatmap visualization Modified feature extraction to use new FeatureExtractor interface. * Change features to fe in OnlineStreamer Replaced list of features with FeatureExtractor. This allows the user to pass in a FeatureExtractor object with feature parameters + standardization. * Offline regression example Added offline regression example .md and .rst files. Also added link to this example in index.rst. * Example for new streamer * Revert "Example for new streamer" This reverts commit 353e286. * Rename ColumnFetch to ColumnFetcher Class name was accidentally changed during a previous commit. Reverted so its name is more consistent with other classes. * Remove duplicate method Previous commit added a duplicate write_output method definition in OnlineStreamer. * Add feature queue for time series models Some time series models operate by predicting on a sequence of windows instead of raw EMG data. Added option to queue windows instead of just passing in a single window. * Move constructor docstrings outside __init__ sphinx-doc parsing wasn't detecting classes that didn't have docstrings before the constructor. Moved constructor docstrings to reflect this. * Regression explanation Added regression details to tabs that just described classification. * Add explicit conditional check instead of relying on casting * Only pop if queue is at max length * Add feature queue parameter to child classes Feature queue parameter was in OnlineStreamer, but not the online classifier and regressor. Added documentation to those classes and implemented parameters to create models that feature queue. * Add online channel mask Channels could not be isolated for online data. Added functionality to only grab certain channels online. * Revert "Revert "Rework FeatureExtractor"" This reverts commit ec52763. * Add skip until buffer fills up * Add TODO * Add wildcard to regex helper Default value for regex helper function often threw an error when searching for a pattern. Replaced the default value with the wildcard, so users can use this to grab the potential values without knowing them. * Add check for None in RegexFilter Since the regex helper function can take None values, users may pass in None to RegexFilter. This wouldn't work since we store metadata as an index of the values they pass in. Added a check to ensure this doesn't happen. * Add wildcard to regex helper Default value for regex helper function often threw an error when searching for a pattern. Replaced the default value with the wildcard, so users can use this to grab the potential values without knowing them. * Add check for None in RegexFilter Since the regex helper function can take None values, users may pass in None to RegexFilter. This wouldn't work since we store metadata as an index of the values they pass in. Added a check to ensure this doesn't happen. * Add MEC24 workshop to landing page * Add post-processing visual to regression example * Add skip until buffer fills up * Regression visualization * Handle coordinates without steady state frames An error was thrown if our method of calculating steady state frames failed (like in cases where there weren't any steady state frames). Added a try catch to deal with this. * Animation documentation Added an animation tab to Modules and described some of the animations you can make. * Add heatmap visualization * Add string parameters for common metadata operations Users could pass in function handles to change metadata operations, but there are some operations that are very common and require users to define function handles (like grabbing the last sample). Added ability to pass in strings for some common operations. * ninapro db2 dataglove support added * Revert "Revert "Revert "Rework FeatureExtractor""" This reverts commit 19f2049. * Add online standardization method Data could not be standardized online. Added a method to install standardization to online models. * Use proper FeatureExtractor interface Some changes were not reverted properly, causing the FeatureExtractor in OnlineStreamer to use the old interface. Updated to use the current interface (no breaking API). * Fix heatmap feature extraction Feature extraction in visualize_heatmap used old modified API for FeatureExtractor. Reverted to expected version. * Add clarification to install_standardization docstring * Fix online regressor visualize lag Online regressor method visualize() had lag when plotting. Modified so it uses the matplotlib animation API instead. * Add docstring to OnlineEMGRegressor.visualize * Skeleton for Controller class Added skeleton for Controller class and potential ideas for implementation. * Check for correct MacOS string * ninapro db2 * Remove unnecessary pop of oldest window There was a check for when the queue reached the max length that would manually pop the oldest window. This functionality is already built into deque by specifying the maxlen parameter, so replaced the check with maxlen. * Create SocketController Created base Controller class that all controllers (including keyboard) will inherit from. Also created a SocketController class that all controllers communicating over a port will inherit from. * Add comment idea * Add abstract timestamp parsing function Added another abstract parsing function because messages should also include a timestamp, so the user may want to parse those as well. * RegressorController Implementation of a regressor controller. Hasn't been tested yet. * updates sifi streamer to accept bridge version. forked sifi bridge into bp and ba * Remove model from RegressorController Users may not have access to the model itself in a script, so passing in the model may not work. Modified so users just pass in the IP and port. * db2 dataglove support limited to NinaproDB2 class * ClassifierController implementation Created concept for ClassifierController, but still hasn't been tested. * Fix online regressor visualize lag Online regressor method visualize() had lag when plotting. Modified so it uses the matplotlib animation API instead. * Add option to parse timestamp Added ability to grab timestamp when calling get_data. * Remove OnlineEMGRegressor.parse_output and references This method was essentially a copy of the RegressorController functionality. Replaced it with a call to the RegressorController. * documentation for NinaproDB2 * Remove duplicate method Removed a duplicate method that parent class already had. * Throw error when streaming stops during SGT * Update .gitignore .gitignore didn't ignore all .pyc files. * Convert methods to private Controller had many methods that the user shouldn't call. Converted these methods to private. * Regression Fitts task Added class for Fitts task using regression. Needs to be modified for classification. * Added sifi-bridge-py to dependencies. Renamed executable to sifibridge * cleanup init and sifibridge configurations * Change raw writes to sbp methods * Updated some redundant parameters and docstrings * Remove executable from repo * Fixed sifibridge process launch, added sifibridge to gitignore * Fix gitignore * Fixed cleanup * make streamer function more in-line with sifi-breakout * docstring, change 'device' to 'name' * Remove metric calculation functions * Implement classifier controller for visualization * Add check for data not being received * Add EMG hero legacy code * Replace deque with multiprocessing Queue SocketController used a deque, which isn't thread-safe and meant that the SocketController's data parameter didn't update in all threads when it was updated in another. Replaced deque with multiprocessing queue, which is thread safe. * Fix online ClassifierController parsing ClassifierController was not parsing all messages properly. Added parameters to allow for definite parsing and some checks to tell the user if they set something up incorrectly. * Fix OnlineEMGClassifier.visualize OnlineEMGClassifier.visualize hadn't been updated since shared memory changes. Modified to use new ClassifierController. * Automatically start Controller in Environment * Classifier working in FittsLawTest Regression worked in Fitts environment, but not classification. Modified data parsing to handle classification properly. Also added a parameter that the user can pass in to customize which commands cause which directions. * Rename FittsLawTest to IsoFitts * Move common functionality to run method Replaced abstract run method with abstract _run_helper method. * Update OnlineEMGClassifier visualize docstring Description for the legend parameterr didn't fully explain what user should pass in. * Hide pygame welcome message * Move functionality to Environment class and Fitts parameters Moved shared functionality into common parent class and added some helpful parameters to IsoFitts. * Rework prediction map Having commands be the keys made it a bit tedious to grab the commands. Reverted so the controller outputs are the keys and we grab the commands from that. * Convert EMGHero to Environment Modified legacy EMGHero code to fit Environment interface. * Rework _get_action for queue Blocking until we grabbed the desired value sometimes caused the program to stop. Replaced this so action will just be set to None if no data was received. * Remove _parse_timestamp from Controller class This method was only needed for the SocketController, so moved it there. * Fix prediction map for IsoFitts Prediction map had hand open/close swapped. * Implemented KeyboardController Added KeyboardController to allow users to test out environments with a keyboard first. * Properly receive None in EMGHero EMGHero didn't properly handle None b/c previous implementation used blocking with multiprocessing Queue. Setting key pressed to -1 on each loop meant that the notes flickered because the while loop updated quicker than the environment FPS. The key pressed is now only updated when data are received. * Split environments.py into multiple files Having all environments in one module would result in a massive file. Split into a submodule where each environment can get their own file. * Add environments to __init__ * Fix environment imports * Hide pygame welcome message * IsoFitts direction fix IsoFitts up/down directions were okay for classification/keyboard, but not regression due to pygame treating down as +. * Convert IsoFitts methods to hidden methods Only the run method should be public to the user, so all other methods were hidden. * Environments documentation Added docstrings and explanations for all added environments and controllers. * Rework SocketController SocketController used a multiprocessing Queue, which randomly caused issues when different threads tried to pop/put out of sync. Replaced with a multiprocessing Value to avoid this. * Add wait in OnlineEMGRegressor.visualize Without waiting for predictions to be received an error would be thrown. Added a while loop to wait until predictions are received. * fixed windows compilation target * Override terminate method Overriding terminate method to gracefully close the socket before stopping the process. * Allow non-callable dtypes Some numpy dtypes aren't callable without extra parameters, so added support for those. * Convert Manager to SharedMemory Previous implementation used a multiprocessing.Value and a string array, which is different from other low level memory interfaces in libemg (e.g., streamers) and isn't as fast. Replaced with straight shared memory. This way it uses the same interface as the streamers, has a fixed buffer size, and is quicker. * Add proportional control flag to IsoFitts * Automatically create parent directories when saving coordinates * Update Fitts example with regression Modified Fitts example documentation to include regressor comparisons. * added sifi bridge to setup * Launch the subprocess in run method for pickling issues * Update streamers.py * removed version from sifi bridge import * added more checks for creating variables * SocketController Lock Fix SocketController created a new Lock in the run method, which meant that the Lock wasn't actually serving its purpose (i.e., locking down that block of memory so we couldn't read in one process and write in another). Replaced it so the run method uses self.lock as well. * Add pygame to setup.py * Simplify SocketController Changing Controllers so they don't need a Process. * Make socket non-blocking Socket default behaviour is blocking, which means the FPS of an environment is tied to how often it is receiving data (i.e., the window increment). Made socket non-blocking so the FPS is always the same. * Replace classifier_smm_writes with model_smm_writes Old SMM key was still using classifier instead of model, which threw an error in the regressor. * Pass model timestamp instead of current time Model timestamp is needed for adaptation. * Fixed saving of IMU data data structure changed for the IMU. Previously it was acc_x, now it is ax. * Remove Delsys pasted docstring * Add numpy import to Delsys API streamer * Add check for num_channels in delsys_api_streamer Delsys API streamer required num_channels to be passed, but didn't explain this in the docstring or check. * Add support for .json when logging environment Some users may not want to read a .pkl file, so add support for json. * Add target and distance radii parameters * Add option to redo final rep * Show "Rep X of Y" during data collection * Automatically calculate rep time for videos * Find labels files and move to corresponding data directory * Throw ConnectionError during start callback * GUI docstring update Added information on convenience changes and behaviour when searching for media files. * Add game time parameter * Removed unused dependencies in sifibridge streamer. Set EDA to 0Hz by default. Removed version parameter. Set all modalities to Off except EMG by default. * Add regression info to docstring * Add timestamps to KeyboardController * Draw cursor for RotationalIsoFitts * WIP: converting to using polygons * Rename isofitts module to fitts * Rename IsoFitts class to ISOFitts * Create PolarFitts * Fitts base class * Remove duplicated code from ISOFitts * Inherit from Fitts instead of ISOFitts for PolarFitts * Remove dead code * Create Fitts targets in uniform circle * Handle different target sizes * Custom directions for PolarFitts prediction_map * Set minimum target size based on cursor * Reimplement feature parameters online Installing feature parameters online had been removed at some point. Reimplemented this by passing feature parameters from the offline model. * Only add to smm_items list if tag has not already been provided * Fix check for cursor being on new goal target * Keep radius on screen * Fix incorrect polar to cartesian action map * Restrict PolarFitts to a single side * Removed unused dependencies in sifibridge streamer. Set EDA to 0Hz by default. Removed version parameter. Set all modalities to Off except EMG by default. * Stop targets from generating in the center of the screen * Set theta pointed at bottom of screen instead of right * Docstring for PolarFitts * Initialize theta to pi/2 * Remap theta direction to correspond to prediction direction * Fix visualize_heatmap throwing error for feature lists of length 1 * WIP: Functional conversion from Cartesian to Polar space * Remove dead code * Pass in rect when drawing to polar space * Add fill parameter for drawing polar circle * Fix logical error when checking if new target is on cursor * Combine PolarFitts into parent Fitts class With the changes so everything is the same under the hood and the cursor is rendered differently, the PolarFitts class could be removed and replaced with a parameter in the Fitts class. This way ISOFitts gets this functionality for free. * Fix check for cursor remaining on screen * Update Fitts and ISOFitts docstrings * Add option to draw radius in polar Fitts * Extract origin to variable * Improve check for cursor staying in bounds * Add check in ISOFitts for target distance radius and screen size * Improve error handling for small screen size error * Fix fstring formatting in error message * Use full screen size when possible * Switch polar Fitts to vertical layout instead of horizontal * Change default num_targets in ISOFitts * Add parameters for target and cursor color * Extract Fitts parameters to config dataclass Managing contructor parameters and docstrings for Fitts and ISOFitts was quite tedious and would only get worse when adding more Fitts environments. Extracted many configuration parameters (e.g., target size and color) to a dataclass. Now the docstring and parameters of the dataclass can be changed to update both environments. * Update colors Changed colors to use a color-blind friendly palette. Also added parameters to change timer color, background color, and cursor in target color. * Make OnlineDataHandler processes daemons Processes are created for visualize and log_to_file, but they weren't created as daemon processes. This meant that users had to manually call stop_log() and stop_visualize(). Otherwise, the script would hang because these processes would continue even after the main script was finished. Created these as daemons so they will be killed when the parent process is killed. * Dataset Updates (#85) * Dataset updates * Updates * Added Grab Myo * Updates * Changed pathing * Update OneSubjectEMaGerDataset to new format Added in OneSubjectEMaGerDataset for regression tasks. Updated for new Dataset format. * Add split parameter to prepare_data Return value of prepare_data was changed to a dictionary, which would be a breaking change. Modified this so users can pass in a flag that will determine if it is returned as a dictionary or a data handler. * Updates * Updates * Fixed Grab Myo * Added resp to EPN * Updated the data handler to run faster * Made them all parse fast * Sped up window parsing * Fixed data handler * Updated ref * Made faster * Fixed * Updated * Updates * Updates * Updates * Added cropping * Updated to crop better * Updates * Undo * Hyser dataset Started working on Hyser dataset. Created parent class and started 1DOF class. Both classes still need to be tested. * Updated libemg * Fix zoom logical error Parameter passed to scipy zoom created a value for each column instead of one for each axis. This didn't throw an error in the past because most data that had been tested was also 2 DOF. Modified to proper zoom factor. * Add regex filter packaging and .hea support to FilePackager FilePackager could not read .hea files and creating a package function to match based on filename was tedious. Added the option to pass in a list of RegexFilters that will match the regex metadata from two files to package them. * Add check in regex package function Some metadata files would throw an error when calling get_metadata because the file didn't match the original regex filter. Added a check so False is returned if it doesn't match the RegexFilters. * Fix Hyser1DOF Added correct RegexFilters, package function, and data split. * Added CI dataset * Updates * Updates * UpdaTes * added limb position * radman->radmand * added h5py req * added kaufmannMD * created kaufmann class * added submodules to _dataset * Updated myodisco * Updates * added h5py * HyserNDOF and HyserRandom Classes Added classes for Hyser NDOF and random datasets. * Add type hint to RegexFilter Specified that values should be a list of strings. * Handle single values from MetadataFetcher MetadatFetcher stated in the docs that it expected an N x M array, but didn't throw an error as long as an array was returned. Added a check to ensure that single values aren't being returned. If a single value array is returned, it is cast to an N x 1 array. * Hyser PR Dataset Implemented pattern recognition Hyser dataset. * Remove subject 10 from random task dataset Subject 10 is missing a labels file, so removed this subjects from the random task dataset. * Rename Hyser to _Hyser Added _ to signify that this is a hidden class. * Hyser documentation Added documentation to Hyser dataset classes. * Don't do any processing on the dataset * Add NinaproDB8 * Add note to NinaproDB8 * Add OneSubjectEMaGerDataset import to datasets.py * Add OneSubjectEMaGerDataset to dataset list * Fix parse_windows for 2D metadata parse_windows call np.hstack to stack metadata, which works for 1D arrays but throws an error for 2D arrays since different files likely won't have the exact same number of samples. We also don't want this behaviour anyways since we want to stack along the sample axis. Replaced np.hstack with np.concatenate so metadata is always concatenated along the 0th axis. * Reimplement NinaPro cyberglove data Parsing cyberglove data wasn't brought over when modifying datasets. Reimplemented cyberglove parsing. * Allow empty strings in RegexFilter Added option to pass in an empty string as a description for a RegexFilter for cases where you want to filter files, like finding a labels file, but don't necessarily want that metadata to be stored. * Properly handle cyberglove data Previous implementation didn't consider that some files are skipped because they don't have cyberglove data. Added logic to parse all data, and then allow user to select what they want to grab based on if they're using the cyberglove. * added tmr data * Updates * Updated logging * Updates * Updates * Add UserComplianceDataset * Updates * Updates * Convert labels field to classes in HyserPR * added CIIL_WS. Fixed dataset exist check for regression & WS * initial commit for CIIL_WS * added onedrive download method * added onedrive download method * added one drive downloader * added arguments for unzip and clean * now downloads * Fixed one site bio * Hyser labels fix * Add subjects to Hyser classes * Continuous transitions debugging * Add subjects to OneSubjectEMaGerDataset * Evaluate method fixes * Fixed continuous * Fix subject indexing with Hyser * Handle default subject values for Hyser datasets * Fixed DB8 * Hyser missing subject fixes * added onesiteBP * Fixed continuous transitions * Biopoint * Fixed hyser * Fixed fougner * Add linear mapping to NinaproDB8 * Normalize NinaProDB8 labels * Add parameter to normalize NinaProDB8 labels * Replace string concatenation with fstring String concatenation only worked when strings were passed in. * fxied some docs * Small fixes * Updated doc * Changed to fast by default * Added cross_user evaluation * Trying to fix grab * Updates * Changed to just one session * Fixed error * Trying to fix grab myo * Fixed cross user * Updates * Added option for CNN * datasets * Added nomralization * added subject to prepare method * Updated docs * Fixed cross-user * Add subjects parameter to UserComplianceDataset * Partially fixed the memory issue * Fixed EPN * Fixed * Removed disco * Updates * Fixed FORS-EMG * Updated hyser * Hopefully fixed hyser * Fixed hyser part 2 * Updates * Added meta * Fixes * Fixed all regression datasets * Upadted meta * Update UserComplianceDataset subject IDs UserComplianceDataset subject IDs were reverted to original. Updated full subject list to reflect this. * Updated all to split Ture * fixed * cross user * Updated EPN * Update UserComplianceDataset baseline analysis This dataset was made to train on poor data and test on good, so having the default use just baseline data didn't make sense. * Added docs for most classification datasets * Just need to add Hyser and references * Doc changes * updated docs to include datasets * Updates * More doc updates * Updates and fixed offline metrics * Updated offline metrics * Updated TMR * Remove Hyser workaround A previous commit had added a hacky workaround for dealing with subject IDs that were removed in the Hyser dataset. Removed this change as it has been fixed by new subject indexing. --------- Co-authored-by: Christian Morrell <cmorrell@unb.ca> Co-authored-by: ECEEvanCampbell <evan.campbell1@unb.ca> * Fixed setup issue * Remove incorrect docstring * Reset dwell timer in _get_new_goal_target * Add HyserMVC Dataset (#105) * Add HyserMVC Dataset All Hyser datasets had been included except for MVC. Added MVC in case users wanted it for normalization. * Fix train/test split for HyserMVC * Wait to start timers until data have been received * Add cursor_radius parameter to FittsConfig * Updated the normalization in the feature extractor * Modify FittsConfig default parameters Many of these parameters were for testing and/or primarily for regression. Reduced dwell time and made timeout default to None to apply to most use cases. Also made num_trials a required parameter. * Update _delsys_streamer.py adding axis to the delsys streamer was done to the wrong location * Added option for dictionary list * Added option for dictionary list * Added a fix feature errors to feature extractor * Changed feature_extractor to default to fix errors is false * Changed dataset names * Updated feature dictionary * Updates * Remove old calls to Controller.start * Fixed code example for feature performance (#118) * added pca visualization to online data handler (#110) * Regressor Visualize On One Axis (#112) * Add single_axis parameter * Remove manual y-axis limits Some labels may not be in the range (-1, 1) so manually setting labels wasn't best. * Add docstring for visualize method * Streamer Cleanup (#115) * Add mindrove to requirements * Add MindroveStreamer * Rename mindrove.py to _mindrove.py * Pass correct append method to smm * Add comment * Add comment * Add mindrove streamer function * Add cleanup to MindroveStreamer * Improve variable name * Cleanup Mindrove streamer * Add proper cleanup for MyoStreamer * Handle proper cleanup for EMaGer streamer * Add MindRove to documentation * Updated README * Updated yaml * Added option to configure BLE Power and Memory Mode (#121) Added parameters ble_power and memory_mode to function configure and in the __init__ as this is an important parameter that users should be able to set when using the devices. Co-authored-by: ulysseTM <156504252+ulysseTM@users.noreply.github.com> * Updated version * added connection to specific myo armband via MAC address (#122) * EMG Streaming Pipeline for OTB MuoviPlus (#124) * added support of OTB muovi+ to the documentation * added working streamer and shared memory for OTB Muovi+ * Fixed CRC error * Emager v3 (#127) * test emager 3 data stream * args * version channel map fixed (to test with each) * version options * Adding discrete control functionality to libemg (beta) (#129) * Got data collection working * Added MVLDA and DTW * Discrete Working * Got the online classifier working * Updated based on copilots suggestions * Added torch * Added tslearn * Added docs for discrete stuff * Fixed documentation --------- Co-authored-by: Christian Morrell <cmorrell@unb.ca> Co-authored-by: Evan Campbell <evan.campbell1@unb.ca> Co-authored-by: Amir Reza Hariri <amirrezahariri2000@gmail.com> Co-authored-by: Gabriel Gagne <gabrielpgagne@gmail.com> Co-authored-by: Gabriel Gagne <gagag158@ulaval.ca> Co-authored-by: ulysseTM <156504252+ulysseTM@users.noreply.github.com> Co-authored-by: Tobias Konieczny <c.konieczny@t-online.de> Co-authored-by: Tobias Konieczny <konieczny.tobias@gmail.com> Co-authored-by: Etienne Michaud <60533283+Michiboi29@users.noreply.github.com>
No description provided.