File size: 4,646 Bytes
0a911cc
 
 
 
 
be70a4d
0a911cc
be70a4d
0a911cc
 
 
 
 
 
 
 
6898ce9
 
 
0a911cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6898ce9
0a911cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1caec79
0a911cc
1caec79
0a911cc
1caec79
0a911cc
be70a4d
 
 
 
 
 
 
 
0a911cc
 
 
 
 
 
 
 
 
 
 
 
be70a4d
 
 
 
9117c71
 
 
 
 
be70a4d
0a911cc
 
 
 
 
 
 
9117c71
 
 
 
 
0a911cc
 
9117c71
 
be70a4d
 
 
9117c71
 
 
 
 
 
 
 
 
 
be70a4d
 
 
9117c71
 
 
be70a4d
0a911cc
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
import cv2
from PIL import Image
import torch
import numpy as np
import os

from transformers import AutoProcessor, CLIPVisionModel
from detection import detect_image, detect_video
from model import LinearClassifier


def load_model(detection_type):

    device = torch.device("cpu")

    processor = AutoProcessor.from_pretrained("clip-vit-large-patch14")
    clip_model = CLIPVisionModel.from_pretrained("clip-vit-large-patch14", output_attentions=True)
    
    model_path = f"pretrained_models/{detection_type}/clip_weights.pth"
    checkpoint = torch.load(model_path, map_location="cpu")
    input_dim = checkpoint["linear.weight"].shape[1]
    
    detection_model = LinearClassifier(input_dim)
    detection_model.load_state_dict(checkpoint)
    detection_model = detection_model.to(device)

    return processor, clip_model, detection_model

def process_image(image, detection_type):
    processor, clip_model, detection_model = load_model(detection_type)
    
    results = detect_image(image, processor, clip_model, detection_model)

    pred_score = 1 - results["pred_score"]
    attn_map = results["attn_map"]

    return pred_score, attn_map

def process_video(video, detection_type):
    processor, clip_model, detection_model = load_model(detection_type)

    cap = cv2.VideoCapture(video)
    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        pil_image = Image.fromarray(frame)
        frames.append(pil_image)
    cap.release()

    results = detect_video(frames, processor, clip_model, detection_model)

    pred_score = results["pred_score"]
    attn_map = results["attn_map"]

    return pred_score, attn_map

def change_input(input_type):
    if input_type == "Image":
        return gr.update(value=None, visible=True), gr.update(value=None, visible=False)
    elif input_type == "Video":
        return gr.update(value=None, visible=False), gr.update(value=None, visible=True)
    else:
        return gr.update(value=None, visible=False), gr.update(value=None, visible=False)

def determine_model_type(image_path):
    if "facial" in image_path.lower():
        return "Facial"
    elif "general" in image_path.lower():
        return "General"
    else:
        return "Facial"  # 기본값


def process_input(input_type, model_type, image, video):
    detection_type = "facial" if model_type == "Facial" else "general"

    if input_type == "Image" and image is not None:
        return process_image(image, detection_type)
    elif input_type == "Video" and video is not None:
        return process_video(video, detection_type)
    else:
        return None, None


def process_example(image_path):
    model_type = determine_model_type(image_path)
    return Image.open(image_path), model_type

fake_examples, real_examples = [], []
for example in os.listdir("examples/fake"):
    fake_examples.append(os.path.join("examples/fake", example))
for example in os.listdir("examples/real"):
    real_examples.append(os.path.join("examples/real", example))

with gr.Blocks() as demo:
  
    gr.Markdown("## Deepfake Detection : Facial / General")
  
    input_type = gr.Radio(["Image", "Video"], label="Choose Input Type", value="Image")

    model_type = gr.Radio(["Facial", "General"], label="Choose Model Type", value="General")

    H, W = 300, 300
    image_input = gr.Image(type="pil", label="Upload Image", visible=True, height=H, width=W)
    video_input = gr.Video(label="Upload Video", visible=False, height=H, width=W)

    process_button = gr.Button("Run Model")

    pred_score_output = gr.Textbox(label="Prediction Score : 0 - REAL, 1 - FAKE")
    attn_map_output = gr.Image(type="pil", label="Attention Map", height=H, width=W)

    # Example Images 추가
    gr.Examples(
        examples=fake_examples,
        inputs=[image_input],
        outputs=[image_input, model_type],
        fn=process_example,
        cache_examples=False,
        examples_per_page=10,
        label="Fake Examples"
    )
    gr.Examples(
        examples=real_examples,
        inputs=[image_input],
        outputs=[image_input, model_type],
        fn=process_example,
        cache_examples=False,
        examples_per_page=10,
        label="Real Examples"
    )
  
    input_type.change(fn=change_input, inputs=[input_type], outputs=[image_input, video_input])
  
    process_button.click(
        fn=process_input, 
        inputs=[input_type, model_type, image_input, video_input], 
        outputs=[pred_score_output, attn_map_output]
    )

if __name__ == "__main__":
    demo.launch()