Shielded training with action masking and action replacement

This example script demonstrates how to train and test a policy with shields. Although DRL can lead to a high-performing policy, it still lacks safety guarantees. Further, adding negative rewards for reaching unsafe states can help the agent learn the safety aspects of the problems, but still lacks guarantees while being prone to issues. Shields can be used to provide these guarantees while minimizing interference. Two main approaches to provide safety during training and testing are action masking and action replacement.

Action replacement: The agent selects an action and the shield checks if it is a safe action. If it is deemed unsafe, the action is replaced with a safe action. This information might or might not be used for the policy update during training (the latter is implemented in this script). Optionally, an interference penalty can be provided to the agent for every shield correction.

Action masking: Provides a masking function, containing allowed and not-allowed actions at a given state. These method alters the logits (raw outputs) of the actor network, such that the agent can only sample safe actions. In this case, the safe selected action is accounted for in the policy update.

This paper is associated with Performance Evaluation of Shielded Neural Networks for Autonomous Agile Earth Observing Satellites in Long Term Scenarios and a future publication. Failure penalties, action masking and action replacement with different interference penalties were investigated for training and their performance analyzed during testing.

[1]:
import gymnasium as gym
from gymnasium import ActionWrapper, ObservationWrapper, RewardWrapper, Wrapper

from Basilisk.architecture import bskLogging
from bsk_rl import act, data, obs, scene, sats
from bsk_rl.sim import dyn, fsw, world
from bsk_rl.utils.orbital import random_orbit
from Basilisk.utilities import orbitalMotion
from typing import Callable, Union, Dict, Any, Iterable, TypeVar
from pathlib import Path
from bsk_rl.gym import SatelliteTasking
import pathlib
import numpy as np

Satellite = TypeVar("Satellite")
SatObs = TypeVar("SatObs")
SatAct = TypeVar("SatAct")
MultiSatObs = tuple[SatObs, ...]
MultiSatAct = Iterable[SatAct]
SatArgRandomizer = Callable[[list[Satellite]], dict[Satellite, dict[str, Any]]]

bskLogging.setDefaultLogLevel(bskLogging.BSK_WARNING)

Different environment configurations

Three different environment options are available:

  • no_failure_penalty: No penalty is assigned for failure

  • failure_penalty: Penalty of -10 is assigned for failure

  • inf_power: Case with unlimited power and no reaction wheel speed limits

[2]:
ALTITUDE = 800  # km

T_ORBIT = (
    2
    * np.pi
    * np.sqrt((orbitalMotion.REQ_EARTH + ALTITUDE) ** 3 / orbitalMotion.MU_EARTH)
)

config = dict(
    general_sat_args=dict(
        oe=lambda: random_orbit(
            alt=ALTITUDE,  # 800 km altitude
            i=45,  # 45 degrees inclination
        ),
        imageAttErrorRequirement=0.01,  # 0.01 MRP normalized
        imageRateErrorRequirement=0.01,  # 0.01 rad/s
        u_max=0.4,  # Maximum control input
        K1=0.25,  # Control gain
        K3=3.0,  # Control gain
        servo_P=150 / 5,  # Servo gain
        omega_max=np.radians(5.0),  # Maximum rate command in degrees per second
        imageTargetMinimumElevation=np.arctan(
            800 / 500
        ),  # 58 degrees minimum elevation
        rwBasePower=20,
        transmitterPacketSize=0,
    ),
    sat_args=dict(
        no_failure_penalty=dict(
            batteryStorageCapacity=80.0 * 3600 * 2,
            storedCharge_Init=lambda: np.random.uniform(0.4, 1.0) * 80.0 * 3600 * 2,
            maxWheelSpeed=1500,
            wheelSpeeds=lambda: np.random.uniform(
                -750,
                750,
                3,
            ),
            dataStorageCapacity=2 * 8e6 * 25,  # Stores 50 images
            storageInit=lambda: np.random.randint(
                0,
                0.99 * 2 * 8e6 * 25,
            ),
        ),
        failure_penalty=dict(
            batteryStorageCapacity=80.0 * 3600 * 2,
            storedCharge_Init=lambda: np.random.uniform(0.4, 1.0) * 80.0 * 3600 * 2,
            maxWheelSpeed=1500,
            wheelSpeeds=lambda: np.random.uniform(
                -750,
                750,
                3,
            ),
            dataStorageCapacity=2 * 8e6 * 25,  # Stores 50 images
            storageInit=lambda: np.random.randint(
                0,
                0.99 * 2 * 8e6 * 25,
            ),
        ),
        inf_power=dict(
            batteryStorageCapacity=80.0 * 3600 * 2 * 1000,
            storedCharge_Init=0.99 * 80.0 * 3600 * 2 * 1000,
            maxWheelSpeed=15000,
            wheelSpeeds=lambda: np.random.uniform(
                -1,
                1,
                3,
            ),
            dataStorageCapacity=2 * 8e6 * 25,  # Stores 50 images
            storageInit=lambda: np.random.randint(
                0,
                0.99 * 2 * 8e6 * 25,
            ),
        ),
    ),
    sim_params=dict(
        horizon=3,
        max_step_duration=300.0,
        sim_rate=0.5,
        failure_penalty=dict(
            no_failure_penalty=0.0,
            failure_penalty=-10.0,
            inf_power=0.0,
        ),
    ),
)

Defining the satellite

The observation space is based on Learning Policies for Autonomous Earth-Observing Satellite Scheduling over Semi-MDPs.

[3]:
def s_hat_H(sat: Satellite):
    """
    Computes the unit vector from the satellite body frame to the Sun in the Hill frame.
    """
    r_SN_N = (
        sat.simulator.world.gravFactory.spiceObject.planetStateOutMsgs[
            sat.simulator.world.sun_index
        ]
        .read()
        .PositionVector
    )
    r_BN_N = sat.dynamics.r_BN_N
    r_SB_N = np.array(r_SN_N) - np.array(r_BN_N)
    r_SB_H = sat.dynamics.HN @ r_SB_N
    return r_SB_H / np.linalg.norm(r_SB_H)


