SeedOfEvil commited on
Commit
c677a45
·
verified ·
1 Parent(s): 5839a3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -1,27 +1,31 @@
1
  import gradio as gr
 
2
  from transformers import pipeline
3
  import torch
4
 
5
- # Global generator variable; we'll load it lazily.
6
  generator = None
7
 
8
  def get_generator():
9
  global generator
10
  if generator is None:
11
  try:
12
- # If GPU is available, load on GPU (device=0).
13
  if torch.cuda.is_available():
14
  generator = pipeline("text-generation", model="EleutherAI/gpt-j-6B", device=0)
15
  else:
16
  generator = pipeline("text-generation", model="EleutherAI/gpt-j-6B", device=-1)
17
  except Exception as e:
18
- # If any error occurs, fallback to CPU
19
  print("Error loading model on GPU, falling back to CPU:", e)
20
  generator = pipeline("text-generation", model="EleutherAI/gpt-j-6B", device=-1)
21
  return generator
22
 
 
23
  def expand_prompt(prompt, num_variants=5, max_length=100):
24
- # Lazy load the model when a prompt is submitted.
 
 
 
25
  gen = get_generator()
26
  outputs = gen(prompt, max_length=max_length, num_return_sequences=num_variants, do_sample=True)
27
  expanded = [out["generated_text"].strip() for out in outputs]
@@ -34,8 +38,9 @@ iface = gr.Interface(
34
  title="Prompt Expansion Generator",
35
  description=(
36
  "Enter a basic prompt and receive 5 creative, expanded prompt variants. "
37
- "This tool leverages the EleutherAI/gpt-j-6B model and defers loading it until the first prompt is received—"
38
- "letting ZeroGPU initialize properly. Simply copy the output for use with your downstream image-generation pipeline."
 
39
  )
40
  )
41
 
 
1
  import gradio as gr
2
+ import spaces # Import ZeroGPU's helper module
3
  from transformers import pipeline
4
  import torch
5
 
6
+ # Global generator variable; load lazily.
7
  generator = None
8
 
9
  def get_generator():
10
  global generator
11
  if generator is None:
12
  try:
13
+ # If GPU is available, load on GPU (device=0)
14
  if torch.cuda.is_available():
15
  generator = pipeline("text-generation", model="EleutherAI/gpt-j-6B", device=0)
16
  else:
17
  generator = pipeline("text-generation", model="EleutherAI/gpt-j-6B", device=-1)
18
  except Exception as e:
 
19
  print("Error loading model on GPU, falling back to CPU:", e)
20
  generator = pipeline("text-generation", model="EleutherAI/gpt-j-6B", device=-1)
21
  return generator
22
 
23
+ @spaces.GPU # This decorator ensures ZeroGPU allocates a GPU when the function is called.
24
  def expand_prompt(prompt, num_variants=5, max_length=100):
25
+ """
26
+ Given a basic prompt, generate `num_variants` expanded prompts using GPT-J-6B.
27
+ The GPU is only engaged during this function call.
28
+ """
29
  gen = get_generator()
30
  outputs = gen(prompt, max_length=max_length, num_return_sequences=num_variants, do_sample=True)
31
  expanded = [out["generated_text"].strip() for out in outputs]
 
38
  title="Prompt Expansion Generator",
39
  description=(
40
  "Enter a basic prompt and receive 5 creative, expanded prompt variants. "
41
+ "This tool leverages the EleutherAI/gpt-j-6B model on an A100 GPU via ZeroGPU. "
42
+ "The GPU is only allocated when a prompt is submitted, ensuring proper ZeroGPU initialization. "
43
+ "Simply copy the output for use with your downstream image-generation pipeline."
44
  )
45
  )
46