burtenshaw HF Staff commited on
Commit
8b900a4
·
verified ·
1 Parent(s): 3db19d9

Upload folder using huggingface_hub

Browse files
Dockerfile CHANGED
@@ -17,7 +17,7 @@ RUN apt-get update && apt-get install -y \
17
  COPY envs/coding_env/ ./envs/coding_env/
18
 
19
  # Install openenv-core first from PyPI, then coding_env
20
- RUN pip install --no-cache-dir "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git@main" && \
21
  pip install --no-cache-dir ./envs/coding_env/
22
 
23
  # Environment variables
 
17
  COPY envs/coding_env/ ./envs/coding_env/
18
 
19
  # Install openenv-core first from PyPI, then coding_env
20
+ RUN pip install --no-cache-dir "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git@v0.2.3" && \
21
  pip install --no-cache-dir ./envs/coding_env/
22
 
23
  # Environment variables
README.md CHANGED
@@ -8,7 +8,7 @@ pinned: false
8
  app_port: 8000
9
  base_path: /web
10
  tags:
11
- - openenv-0.2.2
12
  - openenv
13
  ---
14
 
@@ -17,7 +17,7 @@ tags:
17
  This Space is built from OpenEnv environment `coding_env`.
18
 
19
  - Space URL: `https://huggingface.co/spaces/openenv/coding_env`
20
- - OpenEnv pinned ref: `0.2.2`
21
  - Hub tag: `openenv`
22
 
23
  ### Connecting from Code
 
8
  app_port: 8000
9
  base_path: /web
10
  tags:
11
+ - openenv-0.2.3
12
  - openenv
13
  ---
14
 
 
17
  This Space is built from OpenEnv environment `coding_env`.
18
 
19
  - Space URL: `https://huggingface.co/spaces/openenv/coding_env`
20
+ - OpenEnv pinned ref: `0.2.3`
21
  - Hub tag: `openenv`
22
 
23
  ### Connecting from Code
envs/coding_env/pyproject.toml CHANGED
@@ -8,7 +8,7 @@ version = "0.1.0"
8
  description = "Coding Environment for OpenEnv"
9
  requires-python = ">=3.10"
