import inspect
from copy import deepcopy
from typing import List, Tuple, Optional
from immutabledict import immutabledict
from strax import deterministic_hash
from appletree import utils
from appletree.utils import exporter
export, __all__ = exporter()
[docs]@export
class Plugin:
"""The smallest simulation unit."""
# Do not initialize this class because it is base
__is_base = True
# the plugin's dependency(the arguments of `simulate`)
depends_on: List[str] = []
# the plugin can provide(`simulate` will return)
provides: List[str] = []
# relevant parameters, will be fitted in MCMC
parameters: Tuple = ()
# Set using the takes_config decorator
takes_config = immutabledict()
[docs] def __init__(self, llh_name: Optional[str] = None):
"""Initialization."""
# llh_name will tell us which map to use
self.llh_name = llh_name
if not self.depends_on:
raise ValueError(f"depends_on not provided for {self.__class__.__name__}")
if not self.provides:
raise ValueError(f"provides not provided for {self.__class__.__name__}")
# configs are loaded when a plugin is initialized
for config in self.takes_config.values():
config.build(self.llh_name)
self.sanity_check()
# Do not set configurations as static! This is very important!!!
for config in self.takes_config.values():
setattr(self, config.name, deepcopy(config))
[docs] def __call__(self, *args, **kwargs):
"""Calls self.simulate."""
return self.simulate(*args, **kwargs)
[docs] def simulate(self, *args, **kwargs):
"""The main simulation function.
Args:
key: a jnp.array with length 2, used to generate random variables.
See https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html
parameters: a dictionary with key being parameters' names. Plugin will
get values of self.parameters from this dictionary.
args: other args following ``key`` and ``parameters`` must be in the order of
self.depends_on.
Returns:
``key`` and output simulated variables, ordered by self.provides. ``key`` will
be updated if it's used inside self.simulate to generate random variables.
"""
raise NotImplementedError
[docs] def sanity_check(self):
"""Check the consistency between ``depends_on``, ``provides`` and in(out)put of
``self.simulate``"""
arguments = inspect.getfullargspec(self.simulate)[0]
if arguments[1] != "key":
mesg = f"First argument of {self.__class__.__name__}"
mesg += ".simulate should be 'key'."
raise ValueError(mesg)
if arguments[2] != "parameters":
mesg = f"Second argument of {self.__class__.__name__}"
mesg += ".simulate should be 'parameters'."
raise ValueError(mesg)
for i, depend in enumerate(self.depends_on, start=3):
if arguments[i] != depend:
mesg = f"{i}th argument of {self.__class__.__name__}"
mesg += f".simulate should be '{depend}'. "
mesg += f"Plugin {self.__class__.__name__} is insane, check dependency!"
raise ValueError(mesg)
@property
def lineage(self):
return {
"depends_on": self.depends_on,
"provides": self.provides,
"parameters": self.parameters,
"takes_config": dict(
zip(
self.takes_config.keys(),
[v.lineage for v in self.takes_config.values()],
)
),
}
@property
def lineage_hash(self):
return deterministic_hash(self.lineage)
[docs]@export
def add_plugin_extensions(module1, module2, force=False):
"""Add plugins of module2 to module1."""
utils.add_extensions(module1, module2, Plugin, force=force)
[docs]@export
def _add_plugin_extension(module, plugin, force=False):
"""Add plugin to module."""
utils._add_extension(module, plugin, Plugin, force=force)