hadadrjt commited on
Commit
be88a2b
·
1 Parent(s): 4da5eac

ai: Implementing server-side streaming responses.

Browse files

* Say Hi! to fast responses.
* Say No! to slow responses.

Files changed (1) hide show
  1. jarvis.py +73 -58
jarvis.py CHANGED
@@ -43,11 +43,13 @@ LINUX_SERVER_PROVIDER_KEYS_ATTEMPTS = {}
43
  LINUX_SERVER_ERRORS = set(map(int, os.getenv("LINUX_SERVER_ERROR", "").split(",")))
44
 
45
  AI_TYPES = {f"AI_TYPE_{i}": os.getenv(f"AI_TYPE_{i}") for i in range(1, 8)}
46
- RESPONSES = {f"RESPONSE_{i}": os.getenv(f"RESPONSE_{i}") for i in range(1, 10)}
 
47
 
48
  MODEL_MAPPING = json.loads(os.getenv("MODEL_MAPPING", "{}"))
49
  MODEL_CONFIG = json.loads(os.getenv("MODEL_CONFIG", "{}"))
50
  MODEL_CHOICES = list(MODEL_MAPPING.values())
 
51
  DEFAULT_CONFIG = json.loads(os.getenv("DEFAULT_CONFIG", "{}"))
52
  DEFAULT_MODEL_KEY = list(MODEL_MAPPING.keys())[0] if MODEL_MAPPING else None
53
 
@@ -199,30 +201,36 @@ def extract_file_content(fp):
199
  except Exception as e:
200
  return f"{fp}: {e}"
201
 
202
- async def fetch_response_async(host, key, model, msgs, cfg, sid):
203
  for t in [1, 2]:
204
  try:
205
  async with httpx.AsyncClient(timeout=t) as client:
206
- r = await client.post(host, json={"model": model, "messages": msgs, **cfg, "session_id": sid}, headers={"Authorization": f"Bearer {key}"})
207
- if r.status_code in LINUX_SERVER_ERRORS:
208
- marked_item(key, LINUX_SERVER_PROVIDER_KEYS_MARKED, LINUX_SERVER_PROVIDER_KEYS_ATTEMPTS)
209
- return None
210
- r.raise_for_status()
211
- j = r.json()
212
- if isinstance(j, dict) and j.get("choices"):
213
- ch = j["choices"][0]
214
- if ch.get("message") and isinstance(ch["message"].get("content"), str):
215
- return ch["message"]["content"]
216
- return None
 
 
 
 
 
217
  except:
218
  continue
219
- marked_item(key, LINUX_SERVER_PROVIDER_KEYS_MARKED, LINUX_SERVER_PROVIDER_KEYS_ATTEMPTS)
220
- return None
221
 
222
  async def chat_with_model_async(history, user_input, model_display, sess, custom_prompt):
223
  ensure_stop_event(sess)
224
  if not get_available_items(LINUX_SERVER_PROVIDER_KEYS, LINUX_SERVER_PROVIDER_KEYS_MARKED) or not get_available_items(LINUX_SERVER_HOSTS, LINUX_SERVER_HOSTS_ATTEMPTS):
225
- return RESPONSES["RESPONSE_3"]
 
226
  if not hasattr(sess, "session_id") or not sess.session_id:
227
  sess.session_id = str(uuid.uuid4())
228
  sess.stop_event = asyncio.Event()
