Spaces:
Running
Running
import glob | |
import os | |
import pickle | |
import re | |
from pathlib import Path | |
import gradio as gr | |
import spaces | |
import numpy as np | |
from pypdf import PdfReader | |
from transformers import AutoModel | |
chunk_size = int(os.environ.get("CHUNK_SIZE", 250)) | |
default_k = int(os.environ.get("DEFAULT_K", 5)) | |
model = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-es", trust_remote_code=True) | |
replace_pairs = [ | |
(r"¢\s+100.00", "$50"), | |
(r"¢\s+300.00", "$100"), | |
(r"¢\s+500.00", "$150"), | |
# Attempt to normalize the current chunk by removing more than one consecutive space, | |
# while preserving single spaces within words | |
(r"(?<!\w|[.,;]) +", " "), | |
# remove more than one line break, multiple underscores and unwanted headers or footers | |
(r"(?<!\w|[ .:])\n|_+|INDICE LEGISLATIVO|ASAMBLEA LEGISLATIVA \- REPUBLICA DE EL SALVADOR", ""), | |
] | |
docs = {} | |
def extract_text_from_pdf(reader: PdfReader) -> str: | |
"""Extract text from PDF pages | |
Parameters | |
---------- | |
reader : PdfReader | |
PDF reader | |
Returns | |
------- | |
str | |
Raw text | |
""" | |
content = [page.extract_text().strip() for page in reader.pages] | |
return "\n\n".join(content).strip() | |
def convert(filename: str) -> str: | |
"""Convert file content to raw text | |
Parameters | |
---------- | |
filename : str | |
The filename or path | |
Returns | |
------- | |
str | |
The raw text | |
Raises | |
------ | |
ValueError | |
If the file type is not supported. | |
""" | |
plain_text_filetypes = [".txt", ".csv", ".tsv", ".md", ".yaml", ".toml", ".json", ".json5", ".jsonc"] | |
# Already a plain text file that wouldn't benefit from pandoc so return the content | |
if any(filename.endswith(ft) for ft in plain_text_filetypes): | |
with open(filename, "r", encoding="utf-8") as f: | |
return f.read() | |
if filename.endswith(".pdf"): | |
return extract_text_from_pdf(PdfReader(filename)) | |
raise ValueError(f"Unsupported file type: {filename}") | |
def add_prefix(chunk: str, art_prefix: str) -> tuple[str, str]: | |
"""Add prefix to chunks that are continuation of a certain article | |
Parameters | |
---------- | |
chunk : str | |
original chunk | |
art_prefix : str | |
current prefix | |
Returns | |
------- | |
tuple[str, str] | |
The updated chunk and the new prefix | |
""" | |
results = re.findall(r"(Articulo \d+)\s+-", chunk) | |
ignore_results = False | |
if (len(results) == 1 and chunk.find(results[0]) > 4 and art_prefix) or not results: | |
results.insert(0, art_prefix) | |
elif len(results) == 1 and chunk.find(results[0]) <= 4: | |
ignore_results = True | |
art_prefix = results[-1] | |
# if the current chunk is a continuation of a certain article, an identifier prefix will be added to it | |
return (f"<<{'|'.join(results)}>>{chunk}" if results and not ignore_results else chunk), art_prefix | |
def generate_chunks(text: str, max_length: int) -> list[str]: | |
"""Generate chunks from a file's raw text. Chunks are calculated based | |
on the `max_length` parameter and the split character (.) | |
Parameters | |
---------- | |
text : str | |
The raw text | |
max_length : int | |
Maximum number of characters a chunk can have. Note that chunks | |
may not have this exact length, as another component is also | |
involved in the splitting process | |
Returns | |
------- | |
list[str] | |
A list of chunks/nodes | |
""" | |
for match_result in re.finditer(r"Art\. (\d+)\.", text): | |
# replace Art. X. with Articulo X | |
text = text.replace(match_result.group(), f"Articulo {match_result.group(1)} ") | |
for regex, new in replace_pairs: | |
text = re.sub(regex, new, text) | |
chunks = [] | |
chunk = "" | |
art_prefix = "" | |
for current_segment in text.split("\n"): | |
remaining = "" | |
if len(chunk) + len(current_segment) + 1 <= max_length: | |
chunk += f" {current_segment}" | |
else: | |
remaining = current_segment | |
# split using period (.) but ignoring number such as 1.0, 2.000, etc | |
for idx, little_segment in enumerate(re.split(r"(?<!\d)\.", remaining)): | |
if len(chunk) + len(little_segment) + 2 <= max_length: | |
remaining = remaining.removeprefix(f"{little_segment}.") | |
chunk += f"{'.' if idx > 0 else ''} {little_segment}" | |
else: | |
break | |
if remaining: | |
chunk, art_prefix = add_prefix(chunk, art_prefix) | |
chunks.append(chunk.lower()) | |
chunk = remaining | |
if chunk: | |
chunk, _ = add_prefix(chunk, art_prefix) | |
chunks.append(chunk.lower()) | |
return chunks | |
def predict(query: str, k: int = 5) -> str: | |
"""Find k most relevant chunks based on the given query | |
Parameters | |
---------- | |
query : str | |
The input query | |
k : int, optional | |
Number of relevant chunks to return, by default 5 | |
Returns | |
------- | |
str | |
The k chunks concatenated together as a single string. | |
Example | |
------- | |
If k=2, the returned string might look like: | |
"CONTEXT:\n\nchunk-1\n\nchunk-2" | |
""" | |
# Embed the query | |
query_embedding = model.encode(query) | |
# Initialize a list to store all chunks and their similarities across all documents | |
all_chunks = [] | |
# Iterate through all documents | |
for filename, doc in docs.items(): | |
# Calculate cosine similarity between the query and the document embeddings | |
similarities = np.dot(doc["embeddings"], query_embedding) / ( | |
np.linalg.norm(doc["embeddings"]) * np.linalg.norm(query_embedding) | |
) | |
# Add chunks and similarities to the all_chunks list | |
all_chunks.extend([(filename, chunk, sim) for chunk, sim in zip(doc["chunks"], similarities)]) | |
# Sort all chunks by similarity | |
all_chunks.sort(key=lambda x: x[2], reverse=True) | |
return "CONTEXT:\n\n" + "\n\n".join(f"{filename}: {chunk}" for filename, chunk, _ in all_chunks[:k]) | |
def init(): | |
"""Init function | |
It will load or calculate the embeddings | |
""" | |
global docs # pylint: disable=W0603 | |
embeddings_file = Path("embeddings-es.pickle") | |
if embeddings_file.exists(): | |
with open(embeddings_file, "rb") as embeddings_pickle: | |
docs = pickle.load(embeddings_pickle) | |
else: | |
for filename in glob.glob("sources/*"): | |
converted_doc = convert(filename) | |
chunks = generate_chunks(converted_doc, chunk_size) | |
embeddings = model.encode(chunks) | |
# get the filename and slugify it | |
docs[filename.rsplit("/", 1)[-1].lower().replace(" ", "-")] = { | |
"chunks": chunks, | |
"embeddings": embeddings, | |
} | |
with open(embeddings_file, "wb") as pickle_file: | |
pickle.dump(docs, pickle_file) | |
init() | |
gr.Interface( | |
predict, | |
inputs=[ | |
gr.Textbox(label="Question asked about the documents"), | |
gr.Number(label="Number of relevant sources to return (k)", value=default_k), | |
], | |
outputs=[gr.Text(label="Relevant chunks")], | |
title="ContextqaSV", | |
description="RAG tool enabling questions and answers on legal documents from El Salvador. Legal" | |
" documents supported:\n- Constitución de la república\n- Reglamento de tránsito y seguridad vial", | |
).launch() | |