Source code for appletree.utils

import os
import json
from warnings import warn
import hashlib
from time import time
from importlib.resources import files as _files

import jax
import numpy as np
import pandas as pd
import matplotlib as mpl
from matplotlib.patches import Rectangle
from matplotlib import pyplot as plt
from scipy.special import erf
from scipy.optimize import root
from scipy.stats import chi2

import GOFevaluation
from appletree.share import _cached_configs

NT_AUX_INSTALLED = False
try:
    import ntauxfiles

    NT_AUX_INSTALLED = True
except ImportError:
    pass

SKIP_MONGO_DB = True

JSON_OPTIONS = dict(sort_keys=True, indent=4)

FULL_PATH_LINEAGE = False


[docs]def exporter(export_self=False): """Export utility modified from https://stackoverflow.com/a/41895194 Returns export decorator, __all__ list """ all_ = [] if export_self: all_.append("exporter") def decorator(obj): all_.append(obj.__name__) return obj return decorator, all_
export, __all__ = exporter(export_self=True)
[docs]@export def use_xenon_plot_style(): """Set matplotlib plot style.""" params = { "font.family": "serif", "font.size": 24, "axes.titlesize": 24, "axes.labelsize": 24, "axes.linewidth": 2, # ticks "xtick.labelsize": 22, "ytick.labelsize": 22, "xtick.major.size": 16, "xtick.minor.size": 8, "ytick.major.size": 16, "ytick.minor.size": 8, "xtick.major.width": 2, "xtick.minor.width": 2, "ytick.major.width": 2, "ytick.minor.width": 2, "xtick.direction": "in", "ytick.direction": "in", # markers "lines.markersize": 12, "lines.markeredgewidth": 3, "errorbar.capsize": 8, "lines.linewidth": 3, "savefig.bbox": "tight", "legend.fontsize": 24, "backend": "Agg", "mathtext.fontset": "dejavuserif", "legend.frameon": False, # figure "figure.facecolor": "w", "figure.figsize": (12, 8), # pad "axes.labelpad": 12, # ticks "xtick.major.pad": 6, "xtick.minor.pad": 6, "ytick.major.pad": 3.5, "ytick.minor.pad": 3.5, # colormap } plt.rcParams.update(params)
[docs]@export def load_data(file_name: str): """Load data from file. The suffix can be ".csv", ".pkl". """ file_name = get_file_path(file_name) fmt = file_name.split(".")[-1] if fmt == "csv": data = pd.read_csv(file_name) elif fmt == "pkl": data = pd.read_pickle(file_name) else: raise ValueError(f"unsupported file format {fmt}!") return data
[docs]@export def load_json(file_name: str): """Load data from json file.""" with open(get_file_path(file_name), "r") as file: data = json.load(file) return data
[docs]@export def _get_abspath(file_name): """Get the abspath of the file. Raise FileNotFoundError when not found in any subfolder """ for sub_dir in ("maps", "data", "parameters", "instructs"): p = os.path.join(_package_path(sub_dir), file_name) if os.path.exists(p): return p raise FileNotFoundError(f"Cannot find {file_name}")
[docs]def _package_path(sub_directory): """Get the abs path of the requested sub folder.""" return _files("appletree") / sub_directory
[docs]@export def get_file_path(fname): """Find the full path to the resource file. Try 5 methods in the following order. * fname begin with '/', return absolute path * url_base begin with '/', return url_base + name * can get file from _get_abspath, return appletree internal file path * can be found in local installed ntauxfiles, return ntauxfiles absolute path * can be downloaded from MongoDB, download and return cached path """ # 1. From absolute path if file exists # Usually Config.default is a absolute path if os.path.isfile(fname): return fname # 2. From local folder # Use url_base as prefix if "url_base" in _cached_configs.keys(): url_base = _cached_configs["url_base"] if url_base.startswith("/"): fpath = os.path.join(url_base, fname) if os.path.exists(fpath): warn(f"Load {fname} successfully from {fpath}") return fpath # 3. From appletree internal files try: return _get_abspath(fname) except FileNotFoundError: pass # 4. From local installed ntauxfiles if NT_AUX_INSTALLED: # You might want to use this, for example if you are a developer if fname in ntauxfiles.list_private_files(): fpath = ntauxfiles.get_abspath(fname) warn(f"Load {fname} successfully from {fpath}") return fpath # 5. From MongoDB if not SKIP_MONGO_DB: try: import utilix # Mongo downloader is implemented in utilix downloader = utilix.mongo_storage.MongoDownloader() # FileNotFoundError, ValueErrors can be raised if we # cannot load the requested config fpath = downloader.download_single(fname) warn(f"Loading {fname} from mongo downloader to {fpath}") return fpath except (FileNotFoundError, ValueError, NameError, AttributeError): warn(f"Mongo downloader not possible or does not have {fname}") # raise error when can not find corresponding file raise RuntimeError(f"Can not find {fname}, please check your file system")
[docs]@export def calculate_sha256(file_path): """Get sha256 hash of the file.""" sha256_hash = hashlib.sha256() with open(file_path, "rb") as f: for byte_block in iter(lambda: f.read(4096), b""): sha256_hash.update(byte_block) return sha256_hash.hexdigest()
[docs]@export def dump_lineage(file_path, entity): """Dump lineage of whatever level into .json file.""" with open(file_path, "w") as f: f.write(json.dumps(entity.lineage, **JSON_OPTIONS))
[docs]@export def timeit(indent=""): """Use timeit as a decorator. It will print out the running time of the decorated function. """ def _timeit(func, indent): name = func.__name__ def _func(*args, **kwargs): print(indent + f" Function <{name}> starts.") start = time() res = func(*args, **kwargs) time_ = (time() - start) * 1e3 print(indent + f" Function <{name}> ends! Time cost = {time_:.2f} msec.") return res return _func if isinstance(indent, str): return lambda func: _timeit(func, indent) else: return _timeit(indent, "")
[docs]@export def get_platform() -> str: """Show the platform we are using, e.g. 'cpu', 'gpu', or 'tpu'.""" if hasattr(jax, "default_backend"): return jax.default_backend() # Fallback for extremely old JAX versions: from jax.lib import xla_bridge return xla_bridge.get_backend().platform
[docs]@export def set_gpu_memory_usage(fraction=0.3): """Set GPU memory usage. See more on https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html """ if fraction > 1: fraction = 0.99 if fraction <= 0: raise ValueError("fraction must be positive!") os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = f".{int(fraction * 100):d}"
[docs]@export def get_equiprob_bins_1d( data, n_partitions, clip=(-np.inf, +np.inf), integer=False, left=True, which_np=np, ): """Get 2D equiprobable binning edges. Args: data: array with shape N. n_partitions: M1 which is the number of bins. clip: lower and upper binning edges on the 0th dimension. Data outside the clip will be dropped. Data outside the y_clip will be dropped. integer: bool, whether the corresponding dimension is integer. left: bool, whether start searching for bin edges from the left side. which_np: can be numpy or jax.numpy, determining the returned array type. """ mask = data > clip[0] mask &= data < clip[1] bins = GOFevaluation.utils.get_equiprobable_binning( data[mask], n_partitions, integer=integer, left=left, ) # To be strict, clip the inf(s) bins = np.clip(bins, *clip) return which_np.array(bins)
[docs]@export def get_equiprob_bins_2d( data, n_partitions, order=(0, 1), x_clip=(-np.inf, +np.inf), y_clip=(-np.inf, +np.inf), integer=[False, False], left=True, which_np=np, ): """Get 2D equiprobable binning edges. Args: data: array with shape (N, 2). n_partitions: [M1, M2] where M1 M2 are the number of bins on each dimension. x_clip: lower and upper binning edges on the 0th dimension. Data outside the x_clip will be dropped. y_clip: lower and upper binning edges on the 1st dimension. integer: list of bool with length 2, whether the corresponding dimension is integer. left: bool, whether start searching for bin edges from the left side. Data outside the y_clip will be dropped. which_np: can be numpy or jax.numpy, determining the returned array type. """ mask = data[:, 0] > x_clip[0] mask &= data[:, 0] < x_clip[1] mask &= data[:, 1] > y_clip[0] mask &= data[:, 1] < y_clip[1] x_bins, y_bins = GOFevaluation.utils.get_equiprobable_binning( data[mask], n_partitions, order=order, integer=integer, left=left, ) # To be strict, clip the inf(s) x_bins = np.clip(x_bins, *x_clip) y_bins = np.clip(y_bins, *y_clip) return which_np.array(x_bins), which_np.array(y_bins)
[docs]@export def plot_irreg_histogram_2d(bins_x, bins_y, hist, **kwargs): """Plot histogram defined by irregular binning. Args: bins_x: array with shape (M1,). bins_y: array with shape (M1-1, M2). hist: array with shape (M1-1, M2-1). density: boolean. """ hist = np.asarray(hist) bins_x = np.asarray(bins_x) bins_y = np.asarray(bins_y) density = kwargs.get("density", False) cmap = mpl.cm.RdBu_r loc = [] width = [] height = [] area = [] n = [] for i, _ in enumerate(hist): for j, _ in enumerate(hist[i]): x_lower = bins_x[i] x_upper = bins_x[i + 1] y_lower = bins_y[i, j] y_upper = bins_y[i, j + 1] loc.append((x_lower, y_lower)) width.append(x_upper - x_lower) height.append(y_upper - y_lower) area.append((x_upper - x_lower) * (y_upper - y_lower)) n.append(hist[i, j]) loc = np.asarray(loc) width = np.asarray(width) height = np.asarray(height) area = np.asarray(area) n = np.asarray(n) if density: norm = mpl.colors.Normalize( vmin=kwargs.get("vmin", np.min(n / area)), vmax=kwargs.get("vmax", np.max(n / area)), clip=False, ) else: norm = mpl.colors.Normalize( vmin=kwargs.get("vmin", np.min(n)), vmax=kwargs.get("vmax", np.max(n)), clip=False, ) ax = plt.subplot() for i, _ in enumerate(loc): c = n[i] / area[i] if density else n[i] rec = Rectangle( loc[i], width[i], height[i], facecolor=cmap(norm(c)), edgecolor="k", ) ax.add_patch(rec) fig = plt.gcf() fig.colorbar( mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, label=("# events / bin size" if density else "# events"), ) ax.set_xlim(np.min(bins_x), np.max(bins_x)) ax.set_ylim(np.min(bins_y), np.max(bins_y)) return ax
[docs]@export def add_spaces(x): """Add four spaces to every line in x This is needed to make html raw blocks in rst format correctly.""" y = "" if isinstance(x, str): x = x.split("\n") for q in x: y += " " + q return y
[docs]@export def tree_to_svg(graph_tree, save_as="data_types", view=True): """Where to save this node. Args: graph_tree: Digraph instance. save_as: str, file name. view: bool, Open the rendered result with the default application. """ graph_tree.render(save_as, view=view) with open(f"{save_as}.svg", mode="r") as f: svg = add_spaces(f.readlines()[5:]) os.remove(save_as) return svg
[docs]@export def add_deps_to_graph_tree( component, graph_tree, data_names: list = ["cs1", "cs2", "eff"], _seen=None ): """Recursively add nodes to graph base on plugin.deps. Args: context: Context instance. graph_tree: Digraph instance. data_names: Data type name. _seen: list or None, the seen data_name should not be plot. """ if _seen is None: _seen = [] for data_name in data_names: if data_name in _seen: continue # Add new one graph_tree.node( data_name, style="filled", href="#" + data_name.replace("_", "-"), fillcolor="white" ) if data_name == "batch_size": continue dep_plugin = component._plugin_class_registry[data_name] for dep in dep_plugin.depends_on: graph_tree.edge(data_name, dep) graph_tree, _seen = add_deps_to_graph_tree( component, graph_tree, dep_plugin.depends_on, _seen, ) _seen.append(data_name) return graph_tree, _seen
[docs]@export def add_plugins_to_graph_tree( component, graph_tree, data_names: list = ["cs1", "cs2", "eff"], _seen=None, with_data_names=False, ): """Recursively add nodes to graph base on plugin.deps. Args: context: Context instance. graph_tree: Digraph instance. data_names: Data type name. _seen: list or None, the seen data_name should not be plot. with_data_names: bool, whether plot even with messy data_names """ if _seen is None: _seen = [] for data_name in data_names: if data_name == "batch_size": continue plugin = component._plugin_class_registry[data_name] plugin_name = plugin.__name__ if plugin_name in _seen: continue # Add new one label = f"{plugin_name}" if with_data_names: label += f"\n{', '.join(plugin.depends_on)}\n{', '.join(plugin.provides)}" graph_tree.node( plugin_name, label=label, style="filled", href="#" + plugin_name.replace("_", "-"), fillcolor="white", ) for dep in plugin.depends_on: if dep == "batch_size": continue dep_plugin = component._plugin_class_registry[dep] graph_tree.edge(plugin_name, dep_plugin.__name__) graph_tree, _seen = add_plugins_to_graph_tree( component, graph_tree, plugin.depends_on, _seen, ) _seen.append(plugin_name) return graph_tree, _seen
[docs]@export def add_extensions(module1, module2, base, force=False): """Add subclasses of module2 to module1. When ComponentSim compiles the dependency tree, it will search in the appletree.plugins module for Plugin(as attributes). When building Likelihood, it will also search for corresponding Component(s) specified in the instructions(e.g. NRBand). So we need to assign the attributes before compilation. These plugins are mostly user defined. """ # Assign the module2 as attribute of module1 is_exists = module2.__name__ in dir(module1) if is_exists and not force: raise ValueError( f"{module2.__name__} already existed in {module1.__name__}, " f"do not re-register a module with same name", ) else: if is_exists: print(f"You have forcibly registered {module2.__name__} to {module1.__name__}") setattr(module1, module2.__name__.split(".")[-1], module2) # Iterate the module2 and assign the single Plugin(s) as attribute(s) for x in dir(module2): x = getattr(module2, x) if not isinstance(x, type(type)): continue _add_extension(module1, x, base, force=force)
[docs]def _add_extension(module, subclass, base, force=False): """Add subclass to module Skip the class when it is base class. It is no allowed to assign a class which has same name to an already assigned class. We do not allowed class name covering! Please change the name of your class when Error shows itself. """ if getattr(subclass, "_" + subclass.__name__ + "__is_base", False): return if issubclass(subclass, base) and subclass != base: is_exists = subclass.__name__ in dir(module) if is_exists and not force: raise ValueError( f"{subclass.__name__} already existed in {module.__name__}, " f"do not re-register a {base.__name__} with same name", ) else: if is_exists: print(f"You have forcibly registered {subclass.__name__} to {module.__name__}") setattr(module, subclass.__name__, subclass)
def integrate_midpoint(x, y): """Calculate the integral using midpoint method. Args: x: 1D array-like. y: 1D array-like, with the same length as x. """ _, res = cumulative_integrate_midpoint(x, y) return res[-1] def cumulative_integrate_midpoint(x, y): """Calculate the cumulative integral using midpoint method. Args: x: 1D array-like. y: 1D array-like, with the same length as x. """ x = np.array(x) y = np.array(y) dx = x[1:] - x[:-1] x_mid = 0.5 * (x[1:] + x[:-1]) y_mid = 0.5 * (y[1:] + y[:-1]) return x_mid, np.cumsum(dx * y_mid)
[docs]@export def check_unused_configs(): """Check if there are unused configs.""" unused_configs = set(_cached_configs.keys()) - _cached_configs.accessed_keys if unused_configs: warn(f"Detected unused configs: {unused_configs}, you might set the configs incorrectly.")
[docs]@export def errors_to_two_half_norm_sigmas(errors): """This function solves the sigmas for a two-half-norm distribution, such that the 16 and 84 percentile corresponds to the given errors. In the two-half-norm distribution, the positive and negative errors are assumed to be the std of the glued normal distributions. While we interpret the 16 and 84 percentile as the input errors, thus we need to solve the sigmas for the two-half-norm distribution. The solution is determined by the following conditions: - The 16 percentile of the two-half-norm distribution should be the negative error. - The 84 percentile of the two-half-norm distribution should be the positive error. - The mode of the two-half-norm distribution should be 0. """ def _to_solve(x, errors, p): return [ x[0] / (x[0] + x[1]) * (1 - erf(errors[0] / x[0] / np.sqrt(2))) - p / 2, x[1] / (x[0] + x[1]) * (1 - erf(errors[1] / x[1] / np.sqrt(2))) - p / 2, ] res = root(_to_solve, errors, args=(errors, 1 - chi2.cdf(1, 1))) assert res.success, f"Cannot solve sigmas of TwoHalfNorm for errors {errors}!" return res.x