File size: 8,048 Bytes
6fc2711 d6542ef 6fc2711 d6542ef bae2d47 6fc2711 cc4cc4a d6542ef 6fc2711 cc4cc4a 6fc2711 cc4cc4a 6fc2711 cc4cc4a 6fc2711 cc4cc4a 6fc2711 cc4cc4a 6fc2711 cc4cc4a 6fc2711 cc4cc4a 6fc2711 cc4cc4a 6fc2711 d6542ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
import os
import io
import uuid
import numpy as np
from PIL import Image, ImageFilter
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse
import torch
from transformers import CLIPModel, CLIPProcessor
from diffusers import StableDiffusionInpaintPipeline
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from huggingface_hub import HfApi, hf_hub_download
import uvicorn
# Configurar cach茅s antes de importar cualquier modelo
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
os.environ["HF_HOME"] = "/tmp/huggingface"
os.makedirs("/tmp/transformers_cache", exist_ok=True)
os.makedirs("/tmp/huggingface", exist_ok=True)
app = FastAPI()
# Etiquetas y umbral para filtrar regiones de ropa
CLOTHING_LABELS = ["a piece of clothing", "shirt", "jacket", "pants", "dress", "skirt"]
CLIP_THRESHOLD = 0.25
print("Starting app.py...")
def process_image(pil_img: Image.Image, prompt: str, neg_prompt: str, hf_repo: str = None):
# --- Configuraci贸n de dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Cargar y normalizar embeddings de texto con CLIP
# A帽adido from_tf=True para manejar pesos en formato TensorFlow
clip_model = CLIPModel.from_pretrained(
"openai/clip-vit-base-patch32",
from_tf=True,
cache_dir="/tmp/transformers_cache"
).to(device)
clip_processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-base-patch32",
cache_dir="/tmp/transformers_cache"
)
text_inputs = clip_processor(text=CLOTHING_LABELS, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
text_embeddings = clip_model.get_text_features(**text_inputs)
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
# --- Preparar imagen numpy para SAM
np_img = np.array(pil_img)
# --- Descargar y cargar SAM2
cache_dir = os.path.join("/tmp", "sam2_cache")
os.makedirs(cache_dir, exist_ok=True)
ckpt = os.path.join(cache_dir, "sam2_hiera_large.pt")
cfg = os.path.join(cache_dir, "sam2_hiera_l.yaml")
if not os.path.exists(ckpt):
ckpt = hf_hub_download("facebook/sam2-hiera-large", "sam2_hiera_large.pt", repo_type="model", cache_dir=cache_dir)
if not os.path.exists(cfg):
cfg = hf_hub_download("facebook/sam2-hiera-large", "sam2_hiera_l.yaml", repo_type="model", cache_dir=cache_dir)
sam2 = build_sam2("sam2_hiera_l", ckpt, device=device)
mask_generator = SAM2AutomaticMaskGenerator(sam2)
# --- Generar todas las m谩scaras
masks = mask_generator.generate(np_img)
# --- Filtrar m谩scaras por contenido de ropa usando CLIP
combined = np.zeros(np_img.shape[:2], dtype=bool)
for m in masks:
seg = m.get("segmentation")
if seg is None: continue
ys, xs = np.where(seg)
if ys.size == 0: continue
y1, y2 = ys.min(), ys.max()
x1, x2 = xs.min(), xs.max()
patch = pil_img.crop((x1, y1, x2, y2))
inputs = clip_processor(images=patch, return_tensors="pt").to(device)
with torch.no_grad():
img_emb = clip_model.get_image_features(**inputs)
img_emb = img_emb / img_emb.norm(dim=-1, keepdim=True)
sims = (img_emb @ text_embeddings.T).squeeze(0)
if float(sims.max().cpu()) > CLIP_THRESHOLD:
combined |= seg
# --- Crear y procesar m谩scaras
mask_bin = Image.fromarray((combined.astype(np.uint8)) * 255)
mask_dilated = mask_bin.filter(ImageFilter.MaxFilter(15))
mask_for_inpaint = mask_dilated.filter(ImageFilter.GaussianBlur(7))
# --- Inpainting con Stable Diffusion
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"sd-legacy/stable-diffusion-inpainting",
revision="fp16",
torch_dtype=torch.float16,
cache_dir="/tmp/diffusers_cache"
).to(device)
try:
pipe.enable_xformers_memory_efficient_attention()
except:
pass
if not combined.any():
result = pil_img.copy()
else:
result = pipe(
prompt=prompt,
negative_prompt=neg_prompt,
image=pil_img,
mask_image=mask_for_inpaint
).images[0]
# --- Crear visualizaci贸n de segmentaciones SAM
viz = np.array(pil_img).astype(np.float32)
rnd = np.random.RandomState(42)
for m in masks:
seg = m.get("segmentation")
if seg is None: continue
color = rnd.randint(0, 256, size=3, dtype=np.uint8)
ys, xs = np.where(seg)
viz[ys, xs] = viz[ys, xs] * 0.5 + color * 0.5
seg_viz = Image.fromarray(viz.astype(np.uint8))
# --- Subida a HF Hub (datasets)
token = os.getenv("HF_TOKEN")
if token is None:
raise RuntimeError("HF_TOKEN no definido en variables de entorno")
api = HfApi()
if hf_repo is None:
user = api.whoami(token=token)["name"]
hf_repo = f"{user}/sam2-inpaint-outputs"
api.create_repo(repo_id=hf_repo, repo_type="dataset", token=token, exist_ok=True)
uid = uuid.uuid4().hex[:8]
# Usa directorio temporal para archivos temporales
temp_dir = "/tmp/sam2_outputs"
os.makedirs(temp_dir, exist_ok=True)
names = {
"seg": os.path.join(temp_dir, f"sam_seg_{uid}.png"),
"mask": os.path.join(temp_dir, f"mask_{uid}.png"),
"out": os.path.join(temp_dir, f"inpaint_{uid}.png")
}
# Guardar temporales
seg_viz.save(names["seg"])
mask_bin.save(names["mask"])
result.save(names["out"])
# Nombres para URLs
url_names = {
"seg": f"sam_seg_{uid}.png",
"mask": f"mask_{uid}.png",
"out": f"inpaint_{uid}.png"
}
# Subir
for key, fname in names.items():
api.upload_file(
path_or_fileobj=fname,
path_in_repo=url_names[key],
repo_id=hf_repo,
repo_type="dataset",
token=token
)
os.remove(fname)
base = f"https://huggingface.co./datasets/{hf_repo}/resolve/main"
return (
f"{base}/{url_names['seg']}",
f"{base}/{url_names['mask']}",
f"{base}/{url_names['out']}"
)
@app.post("/inpaint/")
async def inpaint(
file: UploadFile = File(...),
prompt: str = Form(...),
neg_prompt: str = Form("old clothes, residue, artifacts"),
hf_repo: str = Form(None)
):
# Leer imagen subida
try:
data = await file.read()
img = Image.open(io.BytesIO(data)).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Imagen no v谩lida")
# Procesar
try:
seg_url, mask_url, out_url = process_image(img, prompt, neg_prompt, hf_repo)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Responder JSON
return JSONResponse({
"sam_segmentation": seg_url,
"clothing_mask": mask_url,
"inpainted": out_url
})
# Agregar una funci贸n main para ejecutar directamente
if __name__ == "__main__":
# Precargar modelos para verificar que funcionen antes de iniciar el servidor
print("Preloading CLIP model...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
# Usar from_tf=True para arreglar la carga del modelo
clip_model = CLIPModel.from_pretrained(
"openai/clip-vit-base-patch32",
from_tf=True,
cache_dir="/tmp/transformers_cache"
).to(device)
clip_processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-base-patch32",
cache_dir="/tmp/transformers_cache"
)
print("CLIP model loaded successfully!")
except Exception as e:
print(f"Error preloading CLIP model: {e}")
# No salir - dejar que falle en tiempo de ejecuci贸n si es necesario
# Ejecutar la aplicaci贸n FastAPI con uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) |