|
import gradio as gr |
|
import torch |
|
from gradio.themes.utils import sizes |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
import utils |
|
from constants import END_OF_TEXT, MIN_TEMPERATURE |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
"BEE-spoke-data/smol_llama-101M-GQA-python", |
|
use_fast=False, |
|
) |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
tokenizer.pad_token = END_OF_TEXT |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"BEE-spoke-data/smol_llama-101M-GQA-python", |
|
device_map="auto", |
|
) |
|
model = torch.compile(model, mode="reduce-overhead") |
|
|
|
|
|
|
|
_styles = utils.get_file_as_string("styles.css") |
|
|
|
|
|
readme_file_content = utils.get_file_as_string("README.md", path="./") |
|
( |
|
manifest, |
|
description, |
|
disclaimer, |
|
base_model_info, |
|
formats, |
|
) = utils.get_sections(readme_file_content, "---", up_to=5) |
|
|
|
theme = gr.themes.Soft( |
|
primary_hue="yellow", |
|
secondary_hue="orange", |
|
neutral_hue="slate", |
|
radius_size=sizes.radius_sm, |
|
font=[ |
|
gr.themes.GoogleFont("IBM Plex Sans", [400, 600]), |
|
"ui-sans-serif", |
|
"system-ui", |
|
"sans-serif", |
|
], |
|
text_size=sizes.text_lg, |
|
) |
|
|
|
|
|
def run_inference( |
|
prompt, temperature, max_new_tokens, top_p, repetition_penalty |
|
) -> str: |
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
outputs = model.generate( |
|
**inputs, |
|
do_sample=True, |
|
epsilon_cutoff=1e-3, |
|
max_new_tokens=max_new_tokens, |
|
min_new_tokens=2, |
|
no_repeat_ngram_size=6, |
|
renormalize_logits=True, |
|
repetition_penalty=repetition_penalty, |
|
temperature=max(temperature, MIN_TEMPERATURE), |
|
top_p=top_p, |
|
) |
|
text = tokenizer.batch_decode( |
|
outputs, |
|
skip_special_tokens=True, |
|
)[0] |
|
return text |
|
|
|
|
|
examples = [ |
|
[ |
|
'def greet(name: str) -> None:\n """\n Greets the user\n """\n print(f"Hello,', |
|
0.2, |
|
64, |
|
0.9, |
|
1.2, |
|
], |
|
[ |
|
'for i in range(5):\n """\n Loop through 0 to 4\n """\n print(i,', |
|
0.2, |
|
64, |
|
0.9, |
|
1.2, |
|
], |
|
['x = 10\n"""Check if x is greater than 5"""\nif x > 5:', 0.2, 64, 0.9, 1.2], |
|
["def square(x: int) -> int:\n return", 0.2, 64, 0.9, 1.2], |
|
['import math\n"""Math operations"""\nmath.', 0.2, 64, 0.9, 1.2], |
|
[ |
|
'def is_even(n) -> bool:\n """\n Check if a number is even\n """\n if n % 2 == 0:', |
|
0.2, |
|
64, |
|
0.9, |
|
1.2, |
|
], |
|
[ |
|
'while True:\n """Infinite loop example"""\n print("Infinite loop,', |
|
0.2, |
|
64, |
|
0.9, |
|
1.2, |
|
], |
|
[ |
|
"def sum_list(lst: list[int]) -> int:\n total = 0\n for item in lst:", |
|
0.2, |
|
64, |
|
0.9, |
|
1.2, |
|
], |
|
[ |
|
'try:\n """\n Exception handling\n """\n x = int(input("Enter a number: "))\nexcept ValueError:', |
|
0.2, |
|
64, |
|
0.9, |
|
1.2, |
|
], |
|
[ |
|
'def divide(a: float, b: float) -> float:\n """\n Divide a by b\n """\n if b != 0:', |
|
0.2, |
|
64, |
|
0.9, |
|
1.2, |
|
], |
|
] |
|
|
|
|
|
|
|
with gr.Blocks(theme=theme, analytics_enabled=False, css=_styles) as demo: |
|
with gr.Column(): |
|
gr.Markdown(description) |
|
with gr.Row(): |
|
with gr.Column(): |
|
instruction = gr.Textbox( |
|
value=examples[0][0], |
|
placeholder="Enter your code here", |
|
label="Code", |
|
elem_id="q-input", |
|
) |
|
submit = gr.Button("Generate", variant="primary") |
|
output = gr.Code(elem_id="q-output", language="python", lines=10) |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Accordion("Advanced settings", open=False): |
|
with gr.Row(): |
|
column_1, column_2 = gr.Column(), gr.Column() |
|
with column_1: |
|
temperature = gr.Slider( |
|
label="Temperature", |
|
value=0.2, |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.05, |
|
interactive=True, |
|
info="Higher values produce more diverse outputs", |
|
) |
|
max_new_tokens = gr.Slider( |
|
label="Max new tokens", |
|
value=64, |
|
minimum=32, |
|
maximum=512, |
|
step=32, |
|
interactive=True, |
|
info="Number of tokens to generate", |
|
) |
|
with column_2: |
|
top_p = gr.Slider( |
|
label="Top-p (nucleus sampling)", |
|
value=0.90, |
|
minimum=0.0, |
|
maximum=1, |
|
step=0.05, |
|
interactive=True, |
|
info="Higher values sample more low-probability tokens", |
|
) |
|
repetition_penalty = gr.Slider( |
|
label="Repetition penalty", |
|
value=1.2, |
|
minimum=1.0, |
|
maximum=2.0, |
|
step=0.05, |
|
interactive=True, |
|
info="Penalize repeated tokens", |
|
) |
|
with gr.Column(): |
|
version = gr.Dropdown( |
|
[ |
|
"smol_llama-101M-GQA-python", |
|
], |
|
value="smol_llama-101M-GQA-python", |
|
label="Version", |
|
info="", |
|
) |
|
gr.Markdown(disclaimer) |
|
gr.Examples( |
|
examples=examples, |
|
inputs=[ |
|
instruction, |
|
temperature, |
|
max_new_tokens, |
|
top_p, |
|
repetition_penalty, |
|
version, |
|
], |
|
cache_examples=False, |
|
fn=run_inference, |
|
outputs=[output], |
|
) |
|
gr.Markdown(base_model_info) |
|
gr.Markdown(formats) |
|
|
|
submit.click( |
|
run_inference, |
|
inputs=[ |
|
instruction, |
|
temperature, |
|
max_new_tokens, |
|
top_p, |
|
repetition_penalty, |
|
], |
|
outputs=[output], |
|
|
|
|
|
show_progress=True, |
|
) |
|
|
|
|
|
demo.launch( |
|
debug=True, |
|
show_api=False, |
|
share=utils.is_google_colab(), |
|
) |
|
|