Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import io | |
from collections import Counter | |
import gradio as gr | |
from models import segformer_model, segformer_processor | |
from constants import class_names, color_map | |
def segment_image(image, selected_classes=None, show_original=True, show_segmentation=True, show_overlay=True, fixed_size=(400, 400)): | |
"""Segment the image based on selected classes with consistent output sizes""" | |
# Process the image | |
inputs = segformer_processor(images=image, return_tensors="pt") | |
# Get model predictions | |
outputs = segformer_model(**inputs) | |
logits = outputs.logits.cpu() | |
# Upsample the logits to match the original image size | |
upsampled_logits = nn.functional.interpolate( | |
logits, | |
size=image.size[::-1], # (height, width) | |
mode="bilinear", | |
align_corners=False, | |
) | |
# Get the predicted segmentation map | |
pred_seg = upsampled_logits.argmax(dim=1)[0].numpy() | |
# Filter classes if specified | |
if selected_classes and len(selected_classes) > 0: | |
# Create a mask for selected classes | |
mask = np.zeros_like(pred_seg, dtype=bool) | |
for class_name in selected_classes: | |
if class_name in class_names: | |
class_idx = class_names.index(class_name) | |
mask = np.logical_or(mask, pred_seg == class_idx) | |
# Apply the mask to keep only selected classes, set others to background (0) | |
filtered_seg = np.zeros_like(pred_seg) | |
filtered_seg[mask] = pred_seg[mask] | |
pred_seg = filtered_seg | |
# Create a colored segmentation map | |
colored_seg = np.zeros((pred_seg.shape[0], pred_seg.shape[1], 3)) | |
for class_idx in range(len(class_names)): | |
mask = pred_seg == class_idx | |
if mask.any(): | |
colored_seg[mask] = color_map(class_idx)[:3] | |
# Create an overlay of the segmentation on the original image | |
image_array = np.array(image) | |
overlay = image_array.copy() | |
alpha = 0.5 # Transparency factor | |
mask = pred_seg > 0 # Exclude background | |
overlay[mask] = overlay[mask] * (1 - alpha) + colored_seg[mask] * 255 * alpha | |
# Prepare output images based on user selection | |
outputs = [] | |
if show_original: | |
# Resize original image to ensure consistent size | |
resized_original = image.resize(fixed_size) | |
outputs.append(resized_original) | |
if show_segmentation: | |
seg_image = Image.fromarray((colored_seg * 255).astype('uint8')) | |
# Ensure segmentation has consistent size | |
seg_image = seg_image.resize(fixed_size) | |
outputs.append(seg_image) | |
if show_overlay: | |
overlay_image = Image.fromarray(overlay.astype('uint8')) | |
# Ensure overlay has consistent size | |
overlay_image = overlay_image.resize(fixed_size) | |
outputs.append(overlay_image) | |
# Create a legend for the segmentation classes | |
fig, ax = plt.subplots(figsize=(10, 2)) | |
fig.patch.set_alpha(0.0) | |
ax.axis('off') | |
# Create legend patches | |
legend_elements = [] | |
for i, class_name in enumerate(class_names): | |
if i == 0 and selected_classes: # Skip background in legend if filtering | |
continue | |
if not selected_classes or class_name in selected_classes: | |
color = color_map(i)[:3] | |
legend_elements.append(plt.Rectangle((0, 0), 1, 1, color=color)) | |
# Only add legend if there are elements to show | |
if legend_elements: | |
legend_class_names = [name for name in class_names if name != "Background" and (not selected_classes or name in selected_classes)] | |
ax.legend(legend_elements, legend_class_names, loc='center', ncol=6) | |
# Save the legend to a bytes buffer | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', bbox_inches='tight', transparent=True) | |
buf.seek(0) | |
legend_img = Image.open(buf) | |
plt.close(fig) | |
outputs.append(legend_img) | |
return outputs | |
def identify_garment_segformer(image): | |
"""Identify the dominant garment type using SegFormer""" | |
# Process the image | |
inputs = segformer_processor(images=image, return_tensors="pt") | |
# Get model predictions | |
outputs = segformer_model(**inputs) | |
logits = outputs.logits.cpu() | |
# Upsample the logits to match the original image size | |
upsampled_logits = nn.functional.interpolate( | |
logits, | |
size=image.size[::-1], # (height, width) | |
mode="bilinear", | |
align_corners=False, | |
) | |
# Get the predicted segmentation map | |
pred_seg = upsampled_logits.argmax(dim=1)[0].numpy() | |
# Count the pixels for each class (excluding background) | |
class_counts = Counter(pred_seg.flatten()) | |
if 0 in class_counts: # Remove background | |
del class_counts[0] | |
# Find the most common clothing item | |
clothing_classes = [4, 5, 6, 7] # Upper-clothes, Skirt, Pants, Dress | |
# Filter to only include clothing items | |
clothing_counts = {k: v for k, v in class_counts.items() if k in clothing_classes} | |
if not clothing_counts: | |
return "No garment detected", None | |
# Get the most common clothing item | |
dominant_class = max(clothing_counts.items(), key=lambda x: x[1])[0] | |
dominant_class_name = class_names[dominant_class] | |
return dominant_class_name, dominant_class | |