File size: 1,934 Bytes
dfa48cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import streamlit as st
import torch
from torch import nn
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

def load_model():
    """Load the Segformer model and processor."""
    processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
    model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing")
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    model.to(device)
    return processor, model, device

def process_image(image: Image.Image, processor, model, device):
    """Run inference on the image and return the segmentation mask."""
    inputs = processor(images=image, return_tensors="pt").to(device)
    outputs = model(**inputs)
    logits = outputs.logits
    upsampled_logits = nn.functional.interpolate(
        logits, size=image.size[::-1], mode="bilinear", align_corners=False
    )
    labels = upsampled_logits.argmax(dim=1)[0].cpu().numpy()
    return labels

def visualize_segmentation(labels: np.ndarray):
    """Visualize segmentation mask using Matplotlib."""
    fig, ax = plt.subplots()
    ax.imshow(labels, cmap="jet", alpha=0.7)
    ax.axis("off")
    st.pyplot(fig)

# Streamlit UI
st.title("Face Parsing using Segformer")
st.write("Upload an image to perform semantic segmentation on faces.")

# Load model
processor, model, device = load_model()

# File uploader
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])

if uploaded_file:
    image = Image.open(uploaded_file).convert("RGB")
    st.image(image, caption="Uploaded Image", use_column_width=True)
    
    # Process image
    with st.spinner("Processing..."):
        labels = process_image(image, processor, model, device)
        
        # Display result
        visualize_segmentation(labels)