SeedOfEvil commited on
Commit
e94cd94
·
verified ·
1 Parent(s): e15cf75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -10
app.py CHANGED
@@ -1,16 +1,29 @@
1
  import gradio as gr
2
  from transformers import pipeline
 
3
 
4
- # Load the larger text-generation model that uses GPU.
5
- # Here we use EleutherAI/gpt-j-6B: https://huggingface.co/EleutherAI/gpt-j-6B
6
- # Setting device=0 tells the pipeline to use GPU 0.
7
- generator = pipeline("text-generation", model="EleutherAI/gpt-j-6B", device=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def expand_prompt(prompt, num_variants=5, max_length=100):
10
- """
11
- Given a basic prompt, generate `num_variants` expanded prompts using GPT-J-6B.
12
- """
13
- outputs = generator(prompt, max_length=max_length, num_return_sequences=num_variants, do_sample=True)
14
  expanded = [out["generated_text"].strip() for out in outputs]
15
  return "\n\n".join(expanded)
16
 
@@ -21,8 +34,8 @@ iface = gr.Interface(
21
  title="Prompt Expansion Generator",
22
  description=(
23
  "Enter a basic prompt and receive 5 creative, expanded prompt variants. "
24
- "This tool leverages the EleutherAI/gpt-j-6B model on an A100 GPU for fast, expressive prompt expansion. "
25
- "Simply copy the output for use with your downstream image-generation pipeline."
26
  )
27
  )
28
 
 
1
  import gradio as gr
2
  from transformers import pipeline
3
+ import torch
4
 
5
+ # Global generator variable; we'll load it lazily.
6
+ generator = None
7
+
8
+ def get_generator():
9
+ global generator
10
+ if generator is None:
11
+ try:
12
+ # If GPU is available, load on GPU (device=0).
13
+ if torch.cuda.is_available():
14
+ generator = pipeline("text-generation", model="EleutherAI/gpt-j-6B", device=0)
15
+ else:
16
+ generator = pipeline("text-generation", model="EleutherAI/gpt-j-6B", device=-1)
17
+ except Exception as e:
18
+ # If any error occurs, fallback to CPU
19
+ print("Error loading model on GPU, falling back to CPU:", e)
20
+ generator = pipeline("text-generation", model="EleutherAI/gpt-j-6B", device=-1)
21
+ return generator
22
 
23
  def expand_prompt(prompt, num_variants=5, max_length=100):
24
+ # Lazy load the model when a prompt is submitted.
25
+ gen = get_generator()
26
+ outputs = gen(prompt, max_length=max_length, num_return_sequences=num_variants, do_sample=True)
 
27
  expanded = [out["generated_text"].strip() for out in outputs]
28
  return "\n\n".join(expanded)
29
 
 
34
  title="Prompt Expansion Generator",
35
  description=(
36
  "Enter a basic prompt and receive 5 creative, expanded prompt variants. "
37
+ "This tool leverages the EleutherAI/gpt-j-6B model and defers loading it until the first prompt is received—"
38
+ "letting ZeroGPU initialize properly. Simply copy the output for use with your downstream image-generation pipeline."
39
  )
40
  )
41