-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun.py
More file actions
138 lines (113 loc) · 3.74 KB
/
run.py
File metadata and controls
138 lines (113 loc) · 3.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# -*- coding: utf-8 -*-
"""
RVC AI 翻唱 - 主入口
"""
import os
import sys
import argparse
from pathlib import Path
# 添加项目根目录到路径
ROOT_DIR = Path(__file__).parent
sys.path.insert(0, str(ROOT_DIR))
from lib.ffmpeg_runtime import configure_ffmpeg_runtime
from lib.logger import log
from lib.runtime_build import get_runtime_build_label
configure_ffmpeg_runtime()
def check_environment():
"""检查运行环境"""
log.header("RVC AI 翻唱系统")
log.info(get_runtime_build_label())
# 检查 Python 版本
py_version = sys.version_info
log.info(f"Python 版本: {py_version.major}.{py_version.minor}.{py_version.micro}")
if py_version.major < 3 or (py_version.major == 3 and py_version.minor < 8):
log.warning("建议使用 Python 3.8 或更高版本")
# 检查 PyTorch
try:
import torch
log.info(f"PyTorch 版本: {torch.__version__}")
from lib.device import get_device_info, _is_rocm, _has_xpu, _has_directml, _has_mps
info = get_device_info()
log.info(f"可用加速后端: {', '.join(info['backends'])}")
if torch.cuda.is_available():
backend = "ROCm" if _is_rocm() else "CUDA"
log.info(f"{backend} 版本: {torch.version.hip if _is_rocm() else torch.version.cuda}")
log.info(f"GPU: {torch.cuda.get_device_name(0)}")
elif _has_xpu():
log.info(f"Intel GPU: {torch.xpu.get_device_name(0)}")
elif _has_directml():
import torch_directml
log.info(f"DirectML 设备: {torch_directml.device_name(0)}")
elif _has_mps():
log.info("Apple MPS 加速可用")
else:
log.warning("未检测到 GPU 加速,将使用 CPU")
except ImportError:
log.error("未安装 PyTorch")
return False
return True
def check_models():
"""检查必需模型"""
from tools.download_models import check_model, REQUIRED_MODELS
missing = []
for name in REQUIRED_MODELS:
if not check_model(name):
missing.append(name)
if missing:
log.warning(f"缺少必需模型: {', '.join(missing)}")
log.info("正在下载...")
from tools.download_models import download_required_models
if not download_required_models():
log.error("模型下载失败,请检查网络连接")
return False
return True
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="RVC AI 翻唱系统")
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
help="服务器地址 (默认: 127.0.0.1)"
)
parser.add_argument(
"--port",
type=int,
default=7860,
help="服务器端口 (默认: 7860)"
)
parser.add_argument(
"--share",
action="store_true",
help="创建公共链接"
)
parser.add_argument(
"--skip-check",
action="store_true",
help="跳过环境检查"
)
parser.add_argument(
"--download-models",
action="store_true",
help="仅下载模型"
)
args = parser.parse_args()
# 仅下载模型
if args.download_models:
from tools.download_models import download_all_models
download_all_models()
return
# 环境检查
if not args.skip_check:
if not check_environment():
sys.exit(1)
# 模型检查
if not check_models():
log.info("提示: 可以使用 --skip-check 跳过检查")
sys.exit(1)
# 启动界面
log.info(f"启动 Gradio 界面: http://{args.host}:{args.port}")
from ui.app import launch
launch(host=args.host, port=args.port, share=args.share)
if __name__ == "__main__":
main()