Swekerr's picture
Update app.py
ef92f0b verified
raw
history blame
815 Bytes
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import gradio as gr
model = torch.load("squeezenet.pth")
model.eval()
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
def classify_brain_tumor(image):
image = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(image)
_, predicted = torch.max(output, 1)
return "Tumor" if predicted.item() == 1 else "No Tumor"
interface = gr.Interface(
fn=classify_brain_tumor,
inputs=gr.inputs.Image(type="pil"),
outputs="text",
title="Brain Tumor Classification",
description="Upload an MRI image to classify if it has a tumor or not. The Model is SqueezeNet."
)
interface.launch()