Source code for appletree.hist

from jax import numpy as jnp
from jax import jit, vmap

from appletree.utils import exporter

export, __all__ = exporter(export_self=False)


[docs]@export @jit def make_hist_mesh_grid(sample, bins=10, weights=None): """Same as jnp.histogramdd.""" hist, _ = jnp.histogramdd(sample, bins=bins, weights=weights) return hist
[docs]@export @jit def make_hist_irreg_bin_1d(sample, bins, weights): """Make a histogram with irregular binning. Args: sample: array with shape N. bins: array with shape M. weights: array with shape (N,). """ ind = jnp.searchsorted(bins, sample) hist = jnp.zeros(len(bins) + 1) hist = hist.at[ind].add(weights) return hist[1:-1]
[docs]@export @jit def make_hist_irreg_bin_2d(sample, bins_x, bins_y, weights): """Make a histogram with irregular binning. Args: sample: array with shape (N, 2). bins_x: array with shape (M1, ). bins_y: array with shape (M1-1, M2). weights: array with shape (N,). """ x = sample[:, 0] y = sample[:, 1] ind_x = jnp.searchsorted(bins_x, x) ind_y = vmap(jnp.searchsorted, (0, 0), 0)(bins_y[ind_x - 1], y) bin_ind = jnp.stack((ind_x, ind_y)) output_shape = (len(bins_x) + 1, bins_y.shape[-1] + 1) hist = jnp.zeros(output_shape) hist = hist.at[tuple(bin_ind)].add(weights) return hist[1:-1, 1:-1]