class Density(obs.Observation):
    def __init__(
        self,
        interval_duration=60 * 3,
        intervals=10,
        norm=3,
    ):
        self.satellite: "sats.ImagingSatellite"
        super().__init__()
        self.interval_duration = interval_duration
        self.intervals = intervals
        self.norm = norm

    def get_obs(self):
        if self.intervals == 0:
            return []

        self.satellite.calculate_additional_windows(
            self.simulator.sim_time
            + (self.intervals + 1) * self.interval_duration
            - self.satellite.window_calculation_time
        )
        soonest = self.satellite.upcoming_opportunities_dict(types="target")
        rewards = np.array([opportunity.priority for opportunity in soonest])
        times = np.array([opportunities[0][1] for opportunities in soonest.values()])
        time_bins = np.floor((times - self.simulator.sim_time) / self.interval_duration)
        densities = [sum(rewards[time_bins == i]) for i in range(self.intervals)]
        return np.array(densities) / self.norm


class CustomSatComposed(sats.ImagingSatellite):
    observation_spec = [
        obs.SatProperties(
            dict(prop="omega_BN_B", norm=0.03),  # 3
            dict(prop="c_hat_H"),  # 3
            dict(prop="r_BN_P", norm=orbitalMotion.REQ_EARTH * 1e3),  # 3
            dict(prop="v_BN_P", norm=7616.5),  # 3
            dict(prop="battery_charge_fraction"),  # 1
            dict(prop="storage_level_fraction"),  # 1
            dict(prop="wheel_speeds_fraction"),  # 3
            dict(prop="s_hat_H", fn=s_hat_H),  # 3
        ),
        obs.OpportunityProperties(
            dict(prop="opportunity_open", norm=T_ORBIT),
            dict(prop="opportunity_close", norm=T_ORBIT),
            type="ground_station",
            n_ahead_observe=1,
        ),  # 2
        obs.Eclipse(norm=T_ORBIT),  # 2
        Density(intervals=20, norm=5),  # 20
        obs.OpportunityProperties(
            dict(prop="priority"),  # 32
            dict(prop="r_LB_H", norm=orbitalMotion.REQ_EARTH * 1e3),  # 32*3
            dict(prop="target_angle", norm=np.pi / 2),  # 32
            dict(prop="target_angle_rate", norm=0.03),  # 32
            dict(prop="opportunity_open", norm=300.0),  # 32
            dict(prop="opportunity_close", norm=300.0),  # 32
            type="target",
            n_ahead_observe=32,
        ),
    ]

    action_spec = [
        act.Charge(duration=60.0),  # 1
        act.Downlink(),  # 1
        act.Desat(duration=60.0),  # 1
        act.Image(n_ahead_image=32),  # 32
    ]

    dyn_type = dyn.FullFeaturedDynModel
    fsw_type = fsw.SteeringImagerFSWModel

Function setup_env is defined such that environment and the environment configurations can created for training and testing.

[4]:
def setup_env(
    test: bool = True,
    horizon: float = 90,
    n_targets: Union[tuple, int] = (100, 3000),
    target_distribution: str = "cities",
    env_case: str = "no_failure_penalty",
):
    """
    Setup the environment for the satellite tasking problem.

    Args:
        test: If True, sets up the environment for testing.
        horizon: The time horizon for the simulation in orbits.
        n_targets: The number of targets in the environment.
        target_distribution: The distribution of targets, either "uniform" or "cities".
        env_case: The environment case to use from the configuration.

    Returns:
        env: The configured environment (for testing).
        satellite: The satellite object configured with the specified parameters.
        env_args_dict: Dictionary of environment arguments (for training).
        indexes: Dictionary of indexes for specific observations.
    """

    if env_case not in config["sat_args"]:
        raise ValueError(f"Environment case '{env_case}' not found in configuration.")

    if target_distribution == "uniform":
        scene_features = scene.UniformTargets(n_targets=n_targets)
    elif target_distribution == "cities":
        scene_features = scene.CityTargets(n_targets=n_targets)
    else:
        raise (ValueError("Invalid distribution type"))

    sat_args = dict(
        **config["general_sat_args"],
        **config["sat_args"][env_case],
    )

    satellite = CustomSatComposed(
        "EarthObserving",
        sat_args=sat_args,
    )

    if test:
        env = gym.make(
            "SatelliteTasking-v1",
            satellite=satellite,
            world_type=world.GroundStationWorldModel,
            scenario=scene_features,
            rewarder=data.UniqueImageReward(),
            time_limit=np.floor(T_ORBIT * horizon),
            log_level="WARNING",
            failure_penalty=0.0,  # NO FAILURE PENALTY IN TEST
            sim_rate=config["sim_params"]["sim_rate"],
            max_step_duration=config["sim_params"]["max_step_duration"],
        )
        env_args_dict = None

    else:
        env = None
        env_args_dict = dict(
            satellite=satellite,
            world_type=world.GroundStationWorldModel,
            scenario=scene_features,
            rewarder=data.UniqueImageReward(),
            time_limit=np.floor(T_ORBIT * horizon),
            log_level="WARNING",
            failure_penalty=config["sim_params"]["failure_penalty"][env_case],
            sim_rate=config["sim_params"]["sim_rate"],
            max_step_duration=config["sim_params"]["max_step_duration"],
        )

        env = gym.make(
            "SatelliteTasking-v1",
            satellite=satellite,
            world_type=world.GroundStationWorldModel,
            scenario=scene_features,
            rewarder=data.UniqueImageReward(),
            time_limit=np.floor(0.1 * T_ORBIT),
            log_level="WARNING",
            failure_penalty=config["sim_params"]["failure_penalty"][env_case],
            sim_rate=config["sim_params"]["sim_rate"],
            max_step_duration=config["sim_params"]["max_step_duration"],
        )

    # Getting observation indexes - useful for shields
    indexes = {
        "wheel_speeds": [
            env.satellite.observation_builder.obs_array_keys().index(
                "sat_props.wheel_speeds_fraction[0]"
            ),
            env.satellite.observation_builder.obs_array_keys().index(
                "sat_props.wheel_speeds_fraction[1]"
            ),
            env.satellite.observation_builder.obs_array_keys().index(
                "sat_props.wheel_speeds_fraction[2]"
            ),
        ],
        "stored_charge": [
            env.satellite.observation_builder.obs_array_keys().index(
                "sat_props.battery_charge_fraction"
            )
        ],
        "attitude_rate": [
            env.satellite.observation_builder.obs_array_keys().index(
                "sat_props.omega_BN_B_normd[0]"
            ),
            env.satellite.observation_builder.obs_array_keys().index(
                "sat_props.omega_BN_B_normd[1]"
            ),
            env.satellite.observation_builder.obs_array_keys().index(
                "sat_props.omega_BN_B_normd[2]"
            ),
        ],
        "eclipse": [
            env.satellite.observation_builder.obs_array_keys().index("eclipse[0]"),
            env.satellite.observation_builder.obs_array_keys().index("eclipse[1]"),
        ],
    }

    return env, satellite, env_args_dict, indexes

