Spaces:
Running
Running
File size: 5,137 Bytes
2069576 66cb21c 2069576 7357b62 2069576 7357b62 2069576 66cb21c 2069576 66cb21c 2069576 66cb21c 2069576 66cb21c 2069576 66cb21c 2069576 ee4ea51 2069576 66cb21c 2069576 66cb21c 2069576 66cb21c 2069576 66cb21c 2069576 66cb21c 2069576 66cb21c 2069576 66cb21c 2069576 66cb21c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
#############################################################################
# Title: BERUFENET.AI
# Author: Andreas Fischer
# Date: January 4th, 2024
# Last update: October 15th, 2024
#############################################################################
import os
import chromadb
from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.utils import embedding_functions
import torch # chromaDB
from transformers import AutoTokenizer, AutoModel # chromaDB
from huggingface_hub import InferenceClient # Gradio-Interface
import gradio as gr # Gradio-Interface
import json # Gradio-Interface
dbPath="/home/af/Schreibtisch/Code/gradio/BERUFENET/db"
if(os.path.exists(dbPath)==False): dbPath="/home/user/app/db"
print(dbPath)
# Chroma-DB
#-----------
jina = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-de', trust_remote_code=True, torch_dtype=torch.bfloat16)
#jira.save_pretrained("jinaai_jina-embeddings-v2-base-de")
device='cuda:0' if torch.cuda.is_available() else 'cpu'
jina.to(device) #cuda:0
print(device)
class JinaEmbeddingFunction(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
embeddings = jina.encode(input) #max_length=2048
return(embeddings.tolist())
path=dbPath
client = chromadb.PersistentClient(path=path)
print(client.heartbeat())
print(client.get_version())
print(client.list_collections())
#default_ef = embedding_functions.DefaultEmbeddingFunction()
#sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="T-Systems-onsite/cross-en-de-roberta-sentence-transformer")
#instructor_ef = embedding_functions.InstructorEmbeddingFunction(model_name="hkunlp/instructor-large", device="cuda")
jina_ef=JinaEmbeddingFunction()
embeddingFunction=jina_ef
print(str(client.list_collections()))
global collection
if("name=BerufenetDB1" in str(client.list_collections())):
print("BerufenetDB1 found!")
collection = client.get_collection(name="BerufenetDB1", embedding_function=embeddingFunction)
print("Database ready!")
print(collection.count())
# Gradio-GUI
#------------
myModel="mistralai/Mixtral-8x7B-Instruct-v0.1"
def format_prompt(message, history):
prompt = "" #"<s>"
#for user_prompt, bot_response in history:
# prompt += f"[INST] {user_prompt} [/INST]"
# prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
def response(prompt, history, hfToken):
inferenceClient=""
if(hfToken.startswith("hf_")): # use HF-hub with custom token if token is provided
inferenceClient = InferenceClient(model=myModel, token=hfToken)
else:
inferenceClient = InferenceClient(myModel)
generate_kwargs = dict(temperature=float(0.9), max_new_tokens=500, top_p=0.95, repetition_penalty=1.0, do_sample=True, seed=42)
addon=""
results=collection.query(
query_texts=[prompt],
n_results=5
)
dists=["<br><small>(relevance: "+str(round((1-d)*100)/100)+";" for d in results['distances'][0]]
sources=["source: "+s["source"]+")</small>" for s in results['metadatas'][0]]
results=results['documents'][0]
combination = zip(results,dists,sources)
combination = [' '.join(triplets) for triplets in combination]
print(str(prompt)+"\n\n"+str(combination))
if(len(results)>1):
addon=" Bitte berücksichtige bei deiner Antwort ggf. folgende Auszüge aus unserer Datenbank, sofern sie für die Antwort erforderlich sind. Beantworte die Frage knapp und präzise. Ignoriere unpassende Datenbank-Auszüge OHNE sie zu kommentieren, zu erwähnen oder aufzulisten:\n"+"\n".join(results)
system="Du bist ein deutschsprachiges KI-basiertes Assistenzsystem, das zu jedem Anliegen möglichst geeignete Berufe empfiehlt."+addon+"\n\nUser-Anliegen:"
formatted_prompt = format_prompt(system+"\n"+prompt, history)
output = ""
print(""+str(inferenceClient))
try:
stream = inferenceClient.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
for response in stream:
output += response.token.text
yield output
except Exception as e:
output = "Für weitere Antworten von der KI gebe bitte einen gültigen HuggingFace-Token an."
if(len(combination)>0):
output += "\nBis dahin helfen dir hoffentlich die folgenden Quellen weiter:"
yield output
print(str(e))
output=output+"\n\n<br><details open><summary><strong>Sources</strong></summary><br><ul>"+ "".join(["<li>" + s + "</li>" for s in combination])+"</ul></details>"
yield output
gr.ChatInterface(
response,
chatbot=gr.Chatbot(value=[[None,"Herzlich willkommen! Ich bin ein KI-basiertes Assistenzsystem, das für jede Anfrage die am besten passenden Berufe empfiehlt.<br>Erzähle mir, was du gerne tust!"]],render_markdown=True),
title="BERUFENET.AI (Jina-Embeddings)",
additional_inputs=[
gr.Textbox(
value="",
label="HF_token"),
]
).queue().launch(share=True) #False, server_name="0.0.0.0", server_port=7864)
print("Interface up and running!")
|