warhawkmonk commited on
Commit
009e955
·
verified ·
1 Parent(s): 166edc5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -1,12 +1,15 @@
1
  import gradio as gr
2
  from transformers import pipeline
3
 
4
- # Load the TinyLlama model for text generation
5
- pipe = pipeline("text-generation", model="TinyLlama/TinyLlama_v1.1").to("cuda")
 
 
 
 
6
 
7
  # Define the prediction function
8
  def generate_text(prompt, max_length=128, temperature=1.0, top_p=0.95):
9
- # You can expose more parameters as needed
10
  result = pipe(
11
  prompt,
12
  max_length=max_length,
@@ -15,7 +18,6 @@ def generate_text(prompt, max_length=128, temperature=1.0, top_p=0.95):
15
  num_return_sequences=1,
16
  do_sample=True
17
  )
18
- # The output is a list of dicts with 'generated_text'
19
  return result[0]['generated_text']
20
 
21
  # Create the Gradio interface
 
1
  import gradio as gr
2
  from transformers import pipeline
3
 
4
+ # Load the TinyLlama model for text generation on GPU
5
+ pipe = pipeline(
6
+ "text-generation",
7
+ model="TinyLlama/TinyLlama_v1.1",
8
+ device=0 # 0 for 'cuda:0', -1 for CPU
9
+ ) # No .to("cuda") needed[4][6]
10
 
11
  # Define the prediction function
12
  def generate_text(prompt, max_length=128, temperature=1.0, top_p=0.95):
 
13
  result = pipe(
14
  prompt,
15
  max_length=max_length,
 
18
  num_return_sequences=1,
19
  do_sample=True
20
  )
 
21
  return result[0]['generated_text']
22
 
23
  # Create the Gradio interface