Source code for appletree.plugins.reconstruction

from jax import jit
from functools import partial

from jax import numpy as jnp

from appletree import randgen
from appletree.config import takes_config, Map
from appletree.plugin import Plugin
from appletree.utils import exporter

export, __all__ = exporter(export_self=False)


[docs]@export @takes_config( Map( name="posrec_reso", method="LERP", default="_posrec_reso.json", help="Position reconstruction resolution", ), ) class PositionRecon(Plugin): depends_on = ["x", "y", "z", "num_electron_drifted"] provides = ["rec_x", "rec_y", "rec_z", "rec_r"]
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, x, y, z, num_electron_drifted): std = self.posrec_reso.apply(num_electron_drifted) std /= jnp.sqrt(2) mean = jnp.zeros_like(num_electron_drifted) key, delta_x = randgen.normal(key, mean, std) key, delta_y = randgen.normal(key, mean, std) rec_x = x + delta_x rec_y = y + delta_y rec_z = z rec_r = jnp.sqrt(rec_x**2 + rec_y**2) return key, rec_x, rec_y, rec_z, rec_r
[docs]@export @takes_config( Map( name="s1_bias_3f", method="LERP", default="_s1_bias.json", help="3fold S1 reconstruction bias", ), Map( name="s1_smear_3f", method="LERP", default="_s1_smearing.json", help="3fold S1 reconstruction smearing", ), ) class S1(Plugin): depends_on = ["num_s1_pe"] provides = ["s1_area"]
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, num_s1_pe): mean = self.s1_bias_3f.apply(num_s1_pe) std = self.s1_smear_3f.apply(num_s1_pe) key, bias = randgen.normal(key, mean, std) s1_area = num_s1_pe * (1.0 + bias) return key, s1_area
[docs]@export @takes_config( Map( name="s2_bias", method="LERP", default="_s2_bias.json", help="S2 reconstruction bias", ), Map( name="s2_smear", method="LERP", default="_s2_smearing.json", help="S2 reconstruction smearing", ), ) class S2(Plugin): depends_on = ["num_s2_pe"] provides = ["s2_area"]
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, num_s2_pe): mean = self.s2_bias.apply(num_s2_pe) std = self.s2_smear.apply(num_s2_pe) key, bias = randgen.normal(key, mean, std) s2_area = num_s2_pe * (1.0 + bias) return key, s2_area
[docs]@export @takes_config( Map( name="s1_correction", default="_s1_correction.json", help="S1 xyz correction on reconstructed positions", ), ) class S1Correction(Plugin): depends_on = ["rec_x", "rec_y", "rec_z"] provides = ["s1_correction"]
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, rec_x, rec_y, rec_z): pos_rec = jnp.stack([rec_x, rec_y, rec_z]).T s1_correction = self.s1_correction.apply(pos_rec) return key, s1_correction
[docs]@export @takes_config( Map( name="s2_correction", default="_s2_correction.json", help="S2 xy correction on constructed positions", ), ) class S2Correction(Plugin): depends_on = ["rec_x", "rec_y"] provides = ["s2_correction"]
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, rec_x, rec_y): pos_rec = jnp.stack([rec_x, rec_y]).T s2_correction = self.s2_correction.apply(pos_rec) return key, s2_correction
[docs]@export class cS1(Plugin): depends_on = ["s1_area", "s1_correction"] provides = ["cs1"]
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, s1_area, s1_correction): cs1 = s1_area / s1_correction return key, cs1
[docs]@export class cS2(Plugin): depends_on = ["s2_area", "s2_correction", "drift_survive_prob"] provides = ["cs2"]
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, s2_area, s2_correction, drift_survive_prob): cs2 = s2_area / s2_correction / drift_survive_prob return key, cs2