from functools import partial
from jax import jit
from jax import numpy as jnp
from appletree import randgen
from appletree.plugin import Plugin
from appletree.config import takes_config, Constant, Map
from appletree.utils import exporter
export, __all__ = exporter()
[docs]@export
@takes_config(
Map(
name="energy_spectrum",
method="LERP",
default="_nr_spectrum.json",
help="Recoil energy spectrum",
),
)
class FixedEnergySpectra(Plugin):
depends_on = ["batch_size"]
provides = ["energy"]
[docs] @partial(jit, static_argnums=(0, 3))
def simulate(self, key, parameters, batch_size):
key, p = randgen.uniform(key, 0, 1.0, shape=(batch_size,))
energy = self.energy_spectrum.apply(p)
return key, energy
[docs]@export
@takes_config(
Constant(name="mono_energy", type=float, default=2.82, help="Mono energy delta function"),
)
class MonoEnergySpectra(Plugin):
depends_on = ["batch_size"]
provides = ["energy"]
# default energy is Ar37 K shell
[docs] @partial(jit, static_argnums=(0, 3))
def simulate(self, key, parameters, batch_size):
energy = jnp.full(batch_size, self.mono_energy.value)
return key, energy
[docs]@export
@takes_config(
Constant(
name="z_min",
type=float,
default=-133.97,
help="Z lower limit simulated in uniformly distribution",
),
Constant(
name="z_max",
type=float,
default=-13.35,
help="Z upper limit simulated in uniformly distribution",
),
Constant(
name="r_max",
type=float,
default=60.0,
help="Radius upper limit simulated in uniformly distribution",
),
)
class PositionSpectra(Plugin):
depends_on = ["batch_size"]
provides = ["x", "y", "z"]
[docs] @partial(jit, static_argnums=(0, 3))
def simulate(self, key, parameters, batch_size):
key, z = randgen.uniform(key, self.z_min.value, self.z_max.value, shape=(batch_size,))
key, r2 = randgen.uniform(key, 0, self.r_max.value**2, shape=(batch_size,))
key, theta = randgen.uniform(key, 0, 2 * jnp.pi, shape=(batch_size,))
r = jnp.sqrt(r2)
x = r * jnp.cos(theta)
y = r * jnp.sin(theta)
return key, x, y, z