import os
from typing import Optional, Union, Any
from immutabledict import immutabledict
from jax import numpy as jnp
from warnings import warn
import numpy as np
from strax import deterministic_hash
from appletree import utils
from appletree.share import _cached_configs
from appletree.utils import (
exporter,
load_json,
get_file_path,
integrate_midpoint,
cumulative_integrate_midpoint,
calculate_sha256,
)
from appletree import interpolation
from appletree.interpolation import FLOAT_POS_MIN, FLOAT_POS_MAX
export, __all__ = exporter()
OMITTED = "<OMITTED>"
__all__.extend(["OMITTED"])
[docs]@export
def takes_config(*configs):
"""Decorator for plugin classes, to specify which configs it takes.
Args:
configs: Config instances of configs this plugin takes.
"""
def wrapped(plugin_class):
"""
Args:
plugin_class: plugin needs configuration
"""
result = dict()
for config in configs:
if not isinstance(config, Config):
raise RuntimeError("Specify config options by Config objects")
config.taken_by = plugin_class.__name__
result[config.name] = config
if hasattr(plugin_class, "takes_config") and len(plugin_class.takes_config):
# Already have some configs set, e.g. because of subclassing
# where both child and parent have a takes_config decorator
for config in result.values():
if config.name in plugin_class.takes_config:
raise RuntimeError(f"Attempt to specify config {config.name} twice")
plugin_class.takes_config = immutabledict({**plugin_class.takes_config, **result})
else:
plugin_class.takes_config = immutabledict(result)
# Should set the configurations as the attributes of Plugin
return plugin_class
return wrapped
[docs]@export
class Config:
"""Configuration option taken by a appletree plugin."""
llh_name: Optional[str] = None
[docs] def __init__(
self,
name: str,
type: Union[type, tuple, list, str] = OMITTED,
default: Any = OMITTED,
help: str = "",
):
"""Initialization.
Args:
name: name of the map.
type: Excepted type of the option's value.
default: Default value the option takes.
help: description of the map.
"""
self.name = name
self.type = type
self.default = default
self.help = help
# Sanity check
if isinstance(self.default, dict):
raise ValueError(
f"Do not set {self.name}'s default value as dict!",
)
[docs] def get_default(self):
"""Get default value of configuration."""
if self.default is not OMITTED:
return self.default
raise ValueError(f"Missing option {self.name} required by {self.taken_by}")
[docs] def build(self, llh_name: Optional[str] = None):
"""Build configuration, set attributes to Config instance."""
raise NotImplementedError
[docs] def required_parameter(self, llh_name=None):
return None
@property
def lineage(self):
raise NotImplementedError
@property
def lineage_hash(self):
return deterministic_hash(self.lineage)
[docs]@export
class Constant(Config):
"""Constant is a special config which takes only certain value."""
value = None
[docs] def build(self, llh_name: Optional[str] = None):
"""Set value of Constant."""
if self.name in _cached_configs:
value = _cached_configs[self.name]
else:
value = self.get_default()
# Update values to sharing dictionary
_cached_configs[self.name] = value
if isinstance(value, dict):
try:
self.value = value[llh_name]
except KeyError:
mesg = (
f"You specified {self.name} as a dictionary. "
f"The key of it should be the name of one "
f"of the likelihood, but it is {llh_name}."
)
raise ValueError(mesg)
else:
self.value = value
@property
def lineage(self):
return {
"llh_name": self.llh_name,
"value": self.value,
}
[docs]@export
class Map(Config):
"""Map is a special config that takes input files.
The method ``apply`` is dynamically assigned.
When using points, the ``apply`` will be ``map_point``,
while using regular binning, the ``apply`` will be ``map_regbin``.
When using log-binning, we will first convert the positions to log space.
"""
def __init__(self, method="IDW", **kwargs):
super().__init__(**kwargs)
self.method = method
[docs] def build(self, llh_name: Optional[str] = None):
"""Cache the map to jnp.array."""
if self.name in _cached_configs:
file_path = _cached_configs[self.name]
else:
file_path = get_file_path(self.get_default())
# Update values to sharing dictionary
_cached_configs[self.name] = file_path
if isinstance(file_path, dict):
try:
self.file_path = file_path[llh_name]
except KeyError:
mesg = (
f"You specified {self.name} as a dictionary. "
f"The key of it should be the name of one "
f"of the likelihood, but it is {llh_name}."
)
raise ValueError(mesg)
else:
self.file_path = file_path
# try to find the path first
_file_path = get_file_path(self.file_path)
try:
data = load_json(self.file_path)
except Exception:
raise ValueError(f"Cannot load {self.name} from {_file_path}!")
coordinate_type = data["coordinate_type"]
if coordinate_type == "point" or coordinate_type == "log_point":
self.build_point(data)
elif coordinate_type == "regbin" or coordinate_type == "log_regbin":
self.build_regbin(data)
else:
raise ValueError("map_type must be either 'point' or 'regbin'!")
[docs] def build_point(self, data):
"""Cache the map to jnp.array if bins_type is point."""
if data["coordinate_name"] == "pdf" or data["coordinate_name"] == "cdf":
if data["coordinate_type"] == "log_point":
raise ValueError(
f"It is not a good idea to use log pdf nor cdf "
f"in map {self.file_path}. "
f"Because its coordinate type is log-binned. "
)
if data["coordinate_name"] == "pdf":
warn(f"Convert {self.name} from (x, pdf) to (cdf, x).")
x, cdf = self.pdf_to_cdf(data["coordinate_system"], data["map"])
data["coordinate_name"] = "cdf"
data["coordinate_system"] = cdf
data["map"] = x
self.coordinate_type = data["coordinate_type"]
self.coordinate_name = data["coordinate_name"]
self.coordinate_system = jnp.asarray(data["coordinate_system"], dtype=float)
self.map = jnp.asarray(data["map"], dtype=float)
if self.method == "IDW":
setattr(self, "interpolator", interpolation.curve_interpolator)
elif self.method == "NN":
setattr(
self,
"interpolator",
interpolation.map_interpolator_nearest_neighbor_1d,
)
elif self.method == "LERP":
setattr(
self,
"interpolator",
interpolation.map_interpolator_linear_1d,
)
else:
raise ValueError(f"Unknown method {self.method} for 1D regular binning.")
if self.coordinate_type == "log_point":
if jnp.any(self.coordinate_system <= 0):
raise ValueError(
f"Find non-positive coordinate system in map {self.file_path}, "
f"which is specified as {self.coordinate_type}"
)
setattr(self, "preprocess", self.log_pos)
else:
setattr(self, "preprocess", self.linear_pos)
setattr(self, "apply", self.map_point)
[docs] def map_point(self, pos):
val = self.interpolator(
self.preprocess(pos),
self.preprocess(self.coordinate_system),
self.preprocess(self.map),
)
return val
[docs] def build_regbin(self, data):
"""Cache the map to jnp.array if bins_type is regbin."""
if "pdf" in data["coordinate_name"] or "cdf" in data["coordinate_name"]:
if data["coordinate_type"] == "log_regbin":
raise ValueError(
f"It is not a good idea to use log pdf nor cdf "
f"in map {self.file_path}. "
f"Because its coordinate type is log-binned. "
)
self.coordinate_type = data["coordinate_type"]
self.coordinate_name = data["coordinate_name"]
self.coordinate_lowers = jnp.asarray(data["coordinate_lowers"], dtype=float)
self.coordinate_uppers = jnp.asarray(data["coordinate_uppers"], dtype=float)
self.map = jnp.asarray(data["map"], dtype=float)
if len(self.coordinate_lowers) == 1:
setattr(self, "interpolator", interpolation.map_interpolator_regular_binning_1d)
elif len(self.coordinate_lowers) == 2:
if self.method == "IDW":
setattr(self, "interpolator", interpolation.map_interpolator_regular_binning_2d)
elif self.method == "NN":
setattr(
self,
"interpolator",
interpolation.map_interpolator_regular_binning_nearest_neighbor_2d,
)
else:
raise ValueError(f"Unknown method {self.method} for 2D regular binning.")
elif len(self.coordinate_lowers) == 3:
if self.method == "IDW":
setattr(self, "interpolator", interpolation.map_interpolator_regular_binning_3d)
elif self.method == "NN":
setattr(
self,
"interpolator",
interpolation.map_interpolator_regular_binning_nearest_neighbor_3d,
)
else:
raise ValueError(f"Unknown method {self.method} for 3D regular binning.")
if self.coordinate_type == "log_regbin":
if jnp.any(self.coordinate_lowers <= 0) or jnp.any(self.coordinate_uppers <= 0):
raise ValueError(
f"Find non-positive coordinate system in map {self.file_path}, "
f"which is specified as {self.coordinate_type}"
)
setattr(self, "preprocess", self.log_pos)
else:
setattr(self, "preprocess", self.linear_pos)
setattr(self, "apply", self.map_regbin)
[docs] def map_regbin(self, pos):
val = self.interpolator(
self.preprocess(pos),
self.preprocess(self.coordinate_lowers),
self.preprocess(self.coordinate_uppers),
self.map,
)
return val
[docs] def linear_pos(self, pos):
return pos
[docs] def log_pos(self, pos):
return jnp.log10(jnp.clip(pos, a_min=FLOAT_POS_MIN, a_max=FLOAT_POS_MAX))
[docs] def pdf_to_cdf(self, x, pdf):
"""Convert pdf map to cdf map."""
norm = integrate_midpoint(x, pdf)
x, cdf = cumulative_integrate_midpoint(x, pdf)
cdf /= norm
return x, cdf
@property
def lineage(self):
return {
"llh_name": self.llh_name,
"method": self.method,
"file_path": (
os.path.basename(self.file_path)
if not utils.FULL_PATH_LINEAGE
else get_file_path(self.file_path)
),
"sha256": calculate_sha256(get_file_path(self.file_path)),
}
[docs]@export
class SigmaMap(Config):
"""Maps with uncertainty.
The value of a SigmaMap can be:
* a list with four elements,
which are the file names of median, lower, upper maps and the name of the scaler.
* a list with three elements,
which are the file names of median, lower and upper maps. The name of the scaler
is the default one f"{self.name}_sigma".
* a string,
which is the file name of the map for median, lower, upper.
In the first and second case, the name of the scaler will appear
in Component.needed_parameters.
"""
def __init__(self, method="IDW", **kwargs):
super().__init__(**kwargs)
self.method = method
[docs] def build(self, llh_name: Optional[str] = None):
"""Read maps."""
self.llh_name = llh_name
_configs = self.get_configs()
_configs_default = self.get_default()
if isinstance(_configs, list) and len(_configs) > 4:
raise ValueError(f"You give too much information in {self.name}'s configs.")
if isinstance(_configs_default, list) and len(_configs_default) > 4:
raise ValueError(f"You give too much information in {self.name}'s default configs.")
maps = dict()
sigmas = ["median", "lower", "upper"]
for i, sigma in enumerate(sigmas):
# propagate _configs_default to Map instances
if isinstance(_configs_default, list):
default = _configs_default[i]
else:
if not isinstance(_configs_default, str):
raise ValueError(
f"If {self.name}'s default configuration is not a list, "
"then it should be a string."
)
# If only one file is given, then use the same file for all sigmas
default = _configs_default
maps[sigma] = Map(method=self.method, name=self.name + f"_{sigma}", default=default)
setattr(self, sigma, maps[sigma])
if self.llh_name is None:
# if llh_name is not specified, no need to update _cached_configs
continue
# In case some plugins only use the median
# and may already update the map name in `_cached_configs`
if maps[sigma].name not in _cached_configs.keys():
_cached_configs[maps[sigma].name] = dict()
if isinstance(_cached_configs[maps[sigma].name], dict):
if isinstance(_configs, list):
value = _configs[i]
else:
if not isinstance(_configs, str):
raise ValueError(
f"If {self.name}'s configuration is not a list, "
"then it should be a string."
)
# If only one file is given, then use the same file for all sigmas
value = _configs
_value = _cached_configs[maps[sigma].name].get(self.llh_name, value)
if _value != value:
raise ValueError(
f"You give different values for {self.name} in "
f"configs, find {_value} and {value}."
)
_cached_configs[maps[sigma].name].update({self.llh_name: value})
self.median.build(llh_name=self.llh_name) # type: ignore
self.lower.build(llh_name=self.llh_name) # type: ignore
self.upper.build(llh_name=self.llh_name) # type: ignore
required_parameter = self.required_parameter()
if required_parameter is not None:
print(
f"{self.llh_name}'s map {self.name} is using "
f"the parameter {required_parameter}."
)
else:
print(f"{self.llh_name}'s map {self.name} is static and not using any parameter.")
[docs] def get_configs(self, llh_name=None):
"""Get configs of SigmaMap."""
# if llh_name is not specified, use the attribute of SigmaMap
if llh_name is None and self.llh_name is not None:
llh_name = self.llh_name
if self.name in _cached_configs:
_configs = _cached_configs[self.name]
else:
_configs = self.get_default()
# Update values to sharing dictionary
_cached_configs[self.name] = _configs
if isinstance(_configs, dict):
if llh_name is None:
raise ValueError(
f"You specified {self.name} as a dictionary in _cached_configs. "
"The key of it should be the name of one of the likelihood, but it is None."
)
try:
return _configs[llh_name]
except KeyError:
mesg = (
f"You specified {self.name} as a dictionary. "
f"The key of it should be the name of one "
f"of the likelihood, but it is {llh_name}."
)
raise ValueError(mesg)
else:
return _configs
[docs] def required_parameter(self, llh_name=None):
"""Get required parameter of SigmaMap."""
_configs = self.get_configs(llh_name=llh_name)
# Find required parameter
if isinstance(_configs, list):
if len(_configs) == 4:
return _configs[-1]
else:
return self.name + "_sigma"
else:
return None
[docs] def apply(self, pos, parameters):
"""Apply SigmaMap with sigma and position."""
if self.required_parameter() is None:
sigma = 1.0
else:
sigma = parameters[self.required_parameter()]
median = self.median.apply(pos)
lower = self.lower.apply(pos)
upper = self.upper.apply(pos)
add_pos = (upper - median) * sigma
add_neg = (median - lower) * sigma
add = jnp.where(sigma > 0, add_pos, add_neg)
return median + add
@property
def lineage(self):
return {
"llh_name": self.llh_name,
"method": self.method,
"median": self.median.lineage,
"lower": self.lower.lineage,
"upper": self.upper.lineage,
}
[docs]@export
class ConstantSet(Config):
"""ConstantSet is a special config which takes a set of values.
We will not specify any hard-coded distribution or function here. User should be careful with
the actual function implemented. Fortunately, we only use these values as keyword arguments, so
mismatch will be catched when running.
"""
[docs] def build(self, llh_name: Optional[str] = None):
"""Set value of Constant."""
if self.name in _cached_configs:
value = _cached_configs[self.name]
else:
value = self.get_default()
# Update values to sharing dictionary
_cached_configs[self.name] = value
if isinstance(value, dict):
try:
self.value = value[llh_name]
except KeyError:
mesg = (
f"You specified {self.name} as a dictionary. "
f"The key of it should be the name of one "
f"of the likelihood, but it is {llh_name}."
)
raise ValueError(mesg)
else:
self.value = value
self._sanity_check()
self.set_volume = len(self.value[1][0])
self.value = {k: np.array(v) for k, v in zip(*self.value)}
[docs] def _sanity_check(self):
"""Check if parameter set lengths are same."""
mesg = "The given values should follow [names, values] format."
assert len(self.value) == 2, mesg
mesg = "Parameters and their names should have same length"
assert len(self.value[0]) == len(self.value[1]), mesg
volumes = [len(v) for v in self.value[1]]
mesg = "Parameter set lengths should be the same"
assert np.all(np.isclose(volumes, volumes[0])), mesg
@property
def lineage(self):
return {
"llh_name": self.llh_name,
"value": self.value,
}