Spaces:
Running
Running
Cero72
commited on
Commit
Β·
2dc21ea
1
Parent(s):
c3509c8
Add garment segmentation application files
Browse files- README.md +25 -6
- app.py +138 -0
- classification.py +71 -0
- constants.py +168 -0
- models.py +21 -0
- requirements.txt +8 -0
- segmentation.py +150 -0
- utils.py +34 -0
README.md
CHANGED
@@ -1,12 +1,31 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: Garment-based Segmentation
|
3 |
+
emoji: π
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: blue
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.19.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
# Garment-based Segmentation with SegFormer and Fashion-CLIP
|
13 |
+
|
14 |
+
This application uses AI models to segment specific clothing items in images by matching a garment to a person.
|
15 |
+
|
16 |
+
## Features
|
17 |
+
|
18 |
+
- **Targeted Garment Segmentation**: Segment only the specific garment type that matches a reference image
|
19 |
+
- **Interactive Interface**: User-friendly interface with clear sections for person and garment images
|
20 |
+
- **Visualization Options**: View original images, segmentation maps, and overlays
|
21 |
+
|
22 |
+
## How to Use
|
23 |
+
|
24 |
+
1. Upload an image of a person wearing clothes
|
25 |
+
2. Upload a reference image of a garment (e.g., a t-shirt, pants, dress)
|
26 |
+
3. Click "Process Images" to generate the targeted segmentation
|
27 |
+
4. View the results in the gallery
|
28 |
+
|
29 |
+
## Technical Details
|
30 |
+
|
31 |
+
This application combines SegFormer for image segmentation with Fashion-CLIP for garment classification to create a system that can identify and segment specific garment types in images.
|
app.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
|
4 |
+
# Import from our modules
|
5 |
+
from constants import class_names
|
6 |
+
from segmentation import segment_image
|
7 |
+
from utils import process_url, process_person_and_garment
|
8 |
+
|
9 |
+
# Define fixed size for consistent image display
|
10 |
+
FIXED_IMAGE_SIZE = (400, 400)
|
11 |
+
|
12 |
+
def create_interface():
|
13 |
+
"""Create the Gradio interface with improved image consistency"""
|
14 |
+
with gr.Blocks(title="Garment-based Segmentation") as demo:
|
15 |
+
gr.Markdown("""
|
16 |
+
# Garment-based Segmentation with SegFormer and Fashion-CLIP
|
17 |
+
|
18 |
+
This application uses AI models to segment specific clothing items in images by matching a garment to a person.
|
19 |
+
""")
|
20 |
+
|
21 |
+
with gr.Row():
|
22 |
+
with gr.Column(scale=1):
|
23 |
+
# Person image section
|
24 |
+
gr.Markdown("### Person Image")
|
25 |
+
person_image = gr.Image(
|
26 |
+
type="pil",
|
27 |
+
label="Upload a person wearing clothes",
|
28 |
+
height=300,
|
29 |
+
sources=["upload", "webcam", "clipboard"],
|
30 |
+
elem_id="person-image-upload"
|
31 |
+
)
|
32 |
+
|
33 |
+
# Garment image section
|
34 |
+
gr.Markdown("### Garment Image")
|
35 |
+
garment_image = gr.Image(
|
36 |
+
type="pil",
|
37 |
+
label="Upload a garment to detect",
|
38 |
+
height=300,
|
39 |
+
sources=["upload", "webcam", "clipboard"],
|
40 |
+
elem_id="garment-image-upload"
|
41 |
+
)
|
42 |
+
|
43 |
+
with gr.Row():
|
44 |
+
show_original_dual = gr.Checkbox(label="Show Original", value=True)
|
45 |
+
show_segmentation_dual = gr.Checkbox(label="Show Segmentation", value=True)
|
46 |
+
show_overlay_dual = gr.Checkbox(label="Show Overlay", value=True)
|
47 |
+
|
48 |
+
process_button = gr.Button(
|
49 |
+
"Process Images",
|
50 |
+
variant="primary",
|
51 |
+
size="lg",
|
52 |
+
elem_id="process-button"
|
53 |
+
)
|
54 |
+
|
55 |
+
with gr.Column(scale=2):
|
56 |
+
dual_output_images = gr.Gallery(
|
57 |
+
label="Results",
|
58 |
+
columns=3,
|
59 |
+
height=450,
|
60 |
+
object_fit="contain",
|
61 |
+
elem_id="dual_gallery"
|
62 |
+
)
|
63 |
+
result_text = gr.Textbox(label="Result", interactive=False, lines=4)
|
64 |
+
|
65 |
+
# Set up event handler for dual image processing
|
66 |
+
process_button.click(
|
67 |
+
fn=lambda p_img, g_img, orig, seg, over: process_person_and_garment(p_img, g_img, orig, seg, over, FIXED_IMAGE_SIZE),
|
68 |
+
inputs=[person_image, garment_image, show_original_dual, show_segmentation_dual, show_overlay_dual],
|
69 |
+
outputs=[dual_output_images, result_text]
|
70 |
+
)
|
71 |
+
|
72 |
+
# Add custom CSS for consistent image sizes and improved UI
|
73 |
+
gr.HTML("""
|
74 |
+
<style>
|
75 |
+
.gradio-container img {
|
76 |
+
max-height: 400px !important;
|
77 |
+
object-fit: contain !important;
|
78 |
+
}
|
79 |
+
#dual_gallery {
|
80 |
+
min-height: 450px;
|
81 |
+
}
|
82 |
+
/* Larger upload buttons */
|
83 |
+
#person-image-upload .upload-button,
|
84 |
+
#garment-image-upload .upload-button {
|
85 |
+
font-size: 1.2em !important;
|
86 |
+
padding: 12px 20px !important;
|
87 |
+
border-radius: 8px !important;
|
88 |
+
margin: 10px auto !important;
|
89 |
+
display: block !important;
|
90 |
+
width: 80% !important;
|
91 |
+
text-align: center !important;
|
92 |
+
background-color: #4CAF50 !important;
|
93 |
+
color: white !important;
|
94 |
+
border: none !important;
|
95 |
+
cursor: pointer !important;
|
96 |
+
transition: background-color 0.3s ease !important;
|
97 |
+
}
|
98 |
+
#person-image-upload .upload-button:hover,
|
99 |
+
#garment-image-upload .upload-button:hover {
|
100 |
+
background-color: #45a049 !important;
|
101 |
+
}
|
102 |
+
/* Larger process button */
|
103 |
+
#process-button {
|
104 |
+
font-size: 1.3em !important;
|
105 |
+
padding: 15px 25px !important;
|
106 |
+
margin: 15px auto !important;
|
107 |
+
display: block !important;
|
108 |
+
width: 90% !important;
|
109 |
+
}
|
110 |
+
/* Better section headers */
|
111 |
+
h3 {
|
112 |
+
font-size: 1.5em !important;
|
113 |
+
margin-top: 20px !important;
|
114 |
+
margin-bottom: 15px !important;
|
115 |
+
color: #2c3e50 !important;
|
116 |
+
border-bottom: 2px solid #3498db !important;
|
117 |
+
padding-bottom: 8px !important;
|
118 |
+
}
|
119 |
+
/* Better main heading */
|
120 |
+
h1 {
|
121 |
+
color: #2c3e50 !important;
|
122 |
+
text-align: center !important;
|
123 |
+
margin-bottom: 30px !important;
|
124 |
+
font-size: 2.5em !important;
|
125 |
+
}
|
126 |
+
/* Better checkbox layout */
|
127 |
+
.gradio-checkbox {
|
128 |
+
margin: 10px 5px !important;
|
129 |
+
}
|
130 |
+
</style>
|
131 |
+
""")
|
132 |
+
|
133 |
+
return demo
|
134 |
+
|
135 |
+
# Main application entry point
|
136 |
+
if __name__ == "__main__":
|
137 |
+
demo = create_interface()
|
138 |
+
demo.launch()
|
classification.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from models import clip_model, clip_processor, segformer_model, segformer_processor
|
3 |
+
from constants import fashion_categories, fashion_clip_to_segformer, class_names, category_to_segment_mapping, garment_to_segments
|
4 |
+
from segmentation import identify_garment_segformer
|
5 |
+
|
6 |
+
def identify_garment_clip(image):
|
7 |
+
"""Identify the garment type using Fashion-CLIP model"""
|
8 |
+
# Prepare text prompts
|
9 |
+
texts = [f"a photo of a {category}" for category in fashion_categories]
|
10 |
+
|
11 |
+
# Process inputs
|
12 |
+
inputs = clip_processor(text=texts, images=image, return_tensors="pt", padding=True)
|
13 |
+
|
14 |
+
# Get predictions
|
15 |
+
with torch.no_grad():
|
16 |
+
outputs = clip_model(**inputs)
|
17 |
+
logits_per_image = outputs.logits_per_image
|
18 |
+
probs = logits_per_image.softmax(dim=1)
|
19 |
+
|
20 |
+
# Get the top prediction
|
21 |
+
top_idx = torch.argmax(probs[0]).item()
|
22 |
+
top_category = fashion_categories[top_idx]
|
23 |
+
confidence = probs[0][top_idx].item() * 100
|
24 |
+
|
25 |
+
# Map to SegFormer class if possible
|
26 |
+
if top_category in fashion_clip_to_segformer:
|
27 |
+
segformer_idx = fashion_clip_to_segformer[top_category]
|
28 |
+
segformer_class = class_names[segformer_idx]
|
29 |
+
return top_category, segformer_idx, confidence
|
30 |
+
else:
|
31 |
+
# Fallback to using SegFormer directly
|
32 |
+
return top_category, None, confidence
|
33 |
+
|
34 |
+
def get_segments_for_garment(garment_image):
|
35 |
+
"""Get the segments that should be included for a given garment image"""
|
36 |
+
# First try to identify the garment using Fashion-CLIP
|
37 |
+
clip_category, segformer_idx, confidence = identify_garment_clip(garment_image)
|
38 |
+
|
39 |
+
# If CLIP couldn't map to a SegFormer class, fall back to SegFormer
|
40 |
+
if segformer_idx is None:
|
41 |
+
garment_name, segformer_idx = identify_garment_segformer(garment_image)
|
42 |
+
method = "SegFormer (fallback)"
|
43 |
+
confidence_text = ""
|
44 |
+
else:
|
45 |
+
garment_name = class_names[segformer_idx]
|
46 |
+
method = "Fashion-CLIP"
|
47 |
+
confidence_text = f" with {confidence:.2f}% confidence"
|
48 |
+
|
49 |
+
if segformer_idx is None:
|
50 |
+
return None, None, "No clear garment detected in the garment image"
|
51 |
+
|
52 |
+
# Get all segments that should be included based on the detected garment
|
53 |
+
if method == "Fashion-CLIP" and clip_category in category_to_segment_mapping:
|
54 |
+
# Use the detailed mapping for CLIP categories
|
55 |
+
selected_class = category_to_segment_mapping[clip_category]
|
56 |
+
elif segformer_idx in garment_to_segments:
|
57 |
+
# Fall back to the SegFormer class-based mapping
|
58 |
+
segment_indices = garment_to_segments[segformer_idx]
|
59 |
+
selected_class = [class_names[idx] for idx in segment_indices]
|
60 |
+
else:
|
61 |
+
# Fallback to just the detected garment if no mapping exists
|
62 |
+
selected_class = [class_names[segformer_idx]]
|
63 |
+
|
64 |
+
# Prepare a more descriptive result text
|
65 |
+
included_segments = ", ".join(selected_class)
|
66 |
+
if method == "Fashion-CLIP":
|
67 |
+
result_text = f"Detected garment: {clip_category} (mapped to {garment_name})\nUsing {method}{confidence_text}\nSegmented parts: {included_segments}"
|
68 |
+
else:
|
69 |
+
result_text = f"Detected garment: {garment_name}\nUsing {method}\nSegmented parts: {included_segments}"
|
70 |
+
|
71 |
+
return selected_class, segformer_idx, result_text
|
constants.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
|
3 |
+
# Define the class names for the segmentation model
|
4 |
+
class_names = [
|
5 |
+
"Background", "Hat", "Hair", "Sunglasses", "Upper-clothes", "Skirt", "Pants",
|
6 |
+
"Dress", "Belt", "Left-shoe", "Right-shoe", "Face", "Left-leg", "Right-leg",
|
7 |
+
"Left-arm", "Right-arm", "Bag", "Scarf"
|
8 |
+
]
|
9 |
+
|
10 |
+
# Define a color map for visualization
|
11 |
+
color_map = plt.cm.get_cmap('tab20', len(class_names))
|
12 |
+
|
13 |
+
# Define a mapping of garment types to related segments that should be included
|
14 |
+
garment_to_segments = {
|
15 |
+
0: [0], # Background --> segment background only
|
16 |
+
1: [1, 2, 11], # Hat --> segment hat, hair, and face
|
17 |
+
2: [2], # Hair --> segment hair only
|
18 |
+
3: [3, 11], # Sunglasses --> segment sunglasses and face
|
19 |
+
4: [4, 14, 15], # Upper-clothes --> segment upper clothes, left arm, right arm
|
20 |
+
5: [5, 6, 12, 13], # Skirt --> segment skirt, pants, left leg, right leg
|
21 |
+
6: [6, 12, 13], # Pants --> segment pants, left leg, right leg
|
22 |
+
7: [4, 5, 6, 7, 12, 13, 14, 15], # Dress --> segment whole body except face and hair
|
23 |
+
8: [8], # Belt --> segment belt only
|
24 |
+
9: [9], # Left-shoe --> segment left shoe only
|
25 |
+
10: [10], # Right-shoe --> segment right shoe only
|
26 |
+
11: [11], # Face --> segment face only
|
27 |
+
12: [12], # Left-leg --> segment left leg only
|
28 |
+
13: [13], # Right-leg --> segment right leg only
|
29 |
+
14: [14], # Left-arm --> segment left arm only
|
30 |
+
15: [15], # Right-arm --> segment right arm only
|
31 |
+
16: [16], # Bag --> segment bag only
|
32 |
+
17: [17, 2, 11] # Scarf --> segment scarf, hair and face
|
33 |
+
}
|
34 |
+
|
35 |
+
# Define categories for Fashion-CLIP
|
36 |
+
fashion_categories = [
|
37 |
+
# Upper body
|
38 |
+
"t-shirt", "shirt", "blouse", "tank top", "polo shirt", "sweatshirt", "hoodie",
|
39 |
+
|
40 |
+
# Outerwear
|
41 |
+
"jacket", "coat", "blazer", "cardigan", "vest", "windbreaker",
|
42 |
+
|
43 |
+
# Dresses
|
44 |
+
"dress", "shirt dress", "sundress", "evening gown", "maxi dress", "mini dress",
|
45 |
+
|
46 |
+
# Lower body
|
47 |
+
"jeans", "pants", "trousers", "shorts", "skirt", "leggings", "joggers", "sweatpants",
|
48 |
+
|
49 |
+
# Formal wear
|
50 |
+
"suit", "tuxedo", "formal shirt", "formal dress",
|
51 |
+
|
52 |
+
# Undergarments
|
53 |
+
"bra", "underwear", "boxers", "briefs", "lingerie",
|
54 |
+
|
55 |
+
# Sleepwear
|
56 |
+
"pajamas", "nightgown", "bathrobe",
|
57 |
+
|
58 |
+
# Swimwear
|
59 |
+
"swimsuit", "bikini", "swim trunks",
|
60 |
+
|
61 |
+
# Footwear
|
62 |
+
"shoes", "boots", "sneakers", "sandals", "high heels", "loafers", "flats",
|
63 |
+
|
64 |
+
# Accessories
|
65 |
+
"hat", "cap", "scarf", "gloves", "belt", "tie", "socks",
|
66 |
+
|
67 |
+
# Bags
|
68 |
+
"handbag", "backpack", "purse", "tote bag"
|
69 |
+
]
|
70 |
+
|
71 |
+
# Mapping from Fashion-CLIP categories to SegFormer classes
|
72 |
+
fashion_clip_to_segformer = {
|
73 |
+
# Upper body items -> Upper-clothes (4)
|
74 |
+
"t-shirt": 4, "shirt": 4, "blouse": 4, "tank top": 4, "polo shirt": 4, "sweatshirt": 4, "hoodie": 4,
|
75 |
+
"cardigan": 4, "vest": 4, "formal shirt": 4,
|
76 |
+
|
77 |
+
# Outerwear -> Upper-clothes (4)
|
78 |
+
"jacket": 4, "coat": 4, "blazer": 4, "windbreaker": 4,
|
79 |
+
|
80 |
+
# Dresses -> Dress (7)
|
81 |
+
"dress": 7, "shirt dress": 7, "sundress": 7, "evening gown": 7, "maxi dress": 7, "mini dress": 7,
|
82 |
+
"formal dress": 7,
|
83 |
+
|
84 |
+
# Lower body -> Pants (6) or Skirt (5)
|
85 |
+
"jeans": 6, "pants": 6, "trousers": 6, "shorts": 6, "leggings": 6, "joggers": 6, "sweatpants": 6,
|
86 |
+
"skirt": 5,
|
87 |
+
|
88 |
+
# Formal wear -> Upper-clothes (4) or Dress (7)
|
89 |
+
"suit": 4, "tuxedo": 4,
|
90 |
+
|
91 |
+
# Footwear -> Left-shoe/Right-shoe (9/10)
|
92 |
+
"shoes": 9, "boots": 9, "sneakers": 9, "sandals": 9, "high heels": 9, "loafers": 9, "flats": 9,
|
93 |
+
|
94 |
+
# Accessories
|
95 |
+
"hat": 1, "cap": 1, "scarf": 17, "belt": 8,
|
96 |
+
|
97 |
+
# Bags
|
98 |
+
"handbag": 16, "backpack": 16, "purse": 16, "tote bag": 16
|
99 |
+
}
|
100 |
+
|
101 |
+
# Detailed mapping from categories to segment names
|
102 |
+
category_to_segment_mapping = {
|
103 |
+
# Upper body items map to Upper-clothes and arms
|
104 |
+
"t-shirt": ["Upper-clothes", "Left-arm", "Right-arm"],
|
105 |
+
"shirt": ["Upper-clothes", "Left-arm", "Right-arm"],
|
106 |
+
"blouse": ["Upper-clothes", "Left-arm", "Right-arm"],
|
107 |
+
"tank top": ["Upper-clothes", "Left-arm", "Right-arm"],
|
108 |
+
"polo shirt": ["Upper-clothes", "Left-arm", "Right-arm"],
|
109 |
+
"sweatshirt": ["Upper-clothes", "Left-arm", "Right-arm"],
|
110 |
+
"hoodie": ["Upper-clothes", "Left-arm", "Right-arm"],
|
111 |
+
|
112 |
+
# Outerwear maps to Upper-clothes and arms
|
113 |
+
"jacket": ["Upper-clothes", "Left-arm", "Right-arm"],
|
114 |
+
"coat": ["Upper-clothes", "Left-arm", "Right-arm"],
|
115 |
+
"blazer": ["Upper-clothes", "Left-arm", "Right-arm"],
|
116 |
+
"cardigan": ["Upper-clothes", "Left-arm", "Right-arm"],
|
117 |
+
"vest": ["Upper-clothes"],
|
118 |
+
"windbreaker": ["Upper-clothes", "Left-arm", "Right-arm"],
|
119 |
+
|
120 |
+
# Dresses map to Dress
|
121 |
+
"dress": ["Dress", "Upper-clothes", "Skirt", "Pants", "Left-arm", "Right-arm", "Left-leg", "Right-leg"],
|
122 |
+
"shirt dress": ["Dress", "Upper-clothes", "Skirt", "Pants", "Left-arm", "Right-arm", "Left-leg", "Right-leg"],
|
123 |
+
"sundress": ["Dress", "Upper-clothes", "Skirt", "Pants", "Left-arm", "Right-arm", "Left-leg", "Right-leg"],
|
124 |
+
"evening gown": ["Dress", "Upper-clothes", "Skirt", "Pants", "Left-arm", "Right-arm", "Left-leg", "Right-leg"],
|
125 |
+
"maxi dress": ["Dress", "Upper-clothes", "Skirt", "Pants", "Left-arm", "Right-arm", "Left-leg", "Right-leg"],
|
126 |
+
"mini dress": ["Dress", "Upper-clothes", "Skirt", "Pants", "Left-arm", "Right-arm", "Left-leg", "Right-leg"],
|
127 |
+
"formal dress": ["Dress", "Upper-clothes", "Skirt", "Pants", "Left-arm", "Right-arm", "Left-leg", "Right-leg"],
|
128 |
+
|
129 |
+
# Lower body items map to Pants or Skirt and legs
|
130 |
+
"jeans": ["Pants", "Left-leg", "Right-leg"],
|
131 |
+
"pants": ["Pants", "Left-leg", "Right-leg"],
|
132 |
+
"trousers": ["Pants", "Left-leg", "Right-leg"],
|
133 |
+
"shorts": ["Pants", "Left-leg", "Right-leg"],
|
134 |
+
"skirt": ["Skirt", "Pants", "Left-leg", "Right-leg"],
|
135 |
+
"leggings": ["Pants", "Left-leg", "Right-leg"],
|
136 |
+
"joggers": ["Pants", "Left-leg", "Right-leg"],
|
137 |
+
"sweatpants": ["Pants", "Left-leg", "Right-leg"],
|
138 |
+
|
139 |
+
# Formal wear maps depending on type
|
140 |
+
"suit": ["Upper-clothes", "Left-arm", "Right-arm", "Pants", "Left-leg", "Right-leg"],
|
141 |
+
"tuxedo": ["Upper-clothes", "Left-arm", "Right-arm", "Pants", "Left-leg", "Right-leg"],
|
142 |
+
"formal shirt": ["Upper-clothes", "Left-arm", "Right-arm"],
|
143 |
+
"formal dress": ["Dress", "Upper-clothes", "Skirt", "Pants", "Left-arm", "Right-arm", "Left-leg", "Right-leg"],
|
144 |
+
|
145 |
+
# Footwear maps to shoes
|
146 |
+
"shoes": ["Left-shoe", "Right-shoe"],
|
147 |
+
"boots": ["Left-shoe", "Right-shoe"],
|
148 |
+
"sneakers": ["Left-shoe", "Right-shoe"],
|
149 |
+
"sandals": ["Left-shoe", "Right-shoe"],
|
150 |
+
"high heels": ["Left-shoe", "Right-shoe"],
|
151 |
+
"loafers": ["Left-shoe", "Right-shoe"],
|
152 |
+
"flats": ["Left-shoe", "Right-shoe"],
|
153 |
+
|
154 |
+
# Accessories map to their respective parts
|
155 |
+
"hat": ["Hat"],
|
156 |
+
"cap": ["Hat"],
|
157 |
+
"scarf": ["Scarf", "Face", "Hair"],
|
158 |
+
"gloves": ["Left-arm", "Right-arm"],
|
159 |
+
"belt": ["Belt"],
|
160 |
+
"tie": ["Upper-clothes"],
|
161 |
+
"socks": ["Left-leg", "Right-leg"],
|
162 |
+
|
163 |
+
# Bags map to Bag
|
164 |
+
"handbag": ["Bag"],
|
165 |
+
"backpack": ["Bag"],
|
166 |
+
"purse": ["Bag"],
|
167 |
+
"tote bag": ["Bag"]
|
168 |
+
}
|
models.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation, CLIPProcessor, CLIPModel
|
4 |
+
|
5 |
+
# Load the SegFormer model and processor for segmentation
|
6 |
+
def load_segformer_model():
|
7 |
+
"""Load and return the SegFormer model and processor"""
|
8 |
+
processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
|
9 |
+
model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
|
10 |
+
return processor, model
|
11 |
+
|
12 |
+
# Load Fashion-CLIP model for garment classification
|
13 |
+
def load_clip_model():
|
14 |
+
"""Load and return the Fashion-CLIP model and processor"""
|
15 |
+
model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip")
|
16 |
+
processor = CLIPProcessor.from_pretrained("patrickjohncyh/fashion-clip")
|
17 |
+
return model, processor
|
18 |
+
|
19 |
+
# Initialize models
|
20 |
+
segformer_processor, segformer_model = load_segformer_model()
|
21 |
+
clip_model, clip_processor = load_clip_model()
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=1.10.0
|
2 |
+
torchvision>=0.11.1
|
3 |
+
transformers>=4.15.0
|
4 |
+
Pillow>=8.4.0
|
5 |
+
requests>=2.26.0
|
6 |
+
matplotlib>=3.5.0
|
7 |
+
gradio>=3.19.0
|
8 |
+
numpy>=1.21.0
|
segmentation.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from PIL import Image
|
6 |
+
import io
|
7 |
+
from collections import Counter
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
from models import segformer_model, segformer_processor
|
11 |
+
from constants import class_names, color_map
|
12 |
+
|
13 |
+
def segment_image(image, selected_classes=None, show_original=True, show_segmentation=True, show_overlay=True, fixed_size=(400, 400)):
|
14 |
+
"""Segment the image based on selected classes with consistent output sizes"""
|
15 |
+
# Process the image
|
16 |
+
inputs = segformer_processor(images=image, return_tensors="pt")
|
17 |
+
|
18 |
+
# Get model predictions
|
19 |
+
outputs = segformer_model(**inputs)
|
20 |
+
logits = outputs.logits.cpu()
|
21 |
+
|
22 |
+
# Upsample the logits to match the original image size
|
23 |
+
upsampled_logits = nn.functional.interpolate(
|
24 |
+
logits,
|
25 |
+
size=image.size[::-1], # (height, width)
|
26 |
+
mode="bilinear",
|
27 |
+
align_corners=False,
|
28 |
+
)
|
29 |
+
|
30 |
+
# Get the predicted segmentation map
|
31 |
+
pred_seg = upsampled_logits.argmax(dim=1)[0].numpy()
|
32 |
+
|
33 |
+
# Filter classes if specified
|
34 |
+
if selected_classes and len(selected_classes) > 0:
|
35 |
+
# Create a mask for selected classes
|
36 |
+
mask = np.zeros_like(pred_seg, dtype=bool)
|
37 |
+
for class_name in selected_classes:
|
38 |
+
if class_name in class_names:
|
39 |
+
class_idx = class_names.index(class_name)
|
40 |
+
mask = np.logical_or(mask, pred_seg == class_idx)
|
41 |
+
|
42 |
+
# Apply the mask to keep only selected classes, set others to background (0)
|
43 |
+
filtered_seg = np.zeros_like(pred_seg)
|
44 |
+
filtered_seg[mask] = pred_seg[mask]
|
45 |
+
pred_seg = filtered_seg
|
46 |
+
|
47 |
+
# Create a colored segmentation map
|
48 |
+
colored_seg = np.zeros((pred_seg.shape[0], pred_seg.shape[1], 3))
|
49 |
+
for class_idx in range(len(class_names)):
|
50 |
+
mask = pred_seg == class_idx
|
51 |
+
if mask.any():
|
52 |
+
colored_seg[mask] = color_map(class_idx)[:3]
|
53 |
+
|
54 |
+
# Create an overlay of the segmentation on the original image
|
55 |
+
image_array = np.array(image)
|
56 |
+
overlay = image_array.copy()
|
57 |
+
alpha = 0.5 # Transparency factor
|
58 |
+
mask = pred_seg > 0 # Exclude background
|
59 |
+
overlay[mask] = overlay[mask] * (1 - alpha) + colored_seg[mask] * 255 * alpha
|
60 |
+
|
61 |
+
# Prepare output images based on user selection
|
62 |
+
outputs = []
|
63 |
+
|
64 |
+
if show_original:
|
65 |
+
# Resize original image to ensure consistent size
|
66 |
+
resized_original = image.resize(fixed_size)
|
67 |
+
outputs.append(resized_original)
|
68 |
+
|
69 |
+
if show_segmentation:
|
70 |
+
seg_image = Image.fromarray((colored_seg * 255).astype('uint8'))
|
71 |
+
# Ensure segmentation has consistent size
|
72 |
+
seg_image = seg_image.resize(fixed_size)
|
73 |
+
outputs.append(seg_image)
|
74 |
+
|
75 |
+
if show_overlay:
|
76 |
+
overlay_image = Image.fromarray(overlay.astype('uint8'))
|
77 |
+
# Ensure overlay has consistent size
|
78 |
+
overlay_image = overlay_image.resize(fixed_size)
|
79 |
+
outputs.append(overlay_image)
|
80 |
+
|
81 |
+
# Create a legend for the segmentation classes
|
82 |
+
fig, ax = plt.subplots(figsize=(10, 2))
|
83 |
+
fig.patch.set_alpha(0.0)
|
84 |
+
ax.axis('off')
|
85 |
+
|
86 |
+
# Create legend patches
|
87 |
+
legend_elements = []
|
88 |
+
for i, class_name in enumerate(class_names):
|
89 |
+
if i == 0 and selected_classes: # Skip background in legend if filtering
|
90 |
+
continue
|
91 |
+
if not selected_classes or class_name in selected_classes:
|
92 |
+
color = color_map(i)[:3]
|
93 |
+
legend_elements.append(plt.Rectangle((0, 0), 1, 1, color=color))
|
94 |
+
|
95 |
+
# Only add legend if there are elements to show
|
96 |
+
if legend_elements:
|
97 |
+
legend_class_names = [name for name in class_names if name != "Background" and (not selected_classes or name in selected_classes)]
|
98 |
+
ax.legend(legend_elements, legend_class_names, loc='center', ncol=6)
|
99 |
+
|
100 |
+
# Save the legend to a bytes buffer
|
101 |
+
buf = io.BytesIO()
|
102 |
+
plt.savefig(buf, format='png', bbox_inches='tight', transparent=True)
|
103 |
+
buf.seek(0)
|
104 |
+
legend_img = Image.open(buf)
|
105 |
+
|
106 |
+
plt.close(fig)
|
107 |
+
|
108 |
+
outputs.append(legend_img)
|
109 |
+
|
110 |
+
return outputs
|
111 |
+
|
112 |
+
def identify_garment_segformer(image):
|
113 |
+
"""Identify the dominant garment type using SegFormer"""
|
114 |
+
# Process the image
|
115 |
+
inputs = segformer_processor(images=image, return_tensors="pt")
|
116 |
+
|
117 |
+
# Get model predictions
|
118 |
+
outputs = segformer_model(**inputs)
|
119 |
+
logits = outputs.logits.cpu()
|
120 |
+
|
121 |
+
# Upsample the logits to match the original image size
|
122 |
+
upsampled_logits = nn.functional.interpolate(
|
123 |
+
logits,
|
124 |
+
size=image.size[::-1], # (height, width)
|
125 |
+
mode="bilinear",
|
126 |
+
align_corners=False,
|
127 |
+
)
|
128 |
+
|
129 |
+
# Get the predicted segmentation map
|
130 |
+
pred_seg = upsampled_logits.argmax(dim=1)[0].numpy()
|
131 |
+
|
132 |
+
# Count the pixels for each class (excluding background)
|
133 |
+
class_counts = Counter(pred_seg.flatten())
|
134 |
+
if 0 in class_counts: # Remove background
|
135 |
+
del class_counts[0]
|
136 |
+
|
137 |
+
# Find the most common clothing item
|
138 |
+
clothing_classes = [4, 5, 6, 7] # Upper-clothes, Skirt, Pants, Dress
|
139 |
+
|
140 |
+
# Filter to only include clothing items
|
141 |
+
clothing_counts = {k: v for k, v in class_counts.items() if k in clothing_classes}
|
142 |
+
|
143 |
+
if not clothing_counts:
|
144 |
+
return "No garment detected", None
|
145 |
+
|
146 |
+
# Get the most common clothing item
|
147 |
+
dominant_class = max(clothing_counts.items(), key=lambda x: x[1])[0]
|
148 |
+
dominant_class_name = class_names[dominant_class]
|
149 |
+
|
150 |
+
return dominant_class_name, dominant_class
|
utils.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from PIL import Image
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
from segmentation import segment_image
|
6 |
+
from classification import get_segments_for_garment
|
7 |
+
|
8 |
+
def process_url(url, selected_classes, show_original, show_segmentation, show_overlay, fixed_size=(400, 400)):
|
9 |
+
"""Process an image from a URL"""
|
10 |
+
try:
|
11 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
12 |
+
return segment_image(image, selected_classes, show_original, show_segmentation, show_overlay, fixed_size)
|
13 |
+
except Exception as e:
|
14 |
+
return [gr.update(value=None)] * 4, f"Error: {str(e)}"
|
15 |
+
|
16 |
+
def process_person_and_garment(person_image, garment_image, show_original, show_segmentation, show_overlay, fixed_size=(400, 400)):
|
17 |
+
"""Process person and garment images for targeted segmentation"""
|
18 |
+
if person_image is None or garment_image is None:
|
19 |
+
return [gr.update(value=None)] * 4, "Please provide both person and garment images"
|
20 |
+
|
21 |
+
try:
|
22 |
+
# Get segments that should be included based on the garment
|
23 |
+
selected_class, segformer_idx, result_text = get_segments_for_garment(garment_image)
|
24 |
+
|
25 |
+
if selected_class is None:
|
26 |
+
return [gr.update(value=None)] * 4, result_text
|
27 |
+
|
28 |
+
# Process the person image with the selected garment classes
|
29 |
+
result_images = segment_image(person_image, selected_class, show_original, show_segmentation, show_overlay, fixed_size)
|
30 |
+
|
31 |
+
return result_images, result_text
|
32 |
+
|
33 |
+
except Exception as e:
|
34 |
+
return [gr.update(value=None)] * 4, f"Error: {str(e)}"
|