Action replacement and action masking wrappers

Wrappers are used to extend the capabilities of the base environment:

  • WrapperActionLogging provides useful logging metrics that can be easily accessed during training and testing.

  • WrapperPostPosed implements the action replacement logic. The wrapper receives a function that is used to monitor the action selected by the agent and modifies it if necessary. It is possible to specify an interference penalty.

  • WrapperActionMasking extends the observation returned by the agent to incorporate the masking, which will be used with a modified RLmodule during training.

[5]:
class WrapperActionLogging(Wrapper):

    def __init__(
        self,
        env: Satellite,
    ):

        super().__init__(env)
        self._initialize_action_logger()

    def _initialize_action_logger(self):
        self.action_logger = {
            "action_charge_count": 0,
            "action_downlink_count": 0,
            "action_desat_count": 0,
            "action_image_count": 0,
            "actions_total_count": 0,
        }
        self.shield_info = {
            "shield_interference": False,
            "original_action": None,
            "shielded_action": None,
            "shield_interference_count": 0,
            "shield_penalty_total": 0.0,
            "masking_all_actions_available_count": 0,
        }

    def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None):
        self._initialize_action_logger()
        return self.env.reset(seed=seed, options=options)

    def step(self, action: int):

        if action == 0:
            self.action_logger["action_charge_count"] += 1
        elif action == 1:
            self.action_logger["action_downlink_count"] += 1
        elif action == 2:
            self.action_logger["action_desat_count"] += 1
        elif action >= 3:
            self.action_logger["action_image_count"] += 1
        self.action_logger["actions_total_count"] += 1

        return self.env.step(action)


class WrapperPostPosed(WrapperActionLogging, ActionWrapper, RewardWrapper):
    """
    A wrapper that allows for post-posing shields in a gym environment.
    """

    def __init__(
        self,
        env: Satellite,
        shield_function: Callable[[list[float], int], int] = None,
        shield_penalty: float = 0.0,
    ):
        super().__init__(env)
        self.shield_function = shield_function
        self.shield_penalty = shield_penalty

    def action(self, action: int) -> int:
        original_action = action

        shielded_action = self.shield_function(
            self.env.satellite.get_obs(), original_action
        )

        if shielded_action is None or shielded_action == original_action:
            modified_action = original_action
            self.shield_info["shield_interference"] = False
            self.shield_info["original_action"] = original_action
            self.shield_info["shielded_action"] = original_action

        else:
            modified_action = shielded_action
            self.shield_info["shield_interference"] = True
            self.shield_info["original_action"] = original_action
            self.shield_info["shielded_action"] = shielded_action
            self.shield_info["shield_interference_count"] += 1

        return modified_action

    def reward(self, reward: float) -> float:

        if self.shield_info["shield_interference"]:
            reward += self.shield_penalty
            self.shield_info["shield_penalty_total"] += self.shield_penalty

        return reward


class WrapperActionMasking(ObservationWrapper):
    """
    A wrapper that allows for action-masking in a gym environment.
    """

    def __init__(
        self, env: Satellite, masking_function: Callable[[list[float]], list[int]]
    ):
        super().__init__(env)
        self.masking_function = masking_function
        self.valid_actions = None

    @property
    def observation_space(self):
        """Return the single satellite observation space."""
        self.unwrapped.observation_space
        obs_space = gym.spaces.Dict(
            {
                "action_mask": gym.spaces.Box(0.0, 1.0, shape=(self.action_space.n,)),
                "observations": self.unwrapped.satellite.observation_space,
            }
        )
        return obs_space

    def observation(self, observation: list) -> dict:
        self.valid_actions = np.array(
            self.masking_function(observation), dtype=np.float32
        )
        n_available_actions = np.sum(self.valid_actions)
        if n_available_actions == len(self.valid_actions):
            self.shield_info["masking_all_actions_available_count"] += 1
        if n_available_actions == 0:
            # if no actions are available, all actions are allowed
            self.valid_actions = np.ones(len(self.valid_actions), dtype=np.float32)

        observation_with_mask = {
            "action_mask": self.valid_actions,
            "observations": observation,
        }

        return observation_with_mask

Handmade shield

