# ISC License
#
# Copyright (c) 2023, Autonomous Vehicle Systems Lab, University of Colorado at Boulder
#
# Permission to use, copy, modify, and/or distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
# Import some architectural stuff that we will probably always use
import html
import math
import os
import shutil
import subprocess
import tempfile
import webbrowser
import warnings
import xml.etree.ElementTree as ET
from collections import OrderedDict
from typing import Literal
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import FancyArrowPatch
from Basilisk.architecture import alg_contain, bskLogging, sim_model
from Basilisk.utilities import deprecated, simulationArchTypes
from Basilisk.utilities.pythonVariableLogger import PythonVariableLogger
from Basilisk.utilities.simulationProgessBar import SimulationProgressBar
# Point the path to the module storage area
# define ASCI color codes
processColor = "\u001b[32m"
taskColor = "\u001b[33m"
moduleColor = "\u001b[36m"
endColor = "\u001b[0m"
connectedMessageColor = "#F58518"
inactiveMessageColor = "#8A8F98"
connectionColor = "#5F6B73"
[docs]
def methodizeCondition(conditionList):
"""Methodize a condition list to a function"""
if conditionList is None or len(conditionList) == 0:
return lambda _: False
funcString = "def EVENT_check_condition(self):\n"
funcString += " if("
for condValue in conditionList:
funcString += " " + condValue + " and"
funcString = funcString[:-3] + "):\n"
funcString += " return True\n"
funcString += " return False"
local_namespace = {}
exec(funcString, globals(), local_namespace)
return local_namespace["EVENT_check_condition"]
[docs]
def methodizeAction(actionList):
"""Methodize an action list to a function"""
if actionList is None or len(actionList) == 0:
return lambda _: None
funcString = "def EVENT_operate_action(self):\n"
for actionValue in actionList:
funcString += " " + actionValue + "\n"
funcString += " return None"
local_namespace = {}
exec(funcString, globals(), local_namespace)
return local_namespace["EVENT_operate_action"]
[docs]
class EventHandlerClass:
"""
Class for defining event checking behavior, conditions, and actions.
Three checking strategies are supported:
1. **Exact Interval Checking**: (default) The event is checked only when the
current time is an exact multiple of the ``eventRate``. This behavior is similar
to how tasks are scheduled in Basilisk. Note that if no task leads to a timestep
at a checking time, the event will not be checked.
2. **Elapsed Interval Checking**: The event is checked whenever the ``eventRate``
has elapsed since the last check. This is enabled by setting ``exactRateMatch``
to ``False``. This behavior is similar to how Basilisk loggers operate.
3. **Condition Time Checking**: An alternative to interval-based checking when
an event should occur at a specific time. This is enabled by setting
``conditionTime``, and will lead to the event being triggered at the first
timestep at or after the specified time.
When an event is checked, the ``conditionFunction`` is called to determine if the
event should occur. If the condition returns ``True``, the ``actionFunction`` is executed.
and the event is deactivated. To continue checking the event, it must be reactivated.
If the event is marked as ``terminal``, the simulation will be instructed to terminate
when the event condition occurs.
Args:
eventName (str): Name of the event
eventRate (int): Rate at which the event is checked in nanoseconds
eventActive (bool): Whether the event is active or not
terminal (bool): Whether this event should terminate the simulation when it occurs
conditionFunction (function): Function to check if the event should occur. The
function should take the simulation object as an argument and return a boolean.
This is the preferred manner to set conditions as it enables the use of arbitrary
packages and objects in events and allows for event code to be parsed by IDE tools.
conditionTime (int): Alternative to conditionFunction, a time in nanoseconds to trigger
the event. Does not depend on eventRate for checking.
actionFunction (function): Function to execute when the event occurs. The
function should take the simulation object as an argument.
This is the preferred manner to set conditions as it enables the use of arbitrary
packages and objects in events and allows for event code to be parsed by IDE tools.
exactRateMatch (bool): If True, the event is checked only when the current time is an
exact multiple of the eventRate. If False, the event is checked whenever the
``eventRate`` has elapsed since the last check.
conditionList (list): (deprecated) List of conditions to check for the event,
expressed as strings of code to execute within the class.
actionList (list): (deprecated) List of actions to perform when the event occurs,
expressed as strings of code to execute within the class.
"""
def __init__(
self,
eventName,
eventRate=int(1e9),
eventActive=False,
conditionList=None,
actionList=None,
conditionFunction=None,
actionFunction=None,
conditionTime=None,
terminal=False,
exactRateMatch=True,
):
self.eventName = eventName
self.eventActive = eventActive
self.eventRate = eventRate
self.occurCounter = 0
self.prevCheckTime = None
self.terminal = terminal
self.exactRateMatch = exactRateMatch
self.conditionTime = conditionTime
if conditionTime is not None:
if conditionFunction is not None:
raise ValueError(
"Only specify a conditionFunction or a conditionTime, not both"
)
conditionFunction = (
lambda sim: sim.TotalSim.CurrentNanos >= self.conditionTime
)
self.conditionFunction = conditionFunction or (lambda _: False)
self.actionFunction = actionFunction or (lambda _: None)
if conditionList is not None:
if conditionFunction is not None:
raise ValueError(
"Only specify a conditionFunction or a conditionList, not both"
)
else:
self.conditionFunction = methodizeCondition(conditionList)
if actionList is not None:
if actionFunction is not None:
raise ValueError(
"Only specify an actionFunction or am actionList, not both."
)
else:
self.actionFunction = methodizeAction(actionList)
[docs]
def shouldBeChecked(self, currentTime):
"""See if the event should be checked at the current time."""
if not self.eventActive:
return False
if self.conditionTime is not None:
return currentTime >= self.conditionTime
if self.exactRateMatch:
return currentTime % self.eventRate == 0
else:
return (
self.prevCheckTime is None
or currentTime >= self.prevCheckTime + self.eventRate
)
[docs]
def checkEvent(self, parentSim):
"""Check the condition and execute the action if condition is met."""
if self.conditionFunction(parentSim):
self.eventActive = False
self.actionFunction(parentSim)
self.occurCounter += 1
if self.terminal:
parentSim.terminate = True
self.prevCheckTime = parentSim.TotalSim.CurrentNanos
[docs]
def nextCheckTime(self, currentTime):
"""Get the earliest upcoming time this event should be checked."""
if self.conditionTime is not None:
return self.conditionTime
if self.exactRateMatch:
return (currentTime // self.eventRate + 1) * self.eventRate
else:
return (
self.prevCheckTime + self.eventRate
if self.prevCheckTime is not None
else currentTime
)
[docs]
class StructDocData:
"""Structure data documentation class"""
class StructElementDef:
def __init__(self, type, name, argstring, desc=""):
self.type = type
self.name = name
self.argstring = argstring
self.desc = desc
def __init__(self, strName):
self.strName = strName
self.structPopulated = False
self.structElements = {}
def clearItem(self):
self.structPopulated = False
self.structElements = {}
def populateElem(self, xmlSearchPath):
if self.structPopulated == True:
return
xmlFileUse = xmlSearchPath + "/" + self.strName + ".xml"
try:
xmlData = ET.parse(xmlFileUse)
except:
print("Failed to parse the XML structure for: " + self.strName)
print("This file does not exist most likely: " + xmlFileUse)
return
root = xmlData.getroot()
validElement = root.find("./compounddef[@id='" + self.strName + "']")
for newVariable in validElement.findall(".//memberdef[@kind='variable']"):
typeUse = (
newVariable.find("type").text
if newVariable.find("type") is not None
else None
)
nameUse = (
newVariable.find("name").text
if newVariable.find("type") is not None
else None
)
argstringUse = (
newVariable.find("argsstring").text
if newVariable.find("argsstring") is not None
else None
)
descUse = (
newVariable.find("./detaileddescription/para").text
if newVariable.find("./detaileddescription/para") is not None
else None
)
if descUse == None:
descUse = (
newVariable.find("./briefdescription/para").text
if newVariable.find("./briefdescription/para") is not None
else None
)
newElement = StructDocData.StructElementDef(
typeUse, nameUse, argstringUse, descUse
)
self.structElements.update({nameUse: newElement})
self.structPopulated = True
def printElem(self):
print(" " + self.strName + " Structure Elements:")
for key, value in self.structElements.items():
outputString = ""
outputString += value.type + " " + value.name
outputString += value.argstring if value.argstring is not None else ""
outputString += ": " + value.desc if value.desc is not None else ""
print(" " + outputString)
class DataPairClass:
def __init__(self):
self.outputMessages = set([])
self.inputMessages = set([])
self.name = ""
self.outputDict = {}
def _hasCallable(obj, name):
"""Return ``True`` if ``obj`` has a callable attribute named ``name``."""
return callable(getattr(obj, name, None))
def _hasBskMessageShape(obj):
"""Return ``True`` if ``obj`` looks like a Basilisk message wrapper."""
if not hasattr(obj, "this"):
return False
methodNames = [
"subscribeTo",
"isSubscribedTo",
"addSubscriber",
"getMsgPointers",
"recorder",
]
return any(_hasCallable(obj, methodName) for methodName in methodNames)
def _isSimpleValue(obj):
"""Return ``True`` if ``obj`` is a scalar-like value that should not be walked."""
return obj is None or isinstance(obj, (str, bytes, int, float, bool, complex))
def _shouldInspectMessageAttribute(obj, name):
"""Return ``True`` if an attribute may contain a Basilisk message endpoint."""
if name.startswith("_") or name in ("this", "thisown", "bskLogger"):
return False
if not hasattr(obj, "this"):
return True
messageContainerNames = {
"gravBodies",
"gravField",
}
return "Msg" in name or name in messageContainerNames
def _iterCollection(obj):
"""Yield collection items from Python and SWIG containers when possible."""
if isinstance(obj, dict):
for key, value in obj.items():
yield f"[{key}]", value
return
if isinstance(obj, (list, tuple, set)):
for index, value in enumerate(obj):
yield f"[{index}]", value
return
if _isSimpleValue(obj) or _hasBskMessageShape(obj):
return
try:
iterator = iter(obj)
except TypeError:
return
except Exception:
return
for index, value in enumerate(iterator):
yield f"[{index}]", value
def _messageIdentity(message):
"""Return a stable identity key for a SWIG-wrapped message object."""
if hasattr(message, "this"):
try:
pointerAddress = int(message.this)
except Exception:
pointerAddress = str(message.this)
return type(message).__module__, type(message).__name__, pointerAddress
return type(message).__module__, type(message).__name__, str(id(message))
def _messagePayloadType(message):
"""Return the best available payload type name for a Basilisk message."""
moduleName = type(message).__module__.split(".")[-1]
if moduleName.endswith("Payload"):
return moduleName
typeName = type(message).__name__
for suffix in ("Reader", "_C", "Msg"):
if typeName.endswith(suffix):
typeName = typeName[: -len(suffix)]
break
if typeName.endswith("Payload"):
return typeName
return typeName + "Payload"
def _messageIsLinked(message):
"""Return ``True`` if a Basilisk message endpoint reports a link."""
if _hasCallable(message, "isLinked"):
try:
return bool(message.isLinked())
except Exception:
pass
header = getattr(message, "header", None)
if header is not None and hasattr(header, "isLinked"):
try:
return bool(header.isLinked)
except Exception:
pass
return False
def _messageDirection(name, message):
"""Infer whether a message endpoint is an input or output endpoint."""
nameTail = name.split(".")[-1]
if "InMsg" in nameTail:
return "input"
if "OutMsg" in nameTail:
return "output"
typeName = type(message).__name__
if typeName.endswith("Reader"):
return "input"
if _hasCallable(message, "addSubscriber") or _hasCallable(message, "getMsgPointers"):
return "output"
return None
def _safeGetAttribute(obj, name):
"""Return an object attribute or ``None`` if SWIG rejects the access."""
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore", deprecated.BSKDeprecationWarning)
return getattr(obj, name)
except Exception:
return None
def _walkMessageLeaves(obj, name, depthRemaining, visited):
"""Yield ``(path, message)`` pairs found below ``obj``."""
if _isSimpleValue(obj):
return
if _hasBskMessageShape(obj):
yield name, obj
return
objIdentity = id(obj)
if objIdentity in visited:
return
visited.add(objIdentity)
if depthRemaining <= 0:
return
for collectionName, value in _iterCollection(obj):
yield from _walkMessageLeaves(
value,
name + collectionName,
depthRemaining - 1,
visited,
)
if not hasattr(obj, "__dict__") and not hasattr(obj, "this"):
return
for attrName in dir(obj):
if not _shouldInspectMessageAttribute(obj, attrName):
continue
value = _safeGetAttribute(obj, attrName)
if value is None:
continue
if callable(value) and not _hasBskMessageShape(value):
continue
yield from _walkMessageLeaves(
value,
attrName if name == "" else name + "." + attrName,
depthRemaining - 1,
visited,
)
def _iterMessageLeaves(obj, name="", maxDepth=4):
"""Yield ``(path, message)`` pairs found below ``obj``."""
yield from _walkMessageLeaves(obj, name, maxDepth, set())
def _isSubscribedTo(inputMessage, sourceMessage):
"""Return ``True`` if ``inputMessage`` is subscribed to ``sourceMessage``."""
if not _hasCallable(inputMessage, "isSubscribedTo"):
return False
try:
return bool(inputMessage.isSubscribedTo(sourceMessage))
except Exception:
return False
def _shortText(text, maxLength=26):
"""Shorten text so that port labels remain legible in compact figures."""
if len(text) <= maxLength:
return text
return text[: maxLength - 3] + "..."
def _graphvizId(text):
"""Return a DOT-safe identifier from arbitrary graph text."""
return "".join(character if character.isalnum() else "_" for character in text)
def _graphvizEscape(text):
"""Escape text for Graphviz HTML-like labels."""
escapedText = html.escape(str(text), quote=True)
if escapedText.strip() == "":
return " "
return escapedText
def _attachRecorderSource(recorder, sourceMessage):
"""Attach source message metadata to a recorder if the wrapper allows it."""
try:
recorder._bskRecordedMessage = sourceMessage
except Exception:
pass
return recorder
def _ensureRecorderSourceHooks():
"""Wrap generated ``recorder()`` methods so recorders remember their source."""
try:
from Basilisk.architecture import messaging
except Exception:
return
for attrName in dir(messaging):
messageClass = getattr(messaging, attrName)
recorderMethod = getattr(messageClass, "recorder", None)
if not callable(recorderMethod):
continue
if getattr(messageClass, "_bskRecorderSourceHooked", False):
continue
def recorderWithSource(self, *args, _recorderMethod=recorderMethod, **kwargs):
recorder = _recorderMethod(self, *args, **kwargs)
return _attachRecorderSource(recorder, self)
try:
messageClass.recorder = recorderWithSource
messageClass._bskRecorderSourceHooked = True
except Exception:
pass
_ensureRecorderSourceHooks()
[docs]
class SimBaseClass:
"""Simulation Base Class"""
def __init__(self):
_ensureRecorderSourceHooks()
self.TotalSim = sim_model.SimModel()
self.TaskList = []
self.procList = []
self.StopTime = 0
self.nextEventTime = 0
self.terminate = False
self.eventMap = {}
self.simBasePath = os.path.dirname(os.path.realpath(__file__)) + "/../"
self.dataStructIndex = self.simBasePath + "/xml/index.xml"
self.indexParsed = False
self.simulationInitialized = False
self.simulationFinished = False
self.bskLogger = bskLogging.BSKLogger()
self.showProgressBar = False
self.allModules = set()
[docs]
def SetProgressBar(self, value):
"""
Shows a dynamic progress in the terminal while the simulation is executing.
"""
self.showProgressBar = value
[docs]
def ShowExecutionOrder(self):
"""
Shows in what order the Basilisk processes, task lists and modules are executed
"""
for processData in self.TotalSim.processList:
print(
f"{processColor}Process Name: {endColor}{processData.processName}, "
f"{processColor}priority: {endColor}{processData.processPriority}"
)
for task in processData.processTasks:
print(
f"{taskColor}Task Name: {endColor}{task.TaskPtr.TaskName}, "
f"{taskColor}priority: {endColor}{task.taskPriority}, "
f"{taskColor}TaskPeriod: {endColor}{task.TaskPtr.TaskPeriod / 1.0e9}s"
)
for module in task.TaskPtr.TaskModels:
print(
f"{moduleColor}ModuleTag: {endColor}{module.ModelPtr.ModelTag}, "
f"{moduleColor}priority: {endColor}{module.CurrentModelPriority}"
)
print()
def _normalizeExtraMessages(self, extraMessages):
"""Normalize stand-alone messages into an ordered name/object dictionary."""
normalizedMessages = OrderedDict()
if extraMessages is None:
return normalizedMessages
messageCollections = [extraMessages]
if not isinstance(extraMessages, dict):
messageCollections = extraMessages
for messageCollection in messageCollections:
if not isinstance(messageCollection, dict):
raise TypeError("extraMessages must be a dictionary or a list of dictionaries.")
for messageName, message in messageCollection.items():
normalizedMessages[str(messageName)] = message
return normalizedMessages
def _getOrderedTaskModels(self, task):
"""Return Python-side task models in simulation execution order."""
priorities = getattr(task, "TaskModelPriorities", None)
if priorities is None or len(priorities) != len(task.TaskModels):
return []
indexedModels = list(enumerate(zip(task.TaskModels, priorities)))
indexedModels.sort(key=lambda item: (-item[1][1], item[0]))
return [model for _, (model, _) in indexedModels]
def _getTaskModelCandidates(self, taskName, modelTag, modelPointer, taskModelIndex):
"""Return Python-side objects that may own message attributes for a model."""
candidates = [modelPointer]
for task in self.TaskList:
if task.Name != taskName:
continue
orderedTaskModels = self._getOrderedTaskModels(task)
if taskModelIndex < len(orderedTaskModels):
taskModel = orderedTaskModels[taskModelIndex]
if taskModel is not modelPointer:
candidates.append(taskModel)
return candidates
for taskModel in task.TaskModels:
if taskModel is modelPointer:
continue
if getattr(taskModel, "ModelTag", "") == modelTag:
candidates.append(taskModel)
return candidates
def _candidateListHasRecorder(self, candidates):
"""Return ``True`` if any model candidate is a message recorder."""
for candidate in candidates:
if getattr(candidate, "_bskRecordedMessage", None) is not None:
return True
modelTag = getattr(candidate, "ModelTag", "")
if isinstance(modelTag, str) and modelTag.startswith("Rec:"):
return True
return False
def _collectModuleMessageEndpoints(self, moduleId, moduleRecord, candidates):
"""Collect input and output message endpoints from module candidates."""
seenEndpoints = set()
for candidate in candidates:
recordedMessage = getattr(candidate, "_bskRecordedMessage", None)
if recordedMessage is not None:
endpoint = {
"id": f"{moduleId}:input:{len(moduleRecord['inputs'])}",
"ownerId": moduleId,
"ownerType": "module",
"moduleTag": moduleRecord["tag"],
"name": "recordedMsg",
"direction": "input",
"payloadType": _messagePayloadType(recordedMessage),
"message": candidate,
"recordedMessage": recordedMessage,
"isLinked": True,
}
moduleRecord["inputs"].append(endpoint)
for messageName, message in _iterMessageLeaves(candidate):
direction = _messageDirection(messageName, message)
if direction is None:
continue
endpointKey = (direction, _messageIdentity(message), messageName)
if endpointKey in seenEndpoints:
continue
seenEndpoints.add(endpointKey)
endpointId = f"{moduleId}:{direction}:{len(moduleRecord[direction + 's'])}"
endpoint = {
"id": endpointId,
"ownerId": moduleId,
"ownerType": "module",
"moduleTag": moduleRecord["tag"],
"name": messageName,
"direction": direction,
"payloadType": _messagePayloadType(message),
"message": message,
"isLinked": _messageIsLinked(message),
}
moduleRecord[direction + "s"].append(endpoint)
[docs]
def GetMessageConnectionGraph(
self,
extraMessages=None,
includeUnlinked=True,
includeRecorders=True,
):
"""
Extract Basilisk message connections from the configured simulation.
Args:
extraMessages (dict or list[dict], optional): Stand-alone messages to
include as possible source messages. The dictionary keys are used
as labels in the returned graph and any generated figure.
includeUnlinked (bool): If ``True``, include unlinked input message
endpoints in the returned graph.
includeRecorders (bool): If ``True``, include message recorder modules
in the returned graph.
Returns:
dict: A graph dictionary containing module records, stand-alone
message records, endpoint records, connection edge records, unlinked
inputs, and linked inputs whose source was not found.
"""
graph = {
"modules": [],
"standaloneMessages": [],
"inputs": [],
"outputs": [],
"edges": [],
"unlinkedInputs": [],
"unresolvedInputs": [],
}
moduleIndex = 0
for processData in self.TotalSim.processList:
for task in processData.processTasks:
taskPeriod = task.TaskPtr.TaskPeriod / 1.0e9 # [s]
for taskModelIndex, module in enumerate(task.TaskPtr.TaskModels):
modelPointer = module.ModelPtr
modelTag = modelPointer.ModelTag
moduleId = f"module:{moduleIndex}"
moduleRecord = {
"id": moduleId,
"tag": modelTag,
"processName": processData.processName,
"processPriority": processData.processPriority,
"taskName": task.TaskPtr.TaskName,
"taskPriority": task.taskPriority,
"taskPeriod": taskPeriod,
"modelPriority": module.CurrentModelPriority,
"executionIndex": moduleIndex,
"inputs": [],
"outputs": [],
}
candidates = self._getTaskModelCandidates(
task.TaskPtr.TaskName,
modelTag,
modelPointer,
taskModelIndex,
)
if not includeRecorders and self._candidateListHasRecorder(candidates):
continue
self._collectModuleMessageEndpoints(moduleId, moduleRecord, candidates)
graph["modules"].append(moduleRecord)
graph["inputs"].extend(moduleRecord["inputs"])
graph["outputs"].extend(moduleRecord["outputs"])
moduleIndex += 1
for messageIndex, (messageName, messageObject) in enumerate(
self._normalizeExtraMessages(extraMessages).items()
):
for sourceName, sourceMessage in _iterMessageLeaves(messageObject, messageName):
sourceId = f"standalone:{messageIndex}:{len(graph['standaloneMessages'])}"
endpoint = {
"id": sourceId + ":output:0",
"ownerId": sourceId,
"ownerType": "standalone",
"moduleTag": "",
"name": sourceName,
"direction": "output",
"payloadType": _messagePayloadType(sourceMessage),
"message": sourceMessage,
"isLinked": _messageIsLinked(sourceMessage),
}
standaloneRecord = {
"id": sourceId,
"name": sourceName,
"outputs": [endpoint],
}
graph["standaloneMessages"].append(standaloneRecord)
graph["outputs"].append(endpoint)
seenEdges = set()
def addEdge(sourceEndpoint, targetEndpoint):
"""Add a graph edge unless it was already recorded."""
edgeKey = (sourceEndpoint["id"], targetEndpoint["id"])
if edgeKey in seenEdges:
return
seenEdges.add(edgeKey)
graph["edges"].append(
{
"source": sourceEndpoint["id"],
"target": targetEndpoint["id"],
"sourceName": sourceEndpoint["name"],
"targetName": targetEndpoint["name"],
"sourceOwner": sourceEndpoint["ownerId"],
"targetOwner": targetEndpoint["ownerId"],
"payloadType": targetEndpoint["payloadType"],
"sourceType": sourceEndpoint["ownerType"],
}
)
for inputEndpoint in graph["inputs"]:
recordedMessage = inputEndpoint.get("recordedMessage", None)
if recordedMessage is not None:
for sourceEndpoint in graph["inputs"]:
if inputEndpoint["id"] == sourceEndpoint["id"]:
continue
if _messageIdentity(recordedMessage) == _messageIdentity(
sourceEndpoint["message"]
):
addEdge(sourceEndpoint, inputEndpoint)
break
else:
for outputEndpoint in graph["outputs"]:
if _messageIdentity(recordedMessage) == _messageIdentity(
outputEndpoint["message"]
):
addEdge(outputEndpoint, inputEndpoint)
break
continue
for outputEndpoint in graph["outputs"]:
if inputEndpoint["message"] is outputEndpoint["message"]:
continue
if not _isSubscribedTo(inputEndpoint["message"], outputEndpoint["message"]):
continue
addEdge(outputEndpoint, inputEndpoint)
connectedInputIds = {edge["target"] for edge in graph["edges"]}
for inputEndpoint in graph["inputs"]:
if inputEndpoint["id"] in connectedInputIds:
continue
if inputEndpoint["isLinked"]:
graph["unresolvedInputs"].append(inputEndpoint)
elif includeUnlinked:
graph["unlinkedInputs"].append(inputEndpoint)
if not includeUnlinked:
hiddenInputIds = {
endpoint["id"]
for endpoint in graph["inputs"]
if endpoint["id"] not in connectedInputIds and not endpoint["isLinked"]
}
graph["inputs"] = [
endpoint for endpoint in graph["inputs"] if endpoint["id"] not in hiddenInputIds
]
for moduleRecord in graph["modules"]:
moduleRecord["inputs"] = [
endpoint
for endpoint in moduleRecord["inputs"]
if endpoint["id"] not in hiddenInputIds
]
return graph
[docs]
def GetMessageConnectionDot(
self,
extraMessages=None,
includeUnlinked=True,
includeRecorders=True,
graphvizLayout="vertical",
):
"""
Return a Graphviz DOT description of the simulation message graph.
Args:
extraMessages (dict or list[dict], optional): Stand-alone messages to
include as possible source messages.
includeUnlinked (bool): If ``True``, include unlinked input message
endpoints.
includeRecorders (bool): If ``True``, include message recorder modules.
graphvizLayout (str): Graphviz module layout direction. Use
``"vertical"`` for a top-to-bottom layout or ``"horizontal"``
for a left-to-right layout.
Returns:
str: Graphviz DOT source text.
"""
graph = self.GetMessageConnectionGraph(
extraMessages=extraMessages,
includeUnlinked=includeUnlinked,
includeRecorders=includeRecorders,
)
return self._messageConnectionGraphToDot(
graph,
graphvizLayout=graphvizLayout,
)
def _renderMessageConnectionGraphviz(
self,
graph,
fileName=None,
graphvizFormat="svg",
graphvizLayout="vertical",
show_plots=False,
):
"""Render a message connection graph through the Graphviz executable."""
outputFormat = graphvizFormat.lower().lstrip(".")
if outputFormat == "":
outputFormat = "svg"
outputFileName, dotFileName, outputFormat = self._resolveGraphvizFileNames(
fileName,
outputFormat,
)
dotText = self._messageConnectionGraphToDot(
graph,
graphvizLayout=graphvizLayout,
)
dotDirectory = os.path.dirname(dotFileName)
if dotDirectory:
os.makedirs(dotDirectory, exist_ok=True)
with open(dotFileName, "w") as dotFile:
dotFile.write(dotText)
if outputFormat == "dot":
outputFileName = dotFileName
else:
dotExecutable = shutil.which("dot")
if dotExecutable is None:
raise RuntimeError(
"Graphviz renderer requires the 'dot' executable to be available."
)
outputDirectory = os.path.dirname(outputFileName)
if outputDirectory:
os.makedirs(outputDirectory, exist_ok=True)
try:
subprocess.run(
[
dotExecutable,
"-T" + outputFormat,
dotFileName,
"-o",
outputFileName,
],
check=True,
capture_output=True,
text=True,
)
except subprocess.CalledProcessError as error:
raise RuntimeError(
"Graphviz 'dot' failed to render the message connection graph: "
+ error.stderr
) from error
if show_plots:
webbrowser.open("file://" + os.path.abspath(outputFileName))
return outputFileName
def _resolveGraphvizFileNames(self, fileName, graphvizFormat):
"""Resolve Graphviz DOT and rendered output paths."""
if fileName is None:
outputDirectory = tempfile.mkdtemp(prefix="bskMessageConnections_")
outputFileName = os.path.join(
outputDirectory,
"messageConnections." + graphvizFormat,
)
dotFileName = os.path.join(outputDirectory, "messageConnections.dot")
return outputFileName, dotFileName, graphvizFormat
fileRoot, fileExtension = os.path.splitext(fileName)
if fileExtension:
extensionFormat = fileExtension[1:].lower()
if extensionFormat == "dot":
dotFileName = fileName
outputFileName = fileRoot + "." + graphvizFormat
else:
graphvizFormat = extensionFormat
outputFileName = fileName
dotFileName = fileRoot + ".dot"
else:
outputFileName = fileName + "." + graphvizFormat
dotFileName = fileName + ".dot"
return outputFileName, dotFileName, graphvizFormat
def _resolveGraphvizLayout(self, graphvizLayout):
"""Resolve the requested Graphviz layout into a rank direction."""
layoutName = str(graphvizLayout).lower()
if layoutName in ["vertical", "tb", "top-bottom", "top_to_bottom"]:
return "vertical", "TB"
if layoutName in ["horizontal", "lr", "left-right", "left_to_right"]:
return "horizontal", "LR"
raise ValueError(
"graphvizLayout must be either 'vertical' or 'horizontal'."
)
def _messageConnectionGraphToDot(self, graph, graphvizLayout="vertical"):
"""Convert a message connection graph dictionary to Graphviz DOT text."""
graphvizLayout, rankDirection = self._resolveGraphvizLayout(graphvizLayout)
if graphvizLayout == "vertical":
nodeSeparation = "0.28"
rankSeparation = "0.45"
sourceCompass = "s"
targetCompass = "n"
else:
nodeSeparation = "0.35"
rankSeparation = "0.65"
sourceCompass = "e"
targetCompass = "w"
endpointNodePorts = {}
connectedEndpointIds = {
endpointId
for edge in graph["edges"]
for endpointId in (edge["source"], edge["target"])
}
lines = [
"digraph BSKMessageConnections {",
f' graph [rankdir="{rankDirection}", bgcolor="transparent", pad="0.15",',
f' nodesep="{nodeSeparation}", ranksep="{rankSeparation}", splines="spline"];',
' node [shape="plain", fontname="Helvetica"];',
' edge [fontname="Helvetica", fontsize="9", arrowsize="0.7",',
f' color="{connectionColor}"];',
]
for standaloneRecord in graph["standaloneMessages"]:
nodeId = _graphvizId(standaloneRecord["id"])
outputEndpoint = standaloneRecord["outputs"][0]
outputPort = _graphvizId(outputEndpoint["id"])
endpointNodePorts[outputEndpoint["id"]] = (nodeId, outputPort)
connected = outputEndpoint["id"] in connectedEndpointIds
lines.extend(
self._makeGraphvizStandaloneNode(
nodeId,
standaloneRecord,
outputEndpoint,
outputPort,
connected,
)
)
for moduleRecord in graph["modules"]:
nodeId = _graphvizId(moduleRecord["id"])
for endpoint in moduleRecord["inputs"] + moduleRecord["outputs"]:
endpointNodePorts[endpoint["id"]] = (nodeId, _graphvizId(endpoint["id"]))
lines.extend(
self._makeGraphvizModuleNode(
nodeId,
moduleRecord,
connectedEndpointIds,
)
)
for edge in graph["edges"]:
if edge["source"] not in endpointNodePorts or edge["target"] not in endpointNodePorts:
continue
sourceNode, sourcePort = endpointNodePorts[edge["source"]]
targetNode, targetPort = endpointNodePorts[edge["target"]]
style = "dashed" if edge["sourceType"] == "standalone" else "solid"
lines.append(
f' "{sourceNode}":"{sourcePort}":{sourceCompass} -> '
f'"{targetNode}":"{targetPort}":{targetCompass} [style="{style}"];'
)
for index in range(len(graph["modules"]) - 1):
sourceNode = _graphvizId(graph["modules"][index]["id"])
targetNode = _graphvizId(graph["modules"][index + 1]["id"])
lines.append(
f' "{sourceNode}" -> "{targetNode}" [style="invis", weight="8"];'
)
lines.extend(self._makeGraphvizLegendNode(graphvizLayout))
if graphvizLayout == "vertical" and graph["modules"]:
lastNode = _graphvizId(graph["modules"][-1]["id"])
lines.append(
f' "{lastNode}" -> "legend" [style="invis", weight="1"];'
)
lines.append("}")
return "\n".join(lines) + "\n"
def _makeGraphvizStandaloneNode(
self,
nodeId,
standaloneRecord,
outputEndpoint,
outputPort,
connected,
):
"""Create DOT lines for a stand-alone message node."""
color = connectedMessageColor if connected else inactiveMessageColor
fontColor = "white" if connected else "#222222"
return [
f' "{nodeId}" [label=<',
' <TABLE BORDER="1" CELLBORDER="0" CELLSPACING="0" CELLPADDING="5"',
f' COLOR="{color}" BGCOLOR="#FFFFFF">',
" <TR>",
f' <TD PORT="{outputPort}" BGCOLOR="{color}">'
f'<FONT POINT-SIZE="9" COLOR="{fontColor}"><B>'
f'{_graphvizEscape(_shortText(standaloneRecord["name"], 28))}'
"</B></FONT></TD>",
" </TR>",
" <TR>",
' <TD><FONT POINT-SIZE="8" COLOR="#555555">'
f'{_graphvizEscape(_shortText(outputEndpoint["payloadType"], 34))}'
"</FONT></TD>",
" </TR>",
" </TABLE>",
" >];",
]
def _makeGraphvizModuleNode(self, nodeId, moduleRecord, connectedEndpointIds):
"""Create DOT lines for a module node with input and output ports."""
lines = [
f' "{nodeId}" [label=<',
' <TABLE BORDER="1" CELLBORDER="0" CELLSPACING="0" CELLPADDING="4"',
' COLOR="#2F3B45" BGCOLOR="#F8FAFF">',
" <TR>",
' <TD COLSPAN="3"><FONT POINT-SIZE="10"><B>'
f'{_graphvizEscape(_shortText(moduleRecord["tag"], 30))}'
"</B></FONT></TD>",
" </TR>",
" <TR>",
' <TD COLSPAN="3"><FONT POINT-SIZE="8" COLOR="#555555">'
f'{_graphvizEscape(_shortText(moduleRecord["processName"] + " / " + moduleRecord["taskName"], 42))}'
"</FONT></TD>",
" </TR>",
]
rowCount = max(len(moduleRecord["inputs"]), len(moduleRecord["outputs"]), 1)
for index in range(rowCount):
inputEndpoint = (
moduleRecord["inputs"][index]
if index < len(moduleRecord["inputs"])
else None
)
outputEndpoint = (
moduleRecord["outputs"][index]
if index < len(moduleRecord["outputs"])
else None
)
inputCell = self._makeGraphvizPortCell(
inputEndpoint,
connectedEndpointIds,
"LEFT",
)
outputCell = self._makeGraphvizPortCell(
outputEndpoint,
connectedEndpointIds,
"RIGHT",
)
lines.extend(
[
" <TR>",
" " + inputCell,
' <TD WIDTH="12"></TD>',
" " + outputCell,
" </TR>",
]
)
lines.extend(
[
" </TABLE>",
" >];",
]
)
return lines
def _makeGraphvizPortCell(self, endpoint, connectedEndpointIds, align):
"""Create one Graphviz HTML table cell for an input or output port."""
if endpoint is None:
return '<TD WIDTH="80"></TD>'
color = (
connectedMessageColor
if endpoint["id"] in connectedEndpointIds
else inactiveMessageColor
)
fontColor = "white" if endpoint["id"] in connectedEndpointIds else "#222222"
portName = _graphvizId(endpoint["id"])
label = _shortText(endpoint["name"].split(".")[-1], 24)
return (
f'<TD PORT="{portName}" ALIGN="{align}" BGCOLOR="{color}">'
f'<FONT POINT-SIZE="8" COLOR="{fontColor}">'
f'{_graphvizEscape(label)}</FONT></TD>'
)
def _makeGraphvizLegendNode(self, graphvizLayout):
"""Create DOT lines for the Graphviz legend node."""
if graphvizLayout == "vertical":
return [
' "legend" [label=<',
' <TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="4">',
" <TR>",
f' <TD BGCOLOR="{inactiveMessageColor}"><FONT POINT-SIZE="8" COLOR="white">message</FONT></TD>',
f' <TD BGCOLOR="{connectedMessageColor}"><FONT POINT-SIZE="8" COLOR="white">connected</FONT></TD>',
" </TR>",
" <TR>",
' <TD><FONT POINT-SIZE="8">solid: module link</FONT></TD>',
' <TD><FONT POINT-SIZE="8">dashed: extraMessages</FONT></TD>',
" </TR>",
" </TABLE>",
" >];",
]
return [
' "legend" [label=<',
' <TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="4">',
" <TR>",
f' <TD BGCOLOR="{inactiveMessageColor}"><FONT POINT-SIZE="8" COLOR="white">message</FONT></TD>',
f' <TD BGCOLOR="{connectedMessageColor}"><FONT POINT-SIZE="8" COLOR="white">connected</FONT></TD>',
' <TD><FONT POINT-SIZE="8">solid: module link</FONT></TD>',
' <TD><FONT POINT-SIZE="8">dashed: extraMessages</FONT></TD>',
" </TR>",
" </TABLE>",
" >];",
]
def _drawMessageConnectionFigure(self, graph):
"""Draw a module message connection figure from a graph dictionary."""
moduleCount = max(len(graph["modules"]), 1)
maxModulesPerRow = min(moduleCount, 4)
moduleWidth = 2.7
moduleGap = 0.95
moduleStep = moduleWidth + moduleGap
maxPortCount = 1
for moduleRecord in graph["modules"]:
maxPortCount = max(
maxPortCount,
len(moduleRecord["inputs"]),
len(moduleRecord["outputs"]),
)
moduleHeight = max(1.55, 0.3 * (maxPortCount + 3))
rowStep = moduleHeight + 1.7
moduleRows = int(math.ceil(moduleCount / maxModulesPerRow))
standaloneRows = int(math.ceil(len(graph["standaloneMessages"]) / maxModulesPerRow))
standaloneBandHeight = 1.35 * standaloneRows
figWidth = max(6.0, maxModulesPerRow * 2.0)
figHeight = max(3.2, moduleRows * 1.5 + standaloneBandHeight * 0.75 + 0.9)
fig = plt.figure(figsize=(figWidth, figHeight))
ax = fig.add_subplot(111)
ax.axis("off")
moduleLayout = {}
endpointCoordinates = {}
for moduleRecord in graph["modules"]:
row = moduleRecord["executionIndex"] // maxModulesPerRow
column = moduleRecord["executionIndex"] % maxModulesPerRow
xPosition = column * moduleStep
yPosition = -row * rowStep
moduleLayout[moduleRecord["id"]] = (xPosition, yPosition)
self._assignModulePortCoordinates(
moduleRecord,
xPosition,
yPosition,
moduleWidth,
moduleHeight,
endpointCoordinates,
)
standaloneLayout = {}
standaloneTop = moduleHeight * 0.5 + standaloneBandHeight
for index, standaloneRecord in enumerate(graph["standaloneMessages"]):
row = index // maxModulesPerRow
column = index % maxModulesPerRow
xPosition = column * moduleStep
yPosition = standaloneTop - 1.35 * row
standaloneLayout[standaloneRecord["id"]] = (xPosition, yPosition)
endpointCoordinates[standaloneRecord["outputs"][0]["id"]] = (
xPosition + moduleWidth,
yPosition,
)
connectedEndpointIds = {
endpointId
for edge in graph["edges"]
for endpointId in (edge["source"], edge["target"])
}
unresolvedInputIds = {endpoint["id"] for endpoint in graph["unresolvedInputs"]}
unlinkedInputIds = {endpoint["id"] for endpoint in graph["unlinkedInputs"]}
self._drawMessageEdges(ax, graph["edges"], endpointCoordinates)
self._drawStandaloneMessageNodes(
ax,
graph["standaloneMessages"],
standaloneLayout,
moduleWidth,
connectedEndpointIds,
)
self._drawModuleMessageNodes(
ax,
graph["modules"],
moduleLayout,
moduleWidth,
moduleHeight,
connectedEndpointIds,
unresolvedInputIds,
unlinkedInputIds,
)
self._drawMessageFigureLegend(ax, moduleLayout, moduleHeight, rowStep)
xLimit = maxModulesPerRow * moduleStep - moduleGap + 0.5
yTop = standaloneTop + 1.1 if graph["standaloneMessages"] else moduleHeight
yBottom = -max(moduleRows - 1, 0) * rowStep - moduleHeight - 0.9
ax.set_xlim(-0.8, xLimit)
ax.set_ylim(yBottom, yTop)
fig.tight_layout(pad=0.2)
return fig
def _assignModulePortCoordinates(
self,
moduleRecord,
xPosition,
yPosition,
moduleWidth,
moduleHeight,
endpointCoordinates,
):
"""Assign input and output port coordinates for one module node."""
headerSpace = 0.65
usableHeight = moduleHeight - headerSpace
for direction, xOffset in (("inputs", 0.0), ("outputs", moduleWidth)):
endpoints = moduleRecord[direction]
if len(endpoints) == 0:
continue
portSpacing = usableHeight / (len(endpoints) + 1)
firstPortY = yPosition + moduleHeight * 0.5 - headerSpace
for index, endpoint in enumerate(endpoints):
endpointCoordinates[endpoint["id"]] = (
xPosition + xOffset,
firstPortY - portSpacing * (index + 1),
)
def _drawMessageEdges(self, ax, edges, endpointCoordinates):
"""Draw message connection arrows."""
for edgeIndex, edge in enumerate(edges):
if edge["source"] not in endpointCoordinates or edge["target"] not in endpointCoordinates:
continue
sourcePoint = endpointCoordinates[edge["source"]]
targetPoint = endpointCoordinates[edge["target"]]
directionSign = 1 if sourcePoint[1] >= targetPoint[1] else -1
curve = 0.12 * directionSign * (1 + edgeIndex % 3)
lineStyle = "--" if edge["sourceType"] == "standalone" else "-"
arrow = FancyArrowPatch(
sourcePoint,
targetPoint,
arrowstyle="-|>",
mutation_scale=9,
linewidth=1.15,
linestyle=lineStyle,
color=connectionColor,
alpha=0.72,
connectionstyle=f"arc3,rad={curve}",
zorder=1,
)
ax.add_patch(arrow)
def _drawStandaloneMessageNodes(
self,
ax,
standaloneRecords,
standaloneLayout,
moduleWidth,
connectedEndpointIds,
):
"""Draw stand-alone message source nodes."""
nodeHeight = 0.72
for standaloneRecord in standaloneRecords:
xPosition, yPosition = standaloneLayout[standaloneRecord["id"]]
outputEndpoint = standaloneRecord["outputs"][0]
color = (
connectedMessageColor
if outputEndpoint["id"] in connectedEndpointIds
else inactiveMessageColor
)
rectangle = plt.Rectangle(
(xPosition, yPosition - nodeHeight * 0.5),
moduleWidth,
nodeHeight,
ec=color,
fc=(1, 1, 1, 0.95),
linewidth=1.2,
zorder=3,
)
ax.add_patch(rectangle)
ax.scatter(
[xPosition + moduleWidth],
[yPosition],
s=55,
color=color,
edgecolors="white",
linewidths=0.8,
zorder=4,
)
ax.text(
xPosition + 0.14,
yPosition + 0.12,
_shortText(standaloneRecord["name"]),
fontsize=7,
fontweight="bold",
va="center",
zorder=4,
)
ax.text(
xPosition + 0.14,
yPosition - 0.16,
_shortText(outputEndpoint["payloadType"], 32),
fontsize=6.5,
color="#555555",
va="center",
zorder=4,
)
def _drawModuleMessageNodes(
self,
ax,
moduleRecords,
moduleLayout,
moduleWidth,
moduleHeight,
connectedEndpointIds,
unresolvedInputIds,
unlinkedInputIds,
):
"""Draw module boxes with colored input and output message ports."""
for moduleRecord in moduleRecords:
xPosition, yPosition = moduleLayout[moduleRecord["id"]]
rectangle = plt.Rectangle(
(xPosition, yPosition - moduleHeight * 0.5),
moduleWidth,
moduleHeight,
ec="#2F3B45",
fc=(0.98, 0.99, 1.0, 0.7),
linewidth=1.2,
zorder=3,
)
ax.add_patch(rectangle)
ax.text(
xPosition + moduleWidth * 0.5,
yPosition + moduleHeight * 0.5 - 0.22,
_shortText(moduleRecord["tag"], 24),
fontsize=7.0,
fontweight="bold",
ha="center",
va="center",
zorder=4,
)
ax.text(
xPosition + moduleWidth * 0.5,
yPosition + moduleHeight * 0.5 - 0.48,
_shortText(
f"{moduleRecord['processName']} / {moduleRecord['taskName']}",
36,
),
fontsize=6.3,
color="#555555",
ha="center",
va="center",
zorder=4,
)
self._drawModulePorts(
ax,
moduleRecord,
xPosition,
yPosition,
moduleWidth,
moduleHeight,
connectedEndpointIds,
unresolvedInputIds,
unlinkedInputIds,
)
def _drawModulePorts(
self,
ax,
moduleRecord,
xPosition,
yPosition,
moduleWidth,
moduleHeight,
connectedEndpointIds,
unresolvedInputIds,
unlinkedInputIds,
):
"""Draw the input and output message ports for one module."""
headerSpace = 0.65
usableHeight = moduleHeight - headerSpace
portSize = 40
for endpointListName, xOffset, labelOffset, textAlign in (
("inputs", 0.0, 0.12, "left"),
("outputs", moduleWidth, -0.12, "right"),
):
endpoints = moduleRecord[endpointListName]
if len(endpoints) == 0:
continue
portSpacing = usableHeight / (len(endpoints) + 1)
firstPortY = yPosition + moduleHeight * 0.5 - headerSpace
for index, endpoint in enumerate(endpoints):
portX = xPosition + xOffset
portY = firstPortY - portSpacing * (index + 1)
color = (
connectedMessageColor
if endpoint["id"] in connectedEndpointIds
else inactiveMessageColor
)
ax.scatter(
[portX],
[portY],
s=portSize,
color=color,
edgecolors="white",
linewidths=0.75,
zorder=5,
)
ax.text(
portX + labelOffset,
portY,
_shortText(endpoint["name"].split(".")[-1], 22),
fontsize=4.4,
color="#222222",
ha=textAlign,
va="center",
zorder=5,
)
def _drawMessageFigureLegend(self, ax, moduleLayout, moduleHeight, rowStep):
"""Draw a compact color legend for message connection figures."""
if len(moduleLayout) == 0:
legendY = -moduleHeight
else:
lastRow = min(yPosition for _, yPosition in moduleLayout.values())
legendY = lastRow - rowStep * 0.35
pointEntries = [
("message", inactiveMessageColor),
("connected", connectedMessageColor),
]
for index, (label, color) in enumerate(pointEntries):
xPosition = index * 1.45
ax.scatter(
[xPosition],
[legendY],
s=40,
color=color,
edgecolors="white",
linewidths=0.7,
zorder=5,
)
ax.text(
xPosition + 0.14,
legendY,
label,
fontsize=6.8,
va="center",
zorder=5,
)
lineEntries = [
("solid: module link", "-"),
("dashed: extraMessages", "--"),
]
for index, (label, lineStyle) in enumerate(lineEntries):
xPosition = 3.25 + index * 2.05
ax.plot(
[xPosition, xPosition + 0.42],
[legendY, legendY],
linestyle=lineStyle,
linewidth=1.15,
color=connectionColor,
alpha=0.72,
zorder=5,
)
ax.text(
xPosition + 0.5,
legendY,
label,
fontsize=6.8,
va="center",
zorder=5,
)
[docs]
def AddModelToTask(self, TaskName, NewModel, ModelData=None, ModelPriority=-1):
"""
This function is responsible for passing on the logger to a module instance (model), adding the
model to a particular task, and defining
the order/priority that the model gets updated within the task.
:param TaskName (str): Name of the task
:param NewModel (obj): Model to add to the task
:param ModelData: None or struct containing, only used for C BSK modules
:param ModelPriority (int): Priority that determines when the model gets updated. (Higher number = Higher priority)
:return:
"""
# Supports calling AddModelToTask(TaskName, NewModel, ModelPriority)
if isinstance(ModelData, int):
ModelPriority = ModelData
ModelData = None
for Task in self.TaskList:
if Task.Name == TaskName:
Task.TaskData.AddNewObject(NewModel, ModelPriority)
if ModelData is not None:
try:
ModelData.bskLogger = self.bskLogger
except:
pass
Task.TaskModels.append(ModelData)
Task.TaskModelPriorities.append(ModelPriority)
else:
try:
NewModel.bskLogger = self.bskLogger
except:
pass
Task.TaskModels.append(NewModel)
Task.TaskModelPriorities.append(ModelPriority)
return
raise ValueError(f"Could not find a Task with name: {TaskName}")
[docs]
def CreateNewProcess(self, procName, priority=-1):
"""
Creates a process and adds it to the sim
:param procName (str): Name of process
:param priority (int): Priority that determines when the model gets updated. (Higher number = Higher priority)
:return: simulationArchTypes.ProcessBaseClass object
"""
proc = simulationArchTypes.ProcessBaseClass(procName, priority)
self.procList.append(proc)
self.TotalSim.addNewProcess(proc.processData)
return proc
[docs]
def CreateNewTask(self, TaskName, TaskRate, InputDelay=None, FirstStart=0):
"""
Creates a simulation task on the C-level with a specific update-frequency (TaskRate), an optional delay, and
an optional start time.
Args:
TaskName (str): Name of Task
TaskRate (int): Number of nanoseconds to elapse before update() is called
InputDelay (int): (deprecated, unimplemented) Number of nanoseconds simulating a lag of the particular task
FirstStart (int): Number of nanoseconds to elapse before task is officially enabled
Returns:
simulationArchTypes.TaskBaseClass object
"""
if InputDelay is not self.CreateNewTask.__defaults__[0]:
deprecated.deprecationWarn(
"InputDelay",
"2024/12/13",
"This input variable is non-functional and now deprecated.",
)
Task = simulationArchTypes.TaskBaseClass(TaskName, TaskRate, FirstStart)
self.TaskList.append(Task)
return Task
def ResetTask(self, taskName):
for Task in self.TaskList:
if Task.Name == taskName:
Task.resetTask(self.TotalSim.CurrentNanos)
[docs]
def InitializeSimulation(self):
"""
Initialize the BSK simulation. This runs the SelfInit() and Reset() methods on each module.
"""
if self.simulationInitialized:
self.TotalSim.resetThreads(self.TotalSim.getThreadCount())
self.TotalSim.assignRemainingProcs()
self.TotalSim.ResetSimulation()
self.TotalSim.selfInitSimulation()
self.TotalSim.resetInitSimulation()
self.simulationInitialized = True
def CheckStopCondition(self):
if self.StopCondition == "<=":
return self.TotalSim.NextTaskTime <= self.StopTime
elif self.StopCondition == ">=":
return (
self.TotalSim.CurrentNanos < self.StopTime
or self.TotalSim.NextTaskTime == self.StopTime
)
[docs]
def ExecuteSimulation(self):
"""
run the simulation until the prescribed stop time or termination.
"""
progressBar = SimulationProgressBar(self.StopTime, self.showProgressBar)
while self.CheckStopCondition():
# Check events
for event in self.activeEvents():
if event.shouldBeChecked(self.TotalSim.CurrentNanos):
event.checkEvent(self)
if self.terminate:
break
# Find the next time to stop the sim
eventCheckTimes = [
event.nextCheckTime(self.TotalSim.CurrentNanos)
for event in self.activeEvents()
]
if len(eventCheckTimes) > 0:
# Stop at next event, if any
nextStopTime = min(eventCheckTimes)
# But must at least reach the next task
nextStopTime = max(nextStopTime, self.TotalSim.NextTaskTime)
# But don't pass stop
nextStopTime = min(nextStopTime, self.StopTime)
else:
nextStopTime = self.StopTime # Otherwise stop at the stop time
# Must at least step to the next task time if StopCondition is ">="
if self.StopCondition == ">=":
nextStopTime = max(nextStopTime, self.TotalSim.NextTaskTime)
# Execute the sim
nextPriority = -1
self.TotalSim.StepUntilStop(int(nextStopTime), nextPriority)
progressBar.update(self.TotalSim.NextTaskTime)
self.terminate = False
progressBar.markComplete()
progressBar.close()
[docs]
def disableTask(self, TaskName):
"""
Disable this particular task from being executed.
"""
for Task in self.TaskList:
if Task.Name == TaskName:
Task.disable()
[docs]
def enableTask(self, TaskName):
"""
Enable this particular task to be executed.
"""
for Task in self.TaskList:
if Task.Name == TaskName:
Task.enable()
def parseDataIndex(self):
self.dataStructureDictionary = {}
try:
xmlData = ET.parse(self.dataStructIndex)
except:
print("Failed to parse the XML index. Likely that it isn't present")
return
root = xmlData.getroot()
for child in root:
newStruct = StructDocData(child.attrib["refid"])
self.dataStructureDictionary.update({child.find("name").text: newStruct})
self.indexParsed = True
[docs]
def createNewEvent(self, eventName, *args, **kwargs):
"""
Create an event sequence that contains a series of tasks to be executed.
Args:
eventName (str): Name of the event
*args: Arguments to pass to the :class:`EventHandlerClass` constructor
**kwargs: Keyword arguments to pass to the :class:`EventHandlerClass` constructor
"""
if eventName in list(self.eventMap.keys()):
warnings.warn(f"Skipping event creation since {eventName} already exists.")
return
newEvent = EventHandlerClass(eventName, *args, **kwargs)
self.eventMap[eventName] = newEvent
def activeEvents(self):
return (event for event in self.eventMap.values() if event.eventActive)
def setEventActivity(self, eventName, activityCommand):
if eventName not in list(self.eventMap.keys()):
print("You asked me to set the status of an event that I don't have.")
return
self.eventMap[eventName].eventActive = activityCommand
[docs]
def setAllButCurrentEventActivity(
self, currentEventName, activityCommand, useIndex=False
):
"""Set all event activity variables except for the currentEventName event. The ``useIndex`` flag can be used to
prevent enabling or disabling every task, and instead only alter the ones that belong to the same group (for
example, the same spacecraft). The distinction is made through an index set after the ``_`` symbol in the event
name. All events of the same group must have the same index."""
if useIndex:
index = currentEventName.partition("_")[2] # save the current event's index
for eventName in list(self.eventMap.keys()):
if currentEventName != eventName:
if useIndex:
if eventName.partition("_")[2] == index:
self.eventMap[eventName].eventActive = activityCommand
else:
self.eventMap[eventName].eventActive = activityCommand
def SetCArray(InputList, VarType, ArrayPointer):
if isinstance(ArrayPointer, (list, tuple)):
raise TypeError(
"Cannot set a C array if it is actually a python list. Just assign the variable to the list directly."
)
arraySetter = getattr(sim_model, VarType + "Array_setitem")
for currIndex, currElem in enumerate(InputList):
arraySetter(ArrayPointer, currIndex, currElem)
def getCArray(varType, arrayPointer, arraySize):
arrayGetter = getattr(sim_model, varType + "Array_getitem")
return [arrayGetter(arrayPointer, currIndex) for currIndex in range(arraySize)]
def synchronizeTimeHistories(arrayList):
returnArrayList = arrayList
timeCounter = 0
for i in range(len(returnArrayList)):
while returnArrayList[i][0, 0] > returnArrayList[0][timeCounter, 0]:
timeCounter += 1
for i in range(len(returnArrayList)):
while returnArrayList[i][1, 0] < returnArrayList[0][timeCounter, 0]:
returnArrayList[i] = np.delete(returnArrayList[i], 0, 0)
timeCounter = -1
for i in range(len(returnArrayList)):
while returnArrayList[i][-1, 0] < returnArrayList[0][timeCounter, 0]:
timeCounter -= 1
for i in range(len(returnArrayList)):
while returnArrayList[i][-2, 0] > returnArrayList[0][timeCounter, 0]:
returnArrayList[i] = np.delete(returnArrayList[i], -1, 0)
timeNow = returnArrayList[0][
0, 0
] # Desirement is to have synched arrays match primary time
outputArrayList = []
indexPrev = [0] * len(returnArrayList)
outputArrayList = [[]] * len(returnArrayList)
timeNow = returnArrayList[0][0, 0]
outputArrayList[0] = returnArrayList[0][0:-2, :]
for i in range(1, returnArrayList[0].shape[0] - 1):
for j in range(1, len(returnArrayList)):
while returnArrayList[j][indexPrev[j] + 1, 0] < returnArrayList[0][i, 0]:
indexPrev[j] += 1
dataProp = (
returnArrayList[j][indexPrev[j] + 1, 1:]
- returnArrayList[j][indexPrev[j], 1:]
)
dataProp *= (timeNow - returnArrayList[j][indexPrev[j], 0]) / (
returnArrayList[j][indexPrev[j] + 1, 0]
- returnArrayList[j][indexPrev[j], 0]
)
dataProp += returnArrayList[j][indexPrev[j], 1:]
dataRow = [timeNow]
dataRow.extend(dataProp.tolist())
outputArrayList[j].append(dataRow)
timePrevious = timeNow
timeNow = returnArrayList[0][i, 0]
for j in range(1, len(returnArrayList)):
outputArrayList[j] = np.array(outputArrayList[j])
return outputArrayList