Asynchronous Multiagent Decision Making

This tutorial demonstrates how to configure and train a multiagent environment in RLlib in which homogeneous agents act asyncronously while learning learning a single policy.

Warning: Part of RLlib’s backend mishandles the potential for zero-length episodes, which this method may produce. If this PR is unmerged, it must be manually applied to your installation: https://github.com/ray-project/ray/pull/46721

This example was run with the following version of RLlib:

[1]:
from importlib.metadata import version
version("ray")  # Parent package of RLlib
[1]:
'2.35.0'

Define the Environment

A simple multi-satellite environment is defined. This environment is trivial for the multiagent case since there is no communication or interaction between satellites, but it serves to demonstrate asynchronous behavior.

[2]:
import numpy as np
from bsk_rl import act, data, obs, sats, scene
from bsk_rl.sim import dyn, fsw

class ScanningDownlinkDynModel(dyn.ContinuousImagingDynModel, dyn.GroundStationDynModel):
    # Define some custom properties to be accessed in the state
    @property
    def instrument_pointing_error(self) -> float:
        r_BN_P_unit = self.r_BN_P/np.linalg.norm(self.r_BN_P)
        c_hat_P = self.satellite.fsw.c_hat_P
        return np.arccos(np.dot(-r_BN_P_unit, c_hat_P))

    @property
    def solar_pointing_error(self) -> float:
        a = self.world.gravFactory.spiceObject.planetStateOutMsgs[
            self.world.sun_index
        ].read().PositionVector
        a_hat_N = a / np.linalg.norm(a)
        nHat_B = self.satellite.sat_args["nHat_B"]
        NB = np.transpose(self.BN)
        nHat_N = NB @ nHat_B
        return np.arccos(np.dot(nHat_N, a_hat_N))

class ScanningSatellite(sats.AccessSatellite):
    observation_spec = [
        obs.SatProperties(
            dict(prop="storage_level_fraction"),
            dict(prop="battery_charge_fraction"),
            dict(prop="wheel_speeds_fraction"),
            dict(prop="instrument_pointing_error", norm=np.pi),
            dict(prop="solar_pointing_error", norm=np.pi)
        ),
        obs.OpportunityProperties(
            dict(prop="opportunity_open", norm=5700),
            dict(prop="opportunity_close", norm=5700),
            type="ground_station",
            n_ahead_observe=1,
        ),
        obs.Eclipse(norm=5700),
    ]
    action_spec = [
        act.Scan(duration=150.0),
        act.Charge(duration=120.0),
        act.Downlink(duration=80.0),
        act.Desat(duration=45.0),
    ]
    dyn_type = ScanningDownlinkDynModel
    fsw_type = fsw.ContinuousImagingFSWModel

sats = [ScanningSatellite(
    f"Scanner-{i+1}",
    sat_args=dict(
        # Data
        dataStorageCapacity=5000 * 8e6,  # bits
        storageInit=lambda: np.random.uniform(0.0, 0.8) * 5000 * 8e6,
        instrumentBaudRate=0.5 * 8e6,
        transmitterBaudRate=-50 * 8e6,
        # Power
        batteryStorageCapacity=200 * 3600,  # W*s
        storedCharge_Init=lambda: np.random.uniform(0.3, 1.0) * 200 * 3600,
        basePowerDraw=-10.0,  # W
        instrumentPowerDraw=-30.0,  # W
        transmitterPowerDraw=-25.0,  # W
        thrusterPowerDraw=-80.0,  # W
        panelArea=0.25,
        # Attitude
        imageAttErrorRequirement=0.1,
        imageRateErrorRequirement=0.1,
        disturbance_vector=lambda: np.random.normal(scale=0.0001, size=3),  # N*m
        maxWheelSpeed=6000.0,  # RPM
        wheelSpeeds=lambda: np.random.uniform(-3000, 3000, 3),
        desatAttitude="nadir",
    )
) for i in range(4)]

Correlated Environment Parameters

To construct a constellation with some coordinated, a function is generated to map satellites to orbital elements:

[3]:
from bsk_rl.utils.orbital import walker_delta_args

sat_arg_randomizer = walker_delta_args(n_planes=2, altitude=500)

The sat_arg_randomizer is included in the environment arguments.

[4]:
duration = 5 * 5700.0  # About 5 orbits
env_args = dict(
    satellites=sats,
    scenario=scene.UniformNadirScanning(value_per_second=1/duration),
    rewarder=data.ScanningTimeReward(),
    time_limit=duration,
    failure_penalty=-1.0,
    terminate_on_time_limit=True,
    sat_arg_randomizer=sat_arg_randomizer,
)

RLlib Training Configuration

A standard PPO configuration is generated.

[5]:
import bsk_rl.utils.rllib  # noqa To access "ConstellationTasking-RLlib"
from ray.rllib.algorithms.ppo import PPOConfig


N_CPUS = 3