The handmade shield is an example of a simple shield constructed based on heuristics based on A comparative analysis of reinforcement learning algorithms for earth-observing satellite scheduling and Reinforcement Learning for Earth-Observing Satellite Autonomy with Event-Based Task Intervals. Although it provides a good compromise between performance and safety, it is prone to edge cases that are hard to model and predict. Instead, automated shields with stronger safety guarantees can be used, such as those discussed in Shielded Deep Reinforcement Learning for Complex Spacecraft Tasking.

[6]:
def power_shielding_function(
    obs: Union[list, np.ndarray],
    act: Union[int, None],
    indexes: dict[str, list[int]],
    min_power: float = 0.25,
    charge_rate: float = 1.0,
    discharge_rate: float = 1.0,
    rw_threshold: float = 0.7,
) -> Union[None, int]:
    """Force charging if not in eclipse and below min_power or the time

    Args:
        obs: Observation vector
        act: unshielded action
        indexes: Dictionary with indexes of the observation vector
        min_power: Minimum battery percentage. [%]
        charge_rate: Rate of charging in charge mode. [%/orbit]
        discharge_rate: Rate of discharge in eclipse. [%/orbit]
        rw_threshold: Threshold for reaction wheel speeds. [%]

    Returns:
        Return None if shield not activated, else return shielded action.
    """

    current_power = obs[indexes["stored_charge"][0]]
    eclipse_start = obs[indexes["eclipse"][0]]
    eclipse_end = obs[indexes["eclipse"][1]]
    rw_1 = obs[indexes["wheel_speeds"][0]]
    rw_2 = obs[indexes["wheel_speeds"][1]]
    rw_3 = obs[indexes["wheel_speeds"][2]]

    in_eclipse = eclipse_end < eclipse_start

    # Check if current state is unsafe
    if not in_eclipse and current_power < _power_requirement(
        eclipse_start, eclipse_end, min_power, charge_rate, discharge_rate
    ):
        return 0  # Returns charge action

    else:
        if any(np.abs(np.array([rw_1, rw_2, rw_3])) > rw_threshold):
            return 2  # Returns desaturate action

    return None


def _power_requirement(
    eclipse_start: float,
    eclipse_end: float,
    min_power: float,
    charge_rate: float,
    discharge_rate: float,
) -> float:
    eclipse_duration = (eclipse_end - eclipse_start) % 1
    in_eclipse = eclipse_end < eclipse_start
    if in_eclipse:
        return min_power + eclipse_end * discharge_rate
    else:
        eclipse_draw = eclipse_duration * discharge_rate
        charge_time = eclipse_draw / charge_rate
        if charge_time < eclipse_start:
            return min_power
        else:
            return min_power + (charge_time - eclipse_start) * charge_rate

The handmade shield can be used in both action replacement and masking. Function generate_shield_functions creates the adequate function for each case.

[7]:
ACTION_SPACE_SIZE = 35


def generate_shield_functions(
    shield_type: str, shield_mode: str, indexes: dict[str, list]
) -> Union[
    Callable[[list[float], int], Union[int, None]], Callable[[list[float]], list[int]]
]:
    """
    Generates shield functions based on the specified shield type and mode.
    Args:
        shield_type: Type of the shield (-1 for handmade, 0 for optimal shielding, 1 for two-step strategy, 2 for value function).
        shield_mode: Mode of the shield ("postposed" or "action_masking").
        indexes: Dictionary containing indexes for specific observations.
    Returns:
        Callable: A function that either shields actions or masks actions based on the shield type and mode.
    Raises:
        ValueError: If the shield type or mode is invalid.
    """

    if shield_type not in ["unshielded", "handmade"]:
        raise ValueError(f"Invalid shield type: {shield_type}")
    if shield_mode not in ["postposed", "action_masking"]:
        raise ValueError(
            f"Invalid shield mode: {shield_mode} for shield type: {shield_type}"
        )

    if shield_type == "unshielded":

        if shield_mode == "postposed":

            def shield_function(obs: list[float], act: int) -> int:
                return act  # No shielding, return the action as is

            return shield_function

        elif shield_mode == "action_masking":

            def mask_function(obs: list[float]) -> list[int]:
                return [1] * ACTION_SPACE_SIZE  # All actions are valid

            return mask_function

    elif shield_type == "handmade":

        if shield_mode == "postposed":

            def shield_function(obs: list[float], act: int) -> Union[int, None]:
                return power_shielding_function(obs, act, indexes)

            return shield_function

        elif shield_mode == "action_masking":

            def mask_function(obs: list[float]) -> list[int]:
                shielded_action = power_shielding_function(obs, None, indexes)
                if shielded_action is None:
                    return [1] * ACTION_SPACE_SIZE
                else:
                    if shielded_action == 0:
                        mask_vector = [0] * ACTION_SPACE_SIZE
                        mask_vector[0] = 1
                        return mask_vector
                    elif shielded_action == 2:
                        mask_vector = [0] * ACTION_SPACE_SIZE
                        mask_vector[2] = 1
                        return mask_vector

            return mask_function

An modified RLModule from RLLib is used for the action masking module with a small modification. ActionMaskingTorchRLModule uses a logit-level log-infinity mask where

\[\mathbf{l}^\varphi=\mathbf{l}+log(\mathbf{m})\]

such that the original logits \(\mathbf{l}\) are added to the log of the mask vector \(\mathbf{m}\) (element wise) resulting in the \(\mathbf{l}^\varphi\). When the action probabilities are computed using softmax

\[\pi^\varphi(a,s)=\frac{e^{l_a^\varphi}}{\sum_{a'\in\mathcal{A}}e^{l_a'^\varphi}}\]

actions masked out receive zero probability. As implemented, the probability provided by \(\pi^\varphi\) is used in the policy update.

