Shielded training with action masking and action replacement
This example script demonstrates how to train and test a policy with shields. Although DRL can lead to a high-performing policy, it still lacks safety guarantees. Further, adding negative rewards for reaching unsafe states can help the agent learn the safety aspects of the problems, but still lacks guarantees while being prone to issues. Shields can be used to provide these guarantees while minimizing interference. Two main approaches to provide safety during training and testing are action masking and action replacement.
Action replacement: The agent selects an action and the shield checks if it is a safe action. If it is deemed unsafe, the action is replaced with a safe action. This information might or might not be used for the policy update during training (the latter is implemented in this script). Optionally, an interference penalty can be provided to the agent for every shield correction.
Action masking: Provides a masking function, containing allowed and not-allowed actions at a given state. These method alters the logits (raw outputs) of the actor network, such that the agent can only sample safe actions. In this case, the safe selected action is accounted for in the policy update.
This paper is associated with Performance Evaluation of Shielded Neural Networks for Autonomous Agile Earth Observing Satellites in Long Term Scenarios and a future publication. Failure penalties, action masking and action replacement with different interference penalties were investigated for training and their performance analyzed during testing.
[1]:
import gymnasium as gym
from gymnasium import ActionWrapper, ObservationWrapper, RewardWrapper, Wrapper
from Basilisk.architecture import bskLogging
from bsk_rl import act, data, obs, scene, sats
from bsk_rl.sim import dyn, fsw, world
from bsk_rl.utils.orbital import random_orbit
from Basilisk.utilities import orbitalMotion
from typing import Callable, Union, Dict, Any, Iterable, TypeVar
from pathlib import Path
from bsk_rl.gym import SatelliteTasking
import pathlib
import numpy as np
Satellite = TypeVar("Satellite")
SatObs = TypeVar("SatObs")
SatAct = TypeVar("SatAct")
MultiSatObs = tuple[SatObs, ...]
MultiSatAct = Iterable[SatAct]
SatArgRandomizer = Callable[[list[Satellite]], dict[Satellite, dict[str, Any]]]
bskLogging.setDefaultLogLevel(bskLogging.BSK_WARNING)
Different environment configurations
Three different environment options are available:
no_failure_penalty
: No penalty is assigned for failurefailure_penalty
: Penalty of -10 is assigned for failureinf_power
: Case with unlimited power and no reaction wheel speed limits
[2]:
ALTITUDE = 800 # km
T_ORBIT = (
2
* np.pi
* np.sqrt((orbitalMotion.REQ_EARTH + ALTITUDE) ** 3 / orbitalMotion.MU_EARTH)
)
config = dict(
general_sat_args=dict(
oe=lambda: random_orbit(
alt=ALTITUDE, # 800 km altitude
i=45, # 45 degrees inclination
),
imageAttErrorRequirement=0.01, # 0.01 MRP normalized
imageRateErrorRequirement=0.01, # 0.01 rad/s
u_max=0.4, # Maximum control input
K1=0.25, # Control gain
K3=3.0, # Control gain
servo_P=150 / 5, # Servo gain
omega_max=np.radians(5.0), # Maximum rate command in degrees per second
imageTargetMinimumElevation=np.arctan(
800 / 500
), # 58 degrees minimum elevation
rwBasePower=20,
transmitterPacketSize=0,
),
sat_args=dict(
no_failure_penalty=dict(
batteryStorageCapacity=80.0 * 3600 * 2,
storedCharge_Init=lambda: np.random.uniform(0.4, 1.0) * 80.0 * 3600 * 2,
maxWheelSpeed=1500,
wheelSpeeds=lambda: np.random.uniform(
-750,
750,
3,
),
dataStorageCapacity=2 * 8e6 * 25, # Stores 50 images
storageInit=lambda: np.random.randint(
0,
0.99 * 2 * 8e6 * 25,
),
),
failure_penalty=dict(
batteryStorageCapacity=80.0 * 3600 * 2,
storedCharge_Init=lambda: np.random.uniform(0.4, 1.0) * 80.0 * 3600 * 2,
maxWheelSpeed=1500,
wheelSpeeds=lambda: np.random.uniform(
-750,
750,
3,
),
dataStorageCapacity=2 * 8e6 * 25, # Stores 50 images
storageInit=lambda: np.random.randint(
0,
0.99 * 2 * 8e6 * 25,
),
),
inf_power=dict(
batteryStorageCapacity=80.0 * 3600 * 2 * 1000,
storedCharge_Init=0.99 * 80.0 * 3600 * 2 * 1000,
maxWheelSpeed=15000,
wheelSpeeds=lambda: np.random.uniform(
-1,
1,
3,
),
dataStorageCapacity=2 * 8e6 * 25, # Stores 50 images
storageInit=lambda: np.random.randint(
0,
0.99 * 2 * 8e6 * 25,
),
),
),
sim_params=dict(
horizon=3,
max_step_duration=300.0,
sim_rate=0.5,
failure_penalty=dict(
no_failure_penalty=0.0,
failure_penalty=-10.0,
inf_power=0.0,
),
),
)
Defining the satellite
The observation space is based on Learning Policies for Autonomous Earth-Observing Satellite Scheduling over Semi-MDPs.
[3]:
def s_hat_H(sat: Satellite):
"""
Computes the unit vector from the satellite body frame to the Sun in the Hill frame.
"""
r_SN_N = (
sat.simulator.world.gravFactory.spiceObject.planetStateOutMsgs[
sat.simulator.world.sun_index
]
.read()
.PositionVector
)
r_BN_N = sat.dynamics.r_BN_N
r_SB_N = np.array(r_SN_N) - np.array(r_BN_N)
r_SB_H = sat.dynamics.HN @ r_SB_N
return r_SB_H / np.linalg.norm(r_SB_H)
class Density(obs.Observation):
def __init__(
self,
interval_duration=60 * 3,
intervals=10,
norm=3,
):
self.satellite: "sats.ImagingSatellite"
super().__init__()
self.interval_duration = interval_duration
self.intervals = intervals
self.norm = norm
def get_obs(self):
if self.intervals == 0:
return []
self.satellite.calculate_additional_windows(
self.simulator.sim_time
+ (self.intervals + 1) * self.interval_duration
- self.satellite.window_calculation_time
)
soonest = self.satellite.upcoming_opportunities_dict(types="target")
rewards = np.array([opportunity.priority for opportunity in soonest])
times = np.array([opportunities[0][1] for opportunities in soonest.values()])
time_bins = np.floor((times - self.simulator.sim_time) / self.interval_duration)
densities = [sum(rewards[time_bins == i]) for i in range(self.intervals)]
return np.array(densities) / self.norm
class CustomSatComposed(sats.ImagingSatellite):
observation_spec = [
obs.SatProperties(
dict(prop="omega_BN_B", norm=0.03), # 3
dict(prop="c_hat_H"), # 3
dict(prop="r_BN_P", norm=orbitalMotion.REQ_EARTH * 1e3), # 3
dict(prop="v_BN_P", norm=7616.5), # 3
dict(prop="battery_charge_fraction"), # 1
dict(prop="storage_level_fraction"), # 1
dict(prop="wheel_speeds_fraction"), # 3
dict(prop="s_hat_H", fn=s_hat_H), # 3
),
obs.OpportunityProperties(
dict(prop="opportunity_open", norm=T_ORBIT),
dict(prop="opportunity_close", norm=T_ORBIT),
type="ground_station",
n_ahead_observe=1,
), # 2
obs.Eclipse(norm=T_ORBIT), # 2
Density(intervals=20, norm=5), # 20
obs.OpportunityProperties(
dict(prop="priority"), # 32
dict(prop="r_LB_H", norm=orbitalMotion.REQ_EARTH * 1e3), # 32*3
dict(prop="target_angle", norm=np.pi / 2), # 32
dict(prop="target_angle_rate", norm=0.03), # 32
dict(prop="opportunity_open", norm=300.0), # 32
dict(prop="opportunity_close", norm=300.0), # 32
type="target",
n_ahead_observe=32,
),
]
action_spec = [
act.Charge(duration=60.0), # 1
act.Downlink(), # 1
act.Desat(duration=60.0), # 1
act.Image(n_ahead_image=32), # 32
]
dyn_type = dyn.FullFeaturedDynModel
fsw_type = fsw.SteeringImagerFSWModel
Function setup_env
is defined such that environment and the environment configurations can created for training and testing.
[4]:
def setup_env(
test: bool = True,
horizon: float = 90,
n_targets: Union[tuple, int] = (100, 3000),
target_distribution: str = "cities",
env_case: str = "no_failure_penalty",
):
"""
Setup the environment for the satellite tasking problem.
Args:
test: If True, sets up the environment for testing.
horizon: The time horizon for the simulation in orbits.
n_targets: The number of targets in the environment.
target_distribution: The distribution of targets, either "uniform" or "cities".
env_case: The environment case to use from the configuration.
Returns:
env: The configured environment (for testing).
satellite: The satellite object configured with the specified parameters.
env_args_dict: Dictionary of environment arguments (for training).
indexes: Dictionary of indexes for specific observations.
"""
if env_case not in config["sat_args"]:
raise ValueError(f"Environment case '{env_case}' not found in configuration.")
if target_distribution == "uniform":
scene_features = scene.UniformTargets(n_targets=n_targets)
elif target_distribution == "cities":
scene_features = scene.CityTargets(n_targets=n_targets)
else:
raise (ValueError("Invalid distribution type"))
sat_args = dict(
**config["general_sat_args"],
**config["sat_args"][env_case],
)
satellite = CustomSatComposed(
"EarthObserving",
sat_args=sat_args,
)
if test:
env = gym.make(
"SatelliteTasking-v1",
satellite=satellite,
world_type=world.GroundStationWorldModel,
scenario=scene_features,
rewarder=data.UniqueImageReward(),
time_limit=np.floor(T_ORBIT * horizon),
log_level="WARNING",
failure_penalty=0.0, # NO FAILURE PENALTY IN TEST
sim_rate=config["sim_params"]["sim_rate"],
max_step_duration=config["sim_params"]["max_step_duration"],
)
env_args_dict = None
else:
env = None
env_args_dict = dict(
satellite=satellite,
world_type=world.GroundStationWorldModel,
scenario=scene_features,
rewarder=data.UniqueImageReward(),
time_limit=np.floor(T_ORBIT * horizon),
log_level="WARNING",
failure_penalty=config["sim_params"]["failure_penalty"][env_case],
sim_rate=config["sim_params"]["sim_rate"],
max_step_duration=config["sim_params"]["max_step_duration"],
)
env = gym.make(
"SatelliteTasking-v1",
satellite=satellite,
world_type=world.GroundStationWorldModel,
scenario=scene_features,
rewarder=data.UniqueImageReward(),
time_limit=np.floor(0.1 * T_ORBIT),
log_level="WARNING",
failure_penalty=config["sim_params"]["failure_penalty"][env_case],
sim_rate=config["sim_params"]["sim_rate"],
max_step_duration=config["sim_params"]["max_step_duration"],
)
# Getting observation indexes - useful for shields
indexes = {
"wheel_speeds": [
env.satellite.observation_builder.obs_array_keys().index(
"sat_props.wheel_speeds_fraction[0]"
),
env.satellite.observation_builder.obs_array_keys().index(
"sat_props.wheel_speeds_fraction[1]"
),
env.satellite.observation_builder.obs_array_keys().index(
"sat_props.wheel_speeds_fraction[2]"
),
],
"stored_charge": [
env.satellite.observation_builder.obs_array_keys().index(
"sat_props.battery_charge_fraction"
)
],
"attitude_rate": [
env.satellite.observation_builder.obs_array_keys().index(
"sat_props.omega_BN_B_normd[0]"
),
env.satellite.observation_builder.obs_array_keys().index(
"sat_props.omega_BN_B_normd[1]"
),
env.satellite.observation_builder.obs_array_keys().index(
"sat_props.omega_BN_B_normd[2]"
),
],
"eclipse": [
env.satellite.observation_builder.obs_array_keys().index("eclipse[0]"),
env.satellite.observation_builder.obs_array_keys().index("eclipse[1]"),
],
}
return env, satellite, env_args_dict, indexes
Action replacement and action masking wrappers
Wrappers are used to extend the capabilities of the base environment:
WrapperActionLogging
provides useful logging metrics that can be easily accessed during training and testing.WrapperPostPosed
implements the action replacement logic. The wrapper receives a function that is used to monitor the action selected by the agent and modifies it if necessary. It is possible to specify an interference penalty.WrapperActionMasking
extends the observation returned by the agent to incorporate the masking, which will be used with a modified RLmodule during training.
[5]:
class WrapperActionLogging(Wrapper):
def __init__(
self,
env: Satellite,
):
super().__init__(env)
self._initialize_action_logger()
def _initialize_action_logger(self):
self.action_logger = {
"action_charge_count": 0,
"action_downlink_count": 0,
"action_desat_count": 0,
"action_image_count": 0,
"actions_total_count": 0,
}
self.shield_info = {
"shield_interference": False,
"original_action": None,
"shielded_action": None,
"shield_interference_count": 0,
"shield_penalty_total": 0.0,
"masking_all_actions_available_count": 0,
}
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None):
self._initialize_action_logger()
return self.env.reset(seed=seed, options=options)
def step(self, action: int):
if action == 0:
self.action_logger["action_charge_count"] += 1
elif action == 1:
self.action_logger["action_downlink_count"] += 1
elif action == 2:
self.action_logger["action_desat_count"] += 1
elif action >= 3:
self.action_logger["action_image_count"] += 1
self.action_logger["actions_total_count"] += 1
return self.env.step(action)
class WrapperPostPosed(WrapperActionLogging, ActionWrapper, RewardWrapper):
"""
A wrapper that allows for post-posing shields in a gym environment.
"""
def __init__(
self,
env: Satellite,
shield_function: Callable[[list[float], int], int] = None,
shield_penalty: float = 0.0,
):
super().__init__(env)
self.shield_function = shield_function
self.shield_penalty = shield_penalty
def action(self, action: int) -> int:
original_action = action
shielded_action = self.shield_function(
self.env.satellite.get_obs(), original_action
)
if shielded_action is None or shielded_action == original_action:
modified_action = original_action
self.shield_info["shield_interference"] = False
self.shield_info["original_action"] = original_action
self.shield_info["shielded_action"] = original_action
else:
modified_action = shielded_action
self.shield_info["shield_interference"] = True
self.shield_info["original_action"] = original_action
self.shield_info["shielded_action"] = shielded_action
self.shield_info["shield_interference_count"] += 1
return modified_action
def reward(self, reward: float) -> float:
if self.shield_info["shield_interference"]:
reward += self.shield_penalty
self.shield_info["shield_penalty_total"] += self.shield_penalty
return reward
class WrapperActionMasking(ObservationWrapper):
"""
A wrapper that allows for action-masking in a gym environment.
"""
def __init__(
self, env: Satellite, masking_function: Callable[[list[float]], list[int]]
):
super().__init__(env)
self.masking_function = masking_function
self.valid_actions = None
@property
def observation_space(self):
"""Return the single satellite observation space."""
self.unwrapped.observation_space
obs_space = gym.spaces.Dict(
{
"action_mask": gym.spaces.Box(0.0, 1.0, shape=(self.action_space.n,)),
"observations": self.unwrapped.satellite.observation_space,
}
)
return obs_space
def observation(self, observation: list) -> dict:
self.valid_actions = np.array(
self.masking_function(observation), dtype=np.float32
)
n_available_actions = np.sum(self.valid_actions)
if n_available_actions == len(self.valid_actions):
self.shield_info["masking_all_actions_available_count"] += 1
if n_available_actions == 0:
# if no actions are available, all actions are allowed
self.valid_actions = np.ones(len(self.valid_actions), dtype=np.float32)
observation_with_mask = {
"action_mask": self.valid_actions,
"observations": observation,
}
return observation_with_mask
Handmade shield
The handmade shield is an example of a simple shield constructed based on heuristics based on A comparative analysis of reinforcement learning algorithms for earth-observing satellite scheduling and Reinforcement Learning for Earth-Observing Satellite Autonomy with Event-Based Task Intervals. Although it provides a good compromise between performance and safety, it is prone to edge cases that are hard to model and predict. Instead, automated shields with stronger safety guarantees can be used, such as those discussed in Shielded Deep Reinforcement Learning for Complex Spacecraft Tasking.
[6]:
def power_shielding_function(
obs: Union[list, np.ndarray],
act: Union[int, None],
indexes: dict[str, list[int]],
min_power: float = 0.25,
charge_rate: float = 1.0,
discharge_rate: float = 1.0,
rw_threshold: float = 0.7,
) -> Union[None, int]:
"""Force charging if not in eclipse and below min_power or the time
Args:
obs: Observation vector
act: unshielded action
indexes: Dictionary with indexes of the observation vector
min_power: Minimum battery percentage. [%]
charge_rate: Rate of charging in charge mode. [%/orbit]
discharge_rate: Rate of discharge in eclipse. [%/orbit]
rw_threshold: Threshold for reaction wheel speeds. [%]
Returns:
Return None if shield not activated, else return shielded action.
"""
current_power = obs[indexes["stored_charge"][0]]
eclipse_start = obs[indexes["eclipse"][0]]
eclipse_end = obs[indexes["eclipse"][1]]
rw_1 = obs[indexes["wheel_speeds"][0]]
rw_2 = obs[indexes["wheel_speeds"][1]]
rw_3 = obs[indexes["wheel_speeds"][2]]
in_eclipse = eclipse_end < eclipse_start
# Check if current state is unsafe
if not in_eclipse and current_power < _power_requirement(
eclipse_start, eclipse_end, min_power, charge_rate, discharge_rate
):
return 0 # Returns charge action
else:
if any(np.abs(np.array([rw_1, rw_2, rw_3])) > rw_threshold):
return 2 # Returns desaturate action
return None
def _power_requirement(
eclipse_start: float,
eclipse_end: float,
min_power: float,
charge_rate: float,
discharge_rate: float,
) -> float:
eclipse_duration = (eclipse_end - eclipse_start) % 1
in_eclipse = eclipse_end < eclipse_start
if in_eclipse:
return min_power + eclipse_end * discharge_rate
else:
eclipse_draw = eclipse_duration * discharge_rate
charge_time = eclipse_draw / charge_rate
if charge_time < eclipse_start:
return min_power
else:
return min_power + (charge_time - eclipse_start) * charge_rate
The handmade shield can be used in both action replacement and masking. Function generate_shield_functions
creates the adequate function for each case.
[7]:
ACTION_SPACE_SIZE = 35
def generate_shield_functions(
shield_type: str, shield_mode: str, indexes: dict[str, list]
) -> Union[
Callable[[list[float], int], Union[int, None]], Callable[[list[float]], list[int]]
]:
"""
Generates shield functions based on the specified shield type and mode.
Args:
shield_type: Type of the shield (-1 for handmade, 0 for optimal shielding, 1 for two-step strategy, 2 for value function).
shield_mode: Mode of the shield ("postposed" or "action_masking").
indexes: Dictionary containing indexes for specific observations.
Returns:
Callable: A function that either shields actions or masks actions based on the shield type and mode.
Raises:
ValueError: If the shield type or mode is invalid.
"""
if shield_type not in ["unshielded", "handmade"]:
raise ValueError(f"Invalid shield type: {shield_type}")
if shield_mode not in ["postposed", "action_masking"]:
raise ValueError(
f"Invalid shield mode: {shield_mode} for shield type: {shield_type}"
)
if shield_type == "unshielded":
if shield_mode == "postposed":
def shield_function(obs: list[float], act: int) -> int:
return act # No shielding, return the action as is
return shield_function
elif shield_mode == "action_masking":
def mask_function(obs: list[float]) -> list[int]:
return [1] * ACTION_SPACE_SIZE # All actions are valid
return mask_function
elif shield_type == "handmade":
if shield_mode == "postposed":
def shield_function(obs: list[float], act: int) -> Union[int, None]:
return power_shielding_function(obs, act, indexes)
return shield_function
elif shield_mode == "action_masking":
def mask_function(obs: list[float]) -> list[int]:
shielded_action = power_shielding_function(obs, None, indexes)
if shielded_action is None:
return [1] * ACTION_SPACE_SIZE
else:
if shielded_action == 0:
mask_vector = [0] * ACTION_SPACE_SIZE
mask_vector[0] = 1
return mask_vector
elif shielded_action == 2:
mask_vector = [0] * ACTION_SPACE_SIZE
mask_vector[2] = 1
return mask_vector
return mask_function
An modified RLModule from RLLib is used for the action masking module with a small modification. ActionMaskingTorchRLModule
uses a logit-level log-infinity mask where
such that the original logits \(\mathbf{l}\) are added to the log of the mask vector \(\mathbf{m}\) (element wise) resulting in the \(\mathbf{l}^\varphi\). When the action probabilities are computed using softmax
actions masked out receive zero probability. As implemented, the probability provided by \(\pi^\varphi\) is used in the policy update.
[8]:
from ray.rllib.examples.rl_modules.classes.action_masking_rlm import (
ActionMaskingTorchRLModule as BaseActionMaskingTorchRLModule,
)
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.utils.typing import TensorType
from ray.rllib.utils.annotations import override
class ActionMaskingTorchRLModule(BaseActionMaskingTorchRLModule):
@override(ValueFunctionAPI)
def compute_values(self, batch: Dict[str, TensorType]):
# Preprocess the batch to extract the `observations` to `Columns.OBS`.
_, batch = self._preprocess_batch(batch)
# Call the super's method to compute values for GAE.
return super(BaseActionMaskingTorchRLModule, self).compute_values(batch)
Training setup
First, the data callback is defined and metrics provided by the WrapperActionLogging
wrapper stored.
[9]:
from bsk_rl.utils.rllib.callbacks import EpisodeDataLogger
from numpy import floating
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
def episode_data_callback(env) -> dict[str, float | floating[Any]]:
"""
Collects data at the end of each episode for the RLlib environment.
"""
reward = env.rewarder.cum_reward
reward = sum(reward.values()) / len(reward)
orbits = env.simulator.sim_time / T_ORBIT
imaged = env.satellite.imaged
missed = env.satellite.missed
count_charge = env.action_logger["action_charge_count"]
count_downlink = env.action_logger["action_downlink_count"]
count_desat = env.action_logger["action_desat_count"]
count_image = env.action_logger["action_image_count"]
count_total_actions = env.action_logger["actions_total_count"]
count_shield_interference = env.shield_info["shield_interference_count"]
shield_penalty_total = env.shield_info["shield_penalty_total"]
masking_all_actions_available_count = env.shield_info[
"masking_all_actions_available_count"
]
mask_active_percent = (
100 - (masking_all_actions_available_count / count_total_actions * 100)
if count_total_actions > 0
else 0
)
data = dict(
reward=reward,
alive_percentage=float(env.satellite.is_alive()),
imaged=imaged,
missed=missed,
orbits_complete=orbits,
data_storage_capacity=env.satellite.dynamics.storageUnit.storageCapacity,
battery_capacity=env.satellite.dynamics.powerMonitor.storageCapacity,
external_torque=np.linalg.norm(
env.satellite.dynamics.extForceTorqueObject.extTorquePntB_B
),
valid_battery=float(
env.satellite.dynamics.battery_valid()
), # True if the battery is valid
valid_rw=float(
env.satellite.dynamics.rw_speeds_valid()
), # True if RW speeds are valid
count_charge_action=count_charge,
count_downlink_action=count_downlink,
count_desat_action=count_desat,
count_image_action=count_image,
count_total_actions=count_total_actions,
count_shield_interference=count_shield_interference,
shield_penalty_total=shield_penalty_total,
masking_all_actions_available_count=masking_all_actions_available_count,
mask_active_percent=mask_active_percent,
)
if orbits > 0:
data["reward_per_orbit"] = reward / orbits
data["imaged_per_orbit"] = imaged / orbits
data["count_charge_action_per_orbit"] = count_charge / orbits
data["count_downlink_action_per_orbit"] = count_downlink / orbits
data["count_desat_action_per_orbit"] = count_desat / orbits
data["count_image_action_per_orbit"] = count_image / orbits
data["count_total_actions_per_orbit"] = count_total_actions / orbits
data["attempts_per_orbit"] = (imaged + missed) / orbits
if not env.satellite.is_alive():
data["orbits_complete_partial_only"] = orbits
if imaged == 0:
data["avg_tgt_val"] = 0
data["success_rate"] = 0
else:
data["avg_tgt_val"] = reward / imaged
data["success_rate"] = imaged / (imaged + missed)
data["attempts"] = imaged + missed
if count_total_actions > 0:
data["count_shield_interference_percent"] = (
count_shield_interference / count_total_actions * 100
)
return data
Next, the training is configured with the options to use different environment cases, shields, and shield modes.
[10]:
env_case = "no_failure_penalty" # no_failure_penalty, failure_penalty, inf_power
shield_type = "handmade" # unshielded or handmade
shield_mode = "action_masking" # postposed or action_masking
shield_penalty = -0.1
_, _, env_args, indexes = setup_env(
test=False,
horizon=3,
n_targets=(100, 3000),
target_distribution="cities",
env_case=env_case,
)
training_args = dict(
lr=0.00003,
gamma=0.997,
train_batch_size=int(128), # Originally 3000
num_sgd_iter=10,
lambda_=0.95,
use_kl_loss=False,
clip_param=0.2,
grad_clip=0.5,
entropy_coeff=0.0,
)
rl_module_args = dict(
model_config_dict={
"use_lstm": False,
"fcnet_hiddens": [2048] * 2,
"vf_share_layers": False,
},
)
shield_function = generate_shield_functions(shield_type, shield_mode, indexes)
if shield_mode == "postposed":
shield_function = generate_shield_functions(shield_type, "postposed", indexes)
elif shield_mode == "action_masking":
mask_function = generate_shield_functions(shield_type, "action_masking", indexes)
rl_module_args = {}
if shield_mode == "postposed" or shield_type == "unshielded":
def env_creation(**env_config) -> WrapperPostPosed:
env = SatelliteTasking(**env_config)
env = WrapperPostPosed(
env, shield_function=shield_function, shield_penalty=shield_penalty
)
return env
elif shield_mode == "action_masking":
rl_module_args["rl_module_spec"] = RLModuleSpec(
module_class=ActionMaskingTorchRLModule,
)
def env_creation(**env_config) -> WrapperActionMasking:
env = SatelliteTasking(**env_config)
env = WrapperActionLogging(env)
env = WrapperActionMasking(env, masking_function=mask_function)
return env
class Env_wrapped(EpisodeDataLogger, Wrapper):
def __init__(self, env_config):
episode_data_callback = env_config.pop("episode_data_callback", None)
satellite_data_callback = env_config.pop("satellite_data_callback", None)
env = env_creation(**env_config)
EpisodeDataLogger.__init__(self, episode_data_callback, satellite_data_callback)
Wrapper.__init__(self, env)
env_args["episode_data_callback"] = episode_data_callback
2025-08-25 18:19:33,681 utils.orbital WARNING Ignoring a, e, and omega and using alt and r_body to generate a circular orbit.random_circular_orbit is preferred for this use case.
Training algorithm is configured and initialized with a maximum of 264 steps.
[11]:
import ray
from bsk_rl.utils.rllib.callbacks import WrappedEpisodeDataCallbacks
from bsk_rl.utils.rllib.discounting import TimeDiscountedGAEPPOTorchLearner
from ray.rllib.algorithms.ppo import PPOConfig
from ray import tune
N_CPUS = 3 # Originally 32
ppo_config = (
PPOConfig()
.training(
**training_args,
learner_class=TimeDiscountedGAEPPOTorchLearner,
)
.env_runners(num_env_runners=N_CPUS - 1, sample_timeout_s=1000.0)
.environment(
env=Env_wrapped,
env_config=env_args,
)
.reporting(
metrics_num_episodes_for_smoothing=1,
metrics_episode_collection_timeout_s=180,
)
.checkpointing(export_native_model_files=True)
.framework(framework="torch")
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.callbacks(WrappedEpisodeDataCallbacks)
)
ppo_config.rl_module(**rl_module_args)
ray.init(
ignore_reinit_error=True,
num_cpus=N_CPUS,
object_store_memory=2_000_000_000, # 2 GB
)
# Run the training
results = tune.run(
"PPO",
config=ppo_config.to_dict(),
stop={
"num_env_steps_sampled_lifetime": 264
}, # Total number of steps to train the model. Originally 20M
checkpoint_freq=1,
checkpoint_at_end=True,
)
ray.shutdown()
2025-08-25 18:19:37,029 INFO worker.py:1783 -- Started a local Ray instance.
2025-08-25 18:19:40,614 INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949
/opt/hostedtoolcache/Python/3.11.13/x64/lib/python3.11/site-packages/gymnasium/spaces/box.py:130: UserWarning: WARN: Box bound precision lowered by casting to float32
gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
/opt/hostedtoolcache/Python/3.11.13/x64/lib/python3.11/site-packages/gymnasium/utils/passive_env_checker.py:164: UserWarning: WARN: The obs returned by the `reset()` method was expecting numpy array dtype to be float32, actual type: float64
logger.warn(
/opt/hostedtoolcache/Python/3.11.13/x64/lib/python3.11/site-packages/gymnasium/utils/passive_env_checker.py:188: UserWarning: WARN: The obs returned by the `reset()` method is not within the observation space.
logger.warn(f"{pre} is not within the observation space.")
Tune Status
Current time: | 2025-08-25 18:20:38 |
Running for: | 00:00:58.13 |
Memory: | 4.7/15.6 GiB |
System Info
Using FIFO scheduling algorithm.Logical resource usage: 3.0/3 CPUs, 0/0 GPUs
Trial Status
Trial name | status | loc | iter | total time (s) | num_env_steps_sample d_lifetime | num_episodes_lifetim e | num_env_steps_traine d_lifetime |
---|---|---|---|---|---|---|---|
PPO_Env_wrapped_11007_00000 | TERMINATED | 10.1.0.68:6946 | 3 | 40.7141 | 384 | 2 | 384 |
(SingleAgentEnvRunner pid=6993) 2025-08-25 18:19:54,506 utils.orbital WARNING Ignoring a, e, and omega and using alt and r_body to generate a circular orbit.random_circular_orbit is preferred for this use case.
(PPO pid=6946) Trainable.setup took 10.175 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.
(PPO pid=6946) Install gputil for GPU system monitoring.
Trial Progress
Trial name | env_runners | fault_tolerance | learners | num_agent_steps_sampled_lifetime | num_env_steps_sampled_lifetime | num_env_steps_trained_lifetime | num_episodes_lifetime | perf | timers |
---|---|---|---|---|---|---|---|---|---|
PPO_Env_wrapped_11007_00000 | {'num_agent_steps_sampled_lifetime': {'default_agent': 768}, 'reward': nan, 'mask_active_percent': nan, 'num_episodes': 0, 'num_agent_steps_sampled': {'default_agent': 128}, 'data_storage_capacity': nan, 'masking_all_actions_available_count': nan, 'sample': np.float64(11.020856162246542), 'count_shield_interference': nan, 'count_desat_action': nan, 'num_env_steps_sampled_lifetime': 1152, 'battery_capacity': nan, 'count_total_actions': nan, 'count_charge_action': nan, 'avg_tgt_val': nan, 'count_image_action': nan, 'imaged': nan, 'missed': nan, 'attempts': nan, 'num_env_steps_sampled': 128, 'num_module_steps_sampled': {'default_policy': 128}, 'external_torque': nan, 'count_downlink_action': nan, 'orbits_complete': nan, 'valid_battery': nan, 'valid_rw': nan, 'success_rate': nan, 'num_module_steps_sampled_lifetime': {'default_policy': 768}, 'shield_penalty_total': nan, 'alive_percentage': nan, 'module_episode_returns_mean': {'default_policy': 5.326449544249592}, 'episode_return_min': 3.1882948939660705, 'agent_episode_returns_mean': {'default_agent': 5.326449544249592}, 'episode_len_max': 116, 'count_total_actions_per_orbit': nan, 'episode_duration_sec_mean': 17.35324782150002, 'time_between_sampling': np.float64(3.839375548681018), 'attempts_per_orbit': nan, 'count_charge_action_per_orbit': nan, 'episode_len_mean': 108.5, 'count_downlink_action_per_orbit': nan, 'count_image_action_per_orbit': nan, 'episode_return_max': 7.464604194533114, 'count_desat_action_per_orbit': nan, 'count_shield_interference_percent': nan, 'imaged_per_orbit': nan, 'episode_len_min': 101, 'reward_per_orbit': nan, 'episode_return_mean': 5.326449544249592} | {'num_healthy_workers': 2, 'num_in_flight_async_reqs': 0, 'num_remote_worker_restarts': 0} | {'__all_modules__': {'num_env_steps_trained': 128, 'num_module_steps_trained': 128, 'total_loss': 0.3029525876045227, 'num_trainable_parameters': 294948.0, 'num_non_trainable_parameters': 0.0}, 'default_policy': {'num_non_trainable_parameters': 0.0, 'curr_entropy_coeff': 0.0, 'vf_loss_unclipped': 0.3207870423793793, 'vf_loss': 0.3207870423793793, 'num_module_steps_trained': 128, 'total_loss': 0.3029525876045227, 'vf_explained_var': 0.24379831552505493, 'num_trainable_parameters': 294948.0, 'default_optimizer_learning_rate': 3e-05, 'gradients_default_optimizer_global_norm': 1.9472508430480957, 'mean_kl_loss': 0.0, 'policy_loss': -0.017834434285759926, 'entropy': 3.389056444168091}} | {'default_agent': 384} | 384 | 384 | 2 | {'cpu_util_percent': np.float64(42.395), 'ram_util_percent': np.float64(29.799999999999994)} | {'env_runner_sampling_timer': 14.623570170496375, 'learner_update_timer': 0.0770635410664295, 'synch_weights': 0.006580553033943306, 'synch_env_connectors': 0.006046280278793051} |
(PPO pid=6946) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/runner/ray_results/PPO_2025-08-25_18-19-40/PPO_Env_wrapped_11007_00000_0_2025-08-25_18-19-40/checkpoint_000000)
(SingleAgentEnvRunner pid=6992) 2025-08-25 18:19:57,753 utils.orbital WARNING Ignoring a, e, and omega and using alt and r_body to generate a circular orbit.random_circular_orbit is preferred for this use case. [repeated 4x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)
(SingleAgentEnvRunner pid=6993) 2025-08-25 18:20:22,478 utils.orbital WARNING Ignoring a, e, and omega and using alt and r_body to generate a circular orbit.random_circular_orbit is preferred for this use case. [repeated 2x across cluster]
(PPO pid=6946) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/runner/ray_results/PPO_2025-08-25_18-19-40/PPO_Env_wrapped_11007_00000_0_2025-08-25_18-19-40/checkpoint_000001)
(PPO pid=6946) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/runner/ray_results/PPO_2025-08-25_18-19-40/PPO_Env_wrapped_11007_00000_0_2025-08-25_18-19-40/checkpoint_000002)
2025-08-25 18:20:38,783 INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/runner/ray_results/PPO_2025-08-25_18-19-40' in 0.0476s.
2025-08-25 18:20:39,120 INFO tune.py:1041 -- Total run time: 58.51 seconds (58.08 seconds for the tuning loop).
Testing configuration
The original testing environment was 90-orbit long, leading to a significant decrease in the number of available targets over time. To solve this issues, a HiddentargetsMask
was introduced, which manages the number of targets visible to the satellite at any given moment through an opportunity filter. Originally, 15,000 targets were created in each environment but only targets_max
targets were available to the satellite at any given time. When the agent successfully acquired a target, a
target previously hidden would be made available.
[12]:
from itertools import compress
class HiddenTargetsMask:
def __init__(self, list_targets: list, targets_max: int):
"""Initialize the mask for hidden targets
Args:
list_targets: A list of all targets in the environment.
targets_max: The maximum number of targets available to the agent in a given time step.
verbose: If True, prints information about the targets.
"""
self.list_targets = list_targets
self.targets_max = targets_max
self.n_total_targets = len(self.list_targets)
self.mask = [True] * self.n_total_targets
self.n_imaged = 0
self.hidden_set = None
def compute_mask(self, n_imaged: int):
if self.hidden_set is None or self.n_imaged != n_imaged:
self.n_imaged = n_imaged
self.mask = self.replace_targets(
self.mask,
self.n_imaged,
self.n_total_targets,
self.targets_max,
)
self.hidden_set = set(compress(self.list_targets, self.mask))
return self.hidden_set
else:
return None
@staticmethod
def replace_targets(
mask: list[bool],
n_imaged: int,
n_total_targets: int,
n_max_targets: int,
) -> list[bool]:
"""Add targets to the environment
Args:
mask: The mask of hidden targets.
n_imaged: The number of targets that have been imaged.
n_total_targets: The total number of targets in the environment.
n_max_targets: The maximum number of targets available to the agent in a given time step.
verbose: If True, prints information about the targets.
Returns:
mask: The updated mask of hidden targets.
"""
mask = np.array(mask, dtype=bool)
n_hidden = np.sum(mask)
n_available = n_total_targets - n_imaged - n_hidden
if n_available < n_max_targets:
n_new = n_max_targets - n_available
mask_idxs = np.where(mask)[0]
readd_idxs = np.random.choice(
mask_idxs,
size=n_new,
replace=False,
)
mask[readd_idxs] = False
return mask.tolist()
load_policy
function returns a policy function compatible with the different training methods. If training was performed with action masking, it is necessary to specify embedded_masking=True
so that an unmasked environment becomes compatible with it. A different masking function can also be specified to be used during testing.
[13]:
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.core import DEFAULT_MODULE_ID
import torch
from ray.rllib.core.columns import Columns
from ray.rllib.utils.numpy import convert_to_numpy, softmax
from ray.rllib.utils.torch_utils import FLOAT_MIN
def load_policy(policy_path_general: Path) -> Callable:
"""Load a PyTorch policy from a saved model.
Args:
policy_path_general: The path to the saved model.
Returns:
A function that takes observations and returns actions.
"""
rl_module = RLModule.from_checkpoint(
policy_path_general
/ "learner_group"
/ "learner"
/ "rl_module"
/ DEFAULT_MODULE_ID,
)
def policy(
obs: list[float],
deterministic: bool = True,
embedded_masking: bool = False,
masking_function: Union[Callable[list[float], list[float]], None] = None,
) -> int:
"""Policy function that takes observations and returns actions.
Args:
obs: A list of observations.
deterministic: If True, use argmax for action selection; otherwise, sample from the action distribution.
embedded_masking: If True, model was trained with embedded masking and observation needs to be modified.
masking_function: A function that takes observations and returns a mask for valid actions. If None, no masking is applied.
Returns:
An integer representing the selected action.
"""
if isinstance(obs, dict):
obs_vec = obs["observations"]
else:
obs_vec = obs
obs_vec = np.array(obs_vec, dtype=np.float32)
if not embedded_masking:
input_dict = {Columns.OBS: torch.from_numpy(obs_vec).unsqueeze(0)}
else:
if isinstance(obs, list) or isinstance(obs, np.ndarray):
mask = np.ones(35, dtype=np.float32)
if masking_function is not None:
# If possible, masking is applies inside the RLModule
mask = masking_function(obs_vec)
input_dict = {
Columns.OBS: {
"observations": torch.from_numpy(obs).unsqueeze(0),
"action_mask": torch.from_numpy(mask).unsqueeze(0),
}
}
else:
input_dict = {
Columns.OBS: {
"observations": torch.from_numpy(obs["observations"]).unsqueeze(
0
),
"action_mask": torch.from_numpy(obs["action_mask"]).unsqueeze(
0
),
}
}
rl_module_out = rl_module.forward_inference(input_dict)
logits = convert_to_numpy(rl_module_out[Columns.ACTION_DIST_INPUTS])
if not embedded_masking and masking_function is not None:
mask = masking_function(obs)
inf_mask = torch.clamp(torch.log(mask), min=FLOAT_MIN)
logits[0] += inf_mask.numpy()
if deterministic:
action = np.argmax(logits[0]) # Use argmax for deterministic action
else:
action = np.random.choice(len(logits[0]), p=softmax(logits[0]))
return int(action)
return policy
Testing is then performed specifying the desired shield method and type. The hidden targets are also re-computed at every step.
[14]:
# Loading the policy produced by tune.run()
policy_path = pathlib.Path(results.get_last_checkpoint().to_directory())
shield_type = "handmade"
shield_mode = "postposed"
targets_max = 100 # Originally varying from (100, 3000)
total_targets = 1000 # Originally 15,000
embedded_masking = True # Necessary when training with action masking. Set to False is trained with action replacement
env_case = "no_failure_penalty"
env, _, _, indexes = setup_env(
test=True,
horizon=1.0, # Orbits. Originally 90
n_targets=total_targets,
target_distribution="cities",
env_case=env_case,
)
shield_function = generate_shield_functions(shield_type, shield_mode, indexes)
policy = load_policy(policy_path)
if shield_mode == "postposed":
env = WrapperPostPosed(env, shield_function=shield_function)
elif shield_mode == "action_masking":
raise NotImplementedError(
"Action masking is not implemented in testing. Cases were tested with postposed shields"
)
_, _ = env.reset()
reward_cumulative = 0
hidden_targets_mask = HiddenTargetsMask(
list_targets=env.satellite.data_store.data.known,
targets_max=targets_max,
)
env.satellite.hidden_targets = hidden_targets_mask.compute_mask(n_imaged=0)
def replace_targets_filter(opp, sat):
return opp["object"] not in sat.hidden_targets
env.satellite.add_access_filter(
lambda opp, sat=env.satellite: replace_targets_filter(opp, sat)
)
while True:
sat = env.satellite
hidden_targets = hidden_targets_mask.compute_mask(sat.imaged)
if hidden_targets is not None:
sat.hidden_targets = hidden_targets
sat.observation_builder.obs_dict_cache = None
action = policy(sat.get_obs(), embedded_masking=embedded_masking)
_, reward, terminated, truncated, _ = env.step(action)
reward_cumulative += reward
if terminated or truncated:
break
print(f"Cumulative reward: {reward_cumulative}")
2025-08-25 18:20:40,572 WARNING Creating logger for new env on PID=6649. Old environments in process may now log times incorrectly.
2025-08-25 18:20:40,631 utils.orbital WARNING Ignoring a, e, and omega and using alt and r_body to generate a circular orbit.random_circular_orbit is preferred for this use case.
2025-08-25 18:20:41,163 utils.orbital WARNING Ignoring a, e, and omega and using alt and r_body to generate a circular orbit.random_circular_orbit is preferred for this use case.
Cumulative reward: 0.0