import streamlit as st import torch from torch import nn from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation from PIL import Image import numpy as np import matplotlib.pyplot as plt def load_model(): """Load the Segformer model and processor.""" processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing") model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing") device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" model.to(device) return processor, model, device def process_image(image: Image.Image, processor, model, device): """Run inference on the image and return the segmentation mask.""" inputs = processor(images=image, return_tensors="pt").to(device) outputs = model(**inputs) logits = outputs.logits upsampled_logits = nn.functional.interpolate( logits, size=image.size[::-1], mode="bilinear", align_corners=False ) labels = upsampled_logits.argmax(dim=1)[0].cpu().numpy() return labels def visualize_segmentation(labels: np.ndarray): """Visualize segmentation mask using Matplotlib.""" fig, ax = plt.subplots() ax.imshow(labels, cmap="jet", alpha=0.7) ax.axis("off") st.pyplot(fig) # Streamlit UI st.title("Face Parsing using Segformer") st.write("Upload an image to perform semantic segmentation on faces.") # Load model processor, model, device = load_model() # File uploader uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"]) if uploaded_file: image = Image.open(uploaded_file).convert("RGB") st.image(image, caption="Uploaded Image", use_column_width=True) # Process image with st.spinner("Processing..."): labels = process_image(image, processor, model, device) # Display result visualize_segmentation(labels)