Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions mmm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@
from .config import InferenceConfig
from .data_loading import DatasetMMM
from .inference import generate
from .logits_processor import StopLogitsProcessor

__all__ = ["DatasetMMM", "generate", "InferenceConfig", "StopLogitsProcessor"]
__all__ = ["DatasetMMM", "InferenceConfig", "generate"]
3 changes: 1 addition & 2 deletions mmm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@ class InferenceConfig:
the new tracks
"""

context_length: int = 4
bars_to_generate: dict[int, list[tuple[int, int, list[str]]]] | None = None
new_tracks: list[tuple[int, list[str]]] | None = None
context_length: int = 4
autoregressive: bool = False
infilling: bool = False

def __post_init__(self) -> None:
"""Check that the Inference config is consistent."""
self.context_tracks = self.bars_to_generate.keys()

if len(self.bars_to_generate) > 0:
self.infilling = True
Expand Down
19 changes: 19 additions & 0 deletions mmm/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,21 @@
"""
# The tokenization steps are outside the try bloc as if there are errors,
# we might want to catch them to fix them instead of skipping the iteration.

try:
score = Score.from_midi(self._dataset[idx]["music"])
except SCORE_LOADING_EXCEPTION:
item = {self.sample_key_name: None, self.labels_key_name: None}
if self.seq2seq:
item[self.decoder_key_name] = None
return item
except Exception as e:

Check failure on line 202 in mmm/data_loading.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (BLE001)

mmm/data_loading.py:202:16: BLE001 Do not catch blind exception: `Exception`
print(f"Here it is at {idx}: {e}")
item = {self.sample_key_name: None, self.labels_key_name: None}
if self.seq2seq:
item[self.decoder_key_name] = None
return item


# Tokenize the score
try:
Expand All @@ -207,6 +215,17 @@
if self.seq2seq:
item[self.decoder_key_name] = None
return item
except KeyError:
item = {self.sample_key_name: None, self.labels_key_name: None}
if self.seq2seq:
item[self.decoder_key_name] = None
return item
except Exception as e:

Check failure on line 223 in mmm/data_loading.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (BLE001)

mmm/data_loading.py:223:16: BLE001 Do not catch blind exception: `Exception`
print(f"Unexpected exception in dataloader (item {idx} ): {e}")
item = {self.sample_key_name: None, self.labels_key_name: None}
if self.seq2seq:
item[self.decoder_key_name] = None
return item
if tseq is None:
item = {self.sample_key_name: None, self.labels_key_name: None}
if self.seq2seq:
Expand Down
Loading
Loading