ehtyalee's picture
Upload 2 files
190598a verified
#@title 3. Load Model from HF Directory and Launch Gradio Interface
# --- Imports ---
import torch
import gradio as gr
from PIL import Image
import os
import torch.nn.functional as F
from transformers import AutoFeatureExtractor, ViTForImageClassification
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from torch import device, cuda
import numpy as np
# --- Configuration ---
hf_model_directory = 'best-model-hf' # Corrected path (no leading dot)
model_checkpoint = "google/vit-base-patch16-224"
device_to_use = device('cuda' if cuda.is_available() else 'cpu')
print(f"Using device: {device_to_use}")
# --- Predictor Class ---
class ImagePredictor:
def __init__(self, model_dir, base_checkpoint, device):
self.model_dir = model_dir
self.base_checkpoint = base_checkpoint
self.device = device
self.model = None
self.feature_extractor = None
self.transforms = None
self.id2label = None
self.num_labels = 0
self._load_resources() # Load everything during initialization
def _load_resources(self):
print("--- Loading Predictor Resources ---")
# --- Load Feature Extractor (Needed for Preprocessing) ---
try:
print(f"Loading feature extractor for: {self.base_checkpoint}")
self.feature_extractor = AutoFeatureExtractor.from_pretrained(self.base_checkpoint)
print("Feature extractor loaded.")
# --- Define Image Transforms ---
normalize = Normalize(mean=self.feature_extractor.image_mean, std=self.feature_extractor.image_std)
if isinstance(self.feature_extractor.size, dict):
image_size = self.feature_extractor.size.get('shortest_edge', self.feature_extractor.size.get('height', 224))
else:
image_size = self.feature_extractor.size
print(f"Using image size: {image_size}")
self.transforms = Compose([
Resize(image_size),
CenterCrop(image_size),
ToTensor(),
normalize,
])
print("Inference transforms defined.")
except Exception as e:
print(f"FATAL: Error loading feature extractor or defining transforms: {e}")
# Re-raise to prevent using a partially initialized object
raise RuntimeError("Feature extractor/transforms loading failed.") from e
# --- Load the Fine-Tuned Model ---
if not os.path.isdir(self.model_dir):
print(f"FATAL: Model directory not found at '{self.model_dir}'.")
raise FileNotFoundError(f"Model directory not found: {self.model_dir}")
print(f"Attempting to load model from directory: {self.model_dir}")
try:
self.model = ViTForImageClassification.from_pretrained(self.model_dir)
self.model.to(self.device)
self.model.eval() # Set model to evaluation mode
print("Model loaded successfully from directory and moved to device.")
# --- Load Label Mapping ---
if hasattr(self.model, 'config') and hasattr(self.model.config, 'id2label'):
self.id2label = {int(k): v for k, v in self.model.config.id2label.items()}
self.num_labels = len(self.id2label)
print(f"Loaded id2label mapping from model config: {self.id2label}")
print(f"Number of labels: {self.num_labels}")
else:
print("WARNING: Could not find 'id2label' in the loaded model's config.")
# --- !! MANUALLY DEFINE FALLBACK IF NEEDED !! ---
self.id2label = {0: 'fake', 1: 'real'} # ENSURE THIS MATCHES TRAINING
self.num_labels = len(self.id2label)
print(f"Using manually defined id2label: {self.id2label}")
# ----------------------------------------------
if self.num_labels == 0:
raise ValueError("Number of labels is zero after loading.")
print("--- Predictor Resources Loaded Successfully ---")
except Exception as e:
print(f"FATAL: An unexpected error occurred loading the model: {e}")
# Reset model attribute to indicate failure clearly
self.model = None
# Re-raise to prevent using a partially initialized object
raise RuntimeError("Model loading failed.") from e
# --- Prediction Method ---
# Inside the ImagePredictor class:
def predict(self, image: Image.Image):
print("--- Predict function called ---") # Check if this even prints in Space logs
if image is None:
print("Input image is None")
return None
try:
# Simulate some processing time
import time
time.sleep(0.1)
# Return a dummy dictionary, bypassing all model/transform logic
dummy_output = {"fake": 0.6, "real": 0.4} # Use your actual labels
print(f"Returning dummy output: {dummy_output}")
return dummy_output
except Exception as e:
print(f"Error in *simplified* predict: {e}")
return {"Error": f"Simplified prediction failed: {str(e)}"}
# --- Main Execution Logic ---
predictor = None
try:
# Instantiate the predictor ONCE globally
# This loads the model, tokenizer, transforms, etc. immediately
predictor = ImagePredictor(
model_dir=hf_model_directory,
base_checkpoint=model_checkpoint,
device=device_to_use
)
except Exception as e:
print(f"Failed to initialize ImagePredictor: {e}")
# predictor remains None
# --- Create and Launch the Gradio Interface ---
if predictor and predictor.model: # Check if predictor initialized successfully
print("\nSetting up Gradio Interface...")
try:
iface = gr.Interface(
# Pass the INSTANCE METHOD to fn
fn=predictor.predict,
inputs=gr.Image(type="pil", label="Upload Face Image"),
outputs=gr.Label(num_top_classes=predictor.num_labels, label="Prediction (Real/Fake)"),
title="Real vs. Fake Face Detector",
description=f"Upload an image of a face to classify it using the fine-tuned ViT model loaded from the '{hf_model_directory}' directory.",
)
print("Launching Gradio interface...")
# Set share=True as requested
iface.launch(share=True, debug=True, show_error=True).queue()
except Exception as e:
print(f"Error creating or launching Gradio interface: {e}")
else:
print("\nCould not launch Gradio interface because the Predictor failed to initialize.")
print("Please check the error messages above.")
# Optional: Add message for Colab/persistent running if needed
print("\nGradio setup finished. Interface should be running or an error reported above.")
# print("Stop this cell execution in Colab to shut down the Gradio server.")