Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -11,7 +11,8 @@ import spaces
|
|
11 |
import torch
|
12 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
13 |
from pathlib import Path
|
14 |
-
|
|
|
15 |
from huggingface_hub import CommitScheduler
|
16 |
|
17 |
HF_UPLOAD = os.environ.get("HF_UPLOAD")
|
@@ -29,6 +30,15 @@ scheduler = CommitScheduler(
|
|
29 |
token=HF_UPLOAD
|
30 |
)
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
def save_json(question: str, answer: str) -> None:
|
33 |
with scheduler.lock:
|
34 |
with JSON_DATASET_PATH.open("a") as f:
|
@@ -36,9 +46,9 @@ def save_json(question: str, answer: str) -> None:
|
|
36 |
f.write("\n")
|
37 |
|
38 |
|
39 |
-
MAX_MAX_NEW_TOKENS =
|
40 |
-
DEFAULT_MAX_NEW_TOKENS =
|
41 |
-
MAX_INPUT_TOKEN_LENGTH =
|
42 |
|
43 |
DESCRIPTION = """\
|
44 |
# Llama-3 8B Korean QA Chatbot \
|
@@ -71,8 +81,22 @@ def generate(
|
|
71 |
conversation.append({"role": "system", "content": system_prompt})
|
72 |
for user, assistant in chat_history:
|
73 |
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
|
74 |
-
conversation.append({"role": "user", "content": message})
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True)
|
77 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
78 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
|
|
11 |
import torch
|
12 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
13 |
from pathlib import Path
|
14 |
+
from pinecone.grpc import PineconeGRPC as Pinecone
|
15 |
+
import torch
|
16 |
from huggingface_hub import CommitScheduler
|
17 |
|
18 |
HF_UPLOAD = os.environ.get("HF_UPLOAD")
|
|
|
30 |
token=HF_UPLOAD
|
31 |
)
|
32 |
|
33 |
+
pc = Pinecone(api_key=os.environ.get("PINECONE"))
|
34 |
+
index = pc.Index("commonsense")
|
35 |
+
|
36 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
37 |
+
retriever_tokenizer = AutoTokenizer.from_pretrained("psyche/dpr-longformer-ko-4096")
|
38 |
+
retriever = AutoModel.from_pretrained("psyche/dpr-longformer-ko-4096")
|
39 |
+
retriever.eval()
|
40 |
+
retriever.to(device)
|
41 |
+
|
42 |
def save_json(question: str, answer: str) -> None:
|
43 |
with scheduler.lock:
|
44 |
with JSON_DATASET_PATH.open("a") as f:
|
|
|
46 |
f.write("\n")
|
47 |
|
48 |
|
49 |
+
MAX_MAX_NEW_TOKENS = 8192
|
50 |
+
DEFAULT_MAX_NEW_TOKENS = 4096
|
51 |
+
MAX_INPUT_TOKEN_LENGTH = 2048
|
52 |
|
53 |
DESCRIPTION = """\
|
54 |
# Llama-3 8B Korean QA Chatbot \
|
|
|
81 |
conversation.append({"role": "system", "content": system_prompt})
|
82 |
for user, assistant in chat_history:
|
83 |
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
|
|
|
84 |
|
85 |
+
|
86 |
+
retriever_inputs = retriever_tokenizer([message], max_length=1024, truncation=True, return_tensors="pt")
|
87 |
+
with torch.no_grad():
|
88 |
+
embeddings = model(**inputs).pooler_output
|
89 |
+
embeddings = embeddings.cpu().numpy()
|
90 |
+
|
91 |
+
results = index.query(
|
92 |
+
vector=embeddings[0],
|
93 |
+
top_k=1,
|
94 |
+
include_values=False,
|
95 |
+
include_metadata=True
|
96 |
+
)
|
97 |
+
|
98 |
+
conversation.append({"role": "user", "content": results["matches"][0]["metadata"]+f"\n\n위 문맥을 참고하여 질문 '{message}'에 답하면?"})
|
99 |
+
|
100 |
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True)
|
101 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
102 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|