File size: 3,740 Bytes
4b91514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c59c95
4b91514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import json
import logging
import requests
import urllib3

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

# Setup logging
logging.basicConfig(level=logging.INFO)


def request_generation(
    header: dict,
    messages: dict,
    cloud_gateway_api: str,
    model_name: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.3,
    frequency_penalty: float = 0.0,
    presence_penalty: float = 0.0,
):
    """
    Request streaming generation from the cloud gateway API. Uses the simple requests module with stream=True to utilize
    token-by-token generation from LLM.

    Args:
        header: authorization header for the API.
        message: prompt from the user.
        system_prompt: system prompt to append.
        cloud_gateway_api (str): API endpoint to send the request.
        max_new_tokens: maximum number of tokens to generate, ignoring the number of tokens in the prompt.
        temperature: the value used to module the next token probabilities.
        top_p: if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p
                or higher are kept for generation.
        repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty.

    Returns:

    """

    payload = {
        "model": model_name,
        "messages": messages,
        "max_tokens": max_new_tokens,
        "temperature": temperature,
        "frequency_penalty": frequency_penalty,
        "presence_penalty": presence_penalty,
        "stream": True,  # Enable streaming
        "serving_runtime": "vllm",
    }

    try:
        response = requests.post(
            cloud_gateway_api + "chat/conversation",
            headers=header,
            json=payload,
            verify=False,
        )

        response.raise_for_status()

        # Append the conversation ID with the key X-Conversation-ID to the header
        header["X-Conversation-ID"] = response.json()["conversationId"]

        with requests.get(
            cloud_gateway_api + f"conversation/stream",
            headers=header,
            verify=False,
            stream=True,
        ) as response:
            for chunk in response.iter_lines():
                if chunk:
                    # Convert the chunk from bytes to a string and then parse it as json
                    chunk_str = chunk.decode("utf-8")

                    # Remove the `data: ` prefix from the chunk if it exists
                    for _ in range(2):
                        if chunk_str.startswith("data: "):
                            chunk_str = chunk_str[len("data: ") :]

                    # Skip empty chunks
                    if chunk_str.strip() == "[DONE]":
                        break

                    # Parse the chunk into a JSON object
                    try:
                        chunk_json = json.loads(chunk_str)

                        # Extract the "content" field from the choices
                        if "choices" in chunk_json and chunk_json["choices"]:
                            content = chunk_json["choices"][0]["delta"].get(
                                "content", ""
                            )
                        else:
                            content = ""

                        # Print the generated content as it's streamed
                        if content:
                            yield content
                    except json.JSONDecodeError:
                        # Handle any potential errors in decoding
                        continue
    except requests.RequestException as e:
        logging.error(f"Failed to generate response: {e}")
        yield "Server not responding. Please try again later."