deepfake-detection / detection.py
wooj0216's picture
FIX: video
c96cf7e
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image
def vis_attn(image, patch_attention_map, alpha=0.5, vis_option="none"):
image = np.array(image)
H, W, _ = image.shape
seq_len = patch_attention_map.shape[0]
grid_size = int(seq_len ** 0.5)
patch_attention_map = patch_attention_map.reshape(grid_size, grid_size)
patch_attention_map = cv2.resize(patch_attention_map.cpu().detach().numpy(), (W, H), interpolation=cv2.INTER_CUBIC)
patch_attention_map = (patch_attention_map - patch_attention_map.min()) / (patch_attention_map.max() - patch_attention_map.min())
patch_attention_map = np.uint8(255 * patch_attention_map)
heatmap = cv2.applyColorMap(patch_attention_map, cv2.COLORMAP_JET)
blended_image = cv2.addWeighted(image, 1 - alpha, heatmap, alpha, 0)
blended_image = cv2.cvtColor(blended_image, cv2.COLOR_RGB2BGR)
blended_image = Image.fromarray(blended_image)
return blended_image
def detect_image(image, processor, clip_model, detection_model):
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = clip_model(**inputs)
last_hidden_states = outputs.last_hidden_state[:, 0, :]
pred_score = float(detection_model(last_hidden_states)[0][0].cpu().detach().numpy())
assert 0 <= pred_score <= 1
for layer_idx in range(len(outputs.attentions)):
attn_map = outputs.attentions[layer_idx]
if layer_idx == 0:
last_layer_attn = attn_map
else:
if layer_idx < 6:
last_layer_attn += attn_map
head_mean_attn = last_layer_attn.mean(dim=1)[0]
cls_attention_map = head_mean_attn[0, 1:]
blended_image = vis_attn(image, cls_attention_map)
results = {
"pred_score": pred_score,
"attn_map": blended_image,
}
return results
def detect_video(frames, processor, clip_model, detection_model):
image = frames[0]
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = clip_model(**inputs)
last_hidden_states = outputs.last_hidden_state[:, 0, :]
pred_score = float(detection_model(last_hidden_states)[0][0].cpu().detach().numpy())
assert 0 <= pred_score <= 1
for layer_idx in range(len(outputs.attentions)):
attn_map = outputs.attentions[layer_idx]
if layer_idx == 0:
last_layer_attn = attn_map
else:
if layer_idx < 6:
last_layer_attn += attn_map
head_mean_attn = last_layer_attn.mean(dim=1)[0]
cls_attention_map = head_mean_attn[0, 1:]
blended_image = vis_attn(image, cls_attention_map)
results = {
"pred_score": pred_score,
"attn_map": blended_image,
}
return results