Luigi's picture
increase zero gpu duration to 10 min
a1c9ab8
import spaces
import cv2
import torch
import numpy as np
from PIL import Image
import gradio as gr
from transformers import CLIPProcessor, CLIPModel
# Load LightCLIP model.
# Replace "openai/clip-vit-base-patch32" with your LightCLIP model checkpoint if available.
MODEL_NAME = "openai/clip-vit-base-patch32"
# Define text prompts for fall and non-fall.
fall_prompt = "A person falling on the ground."
nofall_prompt = "A person standing or walking."
# if torch.cuda.is_available():
# text_inputs = {k: v.cuda() for k, v in text_inputs.items()}
def extract_frames(video_path, target_size=(224, 224)):
"""
Extract all frames from the uploaded video and convert them to PIL Image.
"""
cap = cv2.VideoCapture(video_path)
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
# Convert frame from BGR to RGB and resize.
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = cv2.resize(frame, target_size)
frames.append(Image.fromarray(frame))
cap.release()
return frames
@spaces.GPU(duration=600)
def process_window(frames_window):
"""
Process a window of frames and compute the average fall score.
"""
processor = CLIPProcessor.from_pretrained(MODEL_NAME)
model = CLIPModel.from_pretrained(MODEL_NAME)
text_inputs = processor(text=[fall_prompt, nofall_prompt], return_tensors="pt", padding=True)
inputs = processor(images=frames_window, return_tensors="pt", padding=True)
if torch.cuda.is_available():
text_inputs = text_inputs.to(torch.device("cuda"))
model = model.to(torch.device("cuda"))
inputs = {k: v.cuda() for k, v in inputs.items()}
with torch.no_grad():
image_features = model.get_image_features(**inputs)
# Normalize embeddings.
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
with torch.no_grad():
text_features = model.get_text_features(**text_inputs)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# Compute cosine similarity.
sims = (image_features @ text_features.T).cpu().numpy() # shape: (num_frames, 2)
# We assume index 0 is for the fall prompt.
fall_scores = sims[:, 0]
window_score = np.mean(fall_scores)
return window_score, fall_scores
def detect_fall(video_path, window_size=16, stride=8, threshold=0.8, fps=15):
"""
Process the video file using a sliding window over frames.
Returns a list of timestamps where a fall is detected.
"""
frames = extract_frames(video_path)
if len(frames) < window_size:
return "Video too short for inference.", None
window_scores = []
window_indices = []
for start in range(0, len(frames) - window_size + 1, stride):
window = frames[start:start + window_size]
score, _ = process_window(window)
window_scores.append(score)
window_indices.append(start)
detected_events = []
for idx, score in zip(window_indices, window_scores):
if score > threshold:
time_sec = idx / fps # approximate timestamp
detected_events.append(time_sec)
result_text = ""
if detected_events:
result_text = "Fall events detected at (sec): " + ", ".join([f"{t:.1f}" for t in detected_events])
else:
result_text = "No fall detected."
# Return result and a representative frame for visual reference.
rep_frame = frames[len(frames) // 2]
return result_text, rep_frame
def process_video(video_file):
result_text, rep_frame = detect_fall(video_file)
return result_text, rep_frame
# Gradio interface definition.
demo = gr.Interface(
fn=process_video,
inputs=gr.Video(value="filepath", label="Upload Video Clip"),
outputs=[gr.Textbox(label="Detection Results"), gr.Image(label="Representative Frame")],
title="LightCLIP Fall Detection Demo",
description=(
"This demo detects human falls in video clips using a lightweight transformer-based model (LightCLIP). "
"A sliding window approach aggregates results over multiple frames to improve precision in complex scenes."
)
)
if __name__ == "__main__":
demo.launch()