Dhan98 commited on
Commit
dfd6e31
·
verified ·
1 Parent(s): 35304a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -121
app.py CHANGED
@@ -1,140 +1,132 @@
1
- # app.py
2
  import streamlit as st
3
- from transformers import BlipProcessor, BlipForConditionalGeneration
4
- from diffusers import DiffusionPipeline
5
  import torch
6
- import cv2
7
- import numpy as np
8
  from PIL import Image
 
 
 
9
  import tempfile
10
  import os
11
 
12
- # Configure page
13
- st.set_page_config(
14
- page_title="Video Generator",
15
- page_icon="🎥",
16
- layout="wide"
17
- )
18
-
19
- @st.cache_resource
20
- def load_models():
21
- # Load text-to-video model
22
- pipeline = DiffusionPipeline.from_pretrained(
23
- "cerspense/zeroscope_v2_576w",
24
- torch_dtype=torch.float16
25
- )
26
- if torch.cuda.is_available():
27
- pipeline.to("cuda")
28
- else:
29
- pipeline.to("cpu")
30
-
31
- # Load image captioning model
32
- blip = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
33
- blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
34
-
35
- if torch.cuda.is_available():
36
- blip.to("cuda")
37
- else:
38
- blip.to("cpu")
39
-
40
- return pipeline, blip, blip_processor
41
-
42
- def enhance_image(image):
43
- # Convert PIL Image to numpy array
44
- img_array = np.array(image)
45
-
46
- # Basic enhancement: Increase contrast and brightness
47
- enhanced = cv2.convertScaleAbs(img_array, alpha=1.2, beta=10)
48
-
49
- return Image.fromarray(enhanced)
50
-
51
- def get_description(image, blip, blip_processor):
52
- # Process image for BLIP
53
- inputs = blip_processor(image, return_tensors="pt")
54
-
55
- if torch.cuda.is_available():
56
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
57
-
58
- # Generate caption
59
- with torch.no_grad():
60
- generated_ids = blip.generate(pixel_values=inputs["pixel_values"], max_length=50)
61
- description = blip_processor.decode(generated_ids[0], skip_special_tokens=True)
62
-
63
- return description
64
-
65
- def generate_video(pipeline, description):
66
- # Generate video frames
67
- video_frames = pipeline(
68
- description,
69
- num_inference_steps=30,
70
- num_frames=16
71
- ).frames
72
-
73
- # Create temporary directory and file path
74
- temp_dir = tempfile.mkdtemp()
75
- temp_path = os.path.join(temp_dir, "output.mp4")
76
-
77
- # Convert frames to video
78
- height, width = video_frames[0].shape[:2]
79
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
80
- video_writer = cv2.VideoWriter(temp_path, fourcc, 8, (width, height))
81
-
82
- for frame in video_frames:
83
- video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
84
-
85
- video_writer.release()
86
-
87
- return temp_path
88
-
89
- def main():
90
- st.title("🎥 AI Video Generator")
91
- st.write("Upload an image to generate a video based on its content!")
92
 
 
 
 
 
 
93
  try:
94
- # Load models
95
- pipeline, blip, blip_processor = load_models()
 
 
 
96
 
97
- # File uploader
98
- image_file = st.file_uploader("Upload Image", type=['png', 'jpg', 'jpeg'])
 
99
 
100
- if image_file:
101
- # Display original and enhanced image
102
- col1, col2 = st.columns(2)
103
 
104
- with col1:
105
- image = Image.open(image_file)
106
- st.image(image, caption="Original Image")
 
 
 
 
 
 
 
 
 
 
107
 
108
- with col2:
109
- enhanced_image = enhance_image(image)
110
- st.image(enhanced_image, caption="Enhanced Image")
 
 
 
 
 
 
 
 
 
111
 
112
- # Get and display description
113
- description = get_description(enhanced_image, blip, blip_processor)
114
- st.write("📝 Generated Description:", description)
 
 
 
115
 
116
- # Allow user to edit description
117
- modified_description = st.text_area("Edit description if needed:", description)
 
 
 
 
118
 
119
- # Generate video button
120
- if st.button("🎬 Generate Video"):
121
- with st.spinner("Generating video... This may take a few minutes."):
122
- video_path = generate_video(pipeline, modified_description)
123
- st.success("Video generated successfully!")
124
- st.video(video_path)
125
-
126
- # Add download button
127
- with open(video_path, 'rb') as f:
128
- st.download_button(
129
- label="Download Video",
130
- data=f,
131
- file_name="generated_video.mp4",
132
- mime="video/mp4"
133
- )
134
-
135
  except Exception as e:
