arjunanand13 commited on
Commit
e8bcf17
·
verified ·
1 Parent(s): 06105b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -21
app.py CHANGED
@@ -9,12 +9,9 @@ from typing_extensions import TypedDict
9
  from langgraph.graph import StateGraph, START, END
10
  import csv
11
 
12
- # Load API keys
13
  openai.api_key = os.getenv("OPENAI_API_KEY")
14
 
15
- # --- Initialize Database from external CSV ---
16
  def init_db_from_csv(csv_path: str = "transactions.csv") -> None:
17
- """Create 'transactions' table and load data from the provided CSV file."""
18
  conn = sqlite3.connect("shop.db")
19
  cur = conn.cursor()
20
  cur.execute(
@@ -23,7 +20,6 @@ def init_db_from_csv(csv_path: str = "transactions.csv") -> None:
23
  with open(csv_path, newline='') as f:
24
  reader = csv.DictReader(f)
25
  rows = [(row["date"], row["product"], float(row["amount"])) for row in reader]
26
- # Replace old data
27
  cur.execute("DELETE FROM transactions")
28
  cur.executemany(
29
  "INSERT INTO transactions (date, product, amount) VALUES (?, ?, ?)", rows
@@ -31,11 +27,8 @@ def init_db_from_csv(csv_path: str = "transactions.csv") -> None:
31
  conn.commit()
32
  conn.close()
33
 
34
- # Initialize DB at startup (ensure transactions.csv is present)
35
  init_db_from_csv()
36
 
37
- # --- Business Logic Functions ---
38
-
39
  def db_agent(query: str) -> str:
40
  try:
41
  conn = sqlite3.connect("shop.db")
@@ -57,7 +50,6 @@ def db_agent(query: str) -> str:
57
  except sqlite3.OperationalError as e:
58
  return f"Database error: {e}. Please check 'transactions' table in shop.db."
59
 
60
-
61
  def web_search_agent(query: str) -> str:
62
  try:
63
  resp = requests.get(
@@ -71,7 +63,6 @@ def web_search_agent(query: str) -> str:
71
  pass
72
  return llm_agent(query)
73
 
74
-
75
  def llm_agent(query: str) -> str:
76
  response = openai.chat.completions.create(
77
  model="gpt-4o-mini",
@@ -83,7 +74,6 @@ def llm_agent(query: str) -> str:
83
  )
84
  return response.choices[0].message.content.strip()
85
 
86
-
87
  def stt_agent(audio_path: str) -> str:
88
  with open(audio_path, "rb") as afile:
89
  transcript = openai.audio.transcriptions.create(
@@ -92,20 +82,16 @@ def stt_agent(audio_path: str) -> str:
92
  )
93
  return transcript.text.strip()
94
 
95
-
96
  def tts_agent(text: str, lang: str = 'en') -> str:
97
  tts = gTTS(text=text, lang=lang)
98
  out_path = "response_audio.mp3"
99
  tts.save(out_path)
100
  return out_path
101
 
102
- # --- LangGraph State and Nodes ---
103
  class State(TypedDict):
104
  query: str
105
  result: str
106
 
107
- # Routing logic based on query
108
-
109
  def route_fn(state: State) -> str:
110
  q = state["query"].lower()
111
  if any(k in q for k in ["max revenue", "revenue"]):
@@ -114,8 +100,6 @@ def route_fn(state: State) -> str:
114
  return "web"
115
  return "llm"
116
 
117
- # Node implementations
118
-
119
  def router_node(state: State) -> dict:
120
  return {"query": state["query"]}
121
 
@@ -128,7 +112,6 @@ def web_node(state: State) -> dict:
128
  def llm_node(state: State) -> dict:
129
  return {"result": llm_agent(state["query"]) }
130
 
131
- # Build the LangGraph
132
  builder = StateGraph(State)
133
  builder.add_node("router", router_node)
134
  builder.set_entry_point("router")
@@ -145,7 +128,6 @@ builder.add_edge("web", END)
145
  builder.add_edge("llm", END)
146
  graph = builder.compile()
147
 
148
- # Handler integrates STT/TTS and graph execution
149
  def handle_query(audio_or_text: str):
150
  is_audio = audio_or_text.endswith('.wav') or audio_or_text.endswith('.mp3')
151
  if is_audio:
@@ -161,14 +143,12 @@ def handle_query(audio_or_text: str):
161
  return response, audio_path
162
  return response
163
 
164
- # --- Gradio UI ---
165
  with gr.Blocks() as demo:
166
  gr.Markdown("## Shop Voice-Box Assistant (Speech In/Out)")
167
  inp = gr.Audio(sources=["microphone"], type="filepath", label="Speak or type your question or upload transactions.csv separately in root")
168
  out_text = gr.Textbox(label="Answer (text)")
169
  out_audio = gr.Audio(label="Answer (speech)")
170
  submit = gr.Button("Submit")
171
- # Examples
172
  gr.Examples(
173
  examples=[
174
  ["What is the max revenue product today?"],
@@ -181,4 +161,4 @@ with gr.Blocks() as demo:
181
  submit.click(fn=handle_query, inputs=inp, outputs=[out_text, out_audio])
182
 
183
  if __name__ == "__main__":
184
- demo.launch(share=False, server_name="0.0.0.0", server_port=7860)
 
9
  from langgraph.graph import StateGraph, START, END
10
  import csv
11
 
 
12
  openai.api_key = os.getenv("OPENAI_API_KEY")
13
 
 
14
  def init_db_from_csv(csv_path: str = "transactions.csv") -> None:
 
15
  conn = sqlite3.connect("shop.db")
16
  cur = conn.cursor()
17
  cur.execute(
 
20
  with open(csv_path, newline='') as f:
21
  reader = csv.DictReader(f)
22
  rows = [(row["date"], row["product"], float(row["amount"])) for row in reader]
 
23
  cur.execute("DELETE FROM transactions")
24
  cur.executemany(
25
  "INSERT INTO transactions (date, product, amount) VALUES (?, ?, ?)", rows
 
27
  conn.commit()
28
  conn.close()
29
 
 
30
  init_db_from_csv()
31
 
 
 
32
  def db_agent(query: str) -> str:
33
  try:
34
  conn = sqlite3.connect("shop.db")
 
50
  except sqlite3.OperationalError as e:
51
  return f"Database error: {e}. Please check 'transactions' table in shop.db."
52
 
 
53
  def web_search_agent(query: str) -> str:
54
  try:
55
  resp = requests.get(
 
63
  pass
64
  return llm_agent(query)
65
 
 
66
  def llm_agent(query: str) -> str:
67
  response = openai.chat.completions.create(
68
  model="gpt-4o-mini",
 
74
  )
75
  return response.choices[0].message.content.strip()
76
 
 
77
  def stt_agent(audio_path: str) -> str:
78
  with open(audio_path, "rb") as afile:
79
  transcript = openai.audio.transcriptions.create(
 
82
  )
83
  return transcript.text.strip()
84
 
 
85
  def tts_agent(text: str, lang: str = 'en') -> str:
86
  tts = gTTS(text=text, lang=lang)
87
  out_path = "response_audio.mp3"
88
  tts.save(out_path)
89
  return out_path
90
 
 
91
  class State(TypedDict):
92
  query: str
93
  result: str
94
 
 
 
95
  def route_fn(state: State) -> str:
96
  q = state["query"].lower()
97
  if any(k in q for k in ["max revenue", "revenue"]):
 
100
  return "web"
101
  return "llm"
102
 
 
 
103
  def router_node(state: State) -> dict:
104
  return {"query": state["query"]}
105
 
 
112
  def llm_node(state: State) -> dict:
113
  return {"result": llm_agent(state["query"]) }
114
 
 
115
  builder = StateGraph(State)
116
  builder.add_node("router", router_node)
117
  builder.set_entry_point("router")
 
128
  builder.add_edge("llm", END)
129
  graph = builder.compile()
130
 
 
131
  def handle_query(audio_or_text: str):
132
  is_audio = audio_or_text.endswith('.wav') or audio_or_text.endswith('.mp3')
133
  if is_audio:
 
143
  return response, audio_path
144
  return response
145
 
 
146
  with gr.Blocks() as demo:
147
  gr.Markdown("## Shop Voice-Box Assistant (Speech In/Out)")
148
  inp = gr.Audio(sources=["microphone"], type="filepath", label="Speak or type your question or upload transactions.csv separately in root")
149
  out_text = gr.Textbox(label="Answer (text)")
150
  out_audio = gr.Audio(label="Answer (speech)")
151
  submit = gr.Button("Submit")
 
152
  gr.Examples(
153
  examples=[
154
  ["What is the max revenue product today?"],
 
161
  submit.click(fn=handle_query, inputs=inp, outputs=[out_text, out_audio])
162
 
163
  if __name__ == "__main__":
164
+ demo.launch(share=False, server_name="0.0.0.0", server_port=7860)