TheOneReborn's picture
fix: typo in method
8103efe
import gradio
import numpy
from matplotlib import cm
from pathlib import Path
from PIL import Image
from fastai.vision.all import load_learner, PILImage, PILMask
MODEL_PATH = Path('.') / 'models'
TEST_IMAGES_PATH = Path('.') / 'test'
def preprocess_mask(file_name):
"""Ensures masks are in grayscale format and removes transparency."""
mask_path = Path(
'/kaggle/inumpyut/car-segmentation/car-segmentation/masks') / file_name.name
mask = Image.open(mask_path)
if mask.mode == 'P':
mask = mask.convert('RGBA')
if mask.mode != 'RGBA':
mask = mask.convert('RGBA')
mask_data = mask.getdata()
new_mask_data = [
(r, g, b, 255) if a > 0 else (0, 0, 0, 255)
for r, g, b, a in mask_data
]
mask.putdata(new_mask_data)
return PILMask.create(mask.convert('L'))
LEARNER = load_learner(MODEL_PATH / 'car-segmentation_v1.pkl')
def segment_image(image):
# Store original size
original_size = image.size # (width, height)
# Resize the input image to 400x400 for the model
resized_image = image.resize((400, 400))
resized_image = PILImage.create(resized_image)
# Get the prediction from the model
prediction, _, _ = LEARNER.predict(resized_image)
# Convert prediction to a NumPy array
prediction_array = numpy.asarray(prediction, dtype=numpy.uint8)
# Resize the mask back to the original image size
prediction_resized = Image.fromarray(prediction_array).resize(original_size, Image.NEAREST)
prediction_resized = numpy.array(prediction_resized)
# Apply a colormap for visualization (using the public API)
colormap = cm._colormaps['jet']
# Normalize the mask and apply the colormap (result is in float [0,1])
colored_mask = colormap(prediction_resized / numpy.max(prediction_resized))[:, :, :3]
# Convert the original image to a NumPy array and normalize to [0,1]
image_array = numpy.array(image).astype(numpy.float32) / 255.0
# Blend the original image and the colored mask
overlay = (image_array * 0.7) + (colored_mask * 0.3)
# Convert the blended image back to uint8
overlay = (overlay * 255).astype(numpy.uint8)
return overlay
demo = gradio.Interface(
segment_image,
inputs=gradio.Image(type='pil'),
outputs=gradio.Image(type='numpy'),
examples=[str(image) for image in TEST_IMAGES_PATH.iterdir()]
)
demo.launch()