From 4437c0380d6e3578e670a4cdbcadb874d2d9337b Mon Sep 17 00:00:00 2001 From: hewuxingkong Date: Sat, 25 Apr 2026 22:51:58 +0800 Subject: [PATCH] Fix Mac MPS support + PyTorch 2.6+ weights_only compat MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two small fixes that block training on macOS with recent PyTorch: 1. extract_f0_print.py — RMVPE init was hardcoded to "cuda:0" so it crashed on Macs even when MPS was available. Now auto-detects: cuda → mps → cpu. 2. extract_feature_print.py — fairseq's checkpoint_utils.load_model_ensemble_and_task calls torch.load() without specifying weights_only=False. PyTorch 2.6 flipped the default to True, which rejects fairseq Dictionary pickles. Patches torch.load module-globally before importing fairseq so the existing flow keeps working. Tested with PyTorch 2.8.0 on macOS arm64 (M-series), Python 3.9.6. Hubert + RMVPE both load and run extraction on MPS without errors. --- extract_f0_print.py | 3 ++- extract_feature_print.py | 8 ++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/extract_f0_print.py b/extract_f0_print.py index 063e322326..2feab32011 100644 --- a/extract_f0_print.py +++ b/extract_f0_print.py @@ -241,7 +241,8 @@ def compute_f0(self, path, f0_method, crepe_hop_length): from rmvpe import RMVPE print("loading rmvpe model") - self.model_rmvpe = RMVPE("rmvpe.pt", is_half=False, device="cuda:0") + _dev = "cuda:0" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") + self.model_rmvpe = RMVPE("rmvpe.pt", is_half=False, device=_dev) f0 = self.model_rmvpe.infer_from_audio(x, thred=0.03) elif f0_method == "dio": f0, t = pyworld.dio( diff --git a/extract_feature_print.py b/extract_feature_print.py index 09d87c052d..ff24e66e9b 100644 --- a/extract_feature_print.py +++ b/extract_feature_print.py @@ -15,6 +15,14 @@ os.environ["CUDA_VISIBLE_DEVICES"] = str(i_gpu) version = sys.argv[6] import torch +# Patch torch.load for PyTorch 2.6+ (fairseq calls it with weights_only=True by default) +_orig_torch_load = torch.load +def _patched_torch_load(*args, **kwargs): + if "weights_only" not in kwargs: + kwargs["weights_only"] = False + return _orig_torch_load(*args, **kwargs) +torch.load = _patched_torch_load + import torch.nn.functional as F import soundfile as sf import numpy as np