File size: 3,876 Bytes
dc66050 3674844 5756803 f1aaba3 af7f1fe b1f7d5c f1aaba3 cf29376 f1aaba3 592b501 ee14926 6193310 e1a03f2 ee14926 cf29376 5992538 687b793 ee14926 7fdbbb2 1b0cc5f b1f7d5c 7a8951a 7fdbbb2 1a788a2 ee14926 dc66050 97413fe ee14926 2d1e4f3 ee14926 f7947fc 2d1e4f3 3d2d75d ee14926 2d1e4f3 ee14926 2d1e4f3 ee14926 2d1e4f3 3d2d75d 28fb60e 7a8951a 1b0cc5f 7a8951a 7fdbbb2 8586ebb 086f897 a37048e e95efe1 1b0cc5f e95efe1 d23e6f9 e95efe1 4e51d19 104e584 9d155c8 4e51d19 9d155c8 4e51d19 9d155c8 104e584 4e51d19 9d155c8 eebef7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import gradio as gr
from huggingface_hub import InferenceClient
import os
from smolagents import (
tool,
CodeAgent,
HfApiModel,
GradioUI,
MultiStepAgent,
stream_to_gradio,
)
from sqlalchemy import (
create_engine,
MetaData,
Table,
Column,
String,
Integer,
Float,
insert,
inspect,
text,
select,
Engine,
)
import spaces
from dotenv import load_dotenv
load_dotenv()
# What is the average each customer paid?
# Create a sql statement and invoke your sql_engine tool
@spaces.GPU
def dummy():
pass
@tool
def sql_engine_tool(query: str) -> str:
"""
Allows you to perform SQL queries on the table. Returns a string representation of the result.
The table is named 'receipts'. Its description is as follows:
Columns:
- receipt_id: INTEGER
- customer_name: VARCHAR(16)
- price: FLOAT
- tip: FLOAT
Args:
query: The query to perform. This should be correct SQL.
"""
output = ""
print("debug sql_engine_tool")
print(engine)
with engine.begin() as con:
print(con.connection)
print(metadata_objects.tables.keys())
result = con.execute(
text(
"SELECT name FROM sqlite_master WHERE type='table' AND name='receipts'"
)
)
print("tables available:", result.fetchone())
rows = con.execute(text(query))
for row in rows:
output += "\n" + str(row)
return output
def init_db(engine):
metadata_obj = MetaData()
def insert_rows_into_table(rows, table, engine=engine):
for row in rows:
stmt = insert(table).values(**row)
with engine.begin() as connection:
connection.execute(stmt)
table_name = "receipts"
receipts = Table(
table_name,
metadata_obj,
Column("receipt_id", Integer, primary_key=True),
Column("customer_name", String(16), primary_key=True),
Column("price", Float),
Column("tip", Float),
)
metadata_obj.create_all(engine)
rows = [
{"receipt_id": 1, "customer_name": "Alan Payne", "price": 12.06, "tip": 1.20},
{"receipt_id": 2, "customer_name": "Alex Mason", "price": 23.86, "tip": 0.24},
{
"receipt_id": 3,
"customer_name": "Woodrow Wilson",
"price": 53.43,
"tip": 5.43,
},
{
"receipt_id": 4,
"customer_name": "Margaret James",
"price": 21.11,
"tip": 1.00,
},
]
insert_rows_into_table(rows, receipts)
with engine.begin() as conn:
print("SELECT test", conn.execute(text("SELECT * FROM receipts")).fetchall())
print("init_db debug")
print(engine)
print()
return engine, metadata_obj
if __name__ == "__main__":
engine = create_engine("sqlite:///:localhost:")
engine, metadata_objects = init_db(engine)
model = HfApiModel(
model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
token=os.getenv("my_first_agents_hf_tokens"),
)
agent = CodeAgent(
tools=[sql_engine_tool],
model=model,
max_steps=10,
verbosity_level=1,
)
# GradioUI(agent).launch()
def enter_message(new_message, conversation_history):
conversation_history.append(gr.ChatMessage(role="user", content=new_message))
yield "", conversation_history
for msg in stream_to_gradio(agent, new_message):
conversation_history.append(msg)
yield "", conversation_history
with gr.Blocks() as b:
chatbot = gr.Chatbot(type="messages", height=1000)
textbox = gr.Textbox()
button = gr.Button("reply")
button.click(enter_message, [textbox, chatbot], [textbox, chatbot])
b.launch(debug=True)
|