Spaces:
Sleeping
Sleeping
#@title 3. Load Model from HF Directory and Launch Gradio Interface | |
# --- Imports --- | |
import torch | |
import gradio as gr | |
from PIL import Image | |
import os | |
import torch.nn.functional as F | |
from transformers import AutoFeatureExtractor, ViTForImageClassification | |
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize | |
from torch import device, cuda | |
import numpy as np | |
# --- Configuration --- | |
hf_model_directory = 'best-model-hf' # Corrected path (no leading dot) | |
model_checkpoint = "google/vit-base-patch16-224" | |
device_to_use = device('cuda' if cuda.is_available() else 'cpu') | |
print(f"Using device: {device_to_use}") | |
# --- Predictor Class --- | |
class ImagePredictor: | |
def __init__(self, model_dir, base_checkpoint, device): | |
self.model_dir = model_dir | |
self.base_checkpoint = base_checkpoint | |
self.device = device | |
self.model = None | |
self.feature_extractor = None | |
self.transforms = None | |
self.id2label = None | |
self.num_labels = 0 | |
self._load_resources() # Load everything during initialization | |
def _load_resources(self): | |
print("--- Loading Predictor Resources ---") | |
# --- Load Feature Extractor (Needed for Preprocessing) --- | |
try: | |
print(f"Loading feature extractor for: {self.base_checkpoint}") | |
self.feature_extractor = AutoFeatureExtractor.from_pretrained(self.base_checkpoint) | |
print("Feature extractor loaded.") | |
# --- Define Image Transforms --- | |
normalize = Normalize(mean=self.feature_extractor.image_mean, std=self.feature_extractor.image_std) | |
if isinstance(self.feature_extractor.size, dict): | |
image_size = self.feature_extractor.size.get('shortest_edge', self.feature_extractor.size.get('height', 224)) | |
else: | |
image_size = self.feature_extractor.size | |
print(f"Using image size: {image_size}") | |
self.transforms = Compose([ | |
Resize(image_size), | |
CenterCrop(image_size), | |
ToTensor(), | |
normalize, | |
]) | |
print("Inference transforms defined.") | |
except Exception as e: | |
print(f"FATAL: Error loading feature extractor or defining transforms: {e}") | |
# Re-raise to prevent using a partially initialized object | |
raise RuntimeError("Feature extractor/transforms loading failed.") from e | |
# --- Load the Fine-Tuned Model --- | |
if not os.path.isdir(self.model_dir): | |
print(f"FATAL: Model directory not found at '{self.model_dir}'.") | |
raise FileNotFoundError(f"Model directory not found: {self.model_dir}") | |
print(f"Attempting to load model from directory: {self.model_dir}") | |
try: | |
self.model = ViTForImageClassification.from_pretrained(self.model_dir) | |
self.model.to(self.device) | |
self.model.eval() # Set model to evaluation mode | |
print("Model loaded successfully from directory and moved to device.") | |
# --- Load Label Mapping --- | |
if hasattr(self.model, 'config') and hasattr(self.model.config, 'id2label'): | |
self.id2label = {int(k): v for k, v in self.model.config.id2label.items()} | |
self.num_labels = len(self.id2label) | |
print(f"Loaded id2label mapping from model config: {self.id2label}") | |
print(f"Number of labels: {self.num_labels}") | |
else: | |
print("WARNING: Could not find 'id2label' in the loaded model's config.") | |
# --- !! MANUALLY DEFINE FALLBACK IF NEEDED !! --- | |
self.id2label = {0: 'fake', 1: 'real'} # ENSURE THIS MATCHES TRAINING | |
self.num_labels = len(self.id2label) | |
print(f"Using manually defined id2label: {self.id2label}") | |
# ---------------------------------------------- | |
if self.num_labels == 0: | |
raise ValueError("Number of labels is zero after loading.") | |
print("--- Predictor Resources Loaded Successfully ---") | |
except Exception as e: | |
print(f"FATAL: An unexpected error occurred loading the model: {e}") | |
# Reset model attribute to indicate failure clearly | |
self.model = None | |
# Re-raise to prevent using a partially initialized object | |
raise RuntimeError("Model loading failed.") from e | |
# --- Prediction Method --- | |
# Inside the ImagePredictor class: | |
def predict(self, image: Image.Image): | |
print("--- Predict function called ---") # Check if this even prints in Space logs | |
if image is None: | |
print("Input image is None") | |
return None | |
try: | |
# Simulate some processing time | |
import time | |
time.sleep(0.1) | |
# Return a dummy dictionary, bypassing all model/transform logic | |
dummy_output = {"fake": 0.6, "real": 0.4} # Use your actual labels | |
print(f"Returning dummy output: {dummy_output}") | |
return dummy_output | |
except Exception as e: | |
print(f"Error in *simplified* predict: {e}") | |
return {"Error": f"Simplified prediction failed: {str(e)}"} | |
# --- Main Execution Logic --- | |
predictor = None | |
try: | |
# Instantiate the predictor ONCE globally | |
# This loads the model, tokenizer, transforms, etc. immediately | |
predictor = ImagePredictor( | |
model_dir=hf_model_directory, | |
base_checkpoint=model_checkpoint, | |
device=device_to_use | |
) | |
except Exception as e: | |
print(f"Failed to initialize ImagePredictor: {e}") | |
# predictor remains None | |
# --- Create and Launch the Gradio Interface --- | |
if predictor and predictor.model: # Check if predictor initialized successfully | |
print("\nSetting up Gradio Interface...") | |
try: | |
iface = gr.Interface( | |
# Pass the INSTANCE METHOD to fn | |
fn=predictor.predict, | |
inputs=gr.Image(type="pil", label="Upload Face Image"), | |
outputs=gr.Label(num_top_classes=predictor.num_labels, label="Prediction (Real/Fake)"), | |
title="Real vs. Fake Face Detector", | |
description=f"Upload an image of a face to classify it using the fine-tuned ViT model loaded from the '{hf_model_directory}' directory.", | |
) | |
print("Launching Gradio interface...") | |
# Set share=True as requested | |
iface.launch(share=True, debug=True, show_error=True).queue() | |
except Exception as e: | |
print(f"Error creating or launching Gradio interface: {e}") | |
else: | |
print("\nCould not launch Gradio interface because the Predictor failed to initialize.") | |
print("Please check the error messages above.") | |
# Optional: Add message for Colab/persistent running if needed | |
print("\nGradio setup finished. Interface should be running or an error reported above.") | |
# print("Stop this cell execution in Colab to shut down the Gradio server.") |