File size: 5,451 Bytes
2dc21ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import io
from collections import Counter
import gradio as gr

from models import segformer_model, segformer_processor
from constants import class_names, color_map

def segment_image(image, selected_classes=None, show_original=True, show_segmentation=True, show_overlay=True, fixed_size=(400, 400)):
    """Segment the image based on selected classes with consistent output sizes"""
    # Process the image
    inputs = segformer_processor(images=image, return_tensors="pt")
    
    # Get model predictions
    outputs = segformer_model(**inputs)
    logits = outputs.logits.cpu()
    
    # Upsample the logits to match the original image size
    upsampled_logits = nn.functional.interpolate(
        logits,
        size=image.size[::-1],  # (height, width)
        mode="bilinear",
        align_corners=False,
    )
    
    # Get the predicted segmentation map
    pred_seg = upsampled_logits.argmax(dim=1)[0].numpy()
    
    # Filter classes if specified
    if selected_classes and len(selected_classes) > 0:
        # Create a mask for selected classes
        mask = np.zeros_like(pred_seg, dtype=bool)
        for class_name in selected_classes:
            if class_name in class_names:
                class_idx = class_names.index(class_name)
                mask = np.logical_or(mask, pred_seg == class_idx)
        
        # Apply the mask to keep only selected classes, set others to background (0)
        filtered_seg = np.zeros_like(pred_seg)
        filtered_seg[mask] = pred_seg[mask]
        pred_seg = filtered_seg
    
    # Create a colored segmentation map
    colored_seg = np.zeros((pred_seg.shape[0], pred_seg.shape[1], 3))
    for class_idx in range(len(class_names)):
        mask = pred_seg == class_idx
        if mask.any():
            colored_seg[mask] = color_map(class_idx)[:3]
    
    # Create an overlay of the segmentation on the original image
    image_array = np.array(image)
    overlay = image_array.copy()
    alpha = 0.5  # Transparency factor
    mask = pred_seg > 0  # Exclude background
    overlay[mask] = overlay[mask] * (1 - alpha) + colored_seg[mask] * 255 * alpha
    
    # Prepare output images based on user selection
    outputs = []
    
    if show_original:
        # Resize original image to ensure consistent size
        resized_original = image.resize(fixed_size)
        outputs.append(resized_original)
    
    if show_segmentation:
        seg_image = Image.fromarray((colored_seg * 255).astype('uint8'))
        # Ensure segmentation has consistent size
        seg_image = seg_image.resize(fixed_size)
        outputs.append(seg_image)
    
    if show_overlay:
        overlay_image = Image.fromarray(overlay.astype('uint8'))
        # Ensure overlay has consistent size
        overlay_image = overlay_image.resize(fixed_size)
        outputs.append(overlay_image)
    
    # Create a legend for the segmentation classes
    fig, ax = plt.subplots(figsize=(10, 2))
    fig.patch.set_alpha(0.0)
    ax.axis('off')
    
    # Create legend patches
    legend_elements = []
    for i, class_name in enumerate(class_names):
        if i == 0 and selected_classes:  # Skip background in legend if filtering
            continue
        if not selected_classes or class_name in selected_classes:
            color = color_map(i)[:3]
            legend_elements.append(plt.Rectangle((0, 0), 1, 1, color=color))
    
    # Only add legend if there are elements to show
    if legend_elements:
        legend_class_names = [name for name in class_names if name != "Background" and (not selected_classes or name in selected_classes)]
        ax.legend(legend_elements, legend_class_names, loc='center', ncol=6)
    
    # Save the legend to a bytes buffer
    buf = io.BytesIO()
    plt.savefig(buf, format='png', bbox_inches='tight', transparent=True)
    buf.seek(0)
    legend_img = Image.open(buf)
    
    plt.close(fig)
    
    outputs.append(legend_img)
    
    return outputs

def identify_garment_segformer(image):
    """Identify the dominant garment type using SegFormer"""
    # Process the image
    inputs = segformer_processor(images=image, return_tensors="pt")
    
    # Get model predictions
    outputs = segformer_model(**inputs)
    logits = outputs.logits.cpu()
    
    # Upsample the logits to match the original image size
    upsampled_logits = nn.functional.interpolate(
        logits,
        size=image.size[::-1],  # (height, width)
        mode="bilinear",
        align_corners=False,
    )
    
    # Get the predicted segmentation map
    pred_seg = upsampled_logits.argmax(dim=1)[0].numpy()
    
    # Count the pixels for each class (excluding background)
    class_counts = Counter(pred_seg.flatten())
    if 0 in class_counts:  # Remove background
        del class_counts[0]
    
    # Find the most common clothing item
    clothing_classes = [4, 5, 6, 7]  # Upper-clothes, Skirt, Pants, Dress
    
    # Filter to only include clothing items
    clothing_counts = {k: v for k, v in class_counts.items() if k in clothing_classes}
    
    if not clothing_counts:
        return "No garment detected", None
    
    # Get the most common clothing item
    dominant_class = max(clothing_counts.items(), key=lambda x: x[1])[0]
    dominant_class_name = class_names[dominant_class]
    
    return dominant_class_name, dominant_class