Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,762 Bytes
3570591 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
from huggingface_hub import hf_hub_download
from modelscope import snapshot_download
import os, shutil
from typing_extensions import Literal, TypeAlias
from typing import List
from ..configs.model_config import preset_models_on_huggingface, preset_models_on_modelscope, Preset_model_id
def download_from_modelscope(model_id, origin_file_path, local_dir):
os.makedirs(local_dir, exist_ok=True)
file_name = os.path.basename(origin_file_path)
if file_name in os.listdir(local_dir):
print(f" {file_name} has been already in {local_dir}.")
else:
print(f" Start downloading {os.path.join(local_dir, file_name)}")
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
downloaded_file_path = os.path.join(local_dir, origin_file_path)
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
if downloaded_file_path != target_file_path:
shutil.move(downloaded_file_path, target_file_path)
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
def download_from_huggingface(model_id, origin_file_path, local_dir):
os.makedirs(local_dir, exist_ok=True)
file_name = os.path.basename(origin_file_path)
if file_name in os.listdir(local_dir):
print(f" {file_name} has been already in {local_dir}.")
else:
print(f" Start downloading {os.path.join(local_dir, file_name)}")
hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
downloaded_file_path = os.path.join(local_dir, origin_file_path)
target_file_path = os.path.join(local_dir, file_name)
if downloaded_file_path != target_file_path:
shutil.move(downloaded_file_path, target_file_path)
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
Preset_model_website: TypeAlias = Literal[
"HuggingFace",
"ModelScope",
]
website_to_preset_models = {
"HuggingFace": preset_models_on_huggingface,
"ModelScope": preset_models_on_modelscope,
}
website_to_download_fn = {
"HuggingFace": download_from_huggingface,
"ModelScope": download_from_modelscope,
}
def download_customized_models(
model_id,
origin_file_path,
local_dir,
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
):
downloaded_files = []
for website in downloading_priority:
# Check if the file is downloaded.
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
if file_to_download in downloaded_files:
continue
# Download
website_to_download_fn[website](model_id, origin_file_path, local_dir)
if os.path.basename(origin_file_path) in os.listdir(local_dir):
downloaded_files.append(file_to_download)
return downloaded_files
def download_models(
model_id_list: List[Preset_model_id] = [],
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
):
print(f"Downloading models: {model_id_list}")
downloaded_files = []
load_files = []
for model_id in model_id_list:
for website in downloading_priority:
if model_id in website_to_preset_models[website]:
# Parse model metadata
model_metadata = website_to_preset_models[website][model_id]
if isinstance(model_metadata, list):
file_data = model_metadata
else:
file_data = model_metadata.get("file_list", [])
# Try downloading the model from this website.
model_files = []
for model_id, origin_file_path, local_dir in file_data:
# Check if the file is downloaded.
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
if file_to_download in downloaded_files:
continue
# Download
website_to_download_fn[website](model_id, origin_file_path, local_dir)
if os.path.basename(origin_file_path) in os.listdir(local_dir):
downloaded_files.append(file_to_download)
model_files.append(file_to_download)
# If the model is successfully downloaded, break.
if len(model_files) > 0:
if isinstance(model_metadata, dict) and "load_path" in model_metadata:
model_files = model_metadata["load_path"]
load_files.extend(model_files)
break
return load_files |