File size: 5,020 Bytes
dc66050
 
3674844
5756803
 
 
d17e40b
5756803
 
d17e40b
5756803
f1aaba3
 
 
 
 
 
 
 
 
 
 
af7f1fe
b1f7d5c
f1aaba3
cf29376
f1aaba3
592b501
 
 
f4b8a3e
6193310
 
e1a03f2
ee14926
cf29376
 
 
 
 
5992538
687b793
ee14926
 
 
 
 
 
 
 
 
 
 
 
 
b1f7d5c
ee14926
 
 
 
dc66050
 
97413fe
2d1e4f3
ee14926
f7947fc
2d1e4f3
 
 
3d2d75d
ee14926
2d1e4f3
 
 
 
 
 
 
 
 
 
ee14926
2d1e4f3
 
 
ee14926
 
 
 
 
 
 
 
 
 
 
 
2d1e4f3
 
7ded5fa
e23a574
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ded5fa
8586ebb
 
 
086f897
7ded5fa
d17e40b
 
 
 
 
 
 
 
 
e95efe1
 
 
 
 
d23e6f9
e95efe1
 
104e584
9d155c8
 
600606a
9d155c8
 
 
4e51d19
1cc1c81
 
97f927e
1cc1c81
 
4e51d19
4a206af
 
 
193e8c6
e6a7334
 
 
 
 
 
97f927e
 
 
1cc1c81
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import gradio as gr
from huggingface_hub import InferenceClient
import os
from smolagents import (
    tool,
    CodeAgent,
    TransformersModel,
    GradioUI,
    MultiStepAgent,
    stream_to_gradio, HfApiModel,
)
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    Float,
    insert,
    inspect,
    text,
    select,
    Engine,
)
import spaces

from dotenv import load_dotenv

load_dotenv()
#sample questions
# What is the average each customer paid?
#  Create a sql statement and invoke your sql_engine tool


@spaces.GPU
def dummy():
    pass


@tool
def sql_engine_tool(query: str) -> str:
    """
    Allows you to perform SQL queries on the table. Returns a string representation of the result.
    The table is named 'receipts'. Its description is as follows:
        Columns:
        - receipt_id: INTEGER
        - customer_name: VARCHAR(16)
        - price: FLOAT
        - tip: FLOAT

    Args:
        query: The query to perform. This should be correct SQL.
    """
    output = ""
    with engine.begin() as con:
        rows = con.execute(text(query))
        for row in rows:
            output += "\n" + str(row)
    return output


def init_db(engine):
    metadata_obj = MetaData()

    def insert_rows_into_table(rows, table, engine=engine):
        for row in rows:
            stmt = insert(table).values(**row)
            with engine.begin() as connection:
                connection.execute(stmt)

    table_name = "receipts"
    receipts = Table(
        table_name,
        metadata_obj,
        Column("receipt_id", Integer, primary_key=True),
        Column("customer_name", String(16), primary_key=True),
        Column("price", Float),
        Column("tip", Float),
    )
    metadata_obj.create_all(engine)

    rows = [
        {"receipt_id": 1, "customer_name": "Alan Payne", "price": 12.06, "tip": 1.20},
        {"receipt_id": 2, "customer_name": "Alex Mason", "price": 23.86, "tip": 0.24},
        {
            "receipt_id": 3,
            "customer_name": "Woodrow Wilson",
            "price": 53.43,
            "tip": 5.43,
        },
        {
            "receipt_id": 4,
            "customer_name": "Margaret James",
            "price": 21.11,
            "tip": 1.00,
        },
    ]
    insert_rows_into_table(rows, receipts)

    table_name = "waiters"
    waiters = Table(
        table_name,
        metadata_obj,
        Column("receipt_id", Integer, primary_key=True),
        Column("waiter_name", String(16), primary_key=True),
    )
    metadata_obj.create_all(engine)

    rows = [
        {"receipt_id": 1, "waiter_name": "Corey Johnson"},
        {"receipt_id": 2, "waiter_name": "Michael Watts"},
        {"receipt_id": 3, "waiter_name": "Michael Watts"},
        {"receipt_id": 4, "waiter_name": "Margaret James"},
    ]
    insert_rows_into_table(rows, waiters)
    return engine


if __name__ == "__main__":
    engine = create_engine("sqlite:///:localhost:")
    engine = init_db(engine)
    #Not working at the moment
    # model = TransformersModel(
    #     # model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
    #     device_map="cuda",
    #     model_id="meta-llama/Llama-3.2-3B-Instruct"
    # )
    model = HfApiModel(
          model_id="meta-llama/Llama-3.2-3B-Instruct",
        token=os.getenv("my_first_agents_hf_tokens")
    )

    agent = CodeAgent(
        tools=[sql_engine_tool],
        model=model,
        max_steps=10,
        verbosity_level=1,
    )

    def enter_message(new_message, conversation_history):
        conversation_history.append(gr.ChatMessage(role="user", content=new_message))
        # yield "", conversation_history
        for msg in stream_to_gradio(agent, new_message):
            conversation_history.append(msg)
            yield "", conversation_history


    def clear_message(chat_history: list):
        agent.memory.reset()
        return chat_history.clear(), ""

    with gr.Blocks() as b:
        gr.Markdown('''# Demo text to sql on paying customers' receipts
        a self correcting text to sql ai agent using smolagents, gradio, HF Spaces, sqlalchemy improved from a smolagents guide
        ''')
        chatbot = gr.Chatbot(type="messages", height=2000)
        message_box = gr.Textbox(lines=1, label="chat message (with default sample question)", value="What is the average each customer paid?")
        with gr.Row():
            stop_generating_button = gr.Button("stop generating")
            clear_messages_button = gr.ClearButton([message_box, chatbot])
            enter_button = gr.Button("enter")
        reply_button_click_event = enter_button.click(enter_message, [message_box, chatbot], [message_box, chatbot])
        message_submit = message_box.submit(enter_message, [message_box, chatbot], [message_box, chatbot])
        stop_generating_button.click(fn= stop_gen,cancels=[reply_button_click_event,message_submit])
        clear_messages_button.click(clear_message,outputs=[chatbot,message_box])

    b.launch()