Spaces:
Sleeping
Sleeping
File size: 8,052 Bytes
6a04a88 975d8af 6a04a88 975d8af 6a04a88 975d8af 6a04a88 975d8af 6a04a88 975d8af 6a04a88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
from langgraph.graph import StateGraph, END
from typing import TypedDict, List, Dict, Any
from agent import LLMChain, PromptTemplate, llm, DOCUMENT_DIR, load_documents, split_documents, CHROMA_PATH, load_vectordb, create_and_store_embeddings
import os
# state schema
class AgentState(TypedDict):
query: str
previous_conversation: str
user_data: Dict[str, Any]
requires_rag: bool
context: List[str]
response: str
# def query_classifier(state: AgentState) -> AgentState:
# """Determine if the query requires RAG retrieval based on keywords.
# Is not continued anymore, will be removed in future."""
# query_lower = state["query"].lower()
# rag_keywords = [
# "scheme", "schemes", "program", "programs", "policy", "policies",
# "public health engineering", "phe", "public health", "government",
# "benefit", "financial", "assistance", "aid", "initiative", "yojana",
# ]
# state["requires_rag"] = any(keyword in query_lower for keyword in rag_keywords)
# return state
def query_classifier(state: AgentState) -> AgentState:
"""Updated classifier to use LLM for intent classification."""
query = state["query"]
# Then classify intent
classification_prompt = f"""
Answer with only 'Yes' or 'No'.
Classify if this query is asking about government schemes, policies, or benefits.
The language may not be English, So first detect the language. and understand the query.:
Query: {query}
Remember Answer with only 'Yes' or 'No'."""
result = llm.predict(classification_prompt)
state["requires_rag"] = "yes" in result.lower()
return state
def retrieve_documents(state: AgentState) -> AgentState:
"""Retrieve documents from vector store if needed."""
if state["requires_rag"]:
# Get the global vector_store variable
# This assumes vector_store is accessible in this scope
docs = vector_store.as_retriever(search_kwargs={"k": 5}).get_relevant_documents(state["query"])
state["context"] = [doc.page_content for doc in docs]
else:
state["context"] = []
return state
def generate_response(state: AgentState) -> AgentState:
"""Generate response with or without context."""
# style = state["user_data"].get("style", "normal") if isinstance(state["user_data"], dict) else "normal"
base_prompt = """You are a helpful health assistant. Who will talk to the user as human and resolve their queries.
Use Previous_Conversation to maintain consistency in the conversation.
These are Previous_Conversation between you and user.
Previous_Conversation: {previous_conversation}
These are info about the person.
User_Data: {user_data}
Points to Adhere:
1. Only tell the schemes if user specifically asked, otherwise don't share schemes information.
2. If the user asks about schemes, Ask about what state they belong to first.
3. You can act as a mental-health counselor if needed.
4. Give precautions and natural-remedies for the diseases, if user asked or it's needed, only for Common diseases include the common cold, flu etc.
5. Ask the preferred language of the user, In the starting of the conversation.
6. Give the answer in a friendly and conversational tone.
7. Style to answer in {style} way.
Question: {question}
"""
if state["requires_rag"] and state["context"]:
# Add context to prompt if we're using RAG
context = "\n".join(state["context"])
prompt_template = base_prompt + "\nContext from knowledge base:\n{context}\n\nAnswer:"
prompt = PromptTemplate(
template=prompt_template,
input_variables=["context", "question", "previous_conversation", "user_data", "style"]
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
result = llm_chain({
'context': context,
'question': state["query"],
'previous_conversation': state["previous_conversation"],
'user_data': state["user_data"],
'style': state["user_data"].get("style", "normal")
})
else:
# Answer directly without context
prompt_template = base_prompt + "\nAnswer:"
prompt = PromptTemplate(
template=prompt_template,
input_variables=["question", "previous_conversation", "user_data", "style"]
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
result = llm_chain({
'question': state["query"],
'previous_conversation': state["previous_conversation"],
'user_data': state["user_data"],
'style': state["user_data"].get("style", "normal")
})
state["response"] = result["text"]
return state
def create_agent_workflow():
"""Create the LangGraph workflow for the health agent."""
# Initialize the state graph
workflow = StateGraph(AgentState)
# Add nodes
workflow.add_node("classifier", query_classifier)
workflow.add_node("retriever", retrieve_documents)
workflow.add_node("responder", generate_response)
# Create edges
workflow.add_edge("classifier", "retriever")
workflow.add_edge("retriever", "responder")
workflow.add_edge("responder", END)
# Set the entry point
workflow.set_entry_point("classifier")
# Compile the graph
return workflow.compile()
def agent_with_db():
# Load or create vector store
global vector_store
vector_store = None
try:
vector_store = load_vectordb(CHROMA_PATH)
except ValueError:
pass
UPDATE_DB = os.getenv("UPDATE_DB", "false")
if UPDATE_DB.lower() == "true" or vector_store is None:
print("Loading and processing documents...")
documents = load_documents(DOCUMENT_DIR)
chunks = split_documents(documents)
try:
vector_store = create_and_store_embeddings(chunks)
except Exception as e:
print(f"Error creating embeddings: {e}")
return None
print("Creating the LangGraph health agent workflow...")
agent_workflow = create_agent_workflow()
class HealthAgent:
def __init__(self, workflow):
self.workflow = workflow
self.conversation_history = ""
def __call__(self, input_data):
# Handle both dictionary input and direct arguments
if isinstance(input_data, dict):
query = input_data.get("query", "")
previous_conversation = input_data.get("previous_conversation", "")
user_data = input_data.get("user_data", {})
style = input_data.get("style", "normal")
else:
# Assume it's a direct query string
query = input_data
previous_conversation = ""
user_data = {}
style = "normal"
# Store previous conversation if provided
if previous_conversation:
self.conversation_history = previous_conversation
# Update conversation history
if self.conversation_history:
self.conversation_history += f"\nUser: {query}\n"
else:
self.conversation_history = f"User: {query}\n"
if "style" not in user_data:
user_data["style"] = style
# Prepare initial state
initial_state = {
"query": query,
"previous_conversation": self.conversation_history,
"user_data": user_data,
"requires_rag": False,
"context": [],
"response": "",
}
final_state = self.workflow.invoke(initial_state)
self.conversation_history += f"Assistant: {final_state['response']}\n"
return {"result": final_state["response"]}
return HealthAgent(agent_workflow) |