Cero72 commited on
Commit
2dc21ea
Β·
1 Parent(s): c3509c8

Add garment segmentation application files

Browse files
Files changed (8) hide show
  1. README.md +25 -6
  2. app.py +138 -0
  3. classification.py +71 -0
  4. constants.py +168 -0
  5. models.py +21 -0
  6. requirements.txt +8 -0
  7. segmentation.py +150 -0
  8. utils.py +34 -0
README.md CHANGED
@@ -1,12 +1,31 @@
1
  ---
2
- title: Easel AI Engineering
3
- emoji: πŸ‘
4
- colorFrom: blue
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 5.23.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Garment-based Segmentation
3
+ emoji: πŸ‘•
4
+ colorFrom: green
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 3.19.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ # Garment-based Segmentation with SegFormer and Fashion-CLIP
13
+
14
+ This application uses AI models to segment specific clothing items in images by matching a garment to a person.
15
+
16
+ ## Features
17
+
18
+ - **Targeted Garment Segmentation**: Segment only the specific garment type that matches a reference image
19
+ - **Interactive Interface**: User-friendly interface with clear sections for person and garment images
20
+ - **Visualization Options**: View original images, segmentation maps, and overlays
21
+
22
+ ## How to Use
23
+
24
+ 1. Upload an image of a person wearing clothes
25
+ 2. Upload a reference image of a garment (e.g., a t-shirt, pants, dress)
26
+ 3. Click "Process Images" to generate the targeted segmentation
27
+ 4. View the results in the gallery
28
+
29
+ ## Technical Details
30
+
31
+ This application combines SegFormer for image segmentation with Fashion-CLIP for garment classification to create a system that can identify and segment specific garment types in images.
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+
4
+ # Import from our modules
5
+ from constants import class_names
6
+ from segmentation import segment_image
7
+ from utils import process_url, process_person_and_garment
8
+
9
+ # Define fixed size for consistent image display
10
+ FIXED_IMAGE_SIZE = (400, 400)
11
+
12
+ def create_interface():
13
+ """Create the Gradio interface with improved image consistency"""
14
+ with gr.Blocks(title="Garment-based Segmentation") as demo:
15
+ gr.Markdown("""
16
+ # Garment-based Segmentation with SegFormer and Fashion-CLIP
17
+
18
+ This application uses AI models to segment specific clothing items in images by matching a garment to a person.
19
+ """)
20
+
21
+ with gr.Row():
22
+ with gr.Column(scale=1):
23
+ # Person image section
24
+ gr.Markdown("### Person Image")
25
+ person_image = gr.Image(
26
+ type="pil",
27
+ label="Upload a person wearing clothes",
28
+ height=300,
29
+ sources=["upload", "webcam", "clipboard"],
30
+ elem_id="person-image-upload"
31
+ )
32
+
33
+ # Garment image section
34
+ gr.Markdown("### Garment Image")
35
+ garment_image = gr.Image(
36
+ type="pil",
37
+ label="Upload a garment to detect",
38
+ height=300,
39
+ sources=["upload", "webcam", "clipboard"],
40
+ elem_id="garment-image-upload"
41
+ )
42
+
43
+ with gr.Row():
44
+ show_original_dual = gr.Checkbox(label="Show Original", value=True)
45
+ show_segmentation_dual = gr.Checkbox(label="Show Segmentation", value=True)
46
+ show_overlay_dual = gr.Checkbox(label="Show Overlay", value=True)
47
+
48
+ process_button = gr.Button(
49
+ "Process Images",
50
+ variant="primary",
51
+ size="lg",
52
+ elem_id="process-button"
53
+ )
54
+
55
+ with gr.Column(scale=2):
56
+ dual_output_images = gr.Gallery(
57
+ label="Results",
58
+ columns=3,
59
+ height=450,
60
+ object_fit="contain",
61
+ elem_id="dual_gallery"
62
+ )
63
+ result_text = gr.Textbox(label="Result", interactive=False, lines=4)
64
+
65
+ # Set up event handler for dual image processing
66
+ process_button.click(
67
+ fn=lambda p_img, g_img, orig, seg, over: process_person_and_garment(p_img, g_img, orig, seg, over, FIXED_IMAGE_SIZE),
68
+ inputs=[person_image, garment_image, show_original_dual, show_segmentation_dual, show_overlay_dual],
69
+ outputs=[dual_output_images, result_text]
70
+ )
71
+
72
+ # Add custom CSS for consistent image sizes and improved UI
73
+ gr.HTML("""
74
+ <style>
75
+ .gradio-container img {
76
+ max-height: 400px !important;
77
+ object-fit: contain !important;
78
+ }
79
+ #dual_gallery {
80
+ min-height: 450px;
81
+ }
82
+ /* Larger upload buttons */
83
+ #person-image-upload .upload-button,
84
+ #garment-image-upload .upload-button {
85
+ font-size: 1.2em !important;
86
+ padding: 12px 20px !important;
87
+ border-radius: 8px !important;
88
+ margin: 10px auto !important;
89
+ display: block !important;
90
+ width: 80% !important;
91
+ text-align: center !important;
92
+ background-color: #4CAF50 !important;
93
+ color: white !important;
94
+ border: none !important;
95
+ cursor: pointer !important;
96
+ transition: background-color 0.3s ease !important;
97
+ }
98
+ #person-image-upload .upload-button:hover,
99
+ #garment-image-upload .upload-button:hover {
100
+ background-color: #45a049 !important;
101
+ }
102
+ /* Larger process button */
103
+ #process-button {
104
+ font-size: 1.3em !important;
105
+ padding: 15px 25px !important;
106
+ margin: 15px auto !important;
107
+ display: block !important;
108
+ width: 90% !important;
109
+ }
110
+ /* Better section headers */
111
+ h3 {
112
+ font-size: 1.5em !important;
113
+ margin-top: 20px !important;
114
+ margin-bottom: 15px !important;
115
+ color: #2c3e50 !important;
116
+ border-bottom: 2px solid #3498db !important;
117
+ padding-bottom: 8px !important;
118
+ }
119
+ /* Better main heading */
120
+ h1 {
121
+ color: #2c3e50 !important;
122
+ text-align: center !important;
123
+ margin-bottom: 30px !important;
124
+ font-size: 2.5em !important;
125
+ }
126
+ /* Better checkbox layout */
127
+ .gradio-checkbox {
128
+ margin: 10px 5px !important;
129
+ }
130
+ </style>
131
+ """)
132
+
133
+ return demo
134
+
135
+ # Main application entry point
136
+ if __name__ == "__main__":
137
+ demo = create_interface()
138
+ demo.launch()
classification.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from models import clip_model, clip_processor, segformer_model, segformer_processor
3
+ from constants import fashion_categories, fashion_clip_to_segformer, class_names, category_to_segment_mapping, garment_to_segments
4
+ from segmentation import identify_garment_segformer
5
+
6
+ def identify_garment_clip(image):
7
+ """Identify the garment type using Fashion-CLIP model"""
8
+ # Prepare text prompts
9
+ texts = [f"a photo of a {category}" for category in fashion_categories]
10
+
11
+ # Process inputs
12
+ inputs = clip_processor(text=texts, images=image, return_tensors="pt", padding=True)
13
+
14
+ # Get predictions
15
+ with torch.no_grad():
16
+ outputs = clip_model(**inputs)
17
+ logits_per_image = outputs.logits_per_image
18
+ probs = logits_per_image.softmax(dim=1)
19
+
20
+ # Get the top prediction
21
+ top_idx = torch.argmax(probs[0]).item()
22
+ top_category = fashion_categories[top_idx]
23
+ confidence = probs[0][top_idx].item() * 100
24
+
25
+ # Map to SegFormer class if possible
26
+ if top_category in fashion_clip_to_segformer:
27
+ segformer_idx = fashion_clip_to_segformer[top_category]
28
+ segformer_class = class_names[segformer_idx]
29
+ return top_category, segformer_idx, confidence
30
+ else:
31
+ # Fallback to using SegFormer directly
32
+ return top_category, None, confidence
33
+
34
+ def get_segments_for_garment(garment_image):
35
+ """Get the segments that should be included for a given garment image"""
36
+ # First try to identify the garment using Fashion-CLIP
37
+ clip_category, segformer_idx, confidence = identify_garment_clip(garment_image)
38
+
39
+ # If CLIP couldn't map to a SegFormer class, fall back to SegFormer
40
+ if segformer_idx is None:
41
+ garment_name, segformer_idx = identify_garment_segformer(garment_image)
42
+ method = "SegFormer (fallback)"
43
+ confidence_text = ""
44
+ else:
45
+ garment_name = class_names[segformer_idx]
46
+ method = "Fashion-CLIP"
47
+ confidence_text = f" with {confidence:.2f}% confidence"
48
+
49
+ if segformer_idx is None:
50
+ return None, None, "No clear garment detected in the garment image"
51
+
52
+ # Get all segments that should be included based on the detected garment
53
+ if method == "Fashion-CLIP" and clip_category in category_to_segment_mapping:
54
+ # Use the detailed mapping for CLIP categories
55
+ selected_class = category_to_segment_mapping[clip_category]
56
+ elif segformer_idx in garment_to_segments:
57
+ # Fall back to the SegFormer class-based mapping
58
+ segment_indices = garment_to_segments[segformer_idx]
59
+ selected_class = [class_names[idx] for idx in segment_indices]
60
+ else:
61
+ # Fallback to just the detected garment if no mapping exists
62
+ selected_class = [class_names[segformer_idx]]
63
+
64
+ # Prepare a more descriptive result text
65
+ included_segments = ", ".join(selected_class)
66
+ if method == "Fashion-CLIP":
67
+ result_text = f"Detected garment: {clip_category} (mapped to {garment_name})\nUsing {method}{confidence_text}\nSegmented parts: {included_segments}"
68
+ else:
69
+ result_text = f"Detected garment: {garment_name}\nUsing {method}\nSegmented parts: {included_segments}"
70
+
71
+ return selected_class, segformer_idx, result_text
constants.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+
3
+ # Define the class names for the segmentation model
4
+ class_names = [
5
+ "Background", "Hat", "Hair", "Sunglasses", "Upper-clothes", "Skirt", "Pants",
6
+ "Dress", "Belt", "Left-shoe", "Right-shoe", "Face", "Left-leg", "Right-leg",
7
+ "Left-arm", "Right-arm", "Bag", "Scarf"
8
+ ]
9
+
10
+ # Define a color map for visualization
11
+ color_map = plt.cm.get_cmap('tab20', len(class_names))
12
+
13
+ # Define a mapping of garment types to related segments that should be included
14
+ garment_to_segments = {
15
+ 0: [0], # Background --> segment background only
16
+ 1: [1, 2, 11], # Hat --> segment hat, hair, and face
17
+ 2: [2], # Hair --> segment hair only
18
+ 3: [3, 11], # Sunglasses --> segment sunglasses and face
19
+ 4: [4, 14, 15], # Upper-clothes --> segment upper clothes, left arm, right arm
20
+ 5: [5, 6, 12, 13], # Skirt --> segment skirt, pants, left leg, right leg
21
+ 6: [6, 12, 13], # Pants --> segment pants, left leg, right leg
22
+ 7: [4, 5, 6, 7, 12, 13, 14, 15], # Dress --> segment whole body except face and hair
23
+ 8: [8], # Belt --> segment belt only
24
+ 9: [9], # Left-shoe --> segment left shoe only
25
+ 10: [10], # Right-shoe --> segment right shoe only
26
+ 11: [11], # Face --> segment face only
27
+ 12: [12], # Left-leg --> segment left leg only
28
+ 13: [13], # Right-leg --> segment right leg only
29
+ 14: [14], # Left-arm --> segment left arm only
30
+ 15: [15], # Right-arm --> segment right arm only
31
+ 16: [16], # Bag --> segment bag only
32
+ 17: [17, 2, 11] # Scarf --> segment scarf, hair and face
33
+ }
34
+
35
+ # Define categories for Fashion-CLIP
36
+ fashion_categories = [
37
+ # Upper body
38
+ "t-shirt", "shirt", "blouse", "tank top", "polo shirt", "sweatshirt", "hoodie",
39
+
40
+ # Outerwear
41
+ "jacket", "coat", "blazer", "cardigan", "vest", "windbreaker",
42
+
43
+ # Dresses
44
+ "dress", "shirt dress", "sundress", "evening gown", "maxi dress", "mini dress",
45
+
46
+ # Lower body
47
+ "jeans", "pants", "trousers", "shorts", "skirt", "leggings", "joggers", "sweatpants",
48
+
49
+ # Formal wear
50
+ "suit", "tuxedo", "formal shirt", "formal dress",
51
+
52
+ # Undergarments
53
+ "bra", "underwear", "boxers", "briefs", "lingerie",
54
+
55
+ # Sleepwear
56
+ "pajamas", "nightgown", "bathrobe",
57
+
58
+ # Swimwear
59
+ "swimsuit", "bikini", "swim trunks",
60
+
61
+ # Footwear
62
+ "shoes", "boots", "sneakers", "sandals", "high heels", "loafers", "flats",
63
+
64
+ # Accessories
65
+ "hat", "cap", "scarf", "gloves", "belt", "tie", "socks",
66
+
67
+ # Bags
68
+ "handbag", "backpack", "purse", "tote bag"
69
+ ]
70
+
71
+ # Mapping from Fashion-CLIP categories to SegFormer classes
72
+ fashion_clip_to_segformer = {
73
+ # Upper body items -> Upper-clothes (4)
74
+ "t-shirt": 4, "shirt": 4, "blouse": 4, "tank top": 4, "polo shirt": 4, "sweatshirt": 4, "hoodie": 4,
75
+ "cardigan": 4, "vest": 4, "formal shirt": 4,
76
+
77
+ # Outerwear -> Upper-clothes (4)
78
+ "jacket": 4, "coat": 4, "blazer": 4, "windbreaker": 4,
79
+
80
+ # Dresses -> Dress (7)
81
+ "dress": 7, "shirt dress": 7, "sundress": 7, "evening gown": 7, "maxi dress": 7, "mini dress": 7,
82
+ "formal dress": 7,
83
+
84
+ # Lower body -> Pants (6) or Skirt (5)
85
+ "jeans": 6, "pants": 6, "trousers": 6, "shorts": 6, "leggings": 6, "joggers": 6, "sweatpants": 6,
86
+ "skirt": 5,
87
+
88
+ # Formal wear -> Upper-clothes (4) or Dress (7)
89
+ "suit": 4, "tuxedo": 4,
90
+
91
+ # Footwear -> Left-shoe/Right-shoe (9/10)
92
+ "shoes": 9, "boots": 9, "sneakers": 9, "sandals": 9, "high heels": 9, "loafers": 9, "flats": 9,
93
+
94
+ # Accessories
95
+ "hat": 1, "cap": 1, "scarf": 17, "belt": 8,
96
+
97
+ # Bags
98
+ "handbag": 16, "backpack": 16, "purse": 16, "tote bag": 16
99
+ }
100
+
101
+ # Detailed mapping from categories to segment names
102
+ category_to_segment_mapping = {
103
+ # Upper body items map to Upper-clothes and arms
104
+ "t-shirt": ["Upper-clothes", "Left-arm", "Right-arm"],
105
+ "shirt": ["Upper-clothes", "Left-arm", "Right-arm"],
106
+ "blouse": ["Upper-clothes", "Left-arm", "Right-arm"],
107
+ "tank top": ["Upper-clothes", "Left-arm", "Right-arm"],
108
+ "polo shirt": ["Upper-clothes", "Left-arm", "Right-arm"],
109
+ "sweatshirt": ["Upper-clothes", "Left-arm", "Right-arm"],
110
+ "hoodie": ["Upper-clothes", "Left-arm", "Right-arm"],
111
+
112
+ # Outerwear maps to Upper-clothes and arms
113
+ "jacket": ["Upper-clothes", "Left-arm", "Right-arm"],
114
+ "coat": ["Upper-clothes", "Left-arm", "Right-arm"],
115
+ "blazer": ["Upper-clothes", "Left-arm", "Right-arm"],
116
+ "cardigan": ["Upper-clothes", "Left-arm", "Right-arm"],
117
+ "vest": ["Upper-clothes"],
118
+ "windbreaker": ["Upper-clothes", "Left-arm", "Right-arm"],
119
+
120
+ # Dresses map to Dress
121
+ "dress": ["Dress", "Upper-clothes", "Skirt", "Pants", "Left-arm", "Right-arm", "Left-leg", "Right-leg"],
122
+ "shirt dress": ["Dress", "Upper-clothes", "Skirt", "Pants", "Left-arm", "Right-arm", "Left-leg", "Right-leg"],
123
+ "sundress": ["Dress", "Upper-clothes", "Skirt", "Pants", "Left-arm", "Right-arm", "Left-leg", "Right-leg"],
124
+ "evening gown": ["Dress", "Upper-clothes", "Skirt", "Pants", "Left-arm", "Right-arm", "Left-leg", "Right-leg"],
125
+ "maxi dress": ["Dress", "Upper-clothes", "Skirt", "Pants", "Left-arm", "Right-arm", "Left-leg", "Right-leg"],
126
+ "mini dress": ["Dress", "Upper-clothes", "Skirt", "Pants", "Left-arm", "Right-arm", "Left-leg", "Right-leg"],
127
+ "formal dress": ["Dress", "Upper-clothes", "Skirt", "Pants", "Left-arm", "Right-arm", "Left-leg", "Right-leg"],
128
+
129
+ # Lower body items map to Pants or Skirt and legs
130
+ "jeans": ["Pants", "Left-leg", "Right-leg"],
131
+ "pants": ["Pants", "Left-leg", "Right-leg"],
132
+ "trousers": ["Pants", "Left-leg", "Right-leg"],
133
+ "shorts": ["Pants", "Left-leg", "Right-leg"],
134
+ "skirt": ["Skirt", "Pants", "Left-leg", "Right-leg"],
135
+ "leggings": ["Pants", "Left-leg", "Right-leg"],
136
+ "joggers": ["Pants", "Left-leg", "Right-leg"],
137
+ "sweatpants": ["Pants", "Left-leg", "Right-leg"],
138
+
139
+ # Formal wear maps depending on type
140
+ "suit": ["Upper-clothes", "Left-arm", "Right-arm", "Pants", "Left-leg", "Right-leg"],
141
+ "tuxedo": ["Upper-clothes", "Left-arm", "Right-arm", "Pants", "Left-leg", "Right-leg"],
142
+ "formal shirt": ["Upper-clothes", "Left-arm", "Right-arm"],
143
+ "formal dress": ["Dress", "Upper-clothes", "Skirt", "Pants", "Left-arm", "Right-arm", "Left-leg", "Right-leg"],
144
+
145
+ # Footwear maps to shoes
146
+ "shoes": ["Left-shoe", "Right-shoe"],
147
+ "boots": ["Left-shoe", "Right-shoe"],
148
+ "sneakers": ["Left-shoe", "Right-shoe"],
149
+ "sandals": ["Left-shoe", "Right-shoe"],
150
+ "high heels": ["Left-shoe", "Right-shoe"],
151
+ "loafers": ["Left-shoe", "Right-shoe"],
152
+ "flats": ["Left-shoe", "Right-shoe"],
153
+
154
+ # Accessories map to their respective parts
155
+ "hat": ["Hat"],
156
+ "cap": ["Hat"],
157
+ "scarf": ["Scarf", "Face", "Hair"],
158
+ "gloves": ["Left-arm", "Right-arm"],
159
+ "belt": ["Belt"],
160
+ "tie": ["Upper-clothes"],
161
+ "socks": ["Left-leg", "Right-leg"],
162
+
163
+ # Bags map to Bag
164
+ "handbag": ["Bag"],
165
+ "backpack": ["Bag"],
166
+ "purse": ["Bag"],
167
+ "tote bag": ["Bag"]
168
+ }
models.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation, CLIPProcessor, CLIPModel
4
+
5
+ # Load the SegFormer model and processor for segmentation
6
+ def load_segformer_model():
7
+ """Load and return the SegFormer model and processor"""
8
+ processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
9
+ model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
10
+ return processor, model
11
+
12
+ # Load Fashion-CLIP model for garment classification
13
+ def load_clip_model():
14
+ """Load and return the Fashion-CLIP model and processor"""
15
+ model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip")
16
+ processor = CLIPProcessor.from_pretrained("patrickjohncyh/fashion-clip")
17
+ return model, processor
18
+
19
+ # Initialize models
20
+ segformer_processor, segformer_model = load_segformer_model()
21
+ clip_model, clip_processor = load_clip_model()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=1.10.0
2
+ torchvision>=0.11.1
3
+ transformers>=4.15.0
4
+ Pillow>=8.4.0
5
+ requests>=2.26.0
6
+ matplotlib>=3.5.0
7
+ gradio>=3.19.0
8
+ numpy>=1.21.0
segmentation.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from PIL import Image
6
+ import io
7
+ from collections import Counter
8
+ import gradio as gr
9
+
10
+ from models import segformer_model, segformer_processor
11
+ from constants import class_names, color_map
12
+
13
+ def segment_image(image, selected_classes=None, show_original=True, show_segmentation=True, show_overlay=True, fixed_size=(400, 400)):
14
+ """Segment the image based on selected classes with consistent output sizes"""
15
+ # Process the image
16
+ inputs = segformer_processor(images=image, return_tensors="pt")
17
+
18
+ # Get model predictions
19
+ outputs = segformer_model(**inputs)
20
+ logits = outputs.logits.cpu()
21
+
22
+ # Upsample the logits to match the original image size
23
+ upsampled_logits = nn.functional.interpolate(
24
+ logits,
25
+ size=image.size[::-1], # (height, width)
26
+ mode="bilinear",
27
+ align_corners=False,
28
+ )
29
+
30
+ # Get the predicted segmentation map
31
+ pred_seg = upsampled_logits.argmax(dim=1)[0].numpy()
32
+
33
+ # Filter classes if specified
34
+ if selected_classes and len(selected_classes) > 0:
35
+ # Create a mask for selected classes
36
+ mask = np.zeros_like(pred_seg, dtype=bool)
37
+ for class_name in selected_classes:
38
+ if class_name in class_names:
39
+ class_idx = class_names.index(class_name)
40
+ mask = np.logical_or(mask, pred_seg == class_idx)
41
+
42
+ # Apply the mask to keep only selected classes, set others to background (0)
43
+ filtered_seg = np.zeros_like(pred_seg)
44
+ filtered_seg[mask] = pred_seg[mask]
45
+ pred_seg = filtered_seg
46
+
47
+ # Create a colored segmentation map
48
+ colored_seg = np.zeros((pred_seg.shape[0], pred_seg.shape[1], 3))
49
+ for class_idx in range(len(class_names)):
50
+ mask = pred_seg == class_idx
51
+ if mask.any():
52
+ colored_seg[mask] = color_map(class_idx)[:3]
53
+
54
+ # Create an overlay of the segmentation on the original image
55
+ image_array = np.array(image)
56
+ overlay = image_array.copy()
57
+ alpha = 0.5 # Transparency factor
58
+ mask = pred_seg > 0 # Exclude background
59
+ overlay[mask] = overlay[mask] * (1 - alpha) + colored_seg[mask] * 255 * alpha
60
+
61
+ # Prepare output images based on user selection
62
+ outputs = []
63
+
64
+ if show_original:
65
+ # Resize original image to ensure consistent size
66
+ resized_original = image.resize(fixed_size)
67
+ outputs.append(resized_original)
68
+
69
+ if show_segmentation:
70
+ seg_image = Image.fromarray((colored_seg * 255).astype('uint8'))
71
+ # Ensure segmentation has consistent size
72
+ seg_image = seg_image.resize(fixed_size)
73
+ outputs.append(seg_image)
74
+
75
+ if show_overlay:
76
+ overlay_image = Image.fromarray(overlay.astype('uint8'))
77
+ # Ensure overlay has consistent size
78
+ overlay_image = overlay_image.resize(fixed_size)
79
+ outputs.append(overlay_image)
80
+
81
+ # Create a legend for the segmentation classes
82
+ fig, ax = plt.subplots(figsize=(10, 2))
83
+ fig.patch.set_alpha(0.0)
84
+ ax.axis('off')
85
+
86
+ # Create legend patches
87
+ legend_elements = []
88
+ for i, class_name in enumerate(class_names):
89
+ if i == 0 and selected_classes: # Skip background in legend if filtering
90
+ continue
91
+ if not selected_classes or class_name in selected_classes:
92
+ color = color_map(i)[:3]
93
+ legend_elements.append(plt.Rectangle((0, 0), 1, 1, color=color))
94
+
95
+ # Only add legend if there are elements to show
96
+ if legend_elements:
97
+ legend_class_names = [name for name in class_names if name != "Background" and (not selected_classes or name in selected_classes)]
98
+ ax.legend(legend_elements, legend_class_names, loc='center', ncol=6)
99
+
100
+ # Save the legend to a bytes buffer
101
+ buf = io.BytesIO()
102
+ plt.savefig(buf, format='png', bbox_inches='tight', transparent=True)
103
+ buf.seek(0)
104
+ legend_img = Image.open(buf)
105
+
106
+ plt.close(fig)
107
+
108
+ outputs.append(legend_img)
109
+
110
+ return outputs
111
+
112
+ def identify_garment_segformer(image):
113
+ """Identify the dominant garment type using SegFormer"""
114
+ # Process the image
115
+ inputs = segformer_processor(images=image, return_tensors="pt")
116
+
117
+ # Get model predictions
118
+ outputs = segformer_model(**inputs)
119
+ logits = outputs.logits.cpu()
120
+
121
+ # Upsample the logits to match the original image size
122
+ upsampled_logits = nn.functional.interpolate(
123
+ logits,
124
+ size=image.size[::-1], # (height, width)
125
+ mode="bilinear",
126
+ align_corners=False,
127
+ )
128
+
129
+ # Get the predicted segmentation map
130
+ pred_seg = upsampled_logits.argmax(dim=1)[0].numpy()
131
+
132
+ # Count the pixels for each class (excluding background)
133
+ class_counts = Counter(pred_seg.flatten())
134
+ if 0 in class_counts: # Remove background
135
+ del class_counts[0]
136
+
137
+ # Find the most common clothing item
138
+ clothing_classes = [4, 5, 6, 7] # Upper-clothes, Skirt, Pants, Dress
139
+
140
+ # Filter to only include clothing items
141
+ clothing_counts = {k: v for k, v in class_counts.items() if k in clothing_classes}
142
+
143
+ if not clothing_counts:
144
+ return "No garment detected", None
145
+
146
+ # Get the most common clothing item
147
+ dominant_class = max(clothing_counts.items(), key=lambda x: x[1])[0]
148
+ dominant_class_name = class_names[dominant_class]
149
+
150
+ return dominant_class_name, dominant_class
utils.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from PIL import Image
3
+ import gradio as gr
4
+
5
+ from segmentation import segment_image
6
+ from classification import get_segments_for_garment
7
+
8
+ def process_url(url, selected_classes, show_original, show_segmentation, show_overlay, fixed_size=(400, 400)):
9
+ """Process an image from a URL"""
10
+ try:
11
+ image = Image.open(requests.get(url, stream=True).raw)
12
+ return segment_image(image, selected_classes, show_original, show_segmentation, show_overlay, fixed_size)
13
+ except Exception as e:
14
+ return [gr.update(value=None)] * 4, f"Error: {str(e)}"
15
+
16
+ def process_person_and_garment(person_image, garment_image, show_original, show_segmentation, show_overlay, fixed_size=(400, 400)):
17
+ """Process person and garment images for targeted segmentation"""
18
+ if person_image is None or garment_image is None:
19
+ return [gr.update(value=None)] * 4, "Please provide both person and garment images"
20
+
21
+ try:
22
+ # Get segments that should be included based on the garment
23
+ selected_class, segformer_idx, result_text = get_segments_for_garment(garment_image)
24
+
25
+ if selected_class is None:
26
+ return [gr.update(value=None)] * 4, result_text
27
+
28
+ # Process the person image with the selected garment classes
29
+ result_images = segment_image(person_image, selected_class, show_original, show_segmentation, show_overlay, fixed_size)
30
+
31
+ return result_images, result_text
32
+
33
+ except Exception as e:
34
+ return [gr.update(value=None)] * 4, f"Error: {str(e)}"