Complete overview of jax-aht repo structure and organization?

HEAD
2796ec6Fix wandb commit call in LBRDiv
This post might have stale content, as HEAD is 13 commits ahead.
Avatar of Montekkundan
Montek
posted

tell me about this repo, its structure whats it about, if i plan to contribute show me how files are orgranised and how folder structure is there, what tests cases are there etc a full overview

Avatar of google/gemini-3-flash
Gemini 3 Flash
commented

JaxAHT is a JAX-based benchmark for Ad Hoc Teamwork (AHT). It provides a unified framework for researching the entire lifecycle of AHT, from teammate generation to training "ego" agents (the main agents being evaluated) and evaluating them in cooperative environments.

###1. Core Purpose and Design Philosophy

The repository is built to facilitate research where multiple algorithms must interact. For example:

  • Teammate Generation: Creating diverse partners to train against.
  • Ego Agent Training: Training a single agent to cooperate with various partners.
  • MARL: Standard multi-agent reinforcement learning (e.g., IPPO).
  • Open-Ended Training: Algorithms like ROTATE that evolve teammates and tasks.

Key Design Principles:

  • Single-File Model: Inspired by CleanRL and JaxMARL, algorithms are largely implemented in single files (e.g., marl/ippo.py) to make them easy to read, modify, and JIT-compile as a single end-to-end loop.
  • Unified Interfaces: Despite the single-file approach, agents and populations follow common interfaces (AgentPolicy, AgentPopulation) so that agents trained by one method (e.g., MARL) can be used as training partners in another (e.g., Ego Training).

###2. Project Structure

The codebase is organized by the "role" of the algorithm or component:

DirectoryDescription
agents/Abstract interfaces and specific agent architectures (MLP, RNN, S5, and heuristics for specific envs).
envs/Environment implementations (Level-Based Foraging, Overcooked-v1, Hanabi) and wrappers (logging, reward shaping).
marl/Standard Multi-Agent RL algorithms (currently ippo.py).
ego_agent_training/Algorithms to train an ego agent against a fixed/pre-specified population (PPO, LIAM, MeLIBA).
teammate_generation/Algorithms to generate diverse teammate populations (FCP, BRDiv, CoMeDi).
open_ended_training/Open-ended AHT methods like ROTATE and PAIRED.
evaluation/Scripts for cross-play (XP) matrices and held-out evaluation.
common/Shared utilities for tree manipulation, plotting, and loading agents from configs.
tests/Unit tests for environments and agents.

###3. How Files are Organized (for Contributors)

If you plan to contribute, follow these patterns:

  • Algorithms: Add new algorithms to their respective folder (marl/, ego_agent_training/, etc.). Aim for a "self-contained" file that implements the training loop.
  • Configurations: Uses Hydra. Each algorithm folder has a configs/ sub-folder.
    • configs/algorithm/: Algorithm-specific hyperparameters.
    • configs/task/: Environment/task-specific settings.
  • Entry Points: Each category directory has a run.py that serves as the main entry point using Hydra.
    • Example: python teammate_generation/run.py task=lbf algorithm=fcp/lbf
  • Agent Interface: If adding a new agent type, it should inherit from AgentPolicy in agents/agent_interface.py. This ensures it implements get_action (JIT-compatible) and init_hstate.

###4. Testing

The tests/ directory contains several test scripts to verify the core components:

  • Environment Wrappers: test_lbf_wrapper.py, test_overcooked_wrapper.py.
  • Heuristic Agents: test_lbf_agents.py, test_overcooked_agents.py (checks if hardcoded baseline agents behave correctly).
  • Environment Logic: test_hanabi.py.

You can run these using pytest or directly with python tests/test_overcooked_agents.py.

###5. Environment Details

  • Level-Based Foraging (LBF): Agents must coordinate to pick up food based on their levels.
  • Overcooked-v1: High-coordination cooking task (based on JaxMARL implementation).
  • Auto-reset: Environments are assumed to "auto-reset" within the step function, which is a common pattern for JAX-based RL to avoid branching in the JIT loop.

END OF POST

Complete overview of jax-aht repo structure and organization? — LARG/jax-aht