Source code for appletree.plugins.nestv2

from jax import numpy as jnp
from jax import jit
from functools import partial

from appletree import randgen
from appletree.plugin import Plugin
from appletree.config import takes_config, Constant, ConstantSet
from appletree.utils import exporter

export, __all__ = exporter(export_self=False)

# These scripts are copied from
# https://github.com/NESTCollaboration/nest/releases/tag/v2.3.7
# and https://github.com/NESTCollaboration/nest/blob/v2.3.7/src/NEST.cpp#L715-L794
# Priors of the distribution is copied from https://arxiv.org/abs/2211.10726
# and https://drive.google.com/file/d/1urVT3htFjIC1pQKyaCcFonvWLt74Kgvn/view
# All variables begins with '_' are expectation values, such as `_Nph`, `_Ne`.


[docs]@export @takes_config( ConstantSet( name="energy_twohalfnorm", default=[ ["mu", "sigma_pos", "sigma_neg"], [[1.0], [0.1], [0.1]], ], help="Parameterized energy spectrum", ), ) class MonoEnergiesSpectra(Plugin): depends_on = ["batch_size"] provides = ["energy", "energy_center"]
[docs] @partial(jit, static_argnums=(0, 3)) def simulate(self, key, parameters, batch_size): key, energy = randgen.twohalfnorm( key, shape=(batch_size, self.energy_twohalfnorm.set_volume), **self.energy_twohalfnorm.value ) energy = jnp.clip(energy, 0.0, jnp.inf) energy_center = jnp.broadcast_to( self.energy_twohalfnorm.value["mu"], jnp.shape(energy) ).astype(float) return key, energy, energy_center
[docs]@export @takes_config( Constant( name="clip_lower_energy", type=float, default=0.5, help="Smallest energy considered in inference", ), Constant( name="clip_upper_energy", type=float, default=2.5, help="Largest energy considered in inference", ), ) class UniformEnergiesSpectra(Plugin): depends_on = ["batch_size"] provides = ["energy"]
[docs] @partial(jit, static_argnums=(0, 3)) def simulate(self, key, parameters, batch_size): key, energy = randgen.uniform( key, self.clip_lower_energy.value, self.clip_upper_energy.value, shape=(batch_size,), ) return key, energy
[docs]@export class TotalQuanta(Plugin): depends_on = ["energy"] provides = ["_Nq"] parameters = ("alpha", "beta")
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, energy): _Nq = parameters["alpha"] * energy ** parameters["beta"] return key, _Nq
[docs]@export @takes_config( Constant( name="literature_field", type=float, default=23.0, help="Drift field in each literature" ), ) class ThomasImelBox(Plugin): depends_on = ["energy"] provides = ["ThomasImel"] parameters = ("gamma", "delta", "liquid_xe_density")
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, energy): ThomasImel = jnp.ones(shape=jnp.shape(energy)) ThomasImel *= parameters["gamma"] * self.literature_field.value ** parameters["delta"] ThomasImel *= (parameters["liquid_xe_density"] / 2.9) ** 0.3 return key, ThomasImel
[docs]@export class QyNR(Plugin): depends_on = ["energy", "ThomasImel"] provides = ["charge_yield"] parameters = ("epsilon", "zeta", "eta")
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, energy, ThomasImel): charge_yield = 1 / ThomasImel / jnp.sqrt(energy + parameters["epsilon"]) charge_yield *= 1 - 1 / (1 + (energy / parameters["zeta"]) ** parameters["eta"]) charge_yield = jnp.clip(charge_yield, 0, jnp.inf) return key, charge_yield
[docs]@export class LyNR(Plugin): depends_on = ["energy", "_Nq", "charge_yield"] provides = ["light_yield"] parameters = ("theta", "iota")
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, energy, _Nq, charge_yield): light_yield = _Nq / energy - charge_yield light_yield *= 1 - 1 / (1 + (energy / parameters["theta"]) ** parameters["iota"]) light_yield = jnp.clip(light_yield, 0, jnp.inf) return key, light_yield
[docs]@export class MeanNphNe(Plugin): depends_on = ["light_yield", "charge_yield", "energy"] provides = ["_Nph", "_Ne"]
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, light_yield, charge_yield, energy): _Nph = light_yield * energy _Ne = charge_yield * energy return key, _Nph, _Ne
[docs]@export class MeanExcitonIon(Plugin): depends_on = ["ThomasImel", "_Nph", "_Ne"] provides = ["_Nex", "_Ni", "nex_ni_ratio", "alf", "elecFrac", "recombProb"]
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, ThomasImel, _Nph, _Ne): _Nex = (-1.0 / ThomasImel) * ( 4.0 * jnp.exp(_Ne * ThomasImel / 4.0) - (_Ne + _Nph) * ThomasImel - 4.0 ) _Ni = (4.0 / ThomasImel) * (jnp.exp(_Ne * ThomasImel / 4.0) - 1.0) nex_ni_ratio = _Nex / _Ni alf = 1.0 / (1.0 + nex_ni_ratio) elecFrac = _Ne / (_Nph + _Ne) recombProb = 1.0 - (nex_ni_ratio + 1.0) * elecFrac return key, _Nex, _Ni, nex_ni_ratio, alf, elecFrac, recombProb
[docs]@export class TrueExcitonIonNR(Plugin): depends_on = ["_Nph", "_Ne", "nex_ni_ratio", "alf"] provides = ["Ni", "Nex", "Nq"] parameters = ("fano_ni", "fano_nex")
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, _Nph, _Ne, nex_ni_ratio, alf): Nq_mean = _Nph + _Ne key, Ni = randgen.truncate_normal( key, Nq_mean * alf, jnp.sqrt(parameters["fano_ni"] * Nq_mean * alf), vmin=0 ) Ni = Ni.round().astype(int) key, Nex = randgen.truncate_normal( key, Nq_mean * nex_ni_ratio * alf, jnp.sqrt(parameters["fano_nex"] * Nq_mean * nex_ni_ratio * alf), vmin=0, ) Nex = Nex.round().astype(int) Nq = Nex + Ni return key, Ni, Nex, Nq
[docs]@export class OmegaNR(Plugin): depends_on = ["elecFrac", "recombProb", "Ni"] provides = ["omega", "Variance"] parameters = ("A", "xi", "omega")
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, elecFrac, recombProb, Ni): omega = parameters["A"] * jnp.exp( -0.5 * (elecFrac - parameters["xi"]) ** 2.0 / (parameters["omega"] ** 2) ) Variance = recombProb * (1.0 - recombProb) * Ni + omega * omega * Ni * Ni return key, omega, Variance
[docs]@export class TruePhotonElectronNR(Plugin): depends_on = ["recombProb", "Variance", "Ni", "Nq"] provides = ["num_photon", "num_electron"] parameters = ("alpha2",)
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, recombProb, Variance, Ni, Nq): # these parameters will make mean num_electron is just (1. - recombProb) * Ni widthCorrection = ( 1.0 - (2.0 / jnp.pi) * parameters["alpha2"] ** 2 / (1.0 + parameters["alpha2"] ** 2) ) ** 0.5 muCorrection = ( (jnp.sqrt(Variance) / widthCorrection) * (parameters["alpha2"] / (1.0 + parameters["alpha2"] ** 2) ** 0.5) * 2.0 * (1.0 / (2.0 * jnp.pi) ** 0.5) ) key, num_electron = randgen.skewnormal( key, jnp.full(len(recombProb), parameters["alpha2"]), (1.0 - recombProb) * Ni - muCorrection, jnp.sqrt(Variance) / widthCorrection, ) num_electron = jnp.clip(num_electron.round().astype(int), 0, jnp.inf) num_photon = jnp.clip(Nq - num_electron, 0, jnp.inf) return key, num_photon, num_electron
[docs]@export @takes_config( Constant( name="clip_lower_energy", type=float, default=0.5, help="Smallest energy considered in inference", ), Constant( name="clip_upper_energy", type=float, default=2.5, help="Largest energy considered in inference", ), ) class MonoEnergiesClipEff(Plugin): """For mono-energy-like yields constrain, we need to filter out the energies out of range. The method is set their weights to 0. """ depends_on = ["energy_center"] provides = ["eff"]
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, energy_center): mask = energy_center >= self.clip_lower_energy.value mask &= energy_center <= self.clip_upper_energy.value eff = jnp.where(mask, 1.0, 0.0) return key, eff
[docs]@export class BandEnergiesClipEff(Plugin): """For band-like yields constrain, we only need a placeholder here. Because BandEnergySpectra has already selected energy for us. """ depends_on = ["energy"] provides = ["eff"]
[docs] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, energy): eff = jnp.ones(len(energy)) return key, eff