training_args = dict(
    lr=0.00003,
    gamma=0.99997,
    train_batch_size=200 * N_CPUS,
    num_sgd_iter=10,
    lambda_=0.95,
    use_kl_loss=False,
    clip_param=0.1,
    grad_clip=0.5,
    mini_batch_size_per_learner=100,
)

config = (
    PPOConfig()
    .environment(
        "ConstellationTasking-RLlib",
        env_config=env_args,
    )
    .env_runners(
        num_env_runners=N_CPUS - 1,
        sample_timeout_s=1000.0,
    )
    .reporting(
        metrics_num_episodes_for_smoothing=1,
        metrics_episode_collection_timeout_s=180,
    )
    .checkpointing(export_native_model_files=True)
    .framework(framework="torch")
    .api_stack(
        enable_rl_module_and_learner=True,
        enable_env_runner_and_connector_v2=True,
    )
    .training(
        **training_args,
    )
)

To set up multiple agents using the same policy, the following configurations are set to map all agents to the policy p0.

[6]:
try:
    from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
    from ray.rllib.core.rl_module.rl_module import RLModuleSpec
except (ImportError, ModuleNotFoundError):  # Older versions of RLlib
    from ray.rllib.core.rl_module.marl_module import (
        MultiAgentRLModuleSpec as MultiRLModuleSpec,
    )
    from ray.rllib.core.rl_module.rl_module import (
        SingleAgentRLModuleSpec as RLModuleSpec,
    )

config.multi_agent(
    policies={"p0"},
    policy_mapping_fn=lambda *args, **kwargs: "p0",
).rl_module(
    model_config_dict={
        "use_lstm": False,
        # Use a simpler FCNet when we also have an LSTM.
        "fcnet_hiddens": [2048, 2048],
        "vf_share_layers": False,
    },
    rl_module_spec=MultiRLModuleSpec(
        module_specs={
            "p0": RLModuleSpec(),
        }
    ),
)


[6]:
<ray.rllib.algorithms.ppo.ppo.PPOConfig at 0x11006dc30>

Configuring Multiagent Logging Callbacks

A callback function for the entire environment

[7]:
def env_metrics_callback(env):
    reward = env.rewarder.cum_reward
    reward = sum(reward.values()) / len(reward)
    return dict(reward=reward)

and per satellite

[8]:
def sat_metrics_callback(env, satellite):
    data = dict(
        # Are satellites dying, and how and when?
        alive=float(satellite.is_alive()),
        rw_status_valid=float(satellite.dynamics.rw_speeds_valid()),
        battery_status_valid=float(satellite.dynamics.battery_valid()),
    )
    return data

are defined. The sat_metrics_callback will be reported per-agent and as a mean. If using the predefined "ConstellationTasking-RLlib", only the WrappedEpisodeDataCallbacks need to be added to the config, as in the single-agent case.

[9]:
from bsk_rl.utils.rllib.callbacks import WrappedEpisodeDataCallbacks

config.callbacks(WrappedEpisodeDataCallbacks)
[9]:
<ray.rllib.algorithms.ppo.ppo.PPOConfig at 0x11006dc30>

Action Continuation and Concatenation

Logic to prevent all agents from retasking whenever any agent finishes an action is introduced, in the form of connector modules. First, the ContinuePreviousAction connector overrides any policy-selected action with the bsk_rl.NO_ACTION whenever requires_retasking==False for an agent, causing the agent to continue its current action.

[10]:
from bsk_rl.utils.rllib import discounting

config.env_runners(
    module_to_env_connector=lambda env: (discounting.ContinuePreviousAction(),)
)
[10]:
<ray.rllib.algorithms.ppo.ppo.PPOConfig at 0x11006dc30>

Then, two other connectors compress NO_ACTION out of episodes of experience, combining steps into those with super-actions. The d_ts timestep flag is calculated accordingly.

[11]:
config.training(
    learner_connector=lambda obs_space, act_space: (
        discounting.MakeAddedStepActionValid(expected_train_batch_size=config.train_batch_size),
        discounting.CondenseMultiStepActions(),
    ),
)
[11]:
<ray.rllib.algorithms.ppo.ppo.PPOConfig at 0x11006dc30>

Lastly, the TimeDiscountedGAEPPOTorchLearner is used, as in :doc:examples/time_discounted_gae.

[12]:
config.training(learner_class=discounting.TimeDiscountedGAEPPOTorchLearner)
[12]:
<ray.rllib.algorithms.ppo.ppo.PPOConfig at 0x11006dc30>

Note that when using these connectors, only the requires_retasking flag will case agents to select a new action. Step timeouts due to max_step_duration will not trigger retasking.

Training the Agent

At this point, the PPO config can be trained as desired.

[13]:
import ray
from ray import tune

ray.init(
    ignore_reinit_error=True,
    num_cpus=N_CPUS,
    object_store_memory=2_000_000_000,  # 2 GB
)

