Source code for pbjam.core

"""
The PBjam core module contains the session and star classes which take care of many
of the parts of setting up a peakbagging pipeline. This is meant to be a high-level
interface with the more detailed methods that PBjam makes use of and so may not 
be the best choice if constructing a custom pipeline. In such cases see the modeID and 
peakbag classes. 
"""

import numpy as np
import copy
from collections.abc import Iterable
import jax.numpy as jnp
from pbjam.plotting import plotting
from pbjam import IO
from pbjam.modeID import modeID
from pbjam.peakbagging import peakbag

def _convertToList(arg):
    """
    Convert the input argument to a list.

    Parameters
    ----------
    arg : str, tuple, np.ndarray, list
        The input argument to be converted. Can be a string, tuple, NumPy array, or list.

    Returns
    -------
    list
        The input argument converted to a list. If the input is already a list, it is returned as is.
        If the input is a string, it is wrapped in a list.

    Raises
    ------
    TypeError
        If the input argument is not of a supported type (str, tuple, list, or np.ndarray).
    """

    # If the argument is a string, return it wrapped in a list
    if isinstance(arg, str):
        return [arg]
    
    # If it's an iterable but not a list, convert it to a list
    elif isinstance(arg, (tuple, np.ndarray)):
        return list(arg)
    
    # If it's already a list, return as is
    elif isinstance(arg, list):
        return arg
    
    # Otherwise raise an error if the type is not supported
    else:
        raise TypeError("Unsupported type")
    
def _validateObs(obs, name):
    """
    Validate the observational data for a given target, ensuring required keys are present and values are in the form (value, error).

    Parameters
    ----------
    obs : dict
        Dictionary containing observational data with each key having a tuple of (value, error).
    name : str
        Name or identifier of the target being validated.

    Raises
    ------
    ValueError
        If any of the required keys ('numax', 'dnu', 'teff') are missing from the `obs` dictionary.
    AssertionError
        If any value in `obs` is not iterable or is not in the form of a tuple with two elements (value, error).
    """

    for key in ['numax', 'dnu', 'teff']:
        if key not in obs.keys():
            raise ValueError(f'Missing {key} in obs for target {name}')
        
    for key, val in obs.items():
        assert isinstance(val, Iterable), 'Entries in obs must be of the form (value, error)'
    
        assert len(val) == 2, 'Entries in obs must be of the form (value, error)'

