File size: 3,977 Bytes
683e45d
 
d44e0f3
683e45d
d44e0f3
0d7ada2
683e45d
 
 
d44e0f3
 
 
 
0d7ada2
683e45d
d44e0f3
0d7ada2
683e45d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d7ada2
683e45d
 
 
 
 
0d7ada2
683e45d
 
 
 
 
0d7ada2
d44e0f3
683e45d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# application/chat_inference.py
import time  # Import the time module
from application.utils.chat_completion_api import ChatCompletionAPI
from config import Response, pipeline_dict, convs_dict
import os
from application.utils.image_captioning import ImageCaptioning
from application.utils.web_search import WebScarper
from application.utils.image_generation import generate_image  # Import


class ChatInference:
    def __init__(self):
        self.chatCompletionAPI = ChatCompletionAPI()
        self.image_captioning = ImageCaptioning()
        self.web_scraper = WebScarper()


    def validate(self, data, user):
      try:
        pipeline = pipeline_dict['api']['models']
        model = data['model']
        self.headers = pipeline[model]['headers']
        self.updateHeaders = {}
        for header in self.headers:
            if(header=="config"):
                for configHeader in self.headers[header]:
                    if(configHeader=="Authorization"):
                        auth = self.headers[header][configHeader].split(' ')
                        self.updateHeaders[configHeader] = f"{auth[0]} {eval(auth[1])}"  # Directly evaluate
                    elif(configHeader=="comment"):
                        pass
                    else:
                        self.updateHeaders[configHeader] = f"{eval(self.headers[header][configHeader])}" # Directly evaluate
            else:
                self.updateHeaders[header] = self.headers[header]
        prompt = data['prompt']
        max_tokens = data.get('max_token', 10020)
        temperature = max(0, min(data.get('temperature', 0.7), 2))
        top_p = max(0.1, min(data.get('top_p', 0.9), 1))
        system = data.get('system_prompt','You are a helpful and harmless AI assistant. You are xylaria made by sk md saad amin. You should think step-by-step')
        convId = data['convId']
        image = data.get('image')

        if(len(convs_dict[user][convId]['messages'])==1):
            if system:
                # Include user memory in the system prompt
                system_prompt = f"{system}\n\nMemory: {convs_dict[user]['memory']}"
                convs_dict[user][convId]['messages'][0]['content'] = system_prompt  # Update existing system message

            convs_dict[user]['metadata'].insert(0,{"convId": convId, "title": prompt[:23]})
            convs_dict[user][convId]['title'] = prompt[:30]
        if image:
            caption = self.image_captioning.generate_caption(image)
            prompt = f"{caption}\n\n{prompt}"


        if(pipeline[model]['type'] == 'image-text-to-text'):
            convs_dict[user][convId]['messages'].append({"role": "user", "content": [{"type":"text","text":prompt}]})
        else:
            convs_dict[user][convId]['messages'].append({"role":"user","content":prompt})  # Append user message
        transformed = {
            "model": model,
            "prompt": prompt,
            "messages": convs_dict[user][convId]['messages'],
            "max_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "stream": True
        }
        data.update(transformed)
        return data
      except KeyError as e:
        print(f"KeyError: {e}")  # Debugging
        return 400
      except Exception as e:
        print(f"An unexpected error occurred: {e}")  # Debugging
        return 500

    def chat(self, data, handle_stream, user):
        start_time = time.time()  # Capture start time
        data = self.validate(data=data, user=user)
        if isinstance(data, int):  # Check for error codes
            return "Required Parameters are Missing!", data

        return self.chatCompletionAPI.make_request(
            json=data,
            url=data['base_url'],
            handle_stream=handle_stream,
            messages=data['messages'],
            headers=self.updateHeaders,
            webSearch=data['webSearch'],
            start_time = start_time
        )