{ "cells": [ { "cell_type": "markdown", "id": "ce415bcb", "metadata": {}, "source": [ "# Training with Curriculum Learning\n", "\n", "This example script demonstrates how to use curriculum learning (CL) and domain randomization (DR) during training with RLlib. \n", "\n", "Three different examples of curriculums are shown, as well as the DR case. This example is part of the paper [Improving Robustenss of Autonomous Spacecraft\n", "Scheduling Using Curriculum Learning](https://hanspeterschaub.info/Papers/QuevedoMantovani2025.pdf) and of a future publication. In CL, a sequence of different tasks with increasing difficulty are presented to the agent during training. Each task is seen as a different Markov decision process (MDP). For this problem, each task is characterized by a satellite with different battery capacity and exposed to different external torques, which would lead to different transition probabilities in the MDP. \n", "\n", "## Load Modules" ] }, { "cell_type": "code", "execution_count": null, "id": "161c6a53", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from bsk_rl import act, data, obs, scene, sats\n", "from bsk_rl.sim import dyn, fsw, world\n", "from bsk_rl.gym import SatelliteTasking\n", "from typing import Any, Callable, Optional, TypeVar\n", "from bsk_rl.utils.rllib.callbacks import WrappedEpisodeDataCallbacks, EpisodeDataWrapper\n", "from ray.rllib.algorithms.ppo import PPOConfig\n", "import time\n", "import ray\n", "from ray import tune\n", "from bsk_rl.sats import Satellite\n", "from ray.tune.registry import register_env\n", "from Basilisk.architecture import bskLogging\n", "\n", "bskLogging.setDefaultLogLevel(bskLogging.BSK_WARNING)\n", "\n", "SatObs = TypeVar(\"SatObs\")\n", "MultiSatObs = tuple[SatObs, ...]\n", "SatArgRandomizer = Callable[[list[Satellite]], dict[Satellite, dict[str, Any]]]" ] }, { "cell_type": "markdown", "id": "40303b15", "metadata": {}, "source": [ "## Creating an Environment with CL\n", "\n", "In this example, the [SatelliteTasking](../api_reference/index.rst) environment is modified to allow changes to the spacecraft parameters during training. Two extra method `set_task` and `get_task` are introduced to set the difficulty and get the difficulty of the environment. Additionally, `update_sat_params` is used to change specific spacecraft arguments as a function of the difficulty and is called before each environment reset." ] }, { "cell_type": "code", "execution_count": null, "id": "7623a908", "metadata": {}, "outputs": [], "source": [ "class SatelliteTaskingCL(SatelliteTasking):\n", "\n", " def __init__(\n", " self,\n", " satellite: Satellite,\n", " *args,\n", " difficulty=0.0,\n", " CL_params={},\n", " **kwargs,\n", " ):\n", "\n", " super().__init__(\n", " satellite,\n", " *args,\n", " **kwargs,\n", " )\n", "\n", " self.difficulty = difficulty\n", " self.CL_params = CL_params\n", "\n", " def reset(\n", " self,\n", " seed: Optional[int] = None,\n", " options=None,\n", " ) -> tuple[MultiSatObs, dict[str, Any]]:\n", "\n", " self.update_sat_params() # Update satellite parameters based on difficulty before resetting\n", " obs, info = super().reset(seed=seed, options=options)\n", " return obs, info\n", "\n", " def update_sat_params(self):\n", " \"\"\"\n", " Update the satellite parameters based on the difficulty level.\n", " \"\"\"\n", " if self.CL_params is not None:\n", " for satellite in self.satellites:\n", " for key, value in self.CL_params.items():\n", " if key in satellite.sat_args_generator:\n", " satellite.sat_args_generator[key] = value(self.difficulty)\n", " else:\n", " setattr(self, key, round(value(self.get_task())))\n", "\n", " def set_task(self, task):\n", " \"\"\"\n", " Set the difficulty level.\n", " \"\"\"\n", " self.difficulty = task\n", "\n", " def get_task(self):\n", " \"\"\"\n", " Get the current difficulty level.\n", " \"\"\"\n", " return self.difficulty" ] }, { "cell_type": "markdown", "id": "20f54e0d", "metadata": {}, "source": [ "## Registering the Custom Environment\n", "\n", "Since a custom environment was created, it needs to be registered and made it compatible with RLlib." ] }, { "cell_type": "code", "execution_count": null, "id": "6050760f", "metadata": {}, "outputs": [], "source": [ "def _satellite_tasking_env_creator(env_config):\n", " \"\"\"\n", " Create an environment compatible with RLlib.\n", " \"\"\"\n", "\n", " if \"episode_data_callback\" in env_config:\n", " episode_data_callback = env_config.pop(\"episode_data_callback\")\n", " else:\n", " episode_data_callback = None\n", " if \"satellite_data_callback\" in env_config:\n", " satellite_data_callback = env_config.pop(\"satellite_data_callback\")\n", " else:\n", " satellite_data_callback = None\n", "\n", " return EpisodeDataWrapper(\n", " SatelliteTaskingCL(**env_config),\n", " episode_data_callback=episode_data_callback,\n", " satellite_data_callback=satellite_data_callback,\n", " )\n", "\n", "\n", "register_env(\"SatelliteTaskingCL-RLlib\", _satellite_tasking_env_creator)" ] }, { "cell_type": "markdown", "id": "9b175a59", "metadata": {}, "source": [ "## Creating the Scanning Satellite\n", "\n", "A nadir scanning satellite is created with personalized observation space properties, including the angle between the solar panels and the sun, and the angle between the instrument and nadir. A custom dynamics module is introduced to combine [\"GroundStationDynModel\"](../api_reference/sim/dyn.rst) and [\"ContinuousImagingDynModel\"](../api_reference/sim/dyn.rst), allowing for scanning and downlink actions." ] }, { "cell_type": "code", "execution_count": null, "id": "36272859", "metadata": {}, "outputs": [], "source": [ "def attitude_error_norm(sat) -> float:\n", " # Calculate this using the instrument unit vector and the spacecraft position\n", " # in inertial frame (-r_BN_P) and c_hat_P (get angle between then)\n", " r_BN_P_unit = sat.dynamics.r_BN_P / np.linalg.norm(sat.dynamics.r_BN_P)\n", " c_hat_P = sat.dynamics.satellite.fsw.c_hat_P # Instrument unit vector in ECEF frame\n", " error_angle = np.arccos(np.dot(-r_BN_P_unit, c_hat_P))\n", "\n", " return error_angle / np.pi\n", "\n", "\n", "def solar_angle_norm(sat) -> float:\n", " a = (\n", " sat.dynamics.world.gravFactory.spiceObject.planetStateOutMsgs[\n", " sat.dynamics.world.sun_index\n", " ]\n", " .read()\n", " .PositionVector\n", " )\n", " a_hat = a / np.linalg.norm(a)\n", " b = np.array([0, 0, -1]) # Solar panel opposite to instrument\n", " mat = np.transpose(sat.dynamics.BN)\n", " b_N = np.matmul(mat, b)\n", " error_angle = np.arccos(np.dot(b_N, a_hat))\n", "\n", " return error_angle / np.pi\n", "\n", "\n", "class CustomDynamics(dyn.GroundStationDynModel, dyn.ContinuousImagingDynModel):\n", " pass\n", "\n", "\n", "class ScanningSatellite(sats.AccessSatellite):\n", " observation_spec = [\n", " obs.SatProperties(\n", " dict(prop=\"wheel_speeds_fraction\"),\n", " dict(prop=\"battery_charge_fraction\"),\n", " dict(prop=\"storage_level_fraction\"),\n", " dict(prop=\"attitude_error_norm\", fn=attitude_error_norm),\n", " dict(prop=\"solar_angle_norm\", fn=solar_angle_norm),\n", " ),\n", " obs.Eclipse(norm=5700.0),\n", " obs.OpportunityProperties(\n", " dict(prop=\"opportunity_open\", norm=5700.0),\n", " dict(prop=\"opportunity_close\", norm=5700.0),\n", " type=\"ground_station\",\n", " n_ahead_observe=1,\n", " ),\n", " ]\n", " action_spec = [\n", " act.Scan(duration=180.0), # Scan for 3 minute\n", " act.Charge(duration=180.0), # Charge for 3 minutes\n", " act.Downlink(duration=180.0), # Downlink for 3 minute\n", " act.Desat(duration=180.0), # Desaturate for 3 minute\n", " ]\n", " dyn_type = CustomDynamics\n", " fsw_type = fsw.ContinuousImagingFSWModel" ] }, { "cell_type": "markdown", "id": "5c14fdb1", "metadata": {}, "source": [ "## Defining Curriculum Function\n", "\n", "The following functions are used to define how the satellite parameters vary as a function of the difficulty during training. For these cases, the difficulty is assumed to be between 0 and 1. Direct, inverse, and constant curriculums can be defined based on the initial and final levels." ] }, { "cell_type": "code", "execution_count": null, "id": "2824840f", "metadata": {}, "outputs": [], "source": [ "def capacity_fn(time_seed, init_val, final_val, difficulty):\n", " \"\"\"\n", " Function to calculate the capacity of the a given satellite property (e.g. battery, storage, etc) based on the difficulty level.\n", "\n", " Args:\n", " time_seed (float, optional): Seed for random number generation. If None, CL will be used. Otherwise, DR will be used.\n", " init_val (float): Initial value of the capacity.\n", " final_val (float): Final value of the capacity.\n", " difficulty (float): Difficulty level.\n", "\n", " Returns:\n", " float: Capacity of the satellite.\n", " \"\"\"\n", "\n", " if time_seed is not None:\n", " random_generator = np.random.default_rng(\n", " seed=int(time_seed * 100) * int(difficulty * 10000)\n", " )\n", " return random_generator.uniform(init_val, final_val)\n", " else:\n", " return init_val - (init_val - final_val) * difficulty\n", "\n", "\n", "def capacity_init_fn(time_seed, init_val, final_val, difficulty, max_init, min_init):\n", " \"\"\"\n", " Function to calculate the initial capacity of the a given satellite property (e.g. battery, storage, etc) based on the difficulty level.\n", " This function is necessary since the capacity is not constant and can change based on the difficulty level.\n", "\n", " Args:\n", " time_seed (float, optional): Seed for random number generation. If None, CL will be used. Otherwise, DR will be used.\n", " init_val (float): Initial value of the capacity.\n", " final_val (float): Final value of the capacity.\n", " difficulty (float): Difficulty level.\n", " max_init (float): Maximum initial value of the capacity.\n", " min_init (float): Minimum initial value of the capacity.\n", " Returns:\n", " float: Initial level of the given satellite property.\n", " \"\"\"\n", "\n", " if time_seed is not None:\n", " random_generator = np.random.default_rng(\n", " seed=int(time_seed * 100) * int(difficulty * 10000)\n", " )\n", " capacity = random_generator.uniform(init_val, final_val)\n", " return np.random.uniform(min_init, max_init) * capacity\n", " else:\n", " capacity = init_val - (init_val - final_val) * difficulty\n", " return np.random.uniform(min_init, max_init) * capacity\n", "\n", "\n", "def random_disturbance_vector(magnitude_disturbance, seed=None):\n", " \"\"\"\n", " Function to generate a random disturbance vector with a given magnitude.\n", "\n", " Args:\n", " magnitude_disturbance (float): Magnitude of the disturbance vector.\n", " seed (int, optional): Seed for random number generation. Defaults to None.\n", " Returns:\n", " np.ndarray: Random disturbance vector with the given magnitude.\n", " \"\"\"\n", "\n", " disturbance_rand_vector = np.random.normal(size=3)\n", " disturbance_rand_unit_vector = disturbance_rand_vector / np.linalg.norm(\n", " disturbance_rand_vector\n", " )\n", " disturbance_vector = disturbance_rand_unit_vector * magnitude_disturbance\n", " return disturbance_vector\n", "\n", "\n", "def external_disturbance_fn(time_seed, init_val, final_val, difficulty):\n", " \"\"\"\n", " Function to calculate the external disturbance vector based on the difficulty level.\n", "\n", " Args:\n", " time_seed (float, optional): Seed for random number generation. If None, CL will be used. Otherwise, DR will be used.\n", " init_val (float): Initial value of the disturbance vector.\n", " final_val (float): Final value of the disturbance vector.\n", " difficulty (float): Difficulty level.\n", " Returns:\n", " np.ndarray: External disturbance vector.\n", " \"\"\"\n", "\n", " if time_seed is not None:\n", " random_generator = np.random.default_rng(\n", " seed=int(time_seed * 100) * int(difficulty * 10000)\n", " )\n", " disturbance_mag = random_generator.uniform(init_val, final_val)\n", " return random_disturbance_vector(disturbance_mag)\n", " else:\n", " disturbance_mag = init_val - (init_val - final_val) * difficulty\n", " return random_disturbance_vector(disturbance_mag)" ] }, { "cell_type": "markdown", "id": "888157e6", "metadata": {}, "source": [ "## Custom Callback to Enable CL\n", "\n", "A custom `Callback` function is required to enable CL. The `CLCallbacks` reads the number of trained steps from the environment and determines the task (difficulty). Here, different functions could be used to implement more complex curriculums instead of a linear function, such as spring mass dynamics.\n", "\n", "A custom `episode_data_callback` is also defined to collect information about the agent and the curriculum during training." ] }, { "cell_type": "code", "execution_count": null, "id": "3a9a2ba3", "metadata": {}, "outputs": [], "source": [ "class CLCallbacks(WrappedEpisodeDataCallbacks):\n", "\n", " def on_episode_start(\n", " self,\n", " *,\n", " episode,\n", " worker=None,\n", " env_runner=None,\n", " metrics_logger=None,\n", " base_env=None,\n", " env=None,\n", " policies=None,\n", " rl_module=None,\n", " env_index,\n", " **kwargs,\n", " ) -> None:\n", "\n", " try:\n", " n_steps = metrics_logger.peek(\"num_env_steps_sampled_lifetime\")\n", " if n_steps is None:\n", " task = 0.0\n", " else:\n", " task = n_steps / 5_000_000 # 5M steps = 1.0 difficulty\n", " except KeyError:\n", " task = 0.0\n", "\n", " env.envs[env_index].unwrapped.set_task(task)\n", "\n", "\n", "def episode_data_callback(env):\n", " reward = env.rewarder.cum_reward\n", " reward = sum(reward.values()) / len(reward)\n", " orbits = env.simulator.sim_time / (95 * 60)\n", "\n", " data_log = dict(\n", " reward=reward,\n", " # Are satellites dying, and how and when?\n", " alive=float(env.satellites[0].is_alive()),\n", " rw_status_valid=float(env.satellites[0].dynamics.rw_speeds_valid()),\n", " battery_status_valid=float(env.satellites[0].dynamics.battery_valid()),\n", " orbits_complete=orbits,\n", " # Is CL working? How is it varying during training?\n", " difficulty=env.get_task(),\n", " battery_capacity=env.satellites[0].dynamics.powerMonitor.storageCapacity,\n", " external_torque=np.linalg.norm(\n", " env.satellites[0].dynamics.extForceTorqueObject.extTorquePntB_B\n", " ),\n", " )\n", " if orbits > 0:\n", " data_log[\"reward_per_orbit\"] = reward / orbits\n", " if not env.satellites[0].is_alive():\n", " data_log[\"orbits_complete_partial_only\"] = orbits\n", "\n", " return data_log" ] }, { "cell_type": "markdown", "id": "acbf9874", "metadata": {}, "source": [ "## Defining Satellite, Environment, and CL Options\n", "\n", "Two different environment configurations are defined, the `standard_90` and `degraded_90`, which can be used for training and testing. Additionally, different initialization ranges can be defined for the parameters during reset. Here, `nominal` corresponds to parameters being initialized in a range near their nominal operation values. In `wide`, parameters can vary from 0% to 100%. \n", "\n", "Different CL and DR levels are also defined to be chosen from. Each case can include several different parameters from the spacecraft, each with different CL levels." ] }, { "cell_type": "code", "execution_count": null, "id": "c818db37", "metadata": {}, "outputs": [], "source": [ "sat_config = dict(\n", " standard_90=dict(\n", " # Nominal env parameters\n", " intervals=90,\n", " batteryStorageCapacity=400 * 3600, # in Ws\n", " disturbance_vector_mag=0.0002,\n", " panelEfficiency=0.2,\n", " ),\n", " degraded_90=dict(\n", " # Degraded env parameters\n", " intervals=90,\n", " batteryStorageCapacity=400 * 3600 * 0.5, # in Ws\n", " disturbance_vector_mag=0.0002 * 3,\n", " panelEfficiency=0.2 * 0.75,\n", " ),\n", " # Other sat parameters common to all\n", " sat_params=dict(\n", " imageAttErrorRequirement=0.1, # norm of MRP ~ 20 degree\n", " imageRateErrorRequirement=0.1, # norm of angular velocity (rad/s)\n", " dataStorageCapacity=5000 * 8e6, # in bits\n", " instrumentPowerDraw=-30.0, # in Watts\n", " instrumentBaudRate=0.5e6, # bits per second\n", " transmitterPowerDraw=-25.0, # in Watts\n", " transmitterBaudRate=-112.0e6, # bits per second #size it to downlink in one downlink opportunity\n", " rwMechToElecEfficiency=0.0,\n", " rwElecToMechEfficiency=0.5,\n", " thrusterPowerDraw=-80.0,\n", " rwBasePower=10.0,\n", " maxWheelSpeed=6000, # RPM\n", " desatAttitude=\"nadir\",\n", " K=3.5, # Derivative control gain (attitude)\n", " Ki=-1, # Integral gain (turned off)\n", " P=17.5, # Proportional gain (attitude))\n", " ),\n", ")\n", "\n", "init_range_options = dict(\n", " nominal=dict(\n", " battery_init_range=[0.375, 0.625],\n", " data_storage_init_range=[0, 1],\n", " reaction_wheel_init_range=[-4000, 4000], # RPM\n", " ),\n", " wide=dict(\n", " battery_init_range=[0, 1],\n", " data_storage_init_range=[0, 1],\n", " reaction_wheel_init_range=[-6000, 6000], # RPM\n", " ),\n", ")\n", "\n", "CL_options = dict(\n", " constant_BT_high=dict(\n", " battery={\n", " \"name\": \"batteryStorageCapacity\",\n", " \"init_val\": 0.40,\n", " \"final_val\": 0.40,\n", " \"init_range_config\": \"battery_init_range\",\n", " \"name_init\": \"storedCharge_Init\",\n", " \"domain_randomization\": False,\n", " },\n", " torque={\n", " \"name\": \"disturbance_vector_mag\",\n", " \"var_name\": \"disturbance_vector\",\n", " \"init_val\": 8.0,\n", " \"final_val\": 8.0,\n", " \"domain_randomization\": False,\n", " },\n", " ),\n", " direct_BT_high=dict(\n", " battery={\n", " \"name\": \"batteryStorageCapacity\",\n", " \"init_val\": 1.00,\n", " \"final_val\": 0.40,\n", " \"init_range_config\": \"battery_init_range\",\n", " \"name_init\": \"storedCharge_Init\",\n", " \"domain_randomization\": False,\n", " },\n", " torque={\n", " \"name\": \"disturbance_vector_mag\",\n", " \"var_name\": \"disturbance_vector\",\n", " \"init_val\": 1.0,\n", " \"final_val\": 8.0,\n", " \"domain_randomization\": False,\n", " },\n", " ),\n", " inverse_BT_high=dict(\n", " battery={\n", " \"name\": \"batteryStorageCapacity\",\n", " \"init_val\": 0.40,\n", " \"final_val\": 1.0,\n", " \"init_range_config\": \"battery_init_range\",\n", " \"name_init\": \"storedCharge_Init\",\n", " \"domain_randomization\": False,\n", " },\n", " torque={\n", " \"name\": \"disturbance_vector_mag\",\n", " \"var_name\": \"disturbance_vector\",\n", " \"init_val\": 8.0,\n", " \"final_val\": 1.0,\n", " \"domain_randomization\": False,\n", " },\n", " ),\n", " DR_BT_high=dict(\n", " battery={\n", " \"name\": \"batteryStorageCapacity\",\n", " \"init_val\": 0.40,\n", " \"final_val\": 1.00,\n", " \"init_range_config\": \"battery_init_range\",\n", " \"name_init\": \"storedCharge_Init\",\n", " \"domain_randomization\": True,\n", " },\n", " torque={\n", " \"name\": \"disturbance_vector_mag\",\n", " \"var_name\": \"disturbance_vector\",\n", " \"init_val\": 1.0,\n", " \"final_val\": 8.0,\n", " \"domain_randomization\": True,\n", " },\n", " ),\n", ")" ] }, { "cell_type": "markdown", "id": "171d449a", "metadata": {}, "source": [ "## Choosing Curriculum for Training\n", "\n", "Here, the `direct_BT_high` is selected with a `nominal` initialization range and `standard` environment with each episode lasting at most 90 steps." ] }, { "cell_type": "code", "execution_count": null, "id": "bccbd44f", "metadata": {}, "outputs": [], "source": [ "CL_params = {}\n", "CL_enabled = True\n", "CL_case = \"direct_BT_high\"\n", "initialization_range = \"nominal\"\n", "environment_mode = \"standard_90\"\n", "\n", "sat = ScanningSatellite(\n", " \"Scanner-1\",\n", " sat_args=dict(\n", " **sat_config[\"sat_params\"],\n", " batteryStorageCapacity=sat_config[environment_mode][\"batteryStorageCapacity\"],\n", " disturbance_vector=lambda: random_disturbance_vector(\n", " sat_config[environment_mode][\"disturbance_vector_mag\"]\n", " ),\n", " panelEfficiency=sat_config[environment_mode][\"panelEfficiency\"],\n", " ),\n", ")\n", "\n", "duration = (\n", " sat_config[environment_mode][\"intervals\"] * 180\n", ") # intervals of 180 seconds (3 minutes)" ] }, { "cell_type": "markdown", "id": "8bd9ba26", "metadata": {}, "source": [ "## Assigning Curriculum Functions\n", "\n", "After selecting the curriculum, the code below will populate the `CL_params` dictionary with functions, specifying how each of the parameters will vary during training." ] }, { "cell_type": "code", "execution_count": null, "id": "c223c5af", "metadata": {}, "outputs": [], "source": [ "if CL_enabled:\n", " for key in CL_options[CL_case].keys():\n", " current_time = time.time()\n", "\n", " if CL_options[CL_case][key][\"domain_randomization\"] is False:\n", " current_time = None\n", " else:\n", " current_time = time.time()\n", "\n", " if key == \"torque\":\n", " capacity = sat_config[environment_mode][CL_options[CL_case][key][\"name\"]]\n", " init_val = CL_options[CL_case][key][\"init_val\"]\n", " final_val = CL_options[CL_case][key][\"final_val\"]\n", " CL_params[CL_options[CL_case][key][\"var_name\"]] = (\n", " lambda difficulty, capacity=capacity, init_val=init_val, final_val=final_val, time_seed=current_time: external_disturbance_fn(\n", " time_seed,\n", " capacity * init_val,\n", " capacity * final_val,\n", " difficulty,\n", " )\n", " )\n", "\n", " else:\n", " capacity = sat_config[environment_mode][CL_options[CL_case][key][\"name\"]]\n", " init_val = CL_options[CL_case][key][\"init_val\"]\n", " final_val = CL_options[CL_case][key][\"final_val\"]\n", " if \"var_name\" in CL_options[CL_case][key].keys():\n", " temp_name = CL_options[CL_case][key][\"var_name\"]\n", " else:\n", " temp_name = CL_options[CL_case][key][\"name\"]\n", " CL_params[temp_name] = (\n", " lambda difficulty, capacity=capacity, init_val=init_val, final_val=final_val, time_seed=current_time: capacity_fn(\n", " time_seed,\n", " capacity * init_val,\n", " capacity * final_val,\n", " difficulty,\n", " )\n", " )\n", " if \"name_init\" in CL_options[CL_case][key]:\n", " init_range = init_range_options[initialization_range][\n", " CL_options[CL_case][key][\"init_range_config\"]\n", " ]\n", " init_val = CL_options[CL_case][key][\"init_val\"]\n", " final_val = CL_options[CL_case][key][\"final_val\"]\n", " CL_params[CL_options[CL_case][key][\"name_init\"]] = (\n", " lambda difficulty, capacity=capacity, init_val=init_val, final_val=final_val, init_range=init_range, time_seed=current_time: capacity_init_fn(\n", " time_seed,\n", " capacity * init_val,\n", " capacity * final_val,\n", " difficulty,\n", " init_range[1],\n", " init_range[0],\n", " )\n", " )" ] }, { "cell_type": "markdown", "id": "bc7869c2", "metadata": {}, "source": [ "## Training\n", "\n", "Training is performed using ray tune. Usually, the `num_env_steps_sampled_lifetime` should be set similar to the number of training steps in `CLCallbacks`. Originally, the paper [Improving Robustenss of Autonomous Spacecraft Scheduling Using Curriculum Learning](https://hanspeterschaub.info/Papers/QuevedoMantovani2025.pdf) used the APPO algorithm with generalized advantage estimation instead of PPO." ] }, { "cell_type": "code", "execution_count": null, "id": "67701d79", "metadata": {}, "outputs": [], "source": [ "N_CPUS = 3\n", "\n", "env_args = dict(\n", " satellite=sat,\n", " scenario=scene.UniformNadirScanning(value_per_second=1 / duration),\n", " rewarder=data.ScanningTimeReward(),\n", " world_type=world.GroundStationWorldModel,\n", " time_limit=duration,\n", " failure_penalty=-1.0,\n", " difficulty=0.0,\n", " CL_params=CL_params,\n", ")\n", "\n", "training_args = dict(\n", " lr=0.00003,\n", " gamma=0.999,\n", " train_batch_size=250, # originally 10,000\n", " num_sgd_iter=50,\n", " model=dict(fcnet_hiddens=[512, 512], vf_share_layers=False),\n", " lambda_=0.95,\n", " use_kl_loss=False,\n", " entropy_coeff=0.0,\n", " clip_param=0.2,\n", " grad_clip=0.5,\n", ")\n", "\n", "config = (\n", " PPOConfig()\n", " .training(**training_args)\n", " .env_runners(num_env_runners=N_CPUS - 1, sample_timeout_s=1000.0)\n", " .environment(\n", " env=\"SatelliteTaskingCL-RLlib\",\n", " env_config=dict(**env_args, episode_data_callback=episode_data_callback),\n", " )\n", " .reporting(\n", " metrics_num_episodes_for_smoothing=1,\n", " metrics_episode_collection_timeout_s=180,\n", " )\n", " .checkpointing(export_native_model_files=True)\n", " .framework(framework=\"torch\")\n", " .api_stack(\n", " enable_rl_module_and_learner=True,\n", " enable_env_runner_and_connector_v2=True,\n", " )\n", " .callbacks(CLCallbacks)\n", " # .evaluation(evaluation_interval=10, evaluation_duration=1, evaluation_parallel_to_training=True, evaluation_config={\"env\": unpack_config(env_class), \"env_config\": nominal_env_args, \"explore\":False}, evaluation_num_workers=1, always_attach_evaluation_results=True) #An evaluation environment can be configured with parameters different from the training environment by specifying the `nominal_env_args` argument. This is useful for evaluating the performance of the agent in a different environment than the one it was trained in.\n", ")\n", "\n", "ray.init(\n", " ignore_reinit_error=True,\n", " num_cpus=N_CPUS,\n", " object_store_memory=2_000_000_000, # 2 GB\n", ")\n", "\n", "# Run the training\n", "results = tune.run(\n", " \"PPO\",\n", " config=config.to_dict(),\n", " stop={\n", " \"num_env_steps_sampled_lifetime\": 750\n", " }, # Total number of steps to train the model. Originally 5M\n", " checkpoint_freq=10,\n", " checkpoint_at_end=True,\n", ")\n", "\n", "# Shutdown Ray\n", "ray.shutdown()" ] }, { "cell_type": "markdown", "id": "14c62416", "metadata": {}, "source": [ "## Checking Difficulty Over Training\n", "\n", "After a few training steps, the difficulty started to increase" ] }, { "cell_type": "code", "execution_count": null, "id": "0f4d6c5a", "metadata": {}, "outputs": [], "source": [ "results.results[list(results.results.keys())[0]][\"env_runners\"][\"difficulty\"]" ] } ], "metadata": { "kernelspec": { "display_name": ".venv_CL_example", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }