Spaces:
Running
on
Zero
Running
on
Zero
fall detection with clip
Browse files
app.py
CHANGED
@@ -1,37 +1,116 @@
|
|
|
|
|
|
1 |
import torch
|
2 |
-
import
|
|
|
3 |
import gradio as gr
|
4 |
from transformers import CLIPProcessor, CLIPModel
|
5 |
|
6 |
-
# Load
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
@spaces.GPU
|
10 |
-
def
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
model.
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
#
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
)
|
35 |
|
36 |
if __name__ == "__main__":
|
37 |
-
|
|
|
1 |
+
import spaces
|
2 |
+
import cv2
|
3 |
import torch
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
import gradio as gr
|
7 |
from transformers import CLIPProcessor, CLIPModel
|
8 |
|
9 |
+
# Load LightCLIP model.
|
10 |
+
# Replace "openai/clip-vit-base-patch32" with your LightCLIP model checkpoint if available.
|
11 |
+
MODEL_NAME = "openai/clip-vit-base-patch32"
|
12 |
+
|
13 |
+
# Define text prompts for fall and non-fall.
|
14 |
+
fall_prompt = "A person falling on the ground."
|
15 |
+
nofall_prompt = "A person standing or walking."
|
16 |
+
|
17 |
+
if torch.cuda.is_available():
|
18 |
+
text_inputs = {k: v.cuda() for k, v in text_inputs.items()}
|
19 |
+
|
20 |
+
def extract_frames(video_path, target_size=(224, 224)):
|
21 |
+
"""
|
22 |
+
Extract all frames from the uploaded video and convert them to PIL Image.
|
23 |
+
"""
|
24 |
+
cap = cv2.VideoCapture(video_path)
|
25 |
+
frames = []
|
26 |
+
while True:
|
27 |
+
ret, frame = cap.read()
|
28 |
+
if not ret:
|
29 |
+
break
|
30 |
+
# Convert frame from BGR to RGB and resize.
|
31 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
32 |
+
frame = cv2.resize(frame, target_size)
|
33 |
+
frames.append(Image.fromarray(frame))
|
34 |
+
cap.release()
|
35 |
+
return frames
|
36 |
|
37 |
@spaces.GPU
|
38 |
+
def process_window(frames_window):
|
39 |
+
"""
|
40 |
+
Process a window of frames and compute the average fall score.
|
41 |
+
"""
|
42 |
+
|
43 |
+
processor = CLIPProcessor.from_pretrained(MODEL_NAME)
|
44 |
+
model = CLIPModel.from_pretrained(MODEL_NAME)
|
45 |
+
text_inputs = processor(text=[fall_prompt, nofall_prompt], return_tensors="pt", padding=True)
|
46 |
+
|
47 |
+
|
48 |
+
inputs = processor(images=frames_window, return_tensors="pt", padding=True)
|
49 |
+
if torch.cuda.is_available():
|
50 |
+
inputs = {k: v.cuda() for k, v in inputs.items()}
|
51 |
+
with torch.no_grad():
|
52 |
+
image_features = model.get_image_features(**inputs)
|
53 |
+
# Normalize embeddings.
|
54 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
55 |
+
|
56 |
+
with torch.no_grad():
|
57 |
+
text_features = model.get_text_features(**text_inputs)
|
58 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
59 |
+
|
60 |
+
# Compute cosine similarity.
|
61 |
+
sims = (image_features @ text_features.T).cpu().numpy() # shape: (num_frames, 2)
|
62 |
+
# We assume index 0 is for the fall prompt.
|
63 |
+
fall_scores = sims[:, 0]
|
64 |
+
window_score = np.mean(fall_scores)
|
65 |
+
return window_score, fall_scores
|
66 |
+
|
67 |
+
def detect_fall(video_path, window_size=16, stride=8, threshold=0.8, fps=15):
|
68 |
+
"""
|
69 |
+
Process the video file using a sliding window over frames.
|
70 |
+
Returns a list of timestamps where a fall is detected.
|
71 |
+
"""
|
72 |
+
frames = extract_frames(video_path)
|
73 |
+
if len(frames) < window_size:
|
74 |
+
return "Video too short for inference.", None
|
75 |
+
|
76 |
+
window_scores = []
|
77 |
+
window_indices = []
|
78 |
+
for start in range(0, len(frames) - window_size + 1, stride):
|
79 |
+
window = frames[start:start + window_size]
|
80 |
+
score, _ = process_window(window)
|
81 |
+
window_scores.append(score)
|
82 |
+
window_indices.append(start)
|
83 |
+
|
84 |
+
detected_events = []
|
85 |
+
for idx, score in zip(window_indices, window_scores):
|
86 |
+
if score > threshold:
|
87 |
+
time_sec = idx / fps # approximate timestamp
|
88 |
+
detected_events.append(time_sec)
|
89 |
+
|
90 |
+
result_text = ""
|
91 |
+
if detected_events:
|
92 |
+
result_text = "Fall events detected at (sec): " + ", ".join([f"{t:.1f}" for t in detected_events])
|
93 |
+
else:
|
94 |
+
result_text = "No fall detected."
|
95 |
+
# Return result and a representative frame for visual reference.
|
96 |
+
rep_frame = frames[len(frames) // 2]
|
97 |
+
return result_text, rep_frame
|
98 |
+
|
99 |
+
def process_video(video_file):
|
100 |
+
result_text, rep_frame = detect_fall(video_file)
|
101 |
+
return result_text, rep_frame
|
102 |
+
|
103 |
+
# Gradio interface definition.
|
104 |
+
demo = gr.Interface(
|
105 |
+
fn=process_video,
|
106 |
+
inputs=gr.Video(type="filepath", label="Upload Video Clip"),
|
107 |
+
outputs=[gr.Textbox(label="Detection Results"), gr.Image(label="Representative Frame")],
|
108 |
+
title="LightCLIP Fall Detection Demo",
|
109 |
+
description=(
|
110 |
+
"This demo detects human falls in video clips using a lightweight transformer-based model (LightCLIP). "
|
111 |
+
"A sliding window approach aggregates results over multiple frames to improve precision in complex scenes."
|
112 |
+
)
|
113 |
)
|
114 |
|
115 |
if __name__ == "__main__":
|
116 |
+
demo.launch()
|