import torch import numpy as np import gradio as gr from pathlib import Path from PIL import Image from torchvision import transforms from huggingface_hub import hf_hub_download from ResNet_for_CC import CC_model # Define the Clothing1M class labels CLOTHING1M_CLASSES = [ "T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie", "Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress", "Vest", "Underwear" ] # Initialize the model model = CC_model() model_path = hf_hub_download(repo_id="mohamdlog/CC", filename="CC_net.pt") model.load_state_dict(torch.load(model_path, map_location='cpu')) model.eval() # Define preprocessing pipeline def preprocess_image(image): if isinstance(image, np.ndarray): image = Image.fromarray(image) transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) return transform(image).unsqueeze(0) # Define classification function def classify_image(image): input_tensor = preprocess_image(image) with torch.no_grad(): output = model(input_tensor) # Get predicted class and confidence probabilities = torch.nn.functional.softmax(output, dim=1) predicted_class_idx = output.argmax(dim=1).item() predicted_class = CLOTHING1M_CLASSES[predicted_class_idx] confidence = probabilities[0][predicted_class_idx].item() return f"Category: {predicted_class}\nConfidence: {confidence:.2f}" # Create Gradio interface interface = gr.Interface( fn=classify_image, inputs=gr.Image(label="Uploaded Image"), outputs=gr.Text(label="Predicted Clothing"), title="Clothing Category Classifier", description = """ **Upload an image of clothing, and the model will predict its category.** Try using an image that doesn't belong to any of the available categories, and see how the result differs! **Categories:** | T-Shirt | Shirt | Knitwear | Chiffon | Sweater | Hoodie | Windbreaker | | Jacket | Downcoat | Suit | Shawl | Dress | Vest | Underwear | """, examples=[[str(file)] for file in Path("examples").glob("*")], flagging_mode="never", theme="soft" ) # Launch the interface if __name__ == "__main__": interface.launch()