{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Time-Discounted GAE\n", "In semi-MDPs, each step has an associated duration. Instead of the usual value equation\n", "\n", "\\begin{equation}\n", "V(s_1) = r_1 + \\gamma r_2 + \\gamma^2 r_3 + ...\n", "\\end{equation}\n", "\n", "one discount based on step duration\n", "\n", "\\begin{equation}\n", "V_{\\Delta t}(s_1) = \\gamma^{\\Delta t_1} r_1 + \\gamma^{\\Delta t_1 + \\Delta t_2} r_2 + \\gamma^{\\Delta t_1 + \\Delta t_2 + \\Delta t_3} r_3 + ...\n", "\\end{equation}\n", "\n", "using the convention that reward is given at the end of a step.\n", "\n", "The generalized advantage estimator can be rewritten accordingly. In our implementation,\n", "the exponential decay `lambda` is per-step (as opposed to timewise).\n", "\n", "## RLlib Version\n", "RLlib is actively developed and can change significantly from version to version. For this\n", "script, the following version is used:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from importlib.metadata import version\n", "\n", "version(\"ray\") # Parent package of RLlib" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define the Environment\n", "A simple single-satellite environment is defined, as in :doc:`examples/rllib_training`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from bsk_rl import act, data, obs, sats, scene\n", "from bsk_rl.sim import dyn, fsw\n", "\n", "\n", "class ScanningDownlinkDynModel(\n", " dyn.ContinuousImagingDynModel, dyn.GroundStationDynModel\n", "):\n", " # Define some custom properties to be accessed in the state\n", " @property\n", " def instrument_pointing_error(self) -> float:\n", " r_BN_P_unit = self.r_BN_P / np.linalg.norm(self.r_BN_P)\n", " c_hat_P = self.satellite.fsw.c_hat_P\n", " return np.arccos(np.dot(-r_BN_P_unit, c_hat_P))\n", "\n", " @property\n", " def solar_pointing_error(self) -> float:\n", " a = (\n", " self.world.gravFactory.spiceObject.planetStateOutMsgs[self.world.sun_index]\n", " .read()\n", " .PositionVector\n", " )\n", " a_hat_N = a / np.linalg.norm(a)\n", " nHat_B = self.satellite.sat_args[\"nHat_B\"]\n", " NB = np.transpose(self.BN)\n", " nHat_N = NB @ nHat_B\n", " return np.arccos(np.dot(nHat_N, a_hat_N))\n", "\n", "\n", "class ScanningSatellite(sats.AccessSatellite):\n", " observation_spec = [\n", " obs.SatProperties(\n", " dict(prop=\"storage_level_fraction\"),\n", " dict(prop=\"battery_charge_fraction\"),\n", " dict(prop=\"wheel_speeds_fraction\"),\n", " dict(prop=\"instrument_pointing_error\", norm=np.pi),\n", " dict(prop=\"solar_pointing_error\", norm=np.pi),\n", " ),\n", " obs.OpportunityProperties(\n", " dict(prop=\"opportunity_open\", norm=5700),\n", " dict(prop=\"opportunity_close\", norm=5700),\n", " type=\"ground_station\",\n", " n_ahead_observe=1,\n", " ),\n", " obs.Eclipse(norm=5700),\n", " ]\n", " action_spec = [\n", " act.Scan(duration=180.0),\n", " act.Charge(duration=120.0),\n", " act.Downlink(duration=60.0),\n", " act.Desat(duration=60.0),\n", " ]\n", " dyn_type = ScanningDownlinkDynModel\n", " fsw_type = fsw.ContinuousImagingFSWModel\n", "\n", "\n", "sat = ScanningSatellite(\n", " \"Scanner-1\",\n", " sat_args=dict(\n", " # Data\n", " dataStorageCapacity=5000 * 8e6, # bits\n", " storageInit=lambda: np.random.uniform(0.0, 0.8) * 5000 * 8e6,\n", " instrumentBaudRate=0.5 * 8e6,\n", " transmitterBaudRate=-50 * 8e6,\n", " # Power\n", " batteryStorageCapacity=200 * 3600, # W*s\n", " storedCharge_Init=lambda: np.random.uniform(0.3, 1.0) * 200 * 3600,\n", " basePowerDraw=-10.0, # W\n", " instrumentPowerDraw=-30.0, # W\n", " transmitterPowerDraw=-25.0, # W\n", " thrusterPowerDraw=-80.0, # W\n", " panelArea=0.25,\n", " # Attitude\n", " imageAttErrorRequirement=0.1,\n", " imageRateErrorRequirement=0.1,\n", " disturbance_vector=lambda: np.random.normal(scale=0.0001, size=3), # N*m\n", " maxWheelSpeed=6000.0, # RPM\n", " wheelSpeeds=lambda: np.random.uniform(-3000, 3000, 3),\n", " desatAttitude=\"nadir\",\n", " ),\n", ")\n", "duration = 5 * 5700.0 # About 5 orbits\n", "env_args = dict(\n", " satellite=sat,\n", " scenario=scene.UniformNadirScanning(value_per_second=1 / duration),\n", " rewarder=data.ScanningTimeReward(),\n", " time_limit=duration,\n", " failure_penalty=-1.0,\n", " terminate_on_time_limit=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## RLlib Configuration\n", "\n", "The configuration is mostly the same as in the standard example." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import bsk_rl.utils.rllib # noqa To access \"SatelliteTasking-RLlib\"\n", "from ray.rllib.algorithms.ppo import PPOConfig\n", "\n", "\n", "N_CPUS = 3\n", "\n", "training_args = dict(\n", " lr=0.00003,\n", " gamma=0.999,\n", " train_batch_size=250,\n", " num_sgd_iter=10,\n", " model=dict(fcnet_hiddens=[512, 512], vf_share_layers=False),\n", " lambda_=0.95,\n", " use_kl_loss=False,\n", " clip_param=0.1,\n", " grad_clip=0.5,\n", " reward_time=\"step_end\",\n", ")\n", "\n", "config = (\n", " PPOConfig()\n", " .env_runners(num_env_runners=N_CPUS - 1, sample_timeout_s=1000.0)\n", " .environment(\n", " env=\"SatelliteTasking-RLlib\",\n", " env_config=env_args,\n", " )\n", " .reporting(\n", " metrics_num_episodes_for_smoothing=1,\n", " metrics_episode_collection_timeout_s=180,\n", " )\n", " .checkpointing(export_native_model_files=True)\n", " .framework(framework=\"torch\")\n", " .api_stack(\n", " enable_rl_module_and_learner=True,\n", " enable_env_runner_and_connector_v2=True,\n", " )\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Rewards can also be distributed at the start of the step by setting ``reward_time=\"step_start\"``.\n", "\n", "The additional setting that must be configured is the appropriate learner class. This \n", "uses the `d_ts` key from the info dict to discount by the step length, not just the step\n", "count." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from bsk_rl.utils.rllib.discounting import TimeDiscountedGAEPPOTorchLearner\n", "\n", "config.training(learner_class=TimeDiscountedGAEPPOTorchLearner)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Training can then proceed as normal." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import ray\n", "from ray import tune\n", "\n", "ray.init(\n", " ignore_reinit_error=True,\n", " num_cpus=N_CPUS,\n", " object_store_memory=2_000_000_000, # 2 GB\n", ")\n", "\n", "# Run the training\n", "tune.run(\n", " \"PPO\",\n", " config=config.to_dict(),\n", " stop={\"training_iteration\": 2}, # Adjust the number of iterations as needed\n", ")\n", "\n", "# Shutdown Ray\n", "ray.shutdown()" ] } ], "metadata": { "kernelspec": { "display_name": ".venv_refactor", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.11" } }, "nbformat": 4, "nbformat_minor": 2 }