[8]:
from ray.rllib.examples.rl_modules.classes.action_masking_rlm import (
    ActionMaskingTorchRLModule as BaseActionMaskingTorchRLModule,
)
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.utils.typing import TensorType
from ray.rllib.utils.annotations import override


class ActionMaskingTorchRLModule(BaseActionMaskingTorchRLModule):

    @override(ValueFunctionAPI)
    def compute_values(self, batch: Dict[str, TensorType]):
        # Preprocess the batch to extract the `observations` to `Columns.OBS`.
        _, batch = self._preprocess_batch(batch)
        # Call the super's method to compute values for GAE.
        return super(BaseActionMaskingTorchRLModule, self).compute_values(batch)

Training setup

First, the data callback is defined and metrics provided by the WrapperActionLogging wrapper stored.

[9]:
from bsk_rl.utils.rllib.callbacks import EpisodeDataLogger
from numpy import floating
from ray.rllib.core.rl_module.rl_module import RLModuleSpec


def episode_data_callback(env) -> dict[str, float | floating[Any]]:
    """
    Collects data at the end of each episode for the RLlib environment.
    """
    reward = env.rewarder.cum_reward
    reward = sum(reward.values()) / len(reward)
    orbits = env.simulator.sim_time / T_ORBIT
    imaged = env.satellite.imaged
    missed = env.satellite.missed
    count_charge = env.action_logger["action_charge_count"]
    count_downlink = env.action_logger["action_downlink_count"]
    count_desat = env.action_logger["action_desat_count"]
    count_image = env.action_logger["action_image_count"]
    count_total_actions = env.action_logger["actions_total_count"]
    count_shield_interference = env.shield_info["shield_interference_count"]
    shield_penalty_total = env.shield_info["shield_penalty_total"]
    masking_all_actions_available_count = env.shield_info[
        "masking_all_actions_available_count"
    ]
    mask_active_percent = (
        100 - (masking_all_actions_available_count / count_total_actions * 100)
        if count_total_actions > 0
        else 0
    )

    data = dict(
        reward=reward,
        alive_percentage=float(env.satellite.is_alive()),
        imaged=imaged,
        missed=missed,
        orbits_complete=orbits,
        data_storage_capacity=env.satellite.dynamics.storageUnit.storageCapacity,
        battery_capacity=env.satellite.dynamics.powerMonitor.storageCapacity,
        external_torque=np.linalg.norm(
            env.satellite.dynamics.extForceTorqueObject.extTorquePntB_B
        ),
        valid_battery=float(
            env.satellite.dynamics.battery_valid()
        ),  # True if the battery is valid
        valid_rw=float(
            env.satellite.dynamics.rw_speeds_valid()
        ),  # True if RW speeds are valid
        count_charge_action=count_charge,
        count_downlink_action=count_downlink,
        count_desat_action=count_desat,
        count_image_action=count_image,
        count_total_actions=count_total_actions,
        count_shield_interference=count_shield_interference,
        shield_penalty_total=shield_penalty_total,
        masking_all_actions_available_count=masking_all_actions_available_count,
        mask_active_percent=mask_active_percent,
    )

    if orbits > 0:
        data["reward_per_orbit"] = reward / orbits
        data["imaged_per_orbit"] = imaged / orbits
        data["count_charge_action_per_orbit"] = count_charge / orbits
        data["count_downlink_action_per_orbit"] = count_downlink / orbits
        data["count_desat_action_per_orbit"] = count_desat / orbits
        data["count_image_action_per_orbit"] = count_image / orbits
        data["count_total_actions_per_orbit"] = count_total_actions / orbits
        data["attempts_per_orbit"] = (imaged + missed) / orbits

    if not env.satellite.is_alive():
        data["orbits_complete_partial_only"] = orbits

    if imaged == 0:
        data["avg_tgt_val"] = 0
        data["success_rate"] = 0
    else:
        data["avg_tgt_val"] = reward / imaged
        data["success_rate"] = imaged / (imaged + missed)

    data["attempts"] = imaged + missed

    if count_total_actions > 0:
        data["count_shield_interference_percent"] = (
            count_shield_interference / count_total_actions * 100
        )

    return data

Next, the training is configured with the options to use different environment cases, shields, and shield modes.

[10]:
env_case = "no_failure_penalty"  # no_failure_penalty, failure_penalty, inf_power
shield_type = "handmade"  # unshielded or handmade
shield_mode = "action_masking"  # postposed or action_masking
shield_penalty = -0.1

_, _, env_args, indexes = setup_env(
    test=False,
    horizon=3,
    n_targets=(100, 3000),
    target_distribution="cities",
    env_case=env_case,
)


training_args = dict(
    lr=0.00003,
    gamma=0.997,
    train_batch_size=int(128),  # Originally 3000
    num_sgd_iter=10,
    lambda_=0.95,
    use_kl_loss=False,
    clip_param=0.2,
    grad_clip=0.5,
    entropy_coeff=0.0,
)
rl_module_args = dict(
    model_config_dict={
        "use_lstm": False,
        "fcnet_hiddens": [2048] * 2,
        "vf_share_layers": False,
    },
)


shield_function = generate_shield_functions(shield_type, shield_mode, indexes)

if shield_mode == "postposed":
    shield_function = generate_shield_functions(shield_type, "postposed", indexes)
elif shield_mode == "action_masking":
    mask_function = generate_shield_functions(shield_type, "action_masking", indexes)

rl_module_args = {}
if shield_mode == "postposed" or shield_type == "unshielded":

    def env_creation(**env_config) -> WrapperPostPosed:
        env = SatelliteTasking(**env_config)
        env = WrapperPostPosed(
            env, shield_function=shield_function, shield_penalty=shield_penalty
        )

        return env

