atari_env / server /app.py
burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
cf850a8 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
FastAPI application for the Atari Environment.
This module creates an HTTP server that exposes Atari games
over HTTP and WebSocket endpoints, compatible with EnvClient.
Usage:
# Development (with auto-reload):
uvicorn envs.atari_env.server.app:app --reload --host 0.0.0.0 --port 8000
# Production:
uvicorn envs.atari_env.server.app:app --host 0.0.0.0 --port 8000 --workers 4
# Or run directly:
python -m envs.atari_env.server.app
Environment variables:
ATARI_GAME: Game name to serve (default: "pong")
ATARI_OBS_TYPE: Observation type (default: "rgb")
ATARI_FULL_ACTION_SPACE: Use full action space (default: "false")
ATARI_MODE: Game mode (optional)
ATARI_DIFFICULTY: Game difficulty (optional)
ATARI_REPEAT_ACTION_PROB: Sticky action probability (default: "0.0")
ATARI_FRAMESKIP: Frameskip (default: "4")
"""
import os
from openenv.core.env_server import create_app
# Support both in-repo and standalone imports
try:
# In-repo imports (when running from OpenEnv repository)
from ..models import AtariAction, AtariObservation
from .atari_environment import AtariEnvironment
except ImportError as e:
if "relative import" not in str(e) and "no known parent package" not in str(e):
raise
# Standalone imports (when running via uvicorn server.app:app)
from models import AtariAction, AtariObservation
from server.atari_environment import AtariEnvironment
# Get configuration from environment variables
game_name = os.getenv("ATARI_GAME", "pong")
obs_type = os.getenv("ATARI_OBS_TYPE", "rgb")
full_action_space = os.getenv("ATARI_FULL_ACTION_SPACE", "false").lower() == "true"
repeat_action_prob = float(os.getenv("ATARI_REPEAT_ACTION_PROB", "0.0"))
frameskip = int(os.getenv("ATARI_FRAMESKIP", "4"))
# Optional parameters
mode = os.getenv("ATARI_MODE")
difficulty = os.getenv("ATARI_DIFFICULTY")
# Convert to int if specified
mode = int(mode) if mode is not None else None
difficulty = int(difficulty) if difficulty is not None else None
# Factory function to create AtariEnvironment instances
def create_atari_environment():
"""Factory function that creates AtariEnvironment with config."""
return AtariEnvironment(
game_name=game_name,
obs_type=obs_type,
full_action_space=full_action_space,
mode=mode,
difficulty=difficulty,
repeat_action_probability=repeat_action_prob,
frameskip=frameskip,
)
# Create the FastAPI app with web interface and README integration
# Pass the factory function instead of an instance for WebSocket session support
app = create_app(
create_atari_environment, AtariAction, AtariObservation, env_name="atari_env"
)
def main():
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
if __name__ == "__main__":
main()