File size: 3,740 Bytes
4b91514 2c59c95 4b91514 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import json
import logging
import requests
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# Setup logging
logging.basicConfig(level=logging.INFO)
def request_generation(
header: dict,
messages: dict,
cloud_gateway_api: str,
model_name: str,
max_new_tokens: int = 1024,
temperature: float = 0.3,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
):
"""
Request streaming generation from the cloud gateway API. Uses the simple requests module with stream=True to utilize
token-by-token generation from LLM.
Args:
header: authorization header for the API.
message: prompt from the user.
system_prompt: system prompt to append.
cloud_gateway_api (str): API endpoint to send the request.
max_new_tokens: maximum number of tokens to generate, ignoring the number of tokens in the prompt.
temperature: the value used to module the next token probabilities.
top_p: if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p
or higher are kept for generation.
repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty.
Returns:
"""
payload = {
"model": model_name,
"messages": messages,
"max_tokens": max_new_tokens,
"temperature": temperature,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"stream": True, # Enable streaming
"serving_runtime": "vllm",
}
try:
response = requests.post(
cloud_gateway_api + "chat/conversation",
headers=header,
json=payload,
verify=False,
)
response.raise_for_status()
# Append the conversation ID with the key X-Conversation-ID to the header
header["X-Conversation-ID"] = response.json()["conversationId"]
with requests.get(
cloud_gateway_api + f"conversation/stream",
headers=header,
verify=False,
stream=True,
) as response:
for chunk in response.iter_lines():
if chunk:
# Convert the chunk from bytes to a string and then parse it as json
chunk_str = chunk.decode("utf-8")
# Remove the `data: ` prefix from the chunk if it exists
for _ in range(2):
if chunk_str.startswith("data: "):
chunk_str = chunk_str[len("data: ") :]
# Skip empty chunks
if chunk_str.strip() == "[DONE]":
break
# Parse the chunk into a JSON object
try:
chunk_json = json.loads(chunk_str)
# Extract the "content" field from the choices
if "choices" in chunk_json and chunk_json["choices"]:
content = chunk_json["choices"][0]["delta"].get(
"content", ""
)
else:
content = ""
# Print the generated content as it's streamed
if content:
yield content
except json.JSONDecodeError:
# Handle any potential errors in decoding
continue
except requests.RequestException as e:
logging.error(f"Failed to generate response: {e}")
yield "Server not responding. Please try again later."
|