File size: 7,109 Bytes
9abf394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190598a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9abf394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190598a
9abf394
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#@title 3. Load Model from HF Directory and Launch Gradio Interface

# --- Imports ---
import torch
import gradio as gr
from PIL import Image
import os
import torch.nn.functional as F
from transformers import AutoFeatureExtractor, ViTForImageClassification
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from torch import device, cuda
import numpy as np

# --- Configuration ---
hf_model_directory = 'best-model-hf'  # Corrected path (no leading dot)
model_checkpoint = "google/vit-base-patch16-224"
device_to_use = device('cuda' if cuda.is_available() else 'cpu')
print(f"Using device: {device_to_use}")

# --- Predictor Class ---
class ImagePredictor:
    def __init__(self, model_dir, base_checkpoint, device):
        self.model_dir = model_dir
        self.base_checkpoint = base_checkpoint
        self.device = device
        self.model = None
        self.feature_extractor = None
        self.transforms = None
        self.id2label = None
        self.num_labels = 0
        self._load_resources() # Load everything during initialization

    def _load_resources(self):
        print("--- Loading Predictor Resources ---")
        # --- Load Feature Extractor (Needed for Preprocessing) ---
        try:
            print(f"Loading feature extractor for: {self.base_checkpoint}")
            self.feature_extractor = AutoFeatureExtractor.from_pretrained(self.base_checkpoint)
            print("Feature extractor loaded.")

            # --- Define Image Transforms ---
            normalize = Normalize(mean=self.feature_extractor.image_mean, std=self.feature_extractor.image_std)
            if isinstance(self.feature_extractor.size, dict):
               image_size = self.feature_extractor.size.get('shortest_edge', self.feature_extractor.size.get('height', 224))
            else:
               image_size = self.feature_extractor.size
            print(f"Using image size: {image_size}")

            self.transforms = Compose([
                Resize(image_size),
                CenterCrop(image_size),
                ToTensor(),
                normalize,
            ])
            print("Inference transforms defined.")

        except Exception as e:
            print(f"FATAL: Error loading feature extractor or defining transforms: {e}")
            # Re-raise to prevent using a partially initialized object
            raise RuntimeError("Feature extractor/transforms loading failed.") from e

        # --- Load the Fine-Tuned Model ---
        if not os.path.isdir(self.model_dir):
            print(f"FATAL: Model directory not found at '{self.model_dir}'.")
            raise FileNotFoundError(f"Model directory not found: {self.model_dir}")

        print(f"Attempting to load model from directory: {self.model_dir}")
        try:
            self.model = ViTForImageClassification.from_pretrained(self.model_dir)
            self.model.to(self.device)
            self.model.eval() # Set model to evaluation mode
            print("Model loaded successfully from directory and moved to device.")

            # --- Load Label Mapping ---
            if hasattr(self.model, 'config') and hasattr(self.model.config, 'id2label'):
                self.id2label = {int(k): v for k, v in self.model.config.id2label.items()}
                self.num_labels = len(self.id2label)
                print(f"Loaded id2label mapping from model config: {self.id2label}")
                print(f"Number of labels: {self.num_labels}")
            else:
                print("WARNING: Could not find 'id2label' in the loaded model's config.")
                # --- !! MANUALLY DEFINE FALLBACK IF NEEDED !! ---
                self.id2label = {0: 'fake', 1: 'real'} # ENSURE THIS MATCHES TRAINING
                self.num_labels = len(self.id2label)
                print(f"Using manually defined id2label: {self.id2label}")
                # ----------------------------------------------

            if self.num_labels == 0:
                raise ValueError("Number of labels is zero after loading.")

            print("--- Predictor Resources Loaded Successfully ---")

        except Exception as e:
            print(f"FATAL: An unexpected error occurred loading the model: {e}")
            # Reset model attribute to indicate failure clearly
            self.model = None
            # Re-raise to prevent using a partially initialized object
            raise RuntimeError("Model loading failed.") from e

    # --- Prediction Method ---
# Inside the ImagePredictor class:
def predict(self, image: Image.Image):
    print("--- Predict function called ---") # Check if this even prints in Space logs
    if image is None:
        print("Input image is None")
        return None
    try:
        # Simulate some processing time
        import time
        time.sleep(0.1)
        # Return a dummy dictionary, bypassing all model/transform logic
        dummy_output = {"fake": 0.6, "real": 0.4} # Use your actual labels
        print(f"Returning dummy output: {dummy_output}")
        return dummy_output
    except Exception as e:
        print(f"Error in *simplified* predict: {e}")
        return {"Error": f"Simplified prediction failed: {str(e)}"}


# --- Main Execution Logic ---
predictor = None
try:
    # Instantiate the predictor ONCE globally
    # This loads the model, tokenizer, transforms, etc. immediately
    predictor = ImagePredictor(
        model_dir=hf_model_directory,
        base_checkpoint=model_checkpoint,
        device=device_to_use
    )
except Exception as e:
     print(f"Failed to initialize ImagePredictor: {e}")
     # predictor remains None


# --- Create and Launch the Gradio Interface ---
if predictor and predictor.model: # Check if predictor initialized successfully
    print("\nSetting up Gradio Interface...")
    try:
        iface = gr.Interface(
            # Pass the INSTANCE METHOD to fn
            fn=predictor.predict,
            inputs=gr.Image(type="pil", label="Upload Face Image"),
            outputs=gr.Label(num_top_classes=predictor.num_labels, label="Prediction (Real/Fake)"),
            title="Real vs. Fake Face Detector",
            description=f"Upload an image of a face to classify it using the fine-tuned ViT model loaded from the '{hf_model_directory}' directory.",
        )

        print("Launching Gradio interface...")
        # Set share=True as requested
        iface.launch(share=True, debug=True, show_error=True).queue()

    except Exception as e:
        print(f"Error creating or launching Gradio interface: {e}")

else:
    print("\nCould not launch Gradio interface because the Predictor failed to initialize.")
    print("Please check the error messages above.")


# Optional: Add message for Colab/persistent running if needed
print("\nGradio setup finished. Interface should be running or an error reported above.")
# print("Stop this cell execution in Colab to shut down the Gradio server.")