File size: 5,020 Bytes
dc66050 3674844 5756803 d17e40b 5756803 d17e40b 5756803 f1aaba3 af7f1fe b1f7d5c f1aaba3 cf29376 f1aaba3 592b501 f4b8a3e 6193310 e1a03f2 ee14926 cf29376 5992538 687b793 ee14926 b1f7d5c ee14926 dc66050 97413fe 2d1e4f3 ee14926 f7947fc 2d1e4f3 3d2d75d ee14926 2d1e4f3 ee14926 2d1e4f3 ee14926 2d1e4f3 7ded5fa e23a574 7ded5fa 8586ebb 086f897 7ded5fa d17e40b e95efe1 d23e6f9 e95efe1 104e584 9d155c8 600606a 9d155c8 4e51d19 1cc1c81 97f927e 1cc1c81 4e51d19 4a206af 193e8c6 e6a7334 97f927e 1cc1c81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import gradio as gr
from huggingface_hub import InferenceClient
import os
from smolagents import (
tool,
CodeAgent,
TransformersModel,
GradioUI,
MultiStepAgent,
stream_to_gradio, HfApiModel,
)
from sqlalchemy import (
create_engine,
MetaData,
Table,
Column,
String,
Integer,
Float,
insert,
inspect,
text,
select,
Engine,
)
import spaces
from dotenv import load_dotenv
load_dotenv()
#sample questions
# What is the average each customer paid?
# Create a sql statement and invoke your sql_engine tool
@spaces.GPU
def dummy():
pass
@tool
def sql_engine_tool(query: str) -> str:
"""
Allows you to perform SQL queries on the table. Returns a string representation of the result.
The table is named 'receipts'. Its description is as follows:
Columns:
- receipt_id: INTEGER
- customer_name: VARCHAR(16)
- price: FLOAT
- tip: FLOAT
Args:
query: The query to perform. This should be correct SQL.
"""
output = ""
with engine.begin() as con:
rows = con.execute(text(query))
for row in rows:
output += "\n" + str(row)
return output
def init_db(engine):
metadata_obj = MetaData()
def insert_rows_into_table(rows, table, engine=engine):
for row in rows:
stmt = insert(table).values(**row)
with engine.begin() as connection:
connection.execute(stmt)
table_name = "receipts"
receipts = Table(
table_name,
metadata_obj,
Column("receipt_id", Integer, primary_key=True),
Column("customer_name", String(16), primary_key=True),
Column("price", Float),
Column("tip", Float),
)
metadata_obj.create_all(engine)
rows = [
{"receipt_id": 1, "customer_name": "Alan Payne", "price": 12.06, "tip": 1.20},
{"receipt_id": 2, "customer_name": "Alex Mason", "price": 23.86, "tip": 0.24},
{
"receipt_id": 3,
"customer_name": "Woodrow Wilson",
"price": 53.43,
"tip": 5.43,
},
{
"receipt_id": 4,
"customer_name": "Margaret James",
"price": 21.11,
"tip": 1.00,
},
]
insert_rows_into_table(rows, receipts)
table_name = "waiters"
waiters = Table(
table_name,
metadata_obj,
Column("receipt_id", Integer, primary_key=True),
Column("waiter_name", String(16), primary_key=True),
)
metadata_obj.create_all(engine)
rows = [
{"receipt_id": 1, "waiter_name": "Corey Johnson"},
{"receipt_id": 2, "waiter_name": "Michael Watts"},
{"receipt_id": 3, "waiter_name": "Michael Watts"},
{"receipt_id": 4, "waiter_name": "Margaret James"},
]
insert_rows_into_table(rows, waiters)
return engine
if __name__ == "__main__":
engine = create_engine("sqlite:///:localhost:")
engine = init_db(engine)
#Not working at the moment
# model = TransformersModel(
# # model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
# device_map="cuda",
# model_id="meta-llama/Llama-3.2-3B-Instruct"
# )
model = HfApiModel(
model_id="meta-llama/Llama-3.2-3B-Instruct",
token=os.getenv("my_first_agents_hf_tokens")
)
agent = CodeAgent(
tools=[sql_engine_tool],
model=model,
max_steps=10,
verbosity_level=1,
)
def enter_message(new_message, conversation_history):
conversation_history.append(gr.ChatMessage(role="user", content=new_message))
# yield "", conversation_history
for msg in stream_to_gradio(agent, new_message):
conversation_history.append(msg)
yield "", conversation_history
def clear_message(chat_history: list):
agent.memory.reset()
return chat_history.clear(), ""
with gr.Blocks() as b:
gr.Markdown('''# Demo text to sql on paying customers' receipts
a self correcting text to sql ai agent using smolagents, gradio, HF Spaces, sqlalchemy improved from a smolagents guide
''')
chatbot = gr.Chatbot(type="messages", height=2000)
message_box = gr.Textbox(lines=1, label="chat message (with default sample question)", value="What is the average each customer paid?")
with gr.Row():
stop_generating_button = gr.Button("stop generating")
clear_messages_button = gr.ClearButton([message_box, chatbot])
enter_button = gr.Button("enter")
reply_button_click_event = enter_button.click(enter_message, [message_box, chatbot], [message_box, chatbot])
message_submit = message_box.submit(enter_message, [message_box, chatbot], [message_box, chatbot])
stop_generating_button.click(fn= stop_gen,cancels=[reply_button_click_event,message_submit])
clear_messages_button.click(clear_message,outputs=[chatbot,message_box])
b.launch() |