{ "cells": [ { "cell_type": "markdown", "id": "3f8d5d2d", "metadata": {}, "source": [ "# Shielded training with action masking and action replacement\n", "\n", "This example script demonstrates how to train and test a policy with shields. Although DRL can lead to a high-performing policy, it still lacks safety guarantees. Further, adding negative rewards for reaching unsafe states can help the agent learn the safety aspects of the problems, but still lacks guarantees while being prone to issues. Shields can be used to provide these guarantees while minimizing interference. Two main approaches to provide safety during training and testing are action masking and action replacement.\n", "\n", "**Action replacement**: The agent selects an action and the shield checks if it is a safe action. If it is deemed unsafe, the action is replaced with a safe action. This information might or might not be used for the policy update during training (the latter is implemented in this script). Optionally, an interference penalty can be provided to the agent for every shield correction.\n", "\n", "**Action masking**: Provides a masking function, containing allowed and not-allowed actions at a given state. These method alters the logits (raw outputs) of the actor network, such that the agent can only sample safe actions. In this case, the safe selected action is accounted for in the policy update.\n", "\n", "This paper is associated with [Performance Evaluation of Shielded Neural Networks for Autonomous Agile\n", "Earth Observing Satellites in Long Term Scenarios](https://hanspeterschaub.info/Papers/QuevedoMantovani2025a.pdf) and a future publication. Failure penalties, action masking and action replacement with different interference penalties were investigated for training and their performance analyzed during testing.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "7ff54bca", "metadata": {}, "outputs": [], "source": [ "import gymnasium as gym\n", "from gymnasium import ActionWrapper, ObservationWrapper, RewardWrapper, Wrapper\n", "\n", "from Basilisk.architecture import bskLogging\n", "from bsk_rl import act, data, obs, scene, sats\n", "from bsk_rl.sim import dyn, fsw, world\n", "from bsk_rl.utils.orbital import random_orbit\n", "from Basilisk.utilities import orbitalMotion\n", "from typing import Callable, Union, Dict, Any, Iterable, TypeVar\n", "from pathlib import Path\n", "from bsk_rl.gym import SatelliteTasking\n", "import pathlib\n", "import numpy as np\n", "\n", "Satellite = TypeVar(\"Satellite\")\n", "SatObs = TypeVar(\"SatObs\")\n", "SatAct = TypeVar(\"SatAct\")\n", "MultiSatObs = tuple[SatObs, ...]\n", "MultiSatAct = Iterable[SatAct]\n", "SatArgRandomizer = Callable[[list[Satellite]], dict[Satellite, dict[str, Any]]]\n", "\n", "bskLogging.setDefaultLogLevel(bskLogging.BSK_WARNING)" ] }, { "cell_type": "markdown", "id": "37512123", "metadata": {}, "source": [ "## Different environment configurations\n", "\n", "Three different environment options are available: \n", "- ``no_failure_penalty``: No penalty is assigned for failure\n", "- ``failure_penalty``: Penalty of -10 is assigned for failure\n", "- ``inf_power``: Case with unlimited power and no reaction wheel speed limits" ] }, { "cell_type": "code", "execution_count": null, "id": "4a5b9170", "metadata": {}, "outputs": [], "source": [ "ALTITUDE = 800 # km\n", "\n", "T_ORBIT = (\n", " 2\n", " * np.pi\n", " * np.sqrt((orbitalMotion.REQ_EARTH + ALTITUDE) ** 3 / orbitalMotion.MU_EARTH)\n", ")\n", "\n", "config = dict(\n", " general_sat_args=dict(\n", " oe=lambda: random_orbit(\n", " alt=ALTITUDE, # 800 km altitude\n", " i=45, # 45 degrees inclination\n", " ),\n", " imageAttErrorRequirement=0.01, # 0.01 MRP normalized\n", " imageRateErrorRequirement=0.01, # 0.01 rad/s\n", " u_max=0.4, # Maximum control input\n", " K1=0.25, # Control gain\n", " K3=3.0, # Control gain\n", " servo_P=150 / 5, # Servo gain\n", " omega_max=np.radians(5.0), # Maximum rate command in degrees per second\n", " imageTargetMinimumElevation=np.arctan(\n", " 800 / 500\n", " ), # 58 degrees minimum elevation\n", " rwBasePower=20,\n", " transmitterPacketSize=0,\n", " ),\n", " sat_args=dict(\n", " no_failure_penalty=dict(\n", " batteryStorageCapacity=80.0 * 3600 * 2,\n", " storedCharge_Init=lambda: np.random.uniform(0.4, 1.0) * 80.0 * 3600 * 2,\n", " maxWheelSpeed=1500,\n", " wheelSpeeds=lambda: np.random.uniform(\n", " -750,\n", " 750,\n", " 3,\n", " ),\n", " dataStorageCapacity=2 * 8e6 * 25, # Stores 50 images\n", " storageInit=lambda: np.random.randint(\n", " 0,\n", " 0.99 * 2 * 8e6 * 25,\n", " ),\n", " ),\n", " failure_penalty=dict(\n", " batteryStorageCapacity=80.0 * 3600 * 2,\n", " storedCharge_Init=lambda: np.random.uniform(0.4, 1.0) * 80.0 * 3600 * 2,\n", " maxWheelSpeed=1500,\n", " wheelSpeeds=lambda: np.random.uniform(\n", " -750,\n", " 750,\n", " 3,\n", " ),\n", " dataStorageCapacity=2 * 8e6 * 25, # Stores 50 images\n", " storageInit=lambda: np.random.randint(\n", " 0,\n", " 0.99 * 2 * 8e6 * 25,\n", " ),\n", " ),\n", " inf_power=dict(\n", " batteryStorageCapacity=80.0 * 3600 * 2 * 1000,\n", " storedCharge_Init=0.99 * 80.0 * 3600 * 2 * 1000,\n", " maxWheelSpeed=15000,\n", " wheelSpeeds=lambda: np.random.uniform(\n", " -1,\n", " 1,\n", " 3,\n", " ),\n", " dataStorageCapacity=2 * 8e6 * 25, # Stores 50 images\n", " storageInit=lambda: np.random.randint(\n", " 0,\n", " 0.99 * 2 * 8e6 * 25,\n", " ),\n", " ),\n", " ),\n", " sim_params=dict(\n", " horizon=3,\n", " max_step_duration=300.0,\n", " sim_rate=0.5,\n", " failure_penalty=dict(\n", " no_failure_penalty=0.0,\n", " failure_penalty=-10.0,\n", " inf_power=0.0,\n", " ),\n", " ),\n", ")" ] }, { "cell_type": "markdown", "id": "624c3285", "metadata": {}, "source": [ "## Defining the satellite\n", "\n", "The observation space is based on [Learning Policies for Autonomous Earth-Observing Satellite Scheduling over Semi-MDPs](https://arc.aiaa.org/doi/10.2514/1.I011649)." ] }, { "cell_type": "code", "execution_count": null, "id": "fbb55747", "metadata": {}, "outputs": [], "source": [ "def s_hat_H(sat: Satellite):\n", " \"\"\"\n", " Computes the unit vector from the satellite body frame to the Sun in the Hill frame.\n", " \"\"\"\n", " r_SN_N = (\n", " sat.simulator.world.gravFactory.spiceObject.planetStateOutMsgs[\n", " sat.simulator.world.sun_index\n", " ]\n", " .read()\n", " .PositionVector\n", " )\n", " r_BN_N = sat.dynamics.r_BN_N\n", " r_SB_N = np.array(r_SN_N) - np.array(r_BN_N)\n", " r_SB_H = sat.dynamics.HN @ r_SB_N\n", " return r_SB_H / np.linalg.norm(r_SB_H)\n", "\n", "\n", "class Density(obs.Observation):\n", " def __init__(\n", " self,\n", " interval_duration=60 * 3,\n", " intervals=10,\n", " norm=3,\n", " ):\n", " self.satellite: \"sats.ImagingSatellite\"\n", " super().__init__()\n", " self.interval_duration = interval_duration\n", " self.intervals = intervals\n", " self.norm = norm\n", "\n", " def get_obs(self):\n", " if self.intervals == 0:\n", " return []\n", "\n", " self.satellite.calculate_additional_windows(\n", " self.simulator.sim_time\n", " + (self.intervals + 1) * self.interval_duration\n", " - self.satellite.window_calculation_time\n", " )\n", " soonest = self.satellite.upcoming_opportunities_dict(types=\"target\")\n", " rewards = np.array([opportunity.priority for opportunity in soonest])\n", " times = np.array([opportunities[0][1] for opportunities in soonest.values()])\n", " time_bins = np.floor((times - self.simulator.sim_time) / self.interval_duration)\n", " densities = [sum(rewards[time_bins == i]) for i in range(self.intervals)]\n", " return np.array(densities) / self.norm\n", "\n", "\n", "class CustomSatComposed(sats.ImagingSatellite):\n", " observation_spec = [\n", " obs.SatProperties(\n", " dict(prop=\"omega_BN_B\", norm=0.03), # 3\n", " dict(prop=\"c_hat_H\"), # 3\n", " dict(prop=\"r_BN_P\", norm=orbitalMotion.REQ_EARTH * 1e3), # 3\n", " dict(prop=\"v_BN_P\", norm=7616.5), # 3\n", " dict(prop=\"battery_charge_fraction\"), # 1\n", " dict(prop=\"storage_level_fraction\"), # 1\n", " dict(prop=\"wheel_speeds_fraction\"), # 3\n", " dict(prop=\"s_hat_H\", fn=s_hat_H), # 3\n", " ),\n", " obs.OpportunityProperties(\n", " dict(prop=\"opportunity_open\", norm=T_ORBIT),\n", " dict(prop=\"opportunity_close\", norm=T_ORBIT),\n", " type=\"ground_station\",\n", " n_ahead_observe=1,\n", " ), # 2\n", " obs.Eclipse(norm=T_ORBIT), # 2\n", " Density(intervals=20, norm=5), # 20\n", " obs.OpportunityProperties(\n", " dict(prop=\"priority\"), # 32\n", " dict(prop=\"r_LB_H\", norm=orbitalMotion.REQ_EARTH * 1e3), # 32*3\n", " dict(prop=\"target_angle\", norm=np.pi / 2), # 32\n", " dict(prop=\"target_angle_rate\", norm=0.03), # 32\n", " dict(prop=\"opportunity_open\", norm=300.0), # 32\n", " dict(prop=\"opportunity_close\", norm=300.0), # 32\n", " type=\"target\",\n", " n_ahead_observe=32,\n", " ),\n", " ]\n", "\n", " action_spec = [\n", " act.Charge(duration=60.0), # 1\n", " act.Downlink(), # 1\n", " act.Desat(duration=60.0), # 1\n", " act.Image(n_ahead_image=32), # 32\n", " ]\n", "\n", " dyn_type = dyn.FullFeaturedDynModel\n", " fsw_type = fsw.SteeringImagerFSWModel" ] }, { "cell_type": "markdown", "id": "0c38ce94", "metadata": {}, "source": [ "Function ``setup_env`` is defined such that environment and the environment configurations can created for training and testing." ] }, { "cell_type": "code", "execution_count": null, "id": "09b5937c", "metadata": {}, "outputs": [], "source": [ "def setup_env(\n", " test: bool = True,\n", " horizon: float = 90,\n", " n_targets: Union[tuple, int] = (100, 3000),\n", " target_distribution: str = \"cities\",\n", " env_case: str = \"no_failure_penalty\",\n", "):\n", " \"\"\"\n", " Setup the environment for the satellite tasking problem.\n", "\n", " Args:\n", " test: If True, sets up the environment for testing.\n", " horizon: The time horizon for the simulation in orbits.\n", " n_targets: The number of targets in the environment.\n", " target_distribution: The distribution of targets, either \"uniform\" or \"cities\".\n", " env_case: The environment case to use from the configuration.\n", "\n", " Returns:\n", " env: The configured environment (for testing).\n", " satellite: The satellite object configured with the specified parameters.\n", " env_args_dict: Dictionary of environment arguments (for training).\n", " indexes: Dictionary of indexes for specific observations.\n", " \"\"\"\n", "\n", " if env_case not in config[\"sat_args\"]:\n", " raise ValueError(f\"Environment case '{env_case}' not found in configuration.\")\n", "\n", " if target_distribution == \"uniform\":\n", " scene_features = scene.UniformTargets(n_targets=n_targets)\n", " elif target_distribution == \"cities\":\n", " scene_features = scene.CityTargets(n_targets=n_targets)\n", " else:\n", " raise (ValueError(\"Invalid distribution type\"))\n", "\n", " sat_args = dict(\n", " **config[\"general_sat_args\"],\n", " **config[\"sat_args\"][env_case],\n", " )\n", "\n", " satellite = CustomSatComposed(\n", " \"EarthObserving\",\n", " sat_args=sat_args,\n", " )\n", "\n", " if test:\n", " env = gym.make(\n", " \"SatelliteTasking-v1\",\n", " satellite=satellite,\n", " world_type=world.GroundStationWorldModel,\n", " scenario=scene_features,\n", " rewarder=data.UniqueImageReward(),\n", " time_limit=np.floor(T_ORBIT * horizon),\n", " log_level=\"WARNING\",\n", " failure_penalty=0.0, # NO FAILURE PENALTY IN TEST\n", " sim_rate=config[\"sim_params\"][\"sim_rate\"],\n", " max_step_duration=config[\"sim_params\"][\"max_step_duration\"],\n", " )\n", " env_args_dict = None\n", "\n", " else:\n", " env = None\n", " env_args_dict = dict(\n", " satellite=satellite,\n", " world_type=world.GroundStationWorldModel,\n", " scenario=scene_features,\n", " rewarder=data.UniqueImageReward(),\n", " time_limit=np.floor(T_ORBIT * horizon),\n", " log_level=\"WARNING\",\n", " failure_penalty=config[\"sim_params\"][\"failure_penalty\"][env_case],\n", " sim_rate=config[\"sim_params\"][\"sim_rate\"],\n", " max_step_duration=config[\"sim_params\"][\"max_step_duration\"],\n", " )\n", "\n", " env = gym.make(\n", " \"SatelliteTasking-v1\",\n", " satellite=satellite,\n", " world_type=world.GroundStationWorldModel,\n", " scenario=scene_features,\n", " rewarder=data.UniqueImageReward(),\n", " time_limit=np.floor(0.1 * T_ORBIT),\n", " log_level=\"WARNING\",\n", " failure_penalty=config[\"sim_params\"][\"failure_penalty\"][env_case],\n", " sim_rate=config[\"sim_params\"][\"sim_rate\"],\n", " max_step_duration=config[\"sim_params\"][\"max_step_duration\"],\n", " )\n", "\n", " # Getting observation indexes - useful for shields\n", " indexes = {\n", " \"wheel_speeds\": [\n", " env.satellite.observation_builder.obs_array_keys().index(\n", " \"sat_props.wheel_speeds_fraction[0]\"\n", " ),\n", " env.satellite.observation_builder.obs_array_keys().index(\n", " \"sat_props.wheel_speeds_fraction[1]\"\n", " ),\n", " env.satellite.observation_builder.obs_array_keys().index(\n", " \"sat_props.wheel_speeds_fraction[2]\"\n", " ),\n", " ],\n", " \"stored_charge\": [\n", " env.satellite.observation_builder.obs_array_keys().index(\n", " \"sat_props.battery_charge_fraction\"\n", " )\n", " ],\n", " \"attitude_rate\": [\n", " env.satellite.observation_builder.obs_array_keys().index(\n", " \"sat_props.omega_BN_B_normd[0]\"\n", " ),\n", " env.satellite.observation_builder.obs_array_keys().index(\n", " \"sat_props.omega_BN_B_normd[1]\"\n", " ),\n", " env.satellite.observation_builder.obs_array_keys().index(\n", " \"sat_props.omega_BN_B_normd[2]\"\n", " ),\n", " ],\n", " \"eclipse\": [\n", " env.satellite.observation_builder.obs_array_keys().index(\"eclipse[0]\"),\n", " env.satellite.observation_builder.obs_array_keys().index(\"eclipse[1]\"),\n", " ],\n", " }\n", "\n", " return env, satellite, env_args_dict, indexes" ] }, { "cell_type": "markdown", "id": "4c88f468", "metadata": {}, "source": [ "## Action replacement and action masking wrappers\n", "\n", "Wrappers are used to extend the capabilities of the base environment: \n", "- ``WrapperActionLogging`` provides useful logging metrics that can be easily accessed during training and testing.\n", "- ``WrapperPostPosed`` implements the action replacement logic. The wrapper receives a function that is used to monitor the action selected by the agent and modifies it if necessary. It is possible to specify an interference penalty.\n", "- ``WrapperActionMasking`` extends the observation returned by the agent to incorporate the masking, which will be used with a modified RLmodule during training.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "69402a6d", "metadata": {}, "outputs": [], "source": [ "class WrapperActionLogging(Wrapper):\n", "\n", " def __init__(\n", " self,\n", " env: Satellite,\n", " ):\n", "\n", " super().__init__(env)\n", " self._initialize_action_logger()\n", "\n", " def _initialize_action_logger(self):\n", " self.action_logger = {\n", " \"action_charge_count\": 0,\n", " \"action_downlink_count\": 0,\n", " \"action_desat_count\": 0,\n", " \"action_image_count\": 0,\n", " \"actions_total_count\": 0,\n", " }\n", " self.shield_info = {\n", " \"shield_interference\": False,\n", " \"original_action\": None,\n", " \"shielded_action\": None,\n", " \"shield_interference_count\": 0,\n", " \"shield_penalty_total\": 0.0,\n", " \"masking_all_actions_available_count\": 0,\n", " }\n", "\n", " def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None):\n", " self._initialize_action_logger()\n", " return self.env.reset(seed=seed, options=options)\n", "\n", " def step(self, action: int):\n", "\n", " if action == 0:\n", " self.action_logger[\"action_charge_count\"] += 1\n", " elif action == 1:\n", " self.action_logger[\"action_downlink_count\"] += 1\n", " elif action == 2:\n", " self.action_logger[\"action_desat_count\"] += 1\n", " elif action >= 3:\n", " self.action_logger[\"action_image_count\"] += 1\n", " self.action_logger[\"actions_total_count\"] += 1\n", "\n", " return self.env.step(action)\n", "\n", "\n", "class WrapperPostPosed(WrapperActionLogging, ActionWrapper, RewardWrapper):\n", " \"\"\"\n", " A wrapper that allows for post-posing shields in a gym environment.\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " env: Satellite,\n", " shield_function: Callable[[list[float], int], int] = None,\n", " shield_penalty: float = 0.0,\n", " ):\n", " super().__init__(env)\n", " self.shield_function = shield_function\n", " self.shield_penalty = shield_penalty\n", "\n", " def action(self, action: int) -> int:\n", " original_action = action\n", "\n", " shielded_action = self.shield_function(\n", " self.env.satellite.get_obs(), original_action\n", " )\n", "\n", " if shielded_action is None or shielded_action == original_action:\n", " modified_action = original_action\n", " self.shield_info[\"shield_interference\"] = False\n", " self.shield_info[\"original_action\"] = original_action\n", " self.shield_info[\"shielded_action\"] = original_action\n", "\n", " else:\n", " modified_action = shielded_action\n", " self.shield_info[\"shield_interference\"] = True\n", " self.shield_info[\"original_action\"] = original_action\n", " self.shield_info[\"shielded_action\"] = shielded_action\n", " self.shield_info[\"shield_interference_count\"] += 1\n", "\n", " return modified_action\n", "\n", " def reward(self, reward: float) -> float:\n", "\n", " if self.shield_info[\"shield_interference\"]:\n", " reward += self.shield_penalty\n", " self.shield_info[\"shield_penalty_total\"] += self.shield_penalty\n", "\n", " return reward\n", "\n", "\n", "class WrapperActionMasking(ObservationWrapper):\n", " \"\"\"\n", " A wrapper that allows for action-masking in a gym environment.\n", " \"\"\"\n", "\n", " def __init__(\n", " self, env: Satellite, masking_function: Callable[[list[float]], list[int]]\n", " ):\n", " super().__init__(env)\n", " self.masking_function = masking_function\n", " self.valid_actions = None\n", "\n", " @property\n", " def observation_space(self):\n", " \"\"\"Return the single satellite observation space.\"\"\"\n", " self.unwrapped.observation_space\n", " obs_space = gym.spaces.Dict(\n", " {\n", " \"action_mask\": gym.spaces.Box(0.0, 1.0, shape=(self.action_space.n,)),\n", " \"observations\": self.unwrapped.satellite.observation_space,\n", " }\n", " )\n", " return obs_space\n", "\n", " def observation(self, observation: list) -> dict:\n", " self.valid_actions = np.array(\n", " self.masking_function(observation), dtype=np.float32\n", " )\n", " n_available_actions = np.sum(self.valid_actions)\n", " if n_available_actions == len(self.valid_actions):\n", " self.shield_info[\"masking_all_actions_available_count\"] += 1\n", " if n_available_actions == 0:\n", " # if no actions are available, all actions are allowed\n", " self.valid_actions = np.ones(len(self.valid_actions), dtype=np.float32)\n", "\n", " observation_with_mask = {\n", " \"action_mask\": self.valid_actions,\n", " \"observations\": observation,\n", " }\n", "\n", " return observation_with_mask" ] }, { "cell_type": "markdown", "id": "0cbb2bd0", "metadata": {}, "source": [ "## Handmade shield\n", "\n", "The handmade shield is an example of a simple shield constructed based on heuristics based on [A comparative analysis of reinforcement learning algorithms for earth-observing satellite scheduling](https://www.frontiersin.org/journals/space-technologies/articles/10.3389/frspt.2023.1263489/full) and [Reinforcement Learning for Earth-Observing Satellite Autonomy with Event-Based Task Intervals](https://hanspeterschaub.info/Papers/Stephenson2024a.pdf). Although it provides a good compromise between performance and safety, it is prone to edge cases that are hard to model and predict. Instead, automated shields with stronger safety guarantees can be used, such as those discussed in [Shielded Deep Reinforcement Learning for Complex Spacecraft Tasking](https://hanspeterschaub.info/Papers/Reed2024.pdf)." ] }, { "cell_type": "code", "execution_count": null, "id": "c86aac1d", "metadata": {}, "outputs": [], "source": [ "def power_shielding_function(\n", " obs: Union[list, np.ndarray],\n", " act: Union[int, None],\n", " indexes: dict[str, list[int]],\n", " min_power: float = 0.25,\n", " charge_rate: float = 1.0,\n", " discharge_rate: float = 1.0,\n", " rw_threshold: float = 0.7,\n", ") -> Union[None, int]:\n", " \"\"\"Force charging if not in eclipse and below min_power or the time\n", "\n", " Args:\n", " obs: Observation vector\n", " act: unshielded action\n", " indexes: Dictionary with indexes of the observation vector\n", " min_power: Minimum battery percentage. [%]\n", " charge_rate: Rate of charging in charge mode. [%/orbit]\n", " discharge_rate: Rate of discharge in eclipse. [%/orbit]\n", " rw_threshold: Threshold for reaction wheel speeds. [%]\n", "\n", " Returns:\n", " Return None if shield not activated, else return shielded action.\n", " \"\"\"\n", "\n", " current_power = obs[indexes[\"stored_charge\"][0]]\n", " eclipse_start = obs[indexes[\"eclipse\"][0]]\n", " eclipse_end = obs[indexes[\"eclipse\"][1]]\n", " rw_1 = obs[indexes[\"wheel_speeds\"][0]]\n", " rw_2 = obs[indexes[\"wheel_speeds\"][1]]\n", " rw_3 = obs[indexes[\"wheel_speeds\"][2]]\n", "\n", " in_eclipse = eclipse_end < eclipse_start\n", "\n", " # Check if current state is unsafe\n", " if not in_eclipse and current_power < _power_requirement(\n", " eclipse_start, eclipse_end, min_power, charge_rate, discharge_rate\n", " ):\n", " return 0 # Returns charge action\n", "\n", " else:\n", " if any(np.abs(np.array([rw_1, rw_2, rw_3])) > rw_threshold):\n", " return 2 # Returns desaturate action\n", "\n", " return None\n", "\n", "\n", "def _power_requirement(\n", " eclipse_start: float,\n", " eclipse_end: float,\n", " min_power: float,\n", " charge_rate: float,\n", " discharge_rate: float,\n", ") -> float:\n", " eclipse_duration = (eclipse_end - eclipse_start) % 1\n", " in_eclipse = eclipse_end < eclipse_start\n", " if in_eclipse:\n", " return min_power + eclipse_end * discharge_rate\n", " else:\n", " eclipse_draw = eclipse_duration * discharge_rate\n", " charge_time = eclipse_draw / charge_rate\n", " if charge_time < eclipse_start:\n", " return min_power\n", " else:\n", " return min_power + (charge_time - eclipse_start) * charge_rate" ] }, { "cell_type": "markdown", "id": "1ddb67ce", "metadata": {}, "source": [ "The handmade shield can be used in both action replacement and masking. Function ``generate_shield_functions`` creates the adequate function for each case." ] }, { "cell_type": "code", "execution_count": null, "id": "c2707de3", "metadata": {}, "outputs": [], "source": [ "ACTION_SPACE_SIZE = 35\n", "\n", "\n", "def generate_shield_functions(\n", " shield_type: str, shield_mode: str, indexes: dict[str, list]\n", ") -> Union[\n", " Callable[[list[float], int], Union[int, None]], Callable[[list[float]], list[int]]\n", "]:\n", " \"\"\"\n", " Generates shield functions based on the specified shield type and mode.\n", " Args:\n", " shield_type: Type of the shield (-1 for handmade, 0 for optimal shielding, 1 for two-step strategy, 2 for value function).\n", " shield_mode: Mode of the shield (\"postposed\" or \"action_masking\").\n", " indexes: Dictionary containing indexes for specific observations.\n", " Returns:\n", " Callable: A function that either shields actions or masks actions based on the shield type and mode.\n", " Raises:\n", " ValueError: If the shield type or mode is invalid.\n", " \"\"\"\n", "\n", " if shield_type not in [\"unshielded\", \"handmade\"]:\n", " raise ValueError(f\"Invalid shield type: {shield_type}\")\n", " if shield_mode not in [\"postposed\", \"action_masking\"]:\n", " raise ValueError(\n", " f\"Invalid shield mode: {shield_mode} for shield type: {shield_type}\"\n", " )\n", "\n", " if shield_type == \"unshielded\":\n", "\n", " if shield_mode == \"postposed\":\n", "\n", " def shield_function(obs: list[float], act: int) -> int:\n", " return act # No shielding, return the action as is\n", "\n", " return shield_function\n", "\n", " elif shield_mode == \"action_masking\":\n", "\n", " def mask_function(obs: list[float]) -> list[int]:\n", " return [1] * ACTION_SPACE_SIZE # All actions are valid\n", "\n", " return mask_function\n", "\n", " elif shield_type == \"handmade\":\n", "\n", " if shield_mode == \"postposed\":\n", "\n", " def shield_function(obs: list[float], act: int) -> Union[int, None]:\n", " return power_shielding_function(obs, act, indexes)\n", "\n", " return shield_function\n", "\n", " elif shield_mode == \"action_masking\":\n", "\n", " def mask_function(obs: list[float]) -> list[int]:\n", " shielded_action = power_shielding_function(obs, None, indexes)\n", " if shielded_action is None:\n", " return [1] * ACTION_SPACE_SIZE\n", " else:\n", " if shielded_action == 0:\n", " mask_vector = [0] * ACTION_SPACE_SIZE\n", " mask_vector[0] = 1\n", " return mask_vector\n", " elif shielded_action == 2:\n", " mask_vector = [0] * ACTION_SPACE_SIZE\n", " mask_vector[2] = 1\n", " return mask_vector\n", "\n", " return mask_function" ] }, { "cell_type": "markdown", "id": "738c76a9", "metadata": {}, "source": [ "An modified RLModule from RLLib is used for the action masking module with a small modification. ``ActionMaskingTorchRLModule`` uses a logit-level log-infinity mask where\n", "\n", "$$\\mathbf{l}^\\varphi=\\mathbf{l}+log(\\mathbf{m})$$\n", "\n", "such that the original logits $\\mathbf{l}$ are added to the log of the mask vector $\\mathbf{m}$ (element wise) resulting in the $\\mathbf{l}^\\varphi$. When the action probabilities are computed using softmax\n", "\n", "$$\\pi^\\varphi(a,s)=\\frac{e^{l_a^\\varphi}}{\\sum_{a'\\in\\mathcal{A}}e^{l_a'^\\varphi}}$$\n", "\n", "actions masked out receive zero probability. As implemented, the probability provided by $\\pi^\\varphi$ is used in the policy update." ] }, { "cell_type": "code", "execution_count": null, "id": "c008be36", "metadata": {}, "outputs": [], "source": [ "from ray.rllib.examples.rl_modules.classes.action_masking_rlm import (\n", " ActionMaskingTorchRLModule as BaseActionMaskingTorchRLModule,\n", ")\n", "from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI\n", "from ray.rllib.utils.typing import TensorType\n", "from ray.rllib.utils.annotations import override\n", "\n", "\n", "class ActionMaskingTorchRLModule(BaseActionMaskingTorchRLModule):\n", "\n", " @override(ValueFunctionAPI)\n", " def compute_values(self, batch: Dict[str, TensorType]):\n", " # Preprocess the batch to extract the `observations` to `Columns.OBS`.\n", " _, batch = self._preprocess_batch(batch)\n", " # Call the super's method to compute values for GAE.\n", " return super(BaseActionMaskingTorchRLModule, self).compute_values(batch)" ] }, { "cell_type": "markdown", "id": "afb90554", "metadata": {}, "source": [ "## Training setup\n", "\n", "First, the data callback is defined and metrics provided by the ``WrapperActionLogging`` wrapper stored." ] }, { "cell_type": "code", "execution_count": null, "id": "a0e43ae3", "metadata": {}, "outputs": [], "source": [ "from bsk_rl.utils.rllib.callbacks import EpisodeDataLogger\n", "from numpy import floating\n", "from ray.rllib.core.rl_module.rl_module import RLModuleSpec\n", "\n", "\n", "def episode_data_callback(env) -> dict[str, float | floating[Any]]:\n", " \"\"\"\n", " Collects data at the end of each episode for the RLlib environment.\n", " \"\"\"\n", " reward = env.rewarder.cum_reward\n", " reward = sum(reward.values()) / len(reward)\n", " orbits = env.simulator.sim_time / T_ORBIT\n", " imaged = env.satellite.imaged\n", " missed = env.satellite.missed\n", " count_charge = env.action_logger[\"action_charge_count\"]\n", " count_downlink = env.action_logger[\"action_downlink_count\"]\n", " count_desat = env.action_logger[\"action_desat_count\"]\n", " count_image = env.action_logger[\"action_image_count\"]\n", " count_total_actions = env.action_logger[\"actions_total_count\"]\n", " count_shield_interference = env.shield_info[\"shield_interference_count\"]\n", " shield_penalty_total = env.shield_info[\"shield_penalty_total\"]\n", " masking_all_actions_available_count = env.shield_info[\n", " \"masking_all_actions_available_count\"\n", " ]\n", " mask_active_percent = (\n", " 100 - (masking_all_actions_available_count / count_total_actions * 100)\n", " if count_total_actions > 0\n", " else 0\n", " )\n", "\n", " data = dict(\n", " reward=reward,\n", " alive_percentage=float(env.satellite.is_alive()),\n", " imaged=imaged,\n", " missed=missed,\n", " orbits_complete=orbits,\n", " data_storage_capacity=env.satellite.dynamics.storageUnit.storageCapacity,\n", " battery_capacity=env.satellite.dynamics.powerMonitor.storageCapacity,\n", " external_torque=np.linalg.norm(\n", " env.satellite.dynamics.extForceTorqueObject.extTorquePntB_B\n", " ),\n", " valid_battery=float(\n", " env.satellite.dynamics.battery_valid()\n", " ), # True if the battery is valid\n", " valid_rw=float(\n", " env.satellite.dynamics.rw_speeds_valid()\n", " ), # True if RW speeds are valid\n", " count_charge_action=count_charge,\n", " count_downlink_action=count_downlink,\n", " count_desat_action=count_desat,\n", " count_image_action=count_image,\n", " count_total_actions=count_total_actions,\n", " count_shield_interference=count_shield_interference,\n", " shield_penalty_total=shield_penalty_total,\n", " masking_all_actions_available_count=masking_all_actions_available_count,\n", " mask_active_percent=mask_active_percent,\n", " )\n", "\n", " if orbits > 0:\n", " data[\"reward_per_orbit\"] = reward / orbits\n", " data[\"imaged_per_orbit\"] = imaged / orbits\n", " data[\"count_charge_action_per_orbit\"] = count_charge / orbits\n", " data[\"count_downlink_action_per_orbit\"] = count_downlink / orbits\n", " data[\"count_desat_action_per_orbit\"] = count_desat / orbits\n", " data[\"count_image_action_per_orbit\"] = count_image / orbits\n", " data[\"count_total_actions_per_orbit\"] = count_total_actions / orbits\n", " data[\"attempts_per_orbit\"] = (imaged + missed) / orbits\n", "\n", " if not env.satellite.is_alive():\n", " data[\"orbits_complete_partial_only\"] = orbits\n", "\n", " if imaged == 0:\n", " data[\"avg_tgt_val\"] = 0\n", " data[\"success_rate\"] = 0\n", " else:\n", " data[\"avg_tgt_val\"] = reward / imaged\n", " data[\"success_rate\"] = imaged / (imaged + missed)\n", "\n", " data[\"attempts\"] = imaged + missed\n", "\n", " if count_total_actions > 0:\n", " data[\"count_shield_interference_percent\"] = (\n", " count_shield_interference / count_total_actions * 100\n", " )\n", "\n", " return data" ] }, { "cell_type": "markdown", "id": "449e064c", "metadata": {}, "source": [ "Next, the training is configured with the options to use different environment cases, shields, and shield modes." ] }, { "cell_type": "code", "execution_count": null, "id": "c412a53f", "metadata": {}, "outputs": [], "source": [ "env_case = \"no_failure_penalty\" # no_failure_penalty, failure_penalty, inf_power\n", "shield_type = \"handmade\" # unshielded or handmade\n", "shield_mode = \"action_masking\" # postposed or action_masking\n", "shield_penalty = -0.1\n", "\n", "_, _, env_args, indexes = setup_env(\n", " test=False,\n", " horizon=3,\n", " n_targets=(100, 3000),\n", " target_distribution=\"cities\",\n", " env_case=env_case,\n", ")\n", "\n", "\n", "training_args = dict(\n", " lr=0.00003,\n", " gamma=0.997,\n", " train_batch_size=int(128), # Originally 3000\n", " num_sgd_iter=10,\n", " lambda_=0.95,\n", " use_kl_loss=False,\n", " clip_param=0.2,\n", " grad_clip=0.5,\n", " entropy_coeff=0.0,\n", ")\n", "rl_module_args = dict(\n", " model_config_dict={\n", " \"use_lstm\": False,\n", " \"fcnet_hiddens\": [2048] * 2,\n", " \"vf_share_layers\": False,\n", " },\n", ")\n", "\n", "\n", "shield_function = generate_shield_functions(shield_type, shield_mode, indexes)\n", "\n", "if shield_mode == \"postposed\":\n", " shield_function = generate_shield_functions(shield_type, \"postposed\", indexes)\n", "elif shield_mode == \"action_masking\":\n", " mask_function = generate_shield_functions(shield_type, \"action_masking\", indexes)\n", "\n", "rl_module_args = {}\n", "if shield_mode == \"postposed\" or shield_type == \"unshielded\":\n", "\n", " def env_creation(**env_config) -> WrapperPostPosed:\n", " env = SatelliteTasking(**env_config)\n", " env = WrapperPostPosed(\n", " env, shield_function=shield_function, shield_penalty=shield_penalty\n", " )\n", "\n", " return env\n", "\n", "elif shield_mode == \"action_masking\":\n", " rl_module_args[\"rl_module_spec\"] = RLModuleSpec(\n", " module_class=ActionMaskingTorchRLModule,\n", " )\n", "\n", " def env_creation(**env_config) -> WrapperActionMasking:\n", " env = SatelliteTasking(**env_config)\n", " env = WrapperActionLogging(env)\n", " env = WrapperActionMasking(env, masking_function=mask_function)\n", "\n", " return env\n", "\n", "\n", "class Env_wrapped(EpisodeDataLogger, Wrapper):\n", " def __init__(self, env_config):\n", " episode_data_callback = env_config.pop(\"episode_data_callback\", None)\n", " satellite_data_callback = env_config.pop(\"satellite_data_callback\", None)\n", " env = env_creation(**env_config)\n", " EpisodeDataLogger.__init__(self, episode_data_callback, satellite_data_callback)\n", " Wrapper.__init__(self, env)\n", "\n", "\n", "env_args[\"episode_data_callback\"] = episode_data_callback" ] }, { "cell_type": "markdown", "id": "ed9f2019", "metadata": {}, "source": [ "Training algorithm is configured and initialized with a maximum of 264 steps." ] }, { "cell_type": "code", "execution_count": null, "id": "6d8363a3", "metadata": {}, "outputs": [], "source": [ "import ray\n", "from bsk_rl.utils.rllib.callbacks import WrappedEpisodeDataCallbacks\n", "from bsk_rl.utils.rllib.discounting import TimeDiscountedGAEPPOTorchLearner\n", "from ray.rllib.algorithms.ppo import PPOConfig\n", "from ray import tune\n", "\n", "N_CPUS = 3 # Originally 32\n", "\n", "ppo_config = (\n", " PPOConfig()\n", " .training(\n", " **training_args,\n", " learner_class=TimeDiscountedGAEPPOTorchLearner,\n", " )\n", " .env_runners(num_env_runners=N_CPUS - 1, sample_timeout_s=1000.0)\n", " .environment(\n", " env=Env_wrapped,\n", " env_config=env_args,\n", " )\n", " .reporting(\n", " metrics_num_episodes_for_smoothing=1,\n", " metrics_episode_collection_timeout_s=180,\n", " )\n", " .checkpointing(export_native_model_files=True)\n", " .framework(framework=\"torch\")\n", " .api_stack(\n", " enable_rl_module_and_learner=True,\n", " enable_env_runner_and_connector_v2=True,\n", " )\n", " .callbacks(WrappedEpisodeDataCallbacks)\n", ")\n", "ppo_config.rl_module(**rl_module_args)\n", "\n", "ray.init(\n", " ignore_reinit_error=True,\n", " num_cpus=N_CPUS,\n", " object_store_memory=2_000_000_000, # 2 GB\n", ")\n", "\n", "# Run the training\n", "results = tune.run(\n", " \"PPO\",\n", " config=ppo_config.to_dict(),\n", " stop={\n", " \"num_env_steps_sampled_lifetime\": 264\n", " }, # Total number of steps to train the model. Originally 20M\n", " checkpoint_freq=1,\n", " checkpoint_at_end=True,\n", ")\n", "\n", "ray.shutdown()" ] }, { "cell_type": "markdown", "id": "a7e13876", "metadata": {}, "source": [ "## Testing configuration\n", "\n", "The original testing environment was 90-orbit long, leading to a significant decrease in the number of available targets over time. To solve this issues, a ``HiddentargetsMask`` was introduced, which manages the number of targets visible to the satellite at any given moment through an opportunity filter. Originally, 15,000 targets were created in each environment but only ``targets_max`` targets were available to the satellite at any given time. When the agent successfully acquired a target, a target previously hidden would be made available." ] }, { "cell_type": "code", "execution_count": null, "id": "4acb43fa", "metadata": {}, "outputs": [], "source": [ "from itertools import compress\n", "\n", "\n", "class HiddenTargetsMask:\n", "\n", " def __init__(self, list_targets: list, targets_max: int):\n", " \"\"\"Initialize the mask for hidden targets\n", "\n", " Args:\n", " list_targets: A list of all targets in the environment.\n", " targets_max: The maximum number of targets available to the agent in a given time step.\n", " verbose: If True, prints information about the targets.\n", " \"\"\"\n", " self.list_targets = list_targets\n", " self.targets_max = targets_max\n", " self.n_total_targets = len(self.list_targets)\n", "\n", " self.mask = [True] * self.n_total_targets\n", " self.n_imaged = 0\n", " self.hidden_set = None\n", "\n", " def compute_mask(self, n_imaged: int):\n", "\n", " if self.hidden_set is None or self.n_imaged != n_imaged:\n", " self.n_imaged = n_imaged\n", " self.mask = self.replace_targets(\n", " self.mask,\n", " self.n_imaged,\n", " self.n_total_targets,\n", " self.targets_max,\n", " )\n", " self.hidden_set = set(compress(self.list_targets, self.mask))\n", " return self.hidden_set\n", "\n", " else:\n", " return None\n", "\n", " @staticmethod\n", " def replace_targets(\n", " mask: list[bool],\n", " n_imaged: int,\n", " n_total_targets: int,\n", " n_max_targets: int,\n", " ) -> list[bool]:\n", " \"\"\"Add targets to the environment\n", "\n", " Args:\n", " mask: The mask of hidden targets.\n", " n_imaged: The number of targets that have been imaged.\n", " n_total_targets: The total number of targets in the environment.\n", " n_max_targets: The maximum number of targets available to the agent in a given time step.\n", " verbose: If True, prints information about the targets.\n", "\n", " Returns:\n", " mask: The updated mask of hidden targets.\n", " \"\"\"\n", " mask = np.array(mask, dtype=bool)\n", " n_hidden = np.sum(mask)\n", " n_available = n_total_targets - n_imaged - n_hidden\n", " if n_available < n_max_targets:\n", " n_new = n_max_targets - n_available\n", " mask_idxs = np.where(mask)[0]\n", " readd_idxs = np.random.choice(\n", " mask_idxs,\n", " size=n_new,\n", " replace=False,\n", " )\n", "\n", " mask[readd_idxs] = False\n", "\n", " return mask.tolist()" ] }, { "cell_type": "markdown", "id": "5c2d5bd1", "metadata": {}, "source": [ "``load_policy`` function returns a policy function compatible with the different training methods. If training was performed with action masking, it is necessary to specify ``embedded_masking=True`` so that an unmasked environment becomes compatible with it. A different masking function can also be specified to be used during testing." ] }, { "cell_type": "code", "execution_count": null, "id": "9f77a235", "metadata": {}, "outputs": [], "source": [ "from ray.rllib.core.rl_module.rl_module import RLModule\n", "from ray.rllib.core import DEFAULT_MODULE_ID\n", "import torch\n", "from ray.rllib.core.columns import Columns\n", "from ray.rllib.utils.numpy import convert_to_numpy, softmax\n", "from ray.rllib.utils.torch_utils import FLOAT_MIN\n", "\n", "\n", "def load_policy(policy_path_general: Path) -> Callable:\n", " \"\"\"Load a PyTorch policy from a saved model.\n", "\n", " Args:\n", " policy_path_general: The path to the saved model.\n", " Returns:\n", " A function that takes observations and returns actions.\n", " \"\"\"\n", "\n", " rl_module = RLModule.from_checkpoint(\n", " policy_path_general\n", " / \"learner_group\"\n", " / \"learner\"\n", " / \"rl_module\"\n", " / DEFAULT_MODULE_ID,\n", " )\n", "\n", " def policy(\n", " obs: list[float],\n", " deterministic: bool = True,\n", " embedded_masking: bool = False,\n", " masking_function: Union[Callable[list[float], list[float]], None] = None,\n", " ) -> int:\n", " \"\"\"Policy function that takes observations and returns actions.\n", "\n", " Args:\n", " obs: A list of observations.\n", " deterministic: If True, use argmax for action selection; otherwise, sample from the action distribution.\n", " embedded_masking: If True, model was trained with embedded masking and observation needs to be modified.\n", " masking_function: A function that takes observations and returns a mask for valid actions. If None, no masking is applied.\n", " Returns:\n", " An integer representing the selected action.\n", " \"\"\"\n", " if isinstance(obs, dict):\n", " obs_vec = obs[\"observations\"]\n", " else:\n", " obs_vec = obs\n", " obs_vec = np.array(obs_vec, dtype=np.float32)\n", " if not embedded_masking:\n", " input_dict = {Columns.OBS: torch.from_numpy(obs_vec).unsqueeze(0)}\n", "\n", " else:\n", " if isinstance(obs, list) or isinstance(obs, np.ndarray):\n", " mask = np.ones(35, dtype=np.float32)\n", " if masking_function is not None:\n", " # If possible, masking is applies inside the RLModule\n", " mask = masking_function(obs_vec)\n", "\n", " input_dict = {\n", " Columns.OBS: {\n", " \"observations\": torch.from_numpy(obs).unsqueeze(0),\n", " \"action_mask\": torch.from_numpy(mask).unsqueeze(0),\n", " }\n", " }\n", " else:\n", " input_dict = {\n", " Columns.OBS: {\n", " \"observations\": torch.from_numpy(obs[\"observations\"]).unsqueeze(\n", " 0\n", " ),\n", " \"action_mask\": torch.from_numpy(obs[\"action_mask\"]).unsqueeze(\n", " 0\n", " ),\n", " }\n", " }\n", "\n", " rl_module_out = rl_module.forward_inference(input_dict)\n", " logits = convert_to_numpy(rl_module_out[Columns.ACTION_DIST_INPUTS])\n", " if not embedded_masking and masking_function is not None:\n", " mask = masking_function(obs)\n", " inf_mask = torch.clamp(torch.log(mask), min=FLOAT_MIN)\n", " logits[0] += inf_mask.numpy()\n", " if deterministic:\n", " action = np.argmax(logits[0]) # Use argmax for deterministic action\n", " else:\n", " action = np.random.choice(len(logits[0]), p=softmax(logits[0]))\n", "\n", " return int(action)\n", "\n", " return policy" ] }, { "cell_type": "markdown", "id": "cb2367ae", "metadata": {}, "source": [ "Testing is then performed specifying the desired shield method and type. The hidden targets are also re-computed at every step." ] }, { "cell_type": "code", "execution_count": null, "id": "d1ee4d25", "metadata": {}, "outputs": [], "source": [ "# Loading the policy produced by tune.run()\n", "policy_path = pathlib.Path(results.get_last_checkpoint().to_directory())\n", "\n", "shield_type = \"handmade\"\n", "shield_mode = \"postposed\"\n", "targets_max = 100 # Originally varying from (100, 3000)\n", "total_targets = 1000 # Originally 15,000\n", "embedded_masking = True # Necessary when training with action masking. Set to False is trained with action replacement\n", "\n", "env_case = \"no_failure_penalty\"\n", "\n", "env, _, _, indexes = setup_env(\n", " test=True,\n", " horizon=1.0, # Orbits. Originally 90\n", " n_targets=total_targets,\n", " target_distribution=\"cities\",\n", " env_case=env_case,\n", ")\n", "\n", "shield_function = generate_shield_functions(shield_type, shield_mode, indexes)\n", "policy = load_policy(policy_path)\n", "\n", "if shield_mode == \"postposed\":\n", " env = WrapperPostPosed(env, shield_function=shield_function)\n", "\n", "elif shield_mode == \"action_masking\":\n", " raise NotImplementedError(\n", " \"Action masking is not implemented in testing. Cases were tested with postposed shields\"\n", " )\n", "\n", "_, _ = env.reset()\n", "reward_cumulative = 0\n", "\n", "hidden_targets_mask = HiddenTargetsMask(\n", " list_targets=env.satellite.data_store.data.known,\n", " targets_max=targets_max,\n", ")\n", "\n", "env.satellite.hidden_targets = hidden_targets_mask.compute_mask(n_imaged=0)\n", "\n", "\n", "def replace_targets_filter(opp, sat):\n", " return opp[\"object\"] not in sat.hidden_targets\n", "\n", "\n", "env.satellite.add_access_filter(\n", " lambda opp, sat=env.satellite: replace_targets_filter(opp, sat)\n", ")\n", "\n", "while True:\n", "\n", " sat = env.satellite\n", "\n", " hidden_targets = hidden_targets_mask.compute_mask(sat.imaged)\n", " if hidden_targets is not None:\n", " sat.hidden_targets = hidden_targets\n", " sat.observation_builder.obs_dict_cache = None\n", "\n", " action = policy(sat.get_obs(), embedded_masking=embedded_masking)\n", "\n", " _, reward, terminated, truncated, _ = env.step(action)\n", "\n", " reward_cumulative += reward\n", "\n", " if terminated or truncated:\n", " break\n", "\n", "print(f\"Cumulative reward: {reward_cumulative}\")" ] } ], "metadata": { "kernelspec": { "display_name": ".venv_shields_example", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }