Time-Discounted GAE

In semi-MDPs, each step has an associated duration. Instead of the usual value equation

\begin{equation} V(s_1) = r_1 + \gamma r_2 + \gamma^2 r_3 + ... \end{equation}

one discount based on step duration

\begin{equation} V_{\Delta t}(s_1) = \gamma^{\Delta t_1} r_1 + \gamma^{\Delta t_1 + \Delta t_2} r_2 + \gamma^{\Delta t_1 + \Delta t_2 + \Delta t_3} r_3 + ... \end{equation}

using the convention that reward is given at the end of a step.

The generalized advantage estimator can be rewritten accordingly. In our implementation, the exponential decay lambda is per-step (as opposed to timewise).

RLlib Version

RLlib is actively developed and can change significantly from version to version. For this script, the following version is used:

[1]:
from importlib.metadata import version

version("ray")  # Parent package of RLlib
[1]:
'2.35.0'

Define the Environment

A simple single-satellite environment is defined, as in :doc:examples/rllib_training.

[2]:
import numpy as np
from bsk_rl import act, data, obs, sats, scene
from bsk_rl.sim import dyn, fsw


class ScanningDownlinkDynModel(
    dyn.ContinuousImagingDynModel, dyn.GroundStationDynModel
):
    # Define some custom properties to be accessed in the state
    @property
    def instrument_pointing_error(self) -> float:
        r_BN_P_unit = self.r_BN_P / np.linalg.norm(self.r_BN_P)
        c_hat_P = self.satellite.fsw.c_hat_P
        return np.arccos(np.dot(-r_BN_P_unit, c_hat_P))

    @property
    def solar_pointing_error(self) -> float:
        a = (
            self.world.gravFactory.spiceObject.planetStateOutMsgs[self.world.sun_index]
            .read()
            .PositionVector
        )
        a_hat_N = a / np.linalg.norm(a)
        nHat_B = self.satellite.sat_args["nHat_B"]
        NB = np.transpose(self.BN)
        nHat_N = NB @ nHat_B
        return np.arccos(np.dot(nHat_N, a_hat_N))


class ScanningSatellite(sats.AccessSatellite):
    observation_spec = [
        obs.SatProperties(
            dict(prop="storage_level_fraction"),
            dict(prop="battery_charge_fraction"),
            dict(prop="wheel_speeds_fraction"),
            dict(prop="instrument_pointing_error", norm=np.pi),
            dict(prop="solar_pointing_error", norm=np.pi),
        ),
        obs.OpportunityProperties(
            dict(prop="opportunity_open", norm=5700),
            dict(prop="opportunity_close", norm=5700),
            type="ground_station",
            n_ahead_observe=1,
        ),
        obs.Eclipse(norm=5700),
    ]
    action_spec = [
        act.Scan(duration=180.0),
        act.Charge(duration=120.0),
        act.Downlink(duration=60.0),
        act.Desat(duration=60.0),
    ]
    dyn_type = ScanningDownlinkDynModel
    fsw_type = fsw.ContinuousImagingFSWModel


sat = ScanningSatellite(
    "Scanner-1",
    sat_args=dict(
        # Data
        dataStorageCapacity=5000 * 8e6,  # bits
        storageInit=lambda: np.random.uniform(0.0, 0.8) * 5000 * 8e6,
        instrumentBaudRate=0.5 * 8e6,
        transmitterBaudRate=-50 * 8e6,
        # Power
        batteryStorageCapacity=200 * 3600,  # W*s
        storedCharge_Init=lambda: np.random.uniform(0.3, 1.0) * 200 * 3600,
        basePowerDraw=-10.0,  # W
        instrumentPowerDraw=-30.0,  # W
        transmitterPowerDraw=-25.0,  # W
        thrusterPowerDraw=-80.0,  # W
        panelArea=0.25,
        # Attitude
        imageAttErrorRequirement=0.1,
        imageRateErrorRequirement=0.1,
        disturbance_vector=lambda: np.random.normal(scale=0.0001, size=3),  # N*m
        maxWheelSpeed=6000.0,  # RPM
        wheelSpeeds=lambda: np.random.uniform(-3000, 3000, 3),
        desatAttitude="nadir",
    ),
)
duration = 5 * 5700.0  # About 5 orbits
env_args = dict(
    satellite=sat,
    scenario=scene.UniformNadirScanning(value_per_second=1 / duration),
    rewarder=data.ScanningTimeReward(),
    time_limit=duration,
    failure_penalty=-1.0,
    terminate_on_time_limit=True,
)

RLlib Configuration

The configuration is mostly the same as in the standard example.

