Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,108 +1,148 @@
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
-
|
|
|
4 |
import requests
|
5 |
-
from typing import Dict, List
|
6 |
-
from langchain_core.messages import HumanMessage
|
7 |
from langchain_core.tools import tool
|
8 |
from langchain_openai import ChatOpenAI
|
9 |
from langgraph.checkpoint.memory import MemorySaver
|
10 |
from langgraph.prebuilt import create_react_agent
|
11 |
|
12 |
-
#
|
13 |
@tool
|
14 |
def get_lat_lng(location_description: str) -> dict[str, float]:
|
15 |
"""Get the latitude and longitude of a location."""
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
@tool
|
19 |
def get_weather(lat: float, lng: float) -> dict[str, str]:
|
20 |
"""Get the weather at a location."""
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
llm = ChatOpenAI(temperature=0, model="gpt-4")
|
29 |
-
memory = MemorySaver()
|
30 |
tools = [get_lat_lng, get_weather]
|
31 |
agent_executor = create_react_agent(llm, tools, checkpointer=memory)
|
32 |
-
|
33 |
-
#
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
{"messages": past_messages},
|
44 |
-
config={"configurable": {"thread_id": "abc123"}}
|
45 |
):
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
yield messages_to_display
|
79 |
-
|
80 |
-
# Create the Gradio interface
|
81 |
demo = gr.ChatInterface(
|
82 |
fn=stream_from_agent,
|
83 |
-
type="messages"
|
84 |
title="🌤️ Weather Assistant",
|
85 |
description="Ask about the weather anywhere! Watch as I gather the information step by step.",
|
86 |
examples=[
|
87 |
-
"What's the weather like in Tokyo?",
|
88 |
-
"Is it sunny in Paris right now?",
|
89 |
-
"Should I bring an umbrella in New York today?"
|
90 |
],
|
91 |
-
|
92 |
-
|
93 |
-
"https://cdn2.iconfinder.com/data/icons/city-icons-for-offscreen-magazine/80/new-york-256.png"
|
94 |
-
],
|
95 |
save_history=True,
|
96 |
-
editable=True
|
97 |
-
|
98 |
)
|
99 |
|
100 |
if __name__ == "__main__":
|
101 |
# Load environment variables
|
102 |
try:
|
103 |
from dotenv import load_dotenv
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
except ImportError:
|
|
|
106 |
pass
|
107 |
|
108 |
-
|
|
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
+
# Keep using gradio.ChatMessage for type hints if needed, but not for yielding complex structures directly to ChatInterface
|
4 |
+
# from gradio import ChatMessage # Maybe remove this import if not used elsewhere
|
5 |
import requests
|
6 |
+
from typing import Dict, List, AsyncGenerator, Union, Tuple
|
7 |
+
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage # Use LangChain messages internally
|
8 |
from langchain_core.tools import tool
|
9 |
from langchain_openai import ChatOpenAI
|
10 |
from langgraph.checkpoint.memory import MemorySaver
|
11 |
from langgraph.prebuilt import create_react_agent
|
12 |
|
13 |
+
# --- Tools remain the same ---
|
14 |
@tool
|
15 |
def get_lat_lng(location_description: str) -> dict[str, float]:
|
16 |
"""Get the latitude and longitude of a location."""
|
17 |
+
print(f"Tool: Getting lat/lng for {location_description}")
|
18 |
+
# Replace with actual API call in a real app
|
19 |
+
if "tokyo" in location_description.lower():
|
20 |
+
return {"lat": 35.6895, "lng": 139.6917}
|
21 |
+
elif "paris" in location_description.lower():
|
22 |
+
return {"lat": 48.8566, "lng": 2.3522}
|
23 |
+
elif "new york" in location_description.lower():
|
24 |
+
return {"lat": 40.7128, "lng": -74.0060}
|
25 |
+
else:
|
26 |
+
return {"lat": 51.5072, "lng": -0.1276} # Default London
|
27 |
|
28 |
@tool
|
29 |
def get_weather(lat: float, lng: float) -> dict[str, str]:
|
30 |
"""Get the weather at a location."""
|
31 |
+
print(f"Tool: Getting weather for lat={lat}, lng={lng}")
|
32 |
+
# Replace with actual API call in a real app
|
33 |
+
# Dummy logic based on lat
|
34 |
+
if lat > 45: # Northern locations
|
35 |
+
return {"temperature": "15°C", "description": "Cloudy"}
|
36 |
+
elif lat > 30: # Mid locations
|
37 |
+
return {"temperature": "25°C", "description": "Sunny"}
|
38 |
+
else: # Southern locations
|
39 |
+
return {"temperature": "30°C", "description": "Very Sunny"}
|
40 |
|
41 |
+
# --- Modified Agent Function ---
|
42 |
+
# Change return type hint for clarity if desired, e.g., AsyncGenerator[str, None]
|
43 |
+
# Or keep it simple, Gradio infers based on yields
|
44 |
+
async def stream_from_agent(message: str, history: List[List[str]]) -> AsyncGenerator[str, None]:
|
45 |
+
"""Processes message through LangChain agent, yielding intermediate steps as strings."""
|
46 |
+
|
47 |
+
# Convert Gradio history to LangChain messages
|
48 |
+
lc_messages = []
|
49 |
+
for user_msg, ai_msg in history:
|
50 |
+
if user_msg:
|
51 |
+
lc_messages.append(HumanMessage(content=user_msg))
|
52 |
+
if ai_msg:
|
53 |
+
# Important: Handle potential previous intermediate strings from AI
|
54 |
+
# If the ai_msg contains markers like "🛠️ Using", it was an intermediate step.
|
55 |
+
# For simplicity here, we assume full AI responses were stored previously.
|
56 |
+
# A more robust solution might involve storing message types in history.
|
57 |
+
if not ai_msg.startswith("🛠️ Using") and not ai_msg.startswith("Result:"):
|
58 |
+
lc_messages.append(AIMessage(content=ai_msg))
|
59 |
+
|
60 |
+
lc_messages.append(HumanMessage(content=message))
|
61 |
+
|
62 |
+
# Initialize the agent (consider initializing outside the function if stateful across calls)
|
63 |
llm = ChatOpenAI(temperature=0, model="gpt-4")
|
64 |
+
memory = MemorySaver() # Be mindful of memory state if agent is re-initialized every time
|
65 |
tools = [get_lat_lng, get_weather]
|
66 |
agent_executor = create_react_agent(llm, tools, checkpointer=memory)
|
67 |
+
|
68 |
+
# Use a unique thread_id per session if needed, or manage state differently
|
69 |
+
# Using a fixed one like "abc123" means all users share the same memory if server restarts aren't frequent
|
70 |
+
thread_id = "user_session_" + str(os.urandom(4).hex()) # Example: generate unique ID
|
71 |
+
|
72 |
+
full_response = "" # Accumulate the response parts
|
73 |
+
|
74 |
+
async for chunk in agent_executor.astream_events(
|
75 |
+
{"messages": lc_messages},
|
76 |
+
config={"configurable": {"thread_id": thread_id}},
|
77 |
+
version="v1" # Use v1 for events streaming
|
|
|
|
|
78 |
):
|
79 |
+
event = chunk["event"]
|
80 |
+
data = chunk["data"]
|
81 |
+
|
82 |
+
if event == "on_chat_model_stream":
|
83 |
+
# Stream content from the LLM (final answer parts)
|
84 |
+
content = data["chunk"].content
|
85 |
+
if content:
|
86 |
+
full_response += content
|
87 |
+
yield full_response # Yield the accumulating final response
|
88 |
+
|
89 |
+
elif event == "on_tool_start":
|
90 |
+
# Show tool usage start
|
91 |
+
tool_input_str = str(data.get('input', '')) # Get tool input safely
|
92 |
+
yield f"🛠️ Using tool: **{data['name']}** with input: `{tool_input_str}`"
|
93 |
+
|
94 |
+
elif event == "on_tool_end":
|
95 |
+
# Show tool result (optional, can make chat verbose)
|
96 |
+
tool_output_str = str(data.get('output', '')) # Get tool output safely
|
97 |
+
# Find the corresponding start message to potentially update, or just yield new message
|
98 |
+
# For simplicity, just yield the result as a new message line
|
99 |
+
yield f"Tool **{data['name']}** finished.\nResult: `{tool_output_str}`"
|
100 |
+
# Yield the accumulated response again after tool use in case LLM continues
|
101 |
+
if full_response:
|
102 |
+
yield full_response
|
103 |
+
|
104 |
+
# Ensure the final accumulated response is yielded if not already done by the last LLM chunk
|
105 |
+
# (stream might end on tool end sometimes)
|
106 |
+
if full_response and (not chunk or chunk["event"] != "on_chat_model_stream"):
|
107 |
+
yield full_response
|
108 |
+
|
109 |
+
|
110 |
+
# --- Gradio Interface (mostly unchanged) ---
|
|
|
|
|
|
|
111 |
demo = gr.ChatInterface(
|
112 |
fn=stream_from_agent,
|
113 |
+
# No type="messages" needed when yielding strings; ChatInterface handles it.
|
114 |
title="🌤️ Weather Assistant",
|
115 |
description="Ask about the weather anywhere! Watch as I gather the information step by step.",
|
116 |
examples=[
|
117 |
+
["What's the weather like in Tokyo?"],
|
118 |
+
["Is it sunny in Paris right now?"],
|
119 |
+
["Should I bring an umbrella in New York today?"]
|
120 |
],
|
121 |
+
# Example icons removed for simplicity, ensure they are accessible if added back
|
122 |
+
cache_examples=False, # Turn off caching initially to ensure it's not the issue
|
|
|
|
|
123 |
save_history=True,
|
124 |
+
editable=True,
|
|
|
125 |
)
|
126 |
|
127 |
if __name__ == "__main__":
|
128 |
# Load environment variables
|
129 |
try:
|
130 |
from dotenv import load_dotenv
|
131 |
+
print("Attempting to load .env file...")
|
132 |
+
loaded = load_dotenv()
|
133 |
+
if loaded:
|
134 |
+
print(".env file loaded successfully.")
|
135 |
+
else:
|
136 |
+
print(".env file not found or empty.")
|
137 |
+
# Check if the key is loaded
|
138 |
+
openai_api_key = os.getenv("OPENAI_API_KEY")
|
139 |
+
if openai_api_key:
|
140 |
+
print("OPENAI_API_KEY found.")
|
141 |
+
else:
|
142 |
+
print("Warning: OPENAI_API_KEY not found in environment variables.")
|
143 |
except ImportError:
|
144 |
+
print("dotenv not installed, skipping .env load.")
|
145 |
pass
|
146 |
|
147 |
+
# Add server_name="0.0.0.0" if running in Docker or need external access
|
148 |
+
demo.launch(debug=True, server_name="0.0.0.0")
|