File size: 6,334 Bytes
dfd19f5
c1d3919
 
 
 
 
 
dfd19f5
 
c1d3919
 
dfd19f5
 
c1d3919
 
 
 
dfd19f5
c1d3919
dfd19f5
 
 
 
 
 
 
 
 
c1d3919
dfd19f5
 
 
 
c1d3919
dfd19f5
 
 
 
 
 
 
 
 
c1d3919
dfd19f5
0527a8f
dfd19f5
c1d3919
 
dfd19f5
 
 
 
 
 
 
 
 
 
 
 
c1d3919
dfd19f5
 
 
 
c1d3919
 
 
 
 
 
97d8b63
dfd19f5
 
ce2d7d4
c1d3919
dfd19f5
 
 
 
97d8b63
 
c1d3919
 
 
 
dfd19f5
c1d3919
 
 
 
 
 
 
 
dfd19f5
 
c1d3919
dfd19f5
b9a4880
dfd19f5
 
8310e6d
 
 
c1d3919
8310e6d
dfd19f5
c1d3919
 
dfd19f5
 
 
 
 
 
 
c1d3919
dfd19f5
 
 
c1d3919
dfd19f5
 
c1d3919
d367dae
c1d3919
 
 
 
 
 
d367dae
c1d3919
 
dfd19f5
c1d3919
 
 
 
 
 
d367dae
dfd19f5
c1d3919
dfd19f5
c1d3919
 
 
dfd19f5
 
d367dae
c1d3919
dfd19f5
c1d3919
 
dfd19f5
 
 
c1d3919
dfd19f5
 
 
 
d367dae
 
 
dfd19f5
c1d3919
 
 
d367dae
c1d3919
dfd19f5
c1d3919
 
 
d367dae
dfd19f5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""
agent.py – Gemini-smolagents baseline using google-genai SDK
-----------------------------------------------------------
Environment
-----------
GOOGLE_API_KEY   – API key from Google AI Studio
GAIA_API_URL     – (optional) override for the GAIA scoring endpoint
"""

from __future__ import annotations

import base64
import mimetypes
import os
import re
from typing import List

import google.genai as genai
import requests
from google.genai import types as gtypes
from smolagents import (
    CodeAgent,
    DuckDuckGoSearchTool,
    PythonInterpreterTool,
    tool,
)

# --------------------------------------------------------------------------- #
# constants & helpers
# --------------------------------------------------------------------------- #
DEFAULT_API_URL = os.getenv(
    "GAIA_API_URL", "https://agents-course-unit4-scoring.hf.space"
)
FILE_TAG = re.compile(r"<file:([^>]+)>")  # <file:xyz>

def _download_file(file_id: str) -> bytes:
    """Download the attachment for a GAIA task."""
    url = f"{DEFAULT_API_URL}/files/{file_id}"
    resp = requests.get(url, timeout=30)
    resp.raise_for_status()
    return resp.content

# --------------------------------------------------------------------------- #
# model wrapper
# --------------------------------------------------------------------------- #
class GeminiModel:
    """
    Minimal adapter around google-genai.Client so the instance itself is
    callable (required by smolagents).
    """

    def __init__(
        self,
        model_name: str = "gemini-2.0-flash",
        temperature: float = 0.1,
        max_tokens: int = 128,
    ):
        api_key = os.getenv("GOOGLE_API_KEY")
        if not api_key:
            raise EnvironmentError("GOOGLE_API_KEY is not set.")
        self.client = genai.Client(api_key=api_key)

        self.model_name = model_name
        self.temperature = temperature
        self.max_tokens = max_tokens

    # internal helper -------------------------------------------------------- #
    def _genai_call(
        self,
        contents,
        system_instruction: str,
    ) -> str:
        resp = self.client.models.generate_content(
            model=self.model_name,
            contents=contents,
            config=gtypes.GenerateContentConfig(
                system_instruction=system_instruction,
                temperature=self.temperature,
                max_output_tokens=self.max_tokens,
            ),
        )
        return resp.text.strip()

    # public helpers --------------------------------------------------------- #
    def __call__(self, prompt: str, system_instruction: str, **__) -> str:
        """Used by CodeAgent for plain-text questions."""
        return self._genai_call(prompt, system_instruction)

    def call_parts(
        self,
        parts: List[gtypes.Part],
        system_instruction: str,
    ) -> str:
        """Multimodal path used by GeminiAgent for <file:…> questions."""
        user_content = gtypes.Content(role="user", parts=parts)
        return self._genai_call([user_content], system_instruction)

# --------------------------------------------------------------------------- #
# custom tool: fetch GAIA attachments
# --------------------------------------------------------------------------- #
@tool
def gaia_file_reader(file_id: str) -> str:
    """
    Download a GAIA attachment and return its contents.

    Args:
        file_id: identifier that appears inside a <file:...> placeholder.

    Returns:
        base64-encoded string for binary files (images, PDFs, …) or decoded
        UTF-8 text for textual files.
    """
    try:
        raw = _download_file(file_id)
        mime = mimetypes.guess_type(file_id)[0] or "application/octet-stream"
        if mime.startswith("text") or mime in ("application/json",):
            return raw.decode(errors="ignore")
        return base64.b64encode(raw).decode()
    except Exception as exc:  # pragma: no cover
        return f"ERROR downloading {file_id}: {exc}"

# --------------------------------------------------------------------------- #
# final agent
# --------------------------------------------------------------------------- #
class GeminiAgent:
    """Instantiated once in app.py; called once per question."""

    SYSTEM_PROMPT = (
        "You are a concise, highly accurate assistant. "
        "Unless explicitly required, reply with ONE short sentence. "
        "Use the provided tools if needed. "
        "All answers are graded by exact string match."
    )

    def __init__(self):
        self.model = GeminiModel()
        self.agent = CodeAgent(
            model=self.model,
            tools=[
                PythonInterpreterTool(),
                DuckDuckGoSearchTool(),
                gaia_file_reader,
            ],
            verbosity_level=0,
        )
        print("βœ… GeminiAgent initialised.")

    # --------------------------------------------------------------------- #
    # main entry point
    # --------------------------------------------------------------------- #
    def __call__(self, question: str) -> str:
        file_ids = FILE_TAG.findall(question)

        # ---------- multimodal branch (images / files) -------------------- #
        if file_ids:
            parts: List[gtypes.Part] = []

            text_part = FILE_TAG.sub("", question).strip()
            if text_part:
                parts.append(gtypes.Part.from_text(text_part))

            for fid in file_ids:
                try:
                    img_bytes = _download_file(fid)
                    mime = mimetypes.guess_type(fid)[0] or "image/png"
                    parts.append(
                        gtypes.Part.from_bytes(data=img_bytes, mime_type=mime)
                    )
                except Exception as exc:
                    parts.append(gtypes.Part.from_text(f"[FILE {fid} ERROR: {exc}]"))

            answer = self.model.call_parts(parts, system_instruction=self.SYSTEM_PROMPT)

        # ---------- plain-text branch ------------------------------------- #
        else:
            # Prepend system prompt to make sure CodeAgent->model sees it.
            prompt = f"{self.SYSTEM_PROMPT}\n\n{question}"
            answer = self.agent.model(prompt, system_instruction=self.SYSTEM_PROMPT)

        return answer.rstrip(" .\n\r\t")