File size: 6,812 Bytes
ed33d5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9374fca
cbc8a5e
3f264c3
 
 
 
 
ed33d5b
e925709
 
3f264c3
 
 
7739f9e
3f264c3
e925709
 
9374fca
93bb675
 
 
9374fca
2f457aa
93bb675
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import gradio as gr
import torch
from transformers import SwinForImageClassification, AutoFeatureExtractor
import mediapipe as mp
import cv2
from PIL import Image
import numpy as np
import os

# Face shape descriptions
face_shape_descriptions = {
    "Heart": "dengan dahi lebar dan dagu yang runcing.",
    "Oblong": "yang lebih panjang dari lebar dengan garis pipi lurus.",
    "Oval": "dengan proporsi seimbang dan dagu sedikit melengkung.",
    "Round": "dengan garis rahang melengkung dan pipi penuh.",
    "Square": "dengan rahang tegas dan dahi lebar."
}

# Frame images path
glasses_images = {
    "Oval": "glasses/oval.jpg",
    "Round": "glasses/round.jpg",
    "Square": "glasses/square.jpg",
    "Octagon": "glasses/octagon.jpg",
    "Cat Eye": "glasses/cat eye.jpg",
    "Pilot (Aviator)": "glasses/aviator.jpg"
}

# Ensure the 'glasses' directory exists and contains the images
if not os.path.exists("glasses"):
    os.makedirs("glasses")
    # Create dummy image files if they don't exist
    for _, path in glasses_images.items():
        if not os.path.exists(path):
            dummy_image = Image.new('RGB', (200, 100), color='gray')
            dummy_image.save(path)

id2label = {0: 'Heart', 1: 'Oblong', 2: 'Oval', 3: 'Round', 4: 'Square'}
label2id = {v: k for k, v in id2label.items()}

# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_checkpoint = "microsoft/swin-tiny-patch4-window7-224"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)

model = SwinForImageClassification.from_pretrained(
    model_checkpoint,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True
)

# Load your trained weights
# Ensure 'LR-0001-adamW-32-64swin.pth' is in the same directory or provide the correct path
if os.path.exists('LR-0001-adamW-32-64swin.pth'):
    state_dict = torch.load('LR-0001-adamW-32-64swin.pth', map_location=device)
    model.load_state_dict(state_dict, strict=False)
    model.to(device)
    model.eval()
else:
    print("Warning: Trained weights file 'LR-0001-adamW-32-64swin.pth' not found. Using pre-trained weights only.")

# Initialize Mediapipe
mp_face_detection = mp.solutions.face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5)

# --- New: Decision tree function
def recommend_glasses_tree(face_shape):
    face_shape = face_shape.lower()
    if face_shape == "square":
        return ["Oval", "Round"]
    elif face_shape == "round":
        return ["Square", "Octagon", "Cat Eye"]
    elif face_shape == "oval":
        return ["Oval", "Pilot (Aviator)", "Cat Eye", "Round"]
    elif face_shape == "heart":
        return ["Pilot (Aviator)", "Cat Eye", "Round"]
    elif face_shape == "oblong":
        return ["Square", "Oval", "Pilot (Aviator)", "Cat Eye"]
    else:
        return []

# Preprocess function
def preprocess_image(image):
    img = np.array(image, dtype=np.uint8)
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

    results = mp_face_detection.process(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))

    if results.detections:
        detection = results.detections[0]
        bbox = detection.location_data.relative_bounding_box
        h, w, _ = img.shape
        x1 = int(bbox.xmin * w)
        y1 = int(bbox.ymin * h)
        x2 = int((bbox.xmin + bbox.width) * w)
        y2 = int((bbox.ymin + bbox.height) * h)

        img = img[y1:y2, x1:x2]
    else:
        raise ValueError("Wajah tidak terdeteksi.")

    img = cv2.resize(img, (224, 224))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    inputs = feature_extractor(images=img, return_tensors="pt")
    return inputs['pixel_values'].squeeze(0)

# Prediction function
def predict(image):
    try:
        inputs = preprocess_image(image).unsqueeze(0).to(device)
        with torch.no_grad():
            outputs = model(inputs)
            probs = torch.nn.functional.softmax(outputs.logits, dim=1)
            pred_idx = torch.argmax(probs, dim=1).item()
            pred_label = id2label[pred_idx]
            pred_prob = probs[0][pred_idx].item() * 100

        # --- Use decision tree for recommendations
        frame_recommendations = recommend_glasses_tree(pred_label)

        description = face_shape_descriptions.get(pred_label, "tidak dikenali")
        gallery_items = []

        # Load images for all recommended frames
        for frame in frame_recommendations:
            frame_image_path = glasses_images.get(frame)
            if frame_image_path and os.path.exists(frame_image_path):
                try:
                    frame_image = Image.open(frame_image_path)
                    gallery_items.append((frame_image, frame))  # Tambahkan nama frame
                except Exception as e:
                    print(f"Error loading image for {frame}: {e}")

        # Build explanation text
        if frame_recommendations:
            recommended_frames_text = ', '.join(frame_recommendations)
            explanation = (f"Bentuk wajah kamu adalah {pred_label} ({pred_prob:.2f}%). "
                           f"Kamu memiliki bentuk wajah {description} "
                           f"Rekomendasi bentuk kacamata yang sesuai dengan wajah kamu adalah: {recommended_frames_text}.")
        else:
            explanation = (f"Bentuk wajah kamu adalah {pred_label} ({pred_prob:.2f}%). "
                           f"Tidak ada rekomendasi frame untuk bentuk wajah ini.")

        return pred_label, explanation, gallery_items

    except ValueError as ve:
        return "Error", str(ve), []
    except Exception as e:
        return "Error", f"Terjadi kesalahan: {str(e)}", []

# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as iface:
    gr.Markdown("# Program Rekomendasi Kacamata Berdasarkan Bentuk Wajah")
    gr.Markdown("Upload foto wajahmu untuk mendapatkan rekomendasi bentuk kacamata yang sesuai.")

    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil")
            confirm_button = gr.Button("Konfirmasi")
            restart_button = gr.Button("Restart")
        with gr.Column():
            detected_shape = gr.Textbox(label="Bentuk Wajah Terdeteksi")
            explanation_output = gr.Textbox(label="Penjelasan")
            recommendation_gallery = gr.Gallery(label="Rekomendasi Kacamata", columns=3, show_label=False)

    confirm_button.click(predict, inputs=image_input, outputs=[detected_shape, explanation_output, recommendation_gallery])
    restart_button.click(lambda: (None, "", [], []), inputs=None, outputs=[image_input, detected_shape, explanation_output, recommendation_gallery])

    # Add source statement under the gallery
    gr.Markdown("**Sumber gambar kacamata**: Katalog dari [glassdirect.co.uk](https://www.glassdirect.co.uk)")

if __name__ == "__main__":
    iface.launch()