psyche commited on
Commit
ed9bdc3
·
verified ·
1 Parent(s): 0d3160d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -77,7 +77,7 @@ def generate(
77
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
78
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
79
  input_ids = input_ids.to(model.device)
80
-
81
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
82
  generate_kwargs = dict(
83
  {"input_ids": input_ids},
@@ -90,15 +90,16 @@ def generate(
90
  num_beams=1,
91
  repetition_penalty=repetition_penalty,
92
  )
 
 
93
  t = Thread(target=model.generate, kwargs=generate_kwargs)
94
  t.start()
95
-
96
  outputs = []
97
  for text in streamer:
98
  outputs.append(text)
99
  yield "".join(outputs)
100
 
101
- save_json("user", message)
102
  save_json("assistant", "".join(outputs))
103
 
104
 
 
77
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
78
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
79
  input_ids = input_ids.to(model.device)
80
+
81
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
82
  generate_kwargs = dict(
83
  {"input_ids": input_ids},
 
90
  num_beams=1,
91
  repetition_penalty=repetition_penalty,
92
  )
93
+ save_json("user", message)
94
+
95
  t = Thread(target=model.generate, kwargs=generate_kwargs)
96
  t.start()
97
+
98
  outputs = []
99
  for text in streamer:
100
  outputs.append(text)
101
  yield "".join(outputs)
102
 
 
103
  save_json("assistant", "".join(outputs))
104
 
105