real-jiakai commited on
Commit
97d8b63
·
verified ·
1 Parent(s): d367dae

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +17 -20
agent.py CHANGED
@@ -49,7 +49,7 @@ def _download_file(file_id: str) -> bytes:
49
  # --------------------------------------------------------------------------- #
50
  class GeminiModel:
51
  """
52
- Thin adapter around google-genai.Client so it can be used by smolagents.
53
  """
54
 
55
  def __init__(
@@ -61,15 +61,15 @@ class GeminiModel:
61
  api_key = os.getenv("GOOGLE_API_KEY")
62
  if not api_key:
63
  raise EnvironmentError("GOOGLE_API_KEY is not set.")
64
- # One client per process is enough
65
  self.client = genai.Client(api_key=api_key)
66
  self.model_name = model_name
67
  self.temperature = temperature
68
  self.max_tokens = max_tokens
69
 
70
- # ---------- Text-only convenience ---------- #
71
  def call(self, prompt: str, **kwargs) -> str:
72
- response = self.client.models.generate_content(
 
73
  model=self.model_name,
74
  contents=prompt,
75
  generation_config=gtypes.GenerateContentConfig(
@@ -77,23 +77,16 @@ class GeminiModel:
77
  max_output_tokens=self.max_tokens,
78
  ),
79
  )
80
- return response.text.strip()
81
 
82
- # ---------- smolagents will use this when messages are present ---------- #
83
  def call_messages(self, messages, **kwargs) -> str:
84
- """
85
- `messages` is a list of dictionaries with keys 'role' | 'content'.
86
- If `content` is already a list[types.Content], we forward it as-is.
87
- Otherwise we concatenate to a single string prompt.
88
- """
89
- sys_msg, user_msg = messages # CodeAgent always sends two
90
- if isinstance(user_msg["content"], list):
91
- # Multimodal path – pass system text first, then structured user parts
92
- contents = [sys_msg["content"], *user_msg["content"]]
93
- else:
94
- # Text prompt path
95
- contents = f"{sys_msg['content']}\n\n{user_msg['content']}"
96
- response = self.client.models.generate_content(
97
  model=self.model_name,
98
  contents=contents,
99
  generation_config=gtypes.GenerateContentConfig(
@@ -101,7 +94,11 @@ class GeminiModel:
101
  max_output_tokens=self.max_tokens,
102
  ),
103
  )
104
- return response.text.strip()
 
 
 
 
105
 
106
 
107
  # --------------------------------------------------------------------------- #
 
49
  # --------------------------------------------------------------------------- #
50
  class GeminiModel:
51
  """
52
+ Thin adapter around google-genai Client for smolagents.
53
  """
54
 
55
  def __init__(
 
61
  api_key = os.getenv("GOOGLE_API_KEY")
62
  if not api_key:
63
  raise EnvironmentError("GOOGLE_API_KEY is not set.")
 
64
  self.client = genai.Client(api_key=api_key)
65
  self.model_name = model_name
66
  self.temperature = temperature
67
  self.max_tokens = max_tokens
68
 
69
+ # ---------- main generation helpers ---------- #
70
  def call(self, prompt: str, **kwargs) -> str:
71
+ """Text-only helper used by __call__."""
72
+ resp = self.client.models.generate_content(
73
  model=self.model_name,
74
  contents=prompt,
75
  generation_config=gtypes.GenerateContentConfig(
 
77
  max_output_tokens=self.max_tokens,
78
  ),
79
  )
80
+ return resp.text.strip()
81
 
 
82
  def call_messages(self, messages, **kwargs) -> str:
83
+ sys_msg, user_msg = messages
84
+ contents = (
85
+ [sys_msg["content"], *user_msg["content"]]
86
+ if isinstance(user_msg["content"], list)
87
+ else f"{sys_msg['content']}\n\n{user_msg['content']}"
88
+ )
89
+ resp = self.client.models.generate_content(
 
 
 
 
 
 
90
  model=self.model_name,
91
  contents=contents,
92
  generation_config=gtypes.GenerateContentConfig(
 
94
  max_output_tokens=self.max_tokens,
95
  ),
96
  )
97
+ return resp.text.strip()
98
+
99
+ # ---------- make the instance itself callable ---------- #
100
+ def __call__(self, prompt: str, **kwargs) -> str: # <-- NEW
101
+ return self.call(prompt, **kwargs)
102
 
103
 
104
  # --------------------------------------------------------------------------- #