Luigi commited on
Commit
6a22aef
·
1 Parent(s): d5cb4e0

fall detection with clip

Browse files
Files changed (1) hide show
  1. app.py +107 -28
app.py CHANGED
@@ -1,37 +1,116 @@
 
 
1
  import torch
2
- import spaces # Import early to avoid potential issues
 
3
  import gradio as gr
4
  from transformers import CLIPProcessor, CLIPModel
5
 
6
- # Load the CLIP model and processor on the CPU initially
7
- model_name = "openai/clip-vit-base-patch32"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  @spaces.GPU
10
- def clip_similarity(image, text):
11
- # Load the model and processor inside GPU context
12
- model = CLIPModel.from_pretrained(model_name)
13
- processor = CLIPProcessor.from_pretrained(model_name)
14
-
15
- device = torch.device("cuda")
16
- model.to(device)
17
-
18
- inputs = processor(text=[text], images=image, return_tensors="pt", padding=True)
19
- inputs = {k: v.to(device) for k, v in inputs.items()}
20
-
21
- outputs = model(**inputs)
22
- similarity_score = outputs.logits_per_image.detach().cpu().numpy()[0]
23
- return float(similarity_score)
24
-
25
- # Set up the Gradio interface
26
- iface = gr.Interface(
27
- fn=clip_similarity,
28
- inputs=[
29
- gr.Image(type="pil", label="Upload Image"),
30
- gr.Text(label="Input Text")
31
- ],
32
- outputs=gr.Number(label="Similarity Score"),
33
- title="CLIP Similarity Demo with ZeroGPU"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  )
35
 
36
  if __name__ == "__main__":
37
- iface.launch()
 
1
+ import spaces
2
+ import cv2
3
  import torch
4
+ import numpy as np
5
+ from PIL import Image
6
  import gradio as gr
7
  from transformers import CLIPProcessor, CLIPModel
8
 
9
+ # Load LightCLIP model.
10
+ # Replace "openai/clip-vit-base-patch32" with your LightCLIP model checkpoint if available.
11
+ MODEL_NAME = "openai/clip-vit-base-patch32"
12
+
13
+ # Define text prompts for fall and non-fall.
14
+ fall_prompt = "A person falling on the ground."
15
+ nofall_prompt = "A person standing or walking."
16
+
17
+ if torch.cuda.is_available():
18
+ text_inputs = {k: v.cuda() for k, v in text_inputs.items()}
19
+
20
+ def extract_frames(video_path, target_size=(224, 224)):
21
+ """
22
+ Extract all frames from the uploaded video and convert them to PIL Image.
23
+ """
24
+ cap = cv2.VideoCapture(video_path)
25
+ frames = []
26
+ while True:
27
+ ret, frame = cap.read()
28
+ if not ret:
29
+ break
30
+ # Convert frame from BGR to RGB and resize.
31
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
32
+ frame = cv2.resize(frame, target_size)
33
+ frames.append(Image.fromarray(frame))
34
+ cap.release()
35
+ return frames
36
 
37
  @spaces.GPU
38
+ def process_window(frames_window):
39
+ """
40
+ Process a window of frames and compute the average fall score.
41
+ """
42
+
43
+ processor = CLIPProcessor.from_pretrained(MODEL_NAME)
44
+ model = CLIPModel.from_pretrained(MODEL_NAME)
45
+ text_inputs = processor(text=[fall_prompt, nofall_prompt], return_tensors="pt", padding=True)
46
+
47
+
48
+ inputs = processor(images=frames_window, return_tensors="pt", padding=True)
49
+ if torch.cuda.is_available():
50
+ inputs = {k: v.cuda() for k, v in inputs.items()}
51
+ with torch.no_grad():
52
+ image_features = model.get_image_features(**inputs)
53
+ # Normalize embeddings.
54
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
55
+
56
+ with torch.no_grad():
57
+ text_features = model.get_text_features(**text_inputs)
58
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
59
+
60
+ # Compute cosine similarity.
61
+ sims = (image_features @ text_features.T).cpu().numpy() # shape: (num_frames, 2)
62
+ # We assume index 0 is for the fall prompt.
63
+ fall_scores = sims[:, 0]
64
+ window_score = np.mean(fall_scores)
65
+ return window_score, fall_scores
66
+
67
+ def detect_fall(video_path, window_size=16, stride=8, threshold=0.8, fps=15):
68
+ """
69
+ Process the video file using a sliding window over frames.
70
+ Returns a list of timestamps where a fall is detected.
71
+ """
72
+ frames = extract_frames(video_path)
73
+ if len(frames) < window_size:
74
+ return "Video too short for inference.", None
75
+
76
+ window_scores = []
77
+ window_indices = []
78
+ for start in range(0, len(frames) - window_size + 1, stride):
79
+ window = frames[start:start + window_size]
80
+ score, _ = process_window(window)
81
+ window_scores.append(score)
82
+ window_indices.append(start)
83
+
84
+ detected_events = []
85
+ for idx, score in zip(window_indices, window_scores):
86
+ if score > threshold:
87
+ time_sec = idx / fps # approximate timestamp
88
+ detected_events.append(time_sec)
89
+
90
+ result_text = ""
91
+ if detected_events:
92
+ result_text = "Fall events detected at (sec): " + ", ".join([f"{t:.1f}" for t in detected_events])
93
+ else:
94
+ result_text = "No fall detected."
95
+ # Return result and a representative frame for visual reference.
96
+ rep_frame = frames[len(frames) // 2]
97
+ return result_text, rep_frame
98
+
99
+ def process_video(video_file):
100
+ result_text, rep_frame = detect_fall(video_file)
101
+ return result_text, rep_frame
102
+
103
+ # Gradio interface definition.
104
+ demo = gr.Interface(
105
+ fn=process_video,
106
+ inputs=gr.Video(type="filepath", label="Upload Video Clip"),
107
+ outputs=[gr.Textbox(label="Detection Results"), gr.Image(label="Representative Frame")],
108
+ title="LightCLIP Fall Detection Demo",
109
+ description=(
110
+ "This demo detects human falls in video clips using a lightweight transformer-based model (LightCLIP). "
111
+ "A sliding window approach aggregates results over multiple frames to improve precision in complex scenes."
112
+ )
113
  )
114
 
115
  if __name__ == "__main__":
116
+ demo.launch()