from google.adk.plugins.save_files_as_artifacts_plugin import SaveFilesAsArtifactsPlugin from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams from google.adk.tools.mcp_tool.mcp_toolset import McpToolset from google.adk.tools.tool_context import ToolContext from google.adk.tools.base_tool import BaseTool from google.adk.agents.callback_context import CallbackContext from google.adk.agents import LlmAgent from google.adk.models import LlmResponse, LlmRequest from google.adk.models.lite_llm import LiteLlm from google.adk.apps import App from google.genai import types from mcp import ClientSession, StdioServerParameters from mcp.types import CallToolResult, TextContent from mcp.client.stdio import stdio_client from typing import Dict, Any, Optional, Tuple from prompts import Root, Run, Data, Plot, Install import base64 import os # Define MCP server parameters server_params = StdioServerParameters( command="Rscript", args=[ # Use --vanilla to ignore .Rprofile, which is meant for the R instance running mcp_session() "--vanilla", "server.R", ], ) # STDIO transport to local R MCP server connection_params = StdioConnectionParams(server_params=server_params, timeout=60) # Define model # If we're using the OpenAI API, get the value of OPENAI_MODEL_NAME set by entrypoint.sh # If we're using an OpenAI-compatible endpoint (Docker Model Runner), use a fake API key model = LiteLlm( model=os.environ.get("OPENAI_MODEL_NAME", ""), api_key=os.environ.get("OPENAI_API_KEY", "fake-API-key"), ) async def select_r_session( callback_context: CallbackContext, ) -> Optional[types.Content]: """ Callback function to select the first R session. """ async with stdio_client(server_params) as (read, write): async with ClientSession(read, write) as session: await session.initialize() await session.call_tool("select_r_session", {"session": 1}) print("[select_r_session] R session selected!") # Return None to allow the LlmAgent's normal execution return None async def catch_tool_errors(tool: BaseTool, args: dict, tool_context: ToolContext): """ Callback function to catch errors from tool calls and turn them into a message. Modified from https://github.com/google/adk-python/discussions/795#discussioncomment-13460659 """ try: return await tool.run_async(args=args, tool_context=tool_context) except Exception as e: # Format the error as a tool response # https://github.com/google/adk-python/commit/4df926388b6e9ebcf517fbacf2f5532fd73b0f71 response = CallToolResult( # The error has class McpError; use e.error.message to get the text content=[TextContent(type="text", text=e.error.message)], isError=True, ) return response.model_dump(exclude_none=True, mode="json") async def preprocess_artifact( callback_context: CallbackContext, llm_request: LlmRequest ) -> Optional[LlmResponse]: """ Callback function to copy the latest artifact to a temporary file. """ # Callback and artifact handling code modified from: # https://google.github.io/adk-docs/callbacks/types-of-callbacks/#before-model-callback # https://github.com/google/adk-python/issues/2176#issuecomment-3395469070 # Get the last user message in the request contents last_user_message = llm_request.contents[-1].parts[-1].text # Function call events have no text part, so set this to "" for string search in the next step if last_user_message is None: last_user_message = "" # If a file was uploaded then SaveFilesAsArtifactsPlugin() adds "[Uploaded Artifact: file_name.csv]" to the user message # Check for "Uploaded Artifact:" in the last user message if "Uploaded Artifact:" in last_user_message: # Add a text part only if there are any issues with accessing or saving the artifact added_text = "" # List available artifacts artifacts = await callback_context.list_artifacts() if len(artifacts) == 0: added_text = "No uploaded file is available" else: most_recent_file = artifacts[-1] try: # Get artifact and byte data artifact = await callback_context.load_artifact( filename=most_recent_file ) byte_data = artifact.inline_data.data # Save artifact as temporary file tmp_dir = "/tmp/uploads" tmp_file_path = os.path.join(tmp_dir, most_recent_file) # Write the file with open(tmp_file_path, "wb") as f: f.write(byte_data) # Set appropriate permissions os.chmod(tmp_file_path, 0o644) print(f"[preprocess_artifact] Saved artifact as '{tmp_file_path}'") except Exception as e: added_text = f"Error processing artifact: {str(e)}" # If there were any issues, add a new part to the user message if added_text: # llm_request.contents[-1].parts.append(types.Part(text=added_text)) llm_request.contents[0].parts.append(types.Part(text=added_text)) print( f"[preprocess_artifact] Added text part to user message: '{added_text}'" ) # Return None to allow the possibly modified request to go to the LLM return None async def preprocess_messages( callback_context: CallbackContext, llm_request: LlmRequest ) -> Optional[LlmResponse]: """ Callback function to modify user messages to point to temporary artifact file paths. """ # Changes to session state made by callbacks are not preserved across events # See: https://github.com/google/adk-docs/issues/904 # Therefore, for every callback invocation we need to loop over all events, not just the most recent one for i in range(len(llm_request.contents)): # Inspect the user message in the request contents user_message = llm_request.contents[i].parts[-1].text if user_message: # Modify file path in user message # Original file path inserted by SaveFilesAsArtifactsPlugin(): # [Uploaded Artifact: "breast-cancer.csv"] # Modified file path used by preprocess_artifact(): # [Uploaded File: "/tmp/uploads/breast-cancer.csv"] tmp_dir = "/tmp/uploads/" if '[Uploaded Artifact: "' in user_message: user_message = user_message.replace( '[Uploaded Artifact: "', f'[Uploaded File: "{tmp_dir}' ) llm_request.contents[i].parts[-1].text = user_message print(f"[preprocess_messages] Modified user message: '{user_message}'") return None def detect_file_type(byte_data: bytes) -> Tuple[str, str]: """ Detect file type from magic number/bytes and return (mime_type, file_extension). Supports BMP, JPEG, PNG, TIFF, and PDF. """ if len(byte_data) < 8: # Default to PNG if we can't determine return "image/png", "png" # Check magic numbers if byte_data.startswith(b"\x89PNG\r\n\x1a\n"): return "image/png", "png" elif byte_data.startswith(b"\xff\xd8\xff"): return "image/jpeg", "jpg" elif byte_data.startswith(b"BM"): return "image/bmp", "bmp" elif byte_data.startswith(b"II*\x00") or byte_data.startswith(b"MM\x00*"): return "image/tiff", "tiff" elif byte_data.startswith(b"%PDF"): return "application/pdf", "pdf" else: # Default to PNG if we can't determine return "image/png", "png" async def skip_summarization_for_plot_success( tool: BaseTool, args: Dict[str, Any], tool_context: ToolContext, tool_response: Dict ) -> Optional[Dict]: """ Callback function to turn off summarization if plot succeeded. """ # If there was an error making the plot, the LLM tells the user what happened. # This happens because skip_summarization is False by default. # But if the plot was created successfully, there's # no need for an extra LLM call to tell us it's there. if tool.name in ["make_plot", "make_ggplot"]: if not tool_response["isError"]: tool_context.actions.skip_summarization = True return None async def save_plot_artifact( tool: BaseTool, args: Dict[str, Any], tool_context: ToolContext, tool_response: Dict ) -> Optional[Dict]: """ Callback function to save plot files as an ADK artifact. """ # Look for plot tool (so we don't bother with transfer_to_agent or other functions) if tool.name in ["make_plot", "make_ggplot"]: # In ADK 1.17.0, tool_response is a dict (i.e. result of model_dump method invoked on MCP CallToolResult instance): # https://github.com/google/adk-python/commit/4df926388b6e9ebcf517fbacf2f5532fd73b0f71 # https://github.com/modelcontextprotocol/python-sdk?tab=readme-ov-file#parsing-tool-results if "content" in tool_response and not tool_response["isError"]: for content in tool_response["content"]: if "type" in content and content["type"] == "text": # Convert tool response (hex string) to bytes byte_data = bytes.fromhex(content["text"]) # Detect file type from magic number mime_type, file_extension = detect_file_type(byte_data) # Encode binary data to Base64 format encoded = base64.b64encode(byte_data).decode("utf-8") artifact_part = types.Part( inline_data={ "data": encoded, "mime_type": mime_type, } ) # Use second part of tool name (e.g. make_ggplot -> ggplot.png) filename = f"{tool.name.split("_", 1)[1]}.{file_extension}" await tool_context.save_artifact( filename=filename, artifact=artifact_part ) # Format the success message as a tool response text = f"Plot created and saved as an artifact: {filename}" response = CallToolResult( content=[TextContent(type="text", text=text)], ) return response.model_dump(exclude_none=True, mode="json") # Passthrough for other tools or no matching content (e.g. tool error) return None # Create agent to run R code run_agent = LlmAgent( name="Run", description="Runs R code without making plots. Use the `Run` agent for executing code that does not load data or make a plot.", model=model, instruction=Run, tools=[ McpToolset( connection_params=connection_params, tool_filter=["run_visible", "run_hidden"], ) ], before_model_callback=[preprocess_artifact, preprocess_messages], before_tool_callback=catch_tool_errors, ) # Create agent to load data data_agent = LlmAgent( name="Data", description="Loads data into an R data frame and summarizes it. Use the `Data` agent for loading data from a file or URL before making a plot.", model=model, instruction=Data, tools=[ McpToolset( connection_params=connection_params, tool_filter=["run_visible"], ) ], before_model_callback=[preprocess_artifact, preprocess_messages], before_tool_callback=catch_tool_errors, ) # Create agent to make plots using R code plot_agent = LlmAgent( name="Plot", description="Makes plots using R code. Use the `Plot` agent after loading any required data.", model=model, instruction=Plot, tools=[ McpToolset( connection_params=connection_params, tool_filter=["make_plot", "make_ggplot"], ) ], before_model_callback=[preprocess_artifact, preprocess_messages], before_tool_callback=catch_tool_errors, after_tool_callback=[skip_summarization_for_plot_success, save_plot_artifact], ) # Create agent to install R packages install_agent = LlmAgent( name="Install", description="Installs R packages. Use the `Install` agent when an R package needs to be installed.", model=model, instruction=Install, tools=[ McpToolset( connection_params=connection_params, tool_filter=["run_visible"], ) ], before_model_callback=[preprocess_artifact, preprocess_messages], before_tool_callback=catch_tool_errors, ) # Create parent agent and assign children via sub_agents root_agent = LlmAgent( name="Coordinator", # "Use the..." tells sub-agents to transfer to Coordinator for help requests description="Multi-agent system for performing actions in R. Use the `Coordinator` agent for getting help on packages, datasets, and functions.", model=model, instruction=Root, # To pass control back to root, the help and run functions should be tools or a ToolAgent (not sub_agent) tools=[ McpToolset( connection_params=connection_params, tool_filter=["help_package", "help_topic"], ) ], sub_agents=[ run_agent, data_agent, plot_agent, install_agent, ], # Select R session before_agent_callback=select_r_session, # Save user-uploaded artifact as a temporary file and modify messages to point to this file before_model_callback=[preprocess_artifact, preprocess_messages], before_tool_callback=catch_tool_errors, ) app = App( name="PlotMyData", root_agent=root_agent, # This inserts user messages like '[Uploaded Artifact: "breast-cancer.csv"]' plugins=[SaveFilesAsArtifactsPlugin()], )