File size: 5,081 Bytes
3ff674d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from datetime import datetime
import warnings
from typing import List
from pickle import Unpickler
import re

from bs4 import BeautifulSoup
from groq import Groq
from cohere import Client
from numpy.typing import NDArray
from numpy import array

from gossip_semantic_search.models import Article, Answer
from gossip_semantic_search.constant import (AUTHOR_KEY, TITLE_KEY, LINK_KEY, DESCRIPTION_KEY,
                                             PUBLICATION_DATE_KEY, CONTENT_KEY, LLAMA_70B_MODEL,
                                             DATE_FORMAT, EMBEDING_MODEL)
from gossip_semantic_search.prompts import (generate_question_prompt,
                                            generate_context_retriver_prompt)


def xml_to_dict(element):
    result = {}

    for child in element:
        child_dict = xml_to_dict(child)
        if child.tag in result:
            if isinstance(result[child.tag], list):
                result[child.tag].append(child_dict)
            else:
                result[child.tag] = [result[child.tag], child_dict]
        else:
            result[child.tag] = child_dict

    if element.text and element.text.strip():
        result = element.text.strip()

    return result


def sanitize_html_content(html_content):
    soup = BeautifulSoup(html_content, 'html.parser')

    for a in soup.find_all('a'):
        a.unwrap()

    for tag in soup.find_all(['em', 'strong']):
        tag.unwrap()

    for blockquote in soup.find_all('blockquote'):
        blockquote.extract()

    cleaned_text = re.sub(r'\s+', ' ', soup.get_text()).strip()
    return cleaned_text

def article_raw_to_article(raw_article) -> Article:
    return Article(
        author = raw_article[AUTHOR_KEY],
        title = raw_article[TITLE_KEY],
        link = raw_article[LINK_KEY],
        description = raw_article[DESCRIPTION_KEY],
        published_date = datetime.strptime(
            raw_article[PUBLICATION_DATE_KEY],
            DATE_FORMAT
        ),
        content = sanitize_html_content(raw_article[CONTENT_KEY])
    )

def generates_questions(context: str,
                        nb_questions: int,
                        client: Groq) -> List[str]:
    completion = client.chat.completions.create(
        model=LLAMA_70B_MODEL,
        messages=[
            {
                "role": "user",
                "content": generate_question_prompt(context, nb_questions)
            },
        ],
        temperature=1,
        max_tokens=1024,
        top_p=1,
        stream=True,
        stop=None,
    )
    questions_str = "".join(chunk.choices[0].delta.content or "" for chunk in completion)

    try:
        questions = re.findall(r'([^?]*\?)', questions_str)
        questions = [question.strip()[3:] for question in questions]

    except IndexError:
        warnings.warn(f"no question found. \n"
                      f"string return: {questions_str}")
        return []

    if len(questions) != nb_questions:
        warnings.warn(f"Expected {nb_questions} questions, but found "
                      f"{len(questions)}. {', '.join(questions)}", UserWarning)

    return questions

def choose_context_and_answer_questions(articles: List[Article],
                                        query:str,
                                        generative_client) -> Answer:
    for article in articles:
        completion = generative_client.chat.completions.create(
            model=LLAMA_70B_MODEL,
            messages=[
                {
                    "role": "user",
                    "content": generate_context_retriver_prompt(query, article.content)
                },
            ],
            temperature=1,
            max_tokens=1024,
            top_p=1,
            stream=True,
            stop=None,
        )
        answer = "".join(chunk.choices[0].delta.content or "" for chunk in completion)
        pattern = r"answer_in_text\s*=\s*(.*?),"

        # Appliquer la regex
        match = re.search(pattern, answer)
        if match:
            if match.group(1) == "True":
                pattern = r"answer\s*=\s*(.*)"
                match = re.search(pattern, answer)
                if match:
                    answer_value = match.group(1)[1:-2]
                    return Answer(
                            answer = answer_value,
                            link = f"{article.link}",
                            content = f"{article.content}"
                        )

    return Answer(
        answer = "incapable de générer une reponse",
        link = f"{articles[0].link}",
        content = f"{articles[0].content}"
    )


def embed_content(contexts:List[str],
                  client: Client) -> NDArray:
    return array(client.embed(
        model=EMBEDING_MODEL,
        texts=contexts,
        input_type='classification',
        truncate='NONE'
    ).embeddings)

class CustomUnpickler(Unpickler):
    def find_class(self, module, name):
        if module == 'models':
            return Article  # Renvoie une classe de remplacement
        return super().find_class(module, name)