import os
from typing import Optional, Union, Any
from immutabledict import immutabledict
from jax import numpy as jnp
from warnings import warn
import numpy as np
from strax import deterministic_hash
from appletree import utils
from appletree.share import _cached_configs
from appletree.utils import (
exporter,
load_json,
get_file_path,
integrate_midpoint,
cumulative_integrate_midpoint,
calculate_sha256,
)
from appletree import interpolation
from appletree.interpolation import FLOAT_POS_MIN, FLOAT_POS_MAX
export, __all__ = exporter()
OMITTED = "<OMITTED>"
__all__.extend(["OMITTED"])
[docs]@export
def takes_config(*configs):
"""Decorator for plugin classes, to specify which configs it takes.
Args:
configs: Config instances of configs this plugin takes.
"""
def wrapped(plugin_class):
"""
Args:
plugin_class: plugin needs configuration
"""
result = dict()
for config in configs:
if not isinstance(config, Config):
raise RuntimeError("Specify config options by Config objects")
config.taken_by = plugin_class.__name__
result[config.name] = config
if hasattr(plugin_class, "takes_config") and len(plugin_class.takes_config):
# Already have some configs set, e.g. because of subclassing
# where both child and parent have a takes_config decorator
for config in result.values():
if config.name in plugin_class.takes_config:
raise RuntimeError(f"Attempt to specify config {config.name} twice")
plugin_class.takes_config = immutabledict({**plugin_class.takes_config, **result})
else:
plugin_class.takes_config = immutabledict(result)
# Should set the configurations as the attributes of Plugin
return plugin_class
return wrapped
[docs]@export
class Config:
"""Configuration option taken by a appletree plugin."""
llh_name: Optional[str] = None
[docs] def __init__(
self,
name: str,
type: Union[type, tuple, list, str] = OMITTED,
default: Any = OMITTED,
help: str = "",
):
"""Initialization.
Args:
name: name of the map.
type: Excepted type of the option's value.
default: Default value the option takes.
help: description of the map.
"""
self.name = name
self.type = type
self.default = default
self.help = help
# Sanity check
if isinstance(self.default, dict):
raise ValueError(
f"Do not set {self.name}'s default value as dict!",
)
[docs] def get_default(self):
"""Get default value of configuration."""
if self.default is not OMITTED:
return self.default
raise ValueError(f"Missing option {self.name} required by {self.taken_by}")
[docs] def build(self, llh_name: Optional[str] = None):
"""Build configuration, set attributes to Config instance."""
raise NotImplementedError
[docs] def _resolve_cached_config(self, llh_name, transform=None):
"""Look up config in cache, apply transform, resolve per-likelihood dict.
Args:
llh_name: name of the likelihood component.
transform: optional callable applied to the default value
before caching (e.g. get_file_path for Map configs).
Returns:
The resolved scalar (or per-likelihood) config value.
"""
if self.name in _cached_configs:
value = _cached_configs[self.name]
else:
value = self.get_default()
if transform is not None:
value = transform(value)
_cached_configs[self.name] = value
if isinstance(value, dict):
try:
return value[llh_name]
except KeyError:
raise ValueError(
f"You specified {self.name} as a dictionary. "
f"The key of it should be the name of one "
f"of the likelihood, but it is {llh_name}."
)
return value
[docs] def required_parameter(self, llh_name=None):
return None
@property
def lineage(self):
raise NotImplementedError
@property
def lineage_hash(self):
return deterministic_hash(self.lineage)
[docs]@export
class Constant(Config):
"""Constant is a special config which takes only certain value."""
value = None
[docs] def build(self, llh_name: Optional[str] = None):
"""Set value of Constant."""
self.value = self._resolve_cached_config(llh_name)
@property
def lineage(self):
return {
"llh_name": self.llh_name,
"value": self.value,
}
[docs]@export
class Map(Config):
"""Map is a special config that takes input files.
The method ``apply`` is dynamically assigned.
When using points, the ``apply`` will be ``map_point``,
while using regular binning, the ``apply`` will be ``map_regbin``.
When using log-binning, we will first convert the positions to log space.
"""
_POINT_INTERPOLATORS = {
"IDW": interpolation.curve_interpolator,
"NN": interpolation.map_interpolator_nearest_neighbor_1d,
"LERP": interpolation.map_interpolator_linear_1d,
}
_REGBIN_INTERPOLATORS = {
(2, "IDW"): interpolation.map_interpolator_regular_binning_2d,
(2, "NN"): interpolation.map_interpolator_regular_binning_nearest_neighbor_2d,
(3, "IDW"): interpolation.map_interpolator_regular_binning_3d,
(3, "NN"): interpolation.map_interpolator_regular_binning_nearest_neighbor_3d,
}
def __init__(self, method="IDW", **kwargs):
super().__init__(**kwargs)
self.method = method
[docs] def build(self, llh_name: Optional[str] = None):
"""Cache the map to jnp.array."""
self.file_path = self._resolve_cached_config(
llh_name,
transform=get_file_path,
)
# try to find the path first
_file_path = get_file_path(self.file_path)
try:
data = load_json(self.file_path)
except Exception:
raise ValueError(f"Cannot load {self.name} from {_file_path}!")
coordinate_type = data["coordinate_type"]
if isinstance(coordinate_type, list):
for ct in coordinate_type:
if ct not in ("regbin", "log_regbin"):
raise ValueError(
f"Per-axis coordinate_type entries must be 'regbin' or "
f"'log_regbin', got '{ct}'"
)
self.build_regbin(data)
elif coordinate_type == "point" or coordinate_type == "log_point":
self.build_point(data)
elif coordinate_type == "regbin" or coordinate_type == "log_regbin":
self.build_regbin(data)
else:
raise ValueError("map_type must be either 'point' or 'regbin'!")
[docs] def build_point(self, data):
"""Cache the map to jnp.array if bins_type is point."""
if data["coordinate_name"] == "pdf" or data["coordinate_name"] == "cdf":
if data["coordinate_type"] == "log_point":
raise ValueError(
f"It is not a good idea to use log pdf nor cdf "
f"in map {self.file_path}. "
f"Because its coordinate type is log-binned. "
)
if data["coordinate_name"] == "pdf":
warn(f"Convert {self.name} from (x, pdf) to (cdf, x).")
x, cdf = self.pdf_to_cdf(data["coordinate_system"], data["map"])
data["coordinate_name"] = "cdf"
data["coordinate_system"] = cdf
data["map"] = x
self.coordinate_type = data["coordinate_type"]
self.coordinate_name = data["coordinate_name"]
self.coordinate_system = jnp.asarray(data["coordinate_system"], dtype=float)
self.map = jnp.asarray(data["map"], dtype=float)
self._is_log_axis = ["log" in data["coordinate_type"]]
self._log_mask = jnp.array(self._is_log_axis)
if self.method not in self._POINT_INTERPOLATORS:
raise ValueError(f"Unknown method {self.method} for point interpolation.")
self.interpolator = self._POINT_INTERPOLATORS[self.method]
self._validate_log_coords(self.coordinate_system)
self.apply = self.map_point
[docs] def map_point(self, pos):
val = self.interpolator(
self.preprocess(pos),
self.preprocess(self.coordinate_system),
self.preprocess(self.map),
)
return val
[docs] def build_regbin(self, data):
"""Cache the map to jnp.array if bins_type is regbin.
``coordinate_type`` may be a single string (``"regbin"`` or
``"log_regbin"``) applied to every axis, or a list of such
strings with one entry per axis for mixed linear/log grids
(e.g. ``["regbin", "log_regbin"]``).
"""
coordinate_type = data["coordinate_type"]
ndim = len(data["coordinate_lowers"])
# Determine per-axis log flags
if isinstance(coordinate_type, list):
if len(coordinate_type) != ndim:
raise ValueError(
f"coordinate_type list length ({len(coordinate_type)}) "
f"must match number of axes ({ndim})"
)
self._is_log_axis = [ct == "log_regbin" for ct in coordinate_type]
elif coordinate_type == "log_regbin":
self._is_log_axis = [True] * ndim
else:
self._is_log_axis = [False] * ndim
# Validate: don't use log-scaled pdf/cdf axes
if "pdf" in data["coordinate_name"] or "cdf" in data["coordinate_name"]:
for i, name in enumerate(data["coordinate_name"]):
if name in ("pdf", "cdf") and self._is_log_axis[i]:
raise ValueError(
f"It is not a good idea to use log pdf nor cdf "
f"in map {self.file_path}. "
f"Because its coordinate type is log-binned. "
)
self.coordinate_type = coordinate_type
self.coordinate_name = data["coordinate_name"]
self.coordinate_lowers = jnp.asarray(data["coordinate_lowers"], dtype=float)
self.coordinate_uppers = jnp.asarray(data["coordinate_uppers"], dtype=float)
self.map = jnp.asarray(data["map"], dtype=float)
self._log_mask = jnp.array(self._is_log_axis)
ndim = len(self.coordinate_lowers)
self.interpolator = self._get_regbin_interpolator(ndim)
self._validate_log_coords(self.coordinate_lowers, self.coordinate_uppers)
self.apply = self.map_regbin
[docs] def _validate_log_coords(self, *coord_arrays):
"""Raise if any log-scaled coordinate is non-positive."""
if not any(self._is_log_axis):
return
for a in coord_arrays:
if a.ndim == 0 or (a.ndim == 1 and len(self._is_log_axis) == 1):
# Point map or 1D regbin: check entire array
if jnp.any(a <= 0):
raise ValueError(
f"Find non-positive coordinate system in map "
f"{self.file_path}, "
f"which is specified as {self.coordinate_type}"
)
else:
# Multi-D regbin: check only log axes
for i, is_log in enumerate(self._is_log_axis):
if is_log and jnp.any(a[i] <= 0):
raise ValueError(
f"Find non-positive coordinate system in map "
f"{self.file_path}, "
f"which is specified as {self.coordinate_type}"
)
[docs] def _get_regbin_interpolator(self, ndim):
"""Return the regular-binning interpolator for the given dimensionality."""
if ndim == 1:
return interpolation.map_interpolator_regular_binning_1d
key = (ndim, self.method)
try:
return self._REGBIN_INTERPOLATORS[key]
except KeyError:
raise ValueError(f"Unknown method {self.method} " f"for {ndim}D regular binning.")
[docs] def map_regbin(self, pos):
val = self.interpolator(
self.preprocess(pos),
self.preprocess(self.coordinate_lowers),
self.preprocess(self.coordinate_uppers),
self.map,
)
return val
[docs] def preprocess(self, pos):
"""Apply log10 to axes marked as log in ``_log_mask``."""
# Selectively transform only log axes so that linear axes with
# non-positive values are never passed through log10/clip, which
# would corrupt forward values or gradients.
safe_pos = jnp.where(self._log_mask, pos, 1.0)
log_vals = jnp.log10(jnp.clip(safe_pos, FLOAT_POS_MIN, FLOAT_POS_MAX))
return jnp.where(self._log_mask, log_vals, pos)
[docs] def pdf_to_cdf(self, x, pdf):
"""Convert pdf map to cdf map."""
norm = integrate_midpoint(x, pdf)
x, cdf = cumulative_integrate_midpoint(x, pdf)
cdf /= norm
return x, cdf
@property
def lineage(self):
return {
"llh_name": self.llh_name,
"method": self.method,
"file_path": (
os.path.basename(self.file_path)
if not utils.FULL_PATH_LINEAGE
else get_file_path(self.file_path)
),
"sha256": calculate_sha256(get_file_path(self.file_path)),
}
[docs]@export
class SigmaMap(Config):
"""Maps with uncertainty.
The value of a SigmaMap can be:
* a list with four elements,
which are the file names of median, lower, upper maps and the name of the scaler.
* a list with three elements,
which are the file names of median, lower and upper maps. The name of the scaler
is the default one f"{self.name}_sigma".
* a string,
which is the file name of the map for median, lower, upper.
In the first and second case, the name of the scaler will appear
in Component.needed_parameters.
"""
median: Map
lower: Map
upper: Map
def __init__(self, method="IDW", **kwargs):
super().__init__(**kwargs)
self.method = method
[docs] def _resolve_sigma_value(self, configs, index, label):
"""Extract the i-th sigma value from a list-or-string config.
Args:
configs: list of file paths or a single file path string.
index: which sigma (0=median, 1=lower, 2=upper).
label: "configs" or "default configs" for error messages.
"""
if isinstance(configs, list):
return configs[index]
if not isinstance(configs, str):
raise ValueError(
f"If {self.name}'s {label} is not a list, " "then it should be a string."
)
return configs
[docs] def _update_sigma_cache(self, map_name, value):
"""Update _cached_configs for a sigma sub-map, checking for conflicts."""
# In case some plugins only use the median
# and may already update the map name in `_cached_configs`
if map_name not in _cached_configs:
_cached_configs[map_name] = dict()
if isinstance(_cached_configs[map_name], dict):
existing = _cached_configs[map_name].get(
self.llh_name,
value,
)
if existing != value:
raise ValueError(
f"You give different values for {self.name} in "
f"configs, find {existing} and {value}."
)
_cached_configs[map_name][self.llh_name] = value
[docs] def build(self, llh_name: Optional[str] = None):
"""Read maps."""
self.llh_name = llh_name
_configs = self.get_configs()
_configs_default = self.get_default()
if isinstance(_configs, list) and len(_configs) > 4:
raise ValueError(f"You give too much information in {self.name}'s configs.")
if isinstance(_configs_default, list) and len(_configs_default) > 4:
raise ValueError(f"You give too much information in {self.name}'s default configs.")
sigmas = ["median", "lower", "upper"]
for i, sigma in enumerate(sigmas):
default = self._resolve_sigma_value(
_configs_default,
i,
"default configuration",
)
m = Map(
method=self.method,
name=f"{self.name}_{sigma}",
default=default,
)
setattr(self, sigma, m)
if self.llh_name is not None:
value = self._resolve_sigma_value(
_configs,
i,
"configuration",
)
self._update_sigma_cache(m.name, value)
self.median.build(llh_name=self.llh_name)
self.lower.build(llh_name=self.llh_name)
self.upper.build(llh_name=self.llh_name)
required_parameter = self.required_parameter()
if required_parameter is not None:
print(
f"{self.llh_name}'s map {self.name} is using "
f"the parameter {required_parameter}."
)
else:
print(f"{self.llh_name}'s map {self.name} is static " f"and not using any parameter.")
[docs] def get_configs(self, llh_name=None):
"""Get configs of SigmaMap."""
# if llh_name is not specified, use the attribute of SigmaMap
if llh_name is None and self.llh_name is not None:
llh_name = self.llh_name
if self.name in _cached_configs:
_configs = _cached_configs[self.name]
else:
_configs = self.get_default()
# Update values to sharing dictionary
_cached_configs[self.name] = _configs
if isinstance(_configs, dict):
if llh_name is None:
raise ValueError(
f"You specified {self.name} as a dictionary in _cached_configs. "
"The key of it should be the name of one of the likelihood, but it is None."
)
try:
return _configs[llh_name]
except KeyError:
mesg = (
f"You specified {self.name} as a dictionary. "
f"The key of it should be the name of one "
f"of the likelihood, but it is {llh_name}."
)
raise ValueError(mesg)
else:
return _configs
[docs] def required_parameter(self, llh_name=None):
"""Get required parameter of SigmaMap."""
_configs = self.get_configs(llh_name=llh_name)
# Find required parameter
if isinstance(_configs, list):
if len(_configs) == 4:
return _configs[-1]
else:
return self.name + "_sigma"
else:
return None
[docs] def apply(self, pos, parameters):
"""Apply SigmaMap with sigma and position."""
if self.required_parameter() is None:
sigma = 1.0
else:
sigma = parameters[self.required_parameter()]
median = self.median.apply(pos)
lower = self.lower.apply(pos)
upper = self.upper.apply(pos)
add_pos = (upper - median) * sigma
add_neg = (median - lower) * sigma
add = jnp.where(sigma > 0, add_pos, add_neg)
return median + add
@property
def lineage(self):
return {
"llh_name": self.llh_name,
"method": self.method,
"median": self.median.lineage,
"lower": self.lower.lineage,
"upper": self.upper.lineage,
}
[docs]@export
class ConstantSet(Config):
"""ConstantSet is a special config which takes a set of values.
We will not specify any hard-coded distribution or function here. User should be careful with
the actual function implemented. Fortunately, we only use these values as keyword arguments, so
mismatch will be catched when running.
"""
[docs] def build(self, llh_name: Optional[str] = None):
"""Set value of ConstantSet."""
self.value = self._resolve_cached_config(llh_name)
self._sanity_check()
self.set_volume = len(self.value[1][0])
self.value = {k: np.array(v) for k, v in zip(*self.value)}
[docs] def _sanity_check(self):
"""Check if parameter set lengths are same."""
mesg = "The given values should follow [names, values] format."
assert len(self.value) == 2, mesg
mesg = "Parameters and their names should have same length"
assert len(self.value[0]) == len(self.value[1]), mesg
volumes = [len(v) for v in self.value[1]]
mesg = "Parameter set lengths should be the same"
assert np.all(np.isclose(volumes, volumes[0])), mesg
@property
def lineage(self):
return {
"llh_name": self.llh_name,
"value": self.value,
}