import os
from warnings import warn
from functools import partial
from typing import Tuple, List, Dict, Optional, Union, Set
import numpy as np
import pandas as pd
from jax import numpy as jnp
from strax import deterministic_hash
from appletree import utils
from appletree.config import OMITTED
from appletree.plugin import Plugin
from appletree.share import _cached_configs, _cached_functions, set_global_config
from appletree.utils import exporter, get_file_path, load_data, calculate_sha256
from appletree.hist import make_hist_mesh_grid, make_hist_irreg_bin_1d, make_hist_irreg_bin_2d
export, __all__ = exporter()
[docs]@export
class Component:
"""Base class of component."""
# Do not initialize this class because it is base
__is_base = True
rate_name: str = ""
norm_type: str = ""
# add_eps_to_hist==True was introduced as only a workaround
# for likelihood blowup problem when using meshgrid binning
add_eps_to_hist: bool = True
force_no_eff: bool = False
[docs] def __init__(self, name: Optional[str] = None, llh_name: Optional[str] = None, **kwargs):
"""Initialization.
Args:
bins: bins to generate the histogram.
* For irreg bins_type, bins must be bin edges of the two dimensions.
* For meshgrid bins_type, bins are sent to jnp.histogramdd.
bins_type: binning scheme, can be either irreg or meshgrid.
"""
if name is None:
self.name = self.__class__.__name__
else:
self.name = name
if llh_name is None:
self.llh_name = self.__class__.__name__ + "_llh"
else:
self.llh_name = llh_name
self.needed_parameters: Set[str] = set()
if "bins" in kwargs.keys() and "bins_type" in kwargs.keys():
self.set_binning(**kwargs)
[docs] def set_binning(self, **kwargs):
"""Set binning of component."""
if "bins" not in kwargs.keys() or "bins_type" not in kwargs.keys():
raise ValueError("bins and bins_type must be set!")
self.bins = kwargs.get("bins")
self.bins_type = kwargs.get("bins_type")
if self.bins_type not in ["irreg", "meshgrid", None]:
raise ValueError(f"Unsupported bins_type {self.bins_type}!")
if self.bins_type == "meshgrid":
warning = "The usage of meshgrid binning is highly discouraged."
warn(warning)
[docs] def _clip(self, result: list):
"""Clip simulated result."""
mask = np.ones(len(result[-1]), dtype=bool)
for i in range(len(result) - 1):
mask &= result[i] > np.array(self.bins[i]).min()
mask &= result[i] < np.array(self.bins[i]).max()
for i in range(len(result)):
result[i] = result[i][mask]
return result
@property
def _use_mcinput(self):
return "Bootstrap" in self._plugin_class_registry["energy"].__name__
[docs] def simulate_hist(self, *args, **kwargs):
"""Hook for simulation with histogram output."""
raise NotImplementedError
[docs] def multiple_simulations(self, key, batch_size, parameters, times, apply_eff=False):
"""Simulate many times and move results to CPU because the memory limit of GPU."""
results_pile = []
assert times > 0, "times of multiple simulations must be greater than 0!"
for _ in range(times):
key, results = self.simulate(key, batch_size, parameters)
if apply_eff:
if self.force_no_eff:
raise RuntimeError(
"You are forcing to apply efficiency! "
"But component was set to not returning efficiency when "
f"running {self.name}.deduce!"
)
mask = np.array(results[-1]) > 0
for i in range(len(results)):
results[i] = np.array(results[i])[mask]
results_pile.append(results)
results_pile = [
np.hstack([results_pile[j][i] for j in range(times)]) for i in range(len(results))
]
return key, results_pile
[docs] def multiple_simulations_compile(self, key, batch_size, parameters, times, apply_eff=False):
"""Simulate many times after new compilation and move results to CPU because the memory
limit of GPU."""
results_pile = []
for _ in range(times):
if _cached_configs["g4"] and self._use_mcinput:
if isinstance(_cached_configs["g4"], dict):
g4_file_name = _cached_configs["g4"][self.llh_name][0]
_cached_configs["g4"][self.llh_name] = [
g4_file_name,
batch_size,
key.sum().item(),
]
else:
g4_file_name = _cached_configs["g4"][0]
_cached_configs["g4"] = [g4_file_name, batch_size, key.sum().item()]
self.compile()
key, results = self.multiple_simulations(key, batch_size, parameters, 1, apply_eff)
results_pile.append(results)
results_pile = [
np.hstack([results_pile[j][i] for j in range(times)]) for i in range(len(results))
]
return key, results_pile
[docs] def implement_binning(self, mc, eff):
"""Apply binning to MC data.
Args:
mc: data from simulation.
eff: efficiency of each event, as the weight when making a histogram.
"""
if self.bins_type == "irreg":
if len(self.bins) == 1:
hist = make_hist_irreg_bin_1d(mc[:, 0], self.bins[0], weights=eff)
elif len(self.bins) == 2:
hist = make_hist_irreg_bin_2d(mc, *self.bins, weights=eff)
else:
raise ValueError(f"Currently only support 1D and 2D, but got {len(self.bins)}D!")
elif self.bins_type == "meshgrid":
hist = make_hist_mesh_grid(mc, bins=self.bins, weights=eff)
else:
raise ValueError(f"Unsupported bins_type {self.bins_type}!")
if self.add_eps_to_hist:
# as an uncertainty to prevent blowing up
# uncertainty = 1e-10 + jnp.mean(eff)
hist = jnp.clip(hist, 1e-10 + jnp.mean(eff), jnp.inf)
return hist
[docs] def get_normalization(self, hist, parameters, batch_size=None):
"""Return the normalization factor of the histogram."""
if self.norm_type == "on_pdf":
normalization_factor = 1 / jnp.sum(hist) * parameters[self.rate_name]
elif self.norm_type == "on_sim":
if self._use_mcinput:
bootstrap_name = self._plugin_class_registry["energy"].__name__
bootstrap_name = bootstrap_name + "_" + self.name
n_events_selected = _cached_functions[self.llh_name][
bootstrap_name
].g4.n_events_selected
normalization_factor = 1 / n_events_selected * parameters[self.rate_name]
else:
normalization_factor = 1 / batch_size * parameters[self.rate_name]
else:
raise ValueError(f"Unsupported norm_type {self.norm_type}!")
return normalization_factor
[docs] def deduce(self, *args, **kwargs):
"""Hook for workflow deduction."""
raise NotImplementedError
[docs] def compile(self):
"""Hook for compiling simulation code."""
pass
@property
def lineage(self):
raise NotImplementedError
@property
def lineage_hash(self):
return deterministic_hash(self.lineage)
[docs]@export
class ComponentSim(Component):
"""Component that needs MC simulations."""
# Do not initialize this class because it is base
__is_base = True
[docs] def __init__(self, *args, **kwargs):
"""Initialization."""
super().__init__(*args, **kwargs)
self._plugin_class_registry = dict()
[docs] def register(self, plugin_class):
"""Register a plugin to the component."""
if isinstance(plugin_class, (tuple, list)):
# Shortcut for multiple registration
for x in plugin_class:
self.register(x)
return
# Ensure plugin_class.provides is a tuple
if isinstance(plugin_class.provides, str):
plugin_class.provides = tuple([plugin_class.provides])
for p in plugin_class.provides:
self._plugin_class_registry[p] = plugin_class
already_seen = []
for plugin in self._plugin_class_registry.values():
if plugin in already_seen:
continue
already_seen.append(plugin)
for config_name, items in plugin.takes_config.items():
# Looping over the configs of the new plugin and check if
# they can be found in the already registered plugins:
for new_config, new_items in plugin_class.takes_config.items():
if new_config != config_name:
continue
if items.default == new_items.default:
continue
else:
mes = (
f"Two plugins have a different file name "
f"for the same config. The config "
f"'{new_config}' in '{plugin.__name__}' takes "
f"the file name as '{new_items.default}' while in "
f"'{plugin_class.__name__}' the file name "
f"is set to '{items.default}'. Please change "
f"one of the file names."
)
raise ValueError(mes)
[docs] def register_all(self, module):
"""Register all plugins defined in module.
Can pass a list/tuple of modules to register all in each.
"""
if isinstance(module, (tuple, list)):
# Shortcut for multiple registration
for x in module:
self.register_all(x)
return
for x in dir(module):
x = getattr(module, x)
if not isinstance(x, type(type)):
continue
if issubclass(x, Plugin):
self.register(x)
[docs] def dependencies_deduce(
self,
data_names: Union[List[str], Tuple[str]] = ["cs1", "cs2", "eff"],
dependencies: Optional[List[Dict]] = None,
nodep_data_name: str = "batch_size",
) -> list:
"""Deduce dependencies.
Args:
data_names: data names that simulation will output.
dependencies: dependency tree.
nodep_data_name: data_name without dependency will not be deduced.
"""
if dependencies is None:
dependencies = []
for data_name in data_names:
# usually `batch_size` have no dependency
if data_name == nodep_data_name:
continue
try:
dependencies.append(
{
"plugin": self._plugin_class_registry[data_name],
"provides": data_name,
"depends_on": self._plugin_class_registry[data_name].depends_on,
}
)
except KeyError:
raise ValueError(f"Can not find dependency for {data_name}")
for data_name in data_names:
# `batch_size` has no dependency
if data_name == nodep_data_name:
continue
dependencies = self.dependencies_deduce(
data_names=self._plugin_class_registry[data_name].depends_on,
dependencies=dependencies,
nodep_data_name=nodep_data_name,
)
return dependencies
[docs] def dependencies_simplify(self, dependencies):
"""Simplify the dependencies."""
already_seen = []
self.worksheet = []
# Reinitialize needed_parameters
# because sometimes user will deduce(& compile) after changing configs
self.needed_parameters: Set[str] = set()
# Add rate_name to needed_parameters only when it's not empty
if self.rate_name != "":
self.needed_parameters.add(self.rate_name)
for _plugin in dependencies[::-1]:
plugin = _plugin["plugin"]
if plugin.__name__ in already_seen:
continue
self.worksheet.append([plugin.__name__, plugin.provides, plugin.depends_on])
already_seen.append(plugin.__name__)
self.needed_parameters |= set(plugin.parameters)
# Add needed_parameters from config
for config in plugin.takes_config.values():
required_parameter = config.required_parameter(self.llh_name)
if required_parameter is not None:
self.needed_parameters |= {required_parameter}
[docs] def flush_source_code(
self,
data_names: Union[List[str], Tuple[str]] = ["cs1", "cs2", "eff"],
func_name: str = "simulate",
nodep_data_name: str = "batch_size",
):
"""Infer the simulation code from the dependency tree."""
self.func_name = func_name
if not isinstance(data_names, (list, str)):
raise RuntimeError(f"data_names must be list or str, but given {type(data_names)}")
if isinstance(data_names, str):
data_names = [data_names]
instances = set()
code = ""
indent = " " * 4
code += "from functools import partial\n"
code += "from jax import jit\n"
# import needed plugins
for work in self.worksheet:
plugin = work[0]
code += f"from appletree.plugins import {plugin}\n"
# initialize new instances
for work in self.worksheet:
plugin = work[0]
instance = plugin + "_" + self.name
instances.add(instance)
code += f"{instance} = {plugin}('{self.llh_name}')\n"
# define functions
code += "\n"
if nodep_data_name == "batch_size":
code += "@partial(jit, static_argnums=(1, ))\n"
else:
code += "@jit\n"
code += f"def {func_name}(key, {nodep_data_name}, parameters):\n"
for work in self.worksheet:
provides = "key, " + ", ".join(work[1])
depends_on = ", ".join(work[2])
instance = work[0] + "_" + self.name
code += f"{indent}{provides} = {instance}(key, parameters, {depends_on})\n"
output = "key, " + "[" + ", ".join(data_names) + "]"
code += f"{indent}return {output}\n"
self.code = code
self.instances = instances
if func_name in _cached_functions[self.llh_name].keys():
warning = f"Function name {func_name} is already cached. "
warning += "Running compile() will overwrite it."
warn(warning)
@property
def code(self):
"""Code of simulation function."""
return self._code
@code.setter
def code(self, code):
self._code = code
_cached_functions[self.llh_name] = dict()
self._compile = partial(exec, self.code, _cached_functions[self.llh_name])
[docs] def deduce(
self,
data_names: Union[List[str], Tuple[str]] = ["cs1", "cs2"],
func_name: str = "simulate",
nodep_data_name: str = "batch_size",
force_no_eff: bool = False,
):
"""Deduce workflow and code.
Args:
data_names: data names that simulation will output.
func_name: name of the simulation function, used to cache it.
nodep_data_name: data_name without dependency will not be deduced.
force_no_eff: force to ignore the efficiency, used in yield prediction.
"""
if not isinstance(data_names, (list, tuple)):
raise ValueError(f"Unsupported data_names type {type(data_names)}!")
# make sure that 'eff' is the last data_name
data_names = list(data_names)
if "eff" in data_names:
data_names.remove("eff")
if not force_no_eff:
data_names += ["eff"]
else:
# track status of component
self.force_no_eff = True
dependencies = self.dependencies_deduce(data_names, nodep_data_name=nodep_data_name)
self.dependencies_simplify(dependencies)
self.flush_source_code(data_names, func_name, nodep_data_name)
[docs] def compile(self):
"""Build simulation function and cache it to share._cached_functions."""
self._compile()
self.simulate = _cached_functions[self.llh_name][self.func_name]
[docs] def simulate_hist(self, key, batch_size, parameters):
"""Simulate and return histogram.
Args:
key: key used for pseudorandom generator.
batch_size: number of events to be simulated.
parameters: a dictionary that contains all parameters needed in simulation.
"""
key, result = self.simulate(key, batch_size, parameters)
if self.force_no_eff:
mc = jnp.asarray(result).T
eff = jnp.ones(mc.shape[0])
else:
mc = jnp.asarray(result[:-1]).T
eff = result[-1] # we guarantee that the last output is efficiency in self.deduce
assert mc.shape[1] == len(
self.bins
), "Length of bins must be the same as length of bins_on!"
hist = self.implement_binning(mc, eff)
normalization_factor = self.get_normalization(hist, parameters, batch_size)
hist *= normalization_factor
return key, hist
[docs] def simulate_weighted_data(self, key, batch_size, parameters):
"""Simulate and return histogram."""
key, result = self.simulate(key, batch_size, parameters)
# Move data to CPU
result = [np.array(r) for r in result]
# Clip data points out of ROI
result = self._clip(result)
mc = result[:-1]
assert len(mc) == len(self.bins), "Length of bins must be the same as length of bins_on!"
mc = jnp.asarray(mc).T
eff = jnp.asarray(
result[-1]
) # we guarantee that the last output is efficiency in self.deduce
hist = self.implement_binning(mc, eff)
normalization_factor = self.get_normalization(hist, parameters, batch_size)
result[-1] *= normalization_factor
return key, result
[docs] def save_code(self, file_path):
"""Save the code to file."""
with open(file_path, "w") as f:
f.write(self.code)
[docs] def set_config(self, configs):
"""Set new global configuration options.
Args:
configs: dict, configuration file name or dictionary
"""
set_global_config(configs)
warn(
"New config is set, please run deduce() "
"and compile() again to update the simulation code."
)
[docs] def show_config(self, data_names: Union[List[str], Tuple[str]] = ["cs1", "cs2", "eff"]):
"""Return configuration options that affect data_names.
Args:
data_names: Data type name
"""
dependencies = self.dependencies_deduce(
data_names,
nodep_data_name="batch_size",
)
r = []
seen = []
for dep in dependencies:
p = dep["plugin"]
# Track plugins we already saw, so options from
# multi-output plugins don't come up several times
if p in seen:
continue
seen.append(p)
for config in p.takes_config.values():
try:
default = config.get_default()
except ValueError:
default = OMITTED
current = _cached_configs.get(config.name, None)
if isinstance(current, dict):
current = current[self.llh_name]
r.append(
dict(
option=config.name,
default=default,
current=current,
applies_to=p.provides,
help=config.help,
)
)
if len(r):
df = pd.DataFrame(r, columns=r[0].keys())
else:
df = pd.DataFrame([])
# Then you can print the dataframe like:
# straxen.dataframe_to_wiki(df, title=f'{data_names}', float_digits=1)
return df
[docs] def new_component(self, llh_name: Optional[str] = None, pass_binning: bool = True):
"""Generate new component with same binning, usually used on predicting yields."""
if pass_binning:
if hasattr(self, "bins") and hasattr(self, "bins_type"):
component = self.__class__(
name=self.name + "_copy",
llh_name=llh_name,
bins=self.bins,
bins_type=self.bins_type,
)
else:
raise ValueError("Should provide bins and bins_type if you want to pass binning!")
else:
component = self.__class__(
name=self.name + "_copy",
llh_name=llh_name,
)
return component
@property
def lineage(self):
bins_dict = dict()
if hasattr(self, "bins") or hasattr(self, "bins_type"):
bins_dict = {
"bins": (
tuple(b.tolist() for b in self.bins) if self.bins is not None else self.bins
),
"bins_type": self.bins_type,
}
return {
"rate_name": self.rate_name,
"norm_type": self.norm_type,
"code": self.code,
"instances": dict(
zip(
self.instances,
[_cached_functions[self.llh_name][p].lineage for p in self.instances],
)
),
**bins_dict,
}
[docs]@export
class ComponentFixed(Component):
"""Component whose shape is fixed."""
# Do not initialize this class because it is base
__is_base = True
[docs] def __init__(self, *args, **kwargs):
"""Initialization."""
if not kwargs.get("file_name", None):
raise ValueError("Should provide file_name for ComponentFixed!")
else:
self._file_name = kwargs.get("file_name", None)
super().__init__(*args, **kwargs)
[docs] def deduce(self, data_names: Union[List[str], Tuple[str]] = ["cs1", "cs2"]):
"""Deduce the needed parameters and make the fixed histogram."""
self.data = load_data(self._file_name)[list(data_names)].to_numpy()
self.eff = jnp.ones(len(self.data))
self.hist = self.implement_binning(self.data, self.eff)
self.needed_parameters.add(self.rate_name)
[docs] def simulate(self):
"""Fixed component does not need to simulate."""
raise NotImplementedError
[docs] def simulate_hist(self, parameters, *args, **kwargs):
"""Return the fixed histogram."""
normalization_factor = self.get_normalization(self.hist, parameters, len(self.data))
return self.hist * normalization_factor
[docs] def simulate_weighted_data(self, parameters, *args, **kwargs):
"""Simulate and return histogram."""
result = [r for r in self.data.T]
result.append(np.array(self.eff))
# Clip all simulated data points
result = self._clip(result)
normalization_factor = self.get_normalization(self.hist, parameters, len(self.data))
result[-1] *= normalization_factor
return result
@property
def lineage(self):
return {
"rate_name": self.rate_name,
"norm_type": self.norm_type,
"bins": tuple(b.tolist() for b in self.bins) if self.bins is not None else self.bins,
"bins_type": self.bins_type,
"file_path": (
os.path.basename(self._file_name)
if not utils.FULL_PATH_LINEAGE
else get_file_path(self._file_name)
),
"sha256": calculate_sha256(get_file_path(self._file_name)),
}
[docs]@export
def add_component_extensions(module1, module2, force=False):
"""Add components of module2 to module1."""
utils.add_extensions(module1, module2, Component, force=force)
[docs]@export
def _add_component_extension(module, component, force=False):
"""Add component to module."""
utils._add_extension(module, component, Component, force=force)