Source code for jetset.base_model

"""Base model abstractions and shared utilities used across JetSeT models."""


__author__ = "Andrea Tramacere"


import numpy as np
import copy
import dill as pickle
import warnings
import inspect
import numbers
from astropy.table import Table

from .model_parameters import ModelParameterArray, ModelParameter
from .spectral_shapes import SED
from .data_loader import  ObsData
from .utils import  get_info
from .plot_sedfit import  PlotSED
from .utils import check_frame,unexpected_behaviour

from .cosmo_tools import  Cosmo

__all__=['Model','MultiplicativeModel']




[docs] class Model(object): """Base class for analytical and numerical SED models.""" def __init__(self,name='no_name', nu_size=200, model_type='base_model', scale='lin-lin', cosmo=None, nu_min=None, nu_max=None): """Initialize a model container. Parameters ---------- name : str, optional Model name. nu_size : int, optional Number of frequencies used when building an internal evaluation grid. model_type : str, optional Descriptive model type label. scale : str, optional Preferred plotting/evaluation scale. cosmo : Cosmo, optional Cosmology helper. If omitted, a default :class:`Cosmo` instance is used. nu_min : float, optional Minimum frequency of the model grid in Hz. nu_max : float, optional Maximum frequency of the model grid in Hz. """ self.model_type=model_type self.name=name self.SED = SED(name=self.name) self.parameters = ModelParameterArray() self._scale=scale self.nu_size=nu_size self.nu_min=nu_min self.nu_max=nu_max self.flux_plot_lim = 1E-30 if cosmo is None: self.cosmo=Cosmo() else: self.cosmo=cosmo self._set_version(v=None) @property def version(self): """Return the package version used to create this model. Returns ------- str Version string. """ return self._version def _set_version(self, v=None): if v is None: self._version = get_info()['version'] else: self._version = v def _prepare_nu_model(self,nu,loglog): if nu is None: x1 = np.log10(self.nu_min) x2 = np.log10(self.nu_max) lin_nu = np.logspace(x1, x2, self.nu_size) log_nu = np.log10(lin_nu) else: if np.shape(nu) == (): nu = np.array([nu]) if loglog is True: lin_nu = np.power(10., nu) log_nu = nu else: log_nu = np.log10(nu) lin_nu = nu return lin_nu,log_nu def _eval_model(self,lin_nu,log_nu,loglog): log_model=None if loglog is False: lin_model = self.lin_func(lin_nu) else: if hasattr(self, 'log_func'): log_model = self.log_func(log_nu) lin_model = np.power(10., log_model) else: lin_model = self.lin_func(lin_nu) lin_model[lin_model<0.]=self.flux_plot_lim log_model = np.log10(lin_model) return lin_model,log_model def _fill(self, lin_nu, lin_model): if hasattr(self,'SED'): self.SED.fill(nu=lin_nu, nuFnu=lin_model) z=self.get_par_by_type('redshift') if z is not None: z=z.val if z is None and hasattr(self, 'get_redshift'): z=self.get_redshift() z = z #print('--> fill z ',self.name,self.name,z) if z is not None: if hasattr(self,'get_DL_cm'): dl = self.get_DL_cm('redshift') else: dl = self.cosmo.get_DL_cm(z) self.SED.fill_nuLnu( z =z, dl = dl) else: warnings.warn('model',self.name,'of type',type(self),'has no SED member') if hasattr(self, 'spectral_components_list'): for i in range(len(self.spectral_components_list)): self.spectral_components_list[i].fill_SED(lin_nu=lin_nu,skip_zeros=False)
[docs] def eval(self, fill_SED=True, nu=None, get_model=False, loglog=False, label=None, **kwargs): """Evaluate model fluxes. Parameters ---------- fill_SED : bool, optional If ``True``, update the model SED object after evaluation. nu : array-like or float, optional Frequency grid in Hz (or log10(Hz) if ``loglog`` is ``True``). If omitted, the internal ``nu_min``/``nu_max``/``nu_size`` grid is used. get_model : bool, optional If ``True``, return evaluated values. loglog : bool, optional If ``True``, treat input/output frequency and model values in log10 space. label : str, optional Reserved for subclasses/plotting integrations. **kwargs Extra keyword arguments accepted for subclass compatibility. Returns ------- ndarray or None Evaluated model array when ``get_model`` is ``True``, otherwise ``None``. """ out_model = None #print('--> base model 1', nu[0]) lin_nu,log_nu=self._prepare_nu_model(nu,loglog) #print('--> base model 2', lin_nu[0], log_nu[0],'loglog',loglog) lin_model,log_model = self._eval_model(lin_nu,log_nu,loglog) if fill_SED is True: self._fill(lin_nu, lin_model) if get_model is True: if loglog is True: out_model = log_model else: out_model = lin_model return out_model
def _set_up_plot(self,plot_obj,sed_data,frame,density): if plot_obj is None: plot_obj=PlotSED( frame = frame, density=density,sed_data=sed_data) z_sed_data = None if frame == 'src' and sed_data is not None: z_sed_data = sed_data.z if self.get_par_by_type('redshift') is not None: sed_data.z = self.get_par_by_type('redshift').val #if sed_data is not None: # plot_obj.add_data_plot(sed_data) if frame == 'src' and z_sed_data is not None: sed_data.z = z_sed_data return plot_obj
[docs] def plot_model(self,plot_obj=None,clean=False,sed_data=None,frame='obs',skip_components=False,label=None,line_style='-', density=False): """Plot model SED and optional spectral components. Parameters ---------- plot_obj : PlotSED, optional Existing plotting object. If omitted, a new one is created. clean : bool, optional If ``True``, clear existing model lines from ``plot_obj``. sed_data : ObsData, optional Optional observed data used by the plotting helper. frame : {'obs', 'src'}, optional Output frame for plotting. skip_components : bool, optional If ``True``, do not plot individual spectral components. label : str, optional Label for the model curve. line_style : str, optional Matplotlib line style. density : bool, optional If ``True``, use density representation in plotting helper. Returns ------- PlotSED Plot object with model curves added. """ plot_obj=self._set_up_plot(plot_obj,sed_data,frame,density) if clean is True: plot_obj.clean_model_lines() if label is None: label = self.name if hasattr(self,'SED'): plot_obj.add_model_plot(self.SED, line_style=line_style,label =label,flim=self.flux_plot_lim, frame=frame) if skip_components is False: if hasattr(self,'spectral_components_list'): for c in self.spectral_components_list: #print('--> c name', c.name) comp_label = c.name line_style = '--' if comp_label!='Sum': if hasattr(c, 'SED'): plot_obj.add_model_plot(c.SED, line_style=line_style, label=' -%s'%comp_label, flim=self.flux_plot_lim, frame=frame) line_style = '-' return plot_obj
[docs] def set_nu_grid(self,nu_min=None,nu_max=None,nu_size=None): """Set model frequency-grid settings. Parameters ---------- nu_min : float, optional Minimum frequency in Hz. nu_max : float, optional Maximum frequency in Hz. nu_size : int, optional Number of samples in the grid. """ if nu_size is not None: self.nu_size=nu_size if nu_min is not None: self.nu_min=nu_min if nu_max is not None: self.nu_max=nu_max
[docs] def lin_func(self,lin_nu): """Return model values in linear space. Parameters ---------- lin_nu : ndarray Frequency array in Hz. Returns ------- ndarray Model values in linear ``nuFnu`` units. """ return np.ones(lin_nu.shape) * self.flux_plot_lim
[docs] def log_func(self,log_nu): """Return model values in log10 space. Parameters ---------- log_nu : ndarray Frequency array in log10(Hz). Returns ------- ndarray Model values in log10 space. """ return np.log10(self.lin_func(np.power(10,log_nu)))
[docs] def get_residuals(self, data, log_log=False,filter_UL=True): """Compute residuals between observed data and model prediction. Parameters ---------- data : ObsData or table-like Input data table containing ``nu_data``, ``nuFnu_data``, ``dnuFnu_data`` and ``UL`` columns. log_log : bool, optional If ``True``, return frequency axis in log10(Hz). filter_UL : bool, optional If ``True``, exclude upper-limit points. Returns ------- tuple of ndarray ``(nu_axis, residuals)``. """ if isinstance(data,ObsData): data=data.data model = self.eval(nu=data['nu_data'], fill_SED=False, get_model=True, loglog=False) if filter_UL ==True: msk=data['UL']==False else: msk=np.ones(len(data),dtype=bool) residuals = (data['nuFnu_data'] - model) / data['dnuFnu_data'] nu_residuals=data['nu_data'] if log_log == False: return nu_residuals[msk], residuals[msk] else: return np.log10(nu_residuals[msk]), residuals[msk]
def save_model(self, file_name): """Serialize the model to disk. Parameters ---------- file_name : str Output pickle file path. """ pickle.dump(self, open(file_name, 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
[docs] def save_model(self, file_name): """Serialize the model to disk. Parameters ---------- file_name : str Output pickle file path. """ pickle.dump(self, open(file_name, 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
[docs] @classmethod def load_model(cls, file_name_or_obj,from_string=False): """Load a serialized model. Parameters ---------- file_name_or_obj : str or bytes or file-like Serialized model source accepted by the pickle loader. from_string : bool, optional If ``True``, interpret ``file_name_or_obj`` as in-memory content. Returns ------- Model Reconstructed and evaluated model instance. Raises ------ RuntimeError If deserialization fails or object type is not valid. """ try: c=cls._load_pickle(file_name_or_obj,from_string=from_string) c._fix_par_dep_on_load(verbose=True) if isinstance(c, Model): c.eval() return c else: raise RuntimeError('The model you loaded is not valid please check the file name') except Exception as e: raise RuntimeError(e)
def _fix_par_dep_on_load(self,verbose=True): #print("\n \n ========> fix dep on load START") for p in self.parameters.par_array: if p._is_dependent is True and p._linked is False: #print('==> _master_par_list',p._master_par_list, " for", p.name) #print('==> _depending_par_expr',p._depending_par_expr, " for", p.name) _master_par_list=[p for p in p._master_par_list] _depending_par_expr=copy.deepcopy(p._depending_par_expr) p.reset_dependencies() self.make_dependent_par(p.name, _master_par_list, _depending_par_expr,set_par_expr_source_code=True,verbose=verbose) @staticmethod def _load_pickle(file_name_or_obj,from_string=False): if from_string: c = pickle.loads(file_name_or_obj) else: c = pickle.load(open(file_name_or_obj, "rb")) return c
[docs] def clone(self): """Return a deep clone of the model via in-memory serialization. Returns ------- Model Cloned model instance. """ return self.load_model(pickle.dumps(self, protocol=pickle.HIGHEST_PROTOCOL),from_string=True)
[docs] def show_model(self): """Print model summary and parameters.""" print("") print('-'*80) print("model description") print('-' * 80) print("name: %s " % (self.name)) print("type: %s " % (self.model_type)) print('') print('-'*80) self.parameters.show_pars() print('-'*80)
[docs] def show_pars(self, sort_key='par type'): """Display model parameters. Parameters ---------- sort_key : str, optional Column used to sort displayed parameters. Returns ------- object Output returned by ``ModelParameterArray.show_pars``. """ return self.parameters.show_pars(sort_key=sort_key)
[docs] def show_best_fit_pars(self): """Display best-fit parameter values.""" self.parameters.show_best_fit_pars()
[docs] def set_par(self,par_name,val): """Set a parameter value by name. Parameters ---------- par_name : str Parameter name. val : float or int New parameter value. """ self.parameters.set(par_name, val=val)
[docs] def get_par_by_type(self,par_type): """Return first parameter matching a parameter type. Parameters ---------- par_type : str Parameter type label. Returns ------- ModelParameter or None Matching parameter, if found. """ for param in self.parameters.par_array: if param.par_type==par_type: return param return None
[docs] def get_par_by_name(self,par_name): """Return parameter by name. Parameters ---------- par_name : str Parameter name. Returns ------- ModelParameter or None Matching parameter, if found. """ for param in self.parameters.par_array: if param.name==par_name: return param return None
[docs] def dep_func_get_default_args(self, par_func): """Validate dependency-function arguments against model parameters. Parameters ---------- par_func : callable Function used as dependency expression. Returns ------- list of str Ordered list of parameter names accepted by ``par_func``. Raises ------ RuntimeError If ``par_func`` uses argument names not present in the model. """ signature = inspect.signature(par_func) d = [] for k, v in signature.parameters.items(): #print('==> par',k,v) p = self.get_par_by_name(k) if p is not None: d.append(k) else: raise RuntimeError('argument', k, 'is not valid, should be a model parameter name') return d
def _test_par_expr(self,master_par_list,par_expr): if type(par_expr) == str: _par_values = {p_name: 1 for p_name in master_par_list} try: eval(par_expr, {"np": np, "__builtins__": __builtins__}, _par_values) pass except: raise RuntimeError('the parameter expression is not valid')
[docs] def make_dependent_par(self, par, depends_on, par_expr,verbose=True,set_par_expr_source_code=True,master_pars=None): #print("\n ===> make par: ",par, "depending on : ",depends_on, " START") """Make dependent par. Parameters ---------- par : object Parameter object or parameter name. depends_on : object Names of master parameters used by the dependency. par_expr : object Expression defining a dependent parameter. verbose : bool, optional If ``True``, print additional information. set_par_expr_source_code : bool, optional If ``True``, store source code for the dependency expression. master_pars : object, optional Master-parameter objects used by the dependency. """ master_par_list = depends_on dep_par=self.parameters.get_par_by_name(par) if dep_par.name in master_par_list: raise RuntimeError("depending parameter:", dep_par.name, "can't be in master par list",master_par_list) self._test_par_expr(master_par_list,par_expr) dep_par.freeze() dep_par._is_dependent = True dep_par.par_expr = par_expr dep_par._func = dep_par._eval_par_func dep_par._master_par_list=master_par_list for p in master_par_list: try: m = self.parameters.get_par_by_name(p) dep_par._add_master_par(m,verbose=verbose) m._add_depending_par(dep_par) except Exception as e: message='problem with parameter name: %s'%p message+='\nexception:%s'%str(e) raise RuntimeError(message) for p in master_par_list: m = self.parameters.get_par_by_name(p) if m._is_dependent is False and m.par_type == 'user_defined': try: m.val=m.val except: pass if set_par_expr_source_code is True: dep_par._set_par_expr_source_code() if verbose is True: dep_par.par_expression_source_code
#print(" ===> make par: ",par, "depending on : ",depends_on, " END\n")
[docs] def add_user_par(self,name,val,units='',val_min=None,val_max=None): """Add a user-defined parameter to the model. Parameters ---------- name : str Parameter name. val : float Initial value. units : str, optional Parameter units label. val_min : float, optional Lower physical bound. val_max : float, optional Upper physical bound. """ self.parameters.add_par(ModelParameter(name=name,units=units,val=val,val_min=val_min,val_max=val_max,par_type='user_defined'))
[docs] def set_fit_range(self,down_tol=0.1,up_tol=100): """Set fit ranges for numeric parameters. Parameters ---------- down_tol : float, optional Multiplicative factor for lower fit bound. up_tol : float, optional Relative factor for upper fit bound. """ for p in self.parameters.par_array: if isinstance(p.val, numbers.Number): if p.val_min is not None: p.fit_range_min=max(p.val_min,p.val_lin*down_tol) else: p.fit_range_min = p.val_lin*down_tol if p.val_max is not None: p.fit_range_max=min(p.val_max,p.val_lin *(1+up_tol)) else: p.fit_range_max = p.val_lin*(1+up_tol)
[docs] def build_table(self, restframe='obs'): """Build SED table for the current model state. Parameters ---------- restframe : {'obs', 'src'}, optional Frame used to build frequency/flux columns. """ _names = ['nu'] _cols=[] if hasattr(self,'SED'): check_frame(restframe) if restframe=='obs': _cols.append(self.SED.nu) elif restframe=='src': _cols.append(self.SED.nu_src) else: unexpected_behaviour() if restframe == 'obs': _names.append('nuFnu') _cols.append(self.SED.nuFnu) else: _names.append('nuLnu_src') _cols.append(self.SED.nuLnu_src) _meta=dict(model_name=self.name) _meta['restframe']= restframe self._SED_table = Table(_cols, names=_names,meta=_meta) else: self._SED_table = None
[docs] def sed_table(self, restframe='obs'): """Return SED table, evaluating the model if needed. Parameters ---------- restframe : {'obs', 'src'}, optional Frame used to build frequency/flux columns. Returns ------- astropy.table.Table or None SED table for the current model state. """ try: self.build_table(restframe=restframe) except: self.eval() self.build_table(restframe=restframe) return self._SED_table
[docs] class MultiplicativeModel(Model): """Model subclass for multiplicative spectral components.""" def __init__(self, name='no-name', nu_size=100, model_type='multiplicative_model', scale='lin-lin'): """Initialize a multiplicative model. Parameters ---------- name : str, optional Model name. nu_size : int, optional Number of evaluation frequencies. model_type : str, optional Model type label. scale : str, optional Preferred plotting/evaluation scale. """ super(MultiplicativeModel, self).__init__(name=name, nu_size=nu_size, model_type=model_type,scale=scale) delattr(self,'SED')