File size: 14,100 Bytes
9e909f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2982302
9e909f5
 
 
 
 
 
 
 
 
 
 
 
 
23e6380
9e909f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bf222a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e909f5
 
 
 
 
 
 
9bf222a
9e909f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bf222a
9e909f5
 
2982302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e909f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2982302
9e909f5
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
from google.adk.plugins.save_files_as_artifacts_plugin import SaveFilesAsArtifactsPlugin
from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
from google.adk.tools.tool_context import ToolContext
from google.adk.tools.base_tool import BaseTool
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents import LlmAgent
from google.adk.models import LlmResponse, LlmRequest
from google.adk.models.lite_llm import LiteLlm
from google.adk.apps import App
from google.genai import types
from mcp import ClientSession, StdioServerParameters
from mcp.types import CallToolResult, TextContent
from mcp.client.stdio import stdio_client
from typing import Dict, Any, Optional, Tuple
from prompts import Root, Run, Data, Plot, Install
import base64
import os

# Define MCP server parameters
server_params = StdioServerParameters(
    command="Rscript",
    args=[
        # Use --vanilla to ignore .Rprofile, which is meant for the R instance running mcp_session()
        "--vanilla",
        "server.R",
    ],
)
# STDIO transport to local R MCP server
connection_params = StdioConnectionParams(server_params=server_params, timeout=60)

# Define model
# If we're using the OpenAI API, get the value of OPENAI_MODEL_NAME set by entrypoint.sh
# If we're using an OpenAI-compatible endpoint (Docker Model Runner), use a fake API key
model = LiteLlm(
    model=os.environ.get("OPENAI_MODEL_NAME", ""),
    api_key=os.environ.get("OPENAI_API_KEY", "fake-API-key"),
)


async def select_r_session(
    callback_context: CallbackContext,
) -> Optional[types.Content]:
    """
    Callback function to select the first R session.
    """
    async with stdio_client(server_params) as (read, write):
        async with ClientSession(read, write) as session:
            await session.initialize()
            await session.call_tool("select_r_session", {"session": 1})
            print("[select_r_session] R session selected!")
    # Return None to allow the LlmAgent's normal execution
    return None


async def catch_tool_errors(tool: BaseTool, args: dict, tool_context: ToolContext):
    """
    Callback function to catch errors from tool calls and turn them into a message.
    Modified from https://github.com/google/adk-python/discussions/795#discussioncomment-13460659
    """
    try:
        return await tool.run_async(args=args, tool_context=tool_context)
    except Exception as e:
        # Format the error as a tool response
        # https://github.com/google/adk-python/commit/4df926388b6e9ebcf517fbacf2f5532fd73b0f71
        response = CallToolResult(
            # The error has class McpError; use e.error.message to get the text
            content=[TextContent(type="text", text=e.error.message)],
            isError=True,
        )
        return response.model_dump(exclude_none=True, mode="json")


async def preprocess_artifact(
    callback_context: CallbackContext, llm_request: LlmRequest
) -> Optional[LlmResponse]:
    """
    Callback function to copy the latest artifact to a temporary file.
    """

    # Callback and artifact handling code modified from:
    # https://google.github.io/adk-docs/callbacks/types-of-callbacks/#before-model-callback
    # https://github.com/google/adk-python/issues/2176#issuecomment-3395469070

    # Get the last user message in the request contents
    last_user_message = llm_request.contents[-1].parts[-1].text

    # Function call events have no text part, so set this to "" for string search in the next step
    if last_user_message is None:
        last_user_message = ""

    # If a file was uploaded then SaveFilesAsArtifactsPlugin() adds "[Uploaded Artifact: file_name.csv]" to the user message
    # Check for "Uploaded Artifact:" in the last user message
    if "Uploaded Artifact:" in last_user_message:

        # Add a text part only if there are any issues with accessing or saving the artifact
        added_text = ""
        # List available artifacts
        artifacts = await callback_context.list_artifacts()
        if len(artifacts) == 0:
            added_text = "No uploaded file is available"
        else:
            most_recent_file = artifacts[-1]
            try:
                # Get artifact and byte data
                artifact = await callback_context.load_artifact(
                    filename=most_recent_file
                )
                byte_data = artifact.inline_data.data
                # Save artifact as temporary file
                tmp_dir = "/tmp/uploads"
                tmp_file_path = os.path.join(tmp_dir, most_recent_file)
                # Write the file
                with open(tmp_file_path, "wb") as f:
                    f.write(byte_data)
                # Set appropriate permissions
                os.chmod(tmp_file_path, 0o644)
                print(f"[preprocess_artifact] Saved artifact as '{tmp_file_path}'")

            except Exception as e:
                added_text = f"Error processing artifact: {str(e)}"

        # If there were any issues, add a new part to the user message
        if added_text:
            # llm_request.contents[-1].parts.append(types.Part(text=added_text))
            llm_request.contents[0].parts.append(types.Part(text=added_text))
            print(
                f"[preprocess_artifact] Added text part to user message: '{added_text}'"
            )

    # Return None to allow the possibly modified request to go to the LLM
    return None


