Spaces:
Sleeping
Sleeping
from datetime import datetime | |
import warnings | |
from typing import List | |
from pickle import Unpickler | |
import re | |
from bs4 import BeautifulSoup | |
from groq import Groq | |
from cohere import Client | |
from numpy.typing import NDArray | |
from numpy import array | |
from gossip_semantic_search.models import Article, Answer | |
from gossip_semantic_search.constant import (AUTHOR_KEY, TITLE_KEY, LINK_KEY, DESCRIPTION_KEY, | |
PUBLICATION_DATE_KEY, CONTENT_KEY, LLAMA_70B_MODEL, | |
DATE_FORMAT, EMBEDING_MODEL) | |
from gossip_semantic_search.prompts import (generate_question_prompt, | |
generate_context_retriver_prompt) | |
def xml_to_dict(element): | |
result = {} | |
for child in element: | |
child_dict = xml_to_dict(child) | |
if child.tag in result: | |
if isinstance(result[child.tag], list): | |
result[child.tag].append(child_dict) | |
else: | |
result[child.tag] = [result[child.tag], child_dict] | |
else: | |
result[child.tag] = child_dict | |
if element.text and element.text.strip(): | |
result = element.text.strip() | |
return result | |
def sanitize_html_content(html_content): | |
soup = BeautifulSoup(html_content, 'html.parser') | |
for a in soup.find_all('a'): | |
a.unwrap() | |
for tag in soup.find_all(['em', 'strong']): | |
tag.unwrap() | |
for blockquote in soup.find_all('blockquote'): | |
blockquote.extract() | |
cleaned_text = re.sub(r'\s+', ' ', soup.get_text()).strip() | |
return cleaned_text | |
def article_raw_to_article(raw_article) -> Article: | |
return Article( | |
author = raw_article[AUTHOR_KEY], | |
title = raw_article[TITLE_KEY], | |
link = raw_article[LINK_KEY], | |
description = raw_article[DESCRIPTION_KEY], | |
published_date = datetime.strptime( | |
raw_article[PUBLICATION_DATE_KEY], | |
DATE_FORMAT | |
), | |
content = sanitize_html_content(raw_article[CONTENT_KEY]) | |
) | |
def generates_questions(context: str, | |
nb_questions: int, | |
client: Groq) -> List[str]: | |
completion = client.chat.completions.create( | |
model=LLAMA_70B_MODEL, | |
messages=[ | |
{ | |
"role": "user", | |
"content": generate_question_prompt(context, nb_questions) | |
}, | |
], | |
temperature=1, | |
max_tokens=1024, | |
top_p=1, | |
stream=True, | |
stop=None, | |
) | |
questions_str = "".join(chunk.choices[0].delta.content or "" for chunk in completion) | |
try: | |
questions = re.findall(r'([^?]*\?)', questions_str) | |
questions = [question.strip()[3:] for question in questions] | |
except IndexError: | |
warnings.warn(f"no question found. \n" | |
f"string return: {questions_str}") | |
return [] | |
if len(questions) != nb_questions: | |
warnings.warn(f"Expected {nb_questions} questions, but found " | |
f"{len(questions)}. {', '.join(questions)}", UserWarning) | |
return questions | |
def choose_context_and_answer_questions(articles: List[Article], | |
query:str, | |
generative_client) -> Answer: | |
for article in articles: | |
completion = generative_client.chat.completions.create( | |
model=LLAMA_70B_MODEL, | |
messages=[ | |
{ | |
"role": "user", | |
"content": generate_context_retriver_prompt(query, article.content) | |
}, | |
], | |
temperature=1, | |
max_tokens=1024, | |
top_p=1, | |
stream=True, | |
stop=None, | |
) | |
answer = "".join(chunk.choices[0].delta.content or "" for chunk in completion) | |
pattern = r"answer_in_text\s*=\s*(.*?)," | |
# Appliquer la regex | |
match = re.search(pattern, answer) | |
if match: | |
if match.group(1) == "True": | |
pattern = r"answer\s*=\s*(.*)" | |
match = re.search(pattern, answer) | |
if match: | |
answer_value = match.group(1)[1:-2] | |
return Answer( | |
answer = answer_value, | |
link = f"{article.link}", | |
content = f"{article.content}" | |
) | |
return Answer( | |
answer = "incapable de générer une reponse", | |
link = f"{articles[0].link}", | |
content = f"{articles[0].content}" | |
) | |
def embed_content(contexts:List[str], | |
client: Client) -> NDArray: | |
return array(client.embed( | |
model=EMBEDING_MODEL, | |
texts=contexts, | |
input_type='classification', | |
truncate='NONE' | |
).embeddings) | |
class CustomUnpickler(Unpickler): | |
def find_class(self, module, name): | |
if module == 'models': | |
return Article # Renvoie une classe de remplacement | |
return super().find_class(module, name) | |