136
- st.error(f"An error occurred: {str(e)}")
137
- st.error("Please try again or contact support if the error persists.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  if __name__ == "__main__":
140
  main()
 
 
1
  import streamlit as st
 
 
2
  import torch
3
+ from transformers import pipeline
 
4
  from PIL import Image
5
+ from diffusers import LTXVideoProcessor, LTXVideoPipeline
6
+ import numpy as np
7
+ from moviepy.editor import ImageSequenceClip
8
  import tempfile
9
  import os
10
 
11
+ def generate_video_from_image(image, duration_seconds=10, progress_bar=None):
12
+ """
13
+ Generate a video from an image using LTX-Video and image captioning.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ Args:
16
+ image: PIL Image object
17
+ duration_seconds: Duration of output video in seconds
18
+ progress_bar: Streamlit progress bar object
19
+ """
20
  try:
21
+ if progress_bar:
22
+ progress_bar.progress(0.1, "Generating image caption...")
23
+
24
+ # Setup image captioning pipeline
25
+ captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
26
 
27
+ # Generate caption
28
+ caption = captioner(image)[0]['generated_text']
29
+ st.write(f"Generated caption: *{caption}*")
30
 
31
+ if progress_bar:
32
+ progress_bar.progress(0.3, "Loading LTX-Video model...")
 
33
 
34
+ # Initialize LTX-Video pipeline
35
+ processor = LTXVideoProcessor()
36
+ pipeline = LTXVideoPipeline.from_pretrained("Lightricks/ltx-video")
37
+
38
+ if progress_bar:
39
+ progress_bar.progress(0.4, "Processing image...")
40
+
41
+ # Process image for video generation
42
+ processed_image = processor(image).pixel_values
43
+ processed_image = torch.from_numpy(processed_image).unsqueeze(0)
44
+
45
+ if progress_bar:
46
+ progress_bar.progress(0.5, "Generating video frames...")
47
 
48
+ # Generate video frames
49
+ num_frames = duration_seconds * 30 # 30 FPS
50
+ video_frames = pipeline(
51
+ processed_image,
52
+ num_inference_steps=50,
53
+ num_frames=num_frames,
54
+ guidance_scale=7.5,
55
+ prompt=caption,
56
+ ).videos
57
+
58
+ if progress_bar:
59
+ progress_bar.progress(0.8, "Creating final video...")
60
 
61
+ # Convert frames to format suitable for moviepy
62
+ frames = [np.array(frame) for frame in video_frames[0]]
63
+
64
+ # Create temporary file for video
65
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_file:
66
+ output_path = tmp_file.name
67
 
68
+ # Create and save video
69
+ clip = ImageSequenceClip(frames, fps=30)
70
+ clip.write_videofile(output_path, codec='libx264', audio=False)
71
+
72
+ if progress_bar:
73
+ progress_bar.progress(1.0, "Video generation complete!")
74
 
75
+ return output_path, caption
76
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  except Exception as e:
78
+ st.error(f"Error generating video: {str(e)}")
79
+ return None, None
80
+
81
+ def main():
82
+ st.set_page_config(page_title="Video Generator", page_icon="🎥")
83
+
84
+ st.title("🎥 AI Video Generator")
85
+ st.write("""
86
+ Upload an image to generate a video with AI-powered motion and transitions.
87
+ The app will automatically generate a caption for your image and use it as inspiration for the video.
88
+ """)
89
+
90
+ # File uploader
91
+ uploaded_file = st.file_uploader("Choose an image", type=['png', 'jpg', 'jpeg'])
92
+
93
+ # Duration selector
94
+ duration = st.slider("Video duration (seconds)", min_value=1, max_value=30, value=10)
95
+
96
+ if uploaded_file is not None:
97
+ # Display uploaded image
98
+ image = Image.open(uploaded_file)
99
+ st.image(image, caption="Uploaded Image", use_column_width=True)
100
+
101
+ # Generate button
102
+ if st.button("Generate Video"):
103
+ # Create a progress bar
104
+ progress_text = "Operation in progress. Please wait..."
105
+ my_bar = st.progress(0, text=progress_text)
106
+
107
+ # Generate video
108
+ video_path, caption = generate_video_from_image(image, duration, my_bar)
109
+
110
+ if video_path and os.path.exists(video_path):
111
+ # Read the video file
112
+ with open(video_path, 'rb') as video_file:
113
+ video_bytes = video_file.read()
114
+
115
+ # Create download button
116
+ st.download_button(
117
+ label="Download Video",
118
+ data=video_bytes,
119
+ file_name="generated_video.mp4",
120
+ mime="video/mp4"
121
+ )
122
+
123
+ # Display video
124
+ st.video(video_bytes)
125
+
126
+ # Clean up temporary file
127
+ os.unlink(video_path)
128
+ else:
129
+ st.error("Failed to generate video. Please try again.")
130
 
131
  if __name__ == "__main__":
132
  main()