# Run the training
tune.run(
    "PPO",
    config=config.to_dict(),
    stop={"training_iteration": 2},  # Adjust the number of iterations as needed
)

# Shutdown Ray
ray.shutdown()
2024-09-12 15:05:23,877 INFO worker.py:1783 -- Started a local Ray instance.
2024-09-12 15:05:24,190 INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949
/Users/markstephenson/avslab/refactor/.venv_refactor/lib/python3.10/site-packages/gymnasium/spaces/box.py:130: UserWarning: WARN: Box bound precision lowered by casting to float32
  gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
/Users/markstephenson/avslab/refactor/.venv_refactor/lib/python3.10/site-packages/gymnasium/utils/passive_env_checker.py:164: UserWarning: WARN: The obs returned by the `reset()` method was expecting numpy array dtype to be float32, actual type: float64
  logger.warn(
/Users/markstephenson/avslab/refactor/.venv_refactor/lib/python3.10/site-packages/gymnasium/utils/passive_env_checker.py:188: UserWarning: WARN: The obs returned by the `reset()` method is not within the observation space.
  logger.warn(f"{pre} is not within the observation space.")

Tune Status

Current time:2024-09-12 15:05:54
Running for: 00:00:30.52
Memory: 13.5/16.0 GiB

System Info

Using FIFO scheduling algorithm.
Logical resource usage: 3.0/3 CPUs, 0/0 GPUs

Trial Status

Trial name status loc iter total time (s) num_env_steps_sample d_lifetime num_episodes_lifetim e num_env_steps_traine d_lifetime
PPO_ConstellationTasking-RLlib_ba7e4_00000TERMINATED127.0.0.1:95948 2 16.33120001200
(PPO pid=95948) Install gputil for GPU system monitoring.
(MultiAgentEnvRunner pid=95950) 2024-09-12 15:05:44,475 sats.satellite.Scanner-4       WARNING    <6740.00> Scanner-4: failed battery_valid check

Trial Progress

Trial name env_runners fault_tolerance learners num_agent_steps_sampled_lifetime num_env_steps_sampled_lifetime num_env_steps_trained_lifetime num_episodes_lifetimeperf timers
PPO_ConstellationTasking-RLlib_ba7e4_00000{'num_env_steps_sampled': 600, 'num_agent_steps_sampled_lifetime': {'Scanner-3': 1800, 'Scanner-2': 1800, 'Scanner-4': 1406, 'Scanner-1': 1765}, 'num_agent_steps_sampled': {'Scanner-2': 600, 'Scanner-4': 300, 'Scanner-1': 565, 'Scanner-3': 600}, 'num_env_steps_sampled_lifetime': 2400, 'num_module_steps_sampled_lifetime': {'p0': 6771}, 'num_episodes': 0, 'num_module_steps_sampled': {'p0': 2065}, 'episode_return_mean': nan, 'episode_return_min': nan, 'episode_return_max': nan}{'num_healthy_workers': 2, 'num_in_flight_async_reqs': 0, 'num_remote_worker_restarts': 0}{'p0': {'vf_loss': 0.07396450638771057, 'num_non_trainable_parameters': 0.0, 'default_optimizer_learning_rate': 3e-05, 'policy_loss': 1.1960660219192505, 'num_trainable_parameters': 8452101.0, 'mean_kl_loss': 0.0, 'total_loss': 1.2700304985046387, 'vf_explained_var': 0.2418355941772461, 'vf_loss_unclipped': 0.07396450638771057, 'curr_entropy_coeff': 0.0, 'gradients_default_optimizer_global_norm': 3.470919609069824, 'num_module_steps_trained': 631, 'entropy': 1.3778905868530273}, '__all_modules__': {'num_trainable_parameters': 8452101.0, 'total_loss': 1.2700304985046387, 'num_non_trainable_parameters': 0.0, 'num_module_steps_trained': 631, 'num_env_steps_trained': 600}}{'Scanner-1': 1165, 'Scanner-2': 1200, 'Scanner-3': 1200, 'Scanner-4': 853} 1200 1200 0{'cpu_util_percent': 16.16, 'ram_util_percent': 84.24999999999997}{'env_runner_sampling_timer': 6.925581849465962, 'learner_update_timer': 2.1881487244431628, 'synch_weights': 0.009959223236655818, 'synch_env_connectors': 0.006511321889702231}
(MultiAgentEnvRunner pid=95949) 2024-09-12 15:05:52,013 sats.satellite.Scanner-1       WARNING    <15360.00> Scanner-1: failed battery_valid check
2024-09-12 15:05:54,737 INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/Users/markstephenson/ray_results/PPO_2024-09-12_15-05-24' in 0.0020s.
2024-09-12 15:05:55,777 INFO tune.py:1041 -- Total run time: 31.59 seconds (30.51 seconds for the tuning loop).