import gradio import numpy from matplotlib import cm from pathlib import Path from PIL import Image from fastai.vision.all import load_learner, PILImage, PILMask MODEL_PATH = Path('.') / 'models' TEST_IMAGES_PATH = Path('.') / 'test' def preprocess_mask(file_name): """Ensures masks are in grayscale format and removes transparency.""" mask_path = Path( '/kaggle/inumpyut/car-segmentation/car-segmentation/masks') / file_name.name mask = Image.open(mask_path) if mask.mode == 'P': mask = mask.convert('RGBA') if mask.mode != 'RGBA': mask = mask.convert('RGBA') mask_data = mask.getdata() new_mask_data = [ (r, g, b, 255) if a > 0 else (0, 0, 0, 255) for r, g, b, a in mask_data ] mask.putdata(new_mask_data) return PILMask.create(mask.convert('L')) LEARNER = load_learner(MODEL_PATH / 'car-segmentation_v1.pkl') def segment_image(image): # Store original size original_size = image.size # (width, height) # Resize the input image to 400x400 for the model resized_image = image.resize((400, 400)) resized_image = PILImage.create(resized_image) # Get the prediction from the model prediction, _, _ = LEARNER.predict(resized_image) # Convert prediction to a NumPy array prediction_array = numpy.asarray(prediction, dtype=numpy.uint8) # Resize the mask back to the original image size prediction_resized = Image.fromarray(prediction_array).resize(original_size, Image.NEAREST) prediction_resized = numpy.array(prediction_resized) # Apply a colormap for visualization (using the public API) colormap = cm._colormaps['jet'] # Normalize the mask and apply the colormap (result is in float [0,1]) colored_mask = colormap(prediction_resized / numpy.max(prediction_resized))[:, :, :3] # Convert the original image to a NumPy array and normalize to [0,1] image_array = numpy.array(image).astype(numpy.float32) / 255.0 # Blend the original image and the colored mask overlay = (image_array * 0.7) + (colored_mask * 0.3) # Convert the blended image back to uint8 overlay = (overlay * 255).astype(numpy.uint8) return overlay demo = gradio.Interface( segment_image, inputs=gradio.Image(type='pil'), outputs=gradio.Image(type='numpy'), examples=[str(image) for image in TEST_IMAGES_PATH.iterdir()] ) demo.launch()