import os
from warnings import warn
from typing import Type, Dict, Set, Optional, cast
import inspect
from copy import deepcopy
import numpy as np
from scipy.stats import norm
from strax import deterministic_hash
from appletree import utils
from appletree import randgen
from appletree.hist import make_hist_mesh_grid, make_hist_irreg_bin_1d, make_hist_irreg_bin_2d
from appletree.utils import (
get_file_path,
load_data,
get_equiprob_bins_1d,
get_equiprob_bins_2d,
calculate_sha256,
)
from appletree.component import Component, ComponentSim, ComponentFixed
from appletree.randgen import TwoHalfNorm, BandTwoHalfNorm
[docs]def need_replacing_alias(func):
"""Decorator to replace alias in parameters from key to value."""
def wrapper(self, *args, **kwargs):
sig = inspect.signature(func)
bound = sig.bind(self, *args, **kwargs)
bound.apply_defaults()
# Check if 'parameters' is in the arguments and replace aliases
if "parameters" in bound.arguments:
# Replace alias in 'parameters'
bound.arguments["parameters"] = self.replace_alias(bound.arguments["parameters"])
else:
raise ValueError(f"'parameters' must be in the arguments of {func.__name__}!")
# Call the function with possibly modified arguments
return func(*bound.args, **bound.kwargs)
return wrapper
[docs]class Likelihood:
"""Combine all components (e.g. ER, AC, Wall), and calculate log posterior likelihood."""
[docs] def __init__(self, name: Optional[str] = None, **config):
"""Create an appletree likelihood.
Args:
config: Dictionary with configuration options that will be applied, should include:
* data_file_name: the data used in fitting, usually calibration data
* bins_type: either meshgrid or equiprob
* bins_on: observables where we will perform inference on, usually [cs1, cs2]
* x_clip, y_clip: ROI of the fitting, should be list of upper and lower boundary
* parameter_alias: alias of parameters that will be renamed from key to value
"""
if name is None:
self.name = self.__class__.__name__
else:
self.name = name
self.components = cast(Dict[str, Component], dict())
self._config = config
self._data_file_name = config["data_file_name"]
self._bins_type = config["bins_type"]
self._bins_on = config["bins_on"]
self._bins = config["bins"]
self._parameter_alias = config.get("parameter_alias", dict())
if isinstance(self._bins_on, str):
self._dim = 1
self._bins_on = [self._bins_on]
if isinstance(self._bins, int):
self._bins = [self._bins]
elif isinstance(self._bins_on, list):
self._dim = len(self._bins_on)
assert isinstance(
self._bins, list
), f"bins should be list if not 1D fitting, but got {self._bins}!"
else:
raise ValueError(
f"bins_on should be either str or list of str, but got {self._bins_on}."
)
self.needed_parameters: Set[str] = set()
self._sanity_check()
self.data = load_data(self._data_file_name)[self._bins_on].to_numpy()
if self._dim == 1:
mask = self.data[:, 0] > config["clip"][0]
mask &= self.data[:, 0] < config["clip"][1]
else:
mask = self.data[:, 0] > config["x_clip"][0]
mask &= self.data[:, 0] < config["x_clip"][1]
mask &= self.data[:, 1] > config["y_clip"][0]
mask &= self.data[:, 1] < config["y_clip"][1]
self.data = self.data[mask]
self.set_binning(config)
[docs] def __getitem__(self, keys):
"""Get component in likelihood."""
return self.components[keys]
[docs] def _resolve_bin_edges(self, bin_spec, clip_range, config, warn_key):
"""Resolve a single axis bin specification into bin edges.
Args:
bin_spec: int (number of bins) or array-like (explicit edges).
clip_range: tuple of (lo, hi) for linspace when bin_spec is int.
config: the full config dict.
warn_key: config key (e.g. "x_clip") to check for spurious warning.
"""
if isinstance(bin_spec, int):
return np.linspace(*clip_range, bin_spec + 1)
else:
edges = np.array(bin_spec)
if warn_key in config:
warn(
f"{warn_key} is ignored when bins_type is "
f"{self._bins_type} and bins is not int"
)
return edges
[docs] def _resolve_explicit_bins(self, config):
"""Resolve explicit bin specs (meshgrid or irreg) into a tuple of arrays."""
if self._dim == 1:
bins = self._resolve_bin_edges(
self._bins[0],
config["clip"],
config,
"x_clip",
)
return (bins,)
else:
x_bins = self._resolve_bin_edges(
self._bins[0],
config["x_clip"],
config,
"x_clip",
)
y_bins = self._resolve_bin_edges(
self._bins[1],
config["y_clip"],
config,
"y_clip",
)
return (x_bins, y_bins)
[docs] def _make_data_hist(self, use_meshgrid=False):
"""Create the data histogram from self.data and self._bins."""
weights = np.ones(len(self.data))
if use_meshgrid:
self.data_hist = make_hist_mesh_grid(
self.data,
bins=self._bins,
weights=weights,
)
elif self._dim == 1:
self.data_hist = make_hist_irreg_bin_1d(
self.data[:, 0],
bins=self._bins[0],
weights=weights,
)
else:
self.data_hist = make_hist_irreg_bin_2d(
self.data,
bins_x=self._bins[0],
bins_y=self._bins[1],
weights=weights,
)
[docs] def _validate_irreg_bins_2d(self):
"""Validate irregular 2D bins: x-bins length and y-bins uniformity."""
if len(self._bins[0]) != len(self._bins[1]) + 1:
raise ValueError(
f"The x-binning should 1 longer than y-binning, "
f"please check the binning in {self.name}!"
)
if not all(len(b) == len(self._bins[1][0]) for b in self._bins[1]):
raise ValueError(
f"All y-binning should have the same length, "
f"please check the binning in {self.name}!"
)
[docs] def set_binning(self, config):
"""Set binning of likelihood."""
if self._dim not in (1, 2):
raise ValueError(f"Currently only support 1D and 2D, but got {self._dim}D!")
if self._bins_type == "meshgrid":
warn("The usage of meshgrid binning is highly discouraged.")
self.component_bins_type = "meshgrid"
self._bins = self._resolve_explicit_bins(config)
self._make_data_hist(use_meshgrid=True)
elif self._bins_type == "equiprob":
if not all(isinstance(b, int) for b in self._bins):
raise RuntimeError("bins can only be int if bins_type is equiprob")
if self._dim == 1:
self._bins = (
get_equiprob_bins_1d(
self.data[:, 0],
self._bins[0],
clip=config["clip"],
integer=config.get("integer", False),
left=config.get("left", True),
which_np=np,
),
)
else:
self._bins = get_equiprob_bins_2d(
self.data,
self._bins,
x_clip=config["x_clip"],
y_clip=config["y_clip"],
integer=config.get("integer", [False, False]),
left=config.get("left", True),
which_np=np,
)
self.component_bins_type = "irreg"
self._make_data_hist()
elif self._bins_type == "irreg":
self._bins = [np.array(b) for b in self._bins]
self._bins = self._resolve_explicit_bins(config)
self.component_bins_type = "irreg"
if self._dim == 2:
self._validate_irreg_bins_2d()
self._make_data_hist()
else:
raise ValueError("'bins_type' should either be meshgrid, equiprob or irreg")
assert isinstance(self._bins, tuple), "bins should be tuple after setting binning!"
[docs] def register_component(
self, component_cls: Type[Component], component_name: str, file_name: Optional[str] = None
):
"""Create an appletree likelihood.
Args:
component_cls: class of Component.
component_name: name of Component.
file_name: file used in ComponentFixed.
"""
if component_name in self.components:
raise ValueError(f"Component named {component_name} already existed!")
# Initialize component
component = component_cls(
name=component_name,
llh_name=self.name,
bins=self._bins,
bins_type=self.component_bins_type,
file_name=file_name,
)
component.rate_name = component_name + "_rate"
kwargs = {"data_names": self._bins_on}
if isinstance(component, ComponentSim):
kwargs["func_name"] = self.name + "_" + component_name + "_sim"
kwargs["data_names"] = self._bins_on + ["eff"]
component.deduce(**kwargs)
component.compile()
# Update components sheet
self.components[component_name] = component
# Update needed parameters
self.needed_parameters |= self.components[component_name].needed_parameters
# Replace alias in needed parameters
for k, v in self._parameter_alias.items():
if v in self.needed_parameters:
self.needed_parameters.add(k)
self.needed_parameters.remove(v)
[docs] def replace_alias(self, parameters):
"""Replace alias in parameters from key to value.
Note that the value will be popped out.
"""
_parameters = deepcopy(parameters)
for k, v in self._parameter_alias.items():
if k in _parameters:
_parameters[v] = _parameters.pop(k)
return _parameters
[docs] def _sanity_check(self):
"""Check equality between number of bins group and observables."""
if len(self._bins_on) != len(self._bins):
raise RuntimeError("Length of bins must be the same as length of bins_on!")
@need_replacing_alias
def _simulate_model_hist(self, key, batch_size, parameters):
"""Histogram of simulated observables.
Args:
key: a pseudo-random number generator (PRNG) key.
batch_size: int of number of simulated events.
parameters: dict of parameters used in simulation.
"""
hist = np.zeros_like(self.data_hist)
for component_name, component in self.components.items():
if isinstance(component, ComponentSim):
key, _hist = component.simulate_hist(key, batch_size, parameters)
elif isinstance(component, ComponentFixed):
_hist = component.simulate_hist(parameters)
else:
raise TypeError(f"unsupported component type for {component_name}!")
hist += _hist
return key, hist
[docs] @need_replacing_alias
def simulate_weighted_data(self, key, batch_size, parameters):
"""Simulate weighted histogram.
Args:
key: a pseudo-random number generator (PRNG) key.
batch_size: int of number of simulated events.
parameters: dict of parameters used in simulation.
"""
result = []
for component_name, component in self.components.items():
if isinstance(component, ComponentSim):
key, _result = component.simulate_weighted_data(key, batch_size, parameters)
elif isinstance(component, ComponentFixed):
_result = component.simulate_weighted_data(parameters)
else:
raise TypeError(f"unsupported component type for {component_name}!")
result.append(_result)
result = list(r for r in np.hstack(result))
return key, result
[docs] @need_replacing_alias
def get_log_likelihood(self, key, batch_size, parameters):
"""Get log likelihood of given parameters.
Args:
key: a pseudo-random number generator (PRNG) key.
batch_size: int of number of simulated events.
parameters: dict of parameters used in simulation.
"""
key, model_hist = self._simulate_model_hist(key, batch_size, parameters)
# More stable if we first check zeros in model_hist
# If zeros exist, return -inf directly
if np.any(model_hist <= 0):
warn(
"Zero or negative bin(s) in model histogram encountered! "
"Consider increasing batch_size or placing more stringent bounds "
"on parameters to avoid this issue."
)
return key, -float("inf")
if np.any(np.isnan(model_hist)) or np.any(np.isinf(model_hist)):
warn(
"NaN or infinite bin(s) in model histogram encountered! "
"This usually means there is something wrong in your plugins, "
"or the parameters are out of reasonable range."
)
return key, -float("inf")
# Poisson likelihood
llh = np.sum(self.data_hist * np.log(model_hist) - model_hist)
llh = float(llh)
if np.isnan(llh):
raise ValueError("NaN log likelihood encountered!")
return key, llh
[docs] @need_replacing_alias
def get_num_events_accepted(self, batch_size, parameters):
"""Get number of events in the histogram under given parameters.
Args:
batch_size: int of number of simulated events.
parameters: dict of parameters used in simulation.
"""
key = randgen.get_key()
_, model_hist = self._simulate_model_hist(key, batch_size, parameters)
return model_hist.sum()
[docs] def _print_components(self, indent: str = " " * 4, short: bool = True):
"""Print component details shared by Likelihood and LikelihoodLit."""
for i, component_name in enumerate(self.components):
component = self[component_name]
need = component.needed_parameters
print(f"{indent}COMPONENT {i}: {component_name}")
if isinstance(component, ComponentSim):
print(f"{indent * 2}type: simulation")
print(f"{indent * 2}rate_par: {component.rate_name}")
print(f"{indent * 2}pars: {need}")
if not short:
print(f"{indent * 2}worksheet: {component.worksheet}")
elif isinstance(component, ComponentFixed):
print(f"{indent * 2}type: fixed")
print(f"{indent * 2}file_name: {component._file_name}")
print(f"{indent * 2}rate_par: {component.rate_name}")
print(f"{indent * 2}pars: {need}")
if not short:
print(f"{indent * 2}from_file: {component._file_name}")
print()
[docs] def print_likelihood_summary(self, indent: str = " " * 4, short: bool = True):
"""Print likelihood summary: components, bins, file names.
Args:
indent: str of indent.
short: bool, whether only print short summary.
"""
print("\n" + "-" * 40)
print("BINNING\n")
print(f"{indent}bins_type: {self._bins_type}")
print(f"{indent}bins_on: {self._bins_on}")
if not short:
print(f"{indent}bins: {self._bins}")
print("\n" + "-" * 40)
print("DATA\n")
print(f"{indent}file_name: {self._data_file_name}")
print(f"{indent}data_rate: {float(self.data_hist.sum())}")
print("\n" + "-" * 40)
print("MODEL\n")
self._print_components(indent, short)
print("-" * 40)
@property
def lineage(self):
return {
"config": self._config,
"file_path": (
os.path.basename(self._data_file_name)
if not utils.FULL_PATH_LINEAGE
else get_file_path(self._data_file_name)
),
"sha256": calculate_sha256(get_file_path(self._data_file_name)),
"components": dict(
zip(
self.components.keys(),
[v.lineage for v in self.components.values()],
)
),
}
@property
def lineage_hash(self):
return deterministic_hash(self.lineage)
[docs]class LikelihoodLit(Likelihood):
"""Using literature constraint to build LLH.
The idea is to simulate light and charge yields directly with given energy distribution. And
then fit the result with provided literature measurement points. The energy distribution will
always be twohalfnorm(TwoHalfNorm), norm or band, which is specified by
"""
[docs] def __init__(self, name: Optional[str] = None, **config):
"""Create an appletree likelihood.
Args:
config: Dictionary with configuration options that will be applied.
"""
if name is None:
self.name = self.__class__.__name__
else:
self.name = name
self.components = dict()
self._config = config
self._bins = None
self._bins_type = None
self._bins_on = config["bins_on"]
self._dim = len(self._bins_on)
self._parameter_alias = config.get("parameter_alias", dict())
self.needed_parameters: Set[str] = set()
self.component_bins_type = None
logpdf_args = self._config["logpdf_args"]
self.logpdf_args = {k: np.array(v) for k, v in zip(*logpdf_args)}
self.variable_type = config["variable_type"]
self.warning = "Currently only support two dimensional inference"
self._sanity_check()
if self.variable_type == "twohalfnorm":
setattr(self, "logpdf", lambda x, y: TwoHalfNorm.logpdf(x=y, **self.logpdf_args))
elif self.variable_type == "norm":
setattr(self, "logpdf", lambda x, y: norm.logpdf(x=y, **self.logpdf_args))
elif self.variable_type == "band":
self.bandtwohalfnorm = BandTwoHalfNorm(**self.logpdf_args)
setattr(self, "logpdf", lambda x, y: self.bandtwohalfnorm.logpdf(x=x, y=y))
else:
raise NotImplementedError
[docs] def _sanity_check(self):
"""Check sanities of supported distribution and dimension."""
if self.variable_type not in ["twohalfnorm", "norm", "band"]:
raise RuntimeError("Currently only twohalfnorm, norm and band are supported")
if self._dim != 2:
raise AssertionError(self.warning)
@need_replacing_alias
def _simulate_yields(self, key, batch_size, parameters):
"""Histogram of simulated observables.
Args:
key: a pseudo-random number generator (PRNG) key.
batch_size: int of number of simulated events.
parameters: dict of parameters used in simulation.
"""
if len(self.components) != 1:
raise AssertionError(self.warning)
key, result = self.components[self.only_component].simulate(key, batch_size, parameters)
# Move data to CPU
result = [np.array(r) for r in result]
return key, result
[docs] def register_component(self, *args, **kwargs):
if len(self.components) != 0:
raise AssertionError(self.warning)
super().register_component(*args, **kwargs)
# cache the component name
self.only_component = list(self.components.keys())[0]
[docs] @need_replacing_alias
def get_log_likelihood(self, key, batch_size, parameters):
"""Get log likelihood of given parameters.
Args:
key: a pseudo-random number generator (PRNG) key.
batch_size: int of number of simulated events.
parameters: dict of parameters used in simulation.
"""
if batch_size != 1:
warning = (
"You specified the batch_size larger than 1, "
"but it should and will be changed to 1 in literature fitting!"
)
warn(warning)
key, result = self._simulate_yields(key, 1, parameters)
energies, yields, eff = result
llh = self.logpdf(energies, yields)
llh = (llh * eff).sum()
llh = float(llh)
if np.isnan(llh):
llh = -np.inf
return key, llh
[docs] def print_likelihood_summary(self, indent: str = " " * 4, short: bool = True):
"""Print likelihood summary: components, bins, file names.
Args:
indent: str of indent.
short: bool, whether only print short summary.
"""
print("\n" + "-" * 40)
print("BINNING\n")
print(f"{indent}variable_type: {self.variable_type}")
print(f"{indent}variable: {self._bins_on}")
print("\n" + "-" * 40)
print("LOGPDF\n")
print(f"{indent}logpdf_args:")
for k, v in self.logpdf_args.items():
print(f"{indent * 2}{k}: {v}")
print("\n" + "-" * 40)
print("MODEL\n")
self._print_components(indent, short)
print("-" * 40)
@property
def lineage(self):
return {
"config": self._config,
"components": dict(
zip(
self.components.keys(),
[v.lineage_hash for v in self.components.values()],
)
),
}