|
import os |
|
import io |
|
import uuid |
|
import numpy as np |
|
from PIL import Image, ImageFilter |
|
from fastapi import FastAPI, File, UploadFile, Form, HTTPException |
|
from fastapi.responses import JSONResponse |
|
import torch |
|
from transformers import CLIPModel, CLIPProcessor |
|
from diffusers import StableDiffusionInpaintPipeline |
|
from sam2.build_sam import build_sam2 |
|
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator |
|
from huggingface_hub import HfApi, hf_hub_download |
|
import uvicorn |
|
|
|
|
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache" |
|
os.environ["HF_HOME"] = "/tmp/huggingface" |
|
os.makedirs("/tmp/transformers_cache", exist_ok=True) |
|
os.makedirs("/tmp/huggingface", exist_ok=True) |
|
|
|
app = FastAPI() |
|
|
|
|
|
CLOTHING_LABELS = ["a piece of clothing", "shirt", "jacket", "pants", "dress", "skirt"] |
|
CLIP_THRESHOLD = 0.25 |
|
|
|
print("Starting app.py...") |
|
|
|
|
|
def process_image(pil_img: Image.Image, prompt: str, neg_prompt: str, hf_repo: str = None): |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
clip_model = CLIPModel.from_pretrained( |
|
"openai/clip-vit-base-patch32", |
|
from_tf=True, |
|
cache_dir="/tmp/transformers_cache" |
|
).to(device) |
|
clip_processor = CLIPProcessor.from_pretrained( |
|
"openai/clip-vit-base-patch32", |
|
cache_dir="/tmp/transformers_cache" |
|
) |
|
text_inputs = clip_processor(text=CLOTHING_LABELS, return_tensors="pt", padding=True).to(device) |
|
with torch.no_grad(): |
|
text_embeddings = clip_model.get_text_features(**text_inputs) |
|
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True) |
|
|
|
|
|
np_img = np.array(pil_img) |
|
|
|
|
|
cache_dir = os.path.join("/tmp", "sam2_cache") |
|
os.makedirs(cache_dir, exist_ok=True) |
|
ckpt = os.path.join(cache_dir, "sam2_hiera_large.pt") |
|
cfg = os.path.join(cache_dir, "sam2_hiera_l.yaml") |
|
if not os.path.exists(ckpt): |
|
ckpt = hf_hub_download("facebook/sam2-hiera-large", "sam2_hiera_large.pt", repo_type="model", cache_dir=cache_dir) |
|
if not os.path.exists(cfg): |
|
cfg = hf_hub_download("facebook/sam2-hiera-large", "sam2_hiera_l.yaml", repo_type="model", cache_dir=cache_dir) |
|
sam2 = build_sam2("sam2_hiera_l", ckpt, device=device) |
|
mask_generator = SAM2AutomaticMaskGenerator(sam2) |
|
|
|
|
|
masks = mask_generator.generate(np_img) |
|
|
|
|
|
combined = np.zeros(np_img.shape[:2], dtype=bool) |
|
for m in masks: |
|
seg = m.get("segmentation") |
|
if seg is None: continue |
|
ys, xs = np.where(seg) |
|
if ys.size == 0: continue |
|
y1, y2 = ys.min(), ys.max() |
|
x1, x2 = xs.min(), xs.max() |
|
patch = pil_img.crop((x1, y1, x2, y2)) |
|
inputs = clip_processor(images=patch, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
img_emb = clip_model.get_image_features(**inputs) |
|
img_emb = img_emb / img_emb.norm(dim=-1, keepdim=True) |
|
sims = (img_emb @ text_embeddings.T).squeeze(0) |
|
if float(sims.max().cpu()) > CLIP_THRESHOLD: |
|
combined |= seg |
|
|
|
|
|
mask_bin = Image.fromarray((combined.astype(np.uint8)) * 255) |
|
mask_dilated = mask_bin.filter(ImageFilter.MaxFilter(15)) |
|
mask_for_inpaint = mask_dilated.filter(ImageFilter.GaussianBlur(7)) |
|
|
|
|
|
pipe = StableDiffusionInpaintPipeline.from_pretrained( |
|
"sd-legacy/stable-diffusion-inpainting", |
|
revision="fp16", |
|
torch_dtype=torch.float16, |
|
cache_dir="/tmp/diffusers_cache" |
|
).to(device) |
|
try: |
|
pipe.enable_xformers_memory_efficient_attention() |
|
except: |
|
pass |
|
|
|
if not combined.any(): |
|
result = pil_img.copy() |
|
else: |
|
result = pipe( |
|
prompt=prompt, |
|
negative_prompt=neg_prompt, |
|
image=pil_img, |
|
mask_image=mask_for_inpaint |
|
).images[0] |
|
|
|
|
|
viz = np.array(pil_img).astype(np.float32) |
|
rnd = np.random.RandomState(42) |
|
for m in masks: |
|
seg = m.get("segmentation") |
|
if seg is None: continue |
|
color = rnd.randint(0, 256, size=3, dtype=np.uint8) |
|
ys, xs = np.where(seg) |
|
viz[ys, xs] = viz[ys, xs] * 0.5 + color * 0.5 |
|
seg_viz = Image.fromarray(viz.astype(np.uint8)) |
|
|
|
|
|
token = os.getenv("HF_TOKEN") |
|
if token is None: |
|
raise RuntimeError("HF_TOKEN no definido en variables de entorno") |
|
api = HfApi() |
|
if hf_repo is None: |
|
user = api.whoami(token=token)["name"] |
|
hf_repo = f"{user}/sam2-inpaint-outputs" |
|
api.create_repo(repo_id=hf_repo, repo_type="dataset", token=token, exist_ok=True) |
|
|
|
uid = uuid.uuid4().hex[:8] |
|
|
|
temp_dir = "/tmp/sam2_outputs" |
|
os.makedirs(temp_dir, exist_ok=True) |
|
|
|
names = { |
|
"seg": os.path.join(temp_dir, f"sam_seg_{uid}.png"), |
|
"mask": os.path.join(temp_dir, f"mask_{uid}.png"), |
|
"out": os.path.join(temp_dir, f"inpaint_{uid}.png") |
|
} |
|
|
|
|
|
seg_viz.save(names["seg"]) |
|
mask_bin.save(names["mask"]) |
|
result.save(names["out"]) |
|
|
|
|
|
url_names = { |
|
"seg": f"sam_seg_{uid}.png", |
|
"mask": f"mask_{uid}.png", |
|
"out": f"inpaint_{uid}.png" |
|
} |
|
|
|
|
|
for key, fname in names.items(): |
|
api.upload_file( |
|
path_or_fileobj=fname, |
|
path_in_repo=url_names[key], |
|
repo_id=hf_repo, |
|
repo_type="dataset", |
|
token=token |
|
) |
|
os.remove(fname) |
|
|
|
base = f"https://huggingface.co./datasets/{hf_repo}/resolve/main" |
|
return ( |
|
f"{base}/{url_names['seg']}", |
|
f"{base}/{url_names['mask']}", |
|
f"{base}/{url_names['out']}" |
|
) |
|
|
|
|
|
@app.post("/inpaint/") |
|
async def inpaint( |
|
file: UploadFile = File(...), |
|
prompt: str = Form(...), |
|
neg_prompt: str = Form("old clothes, residue, artifacts"), |
|
hf_repo: str = Form(None) |
|
): |
|
|
|
try: |
|
data = await file.read() |
|
img = Image.open(io.BytesIO(data)).convert("RGB") |
|
except Exception: |
|
raise HTTPException(status_code=400, detail="Imagen no v谩lida") |
|
|
|
try: |
|
seg_url, mask_url, out_url = process_image(img, prompt, neg_prompt, hf_repo) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
return JSONResponse({ |
|
"sam_segmentation": seg_url, |
|
"clothing_mask": mask_url, |
|
"inpainted": out_url |
|
}) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
print("Preloading CLIP model...") |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
try: |
|
|
|
clip_model = CLIPModel.from_pretrained( |
|
"openai/clip-vit-base-patch32", |
|
from_tf=True, |
|
cache_dir="/tmp/transformers_cache" |
|
).to(device) |
|
clip_processor = CLIPProcessor.from_pretrained( |
|
"openai/clip-vit-base-patch32", |
|
cache_dir="/tmp/transformers_cache" |
|
) |
|
print("CLIP model loaded successfully!") |
|
except Exception as e: |
|
print(f"Error preloading CLIP model: {e}") |
|
|
|
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |