Spaces:
Running
Running
import gradio as gr | |
import cv2 | |
from PIL import Image | |
import torch | |
import numpy as np | |
import os | |
from transformers import AutoProcessor, CLIPVisionModel | |
from detection import detect_image, detect_video | |
from model import LinearClassifier | |
def load_model(detection_type): | |
device = torch.device("cpu") | |
processor = AutoProcessor.from_pretrained("clip-vit-large-patch14") | |
clip_model = CLIPVisionModel.from_pretrained("clip-vit-large-patch14", output_attentions=True) | |
model_path = f"pretrained_models/{detection_type}/clip_weights.pth" | |
checkpoint = torch.load(model_path, map_location="cpu") | |
input_dim = checkpoint["linear.weight"].shape[1] | |
detection_model = LinearClassifier(input_dim) | |
detection_model.load_state_dict(checkpoint) | |
detection_model = detection_model.to(device) | |
return processor, clip_model, detection_model | |
def process_image(image, detection_type): | |
processor, clip_model, detection_model = load_model(detection_type) | |
results = detect_image(image, processor, clip_model, detection_model) | |
pred_score = 1 - results["pred_score"] | |
attn_map = results["attn_map"] | |
return pred_score, attn_map | |
def process_video(video, detection_type): | |
processor, clip_model, detection_model = load_model(detection_type) | |
cap = cv2.VideoCapture(video) | |
frames = [] | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
pil_image = Image.fromarray(frame) | |
frames.append(pil_image) | |
cap.release() | |
results = detect_video(frames, processor, clip_model, detection_model) | |
pred_score = results["pred_score"] | |
attn_map = results["attn_map"] | |
return pred_score, attn_map | |
def change_input(input_type): | |
if input_type == "Image": | |
return gr.update(value=None, visible=True), gr.update(value=None, visible=False) | |
elif input_type == "Video": | |
return gr.update(value=None, visible=False), gr.update(value=None, visible=True) | |
else: | |
return gr.update(value=None, visible=False), gr.update(value=None, visible=False) | |
def determine_model_type(image_path): | |
if "facial" in image_path.lower(): | |
return "Facial" | |
elif "general" in image_path.lower(): | |
return "General" | |
else: | |
return "Facial" # 기본값 | |
def process_input(input_type, model_type, image, video): | |
detection_type = "facial" if model_type == "Facial" else "general" | |
if input_type == "Image" and image is not None: | |
return process_image(image, detection_type) | |
elif input_type == "Video" and video is not None: | |
return process_video(video, detection_type) | |
else: | |
return None, None | |
def process_example(image_path): | |
model_type = determine_model_type(image_path) | |
return Image.open(image_path), model_type | |
fake_examples, real_examples = [], [] | |
for example in os.listdir("examples/fake"): | |
fake_examples.append(os.path.join("examples/fake", example)) | |
for example in os.listdir("examples/real"): | |
real_examples.append(os.path.join("examples/real", example)) | |
with gr.Blocks() as demo: | |
gr.Markdown("## Deepfake Detection : Facial / General") | |
input_type = gr.Radio(["Image", "Video"], label="Choose Input Type", value="Image") | |
model_type = gr.Radio(["Facial", "General"], label="Choose Model Type", value="General") | |
H, W = 300, 300 | |
image_input = gr.Image(type="pil", label="Upload Image", visible=True, height=H, width=W) | |
video_input = gr.Video(label="Upload Video", visible=False, height=H, width=W) | |
process_button = gr.Button("Run Model") | |
pred_score_output = gr.Textbox(label="Prediction Score : 0 - REAL, 1 - FAKE") | |
attn_map_output = gr.Image(type="pil", label="Attention Map", height=H, width=W) | |
# Example Images 추가 | |
gr.Examples( | |
examples=fake_examples, | |
inputs=[image_input], | |
outputs=[image_input, model_type], | |
fn=process_example, | |
cache_examples=False, | |
examples_per_page=10, | |
label="Fake Examples" | |
) | |
gr.Examples( | |
examples=real_examples, | |
inputs=[image_input], | |
outputs=[image_input, model_type], | |
fn=process_example, | |
cache_examples=False, | |
examples_per_page=10, | |
label="Real Examples" | |
) | |
input_type.change(fn=change_input, inputs=[input_type], outputs=[image_input, video_input]) | |
process_button.click( | |
fn=process_input, | |
inputs=[input_type, model_type, image_input, video_input], | |
outputs=[pred_score_output, attn_map_output] | |
) | |
if __name__ == "__main__": | |
demo.launch() |