sweet / app.py
playgr
new commit
d6542ef
import os
import io
import uuid
import numpy as np
from PIL import Image, ImageFilter
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse
import torch
from transformers import CLIPModel, CLIPProcessor
from diffusers import StableDiffusionInpaintPipeline
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from huggingface_hub import HfApi, hf_hub_download
import uvicorn
# Configurar cach茅s antes de importar cualquier modelo
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
os.environ["HF_HOME"] = "/tmp/huggingface"
os.makedirs("/tmp/transformers_cache", exist_ok=True)
os.makedirs("/tmp/huggingface", exist_ok=True)
app = FastAPI()
# Etiquetas y umbral para filtrar regiones de ropa
CLOTHING_LABELS = ["a piece of clothing", "shirt", "jacket", "pants", "dress", "skirt"]
CLIP_THRESHOLD = 0.25
print("Starting app.py...")
def process_image(pil_img: Image.Image, prompt: str, neg_prompt: str, hf_repo: str = None):
# --- Configuraci贸n de dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Cargar y normalizar embeddings de texto con CLIP
# A帽adido from_tf=True para manejar pesos en formato TensorFlow
clip_model = CLIPModel.from_pretrained(
"openai/clip-vit-base-patch32",
from_tf=True,
cache_dir="/tmp/transformers_cache"
).to(device)
clip_processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-base-patch32",
cache_dir="/tmp/transformers_cache"
)
text_inputs = clip_processor(text=CLOTHING_LABELS, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
text_embeddings = clip_model.get_text_features(**text_inputs)
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
# --- Preparar imagen numpy para SAM
np_img = np.array(pil_img)
# --- Descargar y cargar SAM2
cache_dir = os.path.join("/tmp", "sam2_cache")
os.makedirs(cache_dir, exist_ok=True)
ckpt = os.path.join(cache_dir, "sam2_hiera_large.pt")
cfg = os.path.join(cache_dir, "sam2_hiera_l.yaml")
if not os.path.exists(ckpt):
ckpt = hf_hub_download("facebook/sam2-hiera-large", "sam2_hiera_large.pt", repo_type="model", cache_dir=cache_dir)
if not os.path.exists(cfg):
cfg = hf_hub_download("facebook/sam2-hiera-large", "sam2_hiera_l.yaml", repo_type="model", cache_dir=cache_dir)
sam2 = build_sam2("sam2_hiera_l", ckpt, device=device)
mask_generator = SAM2AutomaticMaskGenerator(sam2)
# --- Generar todas las m谩scaras
masks = mask_generator.generate(np_img)
# --- Filtrar m谩scaras por contenido de ropa usando CLIP
combined = np.zeros(np_img.shape[:2], dtype=bool)
for m in masks:
seg = m.get("segmentation")
if seg is None: continue
ys, xs = np.where(seg)
if ys.size == 0: continue
y1, y2 = ys.min(), ys.max()
x1, x2 = xs.min(), xs.max()
patch = pil_img.crop((x1, y1, x2, y2))
inputs = clip_processor(images=patch, return_tensors="pt").to(device)
with torch.no_grad():
img_emb = clip_model.get_image_features(**inputs)
img_emb = img_emb / img_emb.norm(dim=-1, keepdim=True)
sims = (img_emb @ text_embeddings.T).squeeze(0)
if float(sims.max().cpu()) > CLIP_THRESHOLD:
combined |= seg
# --- Crear y procesar m谩scaras
mask_bin = Image.fromarray((combined.astype(np.uint8)) * 255)
mask_dilated = mask_bin.filter(ImageFilter.MaxFilter(15))
mask_for_inpaint = mask_dilated.filter(ImageFilter.GaussianBlur(7))
# --- Inpainting con Stable Diffusion
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"sd-legacy/stable-diffusion-inpainting",
revision="fp16",
torch_dtype=torch.float16,
cache_dir="/tmp/diffusers_cache"
).to(device)
try:
pipe.enable_xformers_memory_efficient_attention()
except:
pass
if not combined.any():
result = pil_img.copy()
else:
result = pipe(
prompt=prompt,
negative_prompt=neg_prompt,
image=pil_img,
mask_image=mask_for_inpaint
).images[0]
# --- Crear visualizaci贸n de segmentaciones SAM
viz = np.array(pil_img).astype(np.float32)
rnd = np.random.RandomState(42)
for m in masks:
seg = m.get("segmentation")
if seg is None: continue
color = rnd.randint(0, 256, size=3, dtype=np.uint8)
ys, xs = np.where(seg)
viz[ys, xs] = viz[ys, xs] * 0.5 + color * 0.5
seg_viz = Image.fromarray(viz.astype(np.uint8))
# --- Subida a HF Hub (datasets)
token = os.getenv("HF_TOKEN")
if token is None:
raise RuntimeError("HF_TOKEN no definido en variables de entorno")
api = HfApi()
if hf_repo is None:
user = api.whoami(token=token)["name"]
hf_repo = f"{user}/sam2-inpaint-outputs"
api.create_repo(repo_id=hf_repo, repo_type="dataset", token=token, exist_ok=True)
uid = uuid.uuid4().hex[:8]
# Usa directorio temporal para archivos temporales
temp_dir = "/tmp/sam2_outputs"
os.makedirs(temp_dir, exist_ok=True)
names = {
"seg": os.path.join(temp_dir, f"sam_seg_{uid}.png"),
"mask": os.path.join(temp_dir, f"mask_{uid}.png"),
"out": os.path.join(temp_dir, f"inpaint_{uid}.png")
}
# Guardar temporales
seg_viz.save(names["seg"])
mask_bin.save(names["mask"])
result.save(names["out"])
# Nombres para URLs
url_names = {
"seg": f"sam_seg_{uid}.png",
"mask": f"mask_{uid}.png",
"out": f"inpaint_{uid}.png"
}
# Subir
for key, fname in names.items():
api.upload_file(
path_or_fileobj=fname,
path_in_repo=url_names[key],
repo_id=hf_repo,
repo_type="dataset",
token=token
)
os.remove(fname)
base = f"https://huggingface.co./datasets/{hf_repo}/resolve/main"
return (
f"{base}/{url_names['seg']}",
f"{base}/{url_names['mask']}",
f"{base}/{url_names['out']}"
)
@app.post("/inpaint/")
async def inpaint(
file: UploadFile = File(...),
prompt: str = Form(...),
neg_prompt: str = Form("old clothes, residue, artifacts"),
hf_repo: str = Form(None)
):
# Leer imagen subida
try:
data = await file.read()
img = Image.open(io.BytesIO(data)).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Imagen no v谩lida")
# Procesar
try:
seg_url, mask_url, out_url = process_image(img, prompt, neg_prompt, hf_repo)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Responder JSON
return JSONResponse({
"sam_segmentation": seg_url,
"clothing_mask": mask_url,
"inpainted": out_url
})
# Agregar una funci贸n main para ejecutar directamente
if __name__ == "__main__":
# Precargar modelos para verificar que funcionen antes de iniciar el servidor
print("Preloading CLIP model...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
# Usar from_tf=True para arreglar la carga del modelo
clip_model = CLIPModel.from_pretrained(
"openai/clip-vit-base-patch32",
from_tf=True,
cache_dir="/tmp/transformers_cache"
).to(device)
clip_processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-base-patch32",
cache_dir="/tmp/transformers_cache"
)
print("CLIP model loaded successfully!")
except Exception as e:
print(f"Error preloading CLIP model: {e}")
# No salir - dejar que falle en tiempo de ejecuci贸n si es necesario
# Ejecutar la aplicaci贸n FastAPI con uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)