Update agent.py
Browse files
agent.py
CHANGED
@@ -49,7 +49,7 @@ def _download_file(file_id: str) -> bytes:
|
|
49 |
# --------------------------------------------------------------------------- #
|
50 |
class GeminiModel:
|
51 |
"""
|
52 |
-
Thin adapter around google-genai
|
53 |
"""
|
54 |
|
55 |
def __init__(
|
@@ -61,15 +61,15 @@ class GeminiModel:
|
|
61 |
api_key = os.getenv("GOOGLE_API_KEY")
|
62 |
if not api_key:
|
63 |
raise EnvironmentError("GOOGLE_API_KEY is not set.")
|
64 |
-
# One client per process is enough
|
65 |
self.client = genai.Client(api_key=api_key)
|
66 |
self.model_name = model_name
|
67 |
self.temperature = temperature
|
68 |
self.max_tokens = max_tokens
|
69 |
|
70 |
-
# ----------
|
71 |
def call(self, prompt: str, **kwargs) -> str:
|
72 |
-
|
|
|
73 |
model=self.model_name,
|
74 |
contents=prompt,
|
75 |
generation_config=gtypes.GenerateContentConfig(
|
@@ -77,23 +77,16 @@ class GeminiModel:
|
|
77 |
max_output_tokens=self.max_tokens,
|
78 |
),
|
79 |
)
|
80 |
-
return
|
81 |
|
82 |
-
# ---------- smolagents will use this when messages are present ---------- #
|
83 |
def call_messages(self, messages, **kwargs) -> str:
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
# Multimodal path – pass system text first, then structured user parts
|
92 |
-
contents = [sys_msg["content"], *user_msg["content"]]
|
93 |
-
else:
|
94 |
-
# Text prompt path
|
95 |
-
contents = f"{sys_msg['content']}\n\n{user_msg['content']}"
|
96 |
-
response = self.client.models.generate_content(
|
97 |
model=self.model_name,
|
98 |
contents=contents,
|
99 |
generation_config=gtypes.GenerateContentConfig(
|
@@ -101,7 +94,11 @@ class GeminiModel:
|
|
101 |
max_output_tokens=self.max_tokens,
|
102 |
),
|
103 |
)
|
104 |
-
return
|
|
|
|
|
|
|
|
|
105 |
|
106 |
|
107 |
# --------------------------------------------------------------------------- #
|
|
|
49 |
# --------------------------------------------------------------------------- #
|
50 |
class GeminiModel:
|
51 |
"""
|
52 |
+
Thin adapter around google-genai Client for smolagents.
|
53 |
"""
|
54 |
|
55 |
def __init__(
|
|
|
61 |
api_key = os.getenv("GOOGLE_API_KEY")
|
62 |
if not api_key:
|
63 |
raise EnvironmentError("GOOGLE_API_KEY is not set.")
|
|
|
64 |
self.client = genai.Client(api_key=api_key)
|
65 |
self.model_name = model_name
|
66 |
self.temperature = temperature
|
67 |
self.max_tokens = max_tokens
|
68 |
|
69 |
+
# ---------- main generation helpers ---------- #
|
70 |
def call(self, prompt: str, **kwargs) -> str:
|
71 |
+
"""Text-only helper used by __call__."""
|
72 |
+
resp = self.client.models.generate_content(
|
73 |
model=self.model_name,
|
74 |
contents=prompt,
|
75 |
generation_config=gtypes.GenerateContentConfig(
|
|
|
77 |
max_output_tokens=self.max_tokens,
|
78 |
),
|
79 |
)
|
80 |
+
return resp.text.strip()
|
81 |
|
|
|
82 |
def call_messages(self, messages, **kwargs) -> str:
|
83 |
+
sys_msg, user_msg = messages
|
84 |
+
contents = (
|
85 |
+
[sys_msg["content"], *user_msg["content"]]
|
86 |
+
if isinstance(user_msg["content"], list)
|
87 |
+
else f"{sys_msg['content']}\n\n{user_msg['content']}"
|
88 |
+
)
|
89 |
+
resp = self.client.models.generate_content(
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
model=self.model_name,
|
91 |
contents=contents,
|
92 |
generation_config=gtypes.GenerateContentConfig(
|
|
|
94 |
max_output_tokens=self.max_tokens,
|
95 |
),
|
96 |
)
|
97 |
+
return resp.text.strip()
|
98 |
+
|
99 |
+
# ---------- make the instance itself callable ---------- #
|
100 |
+
def __call__(self, prompt: str, **kwargs) -> str: # <-- NEW
|
101 |
+
return self.call(prompt, **kwargs)
|
102 |
|
103 |
|
104 |
# --------------------------------------------------------------------------- #
|