Spaces:
Paused
Paused
import os | |
import torch | |
import numpy as np | |
from PIL import Image | |
import sys | |
import cv2 | |
import base64 | |
import aiohttp | |
from fal import Client as FalClient | |
sys.path.append('./ComfyUI_AutoCropFaces') | |
from dotenv import load_dotenv | |
load_dotenv() | |
from Pytorch_Retinaface.pytorch_retinaface import Pytorch_RetinaFace | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
from transformers import CLIPProcessor, CLIPModel | |
import gc | |
CACHE_DIR = '/workspace/huggingface_cache' | |
os.environ["HF_HOME"] = CACHE_DIR | |
os.makedirs(CACHE_DIR, exist_ok=True) | |
device = "cuda" | |
def clear_cuda_memory(): | |
"""Aggressively clear CUDA memory""" | |
gc.collect() | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
def load_vision_models(): | |
print("Loading CLIP and Florence models...") | |
# Load CLIP | |
clip_model = CLIPModel.from_pretrained( | |
"openai/clip-vit-large-patch14", | |
cache_dir=CACHE_DIR | |
).to(device) | |
clip_processor = CLIPProcessor.from_pretrained( | |
"openai/clip-vit-large-patch14", | |
cache_dir=CACHE_DIR | |
) | |
# Load Florence | |
florence_model = AutoModelForCausalLM.from_pretrained( | |
"microsoft/Florence-2-large", | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
trust_remote_code=True, | |
cache_dir=CACHE_DIR | |
).to(device) | |
florence_processor = AutoProcessor.from_pretrained( | |
"microsoft/Florence-2-large", | |
trust_remote_code=True, | |
cache_dir=CACHE_DIR | |
) | |
return { | |
'clip_model': clip_model, | |
'clip_processor': clip_processor, | |
'florence_model': florence_model, | |
'florence_processor': florence_processor, | |
} | |
def generate_caption(image): | |
vision_models = load_vision_models() | |
# Ensure the image is a PIL Image | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
# Convert the image to RGB if it has an alpha channel | |
if image.mode == 'RGBA': | |
image = image.convert('RGB') | |
prompt = "<DETAILED_CAPTION>" | |
inputs = vision_models['florence_processor']( | |
text=prompt, | |
images=image, | |
return_tensors="pt" | |
).to(device, torch.float16 if torch.cuda.is_available() else torch.float32) | |
generated_ids = vision_models['florence_model'].generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=1024, | |
num_beams=3, | |
do_sample=False | |
) | |
generated_text = vision_models['florence_processor'].batch_decode(generated_ids, skip_special_tokens=True)[0] | |
parsed_answer = vision_models['florence_processor'].post_process_generation( | |
generated_text, task="<DETAILED_CAPTION>", | |
image_size=(image.width, image.height) | |
) | |
clear_cuda_memory() | |
return parsed_answer['<DETAILED_CAPTION>'] | |
def crop_face(image_path, output_dir, output_name, scale_factor=4.0): | |
image = Image.open(image_path).convert("RGB") | |
img_raw = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
img_raw = img_raw.astype(np.float32) | |
rf = Pytorch_RetinaFace( | |
cfg='mobile0.25', | |
pretrained_path='./weights/mobilenet0.25_Final.pth', | |
confidence_threshold=0.02, | |
nms_threshold=0.4, | |
vis_thres=0.6 | |
) | |
dets = rf.detect_faces(img_raw) | |
print("Dets: ", dets) | |
# Instead of asserting, handle multiple faces gracefully | |
if len(dets) == 0: | |
print("No faces detected!") | |
return False | |
# If multiple faces detected, use the one with highest confidence | |
if len(dets) > 1: | |
print(f"Warning: {len(dets)} faces detected, using the one with highest confidence") | |
# Assuming dets is a list of [bbox, landmark, score] and we want to sort by score | |
dets = sorted(dets, key=lambda x: x[2], reverse=True) # Sort by confidence score | |
# Just keep the highest confidence detection | |
dets = [dets[0]] | |
# Pass the scale_factor to center_and_crop_rescale for adjustable crop size | |
try: | |
# Unpack the tuple correctly - the function returns (cropped_imgs, bbox_infos) | |
cropped_imgs, bbox_infos = rf.center_and_crop_rescale(img_raw, dets, shift_factor=0.45, scale_factor=scale_factor) | |
# Check if we got any cropped images | |
if not cropped_imgs or len(cropped_imgs) == 0: | |
print("No cropped images returned") | |
return False | |
# Use the first cropped face image directly - it's not nested | |
img_to_save = cropped_imgs[0] | |
os.makedirs(output_dir, exist_ok=True) | |
cv2.imwrite(os.path.join(output_dir, output_name), img_to_save) | |
print(f"Saved: {output_name}") | |
return True | |
except Exception as e: | |
print(f"Error during face cropping: {e}") | |
return False | |
async def upscale_image(image_path, output_path): | |
"""Upscale an image using fal.ai's RealESRGAN model""" | |
fal_client = FalClient() | |
# Read and encode the image | |
with open(image_path, "rb") as image_file: | |
encoded_image = base64.b64encode(image_file.read()).decode('utf-8') | |
data_uri = f"data:image/jpeg;base64,{encoded_image}" | |
try: | |
# Submit the upscaling request | |
handler = await fal_client.submit_async( | |
"fal-ai/real-esrgan", | |
arguments={ | |
"image_url": data_uri, | |
"scale": 2, | |
"model": "RealESRGAN_x4plus", | |
"output_format": "png", | |
"face": True | |
}, | |
) | |
result = await handler.get() | |
# Download and save the upscaled image | |
image_url = result['image_url'] | |
async with aiohttp.ClientSession() as session: | |
async with session.get(image_url) as response: | |
if response.status == 200: | |
with open(output_path, 'wb') as f: | |
f.write(await response.read()) | |
return True | |
else: | |
print(f"Failed to download upscaled image: {response.status}") | |
return False | |
except Exception as e: | |
print(f"Error during upscaling: {e}") | |
return False | |