async def preprocess_messages(
    callback_context: CallbackContext, llm_request: LlmRequest
) -> Optional[LlmResponse]:
    """
    Callback function to modify user messages to point to temporary artifact file paths.
    """

    # Changes to session state made by callbacks are not preserved across events
    # See: https://github.com/google/adk-docs/issues/904
    # Therefore, for every callback invocation we need to loop over all events, not just the most recent one
    for i in range(len(llm_request.contents)):
        # Inspect the user message in the request contents
        user_message = llm_request.contents[i].parts[-1].text
        if user_message:
            # Modify file path in user message
            # Original file path inserted by SaveFilesAsArtifactsPlugin():
            #   [Uploaded Artifact: "breast-cancer.csv"]
            # Modified file path used by preprocess_artifact():
            #   [Uploaded File: "/tmp/uploads/breast-cancer.csv"]
            tmp_dir = "/tmp/uploads/"
            if '[Uploaded Artifact: "' in user_message:
                user_message = user_message.replace(
                    '[Uploaded Artifact: "', f'[Uploaded File: "{tmp_dir}'
                )
                llm_request.contents[i].parts[-1].text = user_message
                print(f"[preprocess_messages] Modified user message: '{user_message}'")

    return None


def detect_file_type(byte_data: bytes) -> Tuple[str, str]:
    """
    Detect file type from magic number/bytes and return (mime_type, file_extension).
    Supports BMP, JPEG, PNG, TIFF, and PDF.
    """
    if len(byte_data) < 8:
        # Default to PNG if we can't determine
        return "image/png", "png"

    # Check magic numbers
    if byte_data.startswith(b"\x89PNG\r\n\x1a\n"):
        return "image/png", "png"
    elif byte_data.startswith(b"\xff\xd8\xff"):
        return "image/jpeg", "jpg"
    elif byte_data.startswith(b"BM"):
        return "image/bmp", "bmp"
    elif byte_data.startswith(b"II*\x00") or byte_data.startswith(b"MM\x00*"):
        return "image/tiff", "tiff"
    elif byte_data.startswith(b"%PDF"):
        return "application/pdf", "pdf"
    else:
        # Default to PNG if we can't determine
        return "image/png", "png"


async def skip_summarization_for_plot_success(
    tool: BaseTool, args: Dict[str, Any], tool_context: ToolContext, tool_response: Dict
) -> Optional[Dict]:
    """
    Callback function to turn off summarization if plot succeeded.
    """

    # If there was an error making the plot, the LLM tells the user what happened.
    # This happens because skip_summarization is False by default.

    # But if the plot was created successfully, there's
    # no need for an extra LLM call to tell us it's there.
    if tool.name in ["make_plot", "make_ggplot"]:
        if not tool_response["isError"]:
            tool_context.actions.skip_summarization = True

    return None