[3]:
import bsk_rl.utils.rllib  # noqa To access "SatelliteTasking-RLlib"
from ray.rllib.algorithms.ppo import PPOConfig


N_CPUS = 3

training_args = dict(
    lr=0.00003,
    gamma=0.999,
    train_batch_size=250,
    num_sgd_iter=10,
    model=dict(fcnet_hiddens=[512, 512], vf_share_layers=False),
    lambda_=0.95,
    use_kl_loss=False,
    clip_param=0.1,
    grad_clip=0.5,
    reward_time="step_end",
)

config = (
    PPOConfig()
    .env_runners(num_env_runners=N_CPUS - 1, sample_timeout_s=1000.0)
    .environment(
        env="SatelliteTasking-RLlib",
        env_config=env_args,
    )
    .reporting(
        metrics_num_episodes_for_smoothing=1,
        metrics_episode_collection_timeout_s=180,
    )
    .checkpointing(export_native_model_files=True)
    .framework(framework="torch")
    .api_stack(
        enable_rl_module_and_learner=True,
        enable_env_runner_and_connector_v2=True,
    )
)

Rewards can also be distributed at the start of the step by setting reward_time="step_start".

The additional setting that must be configured is the appropriate learner class. This uses the d_ts key from the info dict to discount by the step length, not just the step count.

[4]:
from bsk_rl.utils.rllib.discounting import TimeDiscountedGAEPPOTorchLearner

config.training(learner_class=TimeDiscountedGAEPPOTorchLearner)
[4]:
<ray.rllib.algorithms.ppo.ppo.PPOConfig at 0x7f1b48212f50>

Training can then proceed as normal.

[5]:
import ray
from ray import tune

ray.init(
    ignore_reinit_error=True,
    num_cpus=N_CPUS,
    object_store_memory=2_000_000_000,  # 2 GB
)

# Run the training
tune.run(
    "PPO",
    config=config.to_dict(),
    stop={"training_iteration": 2},  # Adjust the number of iterations as needed
)

# Shutdown Ray
ray.shutdown()
2025-09-14 21:59:42,170 INFO worker.py:1783 -- Started a local Ray instance.
2025-09-14 21:59:45,736 INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949
/opt/hostedtoolcache/Python/3.11.13/x64/lib/python3.11/site-packages/gymnasium/spaces/box.py:130: UserWarning: WARN: Box bound precision lowered by casting to float32
  gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
