|
import json |
|
import logging |
|
import requests |
|
import urllib3 |
|
|
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
def request_generation( |
|
header: dict, |
|
messages: dict, |
|
cloud_gateway_api: str, |
|
model_name: str, |
|
max_new_tokens: int = 1024, |
|
temperature: float = 0.3, |
|
frequency_penalty: float = 0.0, |
|
presence_penalty: float = 0.0, |
|
): |
|
""" |
|
Request streaming generation from the cloud gateway API. Uses the simple requests module with stream=True to utilize |
|
token-by-token generation from LLM. |
|
|
|
Args: |
|
header: authorization header for the API. |
|
message: prompt from the user. |
|
system_prompt: system prompt to append. |
|
cloud_gateway_api (str): API endpoint to send the request. |
|
max_new_tokens: maximum number of tokens to generate, ignoring the number of tokens in the prompt. |
|
temperature: the value used to module the next token probabilities. |
|
top_p: if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p |
|
or higher are kept for generation. |
|
repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty. |
|
|
|
Returns: |
|
|
|
""" |
|
|
|
payload = { |
|
"model": model_name, |
|
"messages": messages, |
|
"max_tokens": max_new_tokens, |
|
"temperature": temperature, |
|
"frequency_penalty": frequency_penalty, |
|
"presence_penalty": presence_penalty, |
|
"stream": True, |
|
"serving_runtime": "vllm", |
|
} |
|
|
|
try: |
|
response = requests.post( |
|
cloud_gateway_api + "chat/conversation", |
|
headers=header, |
|
json=payload, |
|
verify=False, |
|
) |
|
|
|
response.raise_for_status() |
|
|
|
|
|
header["X-Conversation-ID"] = response.json()["conversationId"] |
|
|
|
with requests.get( |
|
cloud_gateway_api + f"conversation/stream", |
|
headers=header, |
|
verify=False, |
|
stream=True, |
|
) as response: |
|
for chunk in response.iter_lines(): |
|
if chunk: |
|
|
|
chunk_str = chunk.decode("utf-8") |
|
|
|
|
|
for _ in range(2): |
|
if chunk_str.startswith("data: "): |
|
chunk_str = chunk_str[len("data: ") :] |
|
|
|
|
|
if chunk_str.strip() == "[DONE]": |
|
break |
|
|
|
|
|
try: |
|
chunk_json = json.loads(chunk_str) |
|
|
|
|
|
if "choices" in chunk_json and chunk_json["choices"]: |
|
content = chunk_json["choices"][0]["delta"].get( |
|
"content", "" |
|
) |
|
else: |
|
content = "" |
|
|
|
|
|
if content: |
|
yield content |
|
except json.JSONDecodeError: |
|
|
|
continue |
|
except requests.RequestException as e: |
|
logging.error(f"Failed to generate response: {e}") |
|
yield "Server not responding. Please try again later." |
|
|