elif shield_mode == "action_masking":
    rl_module_args["rl_module_spec"] = RLModuleSpec(
        module_class=ActionMaskingTorchRLModule,
    )

    def env_creation(**env_config) -> WrapperActionMasking:
        env = SatelliteTasking(**env_config)
        env = WrapperActionLogging(env)
        env = WrapperActionMasking(env, masking_function=mask_function)

        return env


class Env_wrapped(EpisodeDataLogger, Wrapper):
    def __init__(self, env_config):
        episode_data_callback = env_config.pop("episode_data_callback", None)
        satellite_data_callback = env_config.pop("satellite_data_callback", None)
        env = env_creation(**env_config)
        EpisodeDataLogger.__init__(self, episode_data_callback, satellite_data_callback)
        Wrapper.__init__(self, env)


env_args["episode_data_callback"] = episode_data_callback
2025-08-25 18:19:33,681 utils.orbital                  WARNING    Ignoring a, e, and omega and using alt and r_body to generate a circular orbit.random_circular_orbit is preferred for this use case.

Training algorithm is configured and initialized with a maximum of 264 steps.

[11]:
import ray
from bsk_rl.utils.rllib.callbacks import WrappedEpisodeDataCallbacks
from bsk_rl.utils.rllib.discounting import TimeDiscountedGAEPPOTorchLearner
from ray.rllib.algorithms.ppo import PPOConfig
from ray import tune

N_CPUS = 3  # Originally 32

ppo_config = (
    PPOConfig()
    .training(
        **training_args,
        learner_class=TimeDiscountedGAEPPOTorchLearner,
    )
    .env_runners(num_env_runners=N_CPUS - 1, sample_timeout_s=1000.0)
    .environment(
        env=Env_wrapped,
        env_config=env_args,
    )
    .reporting(
        metrics_num_episodes_for_smoothing=1,
        metrics_episode_collection_timeout_s=180,
    )
    .checkpointing(export_native_model_files=True)
    .framework(framework="torch")
    .api_stack(
        enable_rl_module_and_learner=True,
        enable_env_runner_and_connector_v2=True,
    )
    .callbacks(WrappedEpisodeDataCallbacks)
)
ppo_config.rl_module(**rl_module_args)

ray.init(
    ignore_reinit_error=True,
    num_cpus=N_CPUS,
    object_store_memory=2_000_000_000,  # 2 GB
)

# Run the training
results = tune.run(
    "PPO",
    config=ppo_config.to_dict(),
    stop={
        "num_env_steps_sampled_lifetime": 264
    },  # Total number of steps to train the model. Originally 20M
    checkpoint_freq=1,
    checkpoint_at_end=True,
)

ray.shutdown()
2025-08-25 18:19:37,029 INFO worker.py:1783 -- Started a local Ray instance.
2025-08-25 18:19:40,614 INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949
/opt/hostedtoolcache/Python/3.11.13/x64/lib/python3.11/site-packages/gymnasium/spaces/box.py:130: UserWarning: WARN: Box bound precision lowered by casting to float32
  gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
