Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,12 +1,15 @@
|
|
1 |
import gradio as gr
|
2 |
from transformers import pipeline
|
3 |
|
4 |
-
# Load the TinyLlama model for text generation
|
5 |
-
pipe = pipeline(
|
|
|
|
|
|
|
|
|
6 |
|
7 |
# Define the prediction function
|
8 |
def generate_text(prompt, max_length=128, temperature=1.0, top_p=0.95):
|
9 |
-
# You can expose more parameters as needed
|
10 |
result = pipe(
|
11 |
prompt,
|
12 |
max_length=max_length,
|
@@ -15,7 +18,6 @@ def generate_text(prompt, max_length=128, temperature=1.0, top_p=0.95):
|
|
15 |
num_return_sequences=1,
|
16 |
do_sample=True
|
17 |
)
|
18 |
-
# The output is a list of dicts with 'generated_text'
|
19 |
return result[0]['generated_text']
|
20 |
|
21 |
# Create the Gradio interface
|
|
|
1 |
import gradio as gr
|
2 |
from transformers import pipeline
|
3 |
|
4 |
+
# Load the TinyLlama model for text generation on GPU
|
5 |
+
pipe = pipeline(
|
6 |
+
"text-generation",
|
7 |
+
model="TinyLlama/TinyLlama_v1.1",
|
8 |
+
device=0 # 0 for 'cuda:0', -1 for CPU
|
9 |
+
) # No .to("cuda") needed[4][6]
|
10 |
|
11 |
# Define the prediction function
|
12 |
def generate_text(prompt, max_length=128, temperature=1.0, top_p=0.95):
|
|
|
13 |
result = pipe(
|
14 |
prompt,
|
15 |
max_length=max_length,
|
|
|
18 |
num_return_sequences=1,
|
19 |
do_sample=True
|
20 |
)
|
|
|
21 |
return result[0]['generated_text']
|
22 |
|
23 |
# Create the Gradio interface
|