psyche commited on
Commit
d9f674b
·
verified ·
1 Parent(s): f1ac54e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -0
app.py CHANGED
@@ -84,6 +84,7 @@ def generate(
84
 
85
 
86
  retriever_inputs = retriever_tokenizer([message], max_length=1024, truncation=True, return_tensors="pt")
 
87
  with torch.no_grad():
88
  embeddings = retriever(**retriever_inputs).pooler_output
89
  embeddings = embeddings.cpu().numpy()
 
84
 
85
 
86
  retriever_inputs = retriever_tokenizer([message], max_length=1024, truncation=True, return_tensors="pt")
87
+ retriever_inputs = {k:v.to(retriever.device) for k,v in retriever_inputs.items()}
88
  with torch.no_grad():
89
  embeddings = retriever(**retriever_inputs).pooler_output
90
  embeddings = embeddings.cpu().numpy()