Source code for pbjam.distributions

"""
The distributions module contains a selection of probability densities and methods, 
which are primarily used to build prior probability densities in PBjam. These classes
to some extent mimic those of the scipy.stats package in terms of their inputs and 
functionality.
"""

import jax.numpy as jnp
from functools import partial
import jax
import numpy as np
import jax.scipy.special as jsp
from pbjam import jar
import statsmodels.api as sm

def makeDistObject(data, **kwargs):

    ppfs, pdfs, logpdfs, cdfs = getQuantileFuncs(data, **kwargs)

    D = []
    for i in range(len(ppfs)):
        D.append(distribution(ppfs[i], pdfs[i], logpdfs[i], cdfs[i]))

    return D

[docs] def getQuantileFuncs(data, cut=5, densityScale=30, **kwargs): """ Compute distribution methods for arbitrary distributions. All distributions are treated as separable. Parameters ---------- data : array Array of samples to compute the distribution functions of. Returns ------- ppfs : list List of callable functions to evaluate the ppfs of the samples. pdfs : list List of callable functions to evaluate the pdfs of the samples. logpdfs : list List of callable functions to evaluate the logpdfs of the samples. cdfs : list List of callable functions to evaluate the cdfs of the samples. """ ppfs = [] pdfs = [] cdfs = [] logpdfs = [] for i in range(data.shape[1]): kde = sm.nonparametric.KDEUnivariate(np.array(data[:, i]).real) kde.fit(cut=cut) # TODO currently sampling the unit interval at 5120 points, is this # enough? Increasing doesn't seem to impact evaluation time of the ppf. A = jnp.linspace(0, 1, densityScale*len(kde.cdf)) cdfs.append(kde.cdf) # The icdf from statsmodels is only evaluated on the input values, # not the complete support of the kde-pdf which may be wider because # of the kernel bandwidth. x = np.linspace(kde.support[0], kde.support[-1], len(A)) Q = jar.getCurvePercentiles(x, kde.evaluate(x), percentiles=A) ppfs.append(jar.jaxInterp1D(A, Q)) # TODO should increase resolution on pdf like on the ppf pdfs.append(jar.jaxInterp1D(kde.support, kde.evaluate(kde.support))) logpdfs.append(jar.jaxInterp1D(kde.support, jnp.log(kde.evaluate(kde.support)))) return ppfs, pdfs, logpdfs, cdfs
class beta(): def __init__(self, a=1, b=1, loc=0, scale=1): """ beta distribution class Create instances a probability density which follows the beta distribution. Parameters ---------- a : float The first shape parameter of the beta distribution. b : float The second shape parameter of the beta distribution. loc : float The lower limit of the beta distribution. The probability at this limit and below is 0. scale : float The width of the beta distribution. Effectively sets the upper bound for the distribution, which is loc+scale. eps : float, optional Small fudge factor to avoid dividing by 0 etc. """ # Turn init args into attributes self.__dict__.update((k, v) for k, v in locals().items() if k not in ['self']) self.logfac = jsp.gammaln(self.a + self.b) - jsp.gammaln(self.a) - jsp.gammaln(self.b) - jnp.log(self.scale) self.fac = jnp.exp(self.logfac) self.am1 = self.a - 1 self.bm1 = self.b - 1 self._set_stdatt() def rv(self): """ Draw random variable from distribution Returns ------- x : float Random variable drawn from the distribution """ u = np.random.uniform(0, 1) x = self.ppf(u) return x def _set_stdatt(self): """ Set mean and median for the distribution """ x = jnp.linspace(self.ppf(1e-6), self.ppf(1-1e-6), 1000) self.mean = jnp.trapezoid(x * jnp.array([self.pdf(_x) for _x in x]), x) self.median = self.ppf(0.5) @partial(jax.jit, static_argnums=(0,)) def _transformx(self, x): """ Transform x Translates and scales the input x to the unit interval according to the loc and scale parameters. Parameters ---------- x : float Input support for the probability density. Returns ------- _x : float x translated and scaled to the range 0 to 1. """ return (x - self.loc) / self.scale @partial(jax.jit, static_argnums=(0,)) def _inverse_transform(self, x): """ Invert scaling on input Parameters ---------- x : float Input Returns ------- _x : float Scaled x. """ return x * self.scale + self.loc @partial(jax.jit, static_argnums=(0,)) def pdf(self, x, norm=True): """ Return PDF Returns the beta distribution at x. The distribution is normalized to unit integral by default so that it may be used as a PDF. In some cases the normalization is not necessary, and since it's marginally slower it may as well be left out. Parameters ---------- x : array Input support for the probability density. norm : bool, optional If true, returns the normalized beta distribution. The default is True. Returns ------- y : array The value of the beta distribution at x. """ _x = self._transformx(x) T = jax.lax.lt(_x, 0.) | jax.lax.lt(1., _x) y = jax.lax.cond(T, lambda : -jnp.inf, lambda : _x**self.am1 * (1 - _x)**self.bm1) if norm: return y * self.fac else: return y @partial(jax.jit, static_argnums=(0,)) def logpdf(self, x, norm=True): """ Return log-PDF Returns the log of the beta distribution at x. The distribution is normalized to unit integral (in linear units) by default so that it may be used as a PDF. In some cases the normalization is not necessary, and since it's marginally slower it may as well be left out. Parameters ---------- x : array Input support for the probability density. norm : bool, optional If true, returns the normalized beta distribution. The default is True. Returns ------- y : array The value of the logarithm of the beta distribution at x. """ x = jnp.array(x) _x = self._transformx(x) T = jax.lax.lt(_x, 0.) | jax.lax.lt(1., _x) y = jax.lax.cond(T, lambda : -jnp.inf, lambda : self.am1 * jnp.log(_x) + self.bm1 * jnp.log(1-_x)) if norm: return y + self.logfac else: return y def cdf(self, x): _x = self._transformx(x) y = jsp.betainc(self.a, self.b, _x) y = y.at[_x<=0].set(0) y = y.at[_x>=1].set(1) return y @partial(jax.jit, static_argnums=(0,)) def ppf(self, y): _x = self.betaincinv(self.a, self.b, y) x = self._inverse_transform(_x) return x @partial(jax.jit, static_argnums=(0,)) def update_x(self, x, a, b, p, a1, b1, afac): err = jsp.betainc(a, b, x) - p t = jnp.exp(a1 * jnp.log(x) + b1 * jnp.log(1.0 - x) + afac) u = err/t tmp = u * (a1 / x - b1 / (1.0 - x)) t = u/(1.0 - 0.5 * jnp.clip(tmp, a_max=1.0)) x -= t x = jnp.where(x <= 0., 0.5 * (x + t), x) x = jnp.where(x >= 1., 0.5 * (x + t + 1.), x) return x, t @partial(jax.jit, static_argnums=(0,)) def func_1(sefl, a, b, p): pp = jnp.where(p < .5, p, 1. - p) t = jnp.sqrt(-2. * jnp.log(pp)) x = (2.30753 + t * 0.27061) / (1.0 + t * (0.99229 + t * 0.04481)) - t x = jnp.where(p < .5, -x, x) al = (jnp.power(x, 2) - 3.0) / 6.0 h = 2.0 / (1.0 / (2.0 * a - 1.0) + 1.0 / (2.0 * b - 1.0)) w = (x * jnp.sqrt(al + h) / h)-(1.0 / (2.0 * b - 1) - 1.0/(2.0 * a - 1.0)) * (al + 5.0 / 6.0 - 2.0 / (3.0 * h)) return a / (a + b * jnp.exp(2.0 * w)) @partial(jax.jit, static_argnums=(0,)) def func_2(sefl, a, b, p): lna = jnp.log(a / (a + b)) lnb = jnp.log(b / (a + b)) t = jnp.exp(a * lna) / a u = jnp.exp(b * lnb) / b w = t + u return jnp.where(p < t/w, jnp.power(a * w * p, 1.0 / a), 1. - jnp.power(b *w * (1.0 - p), 1.0/b)) @partial(jax.jit, static_argnums=(0,)) def compute_x(self, p, a, b): return jnp.where(jnp.logical_and(a >= 1.0, b >= 1.0), self.func_1(a, b, p), self.func_2(a, b, p)) @partial(jax.jit, static_argnums=(0,)) def betaincinv(self, a, b, p): a1 = a - 1.0 b1 = b - 1.0 ERROR = 1e-8 p = jnp.clip(p, a_min=0., a_max=1.) x = jnp.where(jnp.logical_or(p <= 0.0, p >= 1.), p, self.compute_x(p, a, b)) afac = - jsp.betaln(a, b) stop = jnp.logical_or(x == 0.0, x == 1.0) for i in range(10): x_new, t = self.update_x(x, a, b, p, a1, b1, afac) x = jnp.where(stop, x, x_new) stop = jnp.where(jnp.logical_or(jnp.abs(t) < ERROR * x, stop), True, False) return x class distribution(): def __init__(self, ppf, pdf, logpdf, cdf): """ Generic distribution object Creates wrapper for a set of methods that return the pdf logpdf, ppf and cdf of an arbitrary distribution. Parameters ---------- ppf : callable Function that, given a value between 0 and 1, returns a sample drawn from the pdf. pdf : callable Function that, given x returns the value of pdf(x) logpdf : callable Function that, given x return the value of log(pdf(x)). cdf : callable Function that, given x returns the value of cdf(x). """ self.pdf = pdf self.ppf = ppf self.logpdf = logpdf self.cdf = cdf self._set_stdatt() def rv(self): """ Draw random variable from distribution Returns ------- x : float Random variable drawn from the distribution """ u = np.random.uniform(0, 1) x = self.ppf(u) return x def _set_stdatt(self): """ Set mean and median for the distribution """ x = jnp.linspace(self.ppf(1e-6), self.ppf(1-1e-6), 1000) self.mean = jnp.trapezoid(x * jnp.array([self.pdf(_x) for _x in x]), x) self.median = self.ppf(0.5) class uniform(): def __init__(self, loc=0, scale=1): """ Uniform distribution Emulates the scipy.stats class, but is jaxed. Parameters ---------- loc : float Left side of the uniform distribution scale : float Width of the uniform distribution, such that the right side is loc+scale Attributes ---------- mu : float Mean (loc+scale/2) of the distribution. """ # Turn init args into attributes self.__dict__.update((k, v) for k, v in locals().items() if k not in ['self']) self.a = self.loc self.b = self.loc + self.scale self.mean = 0.5 * (self.a + self.b) self._set_stdatt() def rv(self): """ Draw random variable from distribution Returns ------- x : float Random variable drawn from the distribution """ u = np.random.uniform(0, 1) x = self.ppf(u) return x def _set_stdatt(self): """ Set mean and median for the distribution """ x = jnp.linspace(self.ppf(1e-6), self.ppf(1-1e-6), 1000) self.mean = jnp.trapezoid(x * jnp.array([self.pdf(_x) for _x in x]), x) self.median = self.ppf(0.5) @partial(jax.jit, static_argnums=(0,)) def pdf(self, x): """ The probability density of the distribution Parameters ---------- x : float Evaluate the pdf at x Returns ------- y : float The probability at x """ T = jax.lax.lt(x, self.a) | jax.lax.lt(self.b, x) y = jax.lax.cond(T, lambda : 0., lambda : 1./self.scale) return y @partial(jax.jit, static_argnums=(0,)) def logpdf(self, x): """ The log-probability of the distribution Parameters ---------- x : float Evaluate the pdf at x Returns ------- y : float The probability at x """ T = jax.lax.lt(x, self.a) | jax.lax.lt(self.b, x) y = jax.lax.cond(T, lambda : -jnp.inf, lambda : -jnp.log(self.scale)) return y @partial(jax.jit, static_argnums=(0,)) def cdf(self, x): """ The cumulative probability distribution function Parameters ---------- x : float Evaluate the cdf at x Returns ------- y : float The cumulative probability at x """ y = (x - self.a) / (self.b - self.a) return y @partial(jax.jit, static_argnums=(0,)) def ppf(self, y): """ The point percent (quantile) function. Parameters ---------- y : float Evaluate the ppf at y. Returns ------- x : float The support of the pdf at pdf = y. """ y = jnp.array(y) x = y * (self.b - self.a) + self.a return x class normal(): def __init__(self, loc=0, scale=1): """ normal distribution class Create instances a probability density which follows the normal distribution. Parameters ---------- mu : float The mean of the normal distribution. sigma : float The standard deviation of the normal distribution. """ # Turn init args into attributes self.__dict__.update((k, v) for k, v in locals().items() if k not in ['self']) self.fac = -0.5 / self.scale**2 self.norm = 1 / (jnp.sqrt(2*jnp.pi) * self.scale) self.lognorm = jnp.log(self.norm) self._set_stdatt() def rv(self): """ Draw random variable from distribution Returns ------- x : float Random variable drawn from the distribution """ u = np.random.uniform(0, 1) x = self.ppf(u) return x def _set_stdatt(self): """ Set mean and median for the distribution """ x = jnp.linspace(self.ppf(1e-6), self.ppf(1-1e-6), 1000) self.mean = jnp.trapezoid(x * jnp.array([self.pdf(_x) for _x in x]), x) self.median = self.ppf(0.5) @partial(jax.jit, static_argnums=(0,)) def pdf(self, x, norm=True): """ Return PDF Returns the normal distribution at x. The distribution is normalized to unit integral by default so that it may be used as a PDF. In some cases the normalization is not necessary, and since it's marginally slower it may as well be left out. Parameters ---------- x : array Input support for the probability density. norm : bool, optional If true, returns the normalized normal distribution. The default is True. Returns ------- y : array The value of the normal distribution at x. """ y = jnp.exp( self.fac * (x - self.loc)**2) Y = jax.lax.cond(norm, lambda y: y * self.norm, lambda y: y , y) return Y @partial(jax.jit, static_argnums=(0,)) def logpdf(self, x, norm=True): """ Return log-PDF Returns the log of the normal distribution at x. The distribution is normalized to unit integral (in linear units) by default so that it may be used as a PDF. In some cases the normalization is not necessary, and since it's marginally slower it may as well be left out. Parameters ---------- x : array Input support for the probability density. norm : bool, optional If true, returns the normalized normal distribution. The default is True. Returns ------- y : array The value of the logarithm of the normal distribution at x. """ y = self.fac * (x - self.loc)**2 Y = jax.lax.cond(norm, lambda y: y + self.lognorm, lambda y: y, y) return Y @partial(jax.jit, static_argnums=(0,)) def cdf(self, x): y = 0.5 * (1 + jsp.erf((x-self.loc)/(jnp.sqrt(2)*self.scale))) return y @partial(jax.jit, static_argnums=(0,)) def ppf(self, y): x = self.loc + self.scale*jnp.sqrt(2)*jsp.erfinv(2*y-1) return x class truncsine(): def __init__(self,): """ Sine truncated between 0 and pi/2 """ self._set_stdatt() def rv(self): """ Draw random variable from distribution Returns ------- x : float Random variable drawn from the distribution """ u = np.random.uniform(0, 1) x = self.ppf(u) return x def _set_stdatt(self): """ Set mean and median for the distribution """ x = jnp.linspace(self.ppf(1e-6), self.ppf(1-1e-6), 1000) self.mean = jnp.trapezoid(x * jnp.array([self.pdf(_x) for _x in x]), x) self.median = self.ppf(0.5) @partial(jax.jit, static_argnums=(0,)) def pdf(self, x): """ The probability density of the distribution Parameters ---------- x : float Evaluate the pdf at x Returns ------- y : float The probability at x """ T = jax.lax.lt(x, 0.) | jax.lax.lt(jnp.pi/2, x) y = jax.lax.cond(T, lambda : 0., lambda : jnp.sin(x)) return y @partial(jax.jit, static_argnums=(0,)) def logpdf(self, x): """ The log-probability of the distribution Parameters ---------- x : float Evaluate the pdf at x Returns ------- y : float The probability at x """ T = jax.lax.lt(x, 0.) | jax.lax.lt(jnp.pi/2., x) y = jax.lax.cond(T, lambda : -jnp.inf, lambda : jnp.log(jnp.sin(x))) return y @partial(jax.jit, static_argnums=(0,)) def cdf(self, x): """ The cumulative probability distribution function Parameters ---------- x : float Evaluate the cdf at x Returns ------- y : float The cumulative probability at x """ y = 1 + jnp.cos(x-jnp.pi) return y @partial(jax.jit, static_argnums=(0,)) def ppf(self, y): """ The point percent (quantile) function. Parameters ---------- y : float Evaluate the ppf at y. Returns ------- x : float The support of the pdf at pdf = y. """ x = jnp.arccos(1-y) return x class randint(): def __init__(self, low, high): self.low = low self.high = high self.diff = self.high - self.low self.ints = jnp.arange(low, high, 1) self._set_stdatt() def rv(self): """ Draw random variable from distribution Returns ------- x : float Random variable drawn from the distribution """ u = np.random.uniform(0, 1) x = self.ppf(u) return x def _set_stdatt(self): """ Set mean and median for the distribution """ x = jnp.linspace(self.ppf(1e-6), self.ppf(1-1e-6), 1000) self.mean = jnp.trapezoid(x * jnp.array([self.pdf(_x) for _x in x]), x) self.median = self.ppf(0.5) @partial(jax.jit, static_argnums=(0,)) def pdf(self, x): """_summary_ Parameters ---------- x : _type_ _description_ Returns ------- _type_ _description_ """ T = jnp.sum(self.ints == x).astype(bool) y = jax.lax.cond(T, lambda : 1/self.diff, lambda : 0.) return y @partial(jax.jit, static_argnums=(0,)) def logpdf(self, x): T = jnp.sum(self.ints == x).astype(bool) y = jax.lax.cond(T, lambda : -self.diff, lambda : -jnp.inf) return y @partial(jax.jit, static_argnums=(0,)) def cdf(self, x): k = jnp.floor(x) return (k - self.low + 1.) / self.diff @partial(jax.jit, static_argnums=(0,)) def ppf(self, q): vals = jnp.ceil(q * self.diff + self.low) - 1 vals1 = (vals - 1).clip(self.low, self.high) temp = self.cdf(vals1) return jnp.floor(jnp.where(temp >= q, vals1, vals))