"""
This module contains general purpose functions that are used throughout PBjam.
"""
from . import PACKAGEDIR
import os, jax, json
import jax.numpy as jnp
import numpy as np
from scipy.special import erf
from functools import partial
import scipy.special as sc
import scipy.integrate as si
from dataclasses import dataclass
import pandas as pd
import pbjam.distributions as dist
[docs]
class generalModelFuncs():
"""
A class containing general model functions for various models in PBjam.
This class in inherited by the model classes.
"""
def __init__(self):
pass
[docs]
@partial(jax.jit, static_argnums=(0,))
def obsOnlylnlikelihood(self, theta):
"""
Computes the log-likelihood using just the obs parameters.
This ignores all spectrum information.
Parameters
----------
theta : array-like
Parameter vector.
Returns
-------
lnlike : float
Log-likelihood value.
"""
thetaU = self.unpackParams(theta)
lnlike = self.addAddObsLike(thetaU)
return lnlike
[docs]
def setAddObs(self, keys):
"""
Set attribute containing additional observational data
Additional observational data other than the power spectrum goes here.
Can be Teff or bp_rp color, but may also be additional constraints on
e.g., numax, dnu.
Parameters
----------
keys : list
List of keys for additional observational data.
"""
self.addObs = {}
for key in keys:
self.addObs[key] = dist.normal(loc=self.obs[key][0],
scale=self.obs[key][1])
[docs]
def chi_sqr(self, mod):
""" Chi^2 2 dof likelihood
Evaluates the likelihood of observing the data given the model.
Parameters
----------
mod : jax device array
Spectrum model.
Returns
-------
L : float
Likelihood of the data given the model
"""
L = -jnp.sum(jnp.log(mod) + self.s / mod)
return L
[docs]
@partial(jax.jit, static_argnums=(0))
def lnlikelihood(self, theta):
"""
Calculate the log likelihood of the model given parameters and data.
Parameters
----------
theta : numpy.ndarray
Parameter values.
nu : numpy.ndarray
Array of frequency values.
Returns
-------
float :
Log-likelihood value.
"""
thetaU = self.unpackParams(theta)
lnlike = self.addAddObsLike(thetaU)
# Constraint from the periodogram
mod = self.model(thetaU)
lnlike += self.chi_sqr(mod)
return lnlike
[docs]
def addAddObsLike(self, thetaU):
""" Add the additional probabilities to likelihood
Adds the additional observational data likelihoods to the PSD likelihood.
Parameters
----------
p : list
Sampling parameters.
Returns
-------
lnp : float
The likelihood of a sample given the parameter PDFs.
"""
lnp = 0
for key in self.addObs.keys():
lnp += self.addObs[key].logpdf(thetaU[key])
return lnp
[docs]
def setLabels(self, addPriors, modelParLabels):
"""
Set parameter labels and categorize them based on priors.
Parameters
----------
priors : dict
Dictionary containing prior information for specific parameters.
Notes
-----
- Initializes default PCA and additional parameter lists.
- Checks if parameters are marked for PCA and not in priors; if so,
adds to PCA list.
- Otherwise, adds parameters to the additional list.
- Combines PCA and additional lists to create the final labels list.
- Identifies parameters that use a logarithmic scale and adds them to
logpars list.
"""
Path = os.path.join(*[PACKAGEDIR, 'data', 'parameters.json'])
with open(Path, "r") as read_file:
availableParams = json.load(read_file)
#with open("pbjam/data/parameters.json", "r") as read_file:
# availableParams = json.load(read_file)
self.variables = {key: availableParams[key] for key in modelParLabels}
# Default PCA parameters
self.pcaLabels = []
# Default additional parameters
self.addLabels = []
# If key appears in priors dict, override default and move it to add.
for key in self.variables.keys():
if self.variables[key]['pca'] and (key not in addPriors.keys()):
self.pcaLabels.append(key)
else:
self.addLabels.append(key)
self.logpars = [key for key in self.variables.keys() if self.variables[key]['log10']]
[docs]
def unpackSamples(self, samples=None):
"""
Unpack a set of parameter samples into a dictionary of arrays.
Parameters
----------
samples : array-like
A 2D array of shape (n, m), where n is the number of samples and
m is the number of parameters.
Returns
-------
S : dict
A dictionary containing the parameter values for each parameter
label.
Notes
-----
This method takes a 2D numpy array of parameter samples and unpacks each
sample into a dictionary of parameter values. The keys of the dictionary
are the parameter labels and the values are 1D numpy arrays containing
the parameter values for each sample.
Examples
--------
>>> class MyModel:
... def __init__(self):
... self.labels = ['a', 'b', 'c']
... def unpackParams(self, theta):
... return {'a': theta[0], 'b': theta[1], 'c': theta[2]}
... def unpackSamples(self, samples):
... S = {key: np.zeros(samples.shape[0]) for key in self.labels}
... for i, theta in enumerate(samples):
... theta_u = self.unpackParams(theta)
... for key in self.labels:
... S[key][i] = theta_u[key]
... return S
...
>>> model = MyModel()
>>> samples = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> S = model.unpackSamples(samples)
>>> print(S)
{'a': array([1., 4., 7.]), 'b': array([2., 5., 8.]), 'c': array([3., 6., 9.])}
"""
if samples is None:
samples = self.samples
S = {key: np.zeros(samples.shape[0]) for key in self.pcaLabels + self.addLabels}
jUnpack = jax.jit(self.unpackParams)
for i, theta in enumerate(samples):
thetaU = jUnpack(theta)
for key in thetaU.keys():
S[key][i] = thetaU[key]
return S
[docs]
def envelope(nu, env_height, numax, env_width, **kwargs):
""" Power of the seismic p-mode envelope
Computes the power at frequency nu in the oscillation envelope from a
Gaussian distribution. Used for computing mode heights.
Parameters
----------
nu : float
Frequency (in muHz).
hmax : float
Height of p-mode envelope (in SNR).
numax : float
Frequency of maximum power of the p-mode envelope (in muHz).
width : float
Width of the p-mode envelope (in muHz).
Returns
-------
h : float
Power at frequency nu (in SNR)
"""
return gaussian(nu, 2*env_height, numax, env_width)
[docs]
def modeUpdoot(result, sample, key, Nmodes):
"""
Updates the `result` dictionary with summary statistics and samples for a given key.
Parameters
----------
result : dict
The result dictionary to be updated. It should contain 'summary' and 'samples' sub-dictionaries.
sample : ndarray
The sample data to be added to the result. It is an array of shape (Nsamples, Nmodes).
key : str
The key under which the summary statistics and samples are to be stored in the result dictionary.
Nmodes : int
The number of modes (columns) in the sample array.
"""
result['summary'][key] = np.hstack((result['summary'][key], np.array([smryStats(sample[:, j]) for j in range(Nmodes)]).T))
result['samples'][key] = np.hstack((result['samples'][key], sample))
[docs]
def visell1(emm, inc):
"""
Computes the visibility for l=1 modes based on the azimuthal order (m) and inclination angle.
Parameters
----------
emm : int
Absolute value of the azimuthal order (m).
inc : float
Inclination angle in radians.
Returns
-------
float
Visibility for the l=1 modes.
"""
y = jax.lax.cond(emm == 0,
lambda : jnp.cos(inc)**2, # m = 0
lambda : jax.lax.cond(emm == 1,
lambda : 0.5*jnp.sin(inc)**2, # m = 1
lambda : jnp.nan # m > 1
))
return y
[docs]
def visell2(emm, inc):
"""
Computes the visibility for l=2 modes based on the azimuthal order (m) and inclination angle.
Parameters
----------
emm : int
Absolute value of the azimuthal order (m).
inc : float
Inclination angle in radians.
Returns
-------
float
Visibility for the l=2 modes.
"""
y = jax.lax.cond(emm == 0,
lambda : 1/4 * (3 * jnp.cos(inc)**2 - 1)**2, # m = 0
lambda : jax.lax.cond(emm == 1,
lambda : 3/8 * jnp.sin(2 * inc)**2, # m = 1
lambda : jax.lax.cond(emm == 2,
lambda : 3/8 * jnp.sin(inc)**4, # m = 2
lambda : jnp.nan # m > 2
)))
return y
[docs]
def visell3(emm, inc):
"""
Computes the visibility for l=3 modes based on the azimuthal order (m) and inclination angle.
Parameters
----------
emm : int
Absolute value of the azimuthal order (m).
inc : float
Inclination angle in radians.
Returns
-------
float
Visibility for the l=3 modes.
"""
y = jax.lax.cond(emm == 0,
lambda : 1/64 * (5 * jnp.cos(3 * inc) + 3 * jnp.cos(inc))**2, # m = 0
lambda : jax.lax.cond(emm == 1,
lambda : 3/64 * (5 * jnp.cos(2 * inc) + 3)**2 * jnp.sin(inc)**2, # m =1
lambda : jax.lax.cond(emm == 2,
lambda : 15/8 * jnp.cos(inc)**2 * jnp.sin(inc)**4, # m = 2
lambda : jax.lax.cond(emm == 3,
lambda : 5/16 * jnp.sin(inc)**6, # m = 3
lambda : np.nan # m > 3
))))
return y
[docs]
def visibility(ell, m, inc):
"""
Computes the visibility of a mode based on its degree (l), azimuthal order (m), and inclination angle.
Parameters
----------
ell : int
The degree of the mode.
m : int
The azimuthal order of the mode.
inc : float
The inclination angle in radians.
Returns
-------
float
Visibility for the given mode.
"""
emm = abs(m)
y = jax.lax.cond(ell == 0,
lambda : 1.,
lambda : jax.lax.cond(ell == 1,
lambda : visell1(emm, inc),
lambda : jax.lax.cond(ell == 2,
lambda : visell2(emm, inc),
lambda : jax.lax.cond(ell == 3,
lambda : visell3(emm, inc),
lambda : jnp.nan))))
return y
[docs]
def updatePrior(ID, R, addObs):
"""
Updates the prior data by adding a new entry based on the provided results and additional observations.
Parameters
----------
ID : str
The identifier for the new entry.
R : dict
A dictionary containing the results. Keys should correspond to parameter names, and values are typically arrays or lists where the first element is used.
addObs : dict
A dictionary containing additional observational data, such as 'teff' and 'bp_rp'.
Returns
-------
DataFrame
A pandas DataFrame containing the updated prior data.
Notes
-----
- The function reads the existing prior data from `pbjam/data/prior_data.csv`.
- It filters out certain keys from the results that are not meant to be updated.
- It applies a log10 transformation to certain parameters in the results before updating the prior data.
- The new entry is then added to the prior DataFrame and returned.
"""
prior = pd.read_csv('pbjam/data/prior_data.csv')
badkeys = ['freq', 'height', 'width', 'teff', 'bp_rp', 'nurot_c', 'inc', 'H3_power', 'H3_nu', 'H3_exp', 'shot']
r = {key: [R[key][0]] for key in R.keys() if key not in badkeys}
for key in r.keys():
if key in ['eps_p', 'eps_g', 'bp_rp', 'H1_exp', 'H2_exp']:
continue
else:
r[key] = np.log10(r[key])
r['ID'] = ID
r['teff'] = np.log10(addObs['teff'][0])
r['bp_rp'] = addObs['bp_rp'][0]
row = pd.DataFrame.from_dict(r)
prior = prior.append(row, ignore_index=True)
return prior
[docs]
@dataclass
class constants:
"""
A dataclass for storing astrophysical constants and conversion factors.
Attributes
----------
nu_to_omega : float
Conversion factor from frequency (muHz) to angular frequency (radians/muHz). Default is `2 * jnp.pi / 1e6`.
"""
# Teff0: float = 5777 # K
# TeffRed0: float = 8907 # K
# numax0: float = 3090 # muHz
# Delta_Teff: float = 1550 # K
# Henv0: float = 0.1 # ppm^2/muHz
nu_to_omega: float = 2 * jnp.pi / 1e6 # radians/muHz
c : float = 299792.458 #km/s
# dnu0: float = 135.9 # muHz
# logg0 : float = 4.43775 # log10(2.74e4)
[docs]
def smryStats(y):
"""
Computes summary statistics (median and uncertainty) for a given array of samples.
Parameters
----------
y : array-like
The input array of samples.
Returns
-------
ndarray
An array containing the median and the average absolute deviation.
Notes
-----
- The function computes percentiles corresponding to the 16th, 50th, and 84th percentiles.
- The uncertainty is the mean of the differences between the median and the 16th and 84th percentiles.
"""
p = np.array([0.5 - sc.erf(n/np.sqrt(2))/2 for n in range(-1, 2)])[::-1]*100
u = np.percentile(y, p)
return np.array([u[1], np.mean(np.diff(u))])
[docs]
def attenuation(f, nyq):
""" The sampling attenuation
Determine the attenuation of the PSD due to the discrete sampling of the
variability.
Parameters
----------
f : np.array
Frequency axis of the PSD.
nyq : float
The Nyquist frequency of the observations.
Returns
-------
eta : np.array
The attenuation at each frequency.
"""
eta = jnp.sinc(0.5 * f/nyq)
return eta
[docs]
def lor(nu, nu0, h, w):
""" Lorentzian to describe an oscillation mode.
Parameters
----------
nu0 : float
Frequency of lorentzian (muHz).
h : float
Height of the lorentizan (SNR).
w : float
Full width of the lorentzian (muHz).
Returns
-------
mode : ndarray
The SNR as a function frequency for a lorentzian.
"""
return h / (1.0 + 4.0/w**2*(nu - nu0)**2)
[docs]
def getCurvePercentiles(x, y, cdf=None, percentiles=None):
""" Compute percentiles of value along a curve
Computes the cumulative sum of y, normalized to unit maximum. The returned
percentiles values are where the cumulative sum exceeds the requested
percentiles.
Parameters
----------
x : array
Support for y.
y : array
Array
percentiles: array
Returns
-------
percs : array
Values of y at the requested percentiles.
"""
if percentiles is None:
percentiles = [0.5 - sc.erf(n/np.sqrt(2))/2 for n in range(-2, 3)][::-1]
y /= np.trapz(y, x)
if cdf is None:
cdf = si.cumulative_trapezoid(y, x, initial=0)
cdf /= cdf.max()
percs = np.zeros(len(percentiles))
for i, p in enumerate(percentiles):
q = x[cdf >= p]
percs[i] = q[0]
return np.sort(percs)
class jaxInterp1D():
def __init__(self, xp, fp, left=None, right=None, period=None):
""" Replacement for scipy.interpolate.interp1d in jax
Wraps the jax.numpy.interp in a callable class instance.
Parameters
----------
xp : jax device array
The x-coordinates of the data points, must be increasing if argument
period is not specified. Otherwise, xp is internally sorted after
normalizing the periodic boundaries with xp = xp % period.
fp : jax device array
The y-coordinates of the data points, same length as xp.
left : float
Value to return for x < xp[0], default is fp[0].
right: float
Value to return for x > xp[-1], default is fp[-1].
period : float
A period for the x-coordinates. This parameter allows the proper
interpolation of angular x-coordinates. Parameters left and right
are ignored if period is specified.
"""
self.xp = xp
self.fp = fp
self.left = left
self.right = right
self.period = period
@partial(jax.jit, static_argnums=(0,))
def __call__(self, x):
""" Interpolate onto new axis
Parameters
----------
x : jax device array
The x-coordinates at which to evaluate the interpolated values.
Returns
-------
y : jax device array
The interpolated values, same shape as x.
"""
return jnp.interp(x, self.xp, self.fp, self.left, self.right, self.period)
[docs]
class bibliography():
""" A class for managing references used when running PBjam.
This is inherited by session and star.
Attributes
----------
bibfile : str
The pathname to the pbjam references list.
_reflist : list
List of references that is updated when new functions are used.
bibdict : dict
Dictionary of bib items from the PBjam reference list.
"""
def __init__(self):
self.bibfile = os.path.join(*[PACKAGEDIR, 'data', 'pbjam_references.bib'])
self._reflist = []
self.bibdict = self._parseBibFile()
def _findBlockEnd(self, string, idx):
""" Find block of {}
Go through string starting at idx, and find the index corresponding to
the curly bracket that closes the opening curly bracket.
So { will be closed by } even if there are more curly brackets in
between.
Note
----
This also works in reverse, so opening with } will be closed by {.
Parameters
----------
string : str
The string to parse.
idx : int
The index in string to start at.
"""
a = 0
for i, char in enumerate(string[idx:]):
if char == '{':
a -= 1
elif char == '}':
a += 1
if (i >= len(string[idx:])-1) and (a != 0):
print('Warning: Reached end of bibtex file with no closing curly bracket. Your .bib file may be formatted incorrectly. The reference list may be garbled.')
if a ==0:
break
if string[idx+i] == '{':
print('Warning: Ended on an opening bracket. Your .bib file may be formatted incorrectly.')
return idx+i
def _parseBibFile(self):
""" Put contents of a bibtex file into a dictionary.
Takes the contents of the PBjam bib file and stores it as a dictionary
of bib items.
Article shorthand names (e.g., @Article{shorthand_name) become the
dictionary key, similar to the way LaTeX handles citations.
Returns
-------
bibdict : dict
Dictionary of bib items from the PBjam reference list.
"""
with open(self.bibfile, 'r') as bib:
bib = bib.read()
openers = ['@ARTICLE', '@article', '@Article'
'@MISC', '@misc',
'@BOOK', '@book',
'@SOFTWARE', '@software',
'@INPROCEEDINGS', '@inproceedings'] #Update this if other types of entries are added to the bib file.
bibitems = []
safety = 0
while any([x in bib for x in openers]):
for opener in openers:
try:
start = bib.index(opener)
end = self._findBlockEnd(bib, start+len(opener))
bibitems.append([bib[start:end+1]])
bib = bib[:start] + bib[end+1:]
except:
pass
safety += 1
if safety > 1000:
break
bibitems = np.unique(bibitems)
bibdict = {}
for i, item in enumerate(bibitems):
key = item[item.index('{')+1:item.index(',')]
bibdict[key] = item
return bibdict
def _addRef(self, ref):
""" Add reference from bibdict to active list
The reference must be listed in the PBjam bibfile.
Parameters
----------
ref : str
Bib entry to add to the list
"""
if isinstance(ref, list):
for r in ref:
self._reflist.append(self.bibdict[r])
else:
self._reflist.append(self.bibdict[ref])
def __call__(self, bibfile=None):
""" Print the list of references used.
Parameters
----------
bibfile : str
Filepath to print the list of bib items.
"""
out = '\n\n'.join(np.unique(self._reflist))
print('References used in this run.')
print(out)
if bibfile is not None:
with open(bibfile, mode='w') as file_object: #robustify the filepath so it goes to the right place all the time.
print(out, file=file_object)
[docs]
def getNormalPercentiles(X, nsigma=2, **kwargs):
""" Get percentiles of an distribution
Compute the percentiles corresponding to sigma=1,2,3.. including the
median (50th), of an array.
Parameters
----------
X : numpy.array()
Array to find percentiles of
sigma : int, optional.
Sigma values to compute the percentiles of, e.g. 68% 95% are 1 and 2
sigma, etc. Default is 2.
kwargs : dict
Arguments to be passed to numpy.percentile
returns
-------
percentiles : numpy.array()
Numpy array of percentile values of X.
"""
a = np.array([0.5*(1+erf(z/np.sqrt(2))) for z in range(nsigma+1)])
percs = np.append((1-a[::-1][:-1]),a)*100
return np.percentile(X, percs, **kwargs)
[docs]
def to_log10(x, xerr):
""" Transform to value to log10
Takes a value and related uncertainty and converts them to logscale.
Approximate.
Parameters
----------
x : float
Value to transform to logscale
xerr : float
Value uncertainty
Returns
-------
logval : list
logscaled value and uncertainty
"""
if xerr > 0:
return [np.log10(x), xerr/x/np.log(10.0)]
return [x, xerr]
[docs]
def normal(x, mu, sigma):
""" Evaluate logarithm of normal distribution (not normalized!!)
Evaluates the logarithm of a normal distribution at x.
Inputs
------
x : float
Values to evaluate the normal distribution at.
mu : float
Distribution mean.
sigma : float
Distribution standard deviation.
Returns
-------
y : float
Logarithm of the normal distribution at x
"""
return gaussian(x, 1/jnp.sqrt(2*jnp.pi*sigma**2), mu, sigma)
[docs]
def gaussian(x, A, mu, sigma):
"""
Computes the Gaussian function.
Parameters
----------
x : array-like
Input array of x values.
A : float
Amplitude of the Gaussian.
mu : float
Mean (center) of the Gaussian.
sigma : float
Standard deviation (width) of the Gaussian.
Returns
-------
array-like
The computed Gaussian function values.
"""
return A*jnp.exp(-(x-mu)**2/(2*sigma**2))