async def save_plot_artifact(
    tool: BaseTool, args: Dict[str, Any], tool_context: ToolContext, tool_response: Dict
) -> Optional[Dict]:
    """
    Callback function to save plot files as an ADK artifact.
    """

    # Look for plot tool (so we don't bother with transfer_to_agent or other functions)
    if tool.name in ["make_plot", "make_ggplot"]:
        # In ADK 1.17.0, tool_response is a dict (i.e. result of model_dump method invoked on MCP CallToolResult instance):
        # https://github.com/google/adk-python/commit/4df926388b6e9ebcf517fbacf2f5532fd73b0f71
        # https://github.com/modelcontextprotocol/python-sdk?tab=readme-ov-file#parsing-tool-results
        if "content" in tool_response and not tool_response["isError"]:
            for content in tool_response["content"]:
                if "type" in content and content["type"] == "text":
                    # Convert tool response (hex string) to bytes
                    byte_data = bytes.fromhex(content["text"])

                    # Detect file type from magic number
                    mime_type, file_extension = detect_file_type(byte_data)

                    # Encode binary data to Base64 format
                    encoded = base64.b64encode(byte_data).decode("utf-8")
                    artifact_part = types.Part(
                        inline_data={
                            "data": encoded,
                            "mime_type": mime_type,
                        }
                    )
                    # Use second part of tool name (e.g. make_ggplot -> ggplot.png)
                    filename = f"{tool.name.split("_", 1)[1]}.{file_extension}"
                    await tool_context.save_artifact(
                        filename=filename, artifact=artifact_part
                    )
                    # Format the success message as a tool response
                    text = f"Plot created and saved as an artifact: {filename}"
                    response = CallToolResult(
                        content=[TextContent(type="text", text=text)],
                    )
                    return response.model_dump(exclude_none=True, mode="json")

    # Passthrough for other tools or no matching content (e.g. tool error)
    return None


# Create agent to run R code
run_agent = LlmAgent(
    name="Run",
    description="Runs R code without making plots. Use the `Run` agent for executing code that does not load data or make a plot.",
    model=model,
    instruction=Run,
    tools=[
        McpToolset(
            connection_params=connection_params,
            tool_filter=["run_visible", "run_hidden"],
        )
    ],
    before_model_callback=[preprocess_artifact, preprocess_messages],
    before_tool_callback=catch_tool_errors,
)

# Create agent to load data
data_agent = LlmAgent(
    name="Data",
    description="Loads data into an R data frame and summarizes it. Use the `Data` agent for loading data from a file or URL before making a plot.",
    model=model,
    instruction=Data,
    tools=[
        McpToolset(
            connection_params=connection_params,
            tool_filter=["run_visible"],
        )
    ],
    before_model_callback=[preprocess_artifact, preprocess_messages],
    before_tool_callback=catch_tool_errors,
)

# Create agent to make plots using R code
plot_agent = LlmAgent(
    name="Plot",
    description="Makes plots using R code. Use the `Plot` agent after loading any required data.",
    model=model,
    instruction=Plot,
    tools=[
        McpToolset(
            connection_params=connection_params,
            tool_filter=["make_plot", "make_ggplot"],
        )
    ],
    before_model_callback=[preprocess_artifact, preprocess_messages],
    before_tool_callback=catch_tool_errors,
    after_tool_callback=[skip_summarization_for_plot_success, save_plot_artifact],
)

# Create agent to install R packages
install_agent = LlmAgent(
    name="Install",
    description="Installs R packages. Use the `Install` agent when an R package needs to be installed.",
    model=model,
    instruction=Install,
    tools=[
        McpToolset(
            connection_params=connection_params,
            tool_filter=["run_visible"],
        )
    ],
    before_model_callback=[preprocess_artifact, preprocess_messages],
    before_tool_callback=catch_tool_errors,
)

# Create parent agent and assign children via sub_agents
root_agent = LlmAgent(
    name="Coordinator",
    # "Use the..." tells sub-agents to transfer to Coordinator for help requests
    description="Multi-agent system for performing actions in R. Use the `Coordinator` agent for getting help on packages, datasets, and functions.",
    model=model,
    instruction=Root,
    # To pass control back to root, the help and run functions should be tools or a ToolAgent (not sub_agent)
    tools=[
        McpToolset(
            connection_params=connection_params,
            tool_filter=["help_package", "help_topic"],
        )
    ],
    sub_agents=[
        run_agent,
        data_agent,
        plot_agent,
        install_agent,
    ],
    # Select R session
    before_agent_callback=select_r_session,
    # Save user-uploaded artifact as a temporary file and modify messages to point to this file
    before_model_callback=[preprocess_artifact, preprocess_messages],
    before_tool_callback=catch_tool_errors,
)

app = App(
    name="PlotMyData",
    root_agent=root_agent,
    # This inserts user messages like '[Uploaded Artifact: "breast-cancer.csv"]'
    plugins=[SaveFilesAsArtifactsPlugin()],
)