/opt/hostedtoolcache/Python/3.11.13/x64/lib/python3.11/site-packages/gymnasium/utils/passive_env_checker.py:164: UserWarning: WARN: The obs returned by the `reset()` method was expecting numpy array dtype to be float32, actual type: float64
  logger.warn(
/opt/hostedtoolcache/Python/3.11.13/x64/lib/python3.11/site-packages/gymnasium/utils/passive_env_checker.py:188: UserWarning: WARN: The obs returned by the `reset()` method is not within the observation space.
  logger.warn(f"{pre} is not within the observation space.")

Tune Status

Current time:2025-08-25 18:20:38
Running for: 00:00:58.13
Memory: 4.7/15.6 GiB

System Info

Using FIFO scheduling algorithm.
Logical resource usage: 3.0/3 CPUs, 0/0 GPUs

Trial Status

Trial name status loc iter total time (s) num_env_steps_sample d_lifetime num_episodes_lifetim e num_env_steps_traine d_lifetime
PPO_Env_wrapped_11007_00000TERMINATED10.1.0.68:6946 3 40.71413842384
(SingleAgentEnvRunner pid=6993) 2025-08-25 18:19:54,506 utils.orbital                  WARNING    Ignoring a, e, and omega and using alt and r_body to generate a circular orbit.random_circular_orbit is preferred for this use case.
(PPO pid=6946) Trainable.setup took 10.175 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.
(PPO pid=6946) Install gputil for GPU system monitoring.

Trial Progress

Trial name env_runners fault_tolerance learners num_agent_steps_sampled_lifetime num_env_steps_sampled_lifetime num_env_steps_trained_lifetime num_episodes_lifetimeperf timers
PPO_Env_wrapped_11007_00000{'num_agent_steps_sampled_lifetime': {'default_agent': 768}, 'reward': nan, 'mask_active_percent': nan, 'num_episodes': 0, 'num_agent_steps_sampled': {'default_agent': 128}, 'data_storage_capacity': nan, 'masking_all_actions_available_count': nan, 'sample': np.float64(11.020856162246542), 'count_shield_interference': nan, 'count_desat_action': nan, 'num_env_steps_sampled_lifetime': 1152, 'battery_capacity': nan, 'count_total_actions': nan, 'count_charge_action': nan, 'avg_tgt_val': nan, 'count_image_action': nan, 'imaged': nan, 'missed': nan, 'attempts': nan, 'num_env_steps_sampled': 128, 'num_module_steps_sampled': {'default_policy': 128}, 'external_torque': nan, 'count_downlink_action': nan, 'orbits_complete': nan, 'valid_battery': nan, 'valid_rw': nan, 'success_rate': nan, 'num_module_steps_sampled_lifetime': {'default_policy': 768}, 'shield_penalty_total': nan, 'alive_percentage': nan, 'module_episode_returns_mean': {'default_policy': 5.326449544249592}, 'episode_return_min': 3.1882948939660705, 'agent_episode_returns_mean': {'default_agent': 5.326449544249592}, 'episode_len_max': 116, 'count_total_actions_per_orbit': nan, 'episode_duration_sec_mean': 17.35324782150002, 'time_between_sampling': np.float64(3.839375548681018), 'attempts_per_orbit': nan, 'count_charge_action_per_orbit': nan, 'episode_len_mean': 108.5, 'count_downlink_action_per_orbit': nan, 'count_image_action_per_orbit': nan, 'episode_return_max': 7.464604194533114, 'count_desat_action_per_orbit': nan, 'count_shield_interference_percent': nan, 'imaged_per_orbit': nan, 'episode_len_min': 101, 'reward_per_orbit': nan, 'episode_return_mean': 5.326449544249592}{'num_healthy_workers': 2, 'num_in_flight_async_reqs': 0, 'num_remote_worker_restarts': 0}{'__all_modules__': {'num_env_steps_trained': 128, 'num_module_steps_trained': 128, 'total_loss': 0.3029525876045227, 'num_trainable_parameters': 294948.0, 'num_non_trainable_parameters': 0.0}, 'default_policy': {'num_non_trainable_parameters': 0.0, 'curr_entropy_coeff': 0.0, 'vf_loss_unclipped': 0.3207870423793793, 'vf_loss': 0.3207870423793793, 'num_module_steps_trained': 128, 'total_loss': 0.3029525876045227, 'vf_explained_var': 0.24379831552505493, 'num_trainable_parameters': 294948.0, 'default_optimizer_learning_rate': 3e-05, 'gradients_default_optimizer_global_norm': 1.9472508430480957, 'mean_kl_loss': 0.0, 'policy_loss': -0.017834434285759926, 'entropy': 3.389056444168091}}{'default_agent': 384} 384 384 2{'cpu_util_percent': np.float64(42.395), 'ram_util_percent': np.float64(29.799999999999994)}{'env_runner_sampling_timer': 14.623570170496375, 'learner_update_timer': 0.0770635410664295, 'synch_weights': 0.006580553033943306, 'synch_env_connectors': 0.006046280278793051}
(PPO pid=6946) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/runner/ray_results/PPO_2025-08-25_18-19-40/PPO_Env_wrapped_11007_00000_0_2025-08-25_18-19-40/checkpoint_000000)
(SingleAgentEnvRunner pid=6992) 2025-08-25 18:19:57,753 utils.orbital                  WARNING    Ignoring a, e, and omega and using alt and r_body to generate a circular orbit.random_circular_orbit is preferred for this use case. [repeated 4x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)
(SingleAgentEnvRunner pid=6993) 2025-08-25 18:20:22,478 utils.orbital                  WARNING    Ignoring a, e, and omega and using alt and r_body to generate a circular orbit.random_circular_orbit is preferred for this use case. [repeated 2x across cluster]
(PPO pid=6946) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/runner/ray_results/PPO_2025-08-25_18-19-40/PPO_Env_wrapped_11007_00000_0_2025-08-25_18-19-40/checkpoint_000001)
(PPO pid=6946) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/runner/ray_results/PPO_2025-08-25_18-19-40/PPO_Env_wrapped_11007_00000_0_2025-08-25_18-19-40/checkpoint_000002)
2025-08-25 18:20:38,783 INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/runner/ray_results/PPO_2025-08-25_18-19-40' in 0.0476s.
2025-08-25 18:20:39,120 INFO tune.py:1041 -- Total run time: 58.51 seconds (58.08 seconds for the tuning loop).

Testing configuration

The original testing environment was 90-orbit long, leading to a significant decrease in the number of available targets over time. To solve this issues, a HiddentargetsMask was introduced, which manages the number of targets visible to the satellite at any given moment through an opportunity filter. Originally, 15,000 targets were created in each environment but only targets_max targets were available to the satellite at any given time. When the agent successfully acquired a target, a target previously hidden would be made available.

[12]:
from itertools import compress


class HiddenTargetsMask:

    def __init__(self, list_targets: list, targets_max: int):
        """Initialize the mask for hidden targets

        Args:
            list_targets: A list of all targets in the environment.
            targets_max: The maximum number of targets available to the agent in a given time step.
            verbose: If True, prints information about the targets.
        """
        self.list_targets = list_targets
        self.targets_max = targets_max
        self.n_total_targets = len(self.list_targets)

        self.mask = [True] * self.n_total_targets
        self.n_imaged = 0
        self.hidden_set = None

    def compute_mask(self, n_imaged: int):

        if self.hidden_set is None or self.n_imaged != n_imaged:
            self.n_imaged = n_imaged
            self.mask = self.replace_targets(
                self.mask,
                self.n_imaged,
                self.n_total_targets,
                self.targets_max,
            )
            self.hidden_set = set(compress(self.list_targets, self.mask))
            return self.hidden_set

        else:
            return None

    @staticmethod
    def replace_targets(
        mask: list[bool],
        n_imaged: int,
        n_total_targets: int,
        n_max_targets: int,
    ) -> list[bool]:
        """Add targets to the environment

        Args:
            mask: The mask of hidden targets.
            n_imaged: The number of targets that have been imaged.
            n_total_targets: The total number of targets in the environment.
            n_max_targets: The maximum number of targets available to the agent in a given time step.
            verbose: If True, prints information about the targets.

        Returns:
            mask: The updated mask of hidden targets.
        """
        mask = np.array(mask, dtype=bool)
        n_hidden = np.sum(mask)
        n_available = n_total_targets - n_imaged - n_hidden
        if n_available < n_max_targets:
            n_new = n_max_targets - n_available
            mask_idxs = np.where(mask)[0]
            readd_idxs = np.random.choice(
                mask_idxs,
                size=n_new,
                replace=False,
            )

            mask[readd_idxs] = False

        return mask.tolist()

