File size: 5,174 Bytes
f1d5e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import pdb
from dataclasses import dataclass

from dotenv import load_dotenv
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_ollama import ChatOllama

load_dotenv()

import sys

sys.path.append(".")

@dataclass
class LLMConfig:
    provider: str
    model_name: str
    temperature: float = 0.8
    base_url: str = None
    api_key: str = None

def create_message_content(text, image_path=None):
    content = [{"type": "text", "text": text}]
    image_format = "png" if image_path and image_path.endswith(".png") else "jpeg"
    if image_path:
        from src.utils import utils
        image_data = utils.encode_image(image_path)
        content.append({
            "type": "image_url",
            "image_url": {"url": f"data:image/{image_format};base64,{image_data}"}
        })
    return content

def get_env_value(key, provider):
    env_mappings = {
        "openai": {"api_key": "OPENAI_API_KEY", "base_url": "OPENAI_ENDPOINT"},
        "azure_openai": {"api_key": "AZURE_OPENAI_API_KEY", "base_url": "AZURE_OPENAI_ENDPOINT"},
        "google": {"api_key": "GOOGLE_API_KEY"},
        "deepseek": {"api_key": "DEEPSEEK_API_KEY", "base_url": "DEEPSEEK_ENDPOINT"},
        "mistral": {"api_key": "MISTRAL_API_KEY", "base_url": "MISTRAL_ENDPOINT"},
        "alibaba": {"api_key": "ALIBABA_API_KEY", "base_url": "ALIBABA_ENDPOINT"},
        "moonshot":{"api_key": "MOONSHOT_API_KEY", "base_url": "MOONSHOT_ENDPOINT"},
    }

    if provider in env_mappings and key in env_mappings[provider]:
        return os.getenv(env_mappings[provider][key], "")
    return ""

def test_llm(config, query, image_path=None, system_message=None):
    from src.utils import utils

    # Special handling for Ollama-based models
    if config.provider == "ollama":
        if "deepseek-r1" in config.model_name:
            from src.utils.llm import DeepSeekR1ChatOllama
            llm = DeepSeekR1ChatOllama(model=config.model_name)
        else:
            llm = ChatOllama(model=config.model_name)

        ai_msg = llm.invoke(query)
        print(ai_msg.content)
        if "deepseek-r1" in config.model_name:
            pdb.set_trace()
        return

    # For other providers, use the standard configuration
    llm = utils.get_llm_model(
        provider=config.provider,
        model_name=config.model_name,
        temperature=config.temperature,
        base_url=config.base_url or get_env_value("base_url", config.provider),
        api_key=config.api_key or get_env_value("api_key", config.provider)
    )

    # Prepare messages for non-Ollama models
    messages = []
    if system_message:
        messages.append(SystemMessage(content=create_message_content(system_message)))
    messages.append(HumanMessage(content=create_message_content(query, image_path)))
    ai_msg = llm.invoke(messages)

    # Handle different response types
    if hasattr(ai_msg, "reasoning_content"):
        print(ai_msg.reasoning_content)
    print(ai_msg.content)

    if config.provider == "deepseek" and "deepseek-reasoner" in config.model_name:
        print(llm.model_name)
        pdb.set_trace()

def test_openai_model():
    config = LLMConfig(provider="openai", model_name="gpt-4o")
    test_llm(config, "Describe this image", "assets/examples/test.png")

def test_google_model():
    # Enable your API key first if you haven't: https://ai.google.dev/palm_docs/oauth_quickstart
    config = LLMConfig(provider="google", model_name="gemini-2.0-flash-exp")
    test_llm(config, "Describe this image", "assets/examples/test.png")

def test_azure_openai_model():
    config = LLMConfig(provider="azure_openai", model_name="gpt-4o")
    test_llm(config, "Describe this image", "assets/examples/test.png")

def test_deepseek_model():
    config = LLMConfig(provider="deepseek", model_name="deepseek-chat")
    test_llm(config, "Who are you?")

def test_deepseek_r1_model():
    config = LLMConfig(provider="deepseek", model_name="deepseek-reasoner")
    test_llm(config, "Which is greater, 9.11 or 9.8?", system_message="You are a helpful AI assistant.")

def test_ollama_model():
    config = LLMConfig(provider="ollama", model_name="qwen2.5:7b")
    test_llm(config, "Sing a ballad of LangChain.")

def test_deepseek_r1_ollama_model():
    config = LLMConfig(provider="ollama", model_name="deepseek-r1:14b")
    test_llm(config, "How many 'r's are in the word 'strawberry'?")

def test_mistral_model():
    config = LLMConfig(provider="mistral", model_name="pixtral-large-latest")
    test_llm(config, "Describe this image", "assets/examples/test.png")

def test_moonshot_model():
    config = LLMConfig(provider="moonshot", model_name="moonshot-v1-32k-vision-preview")
    test_llm(config, "Describe this image", "assets/examples/test.png")

if __name__ == "__main__":
    # test_openai_model()
    # test_google_model()
    # test_azure_openai_model()
    #test_deepseek_model()
    # test_ollama_model()
    test_deepseek_r1_model()
    # test_deepseek_r1_ollama_model()
    # test_mistral_model()