@@ -235,25 +243,27 @@ async def chat_with_model_async(history, user_input, model_display, sess, custom
235
  msgs.insert(0, {"role": "system", "content": prompt})
236
  msgs.append({"role": "user", "content": user_input})
237
  if sess.active_candidate:
238
- res = await fetch_response_async(sess.active_candidate[0], sess.active_candidate[1], model_key, msgs, cfg, sess.session_id)
239
- if res:
240
- return res
241
- sess.active_candidate = None
242
  keys = get_available_items(LINUX_SERVER_PROVIDER_KEYS, LINUX_SERVER_PROVIDER_KEYS_MARKED)
243
  hosts = get_available_items(LINUX_SERVER_HOSTS, LINUX_SERVER_HOSTS_ATTEMPTS)
244
  random.shuffle(keys)
245
  random.shuffle(hosts)
246
  for k in keys:
247
  for h in hosts:
248
- task = asyncio.create_task(fetch_response_async(h, k, model_key, msgs, cfg, sess.session_id))
249
- done, _ = await asyncio.wait({task}, return_when=asyncio.FIRST_COMPLETED)
250
- if task in done:
251
- result = task.result()
252
- if result:
 
253
  sess.active_candidate = (h, k)
254
- return result
255
- task.cancel()
256
- return RESPONSES["RESPONSE_2"]
 
 
257
 
258
  async def respond_async(multi, history, model_display, sess, custom_prompt):
259
  ensure_stop_event(sess)
@@ -270,37 +280,42 @@ async def respond_async(multi, history, model_display, sess, custom_prompt):
270
  inp += msg_input["text"]
271
  history.append([inp, RESPONSES["RESPONSE_8"]])
272
  yield history, gr.update(interactive=False, submit_btn=False, stop_btn=True), sess
273
- task = asyncio.create_task(chat_with_model_async(history, inp, model_display, sess, custom_prompt))
 
 
 
 
 
 
 
 
274
  stop_task = asyncio.create_task(sess.stop_event.wait())
275
- done, pending = await asyncio.wait({task, stop_task}, return_when=asyncio.FIRST_COMPLETED)
276
- if stop_task in done:
277
- task.cancel()
278
- history[-1][1] = RESPONSES["RESPONSE_1"]
279
- yield history, gr.update(value="", interactive=True, submit_btn=True, stop_btn=False), sess
280
- sess.stop_event.clear()
281
- return
282
- stop_task.cancel()
283
- ai = task.result()
284
- history[-1][1] = ""
285
- buffer = []
286
- last_update = asyncio.get_event_loop().time()
287
- for char in ai:
288
- if sess.stop_event.is_set():
289
- history[-1][1] = RESPONSES["RESPONSE_1"]
290
- yield history, gr.update(value="", interactive=True, submit_btn=True, stop_btn=False), sess
291
- sess.stop_event.clear()
292
- return
293
- buffer.append(char)
294
- current_time = asyncio.get_event_loop().time()
295
- if len(buffer) >= 4 or (current_time - last_update) > 0.001:
296
- history[-1][1] += "".join(buffer)
297
- buffer.clear()
298
- last_update = current_time
299
- yield history, gr.update(interactive=False, submit_btn=False, stop_btn=True), sess
300
- await asyncio.sleep(0.003)
301
- if buffer:
302
- history[-1][1] += "".join(buffer)
303
- yield history, gr.update(interactive=False, submit_btn=False, stop_btn=True), sess
304
  yield history, gr.update(value="", interactive=True, submit_btn=True, stop_btn=False), sess
305
 
306
  def change_model(new):
 
43
  LINUX_SERVER_ERRORS = set(map(int, os.getenv("LINUX_SERVER_ERROR", "").split(",")))
44
 
45
  AI_TYPES = {f"AI_TYPE_{i}": os.getenv(f"AI_TYPE_{i}") for i in range(1, 8)}
46
+
47
+ RESPONSES = {f"RESPONSE_{i}": os.getenv(f"RESPONSE_{i}") for i in range(1, 11)}
48
 
49
  MODEL_MAPPING = json.loads(os.getenv("MODEL_MAPPING", "{}"))
50
  MODEL_CONFIG = json.loads(os.getenv("MODEL_CONFIG", "{}"))
51
  MODEL_CHOICES = list(MODEL_MAPPING.values())
52
+
53
  DEFAULT_CONFIG = json.loads(os.getenv("DEFAULT_CONFIG", "{}"))
54
  DEFAULT_MODEL_KEY = list(MODEL_MAPPING.keys())[0] if MODEL_MAPPING else None
55
 
 
201
  except Exception as e:
202
  return f"{fp}: {e}"
203
 
204
+ async def fetch_response_stream_async(host, key, model, msgs, cfg, sid):
205
  for t in [1, 2]:
206
  try:
207
  async with httpx.AsyncClient(timeout=t) as client:
208
+ async with client.stream("POST", host, json={**{"model": model, "messages": msgs, "session_id": sid, "stream": True}, **cfg}, headers={"Authorization": f"Bearer {key}"}) as response:
209
+ async for line in response.aiter_lines():
210
+ if not line:
211
+ continue
212
+ if line.startswith("data: "):
213
+ data = line[6:]
214
+ if data.strip() == RESPONSES["RESPONSE_10"]:
215
+ return
216
+ try:
217
+ j = json.loads(data)
218
+ if isinstance(j, dict) and j.get("choices"):
219
+ ch = j["choices"][0]
220
+ if ch.get("delta") and isinstance(ch["delta"].get("content"), str):
221
+ yield ch["delta"]["content"]
222
+ except:
223
+ continue
224
  except:
225
  continue
226
+ marked_item(key, LINUX_SERVER_PROVIDER_KEYS_MARKED, LINUX_SERVER_PROVIDER_KEYS_ATTEMPTS)
227
+ return
228
 
229
  async def chat_with_model_async(history, user_input, model_display, sess, custom_prompt):
230
  ensure_stop_event(sess)
231
  if not get_available_items(LINUX_SERVER_PROVIDER_KEYS, LINUX_SERVER_PROVIDER_KEYS_MARKED) or not get_available_items(LINUX_SERVER_HOSTS, LINUX_SERVER_HOSTS_ATTEMPTS):
232
+ yield RESPONSES["RESPONSE_3"]
233
+ return
234
  if not hasattr(sess, "session_id") or not sess.session_id:
235
  sess.session_id = str(uuid.uuid4())
236
  sess.stop_event = asyncio.Event()
 
243
  msgs.insert(0, {"role": "system", "content": prompt})
244
  msgs.append({"role": "user", "content": user_input})
245
  if sess.active_candidate:
246
+ async for chunk in fetch_response_stream_async(sess.active_candidate[0], sess.active_candidate[1], model_key, msgs, cfg, sess.session_id):
247
+ yield chunk
248
+ return
 
249
  keys = get_available_items(LINUX_SERVER_PROVIDER_KEYS, LINUX_SERVER_PROVIDER_KEYS_MARKED)
250
  hosts = get_available_items(LINUX_SERVER_HOSTS, LINUX_SERVER_HOSTS_ATTEMPTS)
251
  random.shuffle(keys)
252
  random.shuffle(hosts)
253
  for k in keys:
254
  for h in hosts:
255
+ stream_gen = fetch_response_stream_async(h, k, model_key, msgs, cfg, sess.session_id)
256
+ full_text = ""
257
+ got_any = False
258
+ async for chunk in stream_gen:
259
+ if not got_any:
260
+ got_any = True
261
  sess.active_candidate = (h, k)
262
+ full_text += chunk
263
+ yield chunk
264
+ if got_any and full_text:
265
+ return
266
+ yield RESPONSES["RESPONSE_2"]
267
 
268
  async def respond_async(multi, history, model_display, sess, custom_prompt):
269
  ensure_stop_event(sess)
 
280
  inp += msg_input["text"]
281
  history.append([inp, RESPONSES["RESPONSE_8"]])
282
  yield history, gr.update(interactive=False, submit_btn=False, stop_btn=True), sess
283
+ queue = asyncio.Queue()
284
+ async def background():
285
+ full = ""
286
+ async for chunk in chat_with_model_async(history, inp, model_display, sess, custom_prompt):
287
+ full += chunk
288
+ await queue.put(chunk)
289
+ await queue.put(None)
290
+ return full
291
+ bg_task = asyncio.create_task(background())
292
  stop_task = asyncio.create_task(sess.stop_event.wait())
293
+ first_meaningful_chunk_found = False
294
+ try:
295
+ while True:
296
+ done, _ = await asyncio.wait({stop_task, asyncio.create_task(queue.get())}, return_when=asyncio.FIRST_COMPLETED)
297
+ if stop_task in done:
298
+ bg_task.cancel()
299
+ history[-1][1] = RESPONSES["RESPONSE_1"]
300
+ yield history, gr.update(value="", interactive=True, submit_btn=True, stop_btn=False), sess
301
+ sess.stop_event.clear()
302
+ return
303
+ for d in done:
304
+ chunk = d.result()
305
+ if chunk is None:
306
+ raise StopAsyncIteration
307
+ if not first_meaningful_chunk_found:
308
+ if chunk.strip():
309
+ history[-1][1] = chunk
310
+ first_meaningful_chunk_found = True
311
+ else:
312
+ history[-1][1] += chunk
313
+ yield history, gr.update(interactive=False, submit_btn=False, stop_btn=True), sess
314
+ except StopAsyncIteration:
315
+ pass
316
+ finally:
317
+ stop_task.cancel()
318
+ full_response = await bg_task
 
 
 
319
  yield history, gr.update(value="", interactive=True, submit_btn=True, stop_btn=False), sess
320
 
321
  def change_model(new):