From e4797dc174d027c636c1de3ceceebb0a492269ea Mon Sep 17 00:00:00 2001 From: wandouguo Date: Tue, 2 Dec 2025 19:14:43 +0800 Subject: [PATCH] Add support for NPU device in api.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 增加对NPU的判断和支持,以及在910B测试通过。 --- melo/api.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/melo/api.py b/melo/api.py index 236ea8f17..333159415 100644 --- a/melo/api.py +++ b/melo/api.py @@ -9,7 +9,13 @@ import torch.nn as nn from tqdm import tqdm import torch - +torch_npu_tag =False +try: + import torch_npu + torch_npu_tag =True +except: + + pass from . import utils from . import commons from .models import SynthesizerTrn @@ -28,9 +34,12 @@ def __init__(self, if device == 'auto': device = 'cpu' if torch.cuda.is_available(): device = 'cuda' + elif torch_npu_tag and torch_npu.is_available():device = 'npu' if torch.backends.mps.is_available(): device = 'mps' if 'cuda' in device: assert torch.cuda.is_available() + if 'npu' in device: + assert torch_npu.is_available() # config_path = hps = load_or_download_config(language, use_hf=use_hf, config_path=config_path)