Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import SwinForImageClassification, AutoFeatureExtractor | |
import mediapipe as mp | |
import cv2 | |
from PIL import Image | |
import numpy as np | |
import os | |
# Face shape descriptions | |
face_shape_descriptions = { | |
"Heart": "dengan dahi lebar dan dagu yang runcing.", | |
"Oblong": "yang lebih panjang dari lebar dengan garis pipi lurus.", | |
"Oval": "dengan proporsi seimbang dan dagu sedikit melengkung.", | |
"Round": "dengan garis rahang melengkung dan pipi penuh.", | |
"Square": "dengan rahang tegas dan dahi lebar." | |
} | |
# Frame images path | |
glasses_images = { | |
"Oval": "glasses/oval.jpg", | |
"Round": "glasses/round.jpg", | |
"Square": "glasses/square.jpg", | |
"Octagon": "glasses/octagon.jpg", | |
"Cat Eye": "glasses/cat eye.jpg", | |
"Pilot (Aviator)": "glasses/aviator.jpg" | |
} | |
# Ensure the 'glasses' directory exists and contains the images | |
if not os.path.exists("glasses"): | |
os.makedirs("glasses") | |
# Create dummy image files if they don't exist | |
for _, path in glasses_images.items(): | |
if not os.path.exists(path): | |
dummy_image = Image.new('RGB', (200, 100), color='gray') | |
dummy_image.save(path) | |
id2label = {0: 'Heart', 1: 'Oblong', 2: 'Oval', 3: 'Round', 4: 'Square'} | |
label2id = {v: k for k, v in id2label.items()} | |
# Load model | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model_checkpoint = "microsoft/swin-tiny-patch4-window7-224" | |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint) | |
model = SwinForImageClassification.from_pretrained( | |
model_checkpoint, | |
label2id=label2id, | |
id2label=id2label, | |
ignore_mismatched_sizes=True | |
) | |
# Load your trained weights | |
# Ensure 'LR-0001-adamW-32-64swin.pth' is in the same directory or provide the correct path | |
if os.path.exists('LR-0001-adamW-32-64swin.pth'): | |
state_dict = torch.load('LR-0001-adamW-32-64swin.pth', map_location=device) | |
model.load_state_dict(state_dict, strict=False) | |
model.to(device) | |
model.eval() | |
else: | |
print("Warning: Trained weights file 'LR-0001-adamW-32-64swin.pth' not found. Using pre-trained weights only.") | |
# Initialize Mediapipe | |
mp_face_detection = mp.solutions.face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5) | |
# --- New: Decision tree function | |
def recommend_glasses_tree(face_shape): | |
face_shape = face_shape.lower() | |
if face_shape == "square": | |
return ["Oval", "Round"] | |
elif face_shape == "round": | |
return ["Square", "Octagon", "Cat Eye"] | |
elif face_shape == "oval": | |
return ["Oval", "Pilot (Aviator)", "Cat Eye", "Round"] | |
elif face_shape == "heart": | |
return ["Pilot (Aviator)", "Cat Eye", "Round"] | |
elif face_shape == "oblong": | |
return ["Square", "Oval", "Pilot (Aviator)", "Cat Eye"] | |
else: | |
return [] | |
# Preprocess function | |
def preprocess_image(image): | |
img = np.array(image, dtype=np.uint8) | |
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
results = mp_face_detection.process(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) | |
if results.detections: | |
detection = results.detections[0] | |
bbox = detection.location_data.relative_bounding_box | |
h, w, _ = img.shape | |
x1 = int(bbox.xmin * w) | |
y1 = int(bbox.ymin * h) | |
x2 = int((bbox.xmin + bbox.width) * w) | |
y2 = int((bbox.ymin + bbox.height) * h) | |
img = img[y1:y2, x1:x2] | |
else: | |
raise ValueError("Wajah tidak terdeteksi.") | |
img = cv2.resize(img, (224, 224)) | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
inputs = feature_extractor(images=img, return_tensors="pt") | |
return inputs['pixel_values'].squeeze(0) | |
# Prediction function | |
def predict(image): | |
try: | |
inputs = preprocess_image(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
outputs = model(inputs) | |
probs = torch.nn.functional.softmax(outputs.logits, dim=1) | |
pred_idx = torch.argmax(probs, dim=1).item() | |
pred_label = id2label[pred_idx] | |
pred_prob = probs[0][pred_idx].item() * 100 | |
# --- Use decision tree for recommendations | |
frame_recommendations = recommend_glasses_tree(pred_label) | |
description = face_shape_descriptions.get(pred_label, "tidak dikenali") | |
gallery_items = [] | |
# Load images for all recommended frames | |
for frame in frame_recommendations: | |
frame_image_path = glasses_images.get(frame) | |
if frame_image_path and os.path.exists(frame_image_path): | |
try: | |
frame_image = Image.open(frame_image_path) | |
gallery_items.append((frame_image, frame)) # Tambahkan nama frame | |
except Exception as e: | |
print(f"Error loading image for {frame}: {e}") | |
# Build explanation text | |
if frame_recommendations: | |
recommended_frames_text = ', '.join(frame_recommendations) | |
explanation = (f"Bentuk wajah kamu adalah {pred_label} ({pred_prob:.2f}%). " | |
f"Kamu memiliki bentuk wajah {description} " | |
f"Rekomendasi bentuk kacamata yang sesuai dengan wajah kamu adalah: {recommended_frames_text}.") | |
else: | |
explanation = (f"Bentuk wajah kamu adalah {pred_label} ({pred_prob:.2f}%). " | |
f"Tidak ada rekomendasi frame untuk bentuk wajah ini.") | |
return pred_label, explanation, gallery_items | |
except ValueError as ve: | |
return "Error", str(ve), [] | |
except Exception as e: | |
return "Error", f"Terjadi kesalahan: {str(e)}", [] | |
# Gradio Interface | |
with gr.Blocks(theme=gr.themes.Soft()) as iface: | |
gr.Markdown("# Program Rekomendasi Kacamata Berdasarkan Bentuk Wajah") | |
gr.Markdown("Upload foto wajahmu untuk mendapatkan rekomendasi bentuk kacamata yang sesuai.") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="pil") | |
confirm_button = gr.Button("Konfirmasi") | |
restart_button = gr.Button("Restart") | |
with gr.Column(): | |
detected_shape = gr.Textbox(label="Bentuk Wajah Terdeteksi") | |
explanation_output = gr.Textbox(label="Penjelasan") | |
recommendation_gallery = gr.Gallery(label="Rekomendasi Kacamata", columns=3, show_label=False) | |
confirm_button.click(predict, inputs=image_input, outputs=[detected_shape, explanation_output, recommendation_gallery]) | |
restart_button.click(lambda: (None, "", [], []), inputs=None, outputs=[image_input, detected_shape, explanation_output, recommendation_gallery]) | |
# Add source statement under the gallery | |
gr.Markdown("**Sumber gambar kacamata**: Katalog dari [glassdirect.co.uk](https://www.glassdirect.co.uk)") | |
if __name__ == "__main__": | |
iface.launch() | |