Spaces:
Running
Running
File size: 7,208 Bytes
6621c82 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import os
import hashlib
import requests
import numpy as np
from PIL import Image
import decord
from decord import VideoReader, cpu
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import gradio as gr
# ---------------------------------------------------
# 1. Set Up Device: Use Apple's MPS if available, else CPU
# ---------------------------------------------------
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")
# For MPS, we can try using float16 to reduce memory usage.
torch_dtype = torch.float16 if device == "mps" else torch.float32
# ---------------------------------------------------
# 2. Initialize the Qwen 2.5 VL Model (3B) for Local Use
# ---------------------------------------------------
model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch_dtype
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_path)
# ---------------------------------------------------
# 3. Utility Functions for Video Processing
# ---------------------------------------------------
def download_video(url, dest_path):
"""
Downloads a video from a URL.
(This function is kept here if you ever need to download via URL.)
"""
response = requests.get(url, stream=True)
with open(dest_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8096):
f.write(chunk)
print(f"Video downloaded to {dest_path}")
def get_video_frames(video_path, num_frames=64, cache_dir='.cache'):
"""
Extract frames and timestamps from a video file.
If video_path is a URL, it downloads it; otherwise it assumes a local file.
Caching is used to avoid re-processing.
"""
os.makedirs(cache_dir, exist_ok=True)
video_hash = hashlib.md5(video_path.encode('utf-8')).hexdigest()
# If the path starts with 'http', download the file.
if video_path.startswith("http"):
video_file_path = os.path.join(cache_dir, f"{video_hash}.mp4")
if not os.path.exists(video_file_path):
print("Downloading video using requests...")
download_video(video_path, video_file_path)
else:
video_file_path = video_path
frames_cache_file = os.path.join(cache_dir, f"{video_hash}_{num_frames}_frames.npy")
timestamps_cache_file = os.path.join(cache_dir, f"{video_hash}_{num_frames}_timestamps.npy")
if os.path.exists(frames_cache_file) and os.path.exists(timestamps_cache_file):
frames = np.load(frames_cache_file)
timestamps = np.load(timestamps_cache_file)
return video_file_path, frames, timestamps
# Load video using decord
vr = VideoReader(video_file_path, ctx=cpu(0))
total_frames = len(vr)
indices = np.linspace(0, total_frames - 1, num=num_frames, dtype=int)
frames = vr.get_batch(indices).asnumpy()
timestamps = np.array([vr.get_frame_timestamp(idx) for idx in indices])
# Cache the frames and timestamps
np.save(frames_cache_file, frames)
np.save(timestamps_cache_file, timestamps)
return video_file_path, frames, timestamps
# ---------------------------------------------------
# 4. Inference Function Using Qwen 2.5 VL (3B)
# ---------------------------------------------------
def inference(video_path, prompt, max_new_tokens=2048, total_pixels=20480 * 28 * 28, min_pixels=16 * 28 * 28):
"""
Prepares the input with the prompt and video metadata,
processes the video inputs, and runs inference through the model.
"""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": [
{"type": "text", "text": prompt},
{"video": video_path, "total_pixels": total_pixels, "min_pixels": min_pixels},
]},
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs, video_kwargs = process_vision_info([messages], return_video_kwargs=True)
fps_inputs = video_kwargs["fps"]
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
fps=fps_inputs,
padding=True,
return_tensors="pt"
)
# Move inputs to our chosen device (MPS or CPU)
inputs = inputs.to(device)
output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
return output_text[0]
# ---------------------------------------------------
# 5. Define Sample Prompts
# ---------------------------------------------------
sample_prompts = [
"Please analyze the video and split it into chapters with timestamps and descriptive titles in the format 'mm:ss Title'.",
"Provide a breakdown of the video's content by segment, including starting times and summaries.",
"Segment the video into logical chapters and output the start time and a brief description for each chapter.",
]
# ---------------------------------------------------
# 6. Main Processing Function for the Gradio Interface
# ---------------------------------------------------
def process_video(video_file, custom_prompt, sample_prompt):
"""
Called when the user clicks 'Process Video'.
Uses a custom prompt (if provided) or the sample prompt.
Processes the uploaded video and runs inference.
"""
final_prompt = custom_prompt.strip() if custom_prompt.strip() != "" else sample_prompt
try:
# Here, video_file is the local file path from the uploader.
video_path, frames, timestamps = get_video_frames(video_file, num_frames=64)
except Exception as e:
return f"Error processing video: {str(e)}"
try:
output = inference(video_path, final_prompt)
except Exception as e:
return f"Error during inference: {str(e)}"
return output
# ---------------------------------------------------
# 7. Build the Gradio Interface for Local Use
# ---------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("# Video Chapter Splitter using Qwen 2.5 VL (3B) on Mac")
gr.Markdown("Upload a video file and either type a custom prompt or select one of the sample prompts. Then click **Process Video** to generate the chapter breakdown.")
with gr.Row():
video_input = gr.Video(label="Upload Video")
with gr.Row():
custom_prompt_input = gr.Textbox(label="Custom Prompt", placeholder="Enter custom prompt (optional)...", lines=2)
with gr.Row():
sample_prompt_input = gr.Dropdown(label="Sample Prompts", choices=sample_prompts, value=sample_prompts[0])
output_text = gr.Textbox(label="Output", lines=10)
run_button = gr.Button("Process Video")
run_button.click(fn=process_video, inputs=[video_input, custom_prompt_input, sample_prompt_input], outputs=output_text)
if __name__ == "__main__":
demo.launch() |