{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Asynchronous Multiagent Decision Making\n",
    "This tutorial demonstrates how to configure and train a multiagent environment in RLlib\n",
    "in which homogeneous agents act asyncronously while learning learning a single policy.\n",
    "\n",
    "<div class=\"alert alert-warning\">\n",
    "\n",
    "**Warning:** Part of RLlib's backend mishandles the potential for zero-length episodes,\n",
    "which this method may produce. If this PR is unmerged, it must be manually applied to your \n",
    "installation: https://github.com/ray-project/ray/pull/46721\n",
    "\n",
    "</div>\n",
    "\n",
    "This example was run with the following version of RLlib:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-12T21:05:14.764955Z",
     "iopub.status.busy": "2024-09-12T21:05:14.764852Z",
     "iopub.status.idle": "2024-09-12T21:05:14.775762Z",
     "shell.execute_reply": "2024-09-12T21:05:14.775450Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'2.35.0'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from importlib.metadata import version\n",
    "version(\"ray\")  # Parent package of RLlib"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the Environment\n",
    "A simple multi-satellite environment is defined. This environment is trivial for the \n",
    "multiagent case since there is no communication or interaction between satellites, but it \n",
    "serves to demonstrate asynchronous behavior."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-12T21:05:14.796164Z",
     "iopub.status.busy": "2024-09-12T21:05:14.796015Z",
     "iopub.status.idle": "2024-09-12T21:05:15.851841Z",
     "shell.execute_reply": "2024-09-12T21:05:15.851546Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from bsk_rl import act, data, obs, sats, scene\n",
    "from bsk_rl.sim import dyn, fsw\n",
    "\n",
    "class ScanningDownlinkDynModel(dyn.ContinuousImagingDynModel, dyn.GroundStationDynModel):\n",
    "    # Define some custom properties to be accessed in the state\n",
    "    @property\n",
    "    def instrument_pointing_error(self) -> float:\n",
    "        r_BN_P_unit = self.r_BN_P/np.linalg.norm(self.r_BN_P) \n",
    "        c_hat_P = self.satellite.fsw.c_hat_P\n",
    "        return np.arccos(np.dot(-r_BN_P_unit, c_hat_P))\n",
    "    \n",
    "    @property\n",
    "    def solar_pointing_error(self) -> float:\n",
    "        a = self.world.gravFactory.spiceObject.planetStateOutMsgs[\n",
    "            self.world.sun_index\n",
    "        ].read().PositionVector\n",
    "        a_hat_N = a / np.linalg.norm(a)\n",
    "        nHat_B = self.satellite.sat_args[\"nHat_B\"]\n",
    "        NB = np.transpose(self.BN)\n",
    "        nHat_N = NB @ nHat_B\n",
    "        return np.arccos(np.dot(nHat_N, a_hat_N))\n",
    "\n",
    "class ScanningSatellite(sats.AccessSatellite):\n",
    "    observation_spec = [\n",
    "        obs.SatProperties(\n",
    "            dict(prop=\"storage_level_fraction\"),\n",
    "            dict(prop=\"battery_charge_fraction\"),\n",
    "            dict(prop=\"wheel_speeds_fraction\"),\n",
    "            dict(prop=\"instrument_pointing_error\", norm=np.pi),\n",
    "            dict(prop=\"solar_pointing_error\", norm=np.pi)\n",
    "        ),\n",
    "        obs.OpportunityProperties(\n",
    "            dict(prop=\"opportunity_open\", norm=5700),\n",
    "            dict(prop=\"opportunity_close\", norm=5700),\n",
    "            type=\"ground_station\",\n",
    "            n_ahead_observe=1,\n",
    "        ),\n",
    "        obs.Eclipse(norm=5700),\n",
    "    ]\n",
    "    action_spec = [\n",
    "        act.Scan(duration=150.0),\n",
    "        act.Charge(duration=120.0),\n",
    "        act.Downlink(duration=80.0),\n",
    "        act.Desat(duration=45.0),\n",
    "    ]\n",
    "    dyn_type = ScanningDownlinkDynModel\n",
    "    fsw_type = fsw.ContinuousImagingFSWModel\n",
    "\n",
    "sats = [ScanningSatellite(\n",
    "    f\"Scanner-{i+1}\",\n",
    "    sat_args=dict(\n",
    "        # Data\n",
    "        dataStorageCapacity=5000 * 8e6,  # bits\n",
    "        storageInit=lambda: np.random.uniform(0.0, 0.8) * 5000 * 8e6,\n",
    "        instrumentBaudRate=0.5 * 8e6,\n",
    "        transmitterBaudRate=-50 * 8e6,\n",
    "        # Power\n",
    "        batteryStorageCapacity=200 * 3600,  # W*s\n",
    "        storedCharge_Init=lambda: np.random.uniform(0.3, 1.0) * 200 * 3600,\n",
    "        basePowerDraw=-10.0,  # W\n",
    "        instrumentPowerDraw=-30.0,  # W\n",
    "        transmitterPowerDraw=-25.0,  # W\n",
    "        thrusterPowerDraw=-80.0,  # W\n",
    "        panelArea=0.25,\n",
    "        # Attitude\n",
    "        imageAttErrorRequirement=0.1,\n",
    "        imageRateErrorRequirement=0.1,\n",
    "        disturbance_vector=lambda: np.random.normal(scale=0.0001, size=3),  # N*m\n",
    "        maxWheelSpeed=6000.0,  # RPM\n",
    "        wheelSpeeds=lambda: np.random.uniform(-3000, 3000, 3),\n",
    "        desatAttitude=\"nadir\",\n",
    "    )\n",
    ") for i in range(4)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Correlated Environment Parameters\n",
    "To construct a constellation with some coordinated, a function is generated to map satellites\n",
    "to orbital elements:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-12T21:05:15.854023Z",
     "iopub.status.busy": "2024-09-12T21:05:15.853862Z",
     "iopub.status.idle": "2024-09-12T21:05:15.855864Z",
     "shell.execute_reply": "2024-09-12T21:05:15.855611Z"
    }
   },
   "outputs": [],
   "source": [
    "from bsk_rl.utils.orbital import walker_delta_args\n",
    "\n",
    "sat_arg_randomizer = walker_delta_args(n_planes=2, altitude=500)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `sat_arg_randomizer` is included in the environment arguments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-12T21:05:15.857328Z",
     "iopub.status.busy": "2024-09-12T21:05:15.857238Z",
     "iopub.status.idle": "2024-09-12T21:05:15.859036Z",
     "shell.execute_reply": "2024-09-12T21:05:15.858795Z"
    }
   },
   "outputs": [],
   "source": [
    "duration = 5 * 5700.0  # About 5 orbits\n",
    "env_args = dict(\n",
    "    satellites=sats,\n",
    "    scenario=scene.UniformNadirScanning(value_per_second=1/duration),\n",
    "    rewarder=data.ScanningTimeReward(),\n",
    "    time_limit=duration,\n",
    "    failure_penalty=-1.0,\n",
    "    terminate_on_time_limit=True,\n",
    "    sat_arg_randomizer=sat_arg_randomizer,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RLlib Training Configuration\n",
    "\n",
    "A standard PPO configuration is generated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-12T21:05:15.860327Z",
     "iopub.status.busy": "2024-09-12T21:05:15.860239Z",
     "iopub.status.idle": "2024-09-12T21:05:23.100036Z",
     "shell.execute_reply": "2024-09-12T21:05:23.099712Z"
    }
   },
   "outputs": [],
   "source": [
    "import bsk_rl.utils.rllib  # noqa To access \"ConstellationTasking-RLlib\"\n",
    "from ray.rllib.algorithms.ppo import PPOConfig\n",
    "\n",
    "\n",
    "N_CPUS = 3\n",
    "\n",
    "training_args = dict(\n",
    "    lr=0.00003,\n",
    "    gamma=0.99997,\n",
    "    train_batch_size=200 * N_CPUS,\n",
    "    num_sgd_iter=10,\n",
    "    lambda_=0.95,\n",
    "    use_kl_loss=False,\n",
    "    clip_param=0.1,\n",
    "    grad_clip=0.5,\n",
    "    mini_batch_size_per_learner=100,\n",
    ")\n",
    "\n",
    "config = (\n",
    "    PPOConfig()\n",
    "    .environment(\n",
    "        \"ConstellationTasking-RLlib\",\n",
    "        env_config=env_args,\n",
    "    )\n",
    "    .env_runners(\n",
    "        num_env_runners=N_CPUS - 1,\n",
    "        sample_timeout_s=1000.0,\n",
    "    )\n",
    "    .reporting(\n",
    "        metrics_num_episodes_for_smoothing=1,\n",
    "        metrics_episode_collection_timeout_s=180,\n",
    "    )\n",
    "    .checkpointing(export_native_model_files=True)\n",
    "    .framework(framework=\"torch\")\n",
    "    .api_stack(\n",
    "        enable_rl_module_and_learner=True,\n",
    "        enable_env_runner_and_connector_v2=True,\n",
    "    )\n",
    "    .training(\n",
    "        **training_args,\n",
    "    )\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To set up multiple agents using the same policy, the following configurations are set to\n",
    "map all agents to the policy `p0`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-12T21:05:23.102219Z",
     "iopub.status.busy": "2024-09-12T21:05:23.101948Z",
     "iopub.status.idle": "2024-09-12T21:05:23.105180Z",
     "shell.execute_reply": "2024-09-12T21:05:23.104900Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<ray.rllib.algorithms.ppo.ppo.PPOConfig at 0x11006dc30>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "try:\n",
    "    from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec\n",
    "    from ray.rllib.core.rl_module.rl_module import RLModuleSpec\n",
    "except (ImportError, ModuleNotFoundError):  # Older versions of RLlib\n",
    "    from ray.rllib.core.rl_module.marl_module import (\n",
    "        MultiAgentRLModuleSpec as MultiRLModuleSpec,\n",
    "    )\n",
    "    from ray.rllib.core.rl_module.rl_module import (\n",
    "        SingleAgentRLModuleSpec as RLModuleSpec,\n",
    "    )\n",
    "\n",
    "config.multi_agent(\n",
    "    policies={\"p0\"},\n",
    "    policy_mapping_fn=lambda *args, **kwargs: \"p0\",\n",
    ").rl_module(\n",
    "    model_config_dict={\n",
    "        \"use_lstm\": False,\n",
    "        # Use a simpler FCNet when we also have an LSTM.\n",
    "        \"fcnet_hiddens\": [2048, 2048],\n",
    "        \"vf_share_layers\": False,\n",
    "    },\n",
    "    rl_module_spec=MultiRLModuleSpec(\n",
    "        module_specs={\n",
    "            \"p0\": RLModuleSpec(),\n",
    "        }\n",
    "    ),\n",
    ")\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Configuring Multiagent Logging Callbacks\n",
    "A callback function for the entire environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-12T21:05:23.106597Z",
     "iopub.status.busy": "2024-09-12T21:05:23.106506Z",
     "iopub.status.idle": "2024-09-12T21:05:23.108201Z",
     "shell.execute_reply": "2024-09-12T21:05:23.107947Z"
    }
   },
   "outputs": [],
   "source": [
    "def env_metrics_callback(env):\n",
    "    reward = env.rewarder.cum_reward\n",
    "    reward = sum(reward.values()) / len(reward)\n",
    "    return dict(reward=reward)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "and per satellite"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-12T21:05:23.109460Z",
     "iopub.status.busy": "2024-09-12T21:05:23.109361Z",
     "iopub.status.idle": "2024-09-12T21:05:23.111033Z",
     "shell.execute_reply": "2024-09-12T21:05:23.110807Z"
    }
   },
   "outputs": [],
   "source": [
    "def sat_metrics_callback(env, satellite):\n",
    "    data = dict(\n",
    "        # Are satellites dying, and how and when?\n",
    "        alive=float(satellite.is_alive()),\n",
    "        rw_status_valid=float(satellite.dynamics.rw_speeds_valid()),\n",
    "        battery_status_valid=float(satellite.dynamics.battery_valid()),\n",
    "    )\n",
    "    return data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "are defined. The `sat_metrics_callback` will be reported per-agent and as a mean. If\n",
    "using the predefined `\"ConstellationTasking-RLlib\"`, only the `WrappedEpisodeDataCallbacks`\n",
    "need to be added to the config, as in the single-agent case."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-12T21:05:23.112280Z",
     "iopub.status.busy": "2024-09-12T21:05:23.112187Z",
     "iopub.status.idle": "2024-09-12T21:05:23.114109Z",
     "shell.execute_reply": "2024-09-12T21:05:23.113888Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<ray.rllib.algorithms.ppo.ppo.PPOConfig at 0x11006dc30>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from bsk_rl.utils.rllib.callbacks import WrappedEpisodeDataCallbacks\n",
    "\n",
    "config.callbacks(WrappedEpisodeDataCallbacks)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Action Continuation and Concatenation\n",
    "Logic to prevent all agents from retasking whenever any agent finishes an action\n",
    "is introduced, in the form of connector modules. First, the `ContinuePreviousAction` connector\n",
    "overrides any policy-selected action with the `bsk_rl.NO_ACTION` whenever `requires_retasking==False`\n",
    "for an agent, causing the agent to continue its current action."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-12T21:05:23.115449Z",
     "iopub.status.busy": "2024-09-12T21:05:23.115373Z",
     "iopub.status.idle": "2024-09-12T21:05:23.122877Z",
     "shell.execute_reply": "2024-09-12T21:05:23.122657Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<ray.rllib.algorithms.ppo.ppo.PPOConfig at 0x11006dc30>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from bsk_rl.utils.rllib import discounting\n",
    "\n",
    "config.env_runners(\n",
    "    module_to_env_connector=lambda env: (discounting.ContinuePreviousAction(),)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, two other connectors compress `NO_ACTION` out of episodes of experience, combining\n",
    "steps into those with super-actions. The `d_ts` timestep flag is calculated accordingly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-12T21:05:23.124137Z",
     "iopub.status.busy": "2024-09-12T21:05:23.124041Z",
     "iopub.status.idle": "2024-09-12T21:05:23.126169Z",
     "shell.execute_reply": "2024-09-12T21:05:23.125956Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<ray.rllib.algorithms.ppo.ppo.PPOConfig at 0x11006dc30>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "config.training(\n",
    "    learner_connector=lambda obs_space, act_space: (\n",
    "        discounting.MakeAddedStepActionValid(expected_train_batch_size=config.train_batch_size),\n",
    "        discounting.CondenseMultiStepActions(),\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lastly, the `TimeDiscountedGAEPPOTorchLearner` is used, as in :doc:`examples/time_discounted_gae`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-12T21:05:23.127391Z",
     "iopub.status.busy": "2024-09-12T21:05:23.127304Z",
     "iopub.status.idle": "2024-09-12T21:05:23.129225Z",
     "shell.execute_reply": "2024-09-12T21:05:23.129031Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<ray.rllib.algorithms.ppo.ppo.PPOConfig at 0x11006dc30>"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "config.training(learner_class=discounting.TimeDiscountedGAEPPOTorchLearner)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that when using these connectors, only the `requires_retasking` flag will case agents\n",
    "to select a new action. Step timeouts due to `max_step_duration` will not trigger retasking."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training the Agent\n",
    "At this point, the PPO config can be trained as desired."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-12T21:05:23.130450Z",
     "iopub.status.busy": "2024-09-12T21:05:23.130378Z",
     "iopub.status.idle": "2024-09-12T21:05:57.157548Z",
     "shell.execute_reply": "2024-09-12T21:05:57.156788Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-09-12 15:05:23,877\tINFO worker.py:1783 -- Started a local Ray instance.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-09-12 15:05:24,190\tINFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/markstephenson/avslab/refactor/.venv_refactor/lib/python3.10/site-packages/gymnasium/spaces/box.py:130: UserWarning: \u001b[33mWARN: Box bound precision lowered by casting to float32\u001b[0m\n",
      "  gym.logger.warn(f\"Box bound precision lowered by casting to {self.dtype}\")\n",
      "/Users/markstephenson/avslab/refactor/.venv_refactor/lib/python3.10/site-packages/gymnasium/utils/passive_env_checker.py:164: UserWarning: \u001b[33mWARN: The obs returned by the `reset()` method was expecting numpy array dtype to be float32, actual type: float64\u001b[0m\n",
      "  logger.warn(\n",
      "/Users/markstephenson/avslab/refactor/.venv_refactor/lib/python3.10/site-packages/gymnasium/utils/passive_env_checker.py:188: UserWarning: \u001b[33mWARN: The obs returned by the `reset()` method is not within the observation space.\u001b[0m\n",
      "  logger.warn(f\"{pre} is not within the observation space.\")\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div class=\"tuneStatus\">\n",
       "  <div style=\"display: flex;flex-direction: row\">\n",
       "    <div style=\"display: flex;flex-direction: column;\">\n",
       "      <h3>Tune Status</h3>\n",
       "      <table>\n",
       "<tbody>\n",
       "<tr><td>Current time:</td><td>2024-09-12 15:05:54</td></tr>\n",
       "<tr><td>Running for: </td><td>00:00:30.52        </td></tr>\n",
       "<tr><td>Memory:      </td><td>13.5/16.0 GiB      </td></tr>\n",
       "</tbody>\n",
       "</table>\n",
       "    </div>\n",
       "    <div class=\"vDivider\"></div>\n",
       "    <div class=\"systemInfo\">\n",
       "      <h3>System Info</h3>\n",
       "      Using FIFO scheduling algorithm.<br>Logical resource usage: 3.0/3 CPUs, 0/0 GPUs\n",
       "    </div>\n",
       "    \n",
       "  </div>\n",
       "  <div class=\"hDivider\"></div>\n",
       "  <div class=\"trialStatus\">\n",
       "    <h3>Trial Status</h3>\n",
       "    <table>\n",
       "<thead>\n",
       "<tr><th>Trial name                                </th><th>status    </th><th>loc            </th><th style=\"text-align: right;\">  iter</th><th style=\"text-align: right;\">  total time (s)</th><th style=\"text-align: right;\">     num_env_steps_sample\n",
       "d_lifetime</th><th style=\"text-align: right;\">  num_episodes_lifetim\n",
       "e</th><th style=\"text-align: right;\">     num_env_steps_traine\n",
       "d_lifetime</th></tr>\n",
       "</thead>\n",
       "<tbody>\n",
       "<tr><td>PPO_ConstellationTasking-RLlib_ba7e4_00000</td><td>TERMINATED</td><td>127.0.0.1:95948</td><td style=\"text-align: right;\">     2</td><td style=\"text-align: right;\">           16.33</td><td style=\"text-align: right;\">1200</td><td style=\"text-align: right;\">0</td><td style=\"text-align: right;\">1200</td></tr>\n",
       "</tbody>\n",
       "</table>\n",
       "  </div>\n",
       "</div>\n",
       "<style>\n",
       ".tuneStatus {\n",
       "  color: var(--jp-ui-font-color1);\n",
       "}\n",
       ".tuneStatus .systemInfo {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       ".tuneStatus td {\n",
       "  white-space: nowrap;\n",
       "}\n",
       ".tuneStatus .trialStatus {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       ".tuneStatus h3 {\n",
       "  font-weight: bold;\n",
       "}\n",
       ".tuneStatus .hDivider {\n",
       "  border-bottom-width: var(--jp-border-width);\n",
       "  border-bottom-color: var(--jp-border-color0);\n",
       "  border-bottom-style: solid;\n",
       "}\n",
       ".tuneStatus .vDivider {\n",
       "  border-left-width: var(--jp-border-width);\n",
       "  border-left-color: var(--jp-border-color0);\n",
       "  border-left-style: solid;\n",
       "  margin: 0.5em 1em 0.5em 1em;\n",
       "}\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[36m(PPO pid=95948)\u001b[0m Install gputil for GPU system monitoring.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[36m(MultiAgentEnvRunner pid=95950)\u001b[0m \u001b[90;3m2024-09-12 15:05:44,475 \u001b[0m\u001b[36msats.satellite.Scanner-4       \u001b[0m\u001b[93mWARNING    \u001b[0m\u001b[33m<6740.00> \u001b[0m\u001b[36mScanner-4: \u001b[0m\u001b[93mfailed battery_valid check\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div class=\"trialProgress\">\n",
       "  <h3>Trial Progress</h3>\n",
       "  <table>\n",
       "<thead>\n",
       "<tr><th>Trial name                                </th><th>env_runners                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    </th><th>fault_tolerance                                                                           </th><th>learners                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             </th><th>num_agent_steps_sampled_lifetime                                           </th><th style=\"text-align: right;\">  num_env_steps_sampled_lifetime</th><th style=\"text-align: right;\">  num_env_steps_trained_lifetime</th><th style=\"text-align: right;\">  num_episodes_lifetime</th><th>perf                                                              </th><th>timers                                                                                                                                                                           </th></tr>\n",
       "</thead>\n",
       "<tbody>\n",
       "<tr><td>PPO_ConstellationTasking-RLlib_ba7e4_00000</td><td>{&#x27;num_env_steps_sampled&#x27;: 600, &#x27;num_agent_steps_sampled_lifetime&#x27;: {&#x27;Scanner-3&#x27;: 1800, &#x27;Scanner-2&#x27;: 1800, &#x27;Scanner-4&#x27;: 1406, &#x27;Scanner-1&#x27;: 1765}, &#x27;num_agent_steps_sampled&#x27;: {&#x27;Scanner-2&#x27;: 600, &#x27;Scanner-4&#x27;: 300, &#x27;Scanner-1&#x27;: 565, &#x27;Scanner-3&#x27;: 600}, &#x27;num_env_steps_sampled_lifetime&#x27;: 2400, &#x27;num_module_steps_sampled_lifetime&#x27;: {&#x27;p0&#x27;: 6771}, &#x27;num_episodes&#x27;: 0, &#x27;num_module_steps_sampled&#x27;: {&#x27;p0&#x27;: 2065}, &#x27;episode_return_mean&#x27;: nan, &#x27;episode_return_min&#x27;: nan, &#x27;episode_return_max&#x27;: nan}</td><td>{&#x27;num_healthy_workers&#x27;: 2, &#x27;num_in_flight_async_reqs&#x27;: 0, &#x27;num_remote_worker_restarts&#x27;: 0}</td><td>{&#x27;p0&#x27;: {&#x27;vf_loss&#x27;: 0.07396450638771057, &#x27;num_non_trainable_parameters&#x27;: 0.0, &#x27;default_optimizer_learning_rate&#x27;: 3e-05, &#x27;policy_loss&#x27;: 1.1960660219192505, &#x27;num_trainable_parameters&#x27;: 8452101.0, &#x27;mean_kl_loss&#x27;: 0.0, &#x27;total_loss&#x27;: 1.2700304985046387, &#x27;vf_explained_var&#x27;: 0.2418355941772461, &#x27;vf_loss_unclipped&#x27;: 0.07396450638771057, &#x27;curr_entropy_coeff&#x27;: 0.0, &#x27;gradients_default_optimizer_global_norm&#x27;: 3.470919609069824, &#x27;num_module_steps_trained&#x27;: 631, &#x27;entropy&#x27;: 1.3778905868530273}, &#x27;__all_modules__&#x27;: {&#x27;num_trainable_parameters&#x27;: 8452101.0, &#x27;total_loss&#x27;: 1.2700304985046387, &#x27;num_non_trainable_parameters&#x27;: 0.0, &#x27;num_module_steps_trained&#x27;: 631, &#x27;num_env_steps_trained&#x27;: 600}}</td><td>{&#x27;Scanner-1&#x27;: 1165, &#x27;Scanner-2&#x27;: 1200, &#x27;Scanner-3&#x27;: 1200, &#x27;Scanner-4&#x27;: 853}</td><td style=\"text-align: right;\">                            1200</td><td style=\"text-align: right;\">                            1200</td><td style=\"text-align: right;\">                      0</td><td>{&#x27;cpu_util_percent&#x27;: 16.16, &#x27;ram_util_percent&#x27;: 84.24999999999997}</td><td>{&#x27;env_runner_sampling_timer&#x27;: 6.925581849465962, &#x27;learner_update_timer&#x27;: 2.1881487244431628, &#x27;synch_weights&#x27;: 0.009959223236655818, &#x27;synch_env_connectors&#x27;: 0.006511321889702231}</td></tr>\n",
       "</tbody>\n",
       "</table>\n",
       "</div>\n",
       "<style>\n",
       ".trialProgress {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "  color: var(--jp-ui-font-color1);\n",
       "}\n",
       ".trialProgress h3 {\n",
       "  font-weight: bold;\n",
       "}\n",
       ".trialProgress td {\n",
       "  white-space: nowrap;\n",
       "}\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[36m(MultiAgentEnvRunner pid=95949)\u001b[0m \u001b[90;3m2024-09-12 15:05:52,013 \u001b[0m\u001b[36msats.satellite.Scanner-1       \u001b[0m\u001b[93mWARNING    \u001b[0m\u001b[33m<15360.00> \u001b[0m\u001b[36mScanner-1: \u001b[0m\u001b[93mfailed battery_valid check\u001b[0m\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-09-12 15:05:54,737\tINFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/Users/markstephenson/ray_results/PPO_2024-09-12_15-05-24' in 0.0020s.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-09-12 15:05:55,777\tINFO tune.py:1041 -- Total run time: 31.59 seconds (30.51 seconds for the tuning loop).\n"
     ]
    }
   ],
   "source": [
    "import ray\n",
    "from ray import tune\n",
    "\n",
    "ray.init(\n",
    "    ignore_reinit_error=True,\n",
    "    num_cpus=N_CPUS,\n",
    "    object_store_memory=2_000_000_000,  # 2 GB\n",
    ")\n",
    "\n",
    "# Run the training\n",
    "tune.run(\n",
    "    \"PPO\",\n",
    "    config=config.to_dict(),\n",
    "    stop={\"training_iteration\": 2},  # Adjust the number of iterations as needed\n",
    ")\n",
    "\n",
    "# Shutdown Ray\n",
    "ray.shutdown()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv_refactor",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}