|
import spaces |
|
import cv2 |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
import gradio as gr |
|
from transformers import CLIPProcessor, CLIPModel |
|
|
|
|
|
|
|
MODEL_NAME = "openai/clip-vit-base-patch32" |
|
|
|
|
|
fall_prompt = "A person falling on the ground." |
|
nofall_prompt = "A person standing or walking." |
|
|
|
|
|
|
|
|
|
def extract_frames(video_path, target_size=(224, 224)): |
|
""" |
|
Extract all frames from the uploaded video and convert them to PIL Image. |
|
""" |
|
cap = cv2.VideoCapture(video_path) |
|
frames = [] |
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
frame = cv2.resize(frame, target_size) |
|
frames.append(Image.fromarray(frame)) |
|
cap.release() |
|
return frames |
|
|
|
@spaces.GPU(duration=600) |
|
def process_window(frames_window): |
|
""" |
|
Process a window of frames and compute the average fall score. |
|
""" |
|
|
|
processor = CLIPProcessor.from_pretrained(MODEL_NAME) |
|
model = CLIPModel.from_pretrained(MODEL_NAME) |
|
text_inputs = processor(text=[fall_prompt, nofall_prompt], return_tensors="pt", padding=True) |
|
|
|
|
|
inputs = processor(images=frames_window, return_tensors="pt", padding=True) |
|
if torch.cuda.is_available(): |
|
text_inputs = text_inputs.to(torch.device("cuda")) |
|
model = model.to(torch.device("cuda")) |
|
inputs = {k: v.cuda() for k, v in inputs.items()} |
|
with torch.no_grad(): |
|
image_features = model.get_image_features(**inputs) |
|
|
|
image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
|
|
|
with torch.no_grad(): |
|
text_features = model.get_text_features(**text_inputs) |
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
|
|
|
|
|
sims = (image_features @ text_features.T).cpu().numpy() |
|
|
|
fall_scores = sims[:, 0] |
|
window_score = np.mean(fall_scores) |
|
return window_score, fall_scores |
|
|
|
def detect_fall(video_path, window_size=16, stride=8, threshold=0.8, fps=15): |
|
""" |
|
Process the video file using a sliding window over frames. |
|
Returns a list of timestamps where a fall is detected. |
|
""" |
|
frames = extract_frames(video_path) |
|
if len(frames) < window_size: |
|
return "Video too short for inference.", None |
|
|
|
window_scores = [] |
|
window_indices = [] |
|
for start in range(0, len(frames) - window_size + 1, stride): |
|
window = frames[start:start + window_size] |
|
score, _ = process_window(window) |
|
window_scores.append(score) |
|
window_indices.append(start) |
|
|
|
detected_events = [] |
|
for idx, score in zip(window_indices, window_scores): |
|
if score > threshold: |
|
time_sec = idx / fps |
|
detected_events.append(time_sec) |
|
|
|
result_text = "" |
|
if detected_events: |
|
result_text = "Fall events detected at (sec): " + ", ".join([f"{t:.1f}" for t in detected_events]) |
|
else: |
|
result_text = "No fall detected." |
|
|
|
rep_frame = frames[len(frames) // 2] |
|
return result_text, rep_frame |
|
|
|
def process_video(video_file): |
|
result_text, rep_frame = detect_fall(video_file) |
|
return result_text, rep_frame |
|
|
|
|
|
demo = gr.Interface( |
|
fn=process_video, |
|
inputs=gr.Video(value="filepath", label="Upload Video Clip"), |
|
outputs=[gr.Textbox(label="Detection Results"), gr.Image(label="Representative Frame")], |
|
title="LightCLIP Fall Detection Demo", |
|
description=( |
|
"This demo detects human falls in video clips using a lightweight transformer-based model (LightCLIP). " |
|
"A sliding window approach aggregates results over multiple frames to improve precision in complex scenes." |
|
) |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|