LEGENDCODER1 commited on
Commit
5b5f841
·
verified ·
1 Parent(s): 1269013

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -25
app.py CHANGED
@@ -1,34 +1,23 @@
1
- import cv2
2
- import numpy as np
3
- import mediapipe as mp
4
  import gradio as gr
5
 
6
- # Initialize MediaPipe Pose model
7
- mp_pose = mp.solutions.pose
8
- pose = mp_pose.Pose()
9
 
10
- # Function to detect pose
11
- def detect_pose(image):
12
- # Convert image from RGB to BGR for OpenCV processing
13
- image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
14
 
15
- # Run MediaPipe Pose model
16
- results = pose.process(image_rgb)
17
-
18
- if results.pose_landmarks:
19
- return "Person Detected: Standing or Sitting Pose Identified"
20
- else:
21
- return "No person detected, please try again"
22
-
23
- # Gradio Interface
24
  interface = gr.Interface(
25
- fn=detect_pose,
26
- inputs=gr.Image(type="pil"), # Accepts an image as input
27
- outputs="text", # Outputs the detected pose
28
- title="Pose Detection for Exoskeleton",
29
- description="Upload an image of a person sitting or standing. The model will determine their pose."
30
  )
31
 
32
- # Launch the Gradio App
33
  if __name__ == "__main__":
34
  interface.launch(server_name="0.0.0.0")
 
1
+ from transformers import pipeline
 
 
2
  import gradio as gr
3
 
4
+ # Load a pre-trained image classification model
5
+ model = pipeline("image-classification", model="google/vit-base-patch16-224")
 
6
 
7
+ # Define a function for detecting actions
8
+ def classify_image(image):
9
+ predictions = model(image)
10
+ return {pred["label"]: round(pred["score"], 4) for pred in predictions}
11
 
12
+ # Gradio interface
 
 
 
 
 
 
 
 
13
  interface = gr.Interface(
14
+ fn=classify_image,
15
+ inputs=gr.Image(type="pil"), # Accepts image input
16
+ outputs="json", # Outputs predictions
17
+ title="Action Classifier",
18
+ description="Upload an image, and the model will classify actions (e.g., standing, sitting)."
19
  )
20
 
21
+ # Launch the app
22
  if __name__ == "__main__":
23
  interface.launch(server_name="0.0.0.0")