Ankit8544 commited on
Commit
041221d
·
verified ·
1 Parent(s): 5562446

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -18
app.py CHANGED
@@ -6,11 +6,11 @@ from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepSchedu
6
  import os
7
  from uuid import uuid4
8
 
9
- # Check if CUDA is available and set device accordingly
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
- print(f"Using device: {device}") # Print device information
12
 
13
- # Load model on startup
14
  try:
15
  print("Loading model...")
16
  model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
@@ -23,16 +23,17 @@ try:
23
  )
24
  pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
25
  pipe.scheduler = scheduler
26
- pipe.to(device) # Move model to CUDA if available, otherwise CPU
27
- print("Model loaded successfully.")
28
  except Exception as e:
29
  print(f"Error loading model: {e}")
30
- raise e
 
31
 
32
  # Define the generation function
33
  def generate_video(prompt, negative_prompt="", height=720, width=1280, num_frames=81, guidance_scale=5.0):
34
  try:
35
- print(f"Generating video for prompt: {prompt}")
36
  output = pipe(
37
  prompt=prompt,
38
  negative_prompt=negative_prompt,
@@ -47,14 +48,13 @@ def generate_video(prompt, negative_prompt="", height=720, width=1280, num_frame
47
  os.makedirs("outputs", exist_ok=True)
48
  export_to_video(output, output_path, fps=16)
49
 
50
- print(f"Video generated successfully: {output_path}")
51
- return output_path # Gradio returns this as a downloadable file/video
52
-
53
  except Exception as e:
54
- print(f"Error generating video: {e}")
55
- return None # Return None in case of error
56
 
57
- # Gradio Interface with API support
58
  iface = gr.Interface(
59
  fn=generate_video,
60
  inputs=[
@@ -68,12 +68,11 @@ iface = gr.Interface(
68
  outputs=gr.File(label="Generated Video"),
69
  title="Wan2.1 Video Generator",
70
  description="Generate realistic videos from text prompts using the Wan2.1 T2V model.",
71
- api=True # This enables the API
72
  )
73
 
 
74
  try:
75
- print("Launching Gradio interface...")
76
- iface.launch(share=True) # `share=True` will allow others to access your app via a public link
77
- print("Gradio interface launched successfully.")
78
  except Exception as e:
79
- print(f"Error launching Gradio interface: {e}")
 
6
  import os
7
  from uuid import uuid4
8
 
9
+ # Check for available device (CUDA or CPU)
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ print(f"Running on {device}...")
12
 
13
+ # Load the model only once during startup
14
  try:
15
  print("Loading model...")
16
  model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
 
23
  )
24
  pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
25
  pipe.scheduler = scheduler
26
+ pipe.to(device) # Move model to GPU or CPU based on availability
27
+ print("Model loaded successfully!")
28
  except Exception as e:
29
  print(f"Error loading model: {e}")
30
+ device = "cpu" # Fallback to CPU if model loading fails on GPU
31
+ pipe.to(device)
32
 
33
  # Define the generation function
34
  def generate_video(prompt, negative_prompt="", height=720, width=1280, num_frames=81, guidance_scale=5.0):
35
  try:
36
+ print(f"Generating video with prompt: {prompt}")
37
  output = pipe(
38
  prompt=prompt,
39
  negative_prompt=negative_prompt,
 
48
  os.makedirs("outputs", exist_ok=True)
49
  export_to_video(output, output_path, fps=16)
50
 
51
+ print(f"Video generated and saved to {output_path}")
52
+ return output_path # Gradio returns this as downloadable file/video
 
53
  except Exception as e:
54
+ print(f"Error during video generation: {e}")
55
+ return None
56
 
57
+ # Gradio Interface
58
  iface = gr.Interface(
59
  fn=generate_video,
60
  inputs=[
 
68
  outputs=gr.File(label="Generated Video"),
69
  title="Wan2.1 Video Generator",
70
  description="Generate realistic videos from text prompts using the Wan2.1 T2V model.",
71
+ live=True
72
  )
73
 
74
+ # Launch Gradio app in API mode
75
  try:
76
+ iface.launch(share=True, server_name="0.0.0.0", server_port=7860)
 
 
77
  except Exception as e:
78
+ print(f"Error launching Gradio app: {e}")