File size: 5,098 Bytes
7fdb8e9 cc3f1e1 7fdb8e9 cc3f1e1 7fdb8e9 cc3f1e1 7fdb8e9 9cb6543 7fdb8e9 cc3f1e1 7fdb8e9 cc3f1e1 7fdb8e9 cc3f1e1 7fdb8e9 cc3f1e1 7fdb8e9 cc3f1e1 7fdb8e9 cc3f1e1 3823e3e cc3f1e1 3823e3e cc3f1e1 3823e3e cc3f1e1 7fdb8e9 cc3f1e1 3823e3e 7fdb8e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import concurrent.futures
import os
from loguru import logger
from qdrant_client.models import FieldCondition, Filter, MatchValue
from openai import OpenAI
from rag_demo.preprocessing.base import (
EmbeddedChunk,
)
from rag_demo.rag.base.query import EmbeddedQuery, Query
from .query_expansion import QueryExpansion
from .reranker import Reranker
from .prompt_templates import AnswerGenerationTemplate
from .source_annotator import SourceAnnotator
from .query_classifier import QueryClassifier
from dotenv import load_dotenv
load_dotenv()
def flatten(nested_list: list) -> list:
"""Flatten a list of lists into a single list."""
return [item for sublist in nested_list for item in sublist]
class RAGPipeline:
def __init__(self, mock: bool = False) -> None:
self._query_expander = QueryExpansion(mock=mock)
self._reranker = Reranker(mock=mock)
self._source_annotator = SourceAnnotator()
self._query_classifier = QueryClassifier(mock=mock)
def search(
self,
query: str,
k: int = 3,
expand_to_n_queries: int = 3,
) -> list:
query_model = Query.from_str(query)
n_generated_queries = self._query_expander.generate(
query_model, expand_to_n=expand_to_n_queries
)
logger.info(
f"Successfully generated {len(n_generated_queries)} search queries.",
)
with concurrent.futures.ThreadPoolExecutor() as executor:
search_tasks = [
executor.submit(self._search, _query_model, k)
for _query_model in n_generated_queries
]
n_k_documents = [
task.result() for task in concurrent.futures.as_completed(search_tasks)
]
n_k_documents = flatten(n_k_documents)
n_k_documents = list(set(n_k_documents))
logger.info(f"{len(n_k_documents)} documents retrieved successfully")
if len(n_k_documents) > 0:
# k_documents = self.rerank(query, chunks=n_k_documents, keep_top_k=k)
k_documents = n_k_documents[:k]
else:
k_documents = []
return k_documents
def _search(self, query: Query, k: int = 3) -> list[EmbeddedChunk]:
assert k >= 3, "k should be >= 3"
def _search_data(
data_category_odm: type[EmbeddedChunk], embedded_query: EmbeddedQuery
) -> list[EmbeddedChunk]:
return data_category_odm.search(
query_vector=embedded_query.embedding,
limit=k,
)
api = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
embedded_query: EmbeddedQuery = EmbeddedQuery(
embedding=api.embeddings.create(
model="text-embedding-3-small", input=query.content
)
.data[0]
.embedding,
id=query.id,
content=query.content,
)
retrieved_chunks = _search_data(EmbeddedChunk, embedded_query)
logger.info(f"{len(retrieved_chunks)} documents retrieved successfully")
return retrieved_chunks
def rerank(
self, query: str | Query, chunks: list[EmbeddedChunk], keep_top_k: int
) -> list[EmbeddedChunk]:
if isinstance(query, str):
query = Query.from_str(query)
reranked_documents = self._reranker.generate(
query=query, chunks=chunks, keep_top_k=keep_top_k
)
logger.info(f"{len(reranked_documents)} documents reranked successfully.")
return reranked_documents
def generate_answer(self, query: str, reranked_chunks: list[EmbeddedChunk]) -> str:
context = ""
for chunk in reranked_chunks:
context += "\n Document: "
context += chunk.content
api = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
answer_generation_template = AnswerGenerationTemplate()
prompt = answer_generation_template.create_template(context, query)
logger.info(prompt)
response = api.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": prompt}],
max_tokens=8192,
)
return response.choices[0].message.content
def add_context(self, response: str, reranked_chunks: list[EmbeddedChunk]) -> str:
logger.info("Adding context to the answer")
return self._source_annotator.annotate(response, reranked_chunks)
def rag(self, query: str) -> tuple[str, list[str]]:
query_type = self._query_classifier.generate(Query.from_str(query))
logger.info(f"Query type: {query_type}")
if query_type == "Sources_needed":
docs = self.search(query, k=10)
else:
docs = []
answer = self.generate_answer(query, docs)
if docs:
annotated_answer = self.add_context(answer, docs)
else:
annotated_answer = answer
return (
annotated_answer,
list(set([doc.metadata["filename"].split(".pdf")[0] for doc in docs])),
)
|