Source code for appletree.plugins.detector

from jax import numpy as jnp
from jax import jit
from functools import partial

from appletree import randgen
from appletree.plugin import Plugin
from appletree.config import takes_config, Map
from appletree.utils import exporter

export, __all__ = exporter(export_self=False)


[docs]@export @takes_config( Map( name="s1_lce", default="_s1_correction.json", help="S1 light collection efficiency", ), ) class S1LCE(Plugin): depends_on = ["x", "y", "z"] provides = ["s1_lce"]
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, x, y, z): pos_true = jnp.stack([x, y, z]).T s1_lce = self.s1_lce.apply(pos_true) return key, s1_lce
[docs]@export @takes_config( Map( name="s2_lce", default="_s2_correction.json", help="S2 light collection efficiency", ), ) class S2LCE(Plugin): depends_on = ["x", "y"] provides = ["s2_lce"]
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, x, y): pos_true = jnp.stack([x, y]).T s2_lce = self.s2_lce.apply(pos_true) return key, s2_lce
[docs]@export class PhotonDetection(Plugin): depends_on = ["num_photon", "s1_lce"] provides = ["num_s1_phd"] parameters = ("g1", "p_dpe")
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, num_photon, s1_lce): g1_true_no_dpe = jnp.clip(parameters["g1"] * s1_lce / (1.0 + parameters["p_dpe"]), 0, 1.0) key, num_s1_phd = randgen.binomial(key, g1_true_no_dpe, num_photon) return key, num_s1_phd
[docs]@export class S1PE(Plugin): depends_on = ["num_s1_phd"] provides = ["num_s1_pe"] parameters = ("p_dpe",)
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, num_s1_phd): key, num_s1_dpe = randgen.binomial(key, parameters["p_dpe"], num_s1_phd) num_s1_pe = num_s1_dpe + num_s1_phd return key, num_s1_pe
[docs]@export @takes_config( Map( name="elife", method="LERP", default="_elife.json", help="Electron lifetime correction", ), ) class DriftLoss(Plugin): depends_on = ["z"] provides = ["drift_survive_prob"] parameters = ("drift_velocity", "elife_sigma")
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, z): key, p = randgen.uniform(key, 0, 1.0, shape=jnp.shape(z)) lifetime = self.elife.apply(p) * (1 + parameters["elife_sigma"]) drift_survive_prob = jnp.exp(-jnp.abs(z) / parameters["drift_velocity"] / lifetime) return key, drift_survive_prob
[docs]@export class ElectronDrifted(Plugin): depends_on = ["num_electron", "drift_survive_prob"] provides = ["num_electron_drifted"]
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, num_electron, drift_survive_prob): key, num_electron_drifted = randgen.binomial(key, drift_survive_prob, num_electron) return key, num_electron_drifted
[docs]@takes_config( Map( name="gas_gain_relative", default="_gas_gain_relative.json", help="Gas gain as a function of (x, y) / mean gas gain", ), ) @export class S2PE(Plugin): depends_on = ["num_electron_drifted", "s2_lce", "x", "y"] provides = ["num_s2_pe"] parameters = ("g2", "gas_gain")
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, num_electron_drifted, s2_lce, x, y): pos_true = jnp.stack([x, y]).T gas_gain = parameters["gas_gain"] * self.gas_gain_relative.apply((pos_true)) extraction_eff = parameters["g2"] * s2_lce / gas_gain key, num_electron_extracted = randgen.binomial(key, extraction_eff, num_electron_drifted) mean_s2_pe = num_electron_extracted * gas_gain key, num_s2_pe = randgen.truncate_normal(key, mean_s2_pe, jnp.sqrt(mean_s2_pe), vmin=0) return key, num_s2_pe