File size: 3,902 Bytes
dc66050 3674844 5756803 f1aaba3 af7f1fe b1f7d5c f1aaba3 cf29376 f1aaba3 592b501 ee14926 6193310 e1a03f2 ee14926 cf29376 5992538 687b793 ee14926 b1f7d5c ee14926 dc66050 97413fe 2d1e4f3 ee14926 f7947fc 2d1e4f3 3d2d75d ee14926 2d1e4f3 ee14926 2d1e4f3 ee14926 2d1e4f3 7ded5fa e23a574 7ded5fa 8586ebb 086f897 7ded5fa e95efe1 e23a574 e95efe1 d23e6f9 e95efe1 104e584 9d155c8 4e51d19 9d155c8 e23a574 4e51d19 9d155c8 7ded5fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import gradio as gr
from huggingface_hub import InferenceClient
import os
from smolagents import (
tool,
CodeAgent,
HfApiModel,
GradioUI,
MultiStepAgent,
stream_to_gradio,
)
from sqlalchemy import (
create_engine,
MetaData,
Table,
Column,
String,
Integer,
Float,
insert,
inspect,
text,
select,
Engine,
)
import spaces
from dotenv import load_dotenv
load_dotenv()
# What is the average each customer paid?
# Create a sql statement and invoke your sql_engine tool
@spaces.GPU
def dummy():
pass
@tool
def sql_engine_tool(query: str) -> str:
"""
Allows you to perform SQL queries on the table. Returns a string representation of the result.
The table is named 'receipts'. Its description is as follows:
Columns:
- receipt_id: INTEGER
- customer_name: VARCHAR(16)
- price: FLOAT
- tip: FLOAT
Args:
query: The query to perform. This should be correct SQL.
"""
output = ""
with engine.begin() as con:
rows = con.execute(text(query))
for row in rows:
output += "\n" + str(row)
return output
def init_db(engine):
metadata_obj = MetaData()
def insert_rows_into_table(rows, table, engine=engine):
for row in rows:
stmt = insert(table).values(**row)
with engine.begin() as connection:
connection.execute(stmt)
table_name = "receipts"
receipts = Table(
table_name,
metadata_obj,
Column("receipt_id", Integer, primary_key=True),
Column("customer_name", String(16), primary_key=True),
Column("price", Float),
Column("tip", Float),
)
metadata_obj.create_all(engine)
rows = [
{"receipt_id": 1, "customer_name": "Alan Payne", "price": 12.06, "tip": 1.20},
{"receipt_id": 2, "customer_name": "Alex Mason", "price": 23.86, "tip": 0.24},
{
"receipt_id": 3,
"customer_name": "Woodrow Wilson",
"price": 53.43,
"tip": 5.43,
},
{
"receipt_id": 4,
"customer_name": "Margaret James",
"price": 21.11,
"tip": 1.00,
},
]
insert_rows_into_table(rows, receipts)
table_name = "waiters"
waiters = Table(
table_name,
metadata_obj,
Column("receipt_id", Integer, primary_key=True),
Column("waiter_name", String(16), primary_key=True),
)
metadata_obj.create_all(engine)
rows = [
{"receipt_id": 1, "waiter_name": "Corey Johnson"},
{"receipt_id": 2, "waiter_name": "Michael Watts"},
{"receipt_id": 3, "waiter_name": "Michael Watts"},
{"receipt_id": 4, "waiter_name": "Margaret James"},
]
insert_rows_into_table(rows, waiters)
return engine
if __name__ == "__main__":
engine = create_engine("sqlite:///:localhost:")
engine = init_db(engine)
model = HfApiModel(
model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
# model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
token=os.getenv("my_first_agents_hf_tokens"),
)
agent = CodeAgent(
tools=[sql_engine_tool],
model=model,
max_steps=10,
verbosity_level=1,
)
def enter_message(new_message, conversation_history):
conversation_history.append(gr.ChatMessage(role="user", content=new_message))
yield "", conversation_history
for msg in stream_to_gradio(agent, new_message):
conversation_history.append(msg)
yield "", conversation_history
with gr.Blocks() as b:
chatbot = gr.Chatbot(type="messages", height=1000)
textbox = gr.Textbox(lines=3, label="")
button = gr.Button("reply")
button.click(enter_message, [textbox, chatbot], [textbox, chatbot])
b.launch()
|