Source code for pbjam.plotting

""" 

This module contains a general set of plotting methods for that are inherited by
the different classes of PBjam, so that they can be used to show the status of 
each step that has been performed. 

"""

import matplotlib.pyplot as plt
import astropy.convolution as conv
import os, corner, warnings, logging
import numpy as np
import jax

ellColors = {0: 'C1', 1: 'C4', 2: 'C3', 3: 'C5'}

[docs] def smooth_power(freq, power, smooth_filter_width): """Smooths the input power array with a Box1DKernel from astropy Parameters ---------- power : array-like Array of power values smooth_filter_width : float filter width Returns ------- array-like Smoothed power """ fac = max([1, smooth_filter_width / (freq[1] - freq[0])]) kernel = conv.Gaussian1DKernel(stddev=np.array(fac)) smoo = conv.convolve_fft(power, kernel) return smoo
[docs] def echelle(freq, power, dnu, fmin=0.0, fmax=None, offset=0.0, sampling=0.001): """Calculates the echelle diagram. Use this function if you want to do some more custom plotting. Parameters ---------- freq : array-like Frequency values power : array-like Power values for every frequency dnu : float Value of deltanu fmin : float, optional Minimum frequency to calculate the echelle at, by default 0. fmax : float, optional Maximum frequency to calculate the echelle at. If none is supplied, will default to the maximum frequency passed in `freq`, by default None offset : float, optional An offset to apply to the echelle diagram, by default 0.0 Returns ------- array-like The x, y, and z values of the echelle diagram. """ if fmax is None: fmax = freq[-1] fmin = fmin - offset fmax = fmax - offset freq = freq - offset if fmin <= 0.0: fmin = 0.0 else: fmin = fmin - (fmin % dnu) # trim data index = (freq >= fmin) & (freq <= fmax) trimx = freq[index] samplinginterval = np.median(np.diff(trimx)) * sampling samplinginterval = dnu/round(dnu/samplinginterval) xp = np.arange(fmin, fmax + dnu, samplinginterval) yp = np.interp(xp, freq, power) n_stack = int((fmax - fmin) / dnu) # Number of rows n_element = int(dnu / samplinginterval) # Number of chunks in a row morerow = 2 arr = np.arange(0, n_stack) * dnu arr2 = np.array([arr, arr]) yn = np.reshape(arr2, len(arr) * 2, order="F") yn = np.insert(yn, 0, 0.0) yn = np.append(yn, n_stack * dnu) + fmin + offset xn = np.arange(0, n_element + 1) / n_element * dnu z = np.zeros([n_stack * morerow, n_element]) for i in range(n_stack): for j in range(i * morerow, (i + 1) * morerow): z[j, :] = yp[n_element * (i) : n_element * (i + 1)] return xn, yn, z
[docs] def plot_echelle(freq, power, numax, dnu, ax=None, cmap="Blues", scale=None, interpolation=None, smooth=False, smooth_filter_width=50, offset=0.0, **kwargs): """Plots the echelle diagram. Parameters ---------- freq : numpy array Frequency values power : array-like Power values for every frequency dnu : float Value of deltanu ax : matplotlib.axes._subplots.AxesSubplot, optional A matplotlib axes to plot into. If no axes is provided, a new one will be generated, by default None cmap : str, optional A matplotlib colormap, by default 'BuPu' scale : str, optional either 'sqrt' or 'log' or None. Scales the echelle to bring out more features, by default 'sqrt' interpolation : str, optional Type of interpolation to perform on the echelle diagram through matplotlib.pyplot.imshow, by default 'none' smooth_filter_width : float, optional Amount by which to smooth the power values, using a Box1DKernel **kwargs : dict Dictionary of arguments to be passed to `echelle.echelle` Returns ------- matplotlib.axes._subplots.AxesSubplot The plotted echelle diagram on the axes """ if smooth: power = smooth_power(freq, power, smooth_filter_width) echx, echy, echz = echelle(freq, power, dnu, offset=offset, **kwargs) if scale is not None: if scale == "log": echz = np.log10(echz) elif scale == "sqrt": echz = np.sqrt(echz) if ax is None: fig, ax = plt.subplots() ax.imshow(echz, aspect="auto", extent=(echx.min(), echx.max(), echy.min(), echy.max()), origin="lower", cmap=cmap, interpolation=interpolation, ) ax.set_xlabel(f"Frequency mod {str(np.round(dnu, 2))} " + r"[$\mu$Hz]", fontsize=15) ax.set_ylabel(r"Frequency [$\mu$Hz]", fontsize=15) ax.set_ylim(freq[0], echy[-1]) for x in np.arange(echy.min(), echy.max()+dnu, dnu): ax.axhline(x, color='k', alpha=0.1) return ax
def _scatterFrame(model, samples, key1, key2, ax,): df = model.DR.priorData # Relevant bit of prior_data.csv prior1 = df[key1] prior2 = df[key2] # Sample used to construct prior select1 = model.DR.selectedSubset[key1] select2 = model.DR.selectedSubset[key2] # Samples from the sampling samplesU = model.unpackSamples(samples) # Compute some limits dx = (select1.max()-select1.min())*0.25 minx = select1.min()-dx maxx = select1.max()+dx dy = (select2.max()-select2.min())*0.25 miny = select2.min()-dy maxy = select2.max()+dy # Plot the things pidx = (minx < prior1) & (prior1 < maxx) & (miny < prior2) & (prior2 < maxy) ax.scatter(prior1[pidx], prior2[pidx], ec='k', alpha=0.25, s=8, fc='None') ax.scatter(select1, select2, c='C3', alpha=0.55, s=15) ax.set_xlim(minx, maxx) ax.set_ylim(miny, maxy) if not np.isnan(samples).all() and (key1 in samplesU.keys()) and (key2 in samplesU.keys()): # Only plot the summary stats result1 = np.percentile(samplesU[key1], [15, 50, 85]) result2 = np.percentile(samplesU[key2], [15, 50, 85]) if key1 in model.logpars: result1 = np.log10(result1) if key2 in model.logpars: result2 = np.log10(result2) ax.errorbar(result1[1], result2[1], xerr=np.array([[result1[1]-result1[0]], [result1[2]-result1[1]]]), yerr=np.array([[result2[1]-result2[0]], [result2[2]-result2[1]]]), fmt='.-', lw=5, ms=25, markeredgecolor='k', color='C0') elif not np.isnan(samples).all() and (key1 in samplesU.keys()) and not (key2 in samplesU.keys()): result1 = np.percentile(samplesU[key1], [15, 50, 85]) if key1 in model.logpars: result1 = np.log10(result1) ax.fill_betweenx([miny, maxy], x1=result1[0], x2=result1[2], color='C0', alpha=0.2) ax.axvline(result1[1], color='C0') def _baseReference(model, samples, fac=3): labels = model.pcaLabels + model.addLabels labelsInfile = [label for label in labels if label in model.DR.priorData.keys()] for key in model.DR.selectLabels: if not key in labelsInfile: labelsInfile.append(key) Nf = len(labelsInfile)-1 fig, axes = plt.subplots(Nf, Nf, figsize=(fac*Nf, fac*Nf)) for i in range(Nf): for j in range(Nf): ax = axes[j, i] key1 = labelsInfile[i] key2 = labelsInfile[j+1] if j >= i: _scatterFrame(model, samples, key1, key2, ax) else: ax.set_visible(False) if i != 0: ax.set_yticks([]) else: ax.set_ylabel(key2) if j != Nf-1: ax.set_xticks([]) else: ax.set_xlabel(key1) axes[0, 0].scatter(np.nan, np.nan, ec='k', alpha=1, s=8, fc='None', label=f'Viable prior sample') axes[0, 0].scatter(np.nan, np.nan, c='C3', alpha=0.55, s=15, label='Selected prior sample') if not np.isnan(samples).all(): axes[0, 0].errorbar(np.nan, np.nan, xerr=np.array([[np.nan], [np.nan]]), yerr=np.array([[np.nan], [np.nan]]), fmt='.-', lw=5, ms=25, markeredgecolor='k', color='C0', label='Result summary statistics') axes[0, 0].legend(bbox_to_anchor=(4, 1), fontsize=24, markerscale=2.) return fig, axes def _ModeIDPriorReference(model, N=1000): samples = np.zeros_like(model.samples[:N, :]) * np.nan fig, axes = _baseReference(model, samples) return fig, axes def _ModeIDPosteriorReference(model, N=1000): samples = model.samples[:N, :].copy() fig, axes = _baseReference(model, samples) return fig, axes @jax.jit def _echellify_freqs(nu, dnu, offset=0): x = (nu - offset*dnu) % dnu y = nu return x, y def _baseEchelle(f, s, N_p, numax, dnu, scale, **kwargs): """ Generate a base echelle diagram of the PSD. Parameters ---------- numax : float Central frequency. dnu : float Frequency spacing. scale : float Smoothing scale. Returns ------- matplotlib.figure.Figure The generated matplotlib Figure. matplotlib.axes._axes.Axes The generated matplotlib Axes. Notes ----- - Computes the echelle diagram for the given frequency range. - Smoothes the diagram using the specified smoothing scale. - Returns the generated Figure and Axes objects. """ n = max([round(N_p*3/4)+1, 9]) idx = ((numax - (n) * dnu) < f) & (f < (numax + (n+2) * dnu)) f, s = f[idx], s[idx] fig, ax = plt.subplots(figsize=(8,7)) plot_echelle(f, s, numax, dnu, ax=ax, smooth=True, smooth_filter_width=dnu * scale, **kwargs) return fig, ax def _ModeIDClassPriorEchelle(self, Nsamples, scale, colors, dnu=None, numax=None, DPi1=None, eps_g=None, **kwargs): if dnu is None: dnu = self.obs['dnu'][0] if numax is None: numax = self.obs['numax'][0] fig, ax = _baseEchelle(self.f, self.s, self.N_p, numax, dnu, scale) if hasattr(self, 'l20model'): junpackl20 = jax.jit(self.l20model.unpackParams) jptforml20 = jax.jit(self.l20model.ptform) jasy_nu_p = jax.jit(self.l20model.asymptotic_nu_p) for _ in range(Nsamples): u = np.random.uniform(0, 1, size=self.l20model.ndims) theta = jptforml20(u) thetaU = junpackl20(theta) nu0p, _ = jasy_nu_p(**thetaU) nu2p = nu0p + thetaU['d02'] for freqs, ell in zip([nu0p, nu2p], [0, 2]): smp_x, smp_y = _echellify_freqs(freqs, dnu) ax.scatter(smp_x, smp_y, alpha=0.05, color=colors[ell], s=100) for ell in [0, 2]: ax.scatter(np.nan, np.nan, alpha=1, color=colors[ell], s=100, label=r'$\ell=$'+str(ell)) if hasattr(self, 'l1model'): junpackl1 = jax.jit(self.l1model.unpackParams) jptforml1 = jax.jit(self.l1model.ptform) jnu = jax.jit(self.l1model.nu1_frequencies) for _ in range(Nsamples): u = np.random.uniform(0, 1, size=self.l1model.ndims) theta = jptforml1(u) thetaU = junpackl1(theta) nu1 = jnu(thetaU) smp_x, smp_y = _echellify_freqs(nu1, dnu) ax.scatter(smp_x, smp_y, alpha=0.05, color=colors[1], s=100) ax.scatter(np.nan, np.nan, alpha=1, color=colors[1], s=100, label=r'$\ell=$'+str(1)) # if DPi1 is None: # DPi1 = self.l1model.priors['DPi1'].ppf(0.5) # if eps_g is None: # eps_g = self.l1model.priors['eps_g'].ppf(0.5) # # Overplot gmode frequencies # nu_g = self.l1model.asymptotic_nu_g(self.l1model.n_g, DPi1, eps_g,) # curlyN = dnu / (DPi1 *1e-6 * numax**2) # ylims = ax.get_ylim() # if curlyN > 1: # nu_g_x, nu_g_y = _echellify_freqs(nu_g, dnu) # ax.scatter(nu_g_x, nu_g_y, color=colors[1]) # else: # for i, nu in enumerate(nu_g): # ax.axhline(nu, color='k', ls='dashed') # if (ylims[0] < nu) & (nu < ylims[1]): # ax.text(dnu, nu + dnu/2, s=r'$n_g$'+f'={self.l1model.n_g[i]}', ha='right', fontsize=11) # ax.axhline(np.nan, color='k', ls='dashed', label='g-modes \n' + r'$\Delta\Pi_1=$'+f'{np.round(DPi1, decimals=0)}s \n' r'$\epsilon_g=$'+f'{np.round(eps_g, decimals=2)}') ax.set_xlim(0, dnu) ax.legend(loc=1) return fig, ax def _ModeIDClassPostEchelle(self, Nsamples, colors, dnu=None, numax=None, **kwargs): gmodes = False if (dnu is None) and hasattr(self, 'result'): dnu = self.result['summary']['dnu'][0] else: dnu = self.obs['dnu'][0] if (numax is None) and hasattr(self, 'result'): numax = self.result['summary']['numax'][0] else: numax = self.obs['numax'][0] Epsilon = self.result['summary']['eps_p'][0] offset = Epsilon - 0.25 fig, ax = _baseEchelle(self.f, self.s, self.N_p, numax, dnu, offset=offset * dnu, **kwargs) axes = np.array([ax]) if hasattr(self, 'result'): for l in np.unique(self.result['ell']).astype(int): idx_ell = (self.result['ell'] == l ) & (self.result['emm'] == 0) freqs = self.result['samples']['freq'][:Nsamples, idx_ell] smp_x, smp_y = _echellify_freqs(freqs, dnu, offset=offset) ax.scatter(smp_x, smp_y, alpha=0.05, color=colors[l], s=10) med_freqs = self.result['summary']['freq'][0, self.result['emm'] == 0] med_x, med_y = _echellify_freqs(med_freqs, dnu, offset) ax.scatter(med_x, med_y, alpha=1, s=100, facecolors='none', edgecolors='k', linestyle='--') # Add to legend ax.scatter(np.nan, np.nan, alpha=1, color=colors[l], s=100, label=r'$\ell=$'+str(l)) ylims = ax.get_ylim() # # If fudge frequencies are used plot those # if hasattr(self, 'l1model') and 'freqError0' in self.result['summary'].keys(): # rect_ax = fig.add_axes([0.92, 0.107, 0.2, 0.775]) # rect_ax.set_xlabel(r'$\sigma_{\nu,\ell=1}$') # rect_ax.set_yticks([]) # rect_ax.set_ylim(ax.get_ylim()) # rect_ax.fill_betweenx(ax.get_ylim(), # x1=self.l1model.priors['freqError0'].mean - self.l1model.priors['freqError0'].scale, # x2=self.l1model.priors['freqError0'].mean + self.l1model.priors['freqError0'].scale, color='k', alpha=0.1) # rect_ax.fill_betweenx(ax.get_ylim(), # x1=self.l1model.priors['freqError0'].mean - 2*self.l1model.priors['freqError0'].scale, # x2=self.l1model.priors['freqError0'].mean + 2*self.l1model.priors['freqError0'].scale, color='k', alpha=0.1) # rect_ax.set_xlim(self.l1model.priors['freqError0'].mean - 3*self.l1model.priors['freqError0'].scale, # self.l1model.priors['freqError0'].mean + 3*self.l1model.priors['freqError0'].scale) # rect_ax.axvline(0, alpha=0.5, ls='dotted', color='k') # l1error = np.array([self.result['samples'][key] for key in self.result['samples'].keys() if key.startswith('freqError')]).T # rect_ax.plot(l1error[:Nsamples, :], self.result['samples']['freq'][:Nsamples, (self.result['ell']==1) & (self.result['emm']==0)], 'o', alpha=0.1, color='C4') # axes = np.append(axes, ax) # Overplot gmode frequencies if hasattr(self, 'l1model'): if self.l1model.N_g > 0: gmodes = True curlyN = dnu / (self.result['summary']['DPi1'][0] *1e-6 * numax**2) if curlyN < 1: nu_g = self.l1model.asymptotic_nu_g(self.l1model.n_g, self.result['summary']['DPi1'][0], self.result['summary']['eps_g'][0], ) for i, nu in enumerate(nu_g): if (ylims[0] < nu) & (nu < ylims[1]): ax.text(dnu, nu + dnu/2, s=r'$n_g$'+f'={self.l1model.n_g[i]}', ha='right', fontsize=11) ax.axhline(nu, color='k', ls='dashed') ax.axhline(np.nan, color='k', ls='dashed', label='g-modes') #Overplot l=1 p-modes # nu0_p, _ = self.l20model.asymptotic_nu_p(self.result['summary']['numax'][0], # self.result['summary']['dnu'][0], # self.result['summary']['eps_p'][0], # self.result['summary']['alpha_p'][0],) # nu1_p = nu0_p + self.result['summary']['d01'][0] # nu1_p_x, nu1_p_y = _echellify_freqs(nu1_p, dnu) #ax.scatter(nu1_p_x, nu1_p_y, edgecolors='k', fc='None', s=100, label='p-like $\ell=1$') ax.set_xlim(0, dnu) ax.legend(ncols=len(np.unique(self.result['ell'])), loc=1, fontsize=12) return fig, axes def _PeakbagClassPriorEchelle(self, scale, colors, dnu=None, numax=None, **kwargs): if dnu is None: dnu = self.dnu[0] if numax is None: numax = np.median(self.freq) fig, ax = _baseEchelle(self.f, self.s, self.N_p, numax, dnu, scale) maxL = 0 for inst in self.pbInstances: freqPriors = {key:val for key,val in inst.priors.items() if 'freq' in key} for l in np.unique(inst.ell).astype(int): idx_ell = inst.ell == l nu = np.array([freqPriors[key].loc for key in np.array(list(freqPriors.keys()))[idx_ell]]) nu_err = np.array([freqPriors[key].scale for key in np.array(list(freqPriors.keys()))[idx_ell]]) nu_x, nu_y = _echellify_freqs(nu, dnu) ax.errorbar(nu_x, nu_y, xerr=nu_err, color=colors[l], fmt='o') maxL = max([maxL, l]) # Add to legend for l in range(maxL+1): ax.errorbar(-100, -100, xerr=1, color=colors[l], fmt='o', label=r'$\ell=$'+str(l)) ax.set_xlim(0, dnu) ax.legend(loc=1) return fig, ax def _PeakbagClassPostEchelle(self, Nsamples, scale, colors, dnu=None, numax=None, **kwargs): if dnu is None: dnu = self.dnu[0] if numax is None: numax = np.median(self.freq[0, :]) #offset = (self.result['summary']['eps_p'][0]) - 0.25 fig, ax = _baseEchelle(self.f, self.s, self.N_p, numax, dnu, scale) maxL = 0 for inst in self.pbInstances: for l in np.unique(inst.ell).astype(int): idx_ell = inst.ell == l freqs = inst.result['samples']['freq'][:Nsamples, idx_ell] smp_x, smp_y = _echellify_freqs(freqs, dnu) ax.scatter(smp_x, smp_y, alpha=0.05, color=colors[l], s=10) maxL = max([maxL, l]) # Add to legend for l in range(maxL+1): ax.scatter(np.nan, np.nan, alpha=1, color=colors[l], s=100, label=r'$\ell=$'+str(l)) #ax.plot([0, 0.05*dnu], [numax, numax], color='black', linewidth=4, label=rf'$\nu_{{max}}$={round(numax)}$\mu$Hz') ax.legend(ncols=len(np.unique(self.result['ell'])), loc=1, fontsize=12) return fig, ax def _baseSpectrum(ax, f, s, smoothness=0.1, alpha=0.6, xlim=[None, None], ylim=[None, None], Legend=None, **kwargs): #ax.plot(f, s, 'k-', label='Data', alpha=0.1) smoo = smooth_power(f, s, smoothness) ax.plot(f, smoo, 'k-', label=Legend, lw=3, alpha=alpha) _ylim = list(ax.get_ylim()) if ylim[0] is None: _ylim[0] = smoo.min()*0.1 else: _ylim[0] = ylim[0] if ylim[1] is None: _ylim[1] = smoo.max()*1.3 else: _ylim[1] = ylim[1] ax.set_ylim(_ylim) _xlim = list(ax.get_xlim()) if xlim[0] is None: _xlim[0] = f.min() else: _xlim[0] = xlim[0] if xlim[1] is None: _xlim[1] = f.max() else: _xlim[1] = xlim[1] ax.set_xlim(_xlim) def _makeBaseFrames(self): if not hasattr(self, 'l20model'): fig, ax = plt.subplots(2, 1, figsize=(16,18)) _baseSpectrum(ax[0], self.f, self.s) _baseSpectrum(ax[1], self.f[self.sel], self.s[self.sel]) elif hasattr(self, 'l20model') and not hasattr(self, 'l1model'): # only l20 has been run fig, ax = plt.subplots(3, 1, figsize=(16,18)) _baseSpectrum(ax[0], self.f, self.s) _baseSpectrum(ax[1], self.f[self.sel], self.s[self.sel]) _baseSpectrum(ax[2], self.f[self.sel], self.s[self.sel] / self.l20model.getMedianModel()) elif hasattr(self, 'l20model') and hasattr(self, 'l1model'): fig, ax = plt.subplots(4, 1, figsize=(16,18)) _baseSpectrum(ax[0], self.f, self.s) _baseSpectrum(ax[1], self.f[self.sel], self.s[self.sel]) _baseSpectrum(ax[2], self.f[self.sel], self.s[self.sel] / self.l20model.getMedianModel()) _baseSpectrum(ax[3], self.f[self.sel], self.l20residual / self.l1model.getMedianModel()) else: raise ValueError('Unable to make plots') ax[0].set_xlim(self.f.min(), self.f.max()) ax[0].set_yscale('log') ax[0].set_xscale('log') for i in range(ax.shape[0]): ax[i].set_ylabel(r'PSD [$\mathrm{ppm}^2/\mu \rm Hz$]') if i > 0: ax[i].set_xlim(self.f[self.sel].min(), self.f[self.sel].max()) if i > 1: ax[i].set_ylabel(r'Residual') ax[-1].set_xlabel(r'Frequency ($\mu \rm Hz$)') return fig, ax def _ModeIDClassPriorSpectrum(self, N): fig, ax = _makeBaseFrames(self) if hasattr(self, 'l20model'): rint = np.random.randint(0, len(self.result['samples']['dnu']), size=N) junpackl20 = jax.jit(self.l20model.unpackParams) jmodell20 = jax.jit(self.l20model.model) jptforml20 = jax.jit(self.l20model.ptform) for k in rint: u = np.random.uniform(0, 1, size=self.l20model.ndims) theta = jptforml20(u) thetaU = junpackl20(theta) mod = jmodell20(thetaU) ax[0].plot(self.f[self.sel], mod, color='C3', alpha=0.2) ax[1].plot(self.f[self.sel], mod, color='C3', alpha=0.2) ax[0].plot([-100, -100], [-100, -100], color='C3', label='Prior samples', alpha=1) ax[1].plot([-100, -100], [-100, -100], color='C3', label='Prior samples', alpha=1) if hasattr(self, 'l1model'): rint = np.random.randint(0, len(self.result['samples']['d01']), size=N) junpackl1 = jax.jit(self.l1model.unpackParams) jmodell1 = jax.jit(self.l1model.model) jptforml1 = jax.jit(self.l1model.ptform) for k in rint: u = np.random.uniform(0, 1, size=self.l1model.ndims) theta = jptforml1(u) thetaU = junpackl1(theta) mod = jmodell1(thetaU,) ax[2].plot(self.f[self.sel], mod, color='C3', alpha=0.2) ax[2].plot([-100, -100], [-100, -100], color='C3', label='Prior samples', alpha=1) ax[0].legend(loc=3) return fig, ax def _ModeIDClassPostSpectrum(self, N): rint = np.random.randint(0, np.min([self.result['samples'][key].shape[0] for key in self.result['samples'].keys()]), size=N) fig, ax = _makeBaseFrames(self) if hasattr(self, 'l20model'): junpackl20 = jax.jit(self.l20model.unpackParams) jmodell20 = jax.jit(self.l20model.model) for k in rint: thetaU = junpackl20(self.l20Samples[k, :]) mod = jmodell20(thetaU) ax[0].plot(self.f[self.sel], mod, color='C3', alpha=0.2) ax[1].plot(self.f[self.sel], mod, color='C3', alpha=0.2) if hasattr(self, 'l1model'): junpackl1 = jax.jit(self.l1model.unpackParams) jmodell1 = jax.jit(self.l1model.model) for k in rint: thetaU = junpackl1(self.l1Samples[k, :]) mod = jmodell1(thetaU,) ax[2].plot(self.f[self.sel], mod, color='C3', alpha=0.2) llim, ulim = ax[2].get_ylim() line_bottom = ulim - 0.1*(ulim-llim) nu_l0 = self.result['summary']['freq'][0, self.result['ell']==0] for _, nu_l0 in enumerate(nu_l0): ax[2].plot([nu_l0, nu_l0],[line_bottom, ulim], lw=5, color=ellColors[0]) nu_l2 = self.result['summary']['freq'][0, self.result['ell']==2] for _, nu_l2 in enumerate(nu_l2): ax[2].plot([nu_l2, nu_l2],[line_bottom, ulim], lw=5, color=ellColors[2]) ax[0].plot([-100, -100], [-100, -100], color='C3', label='Posterior samples', alpha=1) ax[0].legend(loc=3, fontsize=14) # for i in range(1, ax.shape[0]): # for j, nu in enumerate(self.result['summary']['freq'][0]): # if (i==1 and self.result['ell'][j]==1) or (i==2 and self.result['ell'][j]!=1): # _alpha=0.35 # else: # _alpha=1.0 # ax[i].axvline(nu, c='k', linestyle='--', alpha=_alpha, lw=3) # ax[i].plot([-100, -100], [-100, -100], color='C3', label='Posterior samples', alpha=1) # ax[i].axvline(-100, c='k', linestyle='--', label='Median frequencies') # ax[i].legend(loc=2) return fig, ax def _PeakbagClassPriorSpectrum(self, N): fig, ax = plt.subplots(figsize=(16,9)) _baseSpectrum(ax, self.f, self.snr, smoothness=0.05) for inst in self.pbInstances: junpack = jax.jit(inst.unpackParams) jmodel = jax.jit(inst.model) jptform = jax.jit(inst.ptform) for _ in range(N): u = np.random.uniform(0, 1, size=inst.ndims) theta = jptform(u) theta_u = junpack(theta) m = jmodel(theta_u) ax.plot(inst.f[inst.sel], m, alpha=0.2, color='C3') xlims = [float(min([min(inst.f[inst.sel]) for inst in self.pbInstances])), float(max([max(inst.f[inst.sel]) for inst in self.pbInstances]))] ax.plot([-100, -100], [-100, -100], color='C3', label='Prior samples', alpha=1) ax.set_ylabel(r'PSD [$\mathrm{ppm}^2/\mu \rm Hz$]') ax.set_xlabel(r'Frequency [$\mu \rm Hz$]') ax.set_xlim(xlims) ax.legend(loc=1) return fig, ax def _PeakbagClassPostSpectrum(self, N): fig, ax = plt.subplots(figsize=(16,9)) _baseSpectrum(ax, self.f, self.snr, smoothness=0.01, alpha=0.1, Legend='Smoothed') _baseSpectrum(ax, self.f, self.snr, smoothness=0.5, alpha=0.4, Legend='Super Smoothed') for inst in self.pbInstances: randInt = np.random.randint(0, inst.samples.shape[0], size=N) junpack = jax.jit(inst.unpackParams) jmodel = jax.jit(inst.model) for k in randInt: theta = inst.samples[k, :] theta_u = junpack(theta) m = jmodel(theta_u) ax.plot(inst.f[inst.sel], m, color='C3', alpha=0.2) xLim = [float(min([min(inst.f[inst.sel]) for inst in self.pbInstances])), float(max([max(inst.f[inst.sel]) for inst in self.pbInstances]))] Diff = xLim[1]*0.05 xlims = [xLim[0]-Diff, xLim[1]+Diff] ylims = [0, max(m)*1.5] ax.plot([-100, -100], [-100, -100], color='C3', label='Posterior samples', alpha=1) ax.set_ylabel(r'PSD [$\mathrm{ppm}^2/\mu \rm Hz$]', fontsize=15) ax.set_xlabel(r'Frequency [$\mu \rm Hz$]', fontsize=15) ax.set_xlim(xlims) ax.set_ylim(ylims) ax.legend(loc=1, fontsize=15) return fig, ax def _baseCorner(samples, labels): logging.disable(logging.WARNING) fig = corner.corner(np.array([samples[key] for key in labels]).T, hist_kwargs={'density': True}, labels=labels) logging.getLogger().setLevel(logging.WARNING) return fig def _setSampleToPlot(self, N, unpacked=True, stage='prior'): if stage == 'prior': samples = np.array([self.ptform(np.random.uniform(0, 1, size=self.ndims)) for i in range(N)]) else: samples = self.samples if unpacked: _samples = self.unpackSamples(samples) else: _samples = {label: samples[:, i] for i, label in enumerate(self.priors.keys())} return _samples def _ModeIDClassPriorCorner(self, modObj, unpacked, N, **kwargs): _samples = _setSampleToPlot(modObj, N, unpacked=unpacked, stage='prior') labels = list(_samples.keys()) fig = _baseCorner(_samples, labels) axes = np.array(fig.get_axes()).reshape((len(labels), len(labels))) if not unpacked: for i, key in enumerate(labels): if key in modObj.priors.keys(): x = np.linspace(modObj.priors[key].ppf(1e-6), modObj.priors[key].ppf(1-1e-6), 100) pdf = np.array([modObj.priors[key].pdf(x[j]) for j in range(len(x))]) axes[i, i].plot(x, pdf, color='C3', alpha=0.5, lw =5) return fig, axes def _ModeIDClassPostCorner(self, modObj, unpacked, N, **kwargs): _samples = _setSampleToPlot(modObj, N, unpacked=unpacked, stage='posterior') labels = list(_samples.keys()) fig = _baseCorner(_samples, labels) axes = np.array(fig.get_axes()).reshape((len(labels), len(labels))) if not unpacked: for i, key in enumerate(labels): if key in modObj.priors.keys(): x = np.linspace(modObj.priors[key].ppf(1e-9), modObj.priors[key].ppf(1-1e-9), 100) pdf = np.array([modObj.priors[key].pdf(x[j]) for j in range(len(x))]) axes[i, i].plot(x, pdf, color='C3', alpha=0.5, lw =5) return fig, axes def _PeakbagClassPriorCorner(self, samples, labelType, **kwargs): _samples = self.unpackSamples(samples) subSamples = {key: v for key, v in _samples.items() if any([l in key for l in labelType])} plotLabels = list(subSamples.keys()) fig = _baseCorner(subSamples, plotLabels) axes = np.array(fig.get_axes()).reshape((len(plotLabels), len(plotLabels))) for i, key in enumerate(subSamples.keys()): x = np.linspace(self.priors[key].ppf(1e-6), self.priors[key].ppf(1-1e-6), 100) pdf = np.array([self.priors[key].pdf(x[j]) for j in range(len(x))]) isLog10 = [self.variables[varKey]['log10'] for varKey in self.variables if key.startswith(varKey)][0] if isLog10: axes[i, i].plot(10**x, pdf/10**x/np.log(10.0), color='C2', alpha=0.5, lw=5) else: axes[i, i].plot(x, pdf, color='C2', alpha=0.5, lw=5) if any([key.startswith(l) for l in ['height', 'freq', 'width']]): axes[i,i].patch.set_facecolor(self.ell[i]) axes[i,i].patch.set_alpha(0.25) return fig, axes def _PeakbagClassPostCorner(self, samples, labelType, colors, **kwargs): _samples = self.unpackSamples(samples) subSamples = {key: v for key, v in _samples.items() if any([l in key for l in labelType])} plotLabels = list(subSamples.keys()) fig = _baseCorner(subSamples, plotLabels) axes = np.array(fig.get_axes()).reshape((len(plotLabels), len(plotLabels))) for i, key in enumerate(subSamples.keys()): x = np.linspace(self.priors[key].ppf(1e-6), self.priors[key].ppf(1-1e-6), 100) pdf = np.array([self.priors[key].pdf(x[j]) for j in range(len(x))]) isLog10 = [self.variables[varKey]['log10'] for varKey in self.variables if key.startswith(varKey)][0] if isLog10: axes[i, i].plot(10**x, pdf/10**x/np.log(10.0), color='C2', alpha=0.5, lw=5) else: axes[i, i].plot(x, pdf, color='C2', alpha=0.5, lw=5) if any([key.startswith(l) for l in ['height', 'freq', 'width']]): axes[i,i].patch.set_facecolor(colors[int(self.ell[i])]) axes[i,i].patch.set_alpha(0.25) return fig, axes
[docs] class plotting(): """ Class inherited by PBjam modules to plot results This is used to standardize the plots produced at various steps of the peakbagging process. The methods will plot the relevant result based on the class they are being called from. """ def __init__(self): pass def _save_my_fig(self, fig, figtype, path, ID): """ Save the figure object Saves the figure object with a predefined path name pattern. Parameters ---------- fig : Matplotlib figure object Figure object to be saved. figtype : str The type of figure in question. This is used to set the filename. path : str, optional Used along with savefig, sets the output directory to store the figure. Default is to save the figure to the star directory. ID : str, optional ID of the target to be included in the filename of the figure. """ # TODO there should be a check if path is full filepath or just dir if path and ID: outpath = os.path.join(*[path, type(self).__name__+f'_{figtype}_{str(ID)}.png']) fig.savefig(outpath) def echelle(self, stage='posterior', ID=None, savepath=None, save_kwargs={}, kwargs={}): if not 'colors' in kwargs: kwargs['colors'] = ellColors if not 'scale' in kwargs: kwargs['scale'] = 1/300 if not 'Nsamples' in kwargs: kwargs['Nsamples'] = 200 if self.__class__.__name__ == 'modeID': if stage=='prior': fig, ax = _ModeIDClassPriorEchelle(self, **kwargs) elif stage=='posterior': fig, ax = _ModeIDClassPostEchelle(self, **kwargs) else: raise ValueError('Set stage optional argument to either prior or posterior') elif self.__class__.__name__ == 'peakbag': if stage=='prior': fig, ax = _PeakbagClassPriorEchelle(self, **kwargs) elif stage=='posterior': fig, ax = _PeakbagClassPostEchelle(self, **kwargs) else: raise ValueError('Set stage optional argument to either prior or posterior') else: raise ValueError('Unrecognized class type. Only modeID and peakbag have this plotting function built in.') if ID is not None: ax.set_title(ID) if (savepath is not None): fig.savefig(savepath, **save_kwargs) return fig, ax def spectrum(self, stage='posterior', ID=None, savepath=None, kwargs={}, save_kwargs={}, N=30): if self.__class__.__name__ == 'modeID': if stage=='prior': fig, ax = _ModeIDClassPriorSpectrum(self, N, **kwargs) elif stage=='posterior': assert hasattr(self, 'result') fig, ax = _ModeIDClassPostSpectrum(self, N, **kwargs) else: raise ValueError('Set stage optional argument to either prior or posterior') elif self.__class__.__name__ == 'peakbag': if stage=='prior': fig, ax = _PeakbagClassPriorSpectrum(self, N, **kwargs) elif stage=='posterior': fig, ax = _PeakbagClassPostSpectrum(self, N, **kwargs) else: raise ValueError('Set stage optional argument to either prior or posterior') else: raise ValueError('Unrecognized class type. Only modeID and peakbag have this plotting function built in.') if ID is not None: ax.set_title(ID) fig.tight_layout() if savepath is not None: fig.savefig(savepath, **save_kwargs) return fig, ax def corner(self, stage='posterior', ID=None, labels=None, savepath=None, unpacked=False, kwargs={}, save_kwargs={}, N=5000): if not 'colors' in kwargs: kwargs['colors'] = ellColors if self.__class__.__name__ == 'modeID': if stage=='prior': fig, ax = [], [] if hasattr(self, 'l20model'): figl20, axl20 = _ModeIDClassPriorCorner(self, self.l20model, unpacked, N, **kwargs) fig.append(figl20) ax.append(axl20) else: warnings.warn('modeID does not currently have and l20model attribute, use runl20model first.') if hasattr(self, 'l1model'): figl1, axl1 = _ModeIDClassPriorCorner(self, self.l1model, unpacked, N, **kwargs) fig.append(figl1) ax.append(axl1) else: warnings.warn('modeID does not currently have and l1model attribute, use runl1model first.') elif stage=='posterior': fig, ax = [], [] if hasattr(self, 'l20model'): figl20, axl20 = _ModeIDClassPostCorner(self, self.l20model, unpacked, N, **kwargs) fig.append(figl20) ax.append(axl20) else: warnings.warn('modeID does not currently have and l20model attribute, use runl20model first.') if hasattr(self, 'l1model'): figl1, axl1 = _ModeIDClassPostCorner(self, self.l1model, unpacked, N, **kwargs) fig.append(figl1) ax.append(axl1) else: warnings.warn('modeID does not currently have and l1model attribute, use runl1model first.') else: raise ValueError('Set stage optional argument to either prior or posterior') elif self.__class__.__name__ == 'peakbag': if stage=='prior': samples = np.array([self.ptform(np.random.uniform(0, 1, size=self.ndims)) for i in range(N)]) fig, ax = _PeakbagClassPriorCorner(self, samples, labels, **kwargs) elif stage=='posterior': samples = self.samples fig, ax = _PeakbagClassPostCorner(self, samples, labels, **kwargs) else: raise ValueError('Set stage optional argument to either prior or posterior') else: raise ValueError('Unrecognized class type. Only modeID and peakbag have this plotting function built in.') return fig, ax
[docs] def reference(self, stage='posterior', ID=None): """Make a corner plot of the prior sample with relevant overplotted values.""" if self.__class__.__name__ == 'modeID': fig, axes = [], [] if stage=='prior': if hasattr(self, 'l20model'): figl20, axl20 = _ModeIDPriorReference(self.l20model) fig.append(figl20) axes.append(axl20) if hasattr(self, 'l1model'): figl1, axl1 = _ModeIDPriorReference(self.l1model) fig.append(figl1) axes.append(axl1) elif stage=='posterior': if hasattr(self, 'l20model'): figl20, axl20 = _ModeIDPosteriorReference(self.l20model) fig.append(figl20) axes.append(axl20) if hasattr(self, 'l1model'): figl1, axl1 = _ModeIDPosteriorReference(self.l1model) fig.append(figl1) axes.append(axl1) else: raise ValueError('Set stage optional argument to either prior or posterior') return fig, axes else: raise ValueError('This kind of plot is only available for the modeIDsampler module.')
# def plotLatentCorner(self, samples, labels=None): # if labels == None: # labels = list(self.priors.keys()) # fig = corner.corner(samples, hist_kwargs = {'density': True}, labels=labels) # axes = np.array(fig.get_axes()).reshape((len(labels), len(labels))) # for i, key in enumerate(labels): # if key in self.priors.keys(): # x = np.linspace(self.priors[key].ppf(1e-6), self.priors[key].ppf(1-1e-6), 100) # pdf = np.array([self.priors[key].pdf(x[j]) for j in range(len(x))]) # axes[i, i].plot(x, pdf, color='C3', alpha=0.5, lw =5) # def plot_corner(self, path=None, ID=None, savefig=False): # """ Make corner plot of result. # Makes a nice corner plot of the fit parameters. # Parameters # ---------- # path : str, optional # Used along with savefig, sets the output directory to store the # figure. Default is to save the figure to the star directory. # ID : str, optional # ID of the target to be included in the filename of the figure. # savefig : bool # Whether or not to save the figure to disk. Default is False. # Returns # ------- # fig : Matplotlib figure object # Figure object with the corner plot. # """ # if not hasattr(self, 'samples'): # warnings.warn(f"'{self.__class__.__name__}' has no attribute 'samples'. Can't plot a corner plot.") # return None # fig = corner.corner(self.samples, labels=self.par_names, # show_titles=True, quantiles=[0.16, 0.5, 0.84], # title_kwargs={"fontsize": 12}) # if savefig: # self._save_my_fig(fig, 'corner', path, ID) # return fig # def _fill_diag(self, axes, vals, vals_err, idxs): # """ Overplot diagonal values along a corner plot diagonal. # Plots a set of specified values over the 1D histograms in the diagonal # frames of a corner plot. # Parameters # ---------- # axes : Matplotlib axis object # The particular axis element to be plotted in. # vals : float # Mean values to plot. # vals_err : float # Error estimates for the value to be plotted. # idxs : list # List of 2D indices that represent the diagonal. # """ # N = int(np.sqrt(len(axes))) # axs = np.array(axes).reshape((N,N)).T # for i,j in enumerate(idxs): # yrng = axs[j,j].get_ylim() # v, ve = vals[i], vals_err[i] # axs[j,j].fill_betweenx(y=yrng, x1= v-ve[0], x2 = v+ve[-1], color = 'C3', alpha = 0.5) # def _plot_offdiag(self, axes, vals, vals_err, idxs): # """ Overplot offdiagonal values in a corner plot. # Plots a set of specified values over the 2D histograms or scatter in the # off-diagonal frames of a corner plot. # Parameters # ---------- # axes : Matplotlib axis object # The particular axis element to be plotted in. # vals : float # Mean values to plot. # vals_err : float # Error estimates for the value to be plotted. # idxs : list # List of 2D indices that represent the diagonal. # """ # N = int(np.sqrt(len(axes))) # axs = np.array(axes).reshape((N,N)).T # for i, j in enumerate(idxs): # for m, k in enumerate(idxs): # if j >= k: # continue # v, ve = vals[i], vals_err[i] # w, we = vals[m], vals_err[m] # axs[j,k].errorbar(v, w, xerr=ve.reshape((2,1)), yerr=we.reshape((2,1)), fmt = 'o', ms = 10, color = 'C3') # def _make_prior_corner(self, df, numax_rng = 100): # """ Show dataframe contents in a corner plot. # This is meant to be used to show the contents of the prior_data that is # used by KDE and Asy_peakbag. # Parameters # ---------- # df : pandas.Dataframe object # Dataframe of the data to be shown in the corner plot. # numax_rng : float, optional # Range in muHz around the input numax to be shown in the corner plot. # The default is 100. # Returns # ------- # crnr : matplotlib figure object # Corner plot figure object containing NxN axis objects. # crnr_axs : list # List of axis objects from the crnr object. # """ # idx = abs(10**df['numax'] - self._obs['numax'][0]) <= numax_rng # # This is a temporary fix for disabling the 'Too few samples' warning # # from corner.hist2d. The github bleeding edge version has # # hist2d_kwargs = {'quiet': True}, but this isn't in the pip # # installable version yet. March 2020. # logging.disable(logging.WARNING) # crnr = corner.corner(df.to_numpy()[idx,:-1], data_kwargs = {'alpha': 0.5}, labels = df.keys()); # logging.getLogger().setLevel(logging.WARNING) # return crnr, crnr.get_axes() # def plot_prior(self, path=None, ID=None, savefig=False): # """ Corner of result in relation to prior sample. # Create a corner plot showing the location of the star in relation to # the rest of the prior. # Parameters # ---------- # path : str, optional # Used along with savefig, sets the output directory to store the # figure. Default is to save the figure to the star directory. # ID : str, optional # ID of the target to be included in the filename of the figure. # savefig : bool # Whether or not to save the figure to disk. Default is False. # Returns # ------- # crnr : matplotlib figure object # Corner plot figure object containing NxN axis objects. # """ # df = pd.read_csv(self.prior_file) # crnr, axes = self._make_prior_corner(df) # if type(self) == pbjam.star: # vals, vals_err = np.array([self._log_obs['dnu'], # self._log_obs['numax'], # self._log_obs['teff'], # self._obs['bp_rp']]).T # vals_err = np.vstack((vals_err, vals_err)).T # self._fill_diag(axes, vals, vals_err, [0, 1, 8, 9]) # self._plot_offdiag(axes, vals, vals_err, [0, 1, 8, 9]) # # break this if statement if asy_fit should plot something else # elif (type(self) == pbjam.priors.kde) or (type(self) == pbjam.asy_peakbag.asymptotic_fit): # percs = np.percentile(self.samples, [16, 50, 84], axis=0) # vals = percs[1,:] # vals_err = np.diff(percs, axis = 0).T # self._fill_diag(axes, vals, vals_err, range(len(vals))) # self._plot_offdiag(axes, vals, vals_err, range(len(vals))) # elif type(self) == pbjam.peakbag: # raise AttributeError('The result of the peakbag run cannot be plotted in relation to the prior, since it does not know what the prior is anymore. plot_corner is only available for star, kde and asy_peakbag.') # return crnr