Easel_AI_Engineering / segmentation.py
Cero72
Add garment segmentation application files
2dc21ea
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import io
from collections import Counter
import gradio as gr
from models import segformer_model, segformer_processor
from constants import class_names, color_map
def segment_image(image, selected_classes=None, show_original=True, show_segmentation=True, show_overlay=True, fixed_size=(400, 400)):
"""Segment the image based on selected classes with consistent output sizes"""
# Process the image
inputs = segformer_processor(images=image, return_tensors="pt")
# Get model predictions
outputs = segformer_model(**inputs)
logits = outputs.logits.cpu()
# Upsample the logits to match the original image size
upsampled_logits = nn.functional.interpolate(
logits,
size=image.size[::-1], # (height, width)
mode="bilinear",
align_corners=False,
)
# Get the predicted segmentation map
pred_seg = upsampled_logits.argmax(dim=1)[0].numpy()
# Filter classes if specified
if selected_classes and len(selected_classes) > 0:
# Create a mask for selected classes
mask = np.zeros_like(pred_seg, dtype=bool)
for class_name in selected_classes:
if class_name in class_names:
class_idx = class_names.index(class_name)
mask = np.logical_or(mask, pred_seg == class_idx)
# Apply the mask to keep only selected classes, set others to background (0)
filtered_seg = np.zeros_like(pred_seg)
filtered_seg[mask] = pred_seg[mask]
pred_seg = filtered_seg
# Create a colored segmentation map
colored_seg = np.zeros((pred_seg.shape[0], pred_seg.shape[1], 3))
for class_idx in range(len(class_names)):
mask = pred_seg == class_idx
if mask.any():
colored_seg[mask] = color_map(class_idx)[:3]
# Create an overlay of the segmentation on the original image
image_array = np.array(image)
overlay = image_array.copy()
alpha = 0.5 # Transparency factor
mask = pred_seg > 0 # Exclude background
overlay[mask] = overlay[mask] * (1 - alpha) + colored_seg[mask] * 255 * alpha
# Prepare output images based on user selection
outputs = []
if show_original:
# Resize original image to ensure consistent size
resized_original = image.resize(fixed_size)
outputs.append(resized_original)
if show_segmentation:
seg_image = Image.fromarray((colored_seg * 255).astype('uint8'))
# Ensure segmentation has consistent size
seg_image = seg_image.resize(fixed_size)
outputs.append(seg_image)
if show_overlay:
overlay_image = Image.fromarray(overlay.astype('uint8'))
# Ensure overlay has consistent size
overlay_image = overlay_image.resize(fixed_size)
outputs.append(overlay_image)
# Create a legend for the segmentation classes
fig, ax = plt.subplots(figsize=(10, 2))
fig.patch.set_alpha(0.0)
ax.axis('off')
# Create legend patches
legend_elements = []
for i, class_name in enumerate(class_names):
if i == 0 and selected_classes: # Skip background in legend if filtering
continue
if not selected_classes or class_name in selected_classes:
color = color_map(i)[:3]
legend_elements.append(plt.Rectangle((0, 0), 1, 1, color=color))
# Only add legend if there are elements to show
if legend_elements:
legend_class_names = [name for name in class_names if name != "Background" and (not selected_classes or name in selected_classes)]
ax.legend(legend_elements, legend_class_names, loc='center', ncol=6)
# Save the legend to a bytes buffer
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', transparent=True)
buf.seek(0)
legend_img = Image.open(buf)
plt.close(fig)
outputs.append(legend_img)
return outputs
def identify_garment_segformer(image):
"""Identify the dominant garment type using SegFormer"""
# Process the image
inputs = segformer_processor(images=image, return_tensors="pt")
# Get model predictions
outputs = segformer_model(**inputs)
logits = outputs.logits.cpu()
# Upsample the logits to match the original image size
upsampled_logits = nn.functional.interpolate(
logits,
size=image.size[::-1], # (height, width)
mode="bilinear",
align_corners=False,
)
# Get the predicted segmentation map
pred_seg = upsampled_logits.argmax(dim=1)[0].numpy()
# Count the pixels for each class (excluding background)
class_counts = Counter(pred_seg.flatten())
if 0 in class_counts: # Remove background
del class_counts[0]
# Find the most common clothing item
clothing_classes = [4, 5, 6, 7] # Upper-clothes, Skirt, Pants, Dress
# Filter to only include clothing items
clothing_counts = {k: v for k, v in class_counts.items() if k in clothing_classes}
if not clothing_counts:
return "No garment detected", None
# Get the most common clothing item
dominant_class = max(clothing_counts.items(), key=lambda x: x[1])[0]
dominant_class_name = class_names[dominant_class]
return dominant_class_name, dominant_class