File size: 5,186 Bytes
8d0f9c9
f66e4f3
0918d3a
f66e4f3
0918d3a
 
 
 
 
8d0f9c9
f66e4f3
8d0f9c9
0918d3a
 
 
8d0f9c9
f66e4f3
 
0918d3a
8d0f9c9
 
0918d3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d0f9c9
 
0918d3a
 
8d0f9c9
0918d3a
 
 
 
8d0f9c9
0918d3a
 
 
 
8d0f9c9
0918d3a
 
 
 
 
8d0f9c9
 
 
 
 
 
 
 
 
 
 
 
 
0918d3a
8d0f9c9
 
 
 
 
 
 
 
0918d3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d0f9c9
0918d3a
 
 
f66e4f3
 
0918d3a
f66e4f3
0918d3a
 
 
 
 
8d0f9c9
 
0918d3a
8d0f9c9
0918d3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d0f9c9
f66e4f3
8d0f9c9
 
 
 
f66e4f3
 
0918d3a
 
 
8d0f9c9
f66e4f3
8d0f9c9
 
f66e4f3
8d0f9c9
f66e4f3
8d0f9c9
 
0918d3a
 
8d0f9c9
0918d3a
 
 
f66e4f3
0918d3a
 
 
 
 
 
 
 
f66e4f3
 
0918d3a
 
 
 
 
8d0f9c9
 
0918d3a
8d0f9c9
 
 
 
 
 
0918d3a
8d0f9c9
0918d3a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import glob
import os
import pickle
import re
from pathlib import Path

import gradio as gr
import spaces
import numpy as np
from pypdf import PdfReader
from transformers import AutoModel


chunk_size = int(os.environ.get("CHUNK_SIZE", 1000))
default_k = int(os.environ.get("DEFAULT_K", 5))

model = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-es", trust_remote_code=True)

docs = {}


def extract_text_from_pdf(reader: PdfReader) -> str:
    """Extract text from PDF pages

    Parameters
    ----------
    reader : PdfReader
        PDF reader

    Returns
    -------
    str
        Raw text
    """
    content = [page.extract_text().strip() for page in reader.pages]
    return "\n\n".join(content).strip()


def convert(filename: str) -> str:
    """Convert file content to raw text

    Parameters
    ----------
    filename : str
        The filename or path

    Returns
    -------
    str
        The raw text

    Raises
    ------
    ValueError
        If the file type is not supported.
    """
    plain_text_filetypes = [
        ".txt",
        ".csv",
        ".tsv",
        ".md",
        ".yaml",
        ".toml",
        ".json",
        ".json5",
        ".jsonc",
    ]
    # Already a plain text file that wouldn't benefit from pandoc so return the content
    if any(filename.endswith(ft) for ft in plain_text_filetypes):
        with open(filename, "r", encoding="utf-8") as f:
            return f.read()

    if filename.endswith(".pdf"):
        return extract_text_from_pdf(PdfReader(filename))

    raise ValueError(f"Unsupported file type: {filename}")


def generate_chunks(text: str, max_length: int) -> list[str]:
    """Generate chunks from a file's raw text. Chunks are calculated based
    on the `max_lenght` parameter and the split character (.)

    Parameters
    ----------
    text : str
        The raw text
    max_length : int
        Maximum number of characters a chunk can have. Note that chunks
        may not have this exact lenght, as another component is also
        involved in the splitting process

    Returns
    -------
    list[str]
        A list of chunks/nodes
    """

    segments = text.split(".")
    chunks = []
    chunk = ""

    for current_segment in segments:
        # try to normalize the current chunk
        current_segment = re.sub(r"\s+", " ", current_segment).strip()
        if len(chunk) < max_length:
            chunk += f". {current_segment}"
        else:
            chunks.append(chunk)
            chunk = current_segment
    if chunk:
        chunks.append(chunk)
    return chunks


@spaces.GPU
def predict(query: str, k: int = 5) -> str:
    """Find k most relevant chunks based on the given query

    Parameters
    ----------
    query : str
        The input query
    k : int, optional
        Number of relevant chunks to return, by default 5

    Returns
    -------
    str
        The k chunks concatenated together as a single string.

    Example
    -------
    If k=2, the returned string might look like:

    "CONTEXT:\n\nchunk-1\n\nchunk-2"

    """
    # Embed the query
    query_embedding = model.encode(query)

    # Initialize a list to store all chunks and their similarities across all documents
    all_chunks = []
    # Iterate through all documents
    for filename, doc in docs.items():
        # Calculate cosine similarity between the query and the document embeddings
        similarities = np.dot(doc["embeddings"], query_embedding) / (
            np.linalg.norm(doc["embeddings"]) * np.linalg.norm(query_embedding)
        )
        # Add chunks and similarities to the all_chunks list
        all_chunks.extend([(filename, chunk, sim) for chunk, sim in zip(doc["chunks"], similarities)])

    # Sort all chunks by similarity
    all_chunks.sort(key=lambda x: x[2], reverse=True)

    return "CONTEXT:\n\n" + "\n\n".join(f"{filename}: {chunk}" for filename, chunk, _ in all_chunks[:k])


def init():
    """Init function

    It will load or calculate the embeddings
    """
    global docs  # pylint: disable=W0603
    embeddings_file = Path("embeddings-es.pickle")
    if embeddings_file.exists():
        with open(embeddings_file, "rb") as embeddings_pickle:
            docs = pickle.load(embeddings_pickle)
    else:
        for filename in glob.glob("sources/*"):
            converted_doc = convert(filename)
            chunks = generate_chunks(converted_doc, chunk_size)
            embeddings = model.encode(chunks)
            # get the filename and slugify it
            docs[filename.rsplit("/", 1)[-1].lower().replace(" ", "-")] = {
                "chunks": chunks,
                "embeddings": embeddings,
            }
        with open(embeddings_file, "wb") as pickle_file:
            pickle.dump(docs, pickle_file)


init()


gr.Interface(
    predict,
    inputs=[
        gr.Textbox(label="Query asked about the documents"),
        gr.Number(label="Number of relevant sources returned (k)", value=default_k),
    ],
    outputs=[gr.Text(label="Relevant chunks")],
    title="ContextQA tool - El Salvador",
    description="Forked and customized RAG tool working with law documents from El Salvador",
).launch()