/opt/hostedtoolcache/Python/3.11.13/x64/lib/python3.11/site-packages/gymnasium/utils/passive_env_checker.py:164: UserWarning: WARN: The obs returned by the `reset()` method was expecting numpy array dtype to be float32, actual type: float64
  logger.warn(
/opt/hostedtoolcache/Python/3.11.13/x64/lib/python3.11/site-packages/gymnasium/utils/passive_env_checker.py:188: UserWarning: WARN: The obs returned by the `reset()` method is not within the observation space.
  logger.warn(f"{pre} is not within the observation space.")

Tune Status

Current time:2025-09-14 22:01:43
Running for: 00:01:57.59
Memory: 4.5/15.6 GiB

System Info

Using FIFO scheduling algorithm.
Logical resource usage: 3.0/3 CPUs, 0/0 GPUs

Trial Status

Trial name status loc iter total time (s) num_env_steps_sample d_lifetime num_episodes_lifetim e num_env_steps_traine d_lifetime
PPO_SatelliteTasking-RLlib_2021a_00000TERMINATED10.1.0.201:6287 2 102.2938000418000
(PPO pid=6287) Install gputil for GPU system monitoring.
(SingleAgentEnvRunner pid=6334) 2025-09-14 22:00:02,669 sats.satellite.Scanner-1       WARNING    <7740.00> Scanner-1: failed battery_valid check
(SingleAgentEnvRunner pid=6334) 2025-09-14 22:00:11,660 sats.satellite.Scanner-1       WARNING    <24900.00> Scanner-1: failed battery_valid check [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)
(SingleAgentEnvRunner pid=6335) 2025-09-14 22:00:17,349 sats.satellite.Scanner-1       WARNING    <13980.00> Scanner-1: failed battery_valid check [repeated 2x across cluster]
(SingleAgentEnvRunner pid=6335) 2025-09-14 22:00:25,337 sats.satellite.Scanner-1       WARNING    <15300.00> Scanner-1: failed battery_valid check [repeated 2x across cluster]
(SingleAgentEnvRunner pid=6334) 2025-09-14 22:00:30,832 sats.satellite.Scanner-1       WARNING    <13080.00> Scanner-1: failed battery_valid check [repeated 4x across cluster]
(SingleAgentEnvRunner pid=6334) 2025-09-14 22:00:37,953 sats.satellite.Scanner-1       WARNING    <11340.00> Scanner-1: failed battery_valid check [repeated 2x across cluster]
(SingleAgentEnvRunner pid=6335) 2025-09-14 22:00:43,440 sats.satellite.Scanner-1       WARNING    <19800.00> Scanner-1: failed battery_valid check

Trial Progress

Trial name env_runners fault_tolerance learners num_agent_steps_sampled_lifetime num_env_steps_sampled_lifetime num_env_steps_trained_lifetime num_episodes_lifetimeperf timers
PPO_SatelliteTasking-RLlib_2021a_00000{'agent_episode_returns_mean': {'default_agent': -0.7247543859649122}, 'num_episodes': 21, 'num_env_steps_sampled_lifetime': 16000, 'episode_return_min': -0.8071228070175439, 'episode_duration_sec_mean': 8.847802487499962, 'num_module_steps_sampled_lifetime': {'default_policy': 12000}, 'num_agent_steps_sampled': {'default_agent': 4000}, 'num_env_steps_sampled': 4000, 'episode_len_max': 250, 'num_agent_steps_sampled_lifetime': {'default_agent': 12000}, 'num_module_steps_sampled': {'default_policy': 4000}, 'sample': np.float64(41.77654663319187), 'episode_len_mean': 198.0, 'episode_return_mean': -0.7247543859649122, 'module_episode_returns_mean': {'default_policy': -0.7247543859649122}, 'episode_len_min': 146, 'episode_return_max': -0.6423859649122805, 'time_between_sampling': np.float64(7.754538097000022)}{'num_healthy_workers': 2, 'num_in_flight_async_reqs': 0, 'num_remote_worker_restarts': 0}{'__all_modules__': {'num_module_steps_trained': 4000, 'num_trainable_parameters': 139013.0, 'num_env_steps_trained': 4000, 'total_loss': -0.036448147147893906, 'num_non_trainable_parameters': 0.0}, 'default_policy': {'curr_kl_coeff': 0.20000000298023224, 'policy_loss': -0.03920312970876694, 'vf_loss': 0.000287031230982393, 'vf_loss_unclipped': 0.000287031230982393, 'entropy': 1.366040587425232, 'vf_explained_var': 0.031376779079437256, 'num_trainable_parameters': 139013.0, 'default_optimizer_learning_rate': 5e-05, 'mean_kl_loss': 0.012339756824076176, 'total_loss': -0.036448147147893906, 'num_module_steps_trained': 4000, 'num_non_trainable_parameters': 0.0, 'curr_entropy_coeff': 0.0}}{'default_agent': 8000} 8000 8000 41{'cpu_util_percent': np.float64(43.20921052631579), 'ram_util_percent': np.float64(29.002631578947366)}{'env_runner_sampling_timer': 44.92216026339999, 'learner_update_timer': 4.607165536629968, 'synch_weights': 0.005811738079985389, 'synch_env_connectors': 0.006100494909892404}
(SingleAgentEnvRunner pid=6334) 2025-09-14 22:00:51,244 sats.satellite.Scanner-1       WARNING    <13140.00> Scanner-1: failed battery_valid check
(SingleAgentEnvRunner pid=6334) 2025-09-14 22:01:03,831 sats.satellite.Scanner-1       WARNING    <12360.00> Scanner-1: failed battery_valid check [repeated 5x across cluster]
(SingleAgentEnvRunner pid=6335) 2025-09-14 22:01:09,310 sats.satellite.Scanner-1       WARNING    <28440.00> Scanner-1: failed battery_valid check [repeated 2x across cluster]
(SingleAgentEnvRunner pid=6334) 2025-09-14 22:01:15,684 sats.satellite.Scanner-1       WARNING    <23700.00> Scanner-1: failed battery_valid check [repeated 4x across cluster]
(SingleAgentEnvRunner pid=6334) 2025-09-14 22:01:21,599 sats.satellite.Scanner-1       WARNING    <13080.00> Scanner-1: failed battery_valid check [repeated 2x across cluster]
(SingleAgentEnvRunner pid=6334) 2025-09-14 22:01:27,626 sats.satellite.Scanner-1       WARNING    <15240.00> Scanner-1: failed battery_valid check [repeated 2x across cluster]
(SingleAgentEnvRunner pid=6335) 2025-09-14 22:01:34,488 sats.satellite.Scanner-1       WARNING    <28380.00> Scanner-1: failed battery_valid check
2025-09-14 22:01:43,371 INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/runner/ray_results/PPO_2025-09-14_21-59-45' in 0.0157s.
2025-09-14 22:01:44,065 INFO tune.py:1041 -- Total run time: 118.33 seconds (117.58 seconds for the tuning loop).