video_splitter / local_video_understant_app.py
BahadirGLCK's picture
Change application flow.
6621c82
import os
import hashlib
import requests
import numpy as np
from PIL import Image
import decord
from decord import VideoReader, cpu
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import gradio as gr
# ---------------------------------------------------
# 1. Set Up Device: Use Apple's MPS if available, else CPU
# ---------------------------------------------------
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")
# For MPS, we can try using float16 to reduce memory usage.
torch_dtype = torch.float16 if device == "mps" else torch.float32
# ---------------------------------------------------
# 2. Initialize the Qwen 2.5 VL Model (3B) for Local Use
# ---------------------------------------------------
model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch_dtype
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_path)
# ---------------------------------------------------
# 3. Utility Functions for Video Processing
# ---------------------------------------------------
def download_video(url, dest_path):
"""
Downloads a video from a URL.
(This function is kept here if you ever need to download via URL.)
"""
response = requests.get(url, stream=True)
with open(dest_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8096):
f.write(chunk)
print(f"Video downloaded to {dest_path}")
def get_video_frames(video_path, num_frames=64, cache_dir='.cache'):
"""
Extract frames and timestamps from a video file.
If video_path is a URL, it downloads it; otherwise it assumes a local file.
Caching is used to avoid re-processing.
"""
os.makedirs(cache_dir, exist_ok=True)
video_hash = hashlib.md5(video_path.encode('utf-8')).hexdigest()
# If the path starts with 'http', download the file.
if video_path.startswith("http"):
video_file_path = os.path.join(cache_dir, f"{video_hash}.mp4")
if not os.path.exists(video_file_path):
print("Downloading video using requests...")
download_video(video_path, video_file_path)
else:
video_file_path = video_path
frames_cache_file = os.path.join(cache_dir, f"{video_hash}_{num_frames}_frames.npy")
timestamps_cache_file = os.path.join(cache_dir, f"{video_hash}_{num_frames}_timestamps.npy")
if os.path.exists(frames_cache_file) and os.path.exists(timestamps_cache_file):
frames = np.load(frames_cache_file)
timestamps = np.load(timestamps_cache_file)
return video_file_path, frames, timestamps
# Load video using decord
vr = VideoReader(video_file_path, ctx=cpu(0))
total_frames = len(vr)
indices = np.linspace(0, total_frames - 1, num=num_frames, dtype=int)
frames = vr.get_batch(indices).asnumpy()
timestamps = np.array([vr.get_frame_timestamp(idx) for idx in indices])
# Cache the frames and timestamps
np.save(frames_cache_file, frames)
np.save(timestamps_cache_file, timestamps)
return video_file_path, frames, timestamps
# ---------------------------------------------------
# 4. Inference Function Using Qwen 2.5 VL (3B)
# ---------------------------------------------------
def inference(video_path, prompt, max_new_tokens=2048, total_pixels=20480 * 28 * 28, min_pixels=16 * 28 * 28):
"""
Prepares the input with the prompt and video metadata,
processes the video inputs, and runs inference through the model.
"""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": [
{"type": "text", "text": prompt},
{"video": video_path, "total_pixels": total_pixels, "min_pixels": min_pixels},
]},
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs, video_kwargs = process_vision_info([messages], return_video_kwargs=True)
fps_inputs = video_kwargs["fps"]
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
fps=fps_inputs,
padding=True,
return_tensors="pt"
)
# Move inputs to our chosen device (MPS or CPU)
inputs = inputs.to(device)
output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
return output_text[0]
# ---------------------------------------------------
# 5. Define Sample Prompts
# ---------------------------------------------------
sample_prompts = [
"Please analyze the video and split it into chapters with timestamps and descriptive titles in the format 'mm:ss Title'.",
"Provide a breakdown of the video's content by segment, including starting times and summaries.",
"Segment the video into logical chapters and output the start time and a brief description for each chapter.",
]
# ---------------------------------------------------
# 6. Main Processing Function for the Gradio Interface
# ---------------------------------------------------
def process_video(video_file, custom_prompt, sample_prompt):
"""
Called when the user clicks 'Process Video'.
Uses a custom prompt (if provided) or the sample prompt.
Processes the uploaded video and runs inference.
"""
final_prompt = custom_prompt.strip() if custom_prompt.strip() != "" else sample_prompt
try:
# Here, video_file is the local file path from the uploader.
video_path, frames, timestamps = get_video_frames(video_file, num_frames=64)
except Exception as e:
return f"Error processing video: {str(e)}"
try:
output = inference(video_path, final_prompt)
except Exception as e:
return f"Error during inference: {str(e)}"
return output
# ---------------------------------------------------
# 7. Build the Gradio Interface for Local Use
# ---------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("# Video Chapter Splitter using Qwen 2.5 VL (3B) on Mac")
gr.Markdown("Upload a video file and either type a custom prompt or select one of the sample prompts. Then click **Process Video** to generate the chapter breakdown.")
with gr.Row():
video_input = gr.Video(label="Upload Video")
with gr.Row():
custom_prompt_input = gr.Textbox(label="Custom Prompt", placeholder="Enter custom prompt (optional)...", lines=2)
with gr.Row():
sample_prompt_input = gr.Dropdown(label="Sample Prompts", choices=sample_prompts, value=sample_prompts[0])
output_text = gr.Textbox(label="Output", lines=10)
run_button = gr.Button("Process Video")
run_button.click(fn=process_video, inputs=[video_input, custom_prompt_input, sample_prompt_input], outputs=output_text)
if __name__ == "__main__":
demo.launch()