File size: 3,909 Bytes
dc66050
 
3674844
5756803
 
 
 
 
 
 
 
f1aaba3
 
 
 
 
 
 
 
 
 
 
af7f1fe
b1f7d5c
f1aaba3
cf29376
f1aaba3
592b501
 
 
ee14926
6193310
 
e1a03f2
ee14926
cf29376
 
 
 
 
5992538
687b793
ee14926
 
 
 
 
 
 
 
 
 
 
 
 
7fdbbb2
1b0cc5f
b1f7d5c
7a8951a
7fdbbb2
1a788a2
 
 
 
 
 
 
ee14926
 
 
 
dc66050
 
97413fe
ee14926
2d1e4f3
ee14926
f7947fc
2d1e4f3
 
 
3d2d75d
ee14926
2d1e4f3
 
 
 
 
 
 
 
 
 
ee14926
2d1e4f3
 
 
ee14926
 
 
 
 
 
 
 
 
 
 
 
2d1e4f3
 
3d2d75d
28fb60e
7a8951a
1b0cc5f
7a8951a
7fdbbb2
8586ebb
 
 
086f897
a37048e
e95efe1
1b0cc5f
e95efe1
 
 
 
 
 
c25281e
e95efe1
 
4e51d19
104e584
9d155c8
4e51d19
9d155c8
 
 
 
 
 
4e51d19
 
9d155c8
104e584
4e51d19
9d155c8
eebef7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
from huggingface_hub import InferenceClient
import os
from smolagents import (
    tool,
    CodeAgent,
    HfApiModel,
    GradioUI,
    MultiStepAgent,
    stream_to_gradio,
)
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    Float,
    insert,
    inspect,
    text,
    select,
    Engine,
)
import spaces

from dotenv import load_dotenv

load_dotenv()

# What is the average each customer paid?
#  Create a sql statement and invoke your sql_engine tool


@spaces.GPU
def dummy():
    pass


@tool
def sql_engine_tool(query: str) -> str:
    """
    Allows you to perform SQL queries on the table. Returns a string representation of the result.
    The table is named 'receipts'. Its description is as follows:
        Columns:
        - receipt_id: INTEGER
        - customer_name: VARCHAR(16)
        - price: FLOAT
        - tip: FLOAT

    Args:
        query: The query to perform. This should be correct SQL.
    """
    output = ""
    print("debug sql_engine_tool")
    print(engine)
    with engine.begin() as con:
        print(con.connection)
        print(metadata_objects.tables.keys())
        result = con.execute(
            text(
                "SELECT name FROM sqlite_master WHERE type='table' AND name='receipts'"
            )
        )
        print("tables available:", result.fetchone())

        rows = con.execute(text(query))
        for row in rows:
            output += "\n" + str(row)
    return output


def init_db(engine):

    metadata_obj = MetaData()

    def insert_rows_into_table(rows, table, engine=engine):
        for row in rows:
            stmt = insert(table).values(**row)
            with engine.begin() as connection:
                connection.execute(stmt)

    table_name = "receipts"
    receipts = Table(
        table_name,
        metadata_obj,
        Column("receipt_id", Integer, primary_key=True),
        Column("customer_name", String(16), primary_key=True),
        Column("price", Float),
        Column("tip", Float),
    )
    metadata_obj.create_all(engine)

    rows = [
        {"receipt_id": 1, "customer_name": "Alan Payne", "price": 12.06, "tip": 1.20},
        {"receipt_id": 2, "customer_name": "Alex Mason", "price": 23.86, "tip": 0.24},
        {
            "receipt_id": 3,
            "customer_name": "Woodrow Wilson",
            "price": 53.43,
            "tip": 5.43,
        },
        {
            "receipt_id": 4,
            "customer_name": "Margaret James",
            "price": 21.11,
            "tip": 1.00,
        },
    ]
    insert_rows_into_table(rows, receipts)
    with engine.begin() as conn:
        print("SELECT test", conn.execute(text("SELECT * FROM receipts")).fetchall())
    print("init_db debug")
    print(engine)
    print()
    return engine, metadata_obj


if __name__ == "__main__":
    engine = create_engine("sqlite:///:localhost:")
    engine, metadata_objects = init_db(engine)
    model = HfApiModel(
        model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
        token=os.getenv("my_first_agents_hf_tokens"),
    )

    agent = CodeAgent(
        tools=[sql_engine_tool],
        model=model,
        max_steps=1,
        verbosity_level=1,
    )
    # GradioUI(agent).launch()

    def enter_message(new_message, conversation_history):

        conversation_history = []
        conversation_history.append(gr.ChatMessage(role="user", content=new_message))
        yield "", conversation_history
        for msg in stream_to_gradio(agent, new_message):
            conversation_history.append(msg)
            yield "", conversation_history

    with gr.Blocks() as b:
        chatbot = gr.Chatbot(type="messages", height=1000)
        textbox = gr.Textbox()
        button = gr.Button("reply")
        button.click(enter_message, [textbox, chatbot], [textbox, chatbot])
    b.launch(debug=True)