File size: 3,269 Bytes
73f064f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57c77c4
73f064f
 
9695e26
 
 
 
3965865
9695e26
73f064f
57c77c4
 
73f064f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
""" vlm.py

Utilities for working with Vision Language Models

:author: Didier Guillevic
:email: [email protected]
:creation: 2024-12-28
"""

import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

import os
from mistralai import Mistral
import base64

#
# Mistral AI client
#
api_key = os.environ["MISTRAL_API_KEY"]
client = Mistral(api_key=api_key)
model_id = "mistral-small-latest" # 128k context window

#
# Encode images as base64
#
def encode_image(image_path):
    """Encode the image to base64."""
    try:
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')
    except FileNotFoundError:
        print(f"Error: The file {image_path} was not found.")
        return None
    except Exception as e:  # Added general exception handling
        print(f"Error: {e}")
        return None


#
# Build messages
#
def build_messages(message: dict, history: list[tuple]):
    """Build messages given message & history from a **multimodal** chat interface.

    Args:
        message: dictionary with keys: 'text', 'files'
        history: list of tuples with (message, response)
    
    Returns:
        list of messages (to be sent to the model)
    """
    logger.info(f"{message=}")
    logger.info(f"{history=}")
    # Get the user's text and list of images
    user_text = message.get("text", "")
    user_images = message.get("files", [])  # List of images

    # Build the message list including history
    messages = []
    combined_user_input = [] # Combine images and text if found in same turn.
    for user_turn, bot_turn in history:
        if isinstance(user_turn, tuple):  # Image input
            image_content = [
                {
                    "type": "image_url",
                    "image_url": f"data:image/jpeg;base64,{encode_image(image)}"
                } for image in user_turn
            ]
            combined_user_input.extend(image_content)
        elif isinstance(user_turn, str): # Text input
            combined_user_input.append({"type": "text", "text": user_turn})
        if combined_user_input and bot_turn:
            messages.append({'role': 'user', 'content': combined_user_input})
            messages.append({'role': 'assistant', 'content': [{"type": "text", "text": bot_turn}]})
            combined_user_input = [] #reset the combined user input.
    
    # Build the user message's content from the provided message
    user_content = []
    if user_text:
        user_content.append({"type": "text", "text": user_text})
    for image in user_images:
        user_content.append(
            {
                "type": "image_url",
                "image_url": f"data:image/jpeg;base64,{encode_image(image)}"
            }
        )
    
    messages.append({'role': 'user', 'content': user_content})
    logger.info(f"{messages=}")

    return messages

#
# get response
#
def get_response(messages: list[dict]):
    """Stream the model's response to the chat interface.
    
    Args:
        messages: list of messages to send to the model
    """
    response = client.chat.complete(model=model_id, messages=messages)
    logger.info(f"{response=}")
    return response.choices[0].message.content