diff --git a/lightllm/common/basemodel/layer_weights/hf_load_utils.py b/lightllm/common/basemodel/layer_weights/hf_load_utils.py index 8cf66a5ad..e323843bf 100755 --- a/lightllm/common/basemodel/layer_weights/hf_load_utils.py +++ b/lightllm/common/basemodel/layer_weights/hf_load_utils.py @@ -7,6 +7,14 @@ from lightllm.utils.dist_utils import get_current_device_id +def get_optimal_load_workers(): + explicit_value = os.environ.get("LOADWORKER") + if explicit_value is not None: + return int(explicit_value) + cpu_cores = os.cpu_count() or 1 + return max(1, min(cpu_cores, 32)) + + def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_layer_list=None, weight_dir=None): # fix bug for 多线程加载的时候,每个线程内部的cuda device 会切回 0, 修改后来保证不会出现bug import torch.distributed as dist @@ -60,7 +68,7 @@ def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_laye transformer_layer_list=transformer_layer_list, weight_dir=weight_dir, ) # noqa - worker = int(os.environ.get("LOADWORKER", 1)) + worker = get_optimal_load_workers() with Pool(worker) as p: iterator = p.imap_unordered(partial_func, candidate_files, chunksize=1) desc_str = f"pid {os.getpid()} Loading model weights with {worker} workers" diff --git a/lightllm/models/vit/layer_weights/hf_load_utils.py b/lightllm/models/vit/layer_weights/hf_load_utils.py index 3fa82af8e..cb17ab247 100644 --- a/lightllm/models/vit/layer_weights/hf_load_utils.py +++ b/lightllm/models/vit/layer_weights/hf_load_utils.py @@ -5,6 +5,14 @@ import lightllm.utils.petrel_helper as utils +def get_optimal_load_workers(): + explicit_value = os.environ.get("LOADWORKER") + if explicit_value is not None: + return int(explicit_value) + cpu_cores = os.cpu_count() or 1 + return max(1, min(cpu_cores, 32)) + + def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_layer_list=None, weight_dir=None): if use_safetensors: weights = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") @@ -62,7 +70,7 @@ def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_laye transformer_layer_list=transformer_layer_list, weight_dir=weight_dir, ) # noqa - worker = int(os.environ.get("LOADWORKER", 1)) + worker = get_optimal_load_workers() with Pool(worker) as p: _ = p.map(partial_func, candidate_files) return