Spaces:
Sleeping
Sleeping
import os | |
from typing import List | |
from chainlit.types import AskFileResponse | |
from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader | |
from aimakerspace.openai_utils.prompts import ( | |
UserRolePrompt, | |
SystemRolePrompt, | |
AssistantRolePrompt, | |
) | |
from aimakerspace.openai_utils.embedding import EmbeddingModel | |
from aimakerspace.vectordatabase import VectorDatabase | |
from aimakerspace.openai_utils.chatmodel import ChatOpenAI | |
import chainlit as cl | |
system_template = """\ | |
You are a helpful AI assistant that answers questions based on the provided context. | |
Your task is to: | |
1. Carefully read and understand the context | |
2. Answer the user's question using ONLY the information from the context | |
3. If the answer cannot be found in the context, say "I cannot find the answer in the provided context" | |
4. If you find partial information, share what you found and indicate if more information might be needed | |
Remember: Only use information from the provided context to answer questions.""" | |
system_role_prompt = SystemRolePrompt(system_template) | |
user_prompt_template = """\ | |
Context: | |
{context} | |
Based on the above context, please answer the following question. If the answer cannot be found in the context, say "I cannot find the answer in the provided context." If you find partial information, share what you found and indicate if more information might be needed. | |
Question: | |
{question} | |
Please provide a clear and concise answer based ONLY on the information in the context above.""" | |
user_role_prompt = UserRolePrompt(user_prompt_template) | |
class RetrievalAugmentedQAPipeline: | |
def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None: | |
self.llm = llm | |
self.vector_db_retriever = vector_db_retriever | |
async def arun_pipeline(self, user_query: str): | |
# Get more contexts with a broader search | |
print("\nSearching for relevant contexts...") | |
context_list = self.vector_db_retriever.search_by_text(user_query, k=5) # Increased from 3 to 5 | |
print("\nRetrieved contexts:") | |
for i, (context, score) in enumerate(context_list): | |
print(f"\nContext {i+1} (score: {score:.3f}):") | |
print(context[:500] + "..." if len(context) > 500 else context) # Show more context | |
# Limit total context length to approximately 3000 tokens (12000 characters) | |
context_prompt = "" | |
total_length = 0 | |
max_length = 12000 # Reduced from 24000 to 12000 | |
# Sort contexts by score before truncating | |
sorted_contexts = sorted(context_list, key=lambda x: x[1], reverse=True) | |
for context, score in sorted_contexts: | |
if total_length + len(context) > max_length: | |
print(f"\nSkipping context with score {score:.3f} due to length limit") | |
continue | |
context_prompt += context + "\n" | |
total_length += len(context) | |
print(f"\nUsing {len(context_prompt.split())} words of context") | |
formatted_system_prompt = system_role_prompt.create_message() | |
formatted_user_prompt = user_role_prompt.create_message(question=user_query, context=context_prompt) | |
print("\nFinal messages being sent to the model:") | |
print("\nSystem prompt:") | |
print(formatted_system_prompt) | |
print("\nUser prompt:") | |
print(formatted_user_prompt) | |
async def generate_response(): | |
async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]): | |
yield chunk | |
return {"response": generate_response(), "context": context_list} | |
text_splitter = CharacterTextSplitter() | |
def process_file(file: AskFileResponse): | |
import tempfile | |
import shutil | |
print(f"Processing file: {file.name}") | |
# Create a temporary file with the correct extension | |
suffix = f".{file.name.split('.')[-1]}" | |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: | |
# Copy the uploaded file content to the temporary file | |
shutil.copyfile(file.path, temp_file.name) | |
print(f"Created temporary file at: {temp_file.name}") | |
# Create appropriate loader | |
if file.name.lower().endswith('.pdf'): | |
loader = PDFLoader(temp_file.name) | |
else: | |
loader = TextFileLoader(temp_file.name) | |
try: | |
# Load and process the documents | |
documents = loader.load_documents() | |
texts = text_splitter.split_texts(documents) | |
return texts | |
finally: | |
# Clean up the temporary file | |
try: | |
os.unlink(temp_file.name) | |
except Exception as e: | |
print(f"Error cleaning up temporary file: {e}") | |
async def on_chat_start(): | |
files = None | |
# Wait for the user to upload a file | |
while files == None: | |
files = await cl.AskFileMessage( | |
content="Please upload a Text or PDF file to begin!", | |
accept=["text/plain", "application/pdf"], | |
max_size_mb=2, | |
timeout=180, | |
).send() | |
file = files[0] | |
print(f"Received file: {file.name} ({file.type})") | |
msg = cl.Message( | |
content=f"Processing `{file.name}`..." | |
) | |
await msg.send() | |
# load the file | |
try: | |
texts = process_file(file) | |
print(f"Successfully processed file. Generated {len(texts)} text chunks") | |
print("Sample of first chunk:", texts[0][:200] if texts else "No texts generated") | |
except Exception as e: | |
print(f"Error processing file: {str(e)}") | |
await cl.Message(content=f"Error processing file: {str(e)}").send() | |
return | |
# Create a dict vector store | |
try: | |
vector_db = VectorDatabase() | |
vector_db = await vector_db.abuild_from_list(texts) | |
print("Successfully created vector database") | |
except Exception as e: | |
print(f"Error creating vector database: {str(e)}") | |
await cl.Message(content=f"Error creating vector database: {str(e)}").send() | |
return | |
try: | |
chat_openai = ChatOpenAI() | |
print("Successfully initialized ChatOpenAI") | |
except Exception as e: | |
print(f"Error initializing ChatOpenAI: {str(e)}") | |
await cl.Message(content=f"Error initializing ChatOpenAI: {str(e)}").send() | |
return | |
# Create a chain | |
retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline( | |
vector_db_retriever=vector_db, | |
llm=chat_openai | |
) | |
# Let the user know that the system is ready | |
msg.content = f"Processing `{file.name}` done. You can now ask questions!" | |
await msg.update() | |
cl.user_session.set("chain", retrieval_augmented_qa_pipeline) | |
async def main(message): | |
chain = cl.user_session.get("chain") | |
if not chain: | |
await cl.Message(content="Error: Chat session not initialized. Please try uploading the file again.").send() | |
return | |
msg = cl.Message(content="") | |
try: | |
result = await chain.arun_pipeline(message.content) | |
print(f"Retrieved {len(result['context'])} relevant contexts") | |
async for stream_resp in result["response"]: | |
await msg.stream_token(stream_resp) | |
await msg.send() | |
except Exception as e: | |
print(f"Error in chat pipeline: {str(e)}") | |
await cl.Message(content=f"Error processing your question: {str(e)}").send() |