|
import os |
|
import json |
|
import requests |
|
import torch |
|
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation, CLIPProcessor, CLIPModel |
|
from controlnet_aux import MLSDdetector |
|
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline, StableDiffusionControlNetInpaintPipeline |
|
import urllib.request |
|
import shutil |
|
|
|
|
|
def create_directories(): |
|
directories = [ |
|
"resources", |
|
"resources/models", |
|
"resources/images", |
|
"resources/labels", |
|
"resources/output" |
|
] |
|
for directory in directories: |
|
os.makedirs(directory, exist_ok=True) |
|
print("目录结构创建完成") |
|
|
|
|
|
def download_labels(): |
|
url = "https://huggingface.co./datasets/huggingface/label-files/raw/main/ade20k-id2label.json" |
|
labels_path = "resources/labels/ade20k-id2label.json" |
|
response = requests.get(url) |
|
with open(labels_path, 'w') as f: |
|
f.write(response.text) |
|
print(f"标签文件已保存到: {labels_path}") |
|
|
|
|
|
def download_sample_image(): |
|
raw_url = "https://raw.githubusercontent.com/naderAsadi/DesignGenie/main/examples/images/sample_input.png" |
|
img_path = "resources/images/sample_input.png" |
|
try: |
|
urllib.request.urlretrieve(raw_url, img_path) |
|
print(f"示例图片已保存到: {img_path}") |
|
|
|
shutil.copy(img_path, "sample_input.png") |
|
except Exception as e: |
|
print(f"图片下载失败: {e}") |
|
|
|
|
|
def download_models(): |
|
print("正在下载模型,这可能需要一些时间...") |
|
|
|
|
|
print("下载 Mask2Former 模型...") |
|
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-ade-semantic", cache_dir="resources/models") |
|
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-large-ade-semantic", cache_dir="resources/models") |
|
print("Mask2Former 模型下载完成") |
|
|
|
|
|
print("下载 MLSD 检测器...") |
|
processor = MLSDdetector.from_pretrained("lllyasviel/Annotators", cache_dir="resources/models") |
|
print("MLSD 检测器下载完成") |
|
|
|
|
|
print("下载 ControlNet 模型...") |
|
controlnet = ControlNetModel.from_pretrained( |
|
"lllyasviel/control_v11p_sd15_mlsd", |
|
torch_dtype=torch.float16, |
|
cache_dir="resources/models", |
|
use_safetensors=False |
|
) |
|
print("ControlNet 模型下载完成") |
|
|
|
|
|
print("下载 Stable Diffusion 模型...") |
|
pipe = StableDiffusionControlNetPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
controlnet=controlnet, |
|
torch_dtype=torch.float16, |
|
cache_dir="resources/models", |
|
use_safetensors=False |
|
) |
|
print("Stable Diffusion 模型下载完成") |
|
|
|
|
|
print("下载 Stable Diffusion Inpainting 模型...") |
|
pipe_inpaint = StableDiffusionControlNetInpaintPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-inpainting", |
|
controlnet=controlnet, |
|
torch_dtype=torch.float16, |
|
cache_dir="resources/models", |
|
use_safetensors=False |
|
) |
|
print("Stable Diffusion Inpainting 模型下载完成") |
|
|
|
|
|
print("下载图像特征提取模型...") |
|
try: |
|
clip_model = CLIPModel.from_pretrained( |
|
"openai/clip-vit-base-patch32", |
|
cache_dir="resources/models" |
|
) |
|
clip_processor = CLIPProcessor.from_pretrained( |
|
"openai/clip-vit-base-patch32", |
|
cache_dir="resources/models" |
|
) |
|
print("图像特征提取模型下载完成") |
|
except Exception as e: |
|
print(f"图像特征提取模型下载失败: {e}") |
|
|
|
if __name__ == "__main__": |
|
create_directories() |
|
download_labels() |
|
download_sample_image() |
|
download_models() |
|
print("所有资源下载完成!您可以将整个 'resources' 文件夹保存到本地使用。") |