Spaces:
Sleeping
Sleeping
File size: 7,109 Bytes
9abf394 190598a 9abf394 190598a 9abf394 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
#@title 3. Load Model from HF Directory and Launch Gradio Interface
# --- Imports ---
import torch
import gradio as gr
from PIL import Image
import os
import torch.nn.functional as F
from transformers import AutoFeatureExtractor, ViTForImageClassification
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from torch import device, cuda
import numpy as np
# --- Configuration ---
hf_model_directory = 'best-model-hf' # Corrected path (no leading dot)
model_checkpoint = "google/vit-base-patch16-224"
device_to_use = device('cuda' if cuda.is_available() else 'cpu')
print(f"Using device: {device_to_use}")
# --- Predictor Class ---
class ImagePredictor:
def __init__(self, model_dir, base_checkpoint, device):
self.model_dir = model_dir
self.base_checkpoint = base_checkpoint
self.device = device
self.model = None
self.feature_extractor = None
self.transforms = None
self.id2label = None
self.num_labels = 0
self._load_resources() # Load everything during initialization
def _load_resources(self):
print("--- Loading Predictor Resources ---")
# --- Load Feature Extractor (Needed for Preprocessing) ---
try:
print(f"Loading feature extractor for: {self.base_checkpoint}")
self.feature_extractor = AutoFeatureExtractor.from_pretrained(self.base_checkpoint)
print("Feature extractor loaded.")
# --- Define Image Transforms ---
normalize = Normalize(mean=self.feature_extractor.image_mean, std=self.feature_extractor.image_std)
if isinstance(self.feature_extractor.size, dict):
image_size = self.feature_extractor.size.get('shortest_edge', self.feature_extractor.size.get('height', 224))
else:
image_size = self.feature_extractor.size
print(f"Using image size: {image_size}")
self.transforms = Compose([
Resize(image_size),
CenterCrop(image_size),
ToTensor(),
normalize,
])
print("Inference transforms defined.")
except Exception as e:
print(f"FATAL: Error loading feature extractor or defining transforms: {e}")
# Re-raise to prevent using a partially initialized object
raise RuntimeError("Feature extractor/transforms loading failed.") from e
# --- Load the Fine-Tuned Model ---
if not os.path.isdir(self.model_dir):
print(f"FATAL: Model directory not found at '{self.model_dir}'.")
raise FileNotFoundError(f"Model directory not found: {self.model_dir}")
print(f"Attempting to load model from directory: {self.model_dir}")
try:
self.model = ViTForImageClassification.from_pretrained(self.model_dir)
self.model.to(self.device)
self.model.eval() # Set model to evaluation mode
print("Model loaded successfully from directory and moved to device.")
# --- Load Label Mapping ---
if hasattr(self.model, 'config') and hasattr(self.model.config, 'id2label'):
self.id2label = {int(k): v for k, v in self.model.config.id2label.items()}
self.num_labels = len(self.id2label)
print(f"Loaded id2label mapping from model config: {self.id2label}")
print(f"Number of labels: {self.num_labels}")
else:
print("WARNING: Could not find 'id2label' in the loaded model's config.")
# --- !! MANUALLY DEFINE FALLBACK IF NEEDED !! ---
self.id2label = {0: 'fake', 1: 'real'} # ENSURE THIS MATCHES TRAINING
self.num_labels = len(self.id2label)
print(f"Using manually defined id2label: {self.id2label}")
# ----------------------------------------------
if self.num_labels == 0:
raise ValueError("Number of labels is zero after loading.")
print("--- Predictor Resources Loaded Successfully ---")
except Exception as e:
print(f"FATAL: An unexpected error occurred loading the model: {e}")
# Reset model attribute to indicate failure clearly
self.model = None
# Re-raise to prevent using a partially initialized object
raise RuntimeError("Model loading failed.") from e
# --- Prediction Method ---
# Inside the ImagePredictor class:
def predict(self, image: Image.Image):
print("--- Predict function called ---") # Check if this even prints in Space logs
if image is None:
print("Input image is None")
return None
try:
# Simulate some processing time
import time
time.sleep(0.1)
# Return a dummy dictionary, bypassing all model/transform logic
dummy_output = {"fake": 0.6, "real": 0.4} # Use your actual labels
print(f"Returning dummy output: {dummy_output}")
return dummy_output
except Exception as e:
print(f"Error in *simplified* predict: {e}")
return {"Error": f"Simplified prediction failed: {str(e)}"}
# --- Main Execution Logic ---
predictor = None
try:
# Instantiate the predictor ONCE globally
# This loads the model, tokenizer, transforms, etc. immediately
predictor = ImagePredictor(
model_dir=hf_model_directory,
base_checkpoint=model_checkpoint,
device=device_to_use
)
except Exception as e:
print(f"Failed to initialize ImagePredictor: {e}")
# predictor remains None
# --- Create and Launch the Gradio Interface ---
if predictor and predictor.model: # Check if predictor initialized successfully
print("\nSetting up Gradio Interface...")
try:
iface = gr.Interface(
# Pass the INSTANCE METHOD to fn
fn=predictor.predict,
inputs=gr.Image(type="pil", label="Upload Face Image"),
outputs=gr.Label(num_top_classes=predictor.num_labels, label="Prediction (Real/Fake)"),
title="Real vs. Fake Face Detector",
description=f"Upload an image of a face to classify it using the fine-tuned ViT model loaded from the '{hf_model_directory}' directory.",
)
print("Launching Gradio interface...")
# Set share=True as requested
iface.launch(share=True, debug=True, show_error=True).queue()
except Exception as e:
print(f"Error creating or launching Gradio interface: {e}")
else:
print("\nCould not launch Gradio interface because the Predictor failed to initialize.")
print("Please check the error messages above.")
# Optional: Add message for Colab/persistent running if needed
print("\nGradio setup finished. Interface should be running or an error reported above.")
# print("Stop this cell execution in Colab to shut down the Gradio server.") |