10
  dependencies = [
11
- "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git@main",
12
  "fastapi>=0.115.0",
13
  "pydantic>=2.0.0",
14
  "uvicorn[standard]>=0.24.0",
 
8
  description = "Coding Environment for OpenEnv"
9
  requires-python = ">=3.10"
10
  dependencies = [
11
+ "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git@v0.2.3",
12
  "fastapi>=0.115.0",
13
  "pydantic>=2.0.0",
14
  "uvicorn[standard]>=0.24.0",
envs/coding_env/server/Dockerfile CHANGED
@@ -17,7 +17,7 @@ RUN apt-get update && apt-get install -y \
17
  COPY envs/coding_env/ ./envs/coding_env/
18
 
19
  # Install openenv-core first from PyPI, then coding_env
20
- RUN pip install --no-cache-dir "openenv-core[core]>=0.2.1" && \
21
  pip install --no-cache-dir ./envs/coding_env/
22
 
23
  # Environment variables
 
17
  COPY envs/coding_env/ ./envs/coding_env/
18
 
19
  # Install openenv-core first from PyPI, then coding_env
20
+ RUN pip install --no-cache-dir "openenv-core[core]>=0.2.2" && \
21
  pip install --no-cache-dir ./envs/coding_env/
22
 
23
  # Environment variables
pyproject.toml CHANGED
@@ -8,7 +8,7 @@ version = "0.1.0"
8
  description = "Coding Environment for OpenEnv"
9
  requires-python = ">=3.10"
10
  dependencies = [
11
- "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git@main",
12
  "fastapi>=0.115.0",
13
  "pydantic>=2.0.0",
14
  "uvicorn[standard]>=0.24.0",
 
8
  description = "Coding Environment for OpenEnv"
9
  requires-python = ">=3.10"
10
  dependencies = [
11
+ "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git@v0.2.3",
12
  "fastapi>=0.115.0",
13
  "pydantic>=2.0.0",
14
  "uvicorn[standard]>=0.24.0",
server/Dockerfile CHANGED
@@ -17,7 +17,7 @@ RUN apt-get update && apt-get install -y \
17
  COPY envs/coding_env/ ./envs/coding_env/
18
 
19
  # Install openenv-core first from PyPI, then coding_env
20
- RUN pip install --no-cache-dir "openenv-core[core]>=0.2.1" && \
21
  pip install --no-cache-dir ./envs/coding_env/
22
 
23
  # Environment variables
 
17
  COPY envs/coding_env/ ./envs/coding_env/
18
 
19
  # Install openenv-core first from PyPI, then coding_env
20
+ RUN pip install --no-cache-dir "openenv-core[core]>=0.2.2" && \
21
  pip install --no-cache-dir ./envs/coding_env/
22
 
23
  # Environment variables
src/core/env_server/http_server.py CHANGED
@@ -16,11 +16,15 @@ from __future__ import annotations
16
  import asyncio
17
  import inspect
18
  import json
 
19
  import os
20
  import time
21
  import uuid
22
  from concurrent.futures import ThreadPoolExecutor
23
- from typing import Any, Callable, Dict, Optional, Type
 
 
 
24
 
25
  from fastapi import (
26
  Body,
@@ -204,8 +208,9 @@ class HTTPEnvServer:
204
  self.observation_cls = observation_cls
205
 
206
  # Session management for WebSocket connections
207
- self._sessions: Dict[str, Environment] = {}
208
  self._session_executors: Dict[str, ThreadPoolExecutor] = {}
 
209
  self._session_info: Dict[str, SessionInfo] = {}
210
  self._session_lock = asyncio.Lock()
211
 
@@ -213,6 +218,14 @@ class HTTPEnvServer:
213
  # This is needed for environments using sync libraries (e.g., Playwright)
214
  self._executor = ThreadPoolExecutor(max_workers=32)
215
 
 
 
 
 
 
 
 
 
216
  def _validate_concurrency_safety(self) -> None:
217
  """
218
  Validate that the environment supports the configured concurrency level.
@@ -321,12 +334,37 @@ class HTTPEnvServer:
321
  )
322
  raise EnvironmentFactoryError(factory_name) from e
323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  async with self._session_lock:
325
  self._sessions[session_id] = env
 
 
326
  self._session_info[session_id] = SessionInfo(
327
  session_id=session_id,
328
  created_at=current_time,
329
- last_activity_at=current_time,
330
  step_count=0,
331
  environment_type=type(env).__name__,
332
  )
@@ -343,8 +381,27 @@ class HTTPEnvServer:
343
  async with self._session_lock:
344
  env = self._sessions.pop(session_id, None)
345
  executor = self._session_executors.pop(session_id, None)
 
346
  self._session_info.pop(session_id, None)
347
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  # Run close() in the same executor where the env was created
349
  # This is required for thread-sensitive libraries like Playwright/greenlet
350
  if env is not None:
@@ -383,6 +440,51 @@ class HTTPEnvServer:
383
  if increment_step:
384
  self._session_info[session_id].step_count += 1
385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  def get_session_info(self, session_id: str) -> Optional[SessionInfo]:
387
  """
388
  Get information about a specific session.
@@ -458,6 +560,20 @@ class HTTPEnvServer:
458
  f"Invalid mode: '{mode}'. Must be one of: {valid_modes}"
459
  )
460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  # Helper function to handle reset endpoint
462
  async def reset_handler(
463
  request: ResetRequest = Body(default_factory=ResetRequest),
@@ -526,53 +642,214 @@ class HTTPEnvServer:
526
 
527
  # Helper function to handle MCP endpoint
528
  async def mcp_handler(
529
- request: JsonRpcRequest, session_env: Optional[Environment] = None
 
 
530
  ) -> JsonRpcResponse:
531
  """
532
  Handle MCP JSON-RPC requests.
533
 
534
- Supports tools/list and tools/call methods in JSON-RPC 2.0 format.
 
 
 
535
  """
536
  method = request.method
537
  request_id = request.id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538
 
539
  # Use provided session environment or create temporary one
540
  if session_env is not None:
541
  _env = session_env
542
  should_close = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
  else:
544
  _env = self._env_factory()
545
  should_close = True
546
  try:
 
 
 
 
547
  if method == McpMethod.TOOLS_LIST:
548
  # Check if environment is MCP-enabled
549
- if not hasattr(_env, "mcp_client"):
550
  return JsonRpcResponse.error_response(
551
  JsonRpcErrorCode.INTERNAL_ERROR,
552
  "Environment does not support MCP",
553
  request_id=request_id,
554
  )
555
 
556
- # Use async context manager for MCP client
557
- async with _env.mcp_client:
558
- tools = await _env.mcp_client.list_tools()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
559
 
560
- return JsonRpcResponse.success(
561
- result={
562
- "tools": [
563
- t.model_dump() if hasattr(t, "model_dump") else dict(t)
564
- for t in tools
565
- ]
566
- },
 
 
 
 
 
 
 
 
 
 
 
567
  request_id=request_id,
568
  )
569
 
570
  elif method == McpMethod.TOOLS_CALL:
571
- params = request.params
572
  tool_name = params.get("name")
573
  arguments = params.get("arguments", {})
574
 
575
- if not hasattr(_env, "mcp_client"):
576
  return JsonRpcResponse.error_response(
577
  JsonRpcErrorCode.INTERNAL_ERROR,
578
  "Environment does not support MCP",
@@ -581,15 +858,51 @@ class HTTPEnvServer:
581
 
582
  if not tool_name:
583
  return JsonRpcResponse.error_response(
584
- JsonRpcErrorCode.INVALID_REQUEST,
585
  "Missing 'name' in params",
586
  request_id=request_id,
587
  )
588
 
589
- # Use async context manager for MCP client
590
- async with _env.mcp_client:
591
- result = await _env.mcp_client.call_tool(
592
- name=tool_name, arguments=arguments
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593
  )
594
 
595
  # Ensure result is JSON serializable
@@ -614,6 +927,11 @@ class HTTPEnvServer:
614
  request_id=request_id,
615
  )
616
  finally:
 
 
 
 
 
617
  if should_close:
618
  _env.close()
619
 
@@ -637,42 +955,59 @@ class HTTPEnvServer:
637
  try:
638
  # Create session with dedicated environment
639
  session_id, session_env = await self._create_session()
 
 
 
 
640
 
641
- while True:
642
- # Receive message from client
643
- raw_message = await websocket.receive_text()
644
-
645
- try:
646
- jsonrpc_dict = json.loads(raw_message)
647
- jsonrpc_request = JsonRpcRequest(**jsonrpc_dict)
648
- except json.JSONDecodeError as e:
649
- error_resp = JsonRpcResponse.error_response(
650
- JsonRpcErrorCode.PARSE_ERROR,
651
- f"Parse error: {e}",
652
- )
653
- await websocket.send_text(error_resp.model_dump_json())
654
- continue
655
- except ValidationError as e:
656
- error_resp = JsonRpcResponse.error_response(
657
- JsonRpcErrorCode.INVALID_REQUEST,
658
- f"Invalid request: {e}",
659
- )
660
- await websocket.send_text(error_resp.model_dump_json())
661
- continue
662
 
663
- try:
664
- # Call mcp_handler with session environment
665
- response = await mcp_handler(
666
- jsonrpc_request, session_env=session_env
 
667
  )
668
- await websocket.send_text(response.model_dump_json())
669
- except Exception as e:
670
- error_resp = JsonRpcResponse.error_response(
671
- JsonRpcErrorCode.INTERNAL_ERROR,
672
- str(e),
673
- request_id=jsonrpc_request.id,
674
- )
675
- await websocket.send_text(error_resp.model_dump_json())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
 
677
  except WebSocketDisconnect:
678
  pass
@@ -931,120 +1266,8 @@ all schema information needed to interact with the environment.
931
  JsonRpcErrorCode.PARSE_ERROR
932
  ).model_dump()
933
 
934
- method = request.method
935
- params = request.params
936
- request_id = request.id
937
-
938
- # Create a temporary environment for MCP access
939
- _env = self._env_factory()
940
-
941
- try:
942
- # Check if environment supports MCP
943
- if not hasattr(_env, "mcp_client") and not hasattr(_env, "mcp_server"):
944
- return JsonRpcResponse.error_response(
945
- JsonRpcErrorCode.INTERNAL_ERROR,
946
- "Environment does not support MCP",
947
- request_id=request_id,
948
- ).model_dump()
949
-
950
- if method == McpMethod.TOOLS_LIST:
951
- # List tools from MCP server
952
- if hasattr(_env, "mcp_client") and _env.mcp_client:
953
- async with _env.mcp_client:
954
- tools = await _env.mcp_client.list_tools()
955
- return JsonRpcResponse.success(
956
- result={
957
- "tools": [
958
- t.model_dump()
959
- if hasattr(t, "model_dump")
960
- else dict(t)
961
- for t in tools
962
- ]
963
- },
964
- request_id=request_id,
965
- ).model_dump()
966
- elif hasattr(_env, "mcp_server") and _env.mcp_server:
967
- # Use server directly
968
- tools = []
969
- for tool_name, tool in get_server_tools(
970
- _env.mcp_server
971
- ).items():
972
- tool_dict = {
973
- "name": tool.name,
974
- "description": tool.description or "",
975
- "inputSchema": tool.parameters or {},
976
- }
977
- tools.append(tool_dict)
978
- return JsonRpcResponse.success(
979
- result={"tools": tools},
980
- request_id=request_id,
981
- ).model_dump()
982
- else:
983
- return JsonRpcResponse.error_response(
984
- JsonRpcErrorCode.INTERNAL_ERROR,
985
- "MCP server not available",
986
- request_id=request_id,
987
- ).model_dump()
988
-
989
- elif method == McpMethod.TOOLS_CALL:
990
- tool_name = params.get("name")
991
- arguments = params.get("arguments", {})
992
-
993
- if not tool_name:
994
- return JsonRpcResponse.error_response(
995
- JsonRpcErrorCode.INVALID_PARAMS,
996
- "Invalid params - 'name' is required",
997
- request_id=request_id,
998
- ).model_dump()
999
-
1000
- # Call tool via MCP
1001
- if hasattr(_env, "mcp_client") and _env.mcp_client:
1002
- async with _env.mcp_client:
1003
- result = await _env.mcp_client.call_tool(
1004
- name=tool_name, arguments=arguments
1005
- )
1006
- elif hasattr(_env, "mcp_server") and _env.mcp_server:
1007
- # Call tool directly on FastMCP server
1008
- server_tools = get_server_tools(_env.mcp_server)
1009
- if tool_name in server_tools:
1010
- tool = server_tools[tool_name]
1011
- result = tool.fn(**arguments)
1012
- else:
1013
- return JsonRpcResponse.error_response(
1014
- JsonRpcErrorCode.INVALID_PARAMS,
1015
- f"Tool not found: {tool_name}",
1016
- request_id=request_id,
1017
- ).model_dump()
1018
- else:
1019
- return JsonRpcResponse.error_response(
1020
- JsonRpcErrorCode.INTERNAL_ERROR,
1021
- "MCP server not available",
1022
- request_id=request_id,
1023
- ).model_dump()
1024
-
1025
- # Make result JSON serializable
1026
- serializable_result = _make_json_serializable(result)
1027
-
1028
- return JsonRpcResponse.success(
1029
- result=serializable_result,
1030
- request_id=request_id,
1031
- ).model_dump()
1032
-
1033
- else:
1034
- return JsonRpcResponse.error_response(
1035
- JsonRpcErrorCode.METHOD_NOT_FOUND,
1036
- f"Method not found: {method}",
1037
- request_id=request_id,
1038
- ).model_dump()
1039
-
1040
- except Exception as e:
1041
- return JsonRpcResponse.error_response(
1042
- JsonRpcErrorCode.INTERNAL_ERROR,
1043
- str(e),
1044
- request_id=request_id,
1045
- ).model_dump()
1046
- finally:
1047
- _env.close()
1048
 
1049
  # Register WebSocket endpoint for persistent sessions
1050
  @app.websocket("/ws")
@@ -1066,135 +1289,167 @@ all schema information needed to interact with the environment.
1066
  try:
1067
  # Create session with dedicated environment
1068
  session_id, session_env = await self._create_session()
 
 
 
 
1069
 
1070
- while True:
1071
- # Receive message from client
1072
- raw_message = await websocket.receive_text()
1073
 
1074
- try:
1075
- message_dict = json.loads(raw_message)
1076
- except json.JSONDecodeError as e:
1077
- error_resp = WSErrorResponse(
1078
- data={
1079
- "message": f"Invalid JSON: {e}",
1080
- "code": WSErrorCode.INVALID_JSON,
1081
- }
1082
  )
1083
- await websocket.send_text(error_resp.model_dump_json())
1084
- continue
1085
-
1086
- msg_type = message_dict.get("type", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
1087
 
1088
- try:
1089
- match msg_type:
1090
- case "reset":
1091
- msg = WSResetMessage(**message_dict)
1092
 
1093
- is_async = (
1094
- session_env.reset_async.__func__
1095
- is not Environment.reset_async
1096
- )
1097
 
1098
- if is_async:
1099
- sig = inspect.signature(session_env.reset_async)
1100
- valid_kwargs = self._get_valid_kwargs(sig, msg.data)
1101
- observation = await session_env.reset_async(
1102
- **valid_kwargs
1103
  )
1104
- else:
1105
- sig = inspect.signature(session_env.reset)
1106
- valid_kwargs = self._get_valid_kwargs(sig, msg.data)
1107
- observation = await self._run_in_session_executor(
1108
- session_id, session_env.reset, **valid_kwargs
1109
- )
1110
-
1111
- self._update_session_activity(session_id)
1112
-
1113
- response = WSObservationResponse(
1114
- data=serialize_observation(observation),
1115
- )
1116
 
1117
- case "step":
1118
- msg = WSStepMessage(**message_dict)
1119
- action = deserialize_action(msg.data, self.action_cls)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1120
 
1121
- is_async = (
1122
- session_env.step_async.__func__
1123
- is not Environment.step_async
1124
- )
 
1125
 
1126
- if is_async:
1127
- observation = await session_env.step_async(action)
1128
- else:
1129
- observation = await self._run_in_session_executor(
1130
- session_id, session_env.step, action
1131
  )
1132
 
1133
- self._update_session_activity(
1134
- session_id, increment_step=True
1135
- )
 
 
 
 
 
 
 
 
 
 
 
1136
 
1137
- response = WSObservationResponse(
1138
- data=serialize_observation(observation)
1139
- )
1140
 
1141
- case "state":
1142
- msg = WSStateMessage(**message_dict)
1143
- state = session_env.state
1144
- if hasattr(state, "model_dump"):
1145
- state_data = state.model_dump()
1146
- else:
1147
- state_data = dict(state) if state else {}
1148
-
1149
- response = WSStateResponse(data=state_data)
1150
-
1151
- case "close":
1152
- msg = WSCloseMessage(**message_dict)
1153
- break
1154
-
1155
- case "mcp":
1156
- msg = WSMCPMessage(**message_dict)
1157
- try:
1158
- rpc_request = JsonRpcRequest(**msg.data)
1159
- except (ValidationError, Exception) as e:
1160
- rpc_response = JsonRpcResponse.error_response(
1161
- JsonRpcErrorCode.INVALID_REQUEST,
1162
- f"Invalid request: {e}",
 
 
 
 
 
 
 
 
 
1163
  )
1164
- else:
1165
- rpc_response = await mcp_handler(
1166
- rpc_request,
1167
- session_env=session_env,
 
 
 
1168
  )
1169
- response = WSMCPResponse(data=rpc_response.model_dump())
1170
-
1171
- case _:
1172
- response = WSErrorResponse(
1173
- data={
1174
- "message": f"Unknown message type: {msg_type}",
1175
- "code": WSErrorCode.UNKNOWN_TYPE,
1176
- }
1177
- )
1178
 
1179
- await websocket.send_text(response.model_dump_json())
1180
 
1181
- except ValidationError as e:
1182
- error_resp = WSErrorResponse(
1183
- data={
1184
- "message": "Invalid message",
1185
- "code": WSErrorCode.VALIDATION_ERROR,
1186
- "errors": e.errors(),
1187
- }
1188
- )
1189
- await websocket.send_text(error_resp.model_dump_json())
1190
- except Exception as e:
1191
- error_resp = WSErrorResponse(
1192
- data={
1193
- "message": str(e),
1194
- "code": WSErrorCode.EXECUTION_ERROR,
1195
- }
1196
- )
1197
- await websocket.send_text(error_resp.model_dump_json())
1198
 
1199
  except WebSocketDisconnect:
1200
  pass
@@ -1276,7 +1531,7 @@ def create_app(
1276
  from .web_interface import create_web_interface_app
1277
 
1278
  return create_web_interface_app(
1279
- env,
1280
  action_cls,
1281
  observation_cls,
1282
  env_name,
 
16
  import asyncio
17
  import inspect
18
  import json
19
+ import logging
20
  import os
21
  import time
22
  import uuid
23
  from concurrent.futures import ThreadPoolExecutor
24
+ from contextlib import AsyncExitStack
25
+ from typing import Any, AsyncContextManager, Callable, cast, Dict, Optional, Type
26
+
27
+ _MISSING = object()
28
 
29
  from fastapi import (
30
  Body,
 
208
  self.observation_cls = observation_cls
209
 
210
  # Session management for WebSocket connections
211
+ self._sessions: Dict[str, Optional[Environment]] = {}
212
  self._session_executors: Dict[str, ThreadPoolExecutor] = {}
213
+ self._session_stacks: Dict[str, AsyncExitStack] = {}
214
  self._session_info: Dict[str, SessionInfo] = {}
215
  self._session_lock = asyncio.Lock()
216
 
 
218
  # This is needed for environments using sync libraries (e.g., Playwright)
219
  self._executor = ThreadPoolExecutor(max_workers=32)
220
 
221
+ # Idle session reaper configuration.
222
+ # Timeout is taken from ConcurrencyConfig.session_timeout;
223
+ # None means no timeout (default — reaper is a no-op).
224
+ self._session_idle_timeout_s: Optional[float] = (
225
+ self._concurrency_config.session_timeout
226
+ )
227
+ self._reaper_task: Optional[asyncio.Task[None]] = None
228
+
229
  def _validate_concurrency_safety(self) -> None:
230
  """
231
  Validate that the environment supports the configured concurrency level.
 
334
  )
335
  raise EnvironmentFactoryError(factory_name) from e
336
 
337
+ # Hold the MCP session open for the lifetime of this session,
338
+ # matching the WebSocket path's AsyncExitStack pattern. This
339
+ # prevents per-request MCP transport teardown/reconnection and
340
+ # preserves FastMCP session state (ctx.set_state / ctx.get_state)
341
+ # across HTTP calls within the same OpenEnv session.
342
+ stack = AsyncExitStack()
343
+ try:
344
+ mcp_session_factory = getattr(env, "mcp_session", None)
345
+ if callable(mcp_session_factory):
346
+ mcp_session_cm = cast(AsyncContextManager[Any], mcp_session_factory())
347
+ await stack.enter_async_context(mcp_session_cm)
348
+ except Exception:
349
+ # MCP transport failed to start — clean up the reserved slot,
350
+ # the env, and the executor so they don't leak permanently
351
+ # against _max_concurrent_envs.
352
+ await stack.aclose() # best-effort
353
+ async with self._session_lock:
354
+ self._sessions.pop(session_id, None)
355
+ self._session_executors.pop(session_id, None)
356
+ self._session_info.pop(session_id, None)
357
+ await self._cleanup_session_resources(env, executor)
358
+ raise
359
+
360
  async with self._session_lock:
361
  self._sessions[session_id] = env
362
+ self._session_stacks[session_id] = stack
363
+ now = time.time()
364
  self._session_info[session_id] = SessionInfo(
365
  session_id=session_id,
366
  created_at=current_time,
367
+ last_activity_at=now,
368
  step_count=0,
369
  environment_type=type(env).__name__,
370
  )
 
381
  async with self._session_lock:
382
  env = self._sessions.pop(session_id, None)
383
  executor = self._session_executors.pop(session_id, None)
384
+ stack = self._session_stacks.pop(session_id, None)
385
  self._session_info.pop(session_id, None)
386
 
387
+ await self._cleanup_session_resources(env, executor, stack)
388
+
389
+ async def _cleanup_session_resources(
390
+ self,
391
+ env: Optional[Environment],
392
+ executor: Optional[ThreadPoolExecutor],
393
+ stack: Optional[AsyncExitStack] = None,
394
+ ) -> None:
395
+ """Close an environment and shut down its executor (best-effort)."""
396
+ # Close the MCP session stack first — this gracefully exits the
397
+ # mcp_session() context (and the underlying FastMCP Client session)
398
+ # before we tear down the environment references.
399
+ if stack is not None:
400
+ try:
401
+ await stack.aclose()
402
+ except Exception:
403
+ pass # Best effort cleanup
404
+
405
  # Run close() in the same executor where the env was created
406
  # This is required for thread-sensitive libraries like Playwright/greenlet
407
  if env is not None:
 
440
  if increment_step:
441
  self._session_info[session_id].step_count += 1
442
 
443
+ async def _reap_idle_sessions(self) -> None:
444
+ """Background task that periodically destroys sessions idle beyond the timeout."""
445
+ timeout = self._session_idle_timeout_s
446
+ if timeout is None:
447
+ return # no timeout configured — noop
448
+ interval = max(timeout / 4, 5.0) # check frequently enough
449
+ while True:
450
+ try:
451
+ await asyncio.sleep(interval)
452
+ now = time.time()
453
+ stale_ids: list[str] = []
454
+ async with self._session_lock:
455
+ for sid, info in self._session_info.items():
456
+ if now - info.last_activity_at > timeout:
457
+ stale_ids.append(sid)
458
+ for sid in stale_ids:
459
+ # Re-check under lock: activity may have arrived since
460
+ # the snapshot was taken, making this session active again.
461
+ # Refresh `now` so slow _destroy_session calls don't cause
462
+ # subsequent entries to be validated against a stale clock.
463
+ now = time.time()
464
+ async with self._session_lock:
465
+ info = self._session_info.get(sid)
466
+ if info is None or (now - info.last_activity_at) <= timeout:
467
+ continue
468
+ await self._destroy_session(sid)
469
+ except asyncio.CancelledError:
470
+ break
471
+ except Exception as exc:
472
+ logging.getLogger(__name__).warning(
473
+ "Idle-session reaper encountered an error (will retry): %s",
474
+ exc,
475
+ )
476
+
477
+ def _start_reaper(self) -> None:
478
+ """Start the idle-session reaper if a timeout is configured."""
479
+ if self._session_idle_timeout_s is not None and self._reaper_task is None:
480
+ self._reaper_task = asyncio.create_task(self._reap_idle_sessions())
481
+
482
+ def _stop_reaper(self) -> None:
483
+ """Cancel the reaper background task."""
484
+ if self._reaper_task is not None:
485
+ self._reaper_task.cancel()
486
+ self._reaper_task = None
487
+
488
  def get_session_info(self, session_id: str) -> Optional[SessionInfo]:
489
  """
490
  Get information about a specific session.
 
560
  f"Invalid mode: '{mode}'. Must be one of: {valid_modes}"
561
  )
562
 
563
+ # Wire up idle-session reaper lifecycle via app events
564
+ server_ref = self
565
+
566
+ async def _start_session_reaper() -> None:
567
+ server_ref._start_reaper()
568
+
569
+ async def _stop_session_reaper() -> None:
570
+ server_ref._stop_reaper()
571
+
572
+ if not getattr(app.router, "_openenv_reaper_registered", False):
573
+ app.router.on_startup.append(_start_session_reaper)
574
+ app.router.on_shutdown.append(_stop_session_reaper)
575
+ app.router._openenv_reaper_registered = True # type: ignore[attr-defined]
576
+
577
  # Helper function to handle reset endpoint
578
  async def reset_handler(
579
  request: ResetRequest = Body(default_factory=ResetRequest),
 
642
 
643
  # Helper function to handle MCP endpoint
644
  async def mcp_handler(
645
+ request: JsonRpcRequest,
646
+ session_env: Optional[Environment] = None,
647
+ session_id: Optional[str] = None,
648
  ) -> JsonRpcResponse:
649
  """
650
  Handle MCP JSON-RPC requests.
651
 
652
+ Supports tools/list and tools/call methods in JSON-RPC 2.0 format,
653
+ plus OpenEnv session lifecycle methods for HTTP MCP:
654
+ - openenv/session/create
655
+ - openenv/session/close
656
  """
657
  method = request.method
658
  request_id = request.id
659
+ params = request.params
660
+ if not isinstance(params, dict):
661
+ return JsonRpcResponse.error_response(
662
+ JsonRpcErrorCode.INVALID_PARAMS,
663
+ "Params must be an object",
664
+ request_id=request_id,
665
+ )
666
+
667
+ # OpenEnv extension methods for explicit MCP session management.
668
+ # This enables persistent MCP lifecycles over HTTP /mcp, matching WebSocket semantics.
669
+ if method == "openenv/session/create":
670
+ if session_env is not None and session_id is not None:
671
+ return JsonRpcResponse.success(
672
+ result={"session_id": session_id},
673
+ request_id=request_id,
674
+ )
675
+ try:
676
+ created_session_id, _ = await self._create_session()
677
+ except SessionCapacityError as e:
678
+ return JsonRpcResponse.error_response(
679
+ JsonRpcErrorCode.SERVER_ERROR,
680
+ str(e),
681
+ request_id=request_id,
682
+ data={
683
+ "active_sessions": e.active_sessions,
684
+ "max_sessions": e.max_sessions,
685
+ },
686
+ )
687
+ except EnvironmentFactoryError as e:
688
+ return JsonRpcResponse.error_response(
689
+ JsonRpcErrorCode.SERVER_ERROR,
690
+ str(e),
691
+ request_id=request_id,
692
+ data={"factory_name": e.factory_name},
693
+ )
694
+ return JsonRpcResponse.success(
695
+ result={"session_id": created_session_id},
696
+ request_id=request_id,
697
+ )
698
+
699
+ if method == "openenv/session/close":
700
+ target_session_id = params.get("session_id")
701
+ if not target_session_id:
702
+ return JsonRpcResponse.error_response(
703
+ JsonRpcErrorCode.INVALID_PARAMS,
704
+ "Invalid params - 'session_id' is required",
705
+ request_id=request_id,
706
+ )
707
+
708
+ if session_id is not None and target_session_id == session_id:
709
+ return JsonRpcResponse.error_response(
710
+ JsonRpcErrorCode.INVALID_REQUEST,
711
+ "Cannot close active WebSocket-managed session via MCP method",
712
+ request_id=request_id,
713
+ )
714
+
715
+ async with self._session_lock:
716
+ env = self._sessions.pop(target_session_id, _MISSING)
717
+ if env is not _MISSING:
718
+ executor = self._session_executors.pop(target_session_id, None)
719
+ stack = self._session_stacks.pop(target_session_id, None)
720
+ self._session_info.pop(target_session_id, None)
721
+ else:
722
+ executor = None
723
+ stack = None
724
+
725
+ if env is _MISSING:
726
+ return JsonRpcResponse.error_response(
727
+ JsonRpcErrorCode.INVALID_PARAMS,
728
+ f"Unknown session_id: {target_session_id}",
729
+ request_id=request_id,
730
+ )
731
+
732
+ if env is None:
733
+ # Session slot reserved but env factory still running;
734
+ # re-insert the placeholder AND the executor so
735
+ # _create_session can finish and the executor remains
736
+ # tracked for eventual shutdown.
737
+ async with self._session_lock:
738
+ self._sessions[target_session_id] = None
739
+ if executor is not None:
740
+ self._session_executors[target_session_id] = executor
741
+ return JsonRpcResponse.error_response(
742
+ JsonRpcErrorCode.INVALID_REQUEST,
743
+ f"Session {target_session_id} is still initializing; retry shortly",
744
+ request_id=request_id,
745
+ )
746
+
747
+ # env/executor/stack cleanup outside the lock
748
+ await self._cleanup_session_resources(env, executor, stack)
749
+ return JsonRpcResponse.success(
750
+ result={"session_id": target_session_id, "closed": True},
751
+ request_id=request_id,
752
+ )
753
+
754
+ requested_session_id = params.get("session_id")
755
+ managed_session_id = session_id
756
 
757
  # Use provided session environment or create temporary one
758
  if session_env is not None:
759
  _env = session_env
760
  should_close = False
761
+ elif requested_session_id:
762
+ async with self._session_lock:
763
+ _env = self._sessions.get(requested_session_id, _MISSING)
764
+
765
+ if _env is _MISSING:
766
+ return JsonRpcResponse.error_response(
767
+ JsonRpcErrorCode.INVALID_PARAMS,
768
+ f"Unknown session_id: {requested_session_id}",
769
+ request_id=request_id,
770
+ )
771
+
772
+ if _env is None:
773
+ return JsonRpcResponse.error_response(
774
+ JsonRpcErrorCode.INVALID_REQUEST,
775
+ f"Session {requested_session_id} is still initializing; retry shortly",
776
+ request_id=request_id,
777
+ )
778
+
779
+ should_close = False
780
+ managed_session_id = requested_session_id
781
  else:
782
  _env = self._env_factory()
783
  should_close = True
784
  try:
785
+ mcp_client = getattr(_env, "mcp_client", None)
786
+ mcp_server = getattr(_env, "mcp_server", None)
787
+ mcp_session_factory = getattr(_env, "mcp_session", None)
788
+
789
  if method == McpMethod.TOOLS_LIST:
790
  # Check if environment is MCP-enabled
791
+ if mcp_client is None and mcp_server is None:
792
  return JsonRpcResponse.error_response(
793
  JsonRpcErrorCode.INTERNAL_ERROR,
794
  "Environment does not support MCP",
795
  request_id=request_id,
796
  )
797
 
798
+ if mcp_client:
799
+ if managed_session_id and mcp_client.is_connected():
800
+ # Session-managed with live transport — call
801
+ # directly, no redundant re-entry.
802
+ tools = await mcp_client.list_tools()
803
+ elif callable(mcp_session_factory):
804
+ # Stateless request, or session-managed but the
805
+ # background transport was lost: (re-)open.
806
+ mcp_session_cm = cast(
807
+ AsyncContextManager[Any], mcp_session_factory()
808
+ )
809
+ async with mcp_session_cm:
810
+ tools = await mcp_client.list_tools()
811
+ else:
812
+ async with mcp_client:
813
+ tools = await mcp_client.list_tools()
814
+
815
+ return JsonRpcResponse.success(
816
+ result={
817
+ "tools": [
818
+ t.model_dump()
819
+ if hasattr(t, "model_dump")
820
+ else dict(t)
821
+ for t in tools
822
+ ]
823
+ },
824
+ request_id=request_id,
825
+ )
826
 
827
+ if mcp_server:
828
+ tools = []
829
+ for _tool_name, tool in get_server_tools(mcp_server).items():
830
+ tools.append(
831
+ {
832
+ "name": tool.name,
833
+ "description": tool.description or "",
834
+ "inputSchema": tool.parameters or {},
835
+ }
836
+ )
837
+ return JsonRpcResponse.success(
838
+ result={"tools": tools},
839
+ request_id=request_id,
840
+ )
841
+
842
+ return JsonRpcResponse.error_response(
843
+ JsonRpcErrorCode.INTERNAL_ERROR,
844
+ "MCP server not available",
845
  request_id=request_id,
846
  )
847
 
848
  elif method == McpMethod.TOOLS_CALL:
 
849
  tool_name = params.get("name")
850
  arguments = params.get("arguments", {})
851
 
852
+ if mcp_client is None and mcp_server is None:
853
  return JsonRpcResponse.error_response(
854
  JsonRpcErrorCode.INTERNAL_ERROR,
855
  "Environment does not support MCP",
 
858
 
859
  if not tool_name:
860
  return JsonRpcResponse.error_response(
861
+ JsonRpcErrorCode.INVALID_PARAMS,
862
  "Missing 'name' in params",
863
  request_id=request_id,
864
  )
865
 
866
+ if mcp_client:
867
+ if managed_session_id and mcp_client.is_connected():
868
+ # Session-managed with live transport.
869
+ result = await mcp_client.call_tool(
870
+ name=tool_name, arguments=arguments
871
+ )
872
+ elif callable(mcp_session_factory):
873
+ # Stateless request, or session-managed but the
874
+ # background transport was lost: (re-)open.
875
+ mcp_session_cm = cast(
876
+ AsyncContextManager[Any], mcp_session_factory()
877
+ )
878
+ async with mcp_session_cm:
879
+ result = await mcp_client.call_tool(
880
+ name=tool_name, arguments=arguments
881
+ )
882
+ else:
883
+ async with mcp_client:
884
+ result = await mcp_client.call_tool(
885
+ name=tool_name, arguments=arguments
886
+ )
887
+ elif mcp_server:
888
+ server_tools = get_server_tools(mcp_server)
889
+ if tool_name in server_tools:
890
+ tool = server_tools[tool_name]
891
+ if inspect.iscoroutinefunction(tool.fn):
892
+ result = await tool.fn(**arguments)
893
+ else:
894
+ result = tool.fn(**arguments)
895
+ else:
896
+ return JsonRpcResponse.error_response(
897
+ JsonRpcErrorCode.INVALID_PARAMS,
898
+ f"Tool not found: {tool_name}",
899
+ request_id=request_id,
900
+ )
901
+ else:
902
+ return JsonRpcResponse.error_response(
903
+ JsonRpcErrorCode.INTERNAL_ERROR,
904
+ "MCP server not available",
905
+ request_id=request_id,
906
  )
907
 
908
  # Ensure result is JSON serializable
 
927
  request_id=request_id,
928
  )
929
  finally:
930
+ if managed_session_id:
931
+ self._update_session_activity(
932
+ managed_session_id,
933
+ increment_step=(method == McpMethod.TOOLS_CALL),
934
+ )
935
  if should_close:
936
  _env.close()
937
 
 
955
  try:
956
  # Create session with dedicated environment
957
  session_id, session_env = await self._create_session()
958
+ if session_env is None:
959
+ raise RuntimeError(
960
+ "Session environment not initialized for MCP websocket"
961
+ )
962
 
963
+ # If environment has an mcp_session context manager, hold it open
964
+ # for the lifetime of the websocket connection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
965
 
966
+ async with AsyncExitStack() as stack:
967
+ mcp_session_factory = getattr(session_env, "mcp_session", None)
968
+ if callable(mcp_session_factory):
969
+ mcp_session_cm = cast(
970
+ AsyncContextManager[Any], mcp_session_factory()
971
  )
972
+ await stack.enter_async_context(mcp_session_cm)
973
+
974
+ while True:
975
+ # Receive message from client
976
+ raw_message = await websocket.receive_text()
977
+
978
+ try:
979
+ jsonrpc_dict = json.loads(raw_message)
980
+ jsonrpc_request = JsonRpcRequest(**jsonrpc_dict)
981
+ except json.JSONDecodeError as e:
982
+ error_resp = JsonRpcResponse.error_response(
983
+ JsonRpcErrorCode.PARSE_ERROR,
984
+ f"Parse error: {e}",
985
+ )
986
+ await websocket.send_text(error_resp.model_dump_json())
987
+ continue
988
+ except ValidationError as e:
989
+ error_resp = JsonRpcResponse.error_response(
990
+ JsonRpcErrorCode.INVALID_REQUEST,
991
+ f"Invalid request: {e}",
992
+ )
993
+ await websocket.send_text(error_resp.model_dump_json())
994
+ continue
995
+
996
+ try:
997
+ # Call mcp_handler with session environment
998
+ response = await mcp_handler(
999
+ jsonrpc_request,
1000
+ session_env=session_env,
1001
+ session_id=session_id,
1002
+ )
1003
+ await websocket.send_text(response.model_dump_json())
1004
+ except Exception as e:
1005
+ error_resp = JsonRpcResponse.error_response(
1006
+ JsonRpcErrorCode.INTERNAL_ERROR,
1007
+ str(e),
1008
+ request_id=jsonrpc_request.id,
1009
+ )
1010
+ await websocket.send_text(error_resp.model_dump_json())
1011
 
1012
  except WebSocketDisconnect:
1013
  pass
 
1266
  JsonRpcErrorCode.PARSE_ERROR
1267
  ).model_dump()
1268
 
1269
+ response = await mcp_handler(request)
1270
+ return response.model_dump()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1271
 
1272
  # Register WebSocket endpoint for persistent sessions
1273
  @app.websocket("/ws")
 
1289
  try:
1290
  # Create session with dedicated environment
1291
  session_id, session_env = await self._create_session()
1292
+ if session_env is None:
1293
+ raise RuntimeError(
1294
+ "Session environment not initialized for websocket"
1295
+ )
1296
 
1297
+ # Keep MCP session open for entire websocket lifetime
1298
+ # (avoids reconnect overhead on every message)
 
1299
 
1300
+ async with AsyncExitStack() as stack:
1301
+ mcp_session_factory = getattr(session_env, "mcp_session", None)
1302
+ if callable(mcp_session_factory):
1303
+ mcp_session_cm = cast(
1304
+ AsyncContextManager[Any], mcp_session_factory()
 
 
 
1305
  )
1306
+ await stack.enter_async_context(mcp_session_cm)
1307
+
1308
+ while True:
1309
+ # Receive message from client
1310
+ raw_message = await websocket.receive_text()
1311
+
1312
+ try:
1313
+ message_dict = json.loads(raw_message)
1314
+ except json.JSONDecodeError as e:
1315
+ error_resp = WSErrorResponse(
1316
+ data={
1317
+ "message": f"Invalid JSON: {e}",
1318
+ "code": WSErrorCode.INVALID_JSON,
1319
+ }
1320
+ )
1321
+ await websocket.send_text(error_resp.model_dump_json())
1322
+ continue
1323
 
1324
+ msg_type = message_dict.get("type", "")
 
 
 
1325
 
1326
+ try:
1327
+ match msg_type:
1328
+ case "reset":
1329
+ msg = WSResetMessage(**message_dict)
1330
 
1331
+ is_async = (
1332
+ session_env.reset_async.__func__
1333
+ is not Environment.reset_async
 
 
1334
  )
 
 
 
 
 
 
 
 
 
 
 
 
1335
 
1336
+ if is_async:
1337
+ sig = inspect.signature(session_env.reset_async)
1338
+ valid_kwargs = self._get_valid_kwargs(
1339
+ sig, msg.data
1340
+ )
1341
+ observation = await session_env.reset_async(
1342
+ **valid_kwargs
1343
+ )
1344
+ else:
1345
+ sig = inspect.signature(session_env.reset)
1346
+ valid_kwargs = self._get_valid_kwargs(
1347
+ sig, msg.data
1348
+ )
1349
+ observation = (
1350
+ await self._run_in_session_executor(
1351
+ session_id,
1352
+ session_env.reset,
1353
+ **valid_kwargs,
1354
+ )
1355
+ )
1356
+
1357
+ self._update_session_activity(session_id)
1358
+
1359
+ response = WSObservationResponse(
1360
+ data=serialize_observation(observation),
1361
+ )
1362
 
1363
+ case "step":
1364
+ msg = WSStepMessage(**message_dict)
1365
+ action = deserialize_action(
1366
+ msg.data, self.action_cls
1367
+ )
1368
 
1369
+ is_async = (
1370
+ session_env.step_async.__func__
1371
+ is not Environment.step_async
 
 
1372
  )
1373
 
1374
+ if is_async:
1375
+ observation = await session_env.step_async(
1376
+ action
1377
+ )
1378
+ else:
1379
+ observation = (
1380
+ await self._run_in_session_executor(
1381
+ session_id, session_env.step, action
1382
+ )
1383
+ )
1384
+
1385
+ self._update_session_activity(
1386
+ session_id, increment_step=True
1387
+ )
1388
 
1389
+ response = WSObservationResponse(
1390
+ data=serialize_observation(observation)
1391
+ )
1392
 
1393
+ case "state":
1394
+ msg = WSStateMessage(**message_dict)
1395
+ state = session_env.state
1396
+ if hasattr(state, "model_dump"):
1397
+ state_data = state.model_dump()
1398
+ else:
1399
+ state_data = dict(state) if state else {}
1400
+
1401
+ response = WSStateResponse(data=state_data)
1402
+
1403
+ case "close":
1404
+ msg = WSCloseMessage(**message_dict)
1405
+ break
1406
+
1407
+ case "mcp":
1408
+ msg = WSMCPMessage(**message_dict)
1409
+ try:
1410
+ rpc_request = JsonRpcRequest(**msg.data)
1411
+ except (ValidationError, Exception) as e:
1412
+ rpc_response = JsonRpcResponse.error_response(
1413
+ JsonRpcErrorCode.INVALID_REQUEST,
1414
+ f"Invalid request: {e}",
1415
+ )
1416
+ else:
1417
+ rpc_response = await mcp_handler(
1418
+ rpc_request,
1419
+ session_env=session_env,
1420
+ session_id=session_id,
1421
+ )
1422
+ response = WSMCPResponse(
1423
+ data=rpc_response.model_dump()
1424
  )
1425
+
1426
+ case _:
1427
+ response = WSErrorResponse(
1428
+ data={
1429
+ "message": f"Unknown message type: {msg_type}",
1430
+ "code": WSErrorCode.UNKNOWN_TYPE,
1431
+ }
1432
  )
 
 
 
 
 
 
 
 
 
1433
 
1434
+ await websocket.send_text(response.model_dump_json())
1435
 
1436
+ except ValidationError as e:
1437
+ error_resp = WSErrorResponse(
1438
+ data={
1439
+ "message": "Invalid message",
1440
+ "code": WSErrorCode.VALIDATION_ERROR,
1441
+ "errors": e.errors(),
1442
+ }
1443
+ )
1444
+ await websocket.send_text(error_resp.model_dump_json())
1445
+ except Exception as e:
1446
+ error_resp = WSErrorResponse(
1447
+ data={
1448
+ "message": str(e),
1449
+ "code": WSErrorCode.EXECUTION_ERROR,
1450
+ }
1451
+ )
1452
+ await websocket.send_text(error_resp.model_dump_json())
1453
 
1454
  except WebSocketDisconnect:
1455
  pass
 
1531
  from .web_interface import create_web_interface_app
1532
 
1533
  return create_web_interface_app(
1534
+ cast(Any, env),
1535
  action_cls,
1536
  observation_cls,
1537
  env_name,
src/core/env_server/mcp_environment.py CHANGED
@@ -56,6 +56,7 @@ import asyncio
56
  import inspect
57
  from abc import abstractmethod
58
  from collections import defaultdict
 
59
  from typing import Any, Callable, Dict, Optional
60
 
61
  from fastmcp import Client
@@ -164,6 +165,52 @@ class MCPEnvironment(Environment):
164
  # Track tool schemas for list_tools: {tool_name: {mode: schema}}
165
  self._mode_tool_schemas = defaultdict(dict)
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  @property
168
  def supports_code_mode(self) -> bool:
169
  """Check if this environment supports code mode (execute_code)."""
@@ -292,7 +339,8 @@ class MCPEnvironment(Environment):
292
 
293
  # If mode is None, register with FastMCP as usual
294
  if mode is None:
295
- decorated_func = self.mcp_server.tool()(func)
 
296
  self._mode_tools[tool_name][None] = func
297
  return decorated_func
298
 
@@ -372,24 +420,49 @@ class MCPEnvironment(Environment):
372
  return self._step_impl(action, timeout_s=timeout_s, **kwargs)
373
 
374
  def _handle_list_tools(self) -> ListToolsObservation:
 
 
 
 
375
  """
376
- Handle a ListToolsAction by querying the MCP server.
377
 
378
  Returns:
379
- ListToolsObservation containing all available tools with their
380
- names, descriptions, and input schemas, filtered by current mode.
381
  """
382
- try:
383
- # Get current mode
384
- current_mode = getattr(self, "_mode", None)
385
 
386
- # Start with tools from FastMCP server (mode=None tools)
387
- tools_result = run_async_safely(self._async_list_tools())
 
 
 
 
 
 
 
388
 
389
- # Build list of Tool objects
390
- tools = []
 
391
 
392
- # Add FastMCP tools that are not mode-specific
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  for tool in tools_result:
394
  if tool.name not in self._mode_tool_schemas:
395
  tools.append(
@@ -401,11 +474,8 @@ class MCPEnvironment(Environment):
401
  else {},
402
  )
403
  )
404
-
405
- # Add mode-specific tools available in current mode
406
  for tool_name, mode_schemas in self._mode_tool_schemas.items():
407
  if None in mode_schemas:
408
- # Tool available in all modes
409
  schema = mode_schemas[None]
410
  tools.append(
411
  Tool(
@@ -415,7 +485,6 @@ class MCPEnvironment(Environment):
415
  )
416
  )
417
  elif current_mode in mode_schemas:
418
- # Tool available in current mode
419
  schema = mode_schemas[current_mode]
420
  tools.append(
421
  Tool(
@@ -424,65 +493,30 @@ class MCPEnvironment(Environment):
424
  input_schema=schema["input_schema"],
425
  )
426
  )
427
-
428
  return ListToolsObservation(tools=tools)
429
-
430
  except Exception as e:
431
- # Return an observation with error in metadata
432
  return ListToolsObservation(
433
  tools=[],
434
- metadata={
435
- "error": str(e),
436
- "error_type": "list_tools_failed",
437
- },
438
  )
439
 
440
- async def _async_list_tools(self) -> list:
441
- """
442
- Async helper to list tools from the MCP client.
443
-
444
- Returns:
445
- List of tool objects from the MCP server.
446
- """
447
- async with self.mcp_client:
448
- return await self.mcp_client.list_tools()
449
-
450
- def _handle_call_tool(
451
  self,
452
  action: CallToolAction,
453
  timeout_s: Optional[float] = None,
454
  ) -> CallToolObservation:
455
- """
456
- Handle a CallToolAction by invoking the specified tool.
457
-
458
- Args:
459
- action: The CallToolAction containing tool_name and arguments.
460
- timeout_s: Timeout in seconds. Defaults to MCP_TOOL_CALL_TIMEOUT (30s).
461
-
462
- Returns:
463
- CallToolObservation with the tool's result or an error.
464
- """
465
  timeout = timeout_s if timeout_s is not None else MCP_TOOL_CALL_TIMEOUT
466
-
467
- # Check if this is a mode-specific tool
468
  tool_name = action.tool_name
469
  current_mode = getattr(self, "_mode", None)
470
 
471
  if tool_name in self._mode_tools:
472
  mode_info = self._mode_tools[tool_name]
473
-
474
- # Check if tool is available in current mode
475
- # Tool is available if:
476
- # 1. It has a None mode (available in all modes), OR
477
- # 2. It has an implementation for the current mode
478
  if None in mode_info:
479
- # Use the mode-agnostic version
480
  func = mode_info[None]
481
  elif current_mode in mode_info:
482
- # Use the mode-specific version
483
  func = mode_info[current_mode]
484
  else:
485
- # Tool not available in current mode
486
  return CallToolObservation(
487
  tool_name=tool_name,
488
  result=None,
@@ -491,16 +525,11 @@ class MCPEnvironment(Environment):
491
  message=f"Tool '{tool_name}' not available in {current_mode} mode",
492
  ),
493
  )
494
-
495
- # Call the mode-specific function directly
496
  try:
497
- # Check if function is async and await if necessary
498
  if inspect.iscoroutinefunction(func):
499
- result = run_async_safely(func(**action.arguments))
500
  else:
501
  result = func(**action.arguments)
502
-
503
- # Wrap result in CallToolResult format to match FastMCP behavior
504
  return CallToolObservation(
505
  tool_name=tool_name,
506
  result=CallToolResult(
@@ -521,22 +550,12 @@ class MCPEnvironment(Environment):
521
  ),
522
  )
523
 
524
- # Not a mode-specific tool, use FastMCP
525
  try:
526
- # Run the async call_tool with timeout
527
- # Use run_async_safely to handle both sync and async contexts
528
- result = run_async_safely(
529
- asyncio.wait_for(
530
- self._async_call_tool(action.tool_name, action.arguments),
531
- timeout=timeout,
532
- )
533
- )
534
-
535
- return CallToolObservation(
536
- tool_name=action.tool_name,
537
- result=result,
538
  )
539
-
540
  except asyncio.TimeoutError:
541
  return CallToolObservation(
542
  tool_name=action.tool_name,
@@ -546,11 +565,8 @@ class MCPEnvironment(Environment):
546
  message=f"Tool '{action.tool_name}' timed out after {timeout} seconds",
547
  ),
548
  )
549
-
550
  except Exception as e:
551
  error_message = str(e)
552
-
553
- # Determine error type based on the exception
554
  if (
555
  "not found" in error_message.lower()
556
  or "unknown tool" in error_message.lower()
@@ -563,29 +579,34 @@ class MCPEnvironment(Environment):
563
  error_type = ToolErrorType.INVALID_ARGS
564
  else:
565
  error_type = ToolErrorType.EXECUTION_ERROR
566
-
567
  return CallToolObservation(
568
  tool_name=action.tool_name,
569
  result=None,
570
- error=ToolError(
571
- error_type=error_type,
572
- message=error_message,
573
- ),
574
  )
575
 
576
- async def _async_call_tool(self, tool_name: str, arguments: dict) -> Any:
 
 
 
 
 
577
  """
578
- Async helper to call a tool on the MCP server.
579
 
580
- Args:
581
- tool_name: Name of the tool to invoke.
582
- arguments: Dictionary of arguments to pass to the tool.
583
-
584
- Returns:
585
- The result from the tool execution.
586
  """
587
- async with self.mcp_client:
588
- return await self.mcp_client.call_tool(tool_name, arguments)
 
 
 
 
 
 
 
589
 
590
  @abstractmethod
591
  def _step_impl(
 
56
  import inspect
57
  from abc import abstractmethod
58
  from collections import defaultdict
59
+ from contextlib import asynccontextmanager
60
  from typing import Any, Callable, Dict, Optional
61
 
62
  from fastmcp import Client
 
165
  # Track tool schemas for list_tools: {tool_name: {mode: schema}}
166
  self._mode_tool_schemas = defaultdict(dict)
167
 
168
+ def _require_mcp_client(self) -> Any:
169
+ """Return MCP client or raise if environment has been closed."""
170
+ if self.mcp_client is None:
171
+ raise RuntimeError("MCP client is not available; environment is closed")
172
+ return self.mcp_client
173
+
174
+ def _require_mcp_server(self) -> Any:
175
+ """Return MCP server or raise if environment has been closed."""
176
+ if self.mcp_server is None:
177
+ raise RuntimeError("MCP server is not available; environment is closed")
178
+ return self.mcp_server
179
+
180
+ @asynccontextmanager
181
+ async def mcp_session(self):
182
+ """
183
+ Context manager for MCP client sessions.
184
+
185
+ This wrapper serves two purposes:
186
+
187
+ 1. **Null guard** — raises a clear error if ``close()`` has already
188
+ been called (``mcp_client`` is ``None``).
189
+
190
+ 2. **AsyncExitStack adapter** — FastMCP's ``Client.__aenter__``
191
+ creates a background ``asyncio.Task`` for session management.
192
+ When entered directly via ``AsyncExitStack`` in the HTTP session
193
+ path (``_create_session``), this task can be cancelled by ASGI
194
+ harnesses (e.g. Starlette ``TestClient``) between requests,
195
+ corrupting session state. Wrapping in an ``asynccontextmanager``
196
+ generator isolates the task lifecycle: the generator frame keeps
197
+ ``async with client:`` suspended at ``yield``, so cleanup only
198
+ runs when the stack explicitly closes the generator — not when
199
+ the event loop cancels orphaned tasks.
200
+
201
+ Delegates to FastMCP's ``Client`` context manager which is
202
+ reentrant: the first entry opens the transport and subsequent
203
+ (nested) entries simply increment an internal reference counter.
204
+ The transport is closed only when the outermost context exits.
205
+
206
+ No external lock is needed because ``Client._connect`` /
207
+ ``Client._disconnect`` already serialise connection state changes
208
+ through their own ``anyio.Lock``.
209
+ """
210
+ client = self._require_mcp_client()
211
+ async with client:
212
+ yield client
213
+
214
  @property
215
  def supports_code_mode(self) -> bool:
216
  """Check if this environment supports code mode (execute_code)."""
 
339
 
340
  # If mode is None, register with FastMCP as usual
341
  if mode is None:
342
+ mcp_server = self._require_mcp_server()
343
+ decorated_func = mcp_server.tool()(func)
344
  self._mode_tools[tool_name][None] = func
345
  return decorated_func
346
 
 
420
  return self._step_impl(action, timeout_s=timeout_s, **kwargs)
421
 
422
  def _handle_list_tools(self) -> ListToolsObservation:
423
+ """Sync wrapper — delegates to the canonical async implementation."""
424
+ return run_async_safely(self._async_handle_list_tools())
425
+
426
+ async def _async_list_tools(self) -> list:
427
  """
428
+ Async helper to list tools from the MCP client.
429
 
430
  Returns:
431
+ List of tool objects from the MCP server.
 
432
  """
433
+ async with self.mcp_session() as client:
434
+ return await client.list_tools()
 
435
 
436
+ def _handle_call_tool(
437
+ self,
438
+ action: CallToolAction,
439
+ timeout_s: Optional[float] = None,
440
+ ) -> CallToolObservation:
441
+ """Sync wrapper — delegates to the canonical async implementation."""
442
+ return run_async_safely(
443
+ self._async_handle_call_tool(action, timeout_s=timeout_s)
444
+ )
445
 
446
+ async def _async_call_tool(self, tool_name: str, arguments: dict) -> Any:
447
+ """
448
+ Async helper to call a tool on the MCP server.
449
 
450
+ Args:
451
+ tool_name: Name of the tool to invoke.
452
+ arguments: Dictionary of arguments to pass to the tool.
453
+
454
+ Returns:
455
+ The result from the tool execution.
456
+ """
457
+ async with self.mcp_session() as client:
458
+ return await client.call_tool(tool_name, arguments)
459
+
460
+ async def _async_handle_list_tools(self) -> ListToolsObservation:
461
+ """Async version of _handle_list_tools — avoids run_async_safely."""
462
+ try:
463
+ current_mode = getattr(self, "_mode", None)
464
+ tools_result = await self._async_list_tools()
465
+ tools = []
466
  for tool in tools_result:
467
  if tool.name not in self._mode_tool_schemas:
468
  tools.append(
 
474
  else {},
475
  )
476
  )
 
 
477
  for tool_name, mode_schemas in self._mode_tool_schemas.items():
478
  if None in mode_schemas:
 
479
  schema = mode_schemas[None]
480
  tools.append(
481
  Tool(
 
485
  )
486
  )
487
  elif current_mode in mode_schemas:
 
488
  schema = mode_schemas[current_mode]
489
  tools.append(
490
  Tool(
 
493
  input_schema=schema["input_schema"],
494
  )
495
  )
 
496
  return ListToolsObservation(tools=tools)
 
497
  except Exception as e:
 
498
  return ListToolsObservation(
499
  tools=[],
500
+ metadata={"error": str(e), "error_type": "list_tools_failed"},
 
 
 
501
  )
502
 
503
+ async def _async_handle_call_tool(
 
 
 
 
 
 
 
 
 
 
504
  self,
505
  action: CallToolAction,
506
  timeout_s: Optional[float] = None,
507
  ) -> CallToolObservation:
508
+ """Async version of _handle_call_tool — avoids run_async_safely."""
 
 
 
 
 
 
 
 
 
509
  timeout = timeout_s if timeout_s is not None else MCP_TOOL_CALL_TIMEOUT
 
 
510
  tool_name = action.tool_name
511
  current_mode = getattr(self, "_mode", None)
512
 
513
  if tool_name in self._mode_tools:
514
  mode_info = self._mode_tools[tool_name]
 
 
 
 
 
515
  if None in mode_info:
 
516
  func = mode_info[None]
517
  elif current_mode in mode_info:
 
518
  func = mode_info[current_mode]
519
  else:
 
520
  return CallToolObservation(
521
  tool_name=tool_name,
522
  result=None,
 
525
  message=f"Tool '{tool_name}' not available in {current_mode} mode",
526
  ),
527
  )
 
 
528
  try:
 
529
  if inspect.iscoroutinefunction(func):
530
+ result = await func(**action.arguments)
531
  else:
532
  result = func(**action.arguments)
 
 
533
  return CallToolObservation(
534
  tool_name=tool_name,
535
  result=CallToolResult(
 
550
  ),
551
  )
552
 
 
553
  try:
554
+ result = await asyncio.wait_for(
555
+ self._async_call_tool(action.tool_name, action.arguments),
556
+ timeout=timeout,
 
 
 
 
 
 
 
 
 
557
  )
558
+ return CallToolObservation(tool_name=action.tool_name, result=result)
559
  except asyncio.TimeoutError:
560
  return CallToolObservation(
561
  tool_name=action.tool_name,
 
565
  message=f"Tool '{action.tool_name}' timed out after {timeout} seconds",
566
  ),
567
  )
 
568
  except Exception as e:
569
  error_message = str(e)
 
 
570
  if (
571
  "not found" in error_message.lower()
572
  or "unknown tool" in error_message.lower()
 
579
  error_type = ToolErrorType.INVALID_ARGS
580
  else:
581
  error_type = ToolErrorType.EXECUTION_ERROR
 
582
  return CallToolObservation(
583
  tool_name=action.tool_name,
584
  result=None,
585
+ error=ToolError(error_type=error_type, message=error_message),
 
 
 
586
  )
587
 
588
+ async def step_async(
589
+ self,
590
+ action: Action,
591
+ timeout_s: Optional[float] = None,
592
+ **kwargs: Any,
593
+ ) -> Observation:
594
  """
595
+ Async step that routes MCP actions without going through run_async_safely.
596
 
597
+ The WebSocket handler calls this directly on the outer event loop, where
598
+ the MCP session is already open, avoiding the thread/event-loop deadlock
599
+ that occurs when the sync step() path is used via run_in_executor.
 
 
 
600
  """
601
+ if isinstance(action, ListToolsAction):
602
+ return await self._async_handle_list_tools()
603
+ elif isinstance(action, CallToolAction):
604
+ return await self._async_handle_call_tool(action, timeout_s=timeout_s)
605
+ else:
606
+ loop = asyncio.get_event_loop()
607
+ return await loop.run_in_executor(
608
+ None, lambda: self._step_impl(action, timeout_s=timeout_s, **kwargs)
609
+ )
610
 
611
  @abstractmethod
612
  def _step_impl(
src/core/env_server/serialization.py CHANGED
@@ -14,14 +14,28 @@ HTTP server and web interface implementations.
14
 
15
  from typing import Any, Dict, Type
16
 
 
17
  from .types import Action, Observation
18
 
 
 
 
 
 
 
 
 
 
19
 
20
  def deserialize_action(action_data: Dict[str, Any], action_cls: Type[Action]) -> Action:
21
  """
22
  Convert JSON dict to Action instance using Pydantic validation.
23
 
24
- This is a basic deserialization that works for most environments.
 
 
 
 
25
  For special cases (e.g., tensor fields, custom type conversions),
26
  use deserialize_action_with_preprocessing().
27
 
@@ -38,6 +52,17 @@ def deserialize_action(action_data: Dict[str, Any], action_cls: Type[Action]) ->
38
  Note:
39
  This uses Pydantic's model_validate() for automatic validation.
40
  """
 
 
 
 
 
 
 
 
 
 
 
41
  return action_cls.model_validate(action_data)
42
 
43
 
@@ -62,6 +87,15 @@ def deserialize_action_with_preprocessing(
62
  Raises:
63
  ValidationError: If action_data is invalid for the action class
64
  """
 
 
 
 
 
 
 
 
 
65
  processed_data = {}
66
 
67
  for key, value in action_data.items():
 
14
 
15
  from typing import Any, Dict, Type
16
 
17
+ from .mcp_types import CallToolAction, ListToolsAction
18
  from .types import Action, Observation
19
 
20
+ # MCP action types keyed by their "type" discriminator value.
21
+ # These are checked before the environment's own action_cls so that
22
+ # ListToolsAction / CallToolAction payloads are never rejected by an
23
+ # unrelated Pydantic model.
24
+ _MCP_ACTION_TYPES: Dict[str, Type[Action]] = {
25
+ "list_tools": ListToolsAction,
26
+ "call_tool": CallToolAction,
27
+ }
28
+
29
 
30
  def deserialize_action(action_data: Dict[str, Any], action_cls: Type[Action]) -> Action:
31
  """
32
  Convert JSON dict to Action instance using Pydantic validation.
33
 
34
+ MCP action types (``list_tools``, ``call_tool``) are recognised
35
+ automatically via the ``"type"`` discriminator field, regardless of
36
+ the environment's configured ``action_cls``. All other payloads
37
+ fall through to ``action_cls.model_validate()``.
38
+
39
  For special cases (e.g., tensor fields, custom type conversions),
40
  use deserialize_action_with_preprocessing().
41
 
 
52
  Note:
53
  This uses Pydantic's model_validate() for automatic validation.
54
  """
55
+ # Route MCP action types before falling through to the env action_cls.
56
+ # Only intercept when action_cls is the generic Action base or itself an
57
+ # MCP type (i.e. the server hosts an MCP environment). This avoids
58
+ # silently bypassing env-specific validation for non-MCP environments
59
+ # that happen to use "call_tool" / "list_tools" as a type discriminator.
60
+ action_type = action_data.get("type")
61
+ if action_type in _MCP_ACTION_TYPES:
62
+ mcp_cls = _MCP_ACTION_TYPES[action_type]
63
+ if action_cls is Action or action_cls in _MCP_ACTION_TYPES.values():
64
+ return mcp_cls.model_validate(action_data)
65
+
66
  return action_cls.model_validate(action_data)
67
 
68
 
 
87
  Raises:
88
  ValidationError: If action_data is invalid for the action class
89
  """
90
+ # Route MCP action types before preprocessing (they don't need it).
91
+ # Same guard as deserialize_action: only intercept when action_cls is
92
+ # the generic Action base or itself an MCP type.
93
+ action_type = action_data.get("type")
94
+ if action_type in _MCP_ACTION_TYPES:
95
+ mcp_cls = _MCP_ACTION_TYPES[action_type]
96
+ if action_cls is Action or action_cls in _MCP_ACTION_TYPES.values():
97
+ return mcp_cls.model_validate(action_data)
98
+
99
  processed_data = {}
100
 
101
  for key, value in action_data.items():
src/core/env_server/web_interface.py CHANGED
@@ -15,13 +15,15 @@ option (e.g. openenv push --enable-interface) or ENABLE_WEB_INTERFACE env var.
15
  from __future__ import annotations
16
 
17
  import asyncio
 
18
  import json
19
  from concurrent.futures import ThreadPoolExecutor
20
  from datetime import datetime
21
  from typing import Any, Callable, Dict, List, Optional, Type
22
 
23
  import gradio as gr
24
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect
 
25
  from pydantic import BaseModel, ConfigDict, Field
26
 
27
  from .gradio_theme import OPENENV_GRADIO_CSS, OPENENV_GRADIO_THEME
@@ -269,6 +271,28 @@ class WebInterfaceManager:
269
  # Thread pool for running sync code (e.g., Playwright sync API) in async context
270
  self._executor = ThreadPoolExecutor(max_workers=1)
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  async def _run_sync_in_thread_pool(self, func, *args, **kwargs):
273
  """Run a synchronous function in the thread pool executor.
274
 
@@ -317,11 +341,24 @@ class WebInterfaceManager:
317
  for client in disconnected_clients:
318
  self.connected_clients.remove(client)
319
 
320
- async def reset_environment(self) -> Dict[str, Any]:
 
 
321
  """Reset the environment and update state."""
322
- # Run sync reset in thread pool to avoid blocking event loop
323
- # and to support environments using sync libraries (e.g., Playwright)
324
- observation: Observation = await self._run_sync_in_thread_pool(self.env.reset)
 
 
 
 
 
 
 
 
 
 
 
325
  state: State = self.env.state
326
 
327
  # Serialize observation once using shared utility
@@ -428,6 +465,16 @@ def create_web_interface_app(
428
  web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata)
429
 
430
  # Web API routes first (so they take precedence over Gradio mount at /web)
 
 
 
 
 
 
 
 
 
 
431
  @app.get("/web/metadata")
432
  async def web_metadata():
433
  """Get environment metadata."""
@@ -449,9 +496,9 @@ def create_web_interface_app(
449
  await web_manager.disconnect_websocket(websocket)
450
 
451
  @app.post("/web/reset")
452
- async def web_reset():
453
  """Reset endpoint for web interface."""
454
- return await web_manager.reset_environment()
455
 
456
  @app.post("/web/step")
457
  async def web_step(request: Dict[str, Any]):
@@ -475,7 +522,13 @@ def create_web_interface_app(
475
  @app.get("/web/state")
476
  async def web_state():
477
  """State endpoint for web interface."""
478
- return web_manager.get_state()
 
 
 
 
 
 
479
 
480
  action_fields = _extract_action_fields(action_cls)
481
  is_chat_env = _is_chat_env(action_cls)
@@ -505,7 +558,7 @@ def create_web_interface_app(
505
  )
506
  gradio_blocks = gr.TabbedInterface(
507
  [default_blocks, custom_blocks],
508
- tab_names=["Playground", "Visualization"],
509
  title=get_gradio_display_title(metadata),
510
  )
511
  else:
 
15
  from __future__ import annotations
16
 
17
  import asyncio
18
+ import inspect
19
  import json
20
  from concurrent.futures import ThreadPoolExecutor
21
  from datetime import datetime
22
  from typing import Any, Callable, Dict, List, Optional, Type
23
 
24
  import gradio as gr
25
+ from fastapi import Body, FastAPI, HTTPException, status, WebSocket, WebSocketDisconnect
26
+ from fastapi.responses import RedirectResponse
27
  from pydantic import BaseModel, ConfigDict, Field
28
 
29
  from .gradio_theme import OPENENV_GRADIO_CSS, OPENENV_GRADIO_THEME
 
271
  # Thread pool for running sync code (e.g., Playwright sync API) in async context
272
  self._executor = ThreadPoolExecutor(max_workers=1)
273
 
274
+ @staticmethod
275
+ def _get_valid_kwargs(
276
+ sig: inspect.Signature,
277
+ kwargs: Dict[str, Any],
278
+ skip_params: Optional[set[str]] = None,
279
+ ) -> Dict[str, Any]:
280
+ """Filter kwargs to only those accepted by the target function."""
281
+ skip_params = skip_params or set()
282
+ valid_kwargs: Dict[str, Any] = {}
283
+ has_var_kwargs = any(
284
+ param.kind == inspect.Parameter.VAR_KEYWORD
285
+ for param in sig.parameters.values()
286
+ )
287
+
288
+ for key, value in kwargs.items():
289
+ if key in skip_params:
290
+ continue
291
+ if key in sig.parameters or has_var_kwargs:
292
+ valid_kwargs[key] = value
293
+
294
+ return valid_kwargs
295
+
296
  async def _run_sync_in_thread_pool(self, func, *args, **kwargs):
297
  """Run a synchronous function in the thread pool executor.
298
 
 
341
  for client in disconnected_clients:
342
  self.connected_clients.remove(client)
343
 
344
+ async def reset_environment(
345
+ self, reset_kwargs: Optional[Dict[str, Any]] = None
346
+ ) -> Dict[str, Any]:
347
  """Reset the environment and update state."""
348
+ reset_kwargs = reset_kwargs or {}
349
+
350
+ is_async = self.env.reset_async.__func__ is not Environment.reset_async
351
+ sig = inspect.signature(self.env.reset_async if is_async else self.env.reset)
352
+ valid_kwargs = self._get_valid_kwargs(sig, reset_kwargs)
353
+
354
+ if is_async:
355
+ observation = await self.env.reset_async(**valid_kwargs)
356
+ else:
357
+ # Run sync reset in thread pool to avoid blocking event loop
358
+ # and to support environments using sync libraries (e.g., Playwright)
359
+ observation = await self._run_sync_in_thread_pool(
360
+ self.env.reset, **valid_kwargs
361
+ )
362
  state: State = self.env.state
363
 
364
  # Serialize observation once using shared utility
 
465
  web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata)
466
 
467
  # Web API routes first (so they take precedence over Gradio mount at /web)
468
+ @app.get("/", include_in_schema=False)
469
+ async def web_root():
470
+ """Redirect the app root to the Gradio interface."""
471
+ return RedirectResponse(url="/web/")
472
+
473
+ @app.get("/web", include_in_schema=False)
474
+ async def web_root_no_slash():
475
+ """Redirect /web to /web/ for mounted Gradio deployments behind proxies."""
476
+ return RedirectResponse(url="/web/")
477
+
478
  @app.get("/web/metadata")
479
  async def web_metadata():
480
  """Get environment metadata."""
 
496
  await web_manager.disconnect_websocket(websocket)
497
 
498
  @app.post("/web/reset")
499
+ async def web_reset(request: Optional[Dict[str, Any]] = Body(default=None)):
500
  """Reset endpoint for web interface."""
501
+ return await web_manager.reset_environment(request)
502
 
503
  @app.post("/web/step")
504
  async def web_step(request: Dict[str, Any]):
 
522
  @app.get("/web/state")
523
  async def web_state():
524
  """State endpoint for web interface."""
525
+ try:
526
+ return web_manager.get_state()
527
+ except RuntimeError as exc:
528
+ raise HTTPException(
529
+ status_code=status.HTTP_409_CONFLICT,
530
+ detail=str(exc),
531
+ ) from exc
532
 
533
  action_fields = _extract_action_fields(action_cls)
534
  is_chat_env = _is_chat_env(action_cls)
 
558
  )
559
  gradio_blocks = gr.TabbedInterface(
560
  [default_blocks, custom_blocks],
561
+ tab_names=["Playground", "Custom"],
562
  title=get_gradio_display_title(metadata),
563
  )
564
  else:
src/core/mcp_client.py CHANGED
@@ -52,6 +52,7 @@ Example (sync wrapper):
52
  ... result = env.call_tool("echo_message", message="Hello!")
53
  """
54
 
 
55
  from typing import Any, Dict, List, Optional
56
 
57
  from .client_types import StepResult
@@ -118,6 +119,66 @@ class MCPClientBase(EnvClient[Any, Observation, State]):
118
  )
119
  self._tools_cache: Optional[List[Tool]] = None
120
  self.use_production_mode = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  async def list_tools(self, use_cache: bool = True) -> List[Tool]:
123
  """
@@ -138,26 +199,18 @@ class MCPClientBase(EnvClient[Any, Observation, State]):
138
  if use_cache and self._tools_cache is not None:
139
  return self._tools_cache
140
 
141
- # Use production mode HTTP endpoint if enabled
142
- if self.use_production_mode:
143
- import requests
144
-
145
- # Convert ws:// URL to http:// URL
146
- url = self._ws_url.replace("ws://", "http://").replace("wss://", "https://")
147
- # Remove /ws suffix if present and add /mcp
148
- url = url.rstrip("/ws").rstrip("/") + "/mcp"
149
-
150
  try:
151
- response = requests.post(
152
- url,
153
- json={
154
- "jsonrpc": "2.0",
155
- "method": "tools/list",
156
- "params": {},
157
- "id": 1,
158
- },
159
  )
160
- data = response.json()
 
 
161
  if "result" in data and "tools" in data["result"]:
162
  tools = [
163
  Tool(
@@ -177,7 +230,12 @@ class MCPClientBase(EnvClient[Any, Observation, State]):
177
  return []
178
 
179
  result = await self.step(ListToolsAction())
180
- self._tools_cache = result.observation.tools
 
 
 
 
 
181
  return self._tools_cache
182
 
183
  def _step_payload(self, action: Any) -> Dict[str, Any]:
@@ -251,6 +309,35 @@ class MCPClientBase(EnvClient[Any, Observation, State]):
251
  step_count=payload.get("step_count", 0),
252
  )
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
  class MCPToolClient(MCPClientBase):
256
  """
@@ -316,6 +403,26 @@ class MCPToolClient(MCPClientBase):
316
  >>> result = await env.call_tool("greet", name="Claude")
317
  >>> print(result) # "Hello, Claude!"
318
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  action = CallToolAction(tool_name=name, arguments=kwargs)
320
  result = await self.step(action)
321
  obs = result.observation
 
52
  ... result = env.call_tool("echo_message", message="Hello!")
53
  """
54
 
55
+ import asyncio
56
  from typing import Any, Dict, List, Optional
57
 
58
  from .client_types import StepResult
 
119
  )
120
  self._tools_cache: Optional[List[Tool]] = None
121
  self.use_production_mode = False
122
+ self._production_session_id: Optional[str] = None
123
+ self._production_session_lock = asyncio.Lock()
124
+ self._jsonrpc_request_id = 0
125
+ self._http_client: Optional[Any] = None # lazily-created httpx.AsyncClient
126
+
127
+ def _next_request_id(self) -> int:
128
+ """Generate a monotonically increasing JSON-RPC request id."""
129
+ self._jsonrpc_request_id += 1
130
+ return self._jsonrpc_request_id
131
+
132
+ def _production_mcp_url(self) -> str:
133
+ """Build HTTP MCP endpoint URL from the client's websocket URL."""
134
+ url = self._ws_url.replace("ws://", "http://").replace("wss://", "https://")
135
+ if url.endswith("/ws"):
136
+ url = url[: -len("/ws")]
137
+ return url.rstrip("/") + "/mcp"
138
+
139
+ async def _get_http_client(self) -> Any:
140
+ """Return a shared httpx.AsyncClient, creating one lazily."""
141
+ if self._http_client is None:
142
+ import httpx
143
+
144
+ self._http_client = httpx.AsyncClient()
145
+ return self._http_client
146
+
147
+ async def _production_mcp_request(
148
+ self, method: str, params: Optional[Dict[str, Any]] = None
149
+ ) -> Dict[str, Any]:
150
+ """Send a JSON-RPC request to HTTP /mcp and return parsed JSON response."""
151
+ client = await self._get_http_client()
152
+ response = await client.post(
153
+ self._production_mcp_url(),
154
+ json={
155
+ "jsonrpc": "2.0",
156
+ "method": method,
157
+ "params": params or {},
158
+ "id": self._next_request_id(),
159
+ },
160
+ timeout=self._message_timeout,
161
+ )
162
+ response.raise_for_status()
163
+ return response.json()
164
+
165
+ async def _ensure_production_session(self) -> str:
166
+ """Create and cache a persistent HTTP MCP session id if needed."""
167
+ async with self._production_session_lock:
168
+ if self._production_session_id is not None:
169
+ return self._production_session_id
170
+
171
+ data = await self._production_mcp_request("openenv/session/create")
172
+ if "error" in data:
173
+ message = data.get("error", {}).get("message", "unknown error")
174
+ raise RuntimeError(f"Failed to create MCP session: {message}")
175
+
176
+ session_id = data.get("result", {}).get("session_id")
177
+ if not session_id:
178
+ raise RuntimeError("Failed to create MCP session: missing session_id")
179
+
180
+ self._production_session_id = session_id
181
+ return session_id
182
 
183
  async def list_tools(self, use_cache: bool = True) -> List[Tool]:
184
  """
 
199
  if use_cache and self._tools_cache is not None:
200
  return self._tools_cache
201
 
202
+ # Use production mode HTTP endpoint if enabled.
203
+ # Some tests instantiate with __new__ and skip __init__, so default missing flag to False.
204
+ if getattr(self, "use_production_mode", False):
 
 
 
 
 
 
205
  try:
206
+ session_id = await self._ensure_production_session()
207
+ data = await self._production_mcp_request(
208
+ "tools/list",
209
+ {"session_id": session_id},
 
 
 
 
210
  )
211
+ if "error" in data:
212
+ message = data.get("error", {}).get("message", "unknown error")
213
+ raise RuntimeError(f"list_tools failed: {message}")
214
  if "result" in data and "tools" in data["result"]:
215
  tools = [
216
  Tool(
 
230
  return []
231
 
232
  result = await self.step(ListToolsAction())
233
+ if isinstance(result.observation, ListToolsObservation):
234
+ self._tools_cache = result.observation.tools
235
+ return self._tools_cache
236
+
237
+ # Unexpected observation type; keep API stable with an empty tool list.
238
+ self._tools_cache = []
239
  return self._tools_cache
240
 
241
  def _step_payload(self, action: Any) -> Dict[str, Any]:
 
309
  step_count=payload.get("step_count", 0),
310
  )
311
 
312
+ async def close(self) -> None:
313
+ """
314
+ Close client resources.
315
+
316
+ In production MCP mode, this also closes the server-side persistent
317
+ MCP session (best effort) before closing websocket/provider resources.
318
+ """
319
+ if self._production_session_id is not None:
320
+ try:
321
+ await self._production_mcp_request(
322
+ "openenv/session/close",
323
+ {"session_id": self._production_session_id},
324
+ )
325
+ except Exception:
326
+ # Best effort cleanup - do not mask normal close behavior
327
+ pass
328
+ finally:
329
+ self._production_session_id = None
330
+
331
+ if self._http_client is not None:
332
+ try:
333
+ await self._http_client.aclose()
334
+ except Exception:
335
+ pass
336
+ finally:
337
+ self._http_client = None
338
+
339
+ await super().close()
340
+
341
 
342
  class MCPToolClient(MCPClientBase):
343
  """
 
403
  >>> result = await env.call_tool("greet", name="Claude")
404
  >>> print(result) # "Hello, Claude!"
405
  """
406
+ if getattr(self, "use_production_mode", False):
407
+ session_id = await self._ensure_production_session()
408
+ data = await self._production_mcp_request(
409
+ "tools/call",
410
+ {
411
+ "name": name,
412
+ "arguments": kwargs,
413
+ "session_id": session_id,
414
+ },
415
+ )
416
+
417
+ if "error" in data:
418
+ message = data.get("error", {}).get("message", "unknown error")
419
+ raise RuntimeError(f"Tool '{name}' failed: {message}")
420
+
421
+ result = data.get("result")
422
+ if isinstance(result, dict) and "data" in result:
423
+ return result["data"]
424
+ return result
425
+
426
  action = CallToolAction(tool_name=name, arguments=kwargs)
427
  result = await self.step(action)
428
  obs = result.observation
src/core/openenv/__init__.py CHANGED
@@ -14,10 +14,18 @@ __all__ = [
14
  "SyncEnvClient",
15
  ]
16
 
17
- try:
18
- __version__ = metadata.version("openenv") # type: ignore[arg-type]
19
- except metadata.PackageNotFoundError: # pragma: no cover - local dev
20
- __version__ = "0.0.0"
 
 
 
 
 
 
 
 
21
 
22
 
23
  _LAZY_MODULES = {
 
14
  "SyncEnvClient",
15
  ]
16
 
17
+
18
+ def _load_package_version() -> str:
19
+ """Resolve the installed distribution version for the OpenEnv package."""
20
+ for distribution_name in ("openenv-core", "openenv"):
21
+ try:
22
+ return metadata.version(distribution_name)
23
+ except metadata.PackageNotFoundError:
24
+ continue
25
+ return "0.0.0"
26
+
27
+
28
+ __version__ = _load_package_version()
29
 
30
 
31
  _LAZY_MODULES = {
src/core/openenv/cli/templates/openenv_env/pyproject.toml CHANGED
@@ -17,7 +17,7 @@ dependencies = [
17
  # Core OpenEnv runtime (provides FastAPI server + HTTP client types)
18
  # install from github
19
  # "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
20
- "openenv-core[core]>=0.2.1",
21
  # Environment-specific dependencies
22
  # Add all dependencies needed for your environment here
23
  # Examples:
 
17
  # Core OpenEnv runtime (provides FastAPI server + HTTP client types)
18
  # install from github
19
  # "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
20
+ "openenv-core[core]>=0.2.2",
21
  # Environment-specific dependencies
22
  # Add all dependencies needed for your environment here
23
  # Examples:
src/core/openenv/core/env_server/http_server.py CHANGED
@@ -16,11 +16,15 @@ from __future__ import annotations
16
  import asyncio
17
  import inspect
18
  import json
 
19
  import os
20
  import time
21
  import uuid
22
  from concurrent.futures import ThreadPoolExecutor
23
- from typing import Any, Callable, Dict, Optional, Type
 
 
 
24
 
25
  from fastapi import (
26
  Body,
@@ -204,8 +208,9 @@ class HTTPEnvServer:
204
  self.observation_cls = observation_cls
205
 
206
  # Session management for WebSocket connections
207
- self._sessions: Dict[str, Environment] = {}
208
  self._session_executors: Dict[str, ThreadPoolExecutor] = {}
 
209
  self._session_info: Dict[str, SessionInfo] = {}
210
  self._session_lock = asyncio.Lock()
211
 
@@ -213,6 +218,14 @@ class HTTPEnvServer:
213
  # This is needed for environments using sync libraries (e.g., Playwright)
214
  self._executor = ThreadPoolExecutor(max_workers=32)
215
 
 
 
 
 
 
 
 
 
216
  def _validate_concurrency_safety(self) -> None:
217
  """
218
  Validate that the environment supports the configured concurrency level.
@@ -321,12 +334,37 @@ class HTTPEnvServer:
321
  )
322
  raise EnvironmentFactoryError(factory_name) from e
323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  async with self._session_lock:
325
  self._sessions[session_id] = env
 
 
326
  self._session_info[session_id] = SessionInfo(
327
  session_id=session_id,
328
  created_at=current_time,
329
- last_activity_at=current_time,
330
  step_count=0,
331
  environment_type=type(env).__name__,
332
  )
@@ -343,8 +381,27 @@ class HTTPEnvServer:
343
  async with self._session_lock:
344
  env = self._sessions.pop(session_id, None)
345
  executor = self._session_executors.pop(session_id, None)
 
346
  self._session_info.pop(session_id, None)
347
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  # Run close() in the same executor where the env was created
349
  # This is required for thread-sensitive libraries like Playwright/greenlet
350
  if env is not None:
@@ -383,6 +440,51 @@ class HTTPEnvServer:
383
  if increment_step:
384
  self._session_info[session_id].step_count += 1
385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  def get_session_info(self, session_id: str) -> Optional[SessionInfo]:
387
  """
388
  Get information about a specific session.
@@ -458,6 +560,20 @@ class HTTPEnvServer:
458
  f"Invalid mode: '{mode}'. Must be one of: {valid_modes}"
459
  )
460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  # Helper function to handle reset endpoint
462
  async def reset_handler(
463
  request: ResetRequest = Body(default_factory=ResetRequest),
@@ -526,53 +642,214 @@ class HTTPEnvServer:
526
 
527
  # Helper function to handle MCP endpoint
528
  async def mcp_handler(
529
- request: JsonRpcRequest, session_env: Optional[Environment] = None
 
 
530
  ) -> JsonRpcResponse:
531
  """
532
  Handle MCP JSON-RPC requests.
533
 
534
- Supports tools/list and tools/call methods in JSON-RPC 2.0 format.
 
 
 
535
  """
536
  method = request.method
537
  request_id = request.id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538
 
539
  # Use provided session environment or create temporary one
540
  if session_env is not None:
541
  _env = session_env
542
  should_close = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
  else:
544
  _env = self._env_factory()
545
  should_close = True
546
  try:
 
 
 
 
547
  if method == McpMethod.TOOLS_LIST:
548
  # Check if environment is MCP-enabled
549
- if not hasattr(_env, "mcp_client"):
550
  return JsonRpcResponse.error_response(
551
  JsonRpcErrorCode.INTERNAL_ERROR,
552
  "Environment does not support MCP",
553
  request_id=request_id,
554
  )
555
 
556
- # Use async context manager for MCP client
557
- async with _env.mcp_client:
558
- tools = await _env.mcp_client.list_tools()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
559
 
560
- return JsonRpcResponse.success(
561
- result={
562
- "tools": [
563
- t.model_dump() if hasattr(t, "model_dump") else dict(t)
564
- for t in tools
565
- ]
566
- },
 
 
 
 
 
 
 
 
 
 
 
567
  request_id=request_id,
568
  )
569
 
570
  elif method == McpMethod.TOOLS_CALL:
571
- params = request.params
572
  tool_name = params.get("name")
573
  arguments = params.get("arguments", {})
574
 
575
- if not hasattr(_env, "mcp_client"):
576
  return JsonRpcResponse.error_response(
577
  JsonRpcErrorCode.INTERNAL_ERROR,
578
  "Environment does not support MCP",
@@ -581,15 +858,51 @@ class HTTPEnvServer:
581
 
582
  if not tool_name:
583
  return JsonRpcResponse.error_response(
584
- JsonRpcErrorCode.INVALID_REQUEST,
585
  "Missing 'name' in params",
586
  request_id=request_id,
587
  )
588
 
589
- # Use async context manager for MCP client
590
- async with _env.mcp_client:
591
- result = await _env.mcp_client.call_tool(
592
- name=tool_name, arguments=arguments
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593
  )
594
 
595
  # Ensure result is JSON serializable
@@ -614,6 +927,11 @@ class HTTPEnvServer:
614
  request_id=request_id,
615
  )
616
  finally:
 
 
 
 
 
617
  if should_close:
618
  _env.close()
619
 
@@ -637,42 +955,59 @@ class HTTPEnvServer:
637
  try:
638
  # Create session with dedicated environment
639
  session_id, session_env = await self._create_session()
 
 
 
 
640
 
641
- while True:
642
- # Receive message from client
643
- raw_message = await websocket.receive_text()
644
-
645
- try:
646
- jsonrpc_dict = json.loads(raw_message)
647
- jsonrpc_request = JsonRpcRequest(**jsonrpc_dict)
648
- except json.JSONDecodeError as e:
649
- error_resp = JsonRpcResponse.error_response(
650
- JsonRpcErrorCode.PARSE_ERROR,
651
- f"Parse error: {e}",
652
- )
653
- await websocket.send_text(error_resp.model_dump_json())
654
- continue
655
- except ValidationError as e:
656
- error_resp = JsonRpcResponse.error_response(
657
- JsonRpcErrorCode.INVALID_REQUEST,
658
- f"Invalid request: {e}",
659
- )
660
- await websocket.send_text(error_resp.model_dump_json())
661
- continue
662
 
663
- try:
664
- # Call mcp_handler with session environment
665
- response = await mcp_handler(
666
- jsonrpc_request, session_env=session_env
 
667
  )
668
- await websocket.send_text(response.model_dump_json())
669
- except Exception as e:
670
- error_resp = JsonRpcResponse.error_response(
671
- JsonRpcErrorCode.INTERNAL_ERROR,
672
- str(e),
673
- request_id=jsonrpc_request.id,
674
- )
675
- await websocket.send_text(error_resp.model_dump_json())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
 
677
  except WebSocketDisconnect:
678
  pass
@@ -931,120 +1266,8 @@ all schema information needed to interact with the environment.
931
  JsonRpcErrorCode.PARSE_ERROR
932
  ).model_dump()
933
 
934
- method = request.method
935
- params = request.params
936
- request_id = request.id
937
-
938
- # Create a temporary environment for MCP access
939
- _env = self._env_factory()
940
-
941
- try:
942
- # Check if environment supports MCP
943
- if not hasattr(_env, "mcp_client") and not hasattr(_env, "mcp_server"):
944
- return JsonRpcResponse.error_response(
945
- JsonRpcErrorCode.INTERNAL_ERROR,
946
- "Environment does not support MCP",
947
- request_id=request_id,
948
- ).model_dump()
949
-
950
- if method == McpMethod.TOOLS_LIST:
951
- # List tools from MCP server
952
- if hasattr(_env, "mcp_client") and _env.mcp_client:
953
- async with _env.mcp_client:
954
- tools = await _env.mcp_client.list_tools()
955
- return JsonRpcResponse.success(
956
- result={
957
- "tools": [
958
- t.model_dump()
959
- if hasattr(t, "model_dump")
960
- else dict(t)
961
- for t in tools
962
- ]
963
- },
964
- request_id=request_id,
965
- ).model_dump()
966
- elif hasattr(_env, "mcp_server") and _env.mcp_server:
967
- # Use server directly
968
- tools = []
969
- for tool_name, tool in get_server_tools(
970
- _env.mcp_server
971
- ).items():
972
- tool_dict = {
973
- "name": tool.name,
974
- "description": tool.description or "",
975
- "inputSchema": tool.parameters or {},
976
- }
977
- tools.append(tool_dict)
978
- return JsonRpcResponse.success(
979
- result={"tools": tools},
980
- request_id=request_id,
981
- ).model_dump()
982
- else:
983
- return JsonRpcResponse.error_response(
984
- JsonRpcErrorCode.INTERNAL_ERROR,
985
- "MCP server not available",
986
- request_id=request_id,
987
- ).model_dump()
988
-
989
- elif method == McpMethod.TOOLS_CALL:
990
- tool_name = params.get("name")
991
- arguments = params.get("arguments", {})
992
-
993
- if not tool_name:
994
- return JsonRpcResponse.error_response(
995
- JsonRpcErrorCode.INVALID_PARAMS,
996
- "Invalid params - 'name' is required",
997
- request_id=request_id,
998
- ).model_dump()
999
-
1000
- # Call tool via MCP
1001
- if hasattr(_env, "mcp_client") and _env.mcp_client:
1002
- async with _env.mcp_client:
1003
- result = await _env.mcp_client.call_tool(
1004
- name=tool_name, arguments=arguments
1005
- )
1006
- elif hasattr(_env, "mcp_server") and _env.mcp_server:
1007
- # Call tool directly on FastMCP server
1008
- server_tools = get_server_tools(_env.mcp_server)
1009
- if tool_name in server_tools:
1010
- tool = server_tools[tool_name]
1011
- result = tool.fn(**arguments)
1012
- else:
1013
- return JsonRpcResponse.error_response(
1014
- JsonRpcErrorCode.INVALID_PARAMS,
1015
- f"Tool not found: {tool_name}",
1016
- request_id=request_id,
1017
- ).model_dump()
1018
- else:
1019
- return JsonRpcResponse.error_response(
1020
- JsonRpcErrorCode.INTERNAL_ERROR,
1021
- "MCP server not available",
1022
- request_id=request_id,
1023
- ).model_dump()
1024
-
1025
- # Make result JSON serializable
1026
- serializable_result = _make_json_serializable(result)
1027
-
1028
- return JsonRpcResponse.success(
1029
- result=serializable_result,
1030
- request_id=request_id,
1031
- ).model_dump()
1032
-
1033
- else:
1034
- return JsonRpcResponse.error_response(
1035
- JsonRpcErrorCode.METHOD_NOT_FOUND,
1036
- f"Method not found: {method}",
1037
- request_id=request_id,
1038
- ).model_dump()
1039
-
1040
- except Exception as e:
1041
- return JsonRpcResponse.error_response(
1042
- JsonRpcErrorCode.INTERNAL_ERROR,
1043
- str(e),
1044
- request_id=request_id,
1045
- ).model_dump()
1046
- finally:
1047
- _env.close()
1048
 
1049
  # Register WebSocket endpoint for persistent sessions
1050
  @app.websocket("/ws")
@@ -1066,135 +1289,167 @@ all schema information needed to interact with the environment.
1066
  try:
1067
  # Create session with dedicated environment
1068
  session_id, session_env = await self._create_session()
 
 
 
 
1069
 
1070
- while True:
1071
- # Receive message from client
1072
- raw_message = await websocket.receive_text()
1073
 
1074
- try:
1075
- message_dict = json.loads(raw_message)
1076
- except json.JSONDecodeError as e:
1077
- error_resp = WSErrorResponse(
1078
- data={
1079
- "message": f"Invalid JSON: {e}",
1080
- "code": WSErrorCode.INVALID_JSON,
1081
- }
1082
  )
1083
- await websocket.send_text(error_resp.model_dump_json())
1084
- continue
1085
-
1086
- msg_type = message_dict.get("type", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
1087
 
1088
- try:
1089
- match msg_type:
1090
- case "reset":
1091
- msg = WSResetMessage(**message_dict)
1092
 
1093
- is_async = (
1094
- session_env.reset_async.__func__
1095
- is not Environment.reset_async
1096
- )
1097
 
1098
- if is_async:
1099
- sig = inspect.signature(session_env.reset_async)
1100
- valid_kwargs = self._get_valid_kwargs(sig, msg.data)
1101
- observation = await session_env.reset_async(
1102
- **valid_kwargs
1103
  )
1104
- else:
1105
- sig = inspect.signature(session_env.reset)
1106
- valid_kwargs = self._get_valid_kwargs(sig, msg.data)
1107
- observation = await self._run_in_session_executor(
1108
- session_id, session_env.reset, **valid_kwargs
1109
- )
1110
-
1111
- self._update_session_activity(session_id)
1112
-
1113
- response = WSObservationResponse(
1114
- data=serialize_observation(observation),
1115
- )
1116
 
1117
- case "step":
1118
- msg = WSStepMessage(**message_dict)
1119
- action = deserialize_action(msg.data, self.action_cls)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1120
 
1121
- is_async = (
1122
- session_env.step_async.__func__
1123
- is not Environment.step_async
1124
- )
 
1125
 
1126
- if is_async:
1127
- observation = await session_env.step_async(action)
1128
- else:
1129
- observation = await self._run_in_session_executor(
1130
- session_id, session_env.step, action
1131
  )
1132
 
1133
- self._update_session_activity(
1134
- session_id, increment_step=True
1135
- )
 
 
 
 
 
 
 
 
 
 
 
1136
 
1137
- response = WSObservationResponse(
1138
- data=serialize_observation(observation)
1139
- )
1140
 
1141
- case "state":
1142
- msg = WSStateMessage(**message_dict)
1143
- state = session_env.state
1144
- if hasattr(state, "model_dump"):
1145
- state_data = state.model_dump()
1146
- else:
1147
- state_data = dict(state) if state else {}
1148
-
1149
- response = WSStateResponse(data=state_data)
1150
-
1151
- case "close":
1152
- msg = WSCloseMessage(**message_dict)
1153
- break
1154
-
1155
- case "mcp":
1156
- msg = WSMCPMessage(**message_dict)
1157
- try:
1158
- rpc_request = JsonRpcRequest(**msg.data)
1159
- except (ValidationError, Exception) as e:
1160
- rpc_response = JsonRpcResponse.error_response(
1161
- JsonRpcErrorCode.INVALID_REQUEST,
1162
- f"Invalid request: {e}",
 
 
 
 
 
 
 
 
 
1163
  )
1164
- else:
1165
- rpc_response = await mcp_handler(
1166
- rpc_request,
1167
- session_env=session_env,
 
 
 
1168
  )
1169
- response = WSMCPResponse(data=rpc_response.model_dump())
1170
-
1171
- case _:
1172
- response = WSErrorResponse(
1173
- data={
1174
- "message": f"Unknown message type: {msg_type}",
1175
- "code": WSErrorCode.UNKNOWN_TYPE,
1176
- }
1177
- )
1178
 
1179
- await websocket.send_text(response.model_dump_json())
1180
 
1181
- except ValidationError as e:
1182
- error_resp = WSErrorResponse(
1183
- data={
1184
- "message": "Invalid message",
1185
- "code": WSErrorCode.VALIDATION_ERROR,
1186
- "errors": e.errors(),
1187
- }
1188
- )
1189
- await websocket.send_text(error_resp.model_dump_json())
1190
- except Exception as e:
1191
- error_resp = WSErrorResponse(
1192
- data={
1193
- "message": str(e),
1194
- "code": WSErrorCode.EXECUTION_ERROR,
1195
- }
1196
- )
1197
- await websocket.send_text(error_resp.model_dump_json())
1198
 
1199
  except WebSocketDisconnect:
1200
  pass
@@ -1276,7 +1531,7 @@ def create_app(
1276
  from .web_interface import create_web_interface_app
1277
 
1278
  return create_web_interface_app(
1279
- env,
1280
  action_cls,
1281
  observation_cls,
1282
  env_name,
 
16
  import asyncio
17
  import inspect
18
  import json
19
+ import logging
20
  import os
21
  import time
22
  import uuid
23
  from concurrent.futures import ThreadPoolExecutor
24
+ from contextlib import AsyncExitStack
25
+ from typing import Any, AsyncContextManager, Callable, cast, Dict, Optional, Type
26
+
27
+ _MISSING = object()
28
 
29
  from fastapi import (
30
  Body,
 
208
  self.observation_cls = observation_cls
209
 
210
  # Session management for WebSocket connections
211
+ self._sessions: Dict[str, Optional[Environment]] = {}
212
  self._session_executors: Dict[str, ThreadPoolExecutor] = {}
213
+ self._session_stacks: Dict[str, AsyncExitStack] = {}
214
  self._session_info: Dict[str, SessionInfo] = {}
215
  self._session_lock = asyncio.Lock()
216
 
 
218
  # This is needed for environments using sync libraries (e.g., Playwright)
219
  self._executor = ThreadPoolExecutor(max_workers=32)
220
 
221
+ # Idle session reaper configuration.
222
+ # Timeout is taken from ConcurrencyConfig.session_timeout;
223
+ # None means no timeout (default — reaper is a no-op).
224
+ self._session_idle_timeout_s: Optional[float] = (
225
+ self._concurrency_config.session_timeout
226
+ )
227
+ self._reaper_task: Optional[asyncio.Task[None]] = None
228
+
229
  def _validate_concurrency_safety(self) -> None:
230
  """
231
  Validate that the environment supports the configured concurrency level.
 
334
  )
335
  raise EnvironmentFactoryError(factory_name) from e
336
 
337
+ # Hold the MCP session open for the lifetime of this session,
338
+ # matching the WebSocket path's AsyncExitStack pattern. This
339
+ # prevents per-request MCP transport teardown/reconnection and
340
+ # preserves FastMCP session state (ctx.set_state / ctx.get_state)
341
+ # across HTTP calls within the same OpenEnv session.
342
+ stack = AsyncExitStack()
343
+ try:
344
+ mcp_session_factory = getattr(env, "mcp_session", None)
345
+ if callable(mcp_session_factory):
346
+ mcp_session_cm = cast(AsyncContextManager[Any], mcp_session_factory())
347
+ await stack.enter_async_context(mcp_session_cm)
348
+ except Exception:
349
+ # MCP transport failed to start — clean up the reserved slot,
350
+ # the env, and the executor so they don't leak permanently
351
+ # against _max_concurrent_envs.
352
+ await stack.aclose() # best-effort
353
+ async with self._session_lock:
354
+ self._sessions.pop(session_id, None)
355
+ self._session_executors.pop(session_id, None)
356
+ self._session_info.pop(session_id, None)
357
+ await self._cleanup_session_resources(env, executor)
358
+ raise
359
+
360
  async with self._session_lock:
361
  self._sessions[session_id] = env
362
+ self._session_stacks[session_id] = stack
363
+ now = time.time()
364
  self._session_info[session_id] = SessionInfo(
365
  session_id=session_id,
366
  created_at=current_time,
367
+ last_activity_at=now,
368
  step_count=0,
369
  environment_type=type(env).__name__,
370
  )
 
381
  async with self._session_lock:
382
  env = self._sessions.pop(session_id, None)
383
  executor = self._session_executors.pop(session_id, None)
384
+ stack = self._session_stacks.pop(session_id, None)
385
  self._session_info.pop(session_id, None)
386
 
387
+ await self._cleanup_session_resources(env, executor, stack)
388
+
389
+ async def _cleanup_session_resources(
390
+ self,
391
+ env: Optional[Environment],
392
+ executor: Optional[ThreadPoolExecutor],
393
+ stack: Optional[AsyncExitStack] = None,
394
+ ) -> None:
395
+ """Close an environment and shut down its executor (best-effort)."""
396
+ # Close the MCP session stack first — this gracefully exits the
397
+ # mcp_session() context (and the underlying FastMCP Client session)
398
+ # before we tear down the environment references.
399
+ if stack is not None:
400
+ try:
401
+ await stack.aclose()
402
+ except Exception:
403
+ pass # Best effort cleanup
404
+
405
  # Run close() in the same executor where the env was created
406
  # This is required for thread-sensitive libraries like Playwright/greenlet
407
  if env is not None:
 
440
  if increment_step:
441
  self._session_info[session_id].step_count += 1
442
 
443
+ async def _reap_idle_sessions(self) -> None:
444
+ """Background task that periodically destroys sessions idle beyond the timeout."""
445
+ timeout = self._session_idle_timeout_s
446
+ if timeout is None:
447
+ return # no timeout configured — noop
448
+ interval = max(timeout / 4, 5.0) # check frequently enough
449
+ while True:
450
+ try:
451
+ await asyncio.sleep(interval)
452
+ now = time.time()
453
+ stale_ids: list[str] = []
454
+ async with self._session_lock:
455
+ for sid, info in self._session_info.items():
456
+ if now - info.last_activity_at > timeout:
457
+ stale_ids.append(sid)
458
+ for sid in stale_ids:
459
+ # Re-check under lock: activity may have arrived since
460
+ # the snapshot was taken, making this session active again.
461
+ # Refresh `now` so slow _destroy_session calls don't cause
462
+ # subsequent entries to be validated against a stale clock.
463
+ now = time.time()
464
+ async with self._session_lock:
465
+ info = self._session_info.get(sid)
466
+ if info is None or (now - info.last_activity_at) <= timeout:
467
+ continue
468
+ await self._destroy_session(sid)
469
+ except asyncio.CancelledError:
470
+ break
471
+ except Exception as exc:
472
+ logging.getLogger(__name__).warning(
473
+ "Idle-session reaper encountered an error (will retry): %s",
474
+ exc,
475
+ )
476
+
477
+ def _start_reaper(self) -> None:
478
+ """Start the idle-session reaper if a timeout is configured."""
479
+ if self._session_idle_timeout_s is not None and self._reaper_task is None:
480
+ self._reaper_task = asyncio.create_task(self._reap_idle_sessions())
481
+
482
+ def _stop_reaper(self) -> None:
483
+ """Cancel the reaper background task."""
484
+ if self._reaper_task is not None:
485
+ self._reaper_task.cancel()
486
+ self._reaper_task = None
487
+
488
  def get_session_info(self, session_id: str) -> Optional[SessionInfo]:
489
  """
490
  Get information about a specific session.
 
560
  f"Invalid mode: '{mode}'. Must be one of: {valid_modes}"
561
  )
562
 
563
+ # Wire up idle-session reaper lifecycle via app events
564
+ server_ref = self
565
+
566
+ async def _start_session_reaper() -> None:
567
+ server_ref._start_reaper()
568
+
569
+ async def _stop_session_reaper() -> None:
570
+ server_ref._stop_reaper()
571
+
572
+ if not getattr(app.router, "_openenv_reaper_registered", False):
573
+ app.router.on_startup.append(_start_session_reaper)
574
+ app.router.on_shutdown.append(_stop_session_reaper)
575
+ app.router._openenv_reaper_registered = True # type: ignore[attr-defined]
576
+
577
  # Helper function to handle reset endpoint
578
  async def reset_handler(
579
  request: ResetRequest = Body(default_factory=ResetRequest),
 
642
 
643
  # Helper function to handle MCP endpoint
644
  async def mcp_handler(
645
+ request: JsonRpcRequest,
646
+ session_env: Optional[Environment] = None,
647
+ session_id: Optional[str] = None,
648
  ) -> JsonRpcResponse:
649
  """
650
  Handle MCP JSON-RPC requests.
651
 
652
+ Supports tools/list and tools/call methods in JSON-RPC 2.0 format,
653
+ plus OpenEnv session lifecycle methods for HTTP MCP:
654
+ - openenv/session/create
655
+ - openenv/session/close
656
  """
657
  method = request.method
658
  request_id = request.id
659
+ params = request.params
660
+ if not isinstance(params, dict):
661
+ return JsonRpcResponse.error_response(
662
+ JsonRpcErrorCode.INVALID_PARAMS,
663
+ "Params must be an object",
664
+ request_id=request_id,
665
+ )
666
+
667
+ # OpenEnv extension methods for explicit MCP session management.
668
+ # This enables persistent MCP lifecycles over HTTP /mcp, matching WebSocket semantics.
669
+ if method == "openenv/session/create":
670
+ if session_env is not None and session_id is not None:
671
+ return JsonRpcResponse.success(
672
+ result={"session_id": session_id},
673
+ request_id=request_id,
674
+ )
675
+ try:
676
+ created_session_id, _ = await self._create_session()
677
+ except SessionCapacityError as e:
678
+ return JsonRpcResponse.error_response(
679
+ JsonRpcErrorCode.SERVER_ERROR,
680
+ str(e),
681
+ request_id=request_id,
682
+ data={
683
+ "active_sessions": e.active_sessions,
684
+ "max_sessions": e.max_sessions,
685
+ },
686
+ )
687
+ except EnvironmentFactoryError as e:
688
+ return JsonRpcResponse.error_response(
689
+ JsonRpcErrorCode.SERVER_ERROR,
690
+ str(e),
691
+ request_id=request_id,
692
+ data={"factory_name": e.factory_name},
693
+ )
694
+ return JsonRpcResponse.success(
695
+ result={"session_id": created_session_id},
696
+ request_id=request_id,
697
+ )
698
+
699
+ if method == "openenv/session/close":
700
+ target_session_id = params.get("session_id")
701
+ if not target_session_id:
702
+ return JsonRpcResponse.error_response(
703
+ JsonRpcErrorCode.INVALID_PARAMS,
704
+ "Invalid params - 'session_id' is required",
705
+ request_id=request_id,
706
+ )
707
+
708
+ if session_id is not None and target_session_id == session_id:
709
+ return JsonRpcResponse.error_response(
710
+ JsonRpcErrorCode.INVALID_REQUEST,
711
+ "Cannot close active WebSocket-managed session via MCP method",
712
+ request_id=request_id,
713
+ )
714
+
715
+ async with self._session_lock:
716
+ env = self._sessions.pop(target_session_id, _MISSING)
717
+ if env is not _MISSING:
718
+ executor = self._session_executors.pop(target_session_id, None)
719
+ stack = self._session_stacks.pop(target_session_id, None)
720
+ self._session_info.pop(target_session_id, None)
721
+ else:
722
+ executor = None
723
+ stack = None
724
+
725
+ if env is _MISSING:
726
+ return JsonRpcResponse.error_response(
727
+ JsonRpcErrorCode.INVALID_PARAMS,
728
+ f"Unknown session_id: {target_session_id}",
729
+ request_id=request_id,
730
+ )
731
+
732
+ if env is None:
733
+ # Session slot reserved but env factory still running;
734
+ # re-insert the placeholder AND the executor so
735
+ # _create_session can finish and the executor remains
736
+ # tracked for eventual shutdown.
737
+ async with self._session_lock:
738
+ self._sessions[target_session_id] = None
739
+ if executor is not None:
740
+ self._session_executors[target_session_id] = executor
741
+ return JsonRpcResponse.error_response(
742
+ JsonRpcErrorCode.INVALID_REQUEST,
743
+ f"Session {target_session_id} is still initializing; retry shortly",
744
+ request_id=request_id,
745
+ )
746
+
747
+ # env/executor/stack cleanup outside the lock
748
+ await self._cleanup_session_resources(env, executor, stack)
749
+ return JsonRpcResponse.success(
750
+ result={"session_id": target_session_id, "closed": True},
751
+ request_id=request_id,
752
+ )
753
+
754
+ requested_session_id = params.get("session_id")
755
+ managed_session_id = session_id
756
 
757
  # Use provided session environment or create temporary one
758
  if session_env is not None:
759
  _env = session_env
760
  should_close = False
761
+ elif requested_session_id:
762
+ async with self._session_lock:
763
+ _env = self._sessions.get(requested_session_id, _MISSING)
764
+
765
+ if _env is _MISSING:
766
+ return JsonRpcResponse.error_response(
767
+ JsonRpcErrorCode.INVALID_PARAMS,
768
+ f"Unknown session_id: {requested_session_id}",
769
+ request_id=request_id,
770
+ )
771
+
772
+ if _env is None:
773
+ return JsonRpcResponse.error_response(
774
+ JsonRpcErrorCode.INVALID_REQUEST,
775
+ f"Session {requested_session_id} is still initializing; retry shortly",
776
+ request_id=request_id,
777
+ )
778
+
779
+ should_close = False
780
+ managed_session_id = requested_session_id
781
  else:
782
  _env = self._env_factory()
783
  should_close = True
784
  try:
785
+ mcp_client = getattr(_env, "mcp_client", None)
786
+ mcp_server = getattr(_env, "mcp_server", None)
787
+ mcp_session_factory = getattr(_env, "mcp_session", None)
788
+
789
  if method == McpMethod.TOOLS_LIST:
790
  # Check if environment is MCP-enabled
791
+ if mcp_client is None and mcp_server is None:
792
  return JsonRpcResponse.error_response(
793
  JsonRpcErrorCode.INTERNAL_ERROR,
794
  "Environment does not support MCP",
795
  request_id=request_id,
796
  )
797
 
798
+ if mcp_client:
799
+ if managed_session_id and mcp_client.is_connected():
800
+ # Session-managed with live transport — call
801
+ # directly, no redundant re-entry.
802
+ tools = await mcp_client.list_tools()
803
+ elif callable(mcp_session_factory):
804
+ # Stateless request, or session-managed but the
805
+ # background transport was lost: (re-)open.
806
+ mcp_session_cm = cast(
807
+ AsyncContextManager[Any], mcp_session_factory()
808
+ )
809
+ async with mcp_session_cm:
810
+ tools = await mcp_client.list_tools()
811
+ else:
812
+ async with mcp_client:
813
+ tools = await mcp_client.list_tools()
814
+
815
+ return JsonRpcResponse.success(
816
+ result={
817
+ "tools": [
818
+ t.model_dump()
819
+ if hasattr(t, "model_dump")
820
+ else dict(t)
821
+ for t in tools
822
+ ]
823
+ },
824
+ request_id=request_id,
825
+ )
826
 
827
+ if mcp_server:
828
+ tools = []
829
+ for _tool_name, tool in get_server_tools(mcp_server).items():
830
+ tools.append(
831
+ {
832
+ "name": tool.name,
833
+ "description": tool.description or "",
834
+ "inputSchema": tool.parameters or {},
835
+ }
836
+ )
837
+ return JsonRpcResponse.success(
838
+ result={"tools": tools},
839
+ request_id=request_id,
840
+ )
841
+
842
+ return JsonRpcResponse.error_response(
843
+ JsonRpcErrorCode.INTERNAL_ERROR,
844
+ "MCP server not available",
845
  request_id=request_id,
846
  )
847
 
848
  elif method == McpMethod.TOOLS_CALL:
 
849
  tool_name = params.get("name")
850
  arguments = params.get("arguments", {})
851
 
852
+ if mcp_client is None and mcp_server is None:
853
  return JsonRpcResponse.error_response(
854
  JsonRpcErrorCode.INTERNAL_ERROR,
855
  "Environment does not support MCP",
 
858
 
859
  if not tool_name:
860
  return JsonRpcResponse.error_response(
861
+ JsonRpcErrorCode.INVALID_PARAMS,
862
  "Missing 'name' in params",
863
  request_id=request_id,
864
  )
865
 
866
+ if mcp_client:
867
+ if managed_session_id and mcp_client.is_connected():
868
+ # Session-managed with live transport.
869
+ result = await mcp_client.call_tool(
870
+ name=tool_name, arguments=arguments
871
+ )
872
+ elif callable(mcp_session_factory):
873
+ # Stateless request, or session-managed but the
874
+ # background transport was lost: (re-)open.
875
+ mcp_session_cm = cast(
876
+ AsyncContextManager[Any], mcp_session_factory()
877
+ )
878
+ async with mcp_session_cm:
879
+ result = await mcp_client.call_tool(
880
+ name=tool_name, arguments=arguments
881
+ )
882
+ else:
883
+ async with mcp_client:
884
+ result = await mcp_client.call_tool(
885
+ name=tool_name, arguments=arguments
886
+ )
887
+ elif mcp_server:
888
+ server_tools = get_server_tools(mcp_server)
889
+ if tool_name in server_tools:
890
+ tool = server_tools[tool_name]
891
+ if inspect.iscoroutinefunction(tool.fn):
892
+ result = await tool.fn(**arguments)
893
+ else:
894
+ result = tool.fn(**arguments)
895
+ else:
896
+ return JsonRpcResponse.error_response(
897
+ JsonRpcErrorCode.INVALID_PARAMS,
898
+ f"Tool not found: {tool_name}",
899
+ request_id=request_id,
900
+ )
901
+ else:
902
+ return JsonRpcResponse.error_response(
903
+ JsonRpcErrorCode.INTERNAL_ERROR,
904
+ "MCP server not available",
905
+ request_id=request_id,
906
  )
907
 
908
  # Ensure result is JSON serializable
 
927
  request_id=request_id,
928
  )
929
  finally:
930
+ if managed_session_id:
931
+ self._update_session_activity(
932
+ managed_session_id,
933
+ increment_step=(method == McpMethod.TOOLS_CALL),
934
+ )
935
  if should_close:
936
  _env.close()
937
 
 
955
  try:
956
  # Create session with dedicated environment
957
  session_id, session_env = await self._create_session()
958
+ if session_env is None:
959
+ raise RuntimeError(
960
+ "Session environment not initialized for MCP websocket"
961
+ )
962
 
963
+ # If environment has an mcp_session context manager, hold it open
964
+ # for the lifetime of the websocket connection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
965
 
966
+ async with AsyncExitStack() as stack:
967
+ mcp_session_factory = getattr(session_env, "mcp_session", None)
968
+ if callable(mcp_session_factory):
969
+ mcp_session_cm = cast(
970
+ AsyncContextManager[Any], mcp_session_factory()
971
  )
972
+ await stack.enter_async_context(mcp_session_cm)
973
+
974
+ while True:
975
+ # Receive message from client
976
+ raw_message = await websocket.receive_text()
977
+
978
+ try:
979
+ jsonrpc_dict = json.loads(raw_message)
980
+ jsonrpc_request = JsonRpcRequest(**jsonrpc_dict)
981
+ except json.JSONDecodeError as e:
982
+ error_resp = JsonRpcResponse.error_response(
983
+ JsonRpcErrorCode.PARSE_ERROR,
984
+ f"Parse error: {e}",
985
+ )
986
+ await websocket.send_text(error_resp.model_dump_json())
987
+ continue
988
+ except ValidationError as e:
989
+ error_resp = JsonRpcResponse.error_response(
990
+ JsonRpcErrorCode.INVALID_REQUEST,
991
+ f"Invalid request: {e}",
992
+ )
993
+ await websocket.send_text(error_resp.model_dump_json())
994
+ continue
995
+
996
+ try:
997
+ # Call mcp_handler with session environment
998
+ response = await mcp_handler(
999
+ jsonrpc_request,
1000
+ session_env=session_env,
1001
+ session_id=session_id,
1002
+ )
1003
+ await websocket.send_text(response.model_dump_json())
1004
+ except Exception as e:
1005
+ error_resp = JsonRpcResponse.error_response(
1006
+ JsonRpcErrorCode.INTERNAL_ERROR,
1007
+ str(e),
1008
+ request_id=jsonrpc_request.id,
1009
+ )
1010
+ await websocket.send_text(error_resp.model_dump_json())
1011
 
1012
  except WebSocketDisconnect:
1013
  pass
 
1266
  JsonRpcErrorCode.PARSE_ERROR
1267
  ).model_dump()
1268
 
1269
+ response = await mcp_handler(request)
1270
+ return response.model_dump()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1271
 
1272
  # Register WebSocket endpoint for persistent sessions
1273
  @app.websocket("/ws")
 
1289
  try:
1290
  # Create session with dedicated environment
1291
  session_id, session_env = await self._create_session()
1292
+ if session_env is None:
1293
+ raise RuntimeError(
1294
+ "Session environment not initialized for websocket"
1295
+ )
1296
 
1297
+ # Keep MCP session open for entire websocket lifetime
1298
+ # (avoids reconnect overhead on every message)
 
1299
 
1300
+ async with AsyncExitStack() as stack:
1301
+ mcp_session_factory = getattr(session_env, "mcp_session", None)
1302
+ if callable(mcp_session_factory):
1303
+ mcp_session_cm = cast(
1304
+ AsyncContextManager[Any], mcp_session_factory()
 
 
 
1305
  )
1306
+ await stack.enter_async_context(mcp_session_cm)
1307
+
1308
+ while True:
1309
+ # Receive message from client
1310
+ raw_message = await websocket.receive_text()
1311
+
1312
+ try:
1313
+ message_dict = json.loads(raw_message)
1314
+ except json.JSONDecodeError as e:
1315
+ error_resp = WSErrorResponse(
1316
+ data={
1317
+ "message": f"Invalid JSON: {e}",
1318
+ "code": WSErrorCode.INVALID_JSON,
1319
+ }
1320
+ )
1321
+ await websocket.send_text(error_resp.model_dump_json())
1322
+ continue
1323
 
1324
+ msg_type = message_dict.get("type", "")
 
 
 
1325
 
1326
+ try:
1327
+ match msg_type:
1328
+ case "reset":
1329
+ msg = WSResetMessage(**message_dict)
1330
 
1331
+ is_async = (
1332
+ session_env.reset_async.__func__
1333
+ is not Environment.reset_async
 
 
1334
  )
 
 
 
 
 
 
 
 
 
 
 
 
1335
 
1336
+ if is_async:
1337
+ sig = inspect.signature(session_env.reset_async)
1338
+ valid_kwargs = self._get_valid_kwargs(
1339
+ sig, msg.data
1340
+ )
1341
+ observation = await session_env.reset_async(
1342
+ **valid_kwargs
1343
+ )
1344
+ else:
1345
+ sig = inspect.signature(session_env.reset)
1346
+ valid_kwargs = self._get_valid_kwargs(
1347
+ sig, msg.data
1348
+ )
1349
+ observation = (
1350
+ await self._run_in_session_executor(
1351
+ session_id,
1352
+ session_env.reset,
1353
+ **valid_kwargs,
1354
+ )
1355
+ )
1356
+
1357
+ self._update_session_activity(session_id)
1358
+
1359
+ response = WSObservationResponse(
1360
+ data=serialize_observation(observation),
1361
+ )
1362
 
1363
+ case "step":
1364
+ msg = WSStepMessage(**message_dict)
1365
+ action = deserialize_action(
1366
+ msg.data, self.action_cls
1367
+ )
1368
 
1369
+ is_async = (
1370
+ session_env.step_async.__func__
1371
+ is not Environment.step_async
 
 
1372
  )
1373
 
1374
+ if is_async:
1375
+ observation = await session_env.step_async(
1376
+ action
1377
+ )
1378
+ else:
1379
+ observation = (
1380
+ await self._run_in_session_executor(
1381
+ session_id, session_env.step, action
1382
+ )
1383
+ )
1384
+
1385
+ self._update_session_activity(
1386
+ session_id, increment_step=True
1387
+ )
1388
 
1389
+ response = WSObservationResponse(
1390
+ data=serialize_observation(observation)
1391
+ )
1392
 
1393
+ case "state":
1394
+ msg = WSStateMessage(**message_dict)
1395
+ state = session_env.state
1396
+ if hasattr(state, "model_dump"):
1397
+ state_data = state.model_dump()
1398
+ else:
1399
+ state_data = dict(state) if state else {}
1400
+
1401
+ response = WSStateResponse(data=state_data)
1402
+
1403
+ case "close":
1404
+ msg = WSCloseMessage(**message_dict)
1405
+ break
1406
+
1407
+ case "mcp":
1408
+ msg = WSMCPMessage(**message_dict)
1409
+ try:
1410
+ rpc_request = JsonRpcRequest(**msg.data)
1411
+ except (ValidationError, Exception) as e:
1412
+ rpc_response = JsonRpcResponse.error_response(
1413
+ JsonRpcErrorCode.INVALID_REQUEST,
1414
+ f"Invalid request: {e}",
1415
+ )
1416
+ else:
1417
+ rpc_response = await mcp_handler(
1418
+ rpc_request,
1419
+ session_env=session_env,
1420
+ session_id=session_id,
1421
+ )
1422
+ response = WSMCPResponse(
1423
+ data=rpc_response.model_dump()
1424
  )
1425
+
1426
+ case _:
1427
+ response = WSErrorResponse(
1428
+ data={
1429
+ "message": f"Unknown message type: {msg_type}",
1430
+ "code": WSErrorCode.UNKNOWN_TYPE,
1431
+ }
1432
  )
 
 
 
 
 
 
 
 
 
1433
 
1434
+ await websocket.send_text(response.model_dump_json())
1435
 
1436
+ except ValidationError as e:
1437
+ error_resp = WSErrorResponse(
1438
+ data={
1439
+ "message": "Invalid message",
1440
+ "code": WSErrorCode.VALIDATION_ERROR,
1441
+ "errors": e.errors(),
1442
+ }
1443
+ )
1444
+ await websocket.send_text(error_resp.model_dump_json())
1445
+ except Exception as e:
1446
+ error_resp = WSErrorResponse(
1447
+ data={
1448
+ "message": str(e),
1449
+ "code": WSErrorCode.EXECUTION_ERROR,
1450
+ }
1451
+ )
1452
+ await websocket.send_text(error_resp.model_dump_json())
1453
 
1454
  except WebSocketDisconnect:
1455
  pass
 
1531
  from .web_interface import create_web_interface_app
1532
 
1533
  return create_web_interface_app(
1534
+ cast(Any, env),
1535
  action_cls,
1536
  observation_cls,
1537
  env_name,
src/core/openenv/core/env_server/mcp_environment.py CHANGED
@@ -56,6 +56,7 @@ import asyncio
56
  import inspect
57
  from abc import abstractmethod
58
  from collections import defaultdict
 
59
  from typing import Any, Callable, Dict, Optional
60
 
61
  from fastmcp import Client
@@ -164,6 +165,52 @@ class MCPEnvironment(Environment):
164
  # Track tool schemas for list_tools: {tool_name: {mode: schema}}
165
  self._mode_tool_schemas = defaultdict(dict)
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  @property
168
  def supports_code_mode(self) -> bool:
169
  """Check if this environment supports code mode (execute_code)."""
@@ -292,7 +339,8 @@ class MCPEnvironment(Environment):
292
 
293
  # If mode is None, register with FastMCP as usual
294
  if mode is None:
295
- decorated_func = self.mcp_server.tool()(func)
 
296
  self._mode_tools[tool_name][None] = func
297
  return decorated_func
298
 
@@ -372,24 +420,49 @@ class MCPEnvironment(Environment):
372
  return self._step_impl(action, timeout_s=timeout_s, **kwargs)
373
 
374
  def _handle_list_tools(self) -> ListToolsObservation:
 
 
 
 
375
  """
376
- Handle a ListToolsAction by querying the MCP server.
377
 
378
  Returns:
379
- ListToolsObservation containing all available tools with their
380
- names, descriptions, and input schemas, filtered by current mode.
381
  """
382
- try:
383
- # Get current mode
384
- current_mode = getattr(self, "_mode", None)
385
 
386
- # Start with tools from FastMCP server (mode=None tools)
387
- tools_result = run_async_safely(self._async_list_tools())
 
 
 
 
 
 
 
388
 
389
- # Build list of Tool objects
390
- tools = []
 
391
 
392
- # Add FastMCP tools that are not mode-specific
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  for tool in tools_result:
394
  if tool.name not in self._mode_tool_schemas:
395
  tools.append(
@@ -401,11 +474,8 @@ class MCPEnvironment(Environment):
401
  else {},
402
  )
403
  )
404
-
405
- # Add mode-specific tools available in current mode
406
  for tool_name, mode_schemas in self._mode_tool_schemas.items():
407
  if None in mode_schemas:
408
- # Tool available in all modes
409
  schema = mode_schemas[None]
410
  tools.append(
411
  Tool(
@@ -415,7 +485,6 @@ class MCPEnvironment(Environment):
415
  )
416
  )
417
  elif current_mode in mode_schemas:
418
- # Tool available in current mode
419
  schema = mode_schemas[current_mode]
420
  tools.append(
421
  Tool(
@@ -424,65 +493,30 @@ class MCPEnvironment(Environment):
424
  input_schema=schema["input_schema"],
425
  )
426
  )
427
-
428
  return ListToolsObservation(tools=tools)
429
-
430
  except Exception as e:
431
- # Return an observation with error in metadata
432
  return ListToolsObservation(
433
  tools=[],
434
- metadata={
435
- "error": str(e),
436
- "error_type": "list_tools_failed",
437
- },
438
  )
439
 
440
- async def _async_list_tools(self) -> list:
441
- """
442
- Async helper to list tools from the MCP client.
443
-
444
- Returns:
445
- List of tool objects from the MCP server.
446
- """
447
- async with self.mcp_client:
448
- return await self.mcp_client.list_tools()
449
-
450
- def _handle_call_tool(
451
  self,
452
  action: CallToolAction,
453
  timeout_s: Optional[float] = None,
454
  ) -> CallToolObservation:
455
- """
456
- Handle a CallToolAction by invoking the specified tool.
457
-
458
- Args:
459
- action: The CallToolAction containing tool_name and arguments.
460
- timeout_s: Timeout in seconds. Defaults to MCP_TOOL_CALL_TIMEOUT (30s).
461
-
462
- Returns:
463
- CallToolObservation with the tool's result or an error.
464
- """
465
  timeout = timeout_s if timeout_s is not None else MCP_TOOL_CALL_TIMEOUT
466
-
467
- # Check if this is a mode-specific tool
468
  tool_name = action.tool_name
469
  current_mode = getattr(self, "_mode", None)
470
 
471
  if tool_name in self._mode_tools:
472
  mode_info = self._mode_tools[tool_name]
473
-
474
- # Check if tool is available in current mode
475
- # Tool is available if:
476
- # 1. It has a None mode (available in all modes), OR
477
- # 2. It has an implementation for the current mode
478
  if None in mode_info:
479
- # Use the mode-agnostic version
480
  func = mode_info[None]
481
  elif current_mode in mode_info:
482
- # Use the mode-specific version
483
  func = mode_info[current_mode]
484
  else:
485
- # Tool not available in current mode
486
  return CallToolObservation(
487
  tool_name=tool_name,
488
  result=None,
@@ -491,16 +525,11 @@ class MCPEnvironment(Environment):
491
  message=f"Tool '{tool_name}' not available in {current_mode} mode",
492
  ),
493
  )
494
-
495
- # Call the mode-specific function directly
496
  try:
497
- # Check if function is async and await if necessary
498
  if inspect.iscoroutinefunction(func):
499
- result = run_async_safely(func(**action.arguments))
500
  else:
501
  result = func(**action.arguments)
502
-
503
- # Wrap result in CallToolResult format to match FastMCP behavior
504
  return CallToolObservation(
505
  tool_name=tool_name,
506
  result=CallToolResult(
@@ -521,22 +550,12 @@ class MCPEnvironment(Environment):
521
  ),
522
  )
523
 
524
- # Not a mode-specific tool, use FastMCP
525
  try:
526
- # Run the async call_tool with timeout
527
- # Use run_async_safely to handle both sync and async contexts
528
- result = run_async_safely(
529
- asyncio.wait_for(
530
- self._async_call_tool(action.tool_name, action.arguments),
531
- timeout=timeout,
532
- )
533
- )
534
-
535
- return CallToolObservation(
536
- tool_name=action.tool_name,
537
- result=result,
538
  )
539
-
540
  except asyncio.TimeoutError:
541
  return CallToolObservation(
542
  tool_name=action.tool_name,
@@ -546,11 +565,8 @@ class MCPEnvironment(Environment):
546
  message=f"Tool '{action.tool_name}' timed out after {timeout} seconds",
547
  ),
548
  )
549
-
550
  except Exception as e:
551
  error_message = str(e)
552
-
553
- # Determine error type based on the exception
554
  if (
555
  "not found" in error_message.lower()
556
  or "unknown tool" in error_message.lower()
@@ -563,29 +579,34 @@ class MCPEnvironment(Environment):
563
  error_type = ToolErrorType.INVALID_ARGS
564
  else:
565
  error_type = ToolErrorType.EXECUTION_ERROR
566
-
567
  return CallToolObservation(
568
  tool_name=action.tool_name,
569
  result=None,
570
- error=ToolError(
571
- error_type=error_type,
572
- message=error_message,
573
- ),
574
  )
575
 
576
- async def _async_call_tool(self, tool_name: str, arguments: dict) -> Any:
 
 
 
 
 
577
  """
578
- Async helper to call a tool on the MCP server.
579
 
580
- Args:
581
- tool_name: Name of the tool to invoke.
582
- arguments: Dictionary of arguments to pass to the tool.
583
-
584
- Returns:
585
- The result from the tool execution.
586
  """
587
- async with self.mcp_client:
588
- return await self.mcp_client.call_tool(tool_name, arguments)
 
 
 
 
 
 
 
589
 
590
  @abstractmethod
591
  def _step_impl(
 
56
  import inspect
57
  from abc import abstractmethod
58
  from collections import defaultdict
59
+ from contextlib import asynccontextmanager
60
  from typing import Any, Callable, Dict, Optional
61
 
62
  from fastmcp import Client
 
165
  # Track tool schemas for list_tools: {tool_name: {mode: schema}}
166
  self._mode_tool_schemas = defaultdict(dict)
167
 
168
+ def _require_mcp_client(self) -> Any:
169
+ """Return MCP client or raise if environment has been closed."""
170
+ if self.mcp_client is None:
171
+ raise RuntimeError("MCP client is not available; environment is closed")
172
+ return self.mcp_client
173
+
174
+ def _require_mcp_server(self) -> Any:
175
+ """Return MCP server or raise if environment has been closed."""
176
+ if self.mcp_server is None:
177
+ raise RuntimeError("MCP server is not available; environment is closed")
178
+ return self.mcp_server
179
+
180
+ @asynccontextmanager
181
+ async def mcp_session(self):
182
+ """
183
+ Context manager for MCP client sessions.
184
+
185
+ This wrapper serves two purposes:
186
+
187
+ 1. **Null guard** — raises a clear error if ``close()`` has already
188
+ been called (``mcp_client`` is ``None``).
189
+
190
+ 2. **AsyncExitStack adapter** — FastMCP's ``Client.__aenter__``
191
+ creates a background ``asyncio.Task`` for session management.
192
+ When entered directly via ``AsyncExitStack`` in the HTTP session
193
+ path (``_create_session``), this task can be cancelled by ASGI
194
+ harnesses (e.g. Starlette ``TestClient``) between requests,
195
+ corrupting session state. Wrapping in an ``asynccontextmanager``
196
+ generator isolates the task lifecycle: the generator frame keeps
197
+ ``async with client:`` suspended at ``yield``, so cleanup only
198
+ runs when the stack explicitly closes the generator — not when
199
+ the event loop cancels orphaned tasks.
200
+
201
+ Delegates to FastMCP's ``Client`` context manager which is
202
+ reentrant: the first entry opens the transport and subsequent
203
+ (nested) entries simply increment an internal reference counter.
204
+ The transport is closed only when the outermost context exits.
205
+
206
+ No external lock is needed because ``Client._connect`` /
207
+ ``Client._disconnect`` already serialise connection state changes
208
+ through their own ``anyio.Lock``.
209
+ """
210
+ client = self._require_mcp_client()
211
+ async with client:
212
+ yield client
213
+
214
  @property
215
  def supports_code_mode(self) -> bool:
216
  """Check if this environment supports code mode (execute_code)."""
 
339
 
340
  # If mode is None, register with FastMCP as usual
341
  if mode is None:
342
+ mcp_server = self._require_mcp_server()
343
+ decorated_func = mcp_server.tool()(func)
344
  self._mode_tools[tool_name][None] = func
345
  return decorated_func
346
 
 
420
  return self._step_impl(action, timeout_s=timeout_s, **kwargs)
421
 
422
  def _handle_list_tools(self) -> ListToolsObservation:
423
+ """Sync wrapper — delegates to the canonical async implementation."""
424
+ return run_async_safely(self._async_handle_list_tools())
425
+
426
+ async def _async_list_tools(self) -> list:
427
  """
428
+ Async helper to list tools from the MCP client.
429
 
430
  Returns:
431
+ List of tool objects from the MCP server.
 
432
  """
433
+ async with self.mcp_session() as client:
434
+ return await client.list_tools()
 
435
 
436
+ def _handle_call_tool(
437
+ self,
438
+ action: CallToolAction,
439
+ timeout_s: Optional[float] = None,
440
+ ) -> CallToolObservation:
441
+ """Sync wrapper — delegates to the canonical async implementation."""
442
+ return run_async_safely(
443
+ self._async_handle_call_tool(action, timeout_s=timeout_s)
444
+ )
445
 
446
+ async def _async_call_tool(self, tool_name: str, arguments: dict) -> Any:
447
+ """
448
+ Async helper to call a tool on the MCP server.
449
 
450
+ Args:
451
+ tool_name: Name of the tool to invoke.
452
+ arguments: Dictionary of arguments to pass to the tool.
453
+
454
+ Returns:
455
+ The result from the tool execution.
456
+ """
457
+ async with self.mcp_session() as client:
458
+ return await client.call_tool(tool_name, arguments)
459
+
460
+ async def _async_handle_list_tools(self) -> ListToolsObservation:
461
+ """Async version of _handle_list_tools — avoids run_async_safely."""
462
+ try:
463
+ current_mode = getattr(self, "_mode", None)
464
+ tools_result = await self._async_list_tools()
465
+ tools = []
466
  for tool in tools_result:
467
  if tool.name not in self._mode_tool_schemas:
468
  tools.append(
 
474
  else {},
475
  )
476
  )
 
 
477
  for tool_name, mode_schemas in self._mode_tool_schemas.items():
478
  if None in mode_schemas:
 
479
  schema = mode_schemas[None]
480
  tools.append(
481
  Tool(
 
485
  )
486
  )
487
  elif current_mode in mode_schemas:
 
488
  schema = mode_schemas[current_mode]
489
  tools.append(
490
  Tool(
 
493
  input_schema=schema["input_schema"],
494
  )
495
  )
 
496
  return ListToolsObservation(tools=tools)
 
497
  except Exception as e:
 
498
  return ListToolsObservation(
499
  tools=[],
500
+ metadata={"error": str(e), "error_type": "list_tools_failed"},
 
 
 
501
  )
502
 
503
+ async def _async_handle_call_tool(
 
 
 
 
 
 
 
 
 
 
504
  self,
505
  action: CallToolAction,
506
  timeout_s: Optional[float] = None,
507
  ) -> CallToolObservation:
508
+ """Async version of _handle_call_tool — avoids run_async_safely."""
 
 
 
 
 
 
 
 
 
509
  timeout = timeout_s if timeout_s is not None else MCP_TOOL_CALL_TIMEOUT
 
 
510
  tool_name = action.tool_name
511
  current_mode = getattr(self, "_mode", None)
512
 
513
  if tool_name in self._mode_tools:
514
  mode_info = self._mode_tools[tool_name]
 
 
 
 
 
515
  if None in mode_info:
 
516
  func = mode_info[None]
517
  elif current_mode in mode_info:
 
518
  func = mode_info[current_mode]
519
  else:
 
520
  return CallToolObservation(
521
  tool_name=tool_name,
522
  result=None,
 
525
  message=f"Tool '{tool_name}' not available in {current_mode} mode",
526
  ),
527
  )
 
 
528
  try:
 
529
  if inspect.iscoroutinefunction(func):
530
+ result = await func(**action.arguments)
531
  else:
532
  result = func(**action.arguments)
 
 
533
  return CallToolObservation(
534
  tool_name=tool_name,
535
  result=CallToolResult(
 
550
  ),
551
  )
552
 
 
553
  try:
554
+ result = await asyncio.wait_for(
555
+ self._async_call_tool(action.tool_name, action.arguments),
556
+ timeout=timeout,
 
 
 
 
 
 
 
 
 
557
  )
558
+ return CallToolObservation(tool_name=action.tool_name, result=result)
559
  except asyncio.TimeoutError:
560
  return CallToolObservation(
561
  tool_name=action.tool_name,
 
565
  message=f"Tool '{action.tool_name}' timed out after {timeout} seconds",
566
  ),
567
  )
 
568
  except Exception as e:
569
  error_message = str(e)
 
 
570
  if (
571
  "not found" in error_message.lower()
572
  or "unknown tool" in error_message.lower()
 
579
  error_type = ToolErrorType.INVALID_ARGS
580
  else:
581
  error_type = ToolErrorType.EXECUTION_ERROR
 
582
  return CallToolObservation(
583
  tool_name=action.tool_name,
584
  result=None,
585
+ error=ToolError(error_type=error_type, message=error_message),
 
 
 
586
  )
587
 
588
+ async def step_async(
589
+ self,
590
+ action: Action,
591
+ timeout_s: Optional[float] = None,
592
+ **kwargs: Any,
593
+ ) -> Observation:
594
  """
595
+ Async step that routes MCP actions without going through run_async_safely.
596
 
597
+ The WebSocket handler calls this directly on the outer event loop, where
598
+ the MCP session is already open, avoiding the thread/event-loop deadlock
599
+ that occurs when the sync step() path is used via run_in_executor.
 
 
 
600
  """
601
+ if isinstance(action, ListToolsAction):
602
+ return await self._async_handle_list_tools()
603
+ elif isinstance(action, CallToolAction):
604
+ return await self._async_handle_call_tool(action, timeout_s=timeout_s)
605
+ else:
606
+ loop = asyncio.get_event_loop()
607
+ return await loop.run_in_executor(
608
+ None, lambda: self._step_impl(action, timeout_s=timeout_s, **kwargs)
609
+ )
610
 
611
  @abstractmethod
612
  def _step_impl(
src/core/openenv/core/env_server/serialization.py CHANGED
@@ -14,14 +14,28 @@ HTTP server and web interface implementations.
14
 
15
  from typing import Any, Dict, Type
16
 
 
17
  from .types import Action, Observation
18
 
 
 
 
 
 
 
 
 
 
19
 
20
  def deserialize_action(action_data: Dict[str, Any], action_cls: Type[Action]) -> Action:
21
  """
22
  Convert JSON dict to Action instance using Pydantic validation.
23
 
24
- This is a basic deserialization that works for most environments.
 
 
 
 
25
  For special cases (e.g., tensor fields, custom type conversions),
26
  use deserialize_action_with_preprocessing().
27
 
@@ -38,6 +52,17 @@ def deserialize_action(action_data: Dict[str, Any], action_cls: Type[Action]) ->
38
  Note:
39
  This uses Pydantic's model_validate() for automatic validation.
40
  """
 
 
 
 
 
 
 
 
 
 
 
41
  return action_cls.model_validate(action_data)
42
 
43
 
@@ -62,6 +87,15 @@ def deserialize_action_with_preprocessing(
62
  Raises:
63
  ValidationError: If action_data is invalid for the action class
64
  """
 
 
 
 
 
 
 
 
 
65
  processed_data = {}
66
 
67
  for key, value in action_data.items():
 
14
 
15
  from typing import Any, Dict, Type
16
 
17
+ from .mcp_types import CallToolAction, ListToolsAction
18
  from .types import Action, Observation
19
 
20
+ # MCP action types keyed by their "type" discriminator value.
21
+ # These are checked before the environment's own action_cls so that
22
+ # ListToolsAction / CallToolAction payloads are never rejected by an
23
+ # unrelated Pydantic model.
24
+ _MCP_ACTION_TYPES: Dict[str, Type[Action]] = {
25
+ "list_tools": ListToolsAction,
26
+ "call_tool": CallToolAction,
27
+ }
28
+
29
 
30
  def deserialize_action(action_data: Dict[str, Any], action_cls: Type[Action]) -> Action:
31
  """
32
  Convert JSON dict to Action instance using Pydantic validation.
33
 
34
+ MCP action types (``list_tools``, ``call_tool``) are recognised
35
+ automatically via the ``"type"`` discriminator field, regardless of
36
+ the environment's configured ``action_cls``. All other payloads
37
+ fall through to ``action_cls.model_validate()``.
38
+
39
  For special cases (e.g., tensor fields, custom type conversions),
40
  use deserialize_action_with_preprocessing().
41
 
 
52
  Note:
53
  This uses Pydantic's model_validate() for automatic validation.
54
  """
55
+ # Route MCP action types before falling through to the env action_cls.
56
+ # Only intercept when action_cls is the generic Action base or itself an
57
+ # MCP type (i.e. the server hosts an MCP environment). This avoids
58
+ # silently bypassing env-specific validation for non-MCP environments
59
+ # that happen to use "call_tool" / "list_tools" as a type discriminator.
60
+ action_type = action_data.get("type")
61
+ if action_type in _MCP_ACTION_TYPES:
62
+ mcp_cls = _MCP_ACTION_TYPES[action_type]
63
+ if action_cls is Action or action_cls in _MCP_ACTION_TYPES.values():
64
+ return mcp_cls.model_validate(action_data)
65
+
66
  return action_cls.model_validate(action_data)
67
 
68
 
 
87
  Raises:
88
  ValidationError: If action_data is invalid for the action class
89
  """
90
+ # Route MCP action types before preprocessing (they don't need it).
91
+ # Same guard as deserialize_action: only intercept when action_cls is
92
+ # the generic Action base or itself an MCP type.
93
+ action_type = action_data.get("type")
94
+ if action_type in _MCP_ACTION_TYPES:
95
+ mcp_cls = _MCP_ACTION_TYPES[action_type]
96
+ if action_cls is Action or action_cls in _MCP_ACTION_TYPES.values():
97
+ return mcp_cls.model_validate(action_data)
98
+
99
  processed_data = {}
100
 
101
  for key, value in action_data.items():
src/core/openenv/core/env_server/web_interface.py CHANGED
@@ -15,13 +15,15 @@ option (e.g. openenv push --enable-interface) or ENABLE_WEB_INTERFACE env var.
15
  from __future__ import annotations
16
 
17
  import asyncio
 
18
  import json
19
  from concurrent.futures import ThreadPoolExecutor
20
  from datetime import datetime
21
  from typing import Any, Callable, Dict, List, Optional, Type
22
 
23
  import gradio as gr
24
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect
 
25
  from pydantic import BaseModel, ConfigDict, Field
26
 
27
  from .gradio_theme import OPENENV_GRADIO_CSS, OPENENV_GRADIO_THEME
@@ -269,6 +271,28 @@ class WebInterfaceManager:
269
  # Thread pool for running sync code (e.g., Playwright sync API) in async context
270
  self._executor = ThreadPoolExecutor(max_workers=1)
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  async def _run_sync_in_thread_pool(self, func, *args, **kwargs):
273
  """Run a synchronous function in the thread pool executor.
274
 
@@ -317,11 +341,24 @@ class WebInterfaceManager:
317
  for client in disconnected_clients:
318
  self.connected_clients.remove(client)
319
 
320
- async def reset_environment(self) -> Dict[str, Any]:
 
 
321
  """Reset the environment and update state."""
322
- # Run sync reset in thread pool to avoid blocking event loop
323
- # and to support environments using sync libraries (e.g., Playwright)
324
- observation: Observation = await self._run_sync_in_thread_pool(self.env.reset)
 
 
 
 
 
 
 
 
 
 
 
325
  state: State = self.env.state
326
 
327
  # Serialize observation once using shared utility
@@ -428,6 +465,16 @@ def create_web_interface_app(
428
  web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata)
429
 
430
  # Web API routes first (so they take precedence over Gradio mount at /web)
 
 
 
 
 
 
 
 
 
 
431
  @app.get("/web/metadata")
432
  async def web_metadata():
433
  """Get environment metadata."""
@@ -449,9 +496,9 @@ def create_web_interface_app(
449
  await web_manager.disconnect_websocket(websocket)
450
 
451
  @app.post("/web/reset")
452
- async def web_reset():
453
  """Reset endpoint for web interface."""
454
- return await web_manager.reset_environment()
455
 
456
  @app.post("/web/step")
457
  async def web_step(request: Dict[str, Any]):
@@ -475,7 +522,13 @@ def create_web_interface_app(
475
  @app.get("/web/state")
476
  async def web_state():
477
  """State endpoint for web interface."""
478
- return web_manager.get_state()
 
 
 
 
 
 
479
 
480
  action_fields = _extract_action_fields(action_cls)
481
  is_chat_env = _is_chat_env(action_cls)
@@ -505,7 +558,7 @@ def create_web_interface_app(
505
  )
506
  gradio_blocks = gr.TabbedInterface(
507
  [default_blocks, custom_blocks],
508
- tab_names=["Playground", "Visualization"],
509
  title=get_gradio_display_title(metadata),
510
  )
511
  else:
 
15
  from __future__ import annotations
16
 
17
  import asyncio
18
+ import inspect
19
  import json
20
  from concurrent.futures import ThreadPoolExecutor
21
  from datetime import datetime
22
  from typing import Any, Callable, Dict, List, Optional, Type
23
 
24
  import gradio as gr
25
+ from fastapi import Body, FastAPI, HTTPException, status, WebSocket, WebSocketDisconnect
26
+ from fastapi.responses import RedirectResponse
27
  from pydantic import BaseModel, ConfigDict, Field
28
 
29
  from .gradio_theme import OPENENV_GRADIO_CSS, OPENENV_GRADIO_THEME
 
271
  # Thread pool for running sync code (e.g., Playwright sync API) in async context
272
  self._executor = ThreadPoolExecutor(max_workers=1)
273
 
274
+ @staticmethod
275
+ def _get_valid_kwargs(
276
+ sig: inspect.Signature,
277
+ kwargs: Dict[str, Any],
278
+ skip_params: Optional[set[str]] = None,
279
+ ) -> Dict[str, Any]:
280
+ """Filter kwargs to only those accepted by the target function."""
281
+ skip_params = skip_params or set()
282
+ valid_kwargs: Dict[str, Any] = {}
283
+ has_var_kwargs = any(
284
+ param.kind == inspect.Parameter.VAR_KEYWORD
285
+ for param in sig.parameters.values()
286
+ )
287
+
288
+ for key, value in kwargs.items():
289
+ if key in skip_params:
290
+ continue
291
+ if key in sig.parameters or has_var_kwargs:
292
+ valid_kwargs[key] = value
293
+
294
+ return valid_kwargs
295
+
296
  async def _run_sync_in_thread_pool(self, func, *args, **kwargs):
297
  """Run a synchronous function in the thread pool executor.
298
 
 
341
  for client in disconnected_clients:
342
  self.connected_clients.remove(client)
343
 
344
+ async def reset_environment(
345
+ self, reset_kwargs: Optional[Dict[str, Any]] = None
346
+ ) -> Dict[str, Any]:
347
  """Reset the environment and update state."""
348
+ reset_kwargs = reset_kwargs or {}
349
+
350
+ is_async = self.env.reset_async.__func__ is not Environment.reset_async
351
+ sig = inspect.signature(self.env.reset_async if is_async else self.env.reset)
352
+ valid_kwargs = self._get_valid_kwargs(sig, reset_kwargs)
353
+
354
+ if is_async:
355
+ observation = await self.env.reset_async(**valid_kwargs)
356
+ else:
357
+ # Run sync reset in thread pool to avoid blocking event loop
358
+ # and to support environments using sync libraries (e.g., Playwright)
359
+ observation = await self._run_sync_in_thread_pool(
360
+ self.env.reset, **valid_kwargs
361
+ )
362
  state: State = self.env.state
363
 
364
  # Serialize observation once using shared utility
 
465
  web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata)
466
 
467
  # Web API routes first (so they take precedence over Gradio mount at /web)
468
+ @app.get("/", include_in_schema=False)
469
+ async def web_root():
470
+ """Redirect the app root to the Gradio interface."""
471
+ return RedirectResponse(url="/web/")
472
+
473
+ @app.get("/web", include_in_schema=False)
474
+ async def web_root_no_slash():
475
+ """Redirect /web to /web/ for mounted Gradio deployments behind proxies."""
476
+ return RedirectResponse(url="/web/")
477
+
478
  @app.get("/web/metadata")
479
  async def web_metadata():
480
  """Get environment metadata."""
 
496
  await web_manager.disconnect_websocket(websocket)
497
 
498
  @app.post("/web/reset")
499
+ async def web_reset(request: Optional[Dict[str, Any]] = Body(default=None)):
500
  """Reset endpoint for web interface."""
501
+ return await web_manager.reset_environment(request)
502
 
503
  @app.post("/web/step")
504
  async def web_step(request: Dict[str, Any]):
 
522
  @app.get("/web/state")
523
  async def web_state():
524
  """State endpoint for web interface."""
525
+ try:
526
+ return web_manager.get_state()
527
+ except RuntimeError as exc:
528
+ raise HTTPException(
529
+ status_code=status.HTTP_409_CONFLICT,
530
+ detail=str(exc),
531
+ ) from exc
532
 
533
  action_fields = _extract_action_fields(action_cls)
534
  is_chat_env = _is_chat_env(action_cls)
 
558
  )
559
  gradio_blocks = gr.TabbedInterface(
560
  [default_blocks, custom_blocks],
561
+ tab_names=["Playground", "Custom"],
562
  title=get_gradio_display_title(metadata),
563
  )
564
  else:
src/core/openenv/core/mcp_client.py CHANGED
@@ -52,6 +52,7 @@ Example (sync wrapper):
52
  ... result = env.call_tool("echo_message", message="Hello!")
53
  """
54
 
 
55
  from typing import Any, Dict, List, Optional
56
 
57
  from .client_types import StepResult
@@ -118,6 +119,66 @@ class MCPClientBase(EnvClient[Any, Observation, State]):
118
  )
119
  self._tools_cache: Optional[List[Tool]] = None
120
  self.use_production_mode = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  async def list_tools(self, use_cache: bool = True) -> List[Tool]:
123
  """
@@ -138,26 +199,18 @@ class MCPClientBase(EnvClient[Any, Observation, State]):
138
  if use_cache and self._tools_cache is not None:
139
  return self._tools_cache
140
 
141
- # Use production mode HTTP endpoint if enabled
142
- if self.use_production_mode:
143
- import requests
144
-
145
- # Convert ws:// URL to http:// URL
146
- url = self._ws_url.replace("ws://", "http://").replace("wss://", "https://")
147
- # Remove /ws suffix if present and add /mcp
148
- url = url.rstrip("/ws").rstrip("/") + "/mcp"
149
-
150
  try:
151
- response = requests.post(
152
- url,
153
- json={
154
- "jsonrpc": "2.0",
155
- "method": "tools/list",
156
- "params": {},
157
- "id": 1,
158
- },
159
  )
160
- data = response.json()
 
 
161
  if "result" in data and "tools" in data["result"]:
162
  tools = [
163
  Tool(
@@ -177,7 +230,12 @@ class MCPClientBase(EnvClient[Any, Observation, State]):
177
  return []
178
 
179
  result = await self.step(ListToolsAction())
180
- self._tools_cache = result.observation.tools
 
 
 
 
 
181
  return self._tools_cache
182
 
183
  def _step_payload(self, action: Any) -> Dict[str, Any]:
@@ -251,6 +309,35 @@ class MCPClientBase(EnvClient[Any, Observation, State]):
251
  step_count=payload.get("step_count", 0),
252
  )
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
  class MCPToolClient(MCPClientBase):
256
  """
@@ -316,6 +403,26 @@ class MCPToolClient(MCPClientBase):
316
  >>> result = await env.call_tool("greet", name="Claude")
317
  >>> print(result) # "Hello, Claude!"
318
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  action = CallToolAction(tool_name=name, arguments=kwargs)
320
  result = await self.step(action)
321
  obs = result.observation
 
52
  ... result = env.call_tool("echo_message", message="Hello!")
53
  """
54
 
55
+ import asyncio
56
  from typing import Any, Dict, List, Optional
57
 
58
  from .client_types import StepResult
 
119
  )
120
  self._tools_cache: Optional[List[Tool]] = None
121
  self.use_production_mode = False
122
+ self._production_session_id: Optional[str] = None
123
+ self._production_session_lock = asyncio.Lock()
124
+ self._jsonrpc_request_id = 0
125
+ self._http_client: Optional[Any] = None # lazily-created httpx.AsyncClient
126
+
127
+ def _next_request_id(self) -> int:
128
+ """Generate a monotonically increasing JSON-RPC request id."""
129
+ self._jsonrpc_request_id += 1
130
+ return self._jsonrpc_request_id
131
+
132
+ def _production_mcp_url(self) -> str:
133
+ """Build HTTP MCP endpoint URL from the client's websocket URL."""
134
+ url = self._ws_url.replace("ws://", "http://").replace("wss://", "https://")
135
+ if url.endswith("/ws"):
136
+ url = url[: -len("/ws")]
137
+ return url.rstrip("/") + "/mcp"
138
+
139
+ async def _get_http_client(self) -> Any:
140
+ """Return a shared httpx.AsyncClient, creating one lazily."""
141
+ if self._http_client is None:
142
+ import httpx
143
+
144
+ self._http_client = httpx.AsyncClient()
145
+ return self._http_client
146
+
147
+ async def _production_mcp_request(
148
+ self, method: str, params: Optional[Dict[str, Any]] = None
149
+ ) -> Dict[str, Any]:
150
+ """Send a JSON-RPC request to HTTP /mcp and return parsed JSON response."""
151
+ client = await self._get_http_client()
152
+ response = await client.post(
153
+ self._production_mcp_url(),
154
+ json={
155
+ "jsonrpc": "2.0",
156
+ "method": method,
157
+ "params": params or {},
158
+ "id": self._next_request_id(),
159
+ },
160
+ timeout=self._message_timeout,
161
+ )
162
+ response.raise_for_status()
163
+ return response.json()
164
+
165
+ async def _ensure_production_session(self) -> str:
166
+ """Create and cache a persistent HTTP MCP session id if needed."""
167
+ async with self._production_session_lock:
168
+ if self._production_session_id is not None:
169
+ return self._production_session_id
170
+
171
+ data = await self._production_mcp_request("openenv/session/create")
172
+ if "error" in data:
173
+ message = data.get("error", {}).get("message", "unknown error")
174
+ raise RuntimeError(f"Failed to create MCP session: {message}")
175
+
176
+ session_id = data.get("result", {}).get("session_id")
177
+ if not session_id:
178
+ raise RuntimeError("Failed to create MCP session: missing session_id")
179
+
180
+ self._production_session_id = session_id
181
+ return session_id
182
 
183
  async def list_tools(self, use_cache: bool = True) -> List[Tool]:
184
  """
 
199
  if use_cache and self._tools_cache is not None:
200
  return self._tools_cache
201
 
202
+ # Use production mode HTTP endpoint if enabled.
203
+ # Some tests instantiate with __new__ and skip __init__, so default missing flag to False.
204
+ if getattr(self, "use_production_mode", False):
 
 
 
 
 
 
205
  try:
206
+ session_id = await self._ensure_production_session()
207
+ data = await self._production_mcp_request(
208
+ "tools/list",
209
+ {"session_id": session_id},
 
 
 
 
210
  )
211
+ if "error" in data:
212
+ message = data.get("error", {}).get("message", "unknown error")
213
+ raise RuntimeError(f"list_tools failed: {message}")
214
  if "result" in data and "tools" in data["result"]:
215
  tools = [
216
  Tool(
 
230
  return []
231
 
232
  result = await self.step(ListToolsAction())
233
+ if isinstance(result.observation, ListToolsObservation):
234
+ self._tools_cache = result.observation.tools
235
+ return self._tools_cache
236
+
237
+ # Unexpected observation type; keep API stable with an empty tool list.
238
+ self._tools_cache = []
239
  return self._tools_cache
240
 
241
  def _step_payload(self, action: Any) -> Dict[str, Any]:
 
309
  step_count=payload.get("step_count", 0),
310
  )
311
 
312
+ async def close(self) -> None:
313
+ """
314
+ Close client resources.
315
+
316
+ In production MCP mode, this also closes the server-side persistent
317
+ MCP session (best effort) before closing websocket/provider resources.
318
+ """
319
+ if self._production_session_id is not None:
320
+ try:
321
+ await self._production_mcp_request(
322
+ "openenv/session/close",
323
+ {"session_id": self._production_session_id},
324
+ )
325
+ except Exception:
326
+ # Best effort cleanup - do not mask normal close behavior
327
+ pass
328
+ finally:
329
+ self._production_session_id = None
330
+
331
+ if self._http_client is not None:
332
+ try:
333
+ await self._http_client.aclose()
334
+ except Exception:
335
+ pass
336
+ finally:
337
+ self._http_client = None
338
+
339
+ await super().close()
340
+
341
 
342
  class MCPToolClient(MCPClientBase):
343
  """
 
403
  >>> result = await env.call_tool("greet", name="Claude")
404
  >>> print(result) # "Hello, Claude!"
405
  """
406
+ if getattr(self, "use_production_mode", False):
407
+ session_id = await self._ensure_production_session()
408
+ data = await self._production_mcp_request(
409
+ "tools/call",
410
+ {
411
+ "name": name,
412
+ "arguments": kwargs,
413
+ "session_id": session_id,
414
+ },
415
+ )
416
+
417
+ if "error" in data:
418
+ message = data.get("error", {}).get("message", "unknown error")
419
+ raise RuntimeError(f"Tool '{name}' failed: {message}")
420
+
421
+ result = data.get("result")
422
+ if isinstance(result, dict) and "data" in result:
423
+ return result["data"]
424
+ return result
425
+
426
  action = CallToolAction(tool_name=name, arguments=kwargs)
427
  result = await self.step(action)
428
  obs = result.observation
src/openenv/__init__.py CHANGED
@@ -14,10 +14,18 @@ __all__ = [
14
  "SyncEnvClient",
15
  ]
16
 
17
- try:
18
- __version__ = metadata.version("openenv") # type: ignore[arg-type]
19
- except metadata.PackageNotFoundError: # pragma: no cover - local dev
20
- __version__ = "0.0.0"
 
 
 
 
 
 
 
 
21
 
22
 
23
  _LAZY_MODULES = {
 
14
  "SyncEnvClient",
15
  ]
16
 
17
+
18
+ def _load_package_version() -> str:
19
+ """Resolve the installed distribution version for the OpenEnv package."""
20
+ for distribution_name in ("openenv-core", "openenv"):
21
+ try:
22
+ return metadata.version(distribution_name)
23
+ except metadata.PackageNotFoundError:
24
+ continue
25
+ return "0.0.0"
26
+
27
+
28
+ __version__ = _load_package_version()
29
 
30
 
31
  _LAZY_MODULES = {
src/openenv/cli/templates/openenv_env/pyproject.toml CHANGED
@@ -17,7 +17,7 @@ dependencies = [
17
  # Core OpenEnv runtime (provides FastAPI server + HTTP client types)
18
  # install from github
19
  # "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
20
- "openenv-core[core]>=0.2.1",
21
  # Environment-specific dependencies
22
  # Add all dependencies needed for your environment here
23
  # Examples:
 
17
  # Core OpenEnv runtime (provides FastAPI server + HTTP client types)
18
  # install from github
19
  # "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
20
+ "openenv-core[core]>=0.2.2",
21
  # Environment-specific dependencies
22
  # Add all dependencies needed for your environment here
23
  # Examples:
src/openenv/core/env_server/http_server.py CHANGED
@@ -16,11 +16,15 @@ from __future__ import annotations
16
  import asyncio
17
  import inspect
18
  import json
 
19
  import os
20
  import time
21
  import uuid
22
  from concurrent.futures import ThreadPoolExecutor
23
- from typing import Any, Callable, Dict, Optional, Type
 
 
 
24
 
25
  from fastapi import (
26
  Body,
@@ -204,8 +208,9 @@ class HTTPEnvServer:
204
  self.observation_cls = observation_cls
205
 
206
  # Session management for WebSocket connections
207
- self._sessions: Dict[str, Environment] = {}
208
  self._session_executors: Dict[str, ThreadPoolExecutor] = {}
 
209
  self._session_info: Dict[str, SessionInfo] = {}
210
  self._session_lock = asyncio.Lock()
211
 
@@ -213,6 +218,14 @@ class HTTPEnvServer:
213
  # This is needed for environments using sync libraries (e.g., Playwright)
214
  self._executor = ThreadPoolExecutor(max_workers=32)
215
 
 
 
 
 
 
 
 
 
216
  def _validate_concurrency_safety(self) -> None:
217
  """
218
  Validate that the environment supports the configured concurrency level.
@@ -321,12 +334,37 @@ class HTTPEnvServer:
321
  )
322
  raise EnvironmentFactoryError(factory_name) from e
323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  async with self._session_lock:
325
  self._sessions[session_id] = env
 
 
326
  self._session_info[session_id] = SessionInfo(
327
  session_id=session_id,
328
  created_at=current_time,
329
- last_activity_at=current_time,
330
  step_count=0,
331
  environment_type=type(env).__name__,
332
  )
@@ -343,8 +381,27 @@ class HTTPEnvServer:
343
  async with self._session_lock:
344
  env = self._sessions.pop(session_id, None)
345
  executor = self._session_executors.pop(session_id, None)
 
346
  self._session_info.pop(session_id, None)
347
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  # Run close() in the same executor where the env was created
349
  # This is required for thread-sensitive libraries like Playwright/greenlet
350
  if env is not None:
@@ -383,6 +440,51 @@ class HTTPEnvServer:
383
  if increment_step:
384
  self._session_info[session_id].step_count += 1
385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  def get_session_info(self, session_id: str) -> Optional[SessionInfo]:
387
  """
388
  Get information about a specific session.
@@ -458,6 +560,20 @@ class HTTPEnvServer:
458
  f"Invalid mode: '{mode}'. Must be one of: {valid_modes}"
459
  )
460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  # Helper function to handle reset endpoint
462
  async def reset_handler(
463
  request: ResetRequest = Body(default_factory=ResetRequest),
@@ -526,53 +642,214 @@ class HTTPEnvServer:
526
 
527
  # Helper function to handle MCP endpoint
528
  async def mcp_handler(
529
- request: JsonRpcRequest, session_env: Optional[Environment] = None
 
 
530
  ) -> JsonRpcResponse:
531
  """
532
  Handle MCP JSON-RPC requests.
533
 
534
- Supports tools/list and tools/call methods in JSON-RPC 2.0 format.
 
 
 
535
  """
536
  method = request.method
537
  request_id = request.id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538
 
539
  # Use provided session environment or create temporary one
540
  if session_env is not None:
541
  _env = session_env
542
  should_close = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
  else:
544
  _env = self._env_factory()
545
  should_close = True
546
  try:
 
 
 
 
547
  if method == McpMethod.TOOLS_LIST:
548
  # Check if environment is MCP-enabled
549
- if not hasattr(_env, "mcp_client"):
550
  return JsonRpcResponse.error_response(
551
  JsonRpcErrorCode.INTERNAL_ERROR,
552
  "Environment does not support MCP",
553
  request_id=request_id,
554
  )
555
 
556
- # Use async context manager for MCP client
557
- async with _env.mcp_client:
558
- tools = await _env.mcp_client.list_tools()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
559
 
560
- return JsonRpcResponse.success(
561
- result={
562
- "tools": [
563
- t.model_dump() if hasattr(t, "model_dump") else dict(t)
564
- for t in tools
565
- ]
566
- },
 
 
 
 
 
 
 
 
 
 
 
567
  request_id=request_id,
568
  )
569
 
570
  elif method == McpMethod.TOOLS_CALL:
571
- params = request.params
572
  tool_name = params.get("name")
573
  arguments = params.get("arguments", {})
574
 
575
- if not hasattr(_env, "mcp_client"):
576
  return JsonRpcResponse.error_response(
577
  JsonRpcErrorCode.INTERNAL_ERROR,
578
  "Environment does not support MCP",
@@ -581,15 +858,51 @@ class HTTPEnvServer:
581
 
582
  if not tool_name:
583
  return JsonRpcResponse.error_response(
584
- JsonRpcErrorCode.INVALID_REQUEST,
585
  "Missing 'name' in params",
586
  request_id=request_id,
587
  )
588
 
589
- # Use async context manager for MCP client
590
- async with _env.mcp_client:
591
- result = await _env.mcp_client.call_tool(
592
- name=tool_name, arguments=arguments
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593
  )
594
 
595
  # Ensure result is JSON serializable
@@ -614,6 +927,11 @@ class HTTPEnvServer:
614
  request_id=request_id,
615
  )
616
  finally:
 
 
 
 
 
617
  if should_close:
618
  _env.close()
619
 
@@ -637,42 +955,59 @@ class HTTPEnvServer:
637
  try:
638
  # Create session with dedicated environment
639
  session_id, session_env = await self._create_session()
 
 
 
 
640
 
641
- while True:
642
- # Receive message from client
643
- raw_message = await websocket.receive_text()
644
-
645
- try:
646
- jsonrpc_dict = json.loads(raw_message)
647
- jsonrpc_request = JsonRpcRequest(**jsonrpc_dict)
648
- except json.JSONDecodeError as e:
649
- error_resp = JsonRpcResponse.error_response(
650
- JsonRpcErrorCode.PARSE_ERROR,
651
- f"Parse error: {e}",
652
- )
653
- await websocket.send_text(error_resp.model_dump_json())
654
- continue
655
- except ValidationError as e:
656
- error_resp = JsonRpcResponse.error_response(
657
- JsonRpcErrorCode.INVALID_REQUEST,
658
- f"Invalid request: {e}",
659
- )
660
- await websocket.send_text(error_resp.model_dump_json())
661
- continue
662
 
663
- try:
664
- # Call mcp_handler with session environment
665
- response = await mcp_handler(
666
- jsonrpc_request, session_env=session_env
 
667
  )
668
- await websocket.send_text(response.model_dump_json())
669
- except Exception as e:
670
- error_resp = JsonRpcResponse.error_response(
671
- JsonRpcErrorCode.INTERNAL_ERROR,
672
- str(e),
673
- request_id=jsonrpc_request.id,
674
- )
675
- await websocket.send_text(error_resp.model_dump_json())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
 
677
  except WebSocketDisconnect:
678
  pass
@@ -931,120 +1266,8 @@ all schema information needed to interact with the environment.
931
  JsonRpcErrorCode.PARSE_ERROR
932
  ).model_dump()
933
 
934
- method = request.method
935
- params = request.params
936
- request_id = request.id
937
-
938
- # Create a temporary environment for MCP access
939
- _env = self._env_factory()
940
-
941
- try:
942
- # Check if environment supports MCP
943
- if not hasattr(_env, "mcp_client") and not hasattr(_env, "mcp_server"):
944
- return JsonRpcResponse.error_response(
945
- JsonRpcErrorCode.INTERNAL_ERROR,
946
- "Environment does not support MCP",
947
- request_id=request_id,
948
- ).model_dump()
949
-
950
- if method == McpMethod.TOOLS_LIST:
951
- # List tools from MCP server
952
- if hasattr(_env, "mcp_client") and _env.mcp_client:
953
- async with _env.mcp_client:
954
- tools = await _env.mcp_client.list_tools()
955
- return JsonRpcResponse.success(
956
- result={
957
- "tools": [
958
- t.model_dump()
959
- if hasattr(t, "model_dump")
960
- else dict(t)
961
- for t in tools
962
- ]
963
- },
964
- request_id=request_id,
965
- ).model_dump()
966
- elif hasattr(_env, "mcp_server") and _env.mcp_server:
967
- # Use server directly
968
- tools = []
969
- for tool_name, tool in get_server_tools(
970
- _env.mcp_server
971
- ).items():
972
- tool_dict = {
973
- "name": tool.name,
974
- "description": tool.description or "",
975
- "inputSchema": tool.parameters or {},
976
- }
977
- tools.append(tool_dict)
978
- return JsonRpcResponse.success(
979
- result={"tools": tools},
980
- request_id=request_id,
981
- ).model_dump()
982
- else:
983
- return JsonRpcResponse.error_response(
984
- JsonRpcErrorCode.INTERNAL_ERROR,
985
- "MCP server not available",
986
- request_id=request_id,
987
- ).model_dump()
988
-
989
- elif method == McpMethod.TOOLS_CALL:
990
- tool_name = params.get("name")
991
- arguments = params.get("arguments", {})
992
-
993
- if not tool_name:
994
- return JsonRpcResponse.error_response(
995
- JsonRpcErrorCode.INVALID_PARAMS,
996
- "Invalid params - 'name' is required",
997
- request_id=request_id,
998
- ).model_dump()
999
-
1000
- # Call tool via MCP
1001
- if hasattr(_env, "mcp_client") and _env.mcp_client:
1002
- async with _env.mcp_client:
1003
- result = await _env.mcp_client.call_tool(
1004
- name=tool_name, arguments=arguments
1005
- )
1006
- elif hasattr(_env, "mcp_server") and _env.mcp_server:
1007
- # Call tool directly on FastMCP server
1008
- server_tools = get_server_tools(_env.mcp_server)
1009
- if tool_name in server_tools:
1010
- tool = server_tools[tool_name]
1011
- result = tool.fn(**arguments)
1012
- else:
1013
- return JsonRpcResponse.error_response(
1014
- JsonRpcErrorCode.INVALID_PARAMS,
1015
- f"Tool not found: {tool_name}",
1016
- request_id=request_id,
1017
- ).model_dump()
1018
- else:
1019
- return JsonRpcResponse.error_response(
1020
- JsonRpcErrorCode.INTERNAL_ERROR,
1021
- "MCP server not available",
1022
- request_id=request_id,
1023
- ).model_dump()
1024
-
1025
- # Make result JSON serializable
1026
- serializable_result = _make_json_serializable(result)
1027
-
1028
- return JsonRpcResponse.success(
1029
- result=serializable_result,
1030
- request_id=request_id,
1031
- ).model_dump()
1032
-
1033
- else:
1034
- return JsonRpcResponse.error_response(
1035
- JsonRpcErrorCode.METHOD_NOT_FOUND,
1036
- f"Method not found: {method}",
1037
- request_id=request_id,
1038
- ).model_dump()
1039
-
1040
- except Exception as e:
1041
- return JsonRpcResponse.error_response(
1042
- JsonRpcErrorCode.INTERNAL_ERROR,
1043
- str(e),
1044
- request_id=request_id,
1045
- ).model_dump()
1046
- finally:
1047
- _env.close()
1048
 
1049
  # Register WebSocket endpoint for persistent sessions
1050
  @app.websocket("/ws")
@@ -1066,135 +1289,167 @@ all schema information needed to interact with the environment.
1066
  try:
1067
  # Create session with dedicated environment
1068
  session_id, session_env = await self._create_session()
 
 
 
 
1069
 
1070
- while True:
1071
- # Receive message from client
1072
- raw_message = await websocket.receive_text()
1073
 
1074
- try:
1075
- message_dict = json.loads(raw_message)
1076
- except json.JSONDecodeError as e:
1077
- error_resp = WSErrorResponse(
1078
- data={
1079
- "message": f"Invalid JSON: {e}",
1080
- "code": WSErrorCode.INVALID_JSON,
1081
- }
1082
  )
1083
- await websocket.send_text(error_resp.model_dump_json())
1084
- continue
1085
-
1086
- msg_type = message_dict.get("type", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
1087
 
1088
- try:
1089
- match msg_type:
1090
- case "reset":
1091
- msg = WSResetMessage(**message_dict)
1092
 
1093
- is_async = (
1094
- session_env.reset_async.__func__
1095
- is not Environment.reset_async
1096
- )
1097
 
1098
- if is_async:
1099
- sig = inspect.signature(session_env.reset_async)
1100
- valid_kwargs = self._get_valid_kwargs(sig, msg.data)
1101
- observation = await session_env.reset_async(
1102
- **valid_kwargs
1103
  )
1104
- else:
1105
- sig = inspect.signature(session_env.reset)
1106
- valid_kwargs = self._get_valid_kwargs(sig, msg.data)
1107
- observation = await self._run_in_session_executor(
1108
- session_id, session_env.reset, **valid_kwargs
1109
- )
1110
-
1111
- self._update_session_activity(session_id)
1112
-
1113
- response = WSObservationResponse(
1114
- data=serialize_observation(observation),
1115
- )
1116
 
1117
- case "step":
1118
- msg = WSStepMessage(**message_dict)
1119
- action = deserialize_action(msg.data, self.action_cls)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1120
 
1121
- is_async = (
1122
- session_env.step_async.__func__
1123
- is not Environment.step_async
1124
- )
 
1125
 
1126
- if is_async:
1127
- observation = await session_env.step_async(action)
1128
- else:
1129
- observation = await self._run_in_session_executor(
1130
- session_id, session_env.step, action
1131
  )
1132
 
1133
- self._update_session_activity(
1134
- session_id, increment_step=True
1135
- )
 
 
 
 
 
 
 
 
 
 
 
1136
 
1137
- response = WSObservationResponse(
1138
- data=serialize_observation(observation)
1139
- )
1140
 
1141
- case "state":
1142
- msg = WSStateMessage(**message_dict)
1143
- state = session_env.state
1144
- if hasattr(state, "model_dump"):
1145
- state_data = state.model_dump()
1146
- else:
1147
- state_data = dict(state) if state else {}
1148
-
1149
- response = WSStateResponse(data=state_data)
1150
-
1151
- case "close":
1152
- msg = WSCloseMessage(**message_dict)
1153
- break
1154
-
1155
- case "mcp":
1156
- msg = WSMCPMessage(**message_dict)
1157
- try:
1158
- rpc_request = JsonRpcRequest(**msg.data)
1159
- except (ValidationError, Exception) as e:
1160
- rpc_response = JsonRpcResponse.error_response(
1161
- JsonRpcErrorCode.INVALID_REQUEST,
1162
- f"Invalid request: {e}",
 
 
 
 
 
 
 
 
 
1163
  )
1164
- else:
1165
- rpc_response = await mcp_handler(
1166
- rpc_request,
1167
- session_env=session_env,
 
 
 
1168
  )
1169
- response = WSMCPResponse(data=rpc_response.model_dump())
1170
-
1171
- case _:
1172
- response = WSErrorResponse(
1173
- data={
1174
- "message": f"Unknown message type: {msg_type}",
1175
- "code": WSErrorCode.UNKNOWN_TYPE,
1176
- }
1177
- )
1178
 
1179
- await websocket.send_text(response.model_dump_json())
1180
 
1181
- except ValidationError as e:
1182
- error_resp = WSErrorResponse(
1183
- data={
1184
- "message": "Invalid message",
1185
- "code": WSErrorCode.VALIDATION_ERROR,
1186
- "errors": e.errors(),
1187
- }
1188
- )
1189
- await websocket.send_text(error_resp.model_dump_json())
1190
- except Exception as e:
1191
- error_resp = WSErrorResponse(
1192
- data={
1193
- "message": str(e),
1194
- "code": WSErrorCode.EXECUTION_ERROR,
1195
- }
1196
- )
1197
- await websocket.send_text(error_resp.model_dump_json())
1198
 
1199
  except WebSocketDisconnect:
1200
  pass
@@ -1276,7 +1531,7 @@ def create_app(
1276
  from .web_interface import create_web_interface_app
1277
 
1278
  return create_web_interface_app(
1279
- env,
1280
  action_cls,
1281
  observation_cls,
1282
  env_name,
 
16
  import asyncio
17
  import inspect
18
  import json
19
+ import logging
20
  import os
21
  import time
22
  import uuid
23
  from concurrent.futures import ThreadPoolExecutor
24
+ from contextlib import AsyncExitStack
25
+ from typing import Any, AsyncContextManager, Callable, cast, Dict, Optional, Type
26
+
27
+ _MISSING = object()
28
 
29
  from fastapi import (
30
  Body,
 
208
  self.observation_cls = observation_cls
209
 
210
  # Session management for WebSocket connections
211
+ self._sessions: Dict[str, Optional[Environment]] = {}
212
  self._session_executors: Dict[str, ThreadPoolExecutor] = {}
213
+ self._session_stacks: Dict[str, AsyncExitStack] = {}
214
  self._session_info: Dict[str, SessionInfo] = {}
215
  self._session_lock = asyncio.Lock()
216
 
 
218
  # This is needed for environments using sync libraries (e.g., Playwright)
219
  self._executor = ThreadPoolExecutor(max_workers=32)
220
 
221
+ # Idle session reaper configuration.
222
+ # Timeout is taken from ConcurrencyConfig.session_timeout;
223
+ # None means no timeout (default — reaper is a no-op).
224
+ self._session_idle_timeout_s: Optional[float] = (
225
+ self._concurrency_config.session_timeout
226
+ )
227
+ self._reaper_task: Optional[asyncio.Task[None]] = None
228
+
229
  def _validate_concurrency_safety(self) -> None:
230
  """
231
  Validate that the environment supports the configured concurrency level.
 
334
  )
335
  raise EnvironmentFactoryError(factory_name) from e
336
 
337
+ # Hold the MCP session open for the lifetime of this session,
338
+ # matching the WebSocket path's AsyncExitStack pattern. This
339
+ # prevents per-request MCP transport teardown/reconnection and
340
+ # preserves FastMCP session state (ctx.set_state / ctx.get_state)
341
+ # across HTTP calls within the same OpenEnv session.
342
+ stack = AsyncExitStack()
343
+ try:
344
+ mcp_session_factory = getattr(env, "mcp_session", None)
345
+ if callable(mcp_session_factory):
346
+ mcp_session_cm = cast(AsyncContextManager[Any], mcp_session_factory())
347
+ await stack.enter_async_context(mcp_session_cm)
348
+ except Exception:
349
+ # MCP transport failed to start — clean up the reserved slot,
350
+ # the env, and the executor so they don't leak permanently
351
+ # against _max_concurrent_envs.
352
+ await stack.aclose() # best-effort
353
+ async with self._session_lock:
354
+ self._sessions.pop(session_id, None)
355
+ self._session_executors.pop(session_id, None)
356
+ self._session_info.pop(session_id, None)
357
+ await self._cleanup_session_resources(env, executor)
358
+ raise
359
+
360
  async with self._session_lock:
361
  self._sessions[session_id] = env
362
+ self._session_stacks[session_id] = stack
363
+ now = time.time()
364
  self._session_info[session_id] = SessionInfo(
365
  session_id=session_id,
366
  created_at=current_time,
367
+ last_activity_at=now,
368
  step_count=0,
369
  environment_type=type(env).__name__,
370
  )
 
381
  async with self._session_lock:
382
  env = self._sessions.pop(session_id, None)
383
  executor = self._session_executors.pop(session_id, None)
384
+ stack = self._session_stacks.pop(session_id, None)
385
  self._session_info.pop(session_id, None)
386
 
387
+ await self._cleanup_session_resources(env, executor, stack)
388
+
389
+ async def _cleanup_session_resources(
390
+ self,
391
+ env: Optional[Environment],
392
+ executor: Optional[ThreadPoolExecutor],
393
+ stack: Optional[AsyncExitStack] = None,
394
+ ) -> None:
395
+ """Close an environment and shut down its executor (best-effort)."""
396
+ # Close the MCP session stack first — this gracefully exits the
397
+ # mcp_session() context (and the underlying FastMCP Client session)
398
+ # before we tear down the environment references.
399
+ if stack is not None:
400
+ try:
401
+ await stack.aclose()
402
+ except Exception:
403
+ pass # Best effort cleanup
404
+
405
  # Run close() in the same executor where the env was created
406
  # This is required for thread-sensitive libraries like Playwright/greenlet
407
  if env is not None:
 
440
  if increment_step:
441
  self._session_info[session_id].step_count += 1
442
 
443
+ async def _reap_idle_sessions(self) -> None:
444
+ """Background task that periodically destroys sessions idle beyond the timeout."""
445
+ timeout = self._session_idle_timeout_s
446
+ if timeout is None:
447
+ return # no timeout configured — noop
448
+ interval = max(timeout / 4, 5.0) # check frequently enough
449
+ while True:
450
+ try:
451
+ await asyncio.sleep(interval)
452
+ now = time.time()
453
+ stale_ids: list[str] = []
454
+ async with self._session_lock:
455
+ for sid, info in self._session_info.items():
456
+ if now - info.last_activity_at > timeout:
457
+ stale_ids.append(sid)
458
+ for sid in stale_ids:
459
+ # Re-check under lock: activity may have arrived since
460
+ # the snapshot was taken, making this session active again.
461
+ # Refresh `now` so slow _destroy_session calls don't cause
462
+ # subsequent entries to be validated against a stale clock.
463
+ now = time.time()
464
+ async with self._session_lock:
465
+ info = self._session_info.get(sid)
466
+ if info is None or (now - info.last_activity_at) <= timeout:
467
+ continue
468
+ await self._destroy_session(sid)
469
+ except asyncio.CancelledError:
470
+ break
471
+ except Exception as exc:
472
+ logging.getLogger(__name__).warning(
473
+ "Idle-session reaper encountered an error (will retry): %s",
474
+ exc,
475
+ )
476
+
477
+ def _start_reaper(self) -> None:
478
+ """Start the idle-session reaper if a timeout is configured."""
479
+ if self._session_idle_timeout_s is not None and self._reaper_task is None:
480
+ self._reaper_task = asyncio.create_task(self._reap_idle_sessions())
481
+
482
+ def _stop_reaper(self) -> None:
483
+ """Cancel the reaper background task."""
484
+ if self._reaper_task is not None:
485
+ self._reaper_task.cancel()
486
+ self._reaper_task = None
487
+
488
  def get_session_info(self, session_id: str) -> Optional[SessionInfo]:
489
  """
490
  Get information about a specific session.
 
560
  f"Invalid mode: '{mode}'. Must be one of: {valid_modes}"
561
  )
562
 
563
+ # Wire up idle-session reaper lifecycle via app events
564
+ server_ref = self
565
+
566
+ async def _start_session_reaper() -> None:
567
+ server_ref._start_reaper()
568
+
569
+ async def _stop_session_reaper() -> None:
570
+ server_ref._stop_reaper()
571
+
572
+ if not getattr(app.router, "_openenv_reaper_registered", False):
573
+ app.router.on_startup.append(_start_session_reaper)
574
+ app.router.on_shutdown.append(_stop_session_reaper)
575
+ app.router._openenv_reaper_registered = True # type: ignore[attr-defined]
576
+
577
  # Helper function to handle reset endpoint
578
  async def reset_handler(
579
  request: ResetRequest = Body(default_factory=ResetRequest),
 
642
 
643
  # Helper function to handle MCP endpoint
644
  async def mcp_handler(
645
+ request: JsonRpcRequest,
646
+ session_env: Optional[Environment] = None,
647
+ session_id: Optional[str] = None,
648
  ) -> JsonRpcResponse:
649
  """
650
  Handle MCP JSON-RPC requests.
651
 
652
+ Supports tools/list and tools/call methods in JSON-RPC 2.0 format,
653
+ plus OpenEnv session lifecycle methods for HTTP MCP:
654
+ - openenv/session/create
655
+ - openenv/session/close
656
  """
657
  method = request.method
658
  request_id = request.id
659
+ params = request.params
660
+ if not isinstance(params, dict):
661
+ return JsonRpcResponse.error_response(
662
+ JsonRpcErrorCode.INVALID_PARAMS,
663
+ "Params must be an object",
664
+ request_id=request_id,
665
+ )
666
+
667
+ # OpenEnv extension methods for explicit MCP session management.
668
+ # This enables persistent MCP lifecycles over HTTP /mcp, matching WebSocket semantics.
669
+ if method == "openenv/session/create":
670
+ if session_env is not None and session_id is not None:
671
+ return JsonRpcResponse.success(
672
+ result={"session_id": session_id},
673
+ request_id=request_id,
674
+ )
675
+ try:
676
+ created_session_id, _ = await self._create_session()
677
+ except SessionCapacityError as e:
678
+ return JsonRpcResponse.error_response(
679
+ JsonRpcErrorCode.SERVER_ERROR,
680
+ str(e),
681
+ request_id=request_id,
682
+ data={
683
+ "active_sessions": e.active_sessions,
684
+ "max_sessions": e.max_sessions,
685
+ },
686
+ )
687
+ except EnvironmentFactoryError as e:
688
+ return JsonRpcResponse.error_response(
689
+ JsonRpcErrorCode.SERVER_ERROR,
690
+ str(e),
691
+ request_id=request_id,
692
+ data={"factory_name": e.factory_name},
693
+ )
694
+ return JsonRpcResponse.success(
695
+ result={"session_id": created_session_id},
696
+ request_id=request_id,
697
+ )
698
+
699
+ if method == "openenv/session/close":
700
+ target_session_id = params.get("session_id")
701
+ if not target_session_id:
702
+ return JsonRpcResponse.error_response(
703
+ JsonRpcErrorCode.INVALID_PARAMS,
704
+ "Invalid params - 'session_id' is required",
705
+ request_id=request_id,
706
+ )
707
+
708
+ if session_id is not None and target_session_id == session_id:
709
+ return JsonRpcResponse.error_response(
710
+ JsonRpcErrorCode.INVALID_REQUEST,
711
+ "Cannot close active WebSocket-managed session via MCP method",
712
+ request_id=request_id,
713
+ )
714
+
715
+ async with self._session_lock:
716
+ env = self._sessions.pop(target_session_id, _MISSING)
717
+ if env is not _MISSING:
718
+ executor = self._session_executors.pop(target_session_id, None)
719
+ stack = self._session_stacks.pop(target_session_id, None)
720
+ self._session_info.pop(target_session_id, None)
721
+ else:
722
+ executor = None
723
+ stack = None
724
+
725
+ if env is _MISSING:
726
+ return JsonRpcResponse.error_response(
727
+ JsonRpcErrorCode.INVALID_PARAMS,
728
+ f"Unknown session_id: {target_session_id}",
729
+ request_id=request_id,
730
+ )
731
+
732
+ if env is None:
733
+ # Session slot reserved but env factory still running;
734
+ # re-insert the placeholder AND the executor so
735
+ # _create_session can finish and the executor remains
736
+ # tracked for eventual shutdown.
737
+ async with self._session_lock:
738
+ self._sessions[target_session_id] = None
739
+ if executor is not None:
740
+ self._session_executors[target_session_id] = executor
741
+ return JsonRpcResponse.error_response(
742
+ JsonRpcErrorCode.INVALID_REQUEST,
743
+ f"Session {target_session_id} is still initializing; retry shortly",
744
+ request_id=request_id,
745
+ )
746
+
747
+ # env/executor/stack cleanup outside the lock
748
+ await self._cleanup_session_resources(env, executor, stack)
749
+ return JsonRpcResponse.success(
750
+ result={"session_id": target_session_id, "closed": True},
751
+ request_id=request_id,
752
+ )
753
+
754
+ requested_session_id = params.get("session_id")
755
+ managed_session_id = session_id
756
 
757
  # Use provided session environment or create temporary one
758
  if session_env is not None:
759
  _env = session_env
760
  should_close = False
761
+ elif requested_session_id:
762
+ async with self._session_lock:
763
+ _env = self._sessions.get(requested_session_id, _MISSING)
764
+
765
+ if _env is _MISSING:
766
+ return JsonRpcResponse.error_response(
767
+ JsonRpcErrorCode.INVALID_PARAMS,
768
+ f"Unknown session_id: {requested_session_id}",
769
+ request_id=request_id,
770
+ )
771
+
772
+ if _env is None:
773
+ return JsonRpcResponse.error_response(
774
+ JsonRpcErrorCode.INVALID_REQUEST,
775
+ f"Session {requested_session_id} is still initializing; retry shortly",
776
+ request_id=request_id,
777
+ )
778
+
779
+ should_close = False
780
+ managed_session_id = requested_session_id
781
  else:
782
  _env = self._env_factory()
783
  should_close = True
784
  try:
785
+ mcp_client = getattr(_env, "mcp_client", None)
786
+ mcp_server = getattr(_env, "mcp_server", None)
787
+ mcp_session_factory = getattr(_env, "mcp_session", None)
788
+
789
  if method == McpMethod.TOOLS_LIST:
790
  # Check if environment is MCP-enabled
791
+ if mcp_client is None and mcp_server is None:
792
  return JsonRpcResponse.error_response(
793
  JsonRpcErrorCode.INTERNAL_ERROR,
794
  "Environment does not support MCP",
795
  request_id=request_id,
796
  )
797
 
798
+ if mcp_client:
799
+ if managed_session_id and mcp_client.is_connected():
800
+ # Session-managed with live transport — call
801
+ # directly, no redundant re-entry.
802
+ tools = await mcp_client.list_tools()
803
+ elif callable(mcp_session_factory):
804
+ # Stateless request, or session-managed but the
805
+ # background transport was lost: (re-)open.
806
+ mcp_session_cm = cast(
807
+ AsyncContextManager[Any], mcp_session_factory()
808
+ )
809
+ async with mcp_session_cm:
810
+ tools = await mcp_client.list_tools()
811
+ else:
812
+ async with mcp_client:
813
+ tools = await mcp_client.list_tools()
814
+
815
+ return JsonRpcResponse.success(
816
+ result={
817
+ "tools": [
818
+ t.model_dump()
819
+ if hasattr(t, "model_dump")
820
+ else dict(t)
821
+ for t in tools
822
+ ]
823
+ },
824
+ request_id=request_id,
825
+ )
826
 
827
+ if mcp_server:
828
+ tools = []
829
+ for _tool_name, tool in get_server_tools(mcp_server).items():
830
+ tools.append(
831
+ {
832
+ "name": tool.name,
833
+ "description": tool.description or "",
834
+ "inputSchema": tool.parameters or {},
835
+ }
836
+ )
837
+ return JsonRpcResponse.success(
838
+ result={"tools": tools},
839
+ request_id=request_id,
840
+ )
841
+
842
+ return JsonRpcResponse.error_response(
843
+ JsonRpcErrorCode.INTERNAL_ERROR,
844
+ "MCP server not available",
845
  request_id=request_id,
846
  )
847
 
848
  elif method == McpMethod.TOOLS_CALL:
 
849
  tool_name = params.get("name")
850
  arguments = params.get("arguments", {})
851
 
852
+ if mcp_client is None and mcp_server is None:
853
  return JsonRpcResponse.error_response(
854
  JsonRpcErrorCode.INTERNAL_ERROR,
855
  "Environment does not support MCP",
 
858
 
859
  if not tool_name:
860
  return JsonRpcResponse.error_response(
861
+ JsonRpcErrorCode.INVALID_PARAMS,
862
  "Missing 'name' in params",
863
  request_id=request_id,
864
  )
865
 
866
+ if mcp_client:
867
+ if managed_session_id and mcp_client.is_connected():
868
+ # Session-managed with live transport.
869
+ result = await mcp_client.call_tool(
870
+ name=tool_name, arguments=arguments
871
+ )
872
+ elif callable(mcp_session_factory):
873
+ # Stateless request, or session-managed but the
874
+ # background transport was lost: (re-)open.
875
+ mcp_session_cm = cast(
876
+ AsyncContextManager[Any], mcp_session_factory()
877
+ )
878
+ async with mcp_session_cm:
879
+ result = await mcp_client.call_tool(
880
+ name=tool_name, arguments=arguments
881
+ )
882
+ else:
883
+ async with mcp_client:
884
+ result = await mcp_client.call_tool(
885
+ name=tool_name, arguments=arguments
886
+ )
887
+ elif mcp_server:
888
+ server_tools = get_server_tools(mcp_server)
889
+ if tool_name in server_tools:
890
+ tool = server_tools[tool_name]
891
+ if inspect.iscoroutinefunction(tool.fn):
892
+ result = await tool.fn(**arguments)
893
+ else:
894
+ result = tool.fn(**arguments)
895
+ else:
896
+ return JsonRpcResponse.error_response(
897
+ JsonRpcErrorCode.INVALID_PARAMS,
898
+ f"Tool not found: {tool_name}",
899
+ request_id=request_id,
900
+ )
901
+ else:
902
+ return JsonRpcResponse.error_response(
903
+ JsonRpcErrorCode.INTERNAL_ERROR,
904
+ "MCP server not available",
905
+ request_id=request_id,
906
  )
907
 
908
  # Ensure result is JSON serializable
 
927
  request_id=request_id,
928
  )
929
  finally:
930
+ if managed_session_id:
931
+ self._update_session_activity(
932
+ managed_session_id,
933
+ increment_step=(method == McpMethod.TOOLS_CALL),
934
+ )
935
  if should_close:
936
  _env.close()
937
 
 
955
  try:
956
  # Create session with dedicated environment
957
  session_id, session_env = await self._create_session()
958
+ if session_env is None:
959
+ raise RuntimeError(
960
+ "Session environment not initialized for MCP websocket"
961
+ )
962
 
963
+ # If environment has an mcp_session context manager, hold it open
964
+ # for the lifetime of the websocket connection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
965
 
966
+ async with AsyncExitStack() as stack:
967
+ mcp_session_factory = getattr(session_env, "mcp_session", None)
968
+ if callable(mcp_session_factory):
969
+ mcp_session_cm = cast(
970
+ AsyncContextManager[Any], mcp_session_factory()
971
  )
972
+ await stack.enter_async_context(mcp_session_cm)
973
+
974
+ while True:
975
+ # Receive message from client
976
+ raw_message = await websocket.receive_text()
977
+
978
+ try:
979
+ jsonrpc_dict = json.loads(raw_message)
980
+ jsonrpc_request = JsonRpcRequest(**jsonrpc_dict)
981
+ except json.JSONDecodeError as e:
982
+ error_resp = JsonRpcResponse.error_response(
983
+ JsonRpcErrorCode.PARSE_ERROR,
984
+ f"Parse error: {e}",
985
+ )
986
+ await websocket.send_text(error_resp.model_dump_json())
987
+ continue
988
+ except ValidationError as e:
989
+ error_resp = JsonRpcResponse.error_response(
990
+ JsonRpcErrorCode.INVALID_REQUEST,
991
+ f"Invalid request: {e}",
992
+ )
993
+ await websocket.send_text(error_resp.model_dump_json())
994
+ continue
995
+
996
+ try:
997
+ # Call mcp_handler with session environment
998
+ response = await mcp_handler(
999
+ jsonrpc_request,
1000
+ session_env=session_env,
1001
+ session_id=session_id,
1002
+ )
1003
+ await websocket.send_text(response.model_dump_json())
1004
+ except Exception as e:
1005
+ error_resp = JsonRpcResponse.error_response(
1006
+ JsonRpcErrorCode.INTERNAL_ERROR,
1007
+ str(e),
1008
+ request_id=jsonrpc_request.id,
1009
+ )
1010
+ await websocket.send_text(error_resp.model_dump_json())
1011
 
1012
  except WebSocketDisconnect:
1013
  pass
 
1266
  JsonRpcErrorCode.PARSE_ERROR
1267
  ).model_dump()
1268
 
1269
+ response = await mcp_handler(request)
1270
+ return response.model_dump()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1271
 
1272
  # Register WebSocket endpoint for persistent sessions
1273
  @app.websocket("/ws")
 
1289
  try:
1290
  # Create session with dedicated environment
1291
  session_id, session_env = await self._create_session()
1292
+ if session_env is None:
1293
+ raise RuntimeError(
1294
+ "Session environment not initialized for websocket"
1295
+ )
1296
 
1297
+ # Keep MCP session open for entire websocket lifetime
1298
+ # (avoids reconnect overhead on every message)
 
1299
 
1300
+ async with AsyncExitStack() as stack:
1301
+ mcp_session_factory = getattr(session_env, "mcp_session", None)
1302
+ if callable(mcp_session_factory):
1303
+ mcp_session_cm = cast(
1304
+ AsyncContextManager[Any], mcp_session_factory()
 
 
 
1305
  )
1306
+ await stack.enter_async_context(mcp_session_cm)
1307
+
1308
+ while True:
1309
+ # Receive message from client
1310
+ raw_message = await websocket.receive_text()
1311
+
1312
+ try:
1313
+ message_dict = json.loads(raw_message)
1314
+ except json.JSONDecodeError as e:
1315
+ error_resp = WSErrorResponse(
1316
+ data={
1317
+ "message": f"Invalid JSON: {e}",
1318
+ "code": WSErrorCode.INVALID_JSON,
1319
+ }
1320
+ )
1321
+ await websocket.send_text(error_resp.model_dump_json())
1322
+ continue
1323
 
1324
+ msg_type = message_dict.get("type", "")
 
 
 
1325
 
1326
+ try:
1327
+ match msg_type:
1328
+ case "reset":
1329
+ msg = WSResetMessage(**message_dict)
1330
 
1331
+ is_async = (
1332
+ session_env.reset_async.__func__
1333
+ is not Environment.reset_async
 
 
1334
  )
 
 
 
 
 
 
 
 
 
 
 
 
1335
 
1336
+ if is_async:
1337
+ sig = inspect.signature(session_env.reset_async)
1338
+ valid_kwargs = self._get_valid_kwargs(
1339
+ sig, msg.data
1340
+ )
1341
+ observation = await session_env.reset_async(
1342
+ **valid_kwargs
1343
+ )
1344
+ else:
1345
+ sig = inspect.signature(session_env.reset)
1346
+ valid_kwargs = self._get_valid_kwargs(
1347
+ sig, msg.data
1348
+ )
1349
+ observation = (
1350
+ await self._run_in_session_executor(
1351
+ session_id,
1352
+ session_env.reset,
1353
+ **valid_kwargs,
1354
+ )
1355
+ )
1356
+
1357
+ self._update_session_activity(session_id)
1358
+
1359
+ response = WSObservationResponse(
1360
+ data=serialize_observation(observation),
1361
+ )
1362
 
1363
+ case "step":
1364
+ msg = WSStepMessage(**message_dict)
1365
+ action = deserialize_action(
1366
+ msg.data, self.action_cls
1367
+ )
1368
 
1369
+ is_async = (
1370
+ session_env.step_async.__func__
1371
+ is not Environment.step_async
 
 
1372
  )
1373
 
1374
+ if is_async:
1375
+ observation = await session_env.step_async(
1376
+ action
1377
+ )
1378
+ else:
1379
+ observation = (
1380
+ await self._run_in_session_executor(
1381
+ session_id, session_env.step, action
1382
+ )
1383
+ )
1384
+
1385
+ self._update_session_activity(
1386
+ session_id, increment_step=True
1387
+ )
1388
 
1389
+ response = WSObservationResponse(
1390
+ data=serialize_observation(observation)
1391
+ )
1392
 
1393
+ case "state":
1394
+ msg = WSStateMessage(**message_dict)
1395
+ state = session_env.state
1396
+ if hasattr(state, "model_dump"):
1397
+ state_data = state.model_dump()
1398
+ else:
1399
+ state_data = dict(state) if state else {}
1400
+
1401
+ response = WSStateResponse(data=state_data)
1402
+
1403
+ case "close":
1404
+ msg = WSCloseMessage(**message_dict)
1405
+ break
1406
+
1407
+ case "mcp":
1408
+ msg = WSMCPMessage(**message_dict)
1409
+ try:
1410
+ rpc_request = JsonRpcRequest(**msg.data)
1411
+ except (ValidationError, Exception) as e:
1412
+ rpc_response = JsonRpcResponse.error_response(
1413
+ JsonRpcErrorCode.INVALID_REQUEST,
1414
+ f"Invalid request: {e}",
1415
+ )
1416
+ else:
1417
+ rpc_response = await mcp_handler(
1418
+ rpc_request,
1419
+ session_env=session_env,
1420
+ session_id=session_id,
1421
+ )
1422
+ response = WSMCPResponse(
1423
+ data=rpc_response.model_dump()
1424
  )
1425
+
1426
+ case _:
1427
+ response = WSErrorResponse(
1428
+ data={
1429
+ "message": f"Unknown message type: {msg_type}",
1430
+ "code": WSErrorCode.UNKNOWN_TYPE,
1431
+ }
1432
  )
 
 
 
 
 
 
 
 
 
1433
 
1434
+ await websocket.send_text(response.model_dump_json())
1435
 
1436
+ except ValidationError as e:
1437
+ error_resp = WSErrorResponse(
1438
+ data={
1439
+ "message": "Invalid message",
1440
+ "code": WSErrorCode.VALIDATION_ERROR,
1441
+ "errors": e.errors(),
1442
+ }
1443
+ )
1444
+ await websocket.send_text(error_resp.model_dump_json())
1445
+ except Exception as e:
1446
+ error_resp = WSErrorResponse(
1447
+ data={
1448
+ "message": str(e),
1449
+ "code": WSErrorCode.EXECUTION_ERROR,
1450
+ }
1451
+ )
1452
+ await websocket.send_text(error_resp.model_dump_json())
1453
 
1454
  except WebSocketDisconnect:
1455
  pass
 
1531
  from .web_interface import create_web_interface_app
1532
 
1533
  return create_web_interface_app(
1534
+ cast(Any, env),
1535
  action_cls,
1536
  observation_cls,
1537
  env_name,
src/openenv/core/env_server/mcp_environment.py CHANGED
@@ -56,6 +56,7 @@ import asyncio
56
  import inspect
57
  from abc import abstractmethod
58
  from collections import defaultdict
 
59
  from typing import Any, Callable, Dict, Optional
60
 
61
  from fastmcp import Client
@@ -164,6 +165,52 @@ class MCPEnvironment(Environment):
164
  # Track tool schemas for list_tools: {tool_name: {mode: schema}}
165
  self._mode_tool_schemas = defaultdict(dict)
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  @property
168
  def supports_code_mode(self) -> bool:
169
  """Check if this environment supports code mode (execute_code)."""
@@ -292,7 +339,8 @@ class MCPEnvironment(Environment):
292
 
293
  # If mode is None, register with FastMCP as usual
294
  if mode is None:
295
- decorated_func = self.mcp_server.tool()(func)
 
296
  self._mode_tools[tool_name][None] = func
297
  return decorated_func
298
 
@@ -372,24 +420,49 @@ class MCPEnvironment(Environment):
372
  return self._step_impl(action, timeout_s=timeout_s, **kwargs)
373
 
374
  def _handle_list_tools(self) -> ListToolsObservation:
 
 
 
 
375
  """
376
- Handle a ListToolsAction by querying the MCP server.
377
 
378
  Returns:
379
- ListToolsObservation containing all available tools with their
380
- names, descriptions, and input schemas, filtered by current mode.
381
  """
382
- try:
383
- # Get current mode
384
- current_mode = getattr(self, "_mode", None)
385
 
386
- # Start with tools from FastMCP server (mode=None tools)
387
- tools_result = run_async_safely(self._async_list_tools())
 
 
 
 
 
 
 
388
 
389
- # Build list of Tool objects
390
- tools = []
 
391
 
392
- # Add FastMCP tools that are not mode-specific
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  for tool in tools_result:
394
  if tool.name not in self._mode_tool_schemas:
395
  tools.append(
@@ -401,11 +474,8 @@ class MCPEnvironment(Environment):
401
  else {},
402
  )
403
  )
404
-
405
- # Add mode-specific tools available in current mode
406
  for tool_name, mode_schemas in self._mode_tool_schemas.items():
407
  if None in mode_schemas:
408
- # Tool available in all modes
409
  schema = mode_schemas[None]
410
  tools.append(
411
  Tool(
@@ -415,7 +485,6 @@ class MCPEnvironment(Environment):
415
  )
416
  )
417
  elif current_mode in mode_schemas:
418
- # Tool available in current mode
419
  schema = mode_schemas[current_mode]
420
  tools.append(
421
  Tool(
@@ -424,65 +493,30 @@ class MCPEnvironment(Environment):
424
  input_schema=schema["input_schema"],
425
  )
426
  )
427
-
428
  return ListToolsObservation(tools=tools)
429
-
430
  except Exception as e:
431
- # Return an observation with error in metadata
432
  return ListToolsObservation(
433
  tools=[],
434
- metadata={
435
- "error": str(e),
436
- "error_type": "list_tools_failed",
437
- },
438
  )
439
 
440
- async def _async_list_tools(self) -> list:
441
- """
442
- Async helper to list tools from the MCP client.
443
-
444
- Returns:
445
- List of tool objects from the MCP server.
446
- """
447
- async with self.mcp_client:
448
- return await self.mcp_client.list_tools()
449
-
450
- def _handle_call_tool(
451
  self,
452
  action: CallToolAction,
453
  timeout_s: Optional[float] = None,
454
  ) -> CallToolObservation:
455
- """
456
- Handle a CallToolAction by invoking the specified tool.
457
-
458
- Args:
459
- action: The CallToolAction containing tool_name and arguments.
460
- timeout_s: Timeout in seconds. Defaults to MCP_TOOL_CALL_TIMEOUT (30s).
461
-
462
- Returns:
463
- CallToolObservation with the tool's result or an error.
464
- """
465
  timeout = timeout_s if timeout_s is not None else MCP_TOOL_CALL_TIMEOUT
466
-
467
- # Check if this is a mode-specific tool
468
  tool_name = action.tool_name
469
  current_mode = getattr(self, "_mode", None)
470
 
471
  if tool_name in self._mode_tools:
472
  mode_info = self._mode_tools[tool_name]
473
-
474
- # Check if tool is available in current mode
475
- # Tool is available if:
476
- # 1. It has a None mode (available in all modes), OR
477
- # 2. It has an implementation for the current mode
478
  if None in mode_info:
479
- # Use the mode-agnostic version
480
  func = mode_info[None]
481
  elif current_mode in mode_info:
482
- # Use the mode-specific version
483
  func = mode_info[current_mode]
484
  else:
485
- # Tool not available in current mode
486
  return CallToolObservation(
487
  tool_name=tool_name,
488
  result=None,
@@ -491,16 +525,11 @@ class MCPEnvironment(Environment):
491
  message=f"Tool '{tool_name}' not available in {current_mode} mode",
492
  ),
493
  )
494
-
495
- # Call the mode-specific function directly
496
  try:
497
- # Check if function is async and await if necessary
498
  if inspect.iscoroutinefunction(func):
499
- result = run_async_safely(func(**action.arguments))
500
  else:
501
  result = func(**action.arguments)
502
-
503
- # Wrap result in CallToolResult format to match FastMCP behavior
504
  return CallToolObservation(
505
  tool_name=tool_name,
506
  result=CallToolResult(
@@ -521,22 +550,12 @@ class MCPEnvironment(Environment):
521
  ),
522
  )
523
 
524
- # Not a mode-specific tool, use FastMCP
525
  try:
526
- # Run the async call_tool with timeout
527
- # Use run_async_safely to handle both sync and async contexts
528
- result = run_async_safely(
529
- asyncio.wait_for(
530
- self._async_call_tool(action.tool_name, action.arguments),
531
- timeout=timeout,
532
- )
533
- )
534
-
535
- return CallToolObservation(
536
- tool_name=action.tool_name,
537
- result=result,
538
  )
539
-
540
  except asyncio.TimeoutError:
541
  return CallToolObservation(
542
  tool_name=action.tool_name,
@@ -546,11 +565,8 @@ class MCPEnvironment(Environment):
546
  message=f"Tool '{action.tool_name}' timed out after {timeout} seconds",
547
  ),
548
  )
549
-
550
  except Exception as e:
551
  error_message = str(e)
552
-
553
- # Determine error type based on the exception
554
  if (
555
  "not found" in error_message.lower()
556
  or "unknown tool" in error_message.lower()
@@ -563,29 +579,34 @@ class MCPEnvironment(Environment):
563
  error_type = ToolErrorType.INVALID_ARGS
564
  else:
565
  error_type = ToolErrorType.EXECUTION_ERROR
566
-
567
  return CallToolObservation(
568
  tool_name=action.tool_name,
569
  result=None,
570
- error=ToolError(
571
- error_type=error_type,
572
- message=error_message,
573
- ),
574
  )
575
 
576
- async def _async_call_tool(self, tool_name: str, arguments: dict) -> Any:
 
 
 
 
 
577
  """
578
- Async helper to call a tool on the MCP server.
579
 
580
- Args:
581
- tool_name: Name of the tool to invoke.
582
- arguments: Dictionary of arguments to pass to the tool.
583
-
584
- Returns:
585
- The result from the tool execution.
586
  """
587
- async with self.mcp_client:
588
- return await self.mcp_client.call_tool(tool_name, arguments)
 
 
 
 
 
 
 
589
 
590
  @abstractmethod
591
  def _step_impl(
 
56
  import inspect
57
  from abc import abstractmethod
58
  from collections import defaultdict
59
+ from contextlib import asynccontextmanager
60
  from typing import Any, Callable, Dict, Optional
61
 
62
  from fastmcp import Client
 
165
  # Track tool schemas for list_tools: {tool_name: {mode: schema}}
166
  self._mode_tool_schemas = defaultdict(dict)
167
 
168
+ def _require_mcp_client(self) -> Any:
169
+ """Return MCP client or raise if environment has been closed."""
170
+ if self.mcp_client is None:
171
+ raise RuntimeError("MCP client is not available; environment is closed")
172
+ return self.mcp_client
173
+
174
+ def _require_mcp_server(self) -> Any:
175
+ """Return MCP server or raise if environment has been closed."""
176
+ if self.mcp_server is None:
177
+ raise RuntimeError("MCP server is not available; environment is closed")
178
+ return self.mcp_server
179
+
180
+ @asynccontextmanager
181
+ async def mcp_session(self):
182
+ """
183
+ Context manager for MCP client sessions.
184
+
185
+ This wrapper serves two purposes:
186
+
187
+ 1. **Null guard** — raises a clear error if ``close()`` has already
188
+ been called (``mcp_client`` is ``None``).
189
+
190
+ 2. **AsyncExitStack adapter** — FastMCP's ``Client.__aenter__``
191
+ creates a background ``asyncio.Task`` for session management.
192
+ When entered directly via ``AsyncExitStack`` in the HTTP session
193
+ path (``_create_session``), this task can be cancelled by ASGI
194
+ harnesses (e.g. Starlette ``TestClient``) between requests,
195
+ corrupting session state. Wrapping in an ``asynccontextmanager``
196
+ generator isolates the task lifecycle: the generator frame keeps
197
+ ``async with client:`` suspended at ``yield``, so cleanup only
198
+ runs when the stack explicitly closes the generator — not when
199
+ the event loop cancels orphaned tasks.
200
+
201
+ Delegates to FastMCP's ``Client`` context manager which is
202
+ reentrant: the first entry opens the transport and subsequent
203
+ (nested) entries simply increment an internal reference counter.
204
+ The transport is closed only when the outermost context exits.
205
+
206
+ No external lock is needed because ``Client._connect`` /
207
+ ``Client._disconnect`` already serialise connection state changes
208
+ through their own ``anyio.Lock``.
209
+ """
210
+ client = self._require_mcp_client()
211
+ async with client:
212
+ yield client
213
+
214
  @property
215
  def supports_code_mode(self) -> bool:
216
  """Check if this environment supports code mode (execute_code)."""
 
339
 
340
  # If mode is None, register with FastMCP as usual
341
  if mode is None:
342
+ mcp_server = self._require_mcp_server()
343
+ decorated_func = mcp_server.tool()(func)
344
  self._mode_tools[tool_name][None] = func
345
  return decorated_func
346
 
 
420
  return self._step_impl(action, timeout_s=timeout_s, **kwargs)
421
 
422
  def _handle_list_tools(self) -> ListToolsObservation:
423
+ """Sync wrapper — delegates to the canonical async implementation."""
424
+ return run_async_safely(self._async_handle_list_tools())
425
+
426
+ async def _async_list_tools(self) -> list:
427
  """
428
+ Async helper to list tools from the MCP client.
429
 
430
  Returns:
431
+ List of tool objects from the MCP server.
 
432
  """
433
+ async with self.mcp_session() as client:
434
+ return await client.list_tools()
 
435
 
436
+ def _handle_call_tool(
437
+ self,
438
+ action: CallToolAction,
439
+ timeout_s: Optional[float] = None,
440
+ ) -> CallToolObservation:
441
+ """Sync wrapper — delegates to the canonical async implementation."""
442
+ return run_async_safely(
443
+ self._async_handle_call_tool(action, timeout_s=timeout_s)
444
+ )
445
 
446
+ async def _async_call_tool(self, tool_name: str, arguments: dict) -> Any:
447
+ """
448
+ Async helper to call a tool on the MCP server.
449
 
450
+ Args:
451
+ tool_name: Name of the tool to invoke.
452
+ arguments: Dictionary of arguments to pass to the tool.
453
+
454
+ Returns:
455
+ The result from the tool execution.
456
+ """
457
+ async with self.mcp_session() as client:
458
+ return await client.call_tool(tool_name, arguments)
459
+
460
+ async def _async_handle_list_tools(self) -> ListToolsObservation:
461
+ """Async version of _handle_list_tools — avoids run_async_safely."""
462
+ try:
463
+ current_mode = getattr(self, "_mode", None)
464
+ tools_result = await self._async_list_tools()
465
+ tools = []
466
  for tool in tools_result:
467
  if tool.name not in self._mode_tool_schemas:
468
  tools.append(
 
474
  else {},
475
  )
476
  )
 
 
477
  for tool_name, mode_schemas in self._mode_tool_schemas.items():
478
  if None in mode_schemas:
 
479
  schema = mode_schemas[None]
480
  tools.append(
481
  Tool(
 
485
  )
486
  )
487
  elif current_mode in mode_schemas:
 
488
  schema = mode_schemas[current_mode]
489
  tools.append(
490
  Tool(
 
493
  input_schema=schema["input_schema"],
494
  )
495
  )
 
496
  return ListToolsObservation(tools=tools)
 
497
  except Exception as e:
 
498
  return ListToolsObservation(
499
  tools=[],
500
+ metadata={"error": str(e), "error_type": "list_tools_failed"},
 
 
 
501
  )
502
 
503
+ async def _async_handle_call_tool(
 
 
 
 
 
 
 
 
 
 
504
  self,
505
  action: CallToolAction,
506
  timeout_s: Optional[float] = None,
507
  ) -> CallToolObservation:
508
+ """Async version of _handle_call_tool — avoids run_async_safely."""
 
 
 
 
 
 
 
 
 
509
  timeout = timeout_s if timeout_s is not None else MCP_TOOL_CALL_TIMEOUT
 
 
510
  tool_name = action.tool_name
511
  current_mode = getattr(self, "_mode", None)
512
 
513
  if tool_name in self._mode_tools:
514
  mode_info = self._mode_tools[tool_name]
 
 
 
 
 
515
  if None in mode_info:
 
516
  func = mode_info[None]
517
  elif current_mode in mode_info:
 
518
  func = mode_info[current_mode]
519
  else:
 
520
  return CallToolObservation(
521
  tool_name=tool_name,
522
  result=None,
 
525
  message=f"Tool '{tool_name}' not available in {current_mode} mode",
526
  ),
527
  )
 
 
528
  try:
 
529
  if inspect.iscoroutinefunction(func):
530
+ result = await func(**action.arguments)
531
  else:
532
  result = func(**action.arguments)
 
 
533
  return CallToolObservation(
534
  tool_name=tool_name,
535
  result=CallToolResult(
 
550
  ),
551
  )
552
 
 
553
  try:
554
+ result = await asyncio.wait_for(
555
+ self._async_call_tool(action.tool_name, action.arguments),
556
+ timeout=timeout,
 
 
 
 
 
 
 
 
 
557
  )
558
+ return CallToolObservation(tool_name=action.tool_name, result=result)
559
  except asyncio.TimeoutError:
560
  return CallToolObservation(
561
  tool_name=action.tool_name,
 
565
  message=f"Tool '{action.tool_name}' timed out after {timeout} seconds",
566
  ),
567
  )
 
568
  except Exception as e:
569
  error_message = str(e)
 
 
570
  if (
571
  "not found" in error_message.lower()
572
  or "unknown tool" in error_message.lower()
 
579
  error_type = ToolErrorType.INVALID_ARGS
580
  else:
581
  error_type = ToolErrorType.EXECUTION_ERROR
 
582
  return CallToolObservation(
583
  tool_name=action.tool_name,
584
  result=None,
585
+ error=ToolError(error_type=error_type, message=error_message),
 
 
 
586
  )
587
 
588
+ async def step_async(
589
+ self,
590
+ action: Action,
591
+ timeout_s: Optional[float] = None,
592
+ **kwargs: Any,
593
+ ) -> Observation:
594
  """
595
+ Async step that routes MCP actions without going through run_async_safely.
596
 
597
+ The WebSocket handler calls this directly on the outer event loop, where
598
+ the MCP session is already open, avoiding the thread/event-loop deadlock
599
+ that occurs when the sync step() path is used via run_in_executor.
 
 
 
600
  """
601
+ if isinstance(action, ListToolsAction):
602
+ return await self._async_handle_list_tools()
603
+ elif isinstance(action, CallToolAction):
604
+ return await self._async_handle_call_tool(action, timeout_s=timeout_s)
605
+ else:
606
+ loop = asyncio.get_event_loop()
607
+ return await loop.run_in_executor(
608
+ None, lambda: self._step_impl(action, timeout_s=timeout_s, **kwargs)
609
+ )
610
 
611
  @abstractmethod
612
  def _step_impl(
src/openenv/core/env_server/serialization.py CHANGED
@@ -14,14 +14,28 @@ HTTP server and web interface implementations.
14
 
15
  from typing import Any, Dict, Type
16
 
 
17
  from .types import Action, Observation
18
 
 
 
 
 
 
 
 
 
 
19
 
20
  def deserialize_action(action_data: Dict[str, Any], action_cls: Type[Action]) -> Action:
21
  """
22
  Convert JSON dict to Action instance using Pydantic validation.
23
 
24
- This is a basic deserialization that works for most environments.
 
 
 
 
25
  For special cases (e.g., tensor fields, custom type conversions),
26
  use deserialize_action_with_preprocessing().
27
 
@@ -38,6 +52,17 @@ def deserialize_action(action_data: Dict[str, Any], action_cls: Type[Action]) ->
38
  Note:
39
  This uses Pydantic's model_validate() for automatic validation.
40
  """
 
 
 
 
 
 
 
 
 
 
 
41
  return action_cls.model_validate(action_data)
42
 
43
 
@@ -62,6 +87,15 @@ def deserialize_action_with_preprocessing(
62
  Raises:
63
  ValidationError: If action_data is invalid for the action class
64
  """
 
 
 
 
 
 
 
 
 
65
  processed_data = {}
66
 
67
  for key, value in action_data.items():
 
14
 
15
  from typing import Any, Dict, Type
16
 
17
+ from .mcp_types import CallToolAction, ListToolsAction
18
  from .types import Action, Observation
19
 
20
+ # MCP action types keyed by their "type" discriminator value.
21
+ # These are checked before the environment's own action_cls so that
22
+ # ListToolsAction / CallToolAction payloads are never rejected by an
23
+ # unrelated Pydantic model.
24
+ _MCP_ACTION_TYPES: Dict[str, Type[Action]] = {
25
+ "list_tools": ListToolsAction,
26
+ "call_tool": CallToolAction,
27
+ }
28
+
29
 
30
  def deserialize_action(action_data: Dict[str, Any], action_cls: Type[Action]) -> Action:
31
  """
32
  Convert JSON dict to Action instance using Pydantic validation.
33
 
34
+ MCP action types (``list_tools``, ``call_tool``) are recognised
35
+ automatically via the ``"type"`` discriminator field, regardless of
36
+ the environment's configured ``action_cls``. All other payloads
37
+ fall through to ``action_cls.model_validate()``.
38
+
39
  For special cases (e.g., tensor fields, custom type conversions),
40
  use deserialize_action_with_preprocessing().
41
 
 
52
  Note:
53
  This uses Pydantic's model_validate() for automatic validation.
54
  """
55
+ # Route MCP action types before falling through to the env action_cls.
56
+ # Only intercept when action_cls is the generic Action base or itself an
57
+ # MCP type (i.e. the server hosts an MCP environment). This avoids
58
+ # silently bypassing env-specific validation for non-MCP environments
59
+ # that happen to use "call_tool" / "list_tools" as a type discriminator.
60
+ action_type = action_data.get("type")
61
+ if action_type in _MCP_ACTION_TYPES:
62
+ mcp_cls = _MCP_ACTION_TYPES[action_type]
63
+ if action_cls is Action or action_cls in _MCP_ACTION_TYPES.values():
64
+ return mcp_cls.model_validate(action_data)
65
+
66
  return action_cls.model_validate(action_data)
67
 
68
 
 
87
  Raises:
88
  ValidationError: If action_data is invalid for the action class
89
  """
90
+ # Route MCP action types before preprocessing (they don't need it).
91
+ # Same guard as deserialize_action: only intercept when action_cls is
92
+ # the generic Action base or itself an MCP type.
93
+ action_type = action_data.get("type")
94
+ if action_type in _MCP_ACTION_TYPES:
95
+ mcp_cls = _MCP_ACTION_TYPES[action_type]
96
+ if action_cls is Action or action_cls in _MCP_ACTION_TYPES.values():
97
+ return mcp_cls.model_validate(action_data)
98
+
99
  processed_data = {}
100
 
101
  for key, value in action_data.items():
src/openenv/core/env_server/web_interface.py CHANGED
@@ -15,13 +15,15 @@ option (e.g. openenv push --enable-interface) or ENABLE_WEB_INTERFACE env var.
15
  from __future__ import annotations
16
 
17
  import asyncio
 
18
  import json
19
  from concurrent.futures import ThreadPoolExecutor
20
  from datetime import datetime
21
  from typing import Any, Callable, Dict, List, Optional, Type
22
 
23
  import gradio as gr
24
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect
 
25
  from pydantic import BaseModel, ConfigDict, Field
26
 
27
  from .gradio_theme import OPENENV_GRADIO_CSS, OPENENV_GRADIO_THEME
@@ -269,6 +271,28 @@ class WebInterfaceManager:
269
  # Thread pool for running sync code (e.g., Playwright sync API) in async context
270
  self._executor = ThreadPoolExecutor(max_workers=1)
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  async def _run_sync_in_thread_pool(self, func, *args, **kwargs):
273
  """Run a synchronous function in the thread pool executor.
274
 
@@ -317,11 +341,24 @@ class WebInterfaceManager:
317
  for client in disconnected_clients:
318
  self.connected_clients.remove(client)
319
 
320
- async def reset_environment(self) -> Dict[str, Any]:
 
 
321
  """Reset the environment and update state."""
322
- # Run sync reset in thread pool to avoid blocking event loop
323
- # and to support environments using sync libraries (e.g., Playwright)
324
- observation: Observation = await self._run_sync_in_thread_pool(self.env.reset)
 
 
 
 
 
 
 
 
 
 
 
325
  state: State = self.env.state
326
 
327
  # Serialize observation once using shared utility
@@ -428,6 +465,16 @@ def create_web_interface_app(
428
  web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata)
429
 
430
  # Web API routes first (so they take precedence over Gradio mount at /web)
 
 
 
 
 
 
 
 
 
 
431
  @app.get("/web/metadata")
432
  async def web_metadata():
433
  """Get environment metadata."""
@@ -449,9 +496,9 @@ def create_web_interface_app(
449
  await web_manager.disconnect_websocket(websocket)
450
 
451
  @app.post("/web/reset")
452
- async def web_reset():
453
  """Reset endpoint for web interface."""
454
- return await web_manager.reset_environment()
455
 
456
  @app.post("/web/step")
457
  async def web_step(request: Dict[str, Any]):
@@ -475,7 +522,13 @@ def create_web_interface_app(
475
  @app.get("/web/state")
476
  async def web_state():
477
  """State endpoint for web interface."""
478
- return web_manager.get_state()
 
 
 
 
 
 
479
 
480
  action_fields = _extract_action_fields(action_cls)
481
  is_chat_env = _is_chat_env(action_cls)
@@ -505,7 +558,7 @@ def create_web_interface_app(
505
  )
506
  gradio_blocks = gr.TabbedInterface(
507
  [default_blocks, custom_blocks],
508
- tab_names=["Playground", "Visualization"],
509
  title=get_gradio_display_title(metadata),
510
  )
511
  else:
 
15
  from __future__ import annotations
16
 
17
  import asyncio
18
+ import inspect
19
  import json
20
  from concurrent.futures import ThreadPoolExecutor
21
  from datetime import datetime
22
  from typing import Any, Callable, Dict, List, Optional, Type
23
 
24
  import gradio as gr
25
+ from fastapi import Body, FastAPI, HTTPException, status, WebSocket, WebSocketDisconnect
26
+ from fastapi.responses import RedirectResponse
27
  from pydantic import BaseModel, ConfigDict, Field
28
 
29
  from .gradio_theme import OPENENV_GRADIO_CSS, OPENENV_GRADIO_THEME
 
271
  # Thread pool for running sync code (e.g., Playwright sync API) in async context
272
  self._executor = ThreadPoolExecutor(max_workers=1)
273
 
274
+ @staticmethod
275
+ def _get_valid_kwargs(
276
+ sig: inspect.Signature,
277
+ kwargs: Dict[str, Any],
278
+ skip_params: Optional[set[str]] = None,
279
+ ) -> Dict[str, Any]:
280
+ """Filter kwargs to only those accepted by the target function."""
281
+ skip_params = skip_params or set()
282
+ valid_kwargs: Dict[str, Any] = {}
283
+ has_var_kwargs = any(
284
+ param.kind == inspect.Parameter.VAR_KEYWORD
285
+ for param in sig.parameters.values()
286
+ )
287
+
288
+ for key, value in kwargs.items():
289
+ if key in skip_params:
290
+ continue
291
+ if key in sig.parameters or has_var_kwargs:
292
+ valid_kwargs[key] = value
293
+
294
+ return valid_kwargs
295
+
296
  async def _run_sync_in_thread_pool(self, func, *args, **kwargs):
297
  """Run a synchronous function in the thread pool executor.
298
 
 
341
  for client in disconnected_clients:
342
  self.connected_clients.remove(client)
343
 
344
+ async def reset_environment(
345
+ self, reset_kwargs: Optional[Dict[str, Any]] = None
346
+ ) -> Dict[str, Any]:
347
  """Reset the environment and update state."""
348
+ reset_kwargs = reset_kwargs or {}
349
+
350
+ is_async = self.env.reset_async.__func__ is not Environment.reset_async
351
+ sig = inspect.signature(self.env.reset_async if is_async else self.env.reset)
352
+ valid_kwargs = self._get_valid_kwargs(sig, reset_kwargs)
353
+
354
+ if is_async:
355
+ observation = await self.env.reset_async(**valid_kwargs)
356
+ else:
357
+ # Run sync reset in thread pool to avoid blocking event loop
358
+ # and to support environments using sync libraries (e.g., Playwright)
359
+ observation = await self._run_sync_in_thread_pool(
360
+ self.env.reset, **valid_kwargs
361
+ )
362
  state: State = self.env.state
363
 
364
  # Serialize observation once using shared utility
 
465
  web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata)
466
 
467
  # Web API routes first (so they take precedence over Gradio mount at /web)
468
+ @app.get("/", include_in_schema=False)
469
+ async def web_root():
470
+ """Redirect the app root to the Gradio interface."""
471
+ return RedirectResponse(url="/web/")
472
+
473
+ @app.get("/web", include_in_schema=False)
474
+ async def web_root_no_slash():
475
+ """Redirect /web to /web/ for mounted Gradio deployments behind proxies."""
476
+ return RedirectResponse(url="/web/")
477
+
478
  @app.get("/web/metadata")
479
  async def web_metadata():
480
  """Get environment metadata."""
 
496
  await web_manager.disconnect_websocket(websocket)
497
 
498
  @app.post("/web/reset")
499
+ async def web_reset(request: Optional[Dict[str, Any]] = Body(default=None)):
500
  """Reset endpoint for web interface."""
501
+ return await web_manager.reset_environment(request)
502
 
503
  @app.post("/web/step")
504
  async def web_step(request: Dict[str, Any]):
 
522
  @app.get("/web/state")
523
  async def web_state():
524
  """State endpoint for web interface."""
525
+ try:
526
+ return web_manager.get_state()
527
+ except RuntimeError as exc:
528
+ raise HTTPException(
529
+ status_code=status.HTTP_409_CONFLICT,
530
+ detail=str(exc),
531
+ ) from exc
532
 
533
  action_fields = _extract_action_fields(action_cls)
534
  is_chat_env = _is_chat_env(action_cls)
 
558
  )
559
  gradio_blocks = gr.TabbedInterface(
560
  [default_blocks, custom_blocks],
561
+ tab_names=["Playground", "Custom"],
562
  title=get_gradio_display_title(metadata),
563
  )
564
  else:
src/openenv/core/mcp_client.py CHANGED
@@ -52,6 +52,7 @@ Example (sync wrapper):
52
  ... result = env.call_tool("echo_message", message="Hello!")
53
  """
54
 
 
55
  from typing import Any, Dict, List, Optional
56
 
57
  from .client_types import StepResult
@@ -118,6 +119,66 @@ class MCPClientBase(EnvClient[Any, Observation, State]):
118
  )
119
  self._tools_cache: Optional[List[Tool]] = None
120
  self.use_production_mode = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  async def list_tools(self, use_cache: bool = True) -> List[Tool]:
123
  """
@@ -138,26 +199,18 @@ class MCPClientBase(EnvClient[Any, Observation, State]):
138
  if use_cache and self._tools_cache is not None:
139
  return self._tools_cache
140
 
141
- # Use production mode HTTP endpoint if enabled
142
- if self.use_production_mode:
143
- import requests
144
-
145
- # Convert ws:// URL to http:// URL
146
- url = self._ws_url.replace("ws://", "http://").replace("wss://", "https://")
147
- # Remove /ws suffix if present and add /mcp
148
- url = url.rstrip("/ws").rstrip("/") + "/mcp"
149
-
150
  try:
151
- response = requests.post(
152
- url,
153
- json={
154
- "jsonrpc": "2.0",
155
- "method": "tools/list",
156
- "params": {},
157
- "id": 1,
158
- },
159
  )
160
- data = response.json()
 
 
161
  if "result" in data and "tools" in data["result"]:
162
  tools = [
163
  Tool(
@@ -177,7 +230,12 @@ class MCPClientBase(EnvClient[Any, Observation, State]):
177
  return []
178
 
179
  result = await self.step(ListToolsAction())
180
- self._tools_cache = result.observation.tools
 
 
 
 
 
181
  return self._tools_cache
182
 
183
  def _step_payload(self, action: Any) -> Dict[str, Any]:
@@ -251,6 +309,35 @@ class MCPClientBase(EnvClient[Any, Observation, State]):
251
  step_count=payload.get("step_count", 0),
252
  )
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
  class MCPToolClient(MCPClientBase):
256
  """
@@ -316,6 +403,26 @@ class MCPToolClient(MCPClientBase):
316
  >>> result = await env.call_tool("greet", name="Claude")
317
  >>> print(result) # "Hello, Claude!"
318
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  action = CallToolAction(tool_name=name, arguments=kwargs)
320
  result = await self.step(action)
321
  obs = result.observation
 
52
  ... result = env.call_tool("echo_message", message="Hello!")
53
  """
54
 
55
+ import asyncio
56
  from typing import Any, Dict, List, Optional
57
 
58
  from .client_types import StepResult
 
119
  )
120
  self._tools_cache: Optional[List[Tool]] = None
121
  self.use_production_mode = False
122
+ self._production_session_id: Optional[str] = None
123
+ self._production_session_lock = asyncio.Lock()
124
+ self._jsonrpc_request_id = 0
125
+ self._http_client: Optional[Any] = None # lazily-created httpx.AsyncClient
126
+
127
+ def _next_request_id(self) -> int:
128
+ """Generate a monotonically increasing JSON-RPC request id."""
129
+ self._jsonrpc_request_id += 1
130
+ return self._jsonrpc_request_id
131
+
132
+ def _production_mcp_url(self) -> str:
133
+ """Build HTTP MCP endpoint URL from the client's websocket URL."""
134
+ url = self._ws_url.replace("ws://", "http://").replace("wss://", "https://")
135
+ if url.endswith("/ws"):
136
+ url = url[: -len("/ws")]
137
+ return url.rstrip("/") + "/mcp"
138
+
139
+ async def _get_http_client(self) -> Any:
140
+ """Return a shared httpx.AsyncClient, creating one lazily."""
141
+ if self._http_client is None:
142
+ import httpx
143
+
144
+ self._http_client = httpx.AsyncClient()
145
+ return self._http_client
146
+
147
+ async def _production_mcp_request(
148
+ self, method: str, params: Optional[Dict[str, Any]] = None
149
+ ) -> Dict[str, Any]:
150
+ """Send a JSON-RPC request to HTTP /mcp and return parsed JSON response."""
151
+ client = await self._get_http_client()
152
+ response = await client.post(
153
+ self._production_mcp_url(),
154
+ json={
155
+ "jsonrpc": "2.0",
156
+ "method": method,
157
+ "params": params or {},
158
+ "id": self._next_request_id(),
159
+ },
160
+ timeout=self._message_timeout,
161
+ )
162
+ response.raise_for_status()
163
+ return response.json()
164
+
165
+ async def _ensure_production_session(self) -> str:
166
+ """Create and cache a persistent HTTP MCP session id if needed."""
167
+ async with self._production_session_lock:
168
+ if self._production_session_id is not None:
169
+ return self._production_session_id
170
+
171
+ data = await self._production_mcp_request("openenv/session/create")
172
+ if "error" in data:
173
+ message = data.get("error", {}).get("message", "unknown error")
174
+ raise RuntimeError(f"Failed to create MCP session: {message}")
175
+
176
+ session_id = data.get("result", {}).get("session_id")
177
+ if not session_id:
178
+ raise RuntimeError("Failed to create MCP session: missing session_id")
179
+
180
+ self._production_session_id = session_id
181
+ return session_id
182
 
183
  async def list_tools(self, use_cache: bool = True) -> List[Tool]:
184
  """
 
199
  if use_cache and self._tools_cache is not None:
200
  return self._tools_cache
201
 
202
+ # Use production mode HTTP endpoint if enabled.
203
+ # Some tests instantiate with __new__ and skip __init__, so default missing flag to False.
204
+ if getattr(self, "use_production_mode", False):
 
 
 
 
 
 
205
  try:
206
+ session_id = await self._ensure_production_session()
207
+ data = await self._production_mcp_request(
208
+ "tools/list",
209
+ {"session_id": session_id},
 
 
 
 
210
  )
211
+ if "error" in data:
212
+ message = data.get("error", {}).get("message", "unknown error")
213
+ raise RuntimeError(f"list_tools failed: {message}")
214
  if "result" in data and "tools" in data["result"]:
215
  tools = [
216
  Tool(
 
230
  return []
231
 
232
  result = await self.step(ListToolsAction())
233
+ if isinstance(result.observation, ListToolsObservation):
234
+ self._tools_cache = result.observation.tools
235
+ return self._tools_cache
236
+
237
+ # Unexpected observation type; keep API stable with an empty tool list.
238
+ self._tools_cache = []
239
  return self._tools_cache
240
 
241
  def _step_payload(self, action: Any) -> Dict[str, Any]:
 
309
  step_count=payload.get("step_count", 0),
310
  )
311
 
312
+ async def close(self) -> None:
313
+ """
314
+ Close client resources.
315
+
316
+ In production MCP mode, this also closes the server-side persistent
317
+ MCP session (best effort) before closing websocket/provider resources.
318
+ """
319
+ if self._production_session_id is not None:
320
+ try:
321
+ await self._production_mcp_request(
322
+ "openenv/session/close",
323
+ {"session_id": self._production_session_id},
324
+ )
325
+ except Exception:
326
+ # Best effort cleanup - do not mask normal close behavior
327
+ pass
328
+ finally:
329
+ self._production_session_id = None
330
+
331
+ if self._http_client is not None:
332
+ try:
333
+ await self._http_client.aclose()
334
+ except Exception:
335
+ pass
336
+ finally:
337
+ self._http_client = None
338
+
339
+ await super().close()
340
+
341
 
342
  class MCPToolClient(MCPClientBase):
343
  """
 
403
  >>> result = await env.call_tool("greet", name="Claude")
404
  >>> print(result) # "Hello, Claude!"
405
  """
406
+ if getattr(self, "use_production_mode", False):
407
+ session_id = await self._ensure_production_session()
408
+ data = await self._production_mcp_request(
409
+ "tools/call",
410
+ {
411
+ "name": name,
412
+ "arguments": kwargs,
413
+ "session_id": session_id,
414
+ },
415
+ )
416
+
417
+ if "error" in data:
418
+ message = data.get("error", {}).get("message", "unknown error")
419
+ raise RuntimeError(f"Tool '{name}' failed: {message}")
420
+
421
+ result = data.get("result")
422
+ if isinstance(result, dict) and "data" in result:
423
+ return result["data"]
424
+ return result
425
+
426
  action = CallToolAction(tool_name=name, arguments=kwargs)
427
  result = await self.step(action)
428
  obs = result.observation
src/openenv_core.egg-info/PKG-INFO CHANGED
@@ -1,6 +1,6 @@
1
  Metadata-Version: 2.4
2
  Name: openenv-core
3
- Version: 0.2.2.dev0
4
  Summary: A unified framework for reinforcement learning environments
5
  Requires-Python: >=3.10
6
  Description-Content-Type: text/markdown
@@ -19,6 +19,7 @@ Requires-Dist: tomli-w>=1.2.0
19
  Requires-Dist: websockets>=15.0.1
20
  Requires-Dist: fastmcp>=3.0.0
21
  Requires-Dist: gradio>=4.0.0
 
22
  Provides-Extra: core
23
  Requires-Dist: fastapi>=0.104.0; extra == "core"
24
  Requires-Dist: pydantic>=2.0.0; extra == "core"
@@ -61,7 +62,7 @@ Dynamic: license-file
61
 
62
  An e2e framework for creating, deploying and using isolated execution environments for agentic RL training, built using Gymnasium style simple APIs.
63
 
64
- [![PyPI](https://img.shields.io/pypi/v/openenv?color=blue)](https://pypi.org/project/openenv/)
65
  [![Discord](https://img.shields.io/badge/Discord-OpenEnv-7289da?style=flat&logo=discord&logoColor=white)](https://discord.gg/YsTYBh6PD9)
66
  [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/meta-pytorch/OpenEnv/blob/main/examples/OpenEnv_Tutorial.ipynb)
67
  [![Docs](https://img.shields.io/badge/Docs-Explore-blue?logo=readthedocs&logoColor=white)](https://meta-pytorch.org/OpenEnv/)
 
1
  Metadata-Version: 2.4
2
  Name: openenv-core
3
+ Version: 0.2.3
4
  Summary: A unified framework for reinforcement learning environments
5
  Requires-Python: >=3.10
6
  Description-Content-Type: text/markdown
 
19
  Requires-Dist: websockets>=15.0.1
20
  Requires-Dist: fastmcp>=3.0.0
21
  Requires-Dist: gradio>=4.0.0
22
+ Requires-Dist: httpx>=0.28.1
23
  Provides-Extra: core
24
  Requires-Dist: fastapi>=0.104.0; extra == "core"
25
  Requires-Dist: pydantic>=2.0.0; extra == "core"
 
62
 
63
  An e2e framework for creating, deploying and using isolated execution environments for agentic RL training, built using Gymnasium style simple APIs.
64
 
65
+ [![PyPI](https://img.shields.io/pypi/v/openenv-core?color=blue)](https://pypi.org/project/openenv-core/)
66
  [![Discord](https://img.shields.io/badge/Discord-OpenEnv-7289da?style=flat&logo=discord&logoColor=white)](https://discord.gg/YsTYBh6PD9)
67
  [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/meta-pytorch/OpenEnv/blob/main/examples/OpenEnv_Tutorial.ipynb)
68
  [![Docs](https://img.shields.io/badge/Docs-Explore-blue?logo=readthedocs&logoColor=white)](https://meta-pytorch.org/OpenEnv/)
src/openenv_core.egg-info/SOURCES.txt CHANGED
@@ -1,4 +1,5 @@
1
  LICENSE
 
2
  README.md
3
  pyproject.toml
4
  src/openenv/__init__.py
@@ -19,8 +20,6 @@ src/openenv/cli/commands/serve.py
19
  src/openenv/cli/commands/skills.py
20
  src/openenv/cli/commands/validate.py
21
  src/openenv/cli/templates/__init__.py
22
- src/openenv/cli/templates/__pycache__/__init__.cpython-311.pyc
23
- src/openenv/cli/templates/__pycache__/__init__.cpython-313.pyc
24
  src/openenv/cli/templates/openenv_env/README.md
25
  src/openenv/cli/templates/openenv_env/__init__.py
26
  src/openenv/cli/templates/openenv_env/client.py
 
1
  LICENSE
2
+ MANIFEST.in
3
  README.md
4
  pyproject.toml
5
  src/openenv/__init__.py
 
20
  src/openenv/cli/commands/skills.py
21
  src/openenv/cli/commands/validate.py
22
  src/openenv/cli/templates/__init__.py
 
 
23
  src/openenv/cli/templates/openenv_env/README.md
24
  src/openenv/cli/templates/openenv_env/__init__.py
25
  src/openenv/cli/templates/openenv_env/client.py
src/openenv_core.egg-info/requires.txt CHANGED
@@ -12,6 +12,7 @@ tomli-w>=1.2.0
12
  websockets>=15.0.1
13
  fastmcp>=3.0.0
14
  gradio>=4.0.0
 
15
 
16
  [all]
17
  openenv-core[core]
 
12
  websockets>=15.0.1
13
  fastmcp>=3.0.0
14
  gradio>=4.0.0
15
+ httpx>=0.28.1
16
 
17
  [all]
18
  openenv-core[core]