diff --git a/melo/api.py b/melo/api.py index 236ea8f17..333159415 100644 --- a/melo/api.py +++ b/melo/api.py @@ -9,7 +9,13 @@ import torch.nn as nn from tqdm import tqdm import torch - +torch_npu_tag =False +try: + import torch_npu + torch_npu_tag =True +except: + + pass from . import utils from . import commons from .models import SynthesizerTrn @@ -28,9 +34,12 @@ def __init__(self, if device == 'auto': device = 'cpu' if torch.cuda.is_available(): device = 'cuda' + elif torch_npu_tag and torch_npu.is_available():device = 'npu' if torch.backends.mps.is_available(): device = 'mps' if 'cuda' in device: assert torch.cuda.is_available() + if 'npu' in device: + assert torch_npu.is_available() # config_path = hps = load_or_download_config(language, use_hf=use_hf, config_path=config_path)