[docs] class session(): """ Main class used to initiate peakbagging for several stars. Use this class to initialize a star class instance for one or more targets. Once initialized, calling the session class instance will execute a complete peakbagging run. The observational constraints, such numax, dnu, teff, bp_rp, must be provided through keyword entries in a dictionary, which is then passed via the obs argument when initializing the session class. Unless you provide the time series or spectrum, PBjam will download it. In which case it will do some rudimentary reduction, like removing outliers, removing NaN values and running a median filter through the light curve, with a width appropriate for the provided numax. Parameters ---------- name : str Target name, most commonly used identifiers can be used if you want PBjam to download the data (KIC, TIC, HD, Bayer etc.). If you provide data yourself the name can be any string. obs : dict Dictionary of observational inputs: numax, dnu, teff, bp_rp. timeseries : object, optional Timeseries input. Leave as None for PBjam to download it automatically. Otherwise, arrays of shape (2,N). spectrum : object, optional Spectrum input. Leave as None for PBjam to use Timeseries to compute it for you. Otherwise, arrays of shape (2,N). lk_kwargs : dict, optional Arguments passed to lightkurve to download the time series. outpath : str, optional Path to store the plots and results for the various stages of the peakbagging process. downloadDir : str, optional Directory to cache lightkurve downloads. Lightkurve will place the fits files in the default lightkurve cache path in your home directory. """ def __init__(self, name, obs, timeseries=None, spectrum=None, lk_kwargs={}, outpath=None, downloadDir=None): self.__dict__.update((k, v) for k, v in locals().items() if k not in ['self']) # Expect names is a list, # Setup inputs dictionary self.inputs = {} for nm in _convertToList(name): self.inputs[nm] = {} # Handle obs assert isinstance(obs, dict), 'The obs argument must be a dictionary.' # If keys don't match, assume it applies to all targets. for key in self.inputs.keys(): if key in obs.keys(): _obs = obs[key] else: _obs = obs _validateObs(_obs, key) self.inputs[key]['obs']= _obs # # Handle time series and spectrum # spectrum can be a dictionary with keys corresponding to names, if isinstance(spectrum, dict): assert spectrum.keys() == self.inputs.keys(), 'The targets in spectrum must match those in names.' # The values for each key must be iterable of shape (2, N) for key in self.inputs.keys(): assert spectrum[key].shape[0] == 2, f'Shape of spectrum for {key} must be (2, N)' self.inputs[key]['f'] = spectrum[key][0] self.inputs[key]['s'] = spectrum[key][1] # Spectrum can be a iterable of shape (2, N) elif isinstance(spectrum, (type(np.array([])), type(jnp.array([])))): assert spectrum.shape[0] == 2, f'Shape of spectrum for must be (2, N)' for key in self.inputs.keys(): self.inputs[key]['f'] = spectrum[0] self.inputs[key]['s'] = spectrum[1] # Spectrum can also be None, in which case we first look for a corresponding item in timeseries elif spectrum is None: if isinstance(timeseries, dict): assert timeseries.keys() == self.inputs.keys(), 'The targets in timeseries must match those in names.' for key in self.inputs.keys(): if timeseries[key].shape[0] == 3: psd = IO.psd(key, time=timeseries[key][0], flux=timeseries[key][1], flux_err=timeseries[key][2], useWeighted=True) elif timeseries[key].shape[0] == 2: psd = IO.psd(key, time=timeseries[key][0], flux=timeseries[key][1]) else: raise ValueError(f'Unhandled timeseries shape for computing psd for {key}') psd() self.inputs[key]['f'] = psd.freq self.inputs[key]['s'] = psd.powerdensity elif isinstance(timeseries, (type(np.array([])), type(jnp.array([])))): if timeseries.shape[0] == 3: psd = IO.psd(key, time=timeseries[0], flux=timeseries[1], flux_err=timeseries[2], useWeighted=True) elif timeseries.shape[0] == 2: psd = IO.psd(key, time=timeseries[0], flux=timeseries[1]) else: raise ValueError(f'Unhandled timeseries shape for computing psd for {key}') for key in self.inputs.keys(): psd() self.inputs[key]['f'] = psd.freq self.inputs[key]['s'] = psd.powerdensity elif timeseries is None: # Make sure lk_kwargs is not None assert isinstance(lk_kwargs, dict), 'To download data lk_kwargs must be a dict.' assert len(list(lk_kwargs.keys())) > 0 # If keys are the same as input, loop through them and assign to input[key] for key in self.inputs.keys(): if key in lk_kwargs.keys(): _lk_kwargs = lk_kwargs[key] else: _lk_kwargs = lk_kwargs psd = IO.psd(key, lk_kwargs=_lk_kwargs, downloadDir=downloadDir) psd() self.inputs[key]['f'] = psd.freq self.inputs[key]['s'] = psd.powerdensity else: raise ValueError('Timeseries must be a (2, N) array-like or dictionary with entries like the name argument.') else: raise ValueError('Spectrum must be a (2, N) array-like or dictionary with entries like the name argument.') self.stars = [] for nm, inpt in self.inputs.items(): self.stars.append(star(nm, outpath=outpath, **inpt)) def __call__(self, modeID_kwargs={}, peakbag_kwargs={}): """ Sequentially call all the star class instances Calling the session class instance will loop through all the stars that it contains, and call each one. This performs a full peakbagging run on each star in the session. Parameters ---------- modeID_kwargs : dict Arguments passed to the modeID stage of PBjam. peakbag_kwargs : dict Arguments passed to the peakbag stage of PBjam """ # If top level keys correspond if not (self.inputs.keys() == modeID_kwargs.keys()): _modeID_kwargs = {f'{st.name}': modeID_kwargs for st in self.stars} else: _modeID_kwargs = modeID_kwargs if not (self.inputs.keys() == peakbag_kwargs.keys()): _peakbag_kwargs = {f'{st.name}': peakbag_kwargs for st in self.stars} else: _peakbag_kwargs = peakbag_kwargs # Otherwise leave it to the user to make sure the keys are correct. for i, st in enumerate(self.stars): print() print(f'Target: {st.name}') st(_modeID_kwargs[f'{st.name}'], _peakbag_kwargs[f'{st.name}'])
[docs] class star(plotting): """ Main class used to initiate peakbagging for a single stars. Use this class to initialize a star class instance for one target. Once initialized, calling the star class instance will execute a complete peakbagging run. The observational constraints, such numax, dnu, teff, bp_rp, must be provided through keyword entries in a dictionary, which is then passed via the obs argument when initializing the session class. The star class only accepts a power density spectrum in the form of a list of frequency bins 'f' and power density 's'. Parameters ---------- name : str Target name, most commonly used identifiers can be used if you want PBjam to download the data (KIC, TIC, HD, Bayer etc.). If you provide data yourself the name can be any string. f : array-like Frequency bins of the power density spectrum. s : array-like Power density spectrum with the same shape as 'f'. obs : dict Dictionary of observational inputs: numax, dnu, teff, bp_rp. outpath : str, optional Path to store the plots and results for the various stages of the peakbagging process. Default is to output to the working directory. kwargs : dict Dictionary of additional keyword arguments for either modeID or peakbag. """ def __init__(self, name, f, s, obs, outpath=None, **kwargs): self.__dict__.update((k, v) for k, v in locals().items() if k not in ['self']) self.__dict__.update(kwargs) del self.__dict__['kwargs'] self.outpath = IO._setOutpath(self.name, self.outpath) for key, val in self.obs.items(): assert isinstance(val, Iterable), 'Entries in obs must be of the form (value, error)' assert len(val) == 2, 'Entries in obs must be of the form (value, error)'
[docs] def runModeID(self, modeID_kwargs={}): """ Run the mode identification process using the provided or default keyword arguments. This method creates a `modeID` instance and executes it with the arguments provided in `modeID_kwargs` or from the current object's attributes. If `priorpath` is not specified, it fetches the path to the prior file. Parameters ---------- modeID_kwargs : dict, optional Dictionary of additional keyword arguments to update or override the current object's attributes when initializing the `modeID` instance. Default is an empty dictionary. Raises ------ KeyError If required parameters for mode identification are missing. """ _modeID_kwargs = copy.deepcopy(self.__dict__) _modeID_kwargs.update(modeID_kwargs) if not 'priorpath' in _modeID_kwargs: self.priorpath = IO._getPriorPath() _modeID_kwargs['priorpath'] = self.priorpath self.modeID = modeID(**_modeID_kwargs) self.modeID(**_modeID_kwargs)
[docs] def runPeakbag(self, peakbag_kwargs={}): """ Run the peakbagging process using the provided or default keyword arguments. This method creates a `peakbag` instance and executes it with the arguments provided in `peakbag_kwargs` or from the current object's attributes. It uses mode identification results to set missing parameters if necessary. Parameters ---------- peakbag_kwargs : dict, optional Dictionary of additional keyword arguments to update or override the current object's attributes when initializing the `peakbag` instance. Default is an empty dictionary. """ _peakbag_kwargs = copy.deepcopy(self.__dict__) _peakbag_kwargs.update(peakbag_kwargs) if not 'ell' in _peakbag_kwargs.keys(): _peakbag_kwargs.update({'ell': self.modeID.result['ell']}) _peakbag_kwargs.update(self.modeID.result['summary']) if 'RV' in self.obs.keys(): _peakbag_kwargs.update({'RV': self.obs['RV']}) self.peakbag = peakbag(**_peakbag_kwargs) self.peakbag(**_peakbag_kwargs)
def __call__(self, modeID_kwargs={}, peakbag_kwargs={}): """ Execute the modeID and peakbag stages. Will pass the relevant output from modeID to peakbag. Passing arguments that have already been given when initializing the star class will override that parameter. This can be used to, for example, change the resolution of the spectrum from the modeID to the peakbag stage, which may sometimes save time since the modeID doesn't require such high resolution. Parameters ---------- modeID_kwargs : dict, optional Arguments to be passed to the modeID module. peakbag_kwargs : dict, optional Arguments to be passed to the peakbag module. Returns ------- modeID_result: dict Dictionary of results from the modeID stage. peakbag_result: dict Dictionary of results from the peakbag stage. """ # Run the mode ID stage self.runModeID(modeID_kwargs) # Run the peakbag stage self.runPeakbag(peakbag_kwargs) return self.modeID.result, self.peakbag.result