Spaces:
Running
Running
import torch | |
from models import clip_model, clip_processor, segformer_model, segformer_processor | |
from constants import fashion_categories, fashion_clip_to_segformer, class_names, category_to_segment_mapping, garment_to_segments | |
from segmentation import identify_garment_segformer | |
def identify_garment_clip(image): | |
"""Identify the garment type using Fashion-CLIP model""" | |
# Prepare text prompts | |
texts = [f"a photo of a {category}" for category in fashion_categories] | |
# Process inputs | |
inputs = clip_processor(text=texts, images=image, return_tensors="pt", padding=True) | |
# Get predictions | |
with torch.no_grad(): | |
outputs = clip_model(**inputs) | |
logits_per_image = outputs.logits_per_image | |
probs = logits_per_image.softmax(dim=1) | |
# Get the top prediction | |
top_idx = torch.argmax(probs[0]).item() | |
top_category = fashion_categories[top_idx] | |
confidence = probs[0][top_idx].item() * 100 | |
# Map to SegFormer class if possible | |
if top_category in fashion_clip_to_segformer: | |
segformer_idx = fashion_clip_to_segformer[top_category] | |
segformer_class = class_names[segformer_idx] | |
return top_category, segformer_idx, confidence | |
else: | |
# Fallback to using SegFormer directly | |
return top_category, None, confidence | |
def get_segments_for_garment(garment_image): | |
"""Get the segments that should be included for a given garment image""" | |
# First try to identify the garment using Fashion-CLIP | |
clip_category, segformer_idx, confidence = identify_garment_clip(garment_image) | |
# If CLIP couldn't map to a SegFormer class, fall back to SegFormer | |
if segformer_idx is None: | |
garment_name, segformer_idx = identify_garment_segformer(garment_image) | |
method = "SegFormer (fallback)" | |
confidence_text = "" | |
else: | |
garment_name = class_names[segformer_idx] | |
method = "Fashion-CLIP" | |
confidence_text = f" with {confidence:.2f}% confidence" | |
if segformer_idx is None: | |
return None, None, "No clear garment detected in the garment image" | |
# Get all segments that should be included based on the detected garment | |
if method == "Fashion-CLIP" and clip_category in category_to_segment_mapping: | |
# Use the detailed mapping for CLIP categories | |
selected_class = category_to_segment_mapping[clip_category] | |
elif segformer_idx in garment_to_segments: | |
# Fall back to the SegFormer class-based mapping | |
segment_indices = garment_to_segments[segformer_idx] | |
selected_class = [class_names[idx] for idx in segment_indices] | |
else: | |
# Fallback to just the detected garment if no mapping exists | |
selected_class = [class_names[segformer_idx]] | |
# Prepare a more descriptive result text | |
included_segments = ", ".join(selected_class) | |
if method == "Fashion-CLIP": | |
result_text = f"Detected garment: {clip_category} (mapped to {garment_name})\nUsing {method}{confidence_text}\nSegmented parts: {included_segments}" | |
else: | |
result_text = f"Detected garment: {garment_name}\nUsing {method}\nSegmented parts: {included_segments}" | |
return selected_class, segformer_idx, result_text | |