Spaces:
Running
Running
File size: 4,646 Bytes
0a911cc be70a4d 0a911cc be70a4d 0a911cc 6898ce9 0a911cc 6898ce9 0a911cc 1caec79 0a911cc 1caec79 0a911cc 1caec79 0a911cc be70a4d 0a911cc be70a4d 9117c71 be70a4d 0a911cc 9117c71 0a911cc 9117c71 be70a4d 9117c71 be70a4d 9117c71 be70a4d 0a911cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import gradio as gr
import cv2
from PIL import Image
import torch
import numpy as np
import os
from transformers import AutoProcessor, CLIPVisionModel
from detection import detect_image, detect_video
from model import LinearClassifier
def load_model(detection_type):
device = torch.device("cpu")
processor = AutoProcessor.from_pretrained("clip-vit-large-patch14")
clip_model = CLIPVisionModel.from_pretrained("clip-vit-large-patch14", output_attentions=True)
model_path = f"pretrained_models/{detection_type}/clip_weights.pth"
checkpoint = torch.load(model_path, map_location="cpu")
input_dim = checkpoint["linear.weight"].shape[1]
detection_model = LinearClassifier(input_dim)
detection_model.load_state_dict(checkpoint)
detection_model = detection_model.to(device)
return processor, clip_model, detection_model
def process_image(image, detection_type):
processor, clip_model, detection_model = load_model(detection_type)
results = detect_image(image, processor, clip_model, detection_model)
pred_score = 1 - results["pred_score"]
attn_map = results["attn_map"]
return pred_score, attn_map
def process_video(video, detection_type):
processor, clip_model, detection_model = load_model(detection_type)
cap = cv2.VideoCapture(video)
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(frame)
frames.append(pil_image)
cap.release()
results = detect_video(frames, processor, clip_model, detection_model)
pred_score = results["pred_score"]
attn_map = results["attn_map"]
return pred_score, attn_map
def change_input(input_type):
if input_type == "Image":
return gr.update(value=None, visible=True), gr.update(value=None, visible=False)
elif input_type == "Video":
return gr.update(value=None, visible=False), gr.update(value=None, visible=True)
else:
return gr.update(value=None, visible=False), gr.update(value=None, visible=False)
def determine_model_type(image_path):
if "facial" in image_path.lower():
return "Facial"
elif "general" in image_path.lower():
return "General"
else:
return "Facial" # 기본값
def process_input(input_type, model_type, image, video):
detection_type = "facial" if model_type == "Facial" else "general"
if input_type == "Image" and image is not None:
return process_image(image, detection_type)
elif input_type == "Video" and video is not None:
return process_video(video, detection_type)
else:
return None, None
def process_example(image_path):
model_type = determine_model_type(image_path)
return Image.open(image_path), model_type
fake_examples, real_examples = [], []
for example in os.listdir("examples/fake"):
fake_examples.append(os.path.join("examples/fake", example))
for example in os.listdir("examples/real"):
real_examples.append(os.path.join("examples/real", example))
with gr.Blocks() as demo:
gr.Markdown("## Deepfake Detection : Facial / General")
input_type = gr.Radio(["Image", "Video"], label="Choose Input Type", value="Image")
model_type = gr.Radio(["Facial", "General"], label="Choose Model Type", value="General")
H, W = 300, 300
image_input = gr.Image(type="pil", label="Upload Image", visible=True, height=H, width=W)
video_input = gr.Video(label="Upload Video", visible=False, height=H, width=W)
process_button = gr.Button("Run Model")
pred_score_output = gr.Textbox(label="Prediction Score : 0 - REAL, 1 - FAKE")
attn_map_output = gr.Image(type="pil", label="Attention Map", height=H, width=W)
# Example Images 추가
gr.Examples(
examples=fake_examples,
inputs=[image_input],
outputs=[image_input, model_type],
fn=process_example,
cache_examples=False,
examples_per_page=10,
label="Fake Examples"
)
gr.Examples(
examples=real_examples,
inputs=[image_input],
outputs=[image_input, model_type],
fn=process_example,
cache_examples=False,
examples_per_page=10,
label="Real Examples"
)
input_type.change(fn=change_input, inputs=[input_type], outputs=[image_input, video_input])
process_button.click(
fn=process_input,
inputs=[input_type, model_type, image_input, video_input],
outputs=[pred_score_output, attn_map_output]
)
if __name__ == "__main__":
demo.launch() |