{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Time-Discounted GAE\n", "In semi-MDPs, each step has an associated duration. Instead of the usual value equation\n", "\n", "\\begin{equation}\n", "V(s_1) = r_1 + \\gamma r_2 + \\gamma^2 r_3 + ...\n", "\\end{equation}\n", "\n", "one discount based on step duration\n", "\n", "\\begin{equation}\n", "V_{\\Delta t}(s_1) = \\gamma^{\\Delta t_1} r_1 + \\gamma^{\\Delta t_1 + \\Delta t_2} r_2 + \\gamma^{\\Delta t_1 + \\Delta t_2 + \\Delta t_3} r_3 + ...\n", "\\end{equation}\n", "\n", "using the convention that reward is given at the end of a step.\n", "\n", "The generalized advantage estimator can be rewritten accordingly. In our implementation,\n", "the exponential decay `lambda` is per-step (as opposed to timewise).\n", "\n", "## RLlib Version\n", "RLlib is actively developed and can change significantly from version to version. For this\n", "script, the following version is used:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "execution": { "iopub.execute_input": "2024-09-12T21:07:07.293341Z", "iopub.status.busy": "2024-09-12T21:07:07.293238Z", "iopub.status.idle": "2024-09-12T21:07:07.302306Z", "shell.execute_reply": "2024-09-12T21:07:07.302008Z" } }, "outputs": [ { "data": { "text/plain": [ "'2.35.0'" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from importlib.metadata import version\n", "version(\"ray\") # Parent package of RLlib" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define the Environment\n", "A simple single-satellite environment is defined, as in :doc:`examples/rllib_training`." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-09-12T21:07:07.304007Z", "iopub.status.busy": "2024-09-12T21:07:07.303852Z", "iopub.status.idle": "2024-09-12T21:07:08.429665Z", "shell.execute_reply": "2024-09-12T21:07:08.429259Z" } }, "outputs": [], "source": [ "import numpy as np\n", "from bsk_rl import act, data, obs, sats, scene\n", "from bsk_rl.sim import dyn, fsw\n", "\n", "class ScanningDownlinkDynModel(dyn.ContinuousImagingDynModel, dyn.GroundStationDynModel):\n", " # Define some custom properties to be accessed in the state\n", " @property\n", " def instrument_pointing_error(self) -> float:\n", " r_BN_P_unit = self.r_BN_P/np.linalg.norm(self.r_BN_P) \n", " c_hat_P = self.satellite.fsw.c_hat_P\n", " return np.arccos(np.dot(-r_BN_P_unit, c_hat_P))\n", " \n", " @property\n", " def solar_pointing_error(self) -> float:\n", " a = self.world.gravFactory.spiceObject.planetStateOutMsgs[\n", " self.world.sun_index\n", " ].read().PositionVector\n", " a_hat_N = a / np.linalg.norm(a)\n", " nHat_B = self.satellite.sat_args[\"nHat_B\"]\n", " NB = np.transpose(self.BN)\n", " nHat_N = NB @ nHat_B\n", " return np.arccos(np.dot(nHat_N, a_hat_N))\n", "\n", "class ScanningSatellite(sats.AccessSatellite):\n", " observation_spec = [\n", " obs.SatProperties(\n", " dict(prop=\"storage_level_fraction\"),\n", " dict(prop=\"battery_charge_fraction\"),\n", " dict(prop=\"wheel_speeds_fraction\"),\n", " dict(prop=\"instrument_pointing_error\", norm=np.pi),\n", " dict(prop=\"solar_pointing_error\", norm=np.pi)\n", " ),\n", " obs.OpportunityProperties(\n", " dict(prop=\"opportunity_open\", norm=5700),\n", " dict(prop=\"opportunity_close\", norm=5700),\n", " type=\"ground_station\",\n", " n_ahead_observe=1,\n", " ),\n", " obs.Eclipse(norm=5700),\n", " ]\n", " action_spec = [\n", " act.Scan(duration=180.0),\n", " act.Charge(duration=120.0),\n", " act.Downlink(duration=60.0),\n", " act.Desat(duration=60.0),\n", " ]\n", " dyn_type = ScanningDownlinkDynModel\n", " fsw_type = fsw.ContinuousImagingFSWModel\n", "\n", "sat = ScanningSatellite(\n", " \"Scanner-1\",\n", " sat_args=dict(\n", " # Data\n", " dataStorageCapacity=5000 * 8e6, # bits\n", " storageInit=lambda: np.random.uniform(0.0, 0.8) * 5000 * 8e6,\n", " instrumentBaudRate=0.5 * 8e6,\n", " transmitterBaudRate=-50 * 8e6,\n", " # Power\n", " batteryStorageCapacity=200 * 3600, # W*s\n", " storedCharge_Init=lambda: np.random.uniform(0.3, 1.0) * 200 * 3600,\n", " basePowerDraw=-10.0, # W\n", " instrumentPowerDraw=-30.0, # W\n", " transmitterPowerDraw=-25.0, # W\n", " thrusterPowerDraw=-80.0, # W\n", " panelArea=0.25,\n", " # Attitude\n", " imageAttErrorRequirement=0.1,\n", " imageRateErrorRequirement=0.1,\n", " disturbance_vector=lambda: np.random.normal(scale=0.0001, size=3), # N*m\n", " maxWheelSpeed=6000.0, # RPM\n", " wheelSpeeds=lambda: np.random.uniform(-3000, 3000, 3),\n", " desatAttitude=\"nadir\",\n", " )\n", ")\n", "duration = 5 * 5700.0 # About 5 orbits\n", "env_args = dict(\n", " satellite=sat,\n", " scenario=scene.UniformNadirScanning(value_per_second=1/duration),\n", " rewarder=data.ScanningTimeReward(),\n", " time_limit=duration,\n", " failure_penalty=-1.0,\n", " terminate_on_time_limit=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## RLlib Configuration\n", "\n", "The configuration is mostly the same as in the standard example." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-09-12T21:07:08.431676Z", "iopub.status.busy": "2024-09-12T21:07:08.431471Z", "iopub.status.idle": "2024-09-12T21:07:15.818479Z", "shell.execute_reply": "2024-09-12T21:07:15.818175Z" } }, "outputs": [], "source": [ "import bsk_rl.utils.rllib # noqa To access \"SatelliteTasking-RLlib\"\n", "from ray.rllib.algorithms.ppo import PPOConfig\n", "\n", "\n", "N_CPUS = 3\n", "\n", "training_args = dict(\n", " lr=0.00003,\n", " gamma=0.999,\n", " train_batch_size=250,\n", " num_sgd_iter=10,\n", " model=dict(fcnet_hiddens=[512, 512], vf_share_layers=False),\n", " lambda_=0.95,\n", " use_kl_loss=False,\n", " clip_param=0.1,\n", " grad_clip=0.5,\n", ")\n", "\n", "config = (\n", " PPOConfig()\n", " .env_runners(num_env_runners=N_CPUS-1, sample_timeout_s=1000.0)\n", " .environment(\n", " env=\"SatelliteTasking-RLlib\",\n", " env_config=env_args,\n", " )\n", " .reporting(\n", " metrics_num_episodes_for_smoothing=1,\n", " metrics_episode_collection_timeout_s=180,\n", " )\n", " .checkpointing(export_native_model_files=True)\n", " .framework(framework=\"torch\")\n", " .api_stack(\n", " enable_rl_module_and_learner=True,\n", " enable_env_runner_and_connector_v2=True,\n", " )\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The additional setting that must be configured is the appropriate learner class. This \n", "uses the `d_ts` key from the info dict to discount by the step length, not just the step\n", "count." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-09-12T21:07:15.820796Z", "iopub.status.busy": "2024-09-12T21:07:15.820541Z", "iopub.status.idle": "2024-09-12T21:07:15.829862Z", "shell.execute_reply": "2024-09-12T21:07:15.829549Z" } }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from bsk_rl.utils.rllib.discounting import TimeDiscountedGAEPPOTorchLearner\n", "\n", "config.training(learner_class=TimeDiscountedGAEPPOTorchLearner)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Training can then proceed as normal." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2024-09-12T21:07:15.831346Z", "iopub.status.busy": "2024-09-12T21:07:15.831248Z", "iopub.status.idle": "2024-09-12T21:08:14.353619Z", "shell.execute_reply": "2024-09-12T21:08:14.352896Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-09-12 15:07:16,701\tINFO worker.py:1783 -- Started a local Ray instance.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2024-09-12 15:07:17,093\tINFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/markstephenson/avslab/refactor/.venv_refactor/lib/python3.10/site-packages/gymnasium/spaces/box.py:130: UserWarning: \u001b[33mWARN: Box bound precision lowered by casting to float32\u001b[0m\n", " gym.logger.warn(f\"Box bound precision lowered by casting to {self.dtype}\")\n", "/Users/markstephenson/avslab/refactor/.venv_refactor/lib/python3.10/site-packages/gymnasium/utils/passive_env_checker.py:164: UserWarning: \u001b[33mWARN: The obs returned by the `reset()` method was expecting numpy array dtype to be float32, actual type: float64\u001b[0m\n", " logger.warn(\n", "/Users/markstephenson/avslab/refactor/.venv_refactor/lib/python3.10/site-packages/gymnasium/utils/passive_env_checker.py:188: UserWarning: \u001b[33mWARN: The obs returned by the `reset()` method is not within the observation space.\u001b[0m\n", " logger.warn(f\"{pre} is not within the observation space.\")\n" ] }, { "data": { "text/html": [ "
\n", "
\n", "
\n", "

Tune Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "
Current time:2024-09-12 15:08:12
Running for: 00:00:55.44
Memory: 13.4/16.0 GiB
\n", "
\n", "
\n", "
\n", "

System Info

\n", " Using FIFO scheduling algorithm.
Logical resource usage: 3.0/3 CPUs, 0/0 GPUs\n", "
\n", " \n", "
\n", "
\n", "
\n", "

Trial Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "
Trial name status loc iter total time (s) num_env_steps_sample\n", "d_lifetime num_episodes_lifetim\n", "e num_env_steps_traine\n", "d_lifetime
PPO_SatelliteTasking-RLlib_fdcaf_00000TERMINATED127.0.0.1:96400 2 44.98948000468000
\n", "
\n", "
\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[36m(PPO pid=96400)\u001b[0m Install gputil for GPU system monitoring.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[36m(SingleAgentEnvRunner pid=96403)\u001b[0m \u001b[90;3m2024-09-12 15:07:28,637 \u001b[0m\u001b[36msats.satellite.Scanner-1 \u001b[0m\u001b[93mWARNING \u001b[0m\u001b[33m<11460.00> \u001b[0m\u001b[36mScanner-1: \u001b[0m\u001b[93mfailed battery_valid check\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[36m(SingleAgentEnvRunner pid=96403)\u001b[0m \u001b[90;3m2024-09-12 15:07:29,352 \u001b[0m\u001b[36msats.satellite.Scanner-1 \u001b[0m\u001b[93mWARNING \u001b[0m\u001b[33m<6300.00> \u001b[0m\u001b[36mScanner-1: \u001b[0m\u001b[93mfailed battery_valid check\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[36m(SingleAgentEnvRunner pid=96403)\u001b[0m \u001b[90;3m2024-09-12 15:07:30,782 \u001b[0m\u001b[36msats.satellite.Scanner-1 \u001b[0m\u001b[93mWARNING \u001b[0m\u001b[33m<16620.00> \u001b[0m\u001b[36mScanner-1: \u001b[0m\u001b[93mfailed battery_valid check\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[36m(SingleAgentEnvRunner pid=96402)\u001b[0m \u001b[90;3m2024-09-12 15:07:33,644 \u001b[0m\u001b[36msats.satellite.Scanner-1 \u001b[0m\u001b[93mWARNING \u001b[0m\u001b[33m<20760.00> \u001b[0m\u001b[36mScanner-1: \u001b[0m\u001b[93mfailed battery_valid check\u001b[0m\u001b[32m [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[36m(SingleAgentEnvRunner pid=96402)\u001b[0m \u001b[90;3m2024-09-12 15:07:38,804 \u001b[0m\u001b[36msats.satellite.Scanner-1 \u001b[0m\u001b[93mWARNING \u001b[0m\u001b[33m<26880.00> \u001b[0m\u001b[36mScanner-1: \u001b[0m\u001b[93mfailed battery_valid check\u001b[0m\u001b[32m [repeated 6x across cluster]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[36m(SingleAgentEnvRunner pid=96402)\u001b[0m \u001b[90;3m2024-09-12 15:07:43,887 \u001b[0m\u001b[36msats.satellite.Scanner-1 \u001b[0m\u001b[93mWARNING \u001b[0m\u001b[33m<28500.00> \u001b[0m\u001b[36mScanner-1: \u001b[0m\u001b[93mfailed battery_valid check\u001b[0m\u001b[32m [repeated 4x across cluster]\u001b[0m\n" ] }, { "data": { "text/html": [ "
\n", "

Trial Progress

\n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "
Trial name env_runners fault_tolerance learners num_agent_steps_sampled_lifetime num_env_steps_sampled_lifetime num_env_steps_trained_lifetime num_episodes_lifetimeperf timers
PPO_SatelliteTasking-RLlib_fdcaf_00000{'num_episodes': 22, 'episode_len_max': 267, 'episode_len_min': 117, 'episode_return_max': 0.3484912280701757, 'num_env_steps_sampled_lifetime': 16000, 'episode_return_min': -0.7496140350877192, 'num_module_steps_sampled_lifetime': {'default_policy': 12000}, 'episode_duration_sec_mean': 1.9101845209952444, 'module_episode_returns_mean': {'default_policy': -0.20056140350877175}, 'num_env_steps_sampled': 4000, 'num_module_steps_sampled': {'default_policy': 4000}, 'sample': 19.97263330376537, 'agent_episode_returns_mean': {'default_agent': -0.20056140350877175}, 'num_agent_steps_sampled_lifetime': {'default_agent': 12000}, 'episode_return_mean': -0.20056140350877175, 'num_agent_steps_sampled': {'default_agent': 4000}, 'episode_len_mean': 192.0, 'time_between_sampling': 2.675149833521573}{'num_healthy_workers': 2, 'num_in_flight_async_reqs': 0, 'num_remote_worker_restarts': 0}{'default_policy': {'num_module_steps_trained': 4000, 'total_loss': 0.15218333899974823, 'vf_loss': 0.0030167356599122286, 'curr_entropy_coeff': 0.0, 'mean_kl_loss': 0.013256911188364029, 'policy_loss': 0.14651519060134888, 'vf_explained_var': 0.027288198471069336, 'vf_loss_unclipped': 0.0030167356599122286, 'num_trainable_parameters': 139013.0, 'curr_kl_coeff': 0.20000000298023224, 'num_non_trainable_parameters': 0.0, 'default_optimizer_learning_rate': 5e-05, 'entropy': 1.3614425659179688}, '__all_modules__': {'num_trainable_parameters': 139013.0, 'num_non_trainable_parameters': 0.0, 'total_loss': 0.15218333899974823, 'num_module_steps_trained': 4000, 'num_env_steps_trained': 4000}}{'default_agent': 8000} 8000 8000 46{'cpu_util_percent': 26.38064516129032, 'ram_util_percent': 82.59032258064516}{'env_runner_sampling_timer': 20.58649573632865, 'learner_update_timer': 2.029126086170436, 'synch_weights': 0.009999817171483301, 'synch_env_connectors': 0.0037881858070613816}
\n", "
\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[36m(SingleAgentEnvRunner pid=96403)\u001b[0m \u001b[90;3m2024-09-12 15:07:52,128 \u001b[0m\u001b[36msats.satellite.Scanner-1 \u001b[0m\u001b[93mWARNING \u001b[0m\u001b[33m<3600.00> \u001b[0m\u001b[36mScanner-1: \u001b[0m\u001b[93mfailed battery_valid check\u001b[0m\u001b[32m [repeated 5x across cluster]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[36m(SingleAgentEnvRunner pid=96403)\u001b[0m \u001b[90;3m2024-09-12 15:07:58,377 \u001b[0m\u001b[36msats.satellite.Scanner-1 \u001b[0m\u001b[93mWARNING \u001b[0m\u001b[33m<19080.00> \u001b[0m\u001b[36mScanner-1: \u001b[0m\u001b[93mfailed battery_valid check\u001b[0m\u001b[32m [repeated 6x across cluster]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[36m(SingleAgentEnvRunner pid=96402)\u001b[0m \u001b[90;3m2024-09-12 15:08:03,504 \u001b[0m\u001b[36msats.satellite.Scanner-1 \u001b[0m\u001b[93mWARNING \u001b[0m\u001b[33m<26100.00> \u001b[0m\u001b[36mScanner-1: \u001b[0m\u001b[93mfailed battery_valid check\u001b[0m\u001b[32m [repeated 3x across cluster]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[36m(SingleAgentEnvRunner pid=96403)\u001b[0m \u001b[90;3m2024-09-12 15:08:10,010 \u001b[0m\u001b[36msats.satellite.Scanner-1 \u001b[0m\u001b[93mWARNING \u001b[0m\u001b[33m<14160.00> \u001b[0m\u001b[36mScanner-1: \u001b[0m\u001b[93mfailed battery_valid check\u001b[0m\u001b[32m [repeated 2x across cluster]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2024-09-12 15:08:12,577\tINFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/Users/markstephenson/ray_results/PPO_2024-09-12_15-07-17' in 0.0171s.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2024-09-12 15:08:12,922\tINFO tune.py:1041 -- Total run time: 55.83 seconds (55.42 seconds for the tuning loop).\n" ] } ], "source": [ "import ray\n", "from ray import tune\n", "\n", "ray.init(\n", " ignore_reinit_error=True,\n", " num_cpus=N_CPUS,\n", " object_store_memory=2_000_000_000, # 2 GB\n", ")\n", "\n", "# Run the training\n", "tune.run(\n", " \"PPO\",\n", " config=config.to_dict(),\n", " stop={\"training_iteration\": 2}, # Adjust the number of iterations as needed\n", ")\n", "\n", "# Shutdown Ray\n", "ray.shutdown()" ] } ], "metadata": { "kernelspec": { "display_name": ".venv_refactor", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.11" } }, "nbformat": 4, "nbformat_minor": 2 }