Source code for appletree.plugins.efficiency

from jax import numpy as jnp
from jax import jit
from functools import partial

from appletree.plugin import Plugin
from appletree.config import takes_config, SigmaMap
from appletree.utils import exporter

export, __all__ = exporter(export_self=False)


[docs]@export class S2Threshold(Plugin): depends_on = ["s2_area"] provides = ["acc_s2_threshold"] parameters = ("s2_threshold",)
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, s2_area): return key, jnp.where(s2_area > parameters["s2_threshold"], 1.0, 0)
[docs]@export @takes_config( SigmaMap( name="s1_eff_3f", method="NN", default="_3fold_recon_eff.json", help="3fold S1 reconstruction efficiency", ), ) class S1ReconEff(Plugin): depends_on = ["num_s1_phd"] provides = ["acc_s1_recon_eff"] # parameters = ("s1_eff_3f_sigma",)
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, num_s1_phd): acc_s1_recon_eff = self.s1_eff_3f.apply(num_s1_phd, parameters) acc_s1_recon_eff = jnp.clip(acc_s1_recon_eff, 0.0, 1.0) return key, acc_s1_recon_eff
[docs]@export @takes_config( SigmaMap( name="s1_cut_acc", method="LERP", default=["_s1_cut_acc.json", "_s1_cut_acc.json", "_s1_cut_acc.json"], help="S1 cut acceptance", ), ) class S1CutAccept(Plugin): depends_on = ["s1_area"] provides = ["cut_acc_s1"] # parameters = ("s1_cut_acc_sigma",)
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, s1_area): cut_acc_s1 = self.s1_cut_acc.apply(s1_area, parameters) cut_acc_s1 = jnp.clip(cut_acc_s1, 0.0, 1.0) return key, cut_acc_s1
[docs]@export @takes_config( SigmaMap( name="s2_cut_acc", method="LERP", default=["_s2_cut_acc.json", "_s2_cut_acc.json", "_s2_cut_acc.json", "s2_cut_acc_sigma"], help="S2 cut acceptance", ), ) class S2CutAccept(Plugin): depends_on = ["s2_area"] provides = ["cut_acc_s2"] # parameters = ("s2_cut_acc_sigma",)
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, s2_area): cut_acc_s2 = self.s2_cut_acc.apply(s2_area, parameters) cut_acc_s2 = jnp.clip(cut_acc_s2, 0.0, 1.0) return key, cut_acc_s2
[docs]@export class Eff(Plugin): depends_on = ["acc_s2_threshold", "acc_s1_recon_eff", "cut_acc_s1", "cut_acc_s2"] provides = ["eff"]
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, acc_s2_threshold, acc_s1_recon_eff, cut_acc_s1, cut_acc_s2): eff = acc_s2_threshold * acc_s1_recon_eff * cut_acc_s1 * cut_acc_s2 return key, eff