warhawkmonk commited on
Commit
71b64d2
·
verified ·
1 Parent(s): 4392dee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -33
app.py CHANGED
@@ -1,43 +1,36 @@
1
  import gradio as gr
2
- import torch
3
  from transformers import pipeline
4
 
5
- # Model ID for Llama 3 8B instruct (replace with the exact model you want)
6
- MODEL_ID = "manycore-research/SpatialLM-Llama-1B"
7
 
8
- # Load the text-generation pipeline with device_map="auto" to use GPU if available
9
- generator = pipeline(
10
- "text-generation",
11
- model=MODEL_ID,
12
- torch_dtype=torch.float16,
13
- device_map="auto",
14
- )
15
-
16
- def generate_response(prompt, max_length=512, temperature=0.7):
17
- # Format prompt for Llama 3 instruct style
18
- formatted_prompt = f"<s>[INST] {prompt} [/INST]"
19
- output = generator(
20
- formatted_prompt,
21
  max_length=max_length,
22
  temperature=temperature,
23
- do_sample=True,
24
- top_p=0.95,
25
  num_return_sequences=1,
 
26
  )
27
- generated_text = output[0]["generated_text"]
28
- # Extract the response after the [/INST] token
29
- response = generated_text.split("[/INST]")[-1].strip()
30
- return response
31
 
32
- with gr.Blocks() as demo:
33
- gr.Markdown("# Chat with Llama 3 (8B Instruct)")
34
- with gr.Row():
35
- with gr.Column():
36
- user_input = gr.Textbox(lines=3, placeholder="Type your message here...", label="Your Message")
37
- submit_btn = gr.Button("Submit")
38
- with gr.Column():
39
- output = gr.Textbox(lines=10, label="Llama 3 Response")
40
- submit_btn.click(fn=generate_response, inputs=user_input, outputs=output)
 
 
 
 
41
 
42
- if __name__ == "__main__":
43
- demo.launch()
 
1
  import gradio as gr
 
2
  from transformers import pipeline
3
 
4
+ # Load the TinyLlama model for text generation
5
+ pipe = pipeline("text-generation", model="TinyLlama/TinyLlama_v1.1")
6
 
7
+ # Define the prediction function
8
+ def generate_text(prompt, max_length=128, temperature=1.0, top_p=0.95):
9
+ # You can expose more parameters as needed
10
+ result = pipe(
11
+ prompt,
 
 
 
 
 
 
 
 
12
  max_length=max_length,
13
  temperature=temperature,
14
+ top_p=top_p,
 
15
  num_return_sequences=1,
16
+ do_sample=True
17
  )
18
+ # The output is a list of dicts with 'generated_text'
19
+ return result[0]['generated_text']
 
 
20
 
21
+ # Create the Gradio interface
22
+ demo = gr.Interface(
23
+ fn=generate_text,
24
+ inputs=[
25
+ gr.Textbox(lines=4, label="Input Prompt"),
26
+ gr.Slider(32, 512, value=128, step=8, label="Max Length"),
27
+ gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="Temperature"),
28
+ gr.Slider(0.5, 1.0, value=0.95, step=0.01, label="Top-p (nucleus sampling)")
29
+ ],
30
+ outputs=gr.Textbox(lines=8, label="Generated Text"),
31
+ title="TinyLlama Text Generation",
32
+ description="Enter a prompt and generate text using TinyLlama/TinyLlama_v1.1."
33
+ )
34
 
35
+ # Launch the app
36
+ demo.launch()