zaldivards's picture
feat: enhance predictions quality
f66e4f3
raw
history blame
5.19 kB
import glob
import os
import pickle
import re
from pathlib import Path
import gradio as gr
import spaces
import numpy as np
from pypdf import PdfReader
from transformers import AutoModel
chunk_size = int(os.environ.get("CHUNK_SIZE", 1000))
default_k = int(os.environ.get("DEFAULT_K", 5))
model = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-es", trust_remote_code=True)
docs = {}
def extract_text_from_pdf(reader: PdfReader) -> str:
"""Extract text from PDF pages
Parameters
----------
reader : PdfReader
PDF reader
Returns
-------
str
Raw text
"""
content = [page.extract_text().strip() for page in reader.pages]
return "\n\n".join(content).strip()
def convert(filename: str) -> str:
"""Convert file content to raw text
Parameters
----------
filename : str
The filename or path
Returns
-------
str
The raw text
Raises
------
ValueError
If the file type is not supported.
"""
plain_text_filetypes = [
".txt",
".csv",
".tsv",
".md",
".yaml",
".toml",
".json",
".json5",
".jsonc",
]
# Already a plain text file that wouldn't benefit from pandoc so return the content
if any(filename.endswith(ft) for ft in plain_text_filetypes):
with open(filename, "r", encoding="utf-8") as f:
return f.read()
if filename.endswith(".pdf"):
return extract_text_from_pdf(PdfReader(filename))
raise ValueError(f"Unsupported file type: {filename}")
def generate_chunks(text: str, max_length: int) -> list[str]:
"""Generate chunks from a file's raw text. Chunks are calculated based
on the `max_lenght` parameter and the split character (.)
Parameters
----------
text : str
The raw text
max_length : int
Maximum number of characters a chunk can have. Note that chunks
may not have this exact lenght, as another component is also
involved in the splitting process
Returns
-------
list[str]
A list of chunks/nodes
"""
segments = text.split(".")
chunks = []
chunk = ""
for current_segment in segments:
# try to normalize the current chunk
current_segment = re.sub(r"\s+", " ", current_segment).strip()
if len(chunk) < max_length:
chunk += f". {current_segment}"
else:
chunks.append(chunk)
chunk = current_segment
if chunk:
chunks.append(chunk)
return chunks
@spaces.GPU
def predict(query: str, k: int = 5) -> str:
"""Find k most relevant chunks based on the given query
Parameters
----------
query : str
The input query
k : int, optional
Number of relevant chunks to return, by default 5
Returns
-------
str
The k chunks concatenated together as a single string.
Example
-------
If k=2, the returned string might look like:
"CONTEXT:\n\nchunk-1\n\nchunk-2"
"""
# Embed the query
query_embedding = model.encode(query)
# Initialize a list to store all chunks and their similarities across all documents
all_chunks = []
# Iterate through all documents
for filename, doc in docs.items():
# Calculate cosine similarity between the query and the document embeddings
similarities = np.dot(doc["embeddings"], query_embedding) / (
np.linalg.norm(doc["embeddings"]) * np.linalg.norm(query_embedding)
)
# Add chunks and similarities to the all_chunks list
all_chunks.extend([(filename, chunk, sim) for chunk, sim in zip(doc["chunks"], similarities)])
# Sort all chunks by similarity
all_chunks.sort(key=lambda x: x[2], reverse=True)
return "CONTEXT:\n\n" + "\n\n".join(f"{filename}: {chunk}" for filename, chunk, _ in all_chunks[:k])
def init():
"""Init function
It will load or calculate the embeddings
"""
global docs # pylint: disable=W0603
embeddings_file = Path("embeddings-es.pickle")
if embeddings_file.exists():
with open(embeddings_file, "rb") as embeddings_pickle:
docs = pickle.load(embeddings_pickle)
else:
for filename in glob.glob("sources/*"):
converted_doc = convert(filename)
chunks = generate_chunks(converted_doc, chunk_size)
embeddings = model.encode(chunks)
# get the filename and slugify it
docs[filename.rsplit("/", 1)[-1].lower().replace(" ", "-")] = {
"chunks": chunks,
"embeddings": embeddings,
}
with open(embeddings_file, "wb") as pickle_file:
pickle.dump(docs, pickle_file)
init()
gr.Interface(
predict,
inputs=[
gr.Textbox(label="Query asked about the documents"),
gr.Number(label="Number of relevant sources returned (k)", value=default_k),
],
outputs=[gr.Text(label="Relevant chunks")],
title="ContextQA tool - El Salvador",
description="Forked and customized RAG tool working with law documents from El Salvador",
).launch()