# Time-Discounted GAE
In semi-MDPs, each step has an associated duration. Instead of the usual value equation

\begin{equation}
V(s_1) = r_1 + \gamma r_2 + \gamma^2 r_3 + ...
\end{equation}

one discount based on step duration

\begin{equation}
V_{\Delta t}(s_1) = \gamma^{\Delta t_1} r_1 + \gamma^{\Delta t_1 + \Delta t_2} r_2 + \gamma^{\Delta t_1 + \Delta t_2 + \Delta t_3} r_3 + ...
\end{equation}

using the convention that reward is given at the end of a step.

The generalized advantage estimator can be rewritten accordingly. In our implementation,
the exponential decay `lambda` is per-step (as opposed to timewise).

## RLlib Version
RLlib is actively developed and can change significantly from version to version. For this
script, the following version is used:

In [None]:
from importlib.metadata import version

version("ray") # Parent package of RLlib

## Define the Environment
A simple single-satellite environment is defined, as in :doc:`examples/rllib_training`.

In [None]:
import numpy as np
from bsk_rl import act, data, obs, sats, scene
from bsk_rl.sim import dyn, fsw


class ScanningDownlinkDynModel(
 dyn.ContinuousImagingDynModel, dyn.GroundStationDynModel
):
 # Define some custom properties to be accessed in the state
 @property
 def instrument_pointing_error(self) -> float:
 r_BN_P_unit = self.r_BN_P / np.linalg.norm(self.r_BN_P)
 c_hat_P = self.satellite.fsw.c_hat_P
 return np.arccos(np.dot(-r_BN_P_unit, c_hat_P))

 @property
 def solar_pointing_error(self) -> float:
 a = (
 self.world.gravFactory.spiceObject.planetStateOutMsgs[self.world.sun_index]
 .read()
 .PositionVector
 )
 a_hat_N = a / np.linalg.norm(a)
 nHat_B = self.satellite.sat_args["nHat_B"]
 NB = np.transpose(self.BN)
 nHat_N = NB @ nHat_B
 return np.arccos(np.dot(nHat_N, a_hat_N))


class ScanningSatellite(sats.AccessSatellite):
 observation_spec = [
 obs.SatProperties(
 dict(prop="storage_level_fraction"),
 dict(prop="battery_charge_fraction"),
 dict(prop="wheel_speeds_fraction"),
 dict(prop="instrument_pointing_error", norm=np.pi),
 dict(prop="solar_pointing_error", norm=np.pi),
 ),
 obs.OpportunityProperties(
 dict(prop="opportunity_open", norm=5700),
 dict(prop="opportunity_close", norm=5700),
 type="ground_station",
 n_ahead_observe=1,
 ),
 obs.Eclipse(norm=5700),
 ]
 action_spec = [
 act.Scan(duration=180.0),
 act.Charge(duration=120.0),
 act.Downlink(duration=60.0),
 act.Desat(duration=60.0),
 ]
 dyn_type = ScanningDownlinkDynModel
 fsw_type = fsw.ContinuousImagingFSWModel


sat = ScanningSatellite(
 "Scanner-1",
 sat_args=dict(
 # Data
 dataStorageCapacity=5000 * 8e6, # bits
 storageInit=lambda: np.random.uniform(0.0, 0.8) * 5000 * 8e6,
 instrumentBaudRate=0.5 * 8e6,
 transmitterBaudRate=-50 * 8e6,
 # Power
 batteryStorageCapacity=200 * 3600, # W*s
 storedCharge_Init=lambda: np.random.uniform(0.3, 1.0) * 200 * 3600,
 basePowerDraw=-10.0, # W
 instrumentPowerDraw=-30.0, # W
 transmitterPowerDraw=-25.0, # W
 thrusterPowerDraw=-80.0, # W
 panelArea=0.25,
 # Attitude
 imageAttErrorRequirement=0.1,
 imageRateErrorRequirement=0.1,
 disturbance_vector=lambda: np.random.normal(scale=0.0001, size=3), # N*m
 maxWheelSpeed=6000.0, # RPM
 wheelSpeeds=lambda: np.random.uniform(-3000, 3000, 3),
 desatAttitude="nadir",
 ),
)
duration = 5 * 5700.0 # About 5 orbits
env_args = dict(
 satellite=sat,
 scenario=scene.UniformNadirScanning(value_per_second=1 / duration),
 rewarder=data.ScanningTimeReward(),
 time_limit=duration,
 failure_penalty=-1.0,
 terminate_on_time_limit=True,
)

## RLlib Configuration

The configuration is mostly the same as in the standard example.

In [None]:
import bsk_rl.utils.rllib # noqa To access "SatelliteTasking-RLlib"
from ray.rllib.algorithms.ppo import PPOConfig


N_CPUS = 3

training_args = dict(
 lr=0.00003,
 gamma=0.999,
 train_batch_size=250,
 num_sgd_iter=10,
 model=dict(fcnet_hiddens=[512, 512], vf_share_layers=False),
 lambda_=0.95,
 use_kl_loss=False,
 clip_param=0.1,
 grad_clip=0.5,
 reward_time="step_end",
)

config = (
 PPOConfig()
 .env_runners(num_env_runners=N_CPUS - 1, sample_timeout_s=1000.0)
 .environment(
 env="SatelliteTasking-RLlib",
 env_config=env_args,
 )
 .reporting(
 metrics_num_episodes_for_smoothing=1,
 metrics_episode_collection_timeout_s=180,
 )
 .checkpointing(export_native_model_files=True)
 .framework(framework="torch")
 .api_stack(
 enable_rl_module_and_learner=True,
 enable_env_runner_and_connector_v2=True,
 )
)

Rewards can also be distributed at the start of the step by setting ``reward_time="step_start"``.

The additional setting that must be configured is the appropriate learner class. This 
uses the `d_ts` key from the info dict to discount by the step length, not just the step
count.

In [None]:
from bsk_rl.utils.rllib.discounting import TimeDiscountedGAEPPOTorchLearner

config.training(learner_class=TimeDiscountedGAEPPOTorchLearner)

Training can then proceed as normal.

In [None]:
import ray
from ray import tune

ray.init(
 ignore_reinit_error=True,
 num_cpus=N_CPUS,
 object_store_memory=2_000_000_000, # 2 GB
)

# Run the training
tune.run(
 "PPO",
 config=config.to_dict(),
 stop={"training_iteration": 2}, # Adjust the number of iterations as needed
)

# Shutdown Ray
ray.shutdown()