load_policy function returns a policy function compatible with the different training methods. If training was performed with action masking, it is necessary to specify embedded_masking=True so that an unmasked environment becomes compatible with it. A different masking function can also be specified to be used during testing.

[13]:
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.core import DEFAULT_MODULE_ID
import torch
from ray.rllib.core.columns import Columns
from ray.rllib.utils.numpy import convert_to_numpy, softmax
from ray.rllib.utils.torch_utils import FLOAT_MIN


def load_policy(policy_path_general: Path) -> Callable:
    """Load a PyTorch policy from a saved model.

    Args:
        policy_path_general: The path to the saved model.
    Returns:
        A function that takes observations and returns actions.
    """

    rl_module = RLModule.from_checkpoint(
        policy_path_general
        / "learner_group"
        / "learner"
        / "rl_module"
        / DEFAULT_MODULE_ID,
    )

    def policy(
        obs: list[float],
        deterministic: bool = True,
        embedded_masking: bool = False,
        masking_function: Union[Callable[list[float], list[float]], None] = None,
    ) -> int:
        """Policy function that takes observations and returns actions.

        Args:
            obs: A list of observations.
            deterministic: If True, use argmax for action selection; otherwise, sample from the action distribution.
            embedded_masking: If True, model was trained with embedded masking and observation needs to be modified.
            masking_function: A function that takes observations and returns a mask for valid actions. If None, no masking is applied.
        Returns:
            An integer representing the selected action.
        """
        if isinstance(obs, dict):
            obs_vec = obs["observations"]
        else:
            obs_vec = obs
        obs_vec = np.array(obs_vec, dtype=np.float32)
        if not embedded_masking:
            input_dict = {Columns.OBS: torch.from_numpy(obs_vec).unsqueeze(0)}

        else:
            if isinstance(obs, list) or isinstance(obs, np.ndarray):
                mask = np.ones(35, dtype=np.float32)
                if masking_function is not None:
                    # If possible, masking is applies inside the RLModule
                    mask = masking_function(obs_vec)

                input_dict = {
                    Columns.OBS: {
                        "observations": torch.from_numpy(obs).unsqueeze(0),
                        "action_mask": torch.from_numpy(mask).unsqueeze(0),
                    }
                }
            else:
                input_dict = {
                    Columns.OBS: {
                        "observations": torch.from_numpy(obs["observations"]).unsqueeze(
                            0
                        ),
                        "action_mask": torch.from_numpy(obs["action_mask"]).unsqueeze(
                            0
                        ),
                    }
                }

        rl_module_out = rl_module.forward_inference(input_dict)
        logits = convert_to_numpy(rl_module_out[Columns.ACTION_DIST_INPUTS])
        if not embedded_masking and masking_function is not None:
            mask = masking_function(obs)
            inf_mask = torch.clamp(torch.log(mask), min=FLOAT_MIN)
            logits[0] += inf_mask.numpy()
        if deterministic:
            action = np.argmax(logits[0])  # Use argmax for deterministic action
        else:
            action = np.random.choice(len(logits[0]), p=softmax(logits[0]))

        return int(action)

    return policy

Testing is then performed specifying the desired shield method and type. The hidden targets are also re-computed at every step.

[14]:
# Loading the policy produced by tune.run()
policy_path = pathlib.Path(results.get_last_checkpoint().to_directory())

shield_type = "handmade"
shield_mode = "postposed"
targets_max = 100  # Originally varying from (100, 3000)
total_targets = 1000  # Originally 15,000
embedded_masking = True  # Necessary when training with action masking. Set to False is trained with action replacement

env_case = "no_failure_penalty"

env, _, _, indexes = setup_env(
    test=True,
    horizon=1.0,  # Orbits. Originally 90
    n_targets=total_targets,
    target_distribution="cities",
    env_case=env_case,
)

shield_function = generate_shield_functions(shield_type, shield_mode, indexes)
policy = load_policy(policy_path)

if shield_mode == "postposed":
    env = WrapperPostPosed(env, shield_function=shield_function)

elif shield_mode == "action_masking":
    raise NotImplementedError(
        "Action masking is not implemented in testing. Cases were tested with postposed shields"
    )

_, _ = env.reset()
reward_cumulative = 0

hidden_targets_mask = HiddenTargetsMask(
    list_targets=env.satellite.data_store.data.known,
    targets_max=targets_max,
)

env.satellite.hidden_targets = hidden_targets_mask.compute_mask(n_imaged=0)


def replace_targets_filter(opp, sat):
    return opp["object"] not in sat.hidden_targets


env.satellite.add_access_filter(
    lambda opp, sat=env.satellite: replace_targets_filter(opp, sat)
)

while True:

    sat = env.satellite

    hidden_targets = hidden_targets_mask.compute_mask(sat.imaged)
    if hidden_targets is not None:
        sat.hidden_targets = hidden_targets
        sat.observation_builder.obs_dict_cache = None

    action = policy(sat.get_obs(), embedded_masking=embedded_masking)

    _, reward, terminated, truncated, _ = env.step(action)

    reward_cumulative += reward

    if terminated or truncated:
        break

print(f"Cumulative reward: {reward_cumulative}")
2025-08-25 18:20:40,572                                WARNING    Creating logger for new env on PID=6649. Old environments in process may now log times incorrectly.
2025-08-25 18:20:40,631 utils.orbital                  WARNING    Ignoring a, e, and omega and using alt and r_body to generate a circular orbit.random_circular_orbit is preferred for this use case.
2025-08-25 18:20:41,163 utils.orbital                  WARNING    Ignoring a, e, and omega and using alt and r_body to generate a circular orbit.random_circular_orbit is preferred for this use case.
Cumulative reward: 0.0