Source code for jetset.base_model


__author__ = "Andrea Tramacere"


import numpy as np
import copy
import dill as pickle
import warnings
import inspect
import numbers
from astropy.table import Table

from .model_parameters import ModelParameterArray, ModelParameter
from .spectral_shapes import SED
from .data_loader import  ObsData
from .utils import  get_info
from .plot_sedfit import  PlotSED
from .utils import check_frame,unexpected_behaviour

from .cosmo_tools import  Cosmo

__all__=['Model','MultiplicativeModel']




[docs] class Model(object): def __init__(self,name='no_name', nu_size=200, model_type='base_model', scale='lin-lin', cosmo=None, nu_min=None, nu_max=None): self.model_type=model_type self.name=name self.SED = SED(name=self.name) self.parameters = ModelParameterArray() self._scale=scale self.nu_size=nu_size self.nu_min=nu_min self.nu_max=nu_max self.flux_plot_lim = 1E-30 if cosmo is None: self.cosmo=Cosmo() else: self.cosmo=cosmo self._set_version(v=None) @property def version(self): return self._version def _set_version(self, v=None): if v is None: self._version = get_info()['version'] else: self._version = v def _prepare_nu_model(self,nu,loglog): if nu is None: x1 = np.log10(self.nu_min) x2 = np.log10(self.nu_max) lin_nu = np.logspace(x1, x2, self.nu_size) log_nu = np.log10(lin_nu) else: if np.shape(nu) == (): nu = np.array([nu]) if loglog is True: lin_nu = np.power(10., nu) log_nu = nu else: log_nu = np.log10(nu) lin_nu = nu return lin_nu,log_nu def _eval_model(self,lin_nu,log_nu,loglog): log_model=None if loglog is False: lin_model = self.lin_func(lin_nu) else: if hasattr(self, 'log_func'): log_model = self.log_func(log_nu) lin_model = np.power(10., log_model) else: lin_model = self.lin_func(lin_nu) lin_model[lin_model<0.]=self.flux_plot_lim log_model = np.log10(lin_model) return lin_model,log_model def _fill(self, lin_nu, lin_model): if hasattr(self,'SED'): self.SED.fill(nu=lin_nu, nuFnu=lin_model) z=self.get_par_by_type('redshift') if z is not None: z=z.val if z is None and hasattr(self, 'get_redshift'): z=self.get_redshift() z = z #print('--> fill z ',self.name,self.name,z) if z is not None: if hasattr(self,'get_DL_cm'): dl = self.get_DL_cm('redshift') else: dl = self.cosmo.get_DL_cm(z) self.SED.fill_nuLnu( z =z, dl = dl) else: warnings.warn('model',self.name,'of type',type(self),'has no SED member') if hasattr(self, 'spectral_components_list'): for i in range(len(self.spectral_components_list)): self.spectral_components_list[i].fill_SED(lin_nu=lin_nu,skip_zeros=False)
[docs] def eval(self, fill_SED=True, nu=None, get_model=False, loglog=False, label=None, **kwargs): out_model = None #print('--> base model 1', nu[0]) lin_nu,log_nu=self._prepare_nu_model(nu,loglog) #print('--> base model 2', lin_nu[0], log_nu[0],'loglog',loglog) lin_model,log_model = self._eval_model(lin_nu,log_nu,loglog) if fill_SED is True: self._fill(lin_nu, lin_model) if get_model is True: if loglog is True: out_model = log_model else: out_model = lin_model return out_model
def _set_up_plot(self,plot_obj,sed_data,frame,density): if plot_obj is None: plot_obj=PlotSED( frame = frame, density=density,sed_data=sed_data) z_sed_data = None if frame == 'src' and sed_data is not None: z_sed_data = sed_data.z if self.get_par_by_type('redshift') is not None: sed_data.z = self.get_par_by_type('redshift').val #if sed_data is not None: # plot_obj.add_data_plot(sed_data) if frame == 'src' and z_sed_data is not None: sed_data.z = z_sed_data return plot_obj
[docs] def plot_model(self,plot_obj=None,clean=False,sed_data=None,frame='obs',skip_components=False,label=None,line_style='-', density=False): plot_obj=self._set_up_plot(plot_obj,sed_data,frame,density) if clean is True: plot_obj.clean_model_lines() if label is None: label = self.name if hasattr(self,'SED'): plot_obj.add_model_plot(self.SED, line_style=line_style,label =label,flim=self.flux_plot_lim,density=density, frame=frame) if skip_components is False: if hasattr(self,'spectral_components_list'): for c in self.spectral_components_list: #print('--> c name', c.name) comp_label = c.name line_style = '--' if comp_label!='Sum': if hasattr(c, 'SED'): plot_obj.add_model_plot(c.SED, line_style=line_style, label=' -%s'%comp_label, flim=self.flux_plot_lim, density=density, frame=frame) line_style = '-' #plot_obj.add_model_residual_plot(data=sed_data, model=self,fit_range=np.log10([self.nu_min_fit,self.nu_max_fit]) ) return plot_obj
[docs] def set_nu_grid(self,nu_min=None,nu_max=None,nu_size=None): if nu_size is not None: self.nu_size=nu_size if nu_min is not None: self.nu_min=nu_min if nu_max is not None: self.nu_max=nu_max
[docs] def lin_func(self,lin_nu): return np.ones(lin_nu.size) * self.flux_plot_lim
[docs] def log_func(self,log_nu): return np.log10(self.lin_func(np.power(10,log_nu)))
[docs] def get_residuals(self, data, log_log=False,filter_UL=True): if isinstance(data,ObsData): data=data.data model = self.eval(nu=data['nu_data'], fill_SED=False, get_model=True, loglog=False) if filter_UL ==True: msk=data['UL']==False else: msk=np.ones(len(data),dtype=bool) residuals = (data['nuFnu_data'] - model) / data['dnuFnu_data'] nu_residuals=data['nu_data'] if log_log == False: return nu_residuals[msk], residuals[msk] else: return np.log10(nu_residuals[msk]), residuals[msk]
[docs] def save_model(self, file_name): pickle.dump(self, open(file_name, 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
[docs] @classmethod def load_model(cls, file_name): try: c = pickle.load(open(file_name, "rb")) c._fix_par_dep_on_load() if isinstance(c, Model): c.eval() return c else: raise RuntimeError('The model you loaded is not valid please check the file name') except Exception as e: raise RuntimeError(e)
def _fix_par_dep_on_load(self,): #print("\n \n ========> fix dep on load START") for p in self.parameters.par_array: if p._is_dependent is True and p._linked is False: #print('==> _master_par_list',p._master_par_list, " for", p.name) #print('==> _depending_par_expr',p._depending_par_expr, " for", p.name) _master_par_list=[p for p in p._master_par_list] _depending_par_expr=copy.deepcopy(p._depending_par_expr) p.reset_dependencies() self.make_dependent_par(p.name, _master_par_list, _depending_par_expr,set_par_expr_source_code=True) #def _set_pars_dep(self): # for p in self.parameters.par_array: # if #def _set_pars_dep(self): # for p in self.parameters.par_array: # if
[docs] def clone(self): return pickle.loads(pickle.dumps(self))
[docs] def show_model(self): print("") print('-'*80) print("model description") print('-' * 80) print("name: %s " % (self.name)) print("type: %s " % (self.model_type)) print('') print('-'*80) self.parameters.show_pars() print('-'*80)
[docs] def show_pars(self, sort_key='par type'): return self.parameters.show_pars(sort_key=sort_key)
[docs] def show_best_fit_pars(self): self.parameters.show_best_fit_pars()
[docs] def set_par(self,par_name,val): """ shortcut to :class:`ModelParametersArray.set` method set a parameter value :param par_name: (srt), name of the parameter :param val: parameter value """ self.parameters.set(par_name, val=val)
[docs] def get_par_by_type(self,par_type): """ get parameter by type """ for param in self.parameters.par_array: if param.par_type==par_type: return param return None
[docs] def get_par_by_name(self,par_name): """ get parameter by type """ for param in self.parameters.par_array: if param.name==par_name: return param return None
[docs] def dep_func_get_default_args(self, par_func): signature = inspect.signature(par_func) d = [] for k, v in signature.parameters.items(): #print('==> par',k,v) p = self.get_par_by_name(k) if p is not None: d.append(k) else: raise RuntimeError('argument', k, 'is not valid, should be a model parameter name') return d
def _test_par_expr(self,master_par_list,par_expr): if type(par_expr) == str: for p_name in master_par_list: exec(p_name + '= 1') try: eval(par_expr) pass except: raise RuntimeError('the parameter expression is not valid')
[docs] def make_dependent_par(self, par, depends_on, par_expr,verbose=True,set_par_expr_source_code=True,master_pars=None): #print("\n ===> make par: ",par, "depending on : ",depends_on, " START") master_par_list = depends_on dep_par=self.parameters.get_par_by_name(par) if dep_par.name in master_par_list: raise RuntimeError("depending parameter:", dep_par.name, "can't be in master par list",master_par_list) self._test_par_expr(master_par_list,par_expr) dep_par.freeze() dep_par._is_dependent = True dep_par.par_expr = par_expr dep_par._func = dep_par._eval_par_func dep_par._master_par_list=master_par_list for p in master_par_list: try: m = self.parameters.get_par_by_name(p) dep_par._add_master_par(m) m._add_depending_par(dep_par) except Exception as e: message='problem with parameter name: %s'%p message+='\nexception:%s'%str(e) raise RuntimeError(message) for p in master_par_list: m = self.parameters.get_par_by_name(p) if m._is_dependent is False and m.par_type == 'user_defined': try: m.val=m.val except: pass if set_par_expr_source_code is True: dep_par._set_par_expr_source_code() if verbose is True: dep_par.par_expression_source_code
#print(" ===> make par: ",par, "depending on : ",depends_on, " END\n")
[docs] def add_user_par(self,name,val,units='',val_min=None,val_max=None): self.parameters.add_par(ModelParameter(name=name,units=units,val=val,val_min=val_min,val_max=val_max,par_type='user_defined'))
[docs] def set_fit_range(self,down_tol=0.1,up_tol=100): for p in self.parameters.par_array: if isinstance(p.val, numbers.Number): if p.val_min is not None: p.fit_range_min=max(p.val_min,p.val_lin*down_tol) else: p.fit_range_min = p.val_lin*down_tol if p.val_max is not None: p.fit_range_max=min(p.val_max,p.val_lin *(1+up_tol)) else: p.fit_range_max = p.val_lin*(1+up_tol)
[docs] def build_table(self, restframe='obs'): _names = ['nu'] _cols=[] if hasattr(self,'SED'): check_frame(restframe) if restframe=='obs': _cols.append(self.SED.nu) elif restframe=='src': _cols.append(self.SED.nu_src) else: unexpected_behaviour() if restframe == 'obs': _names.append('nuFnu') _cols.append(self.SED.nuFnu) else: _names.append('nuLnu_src') _cols.append(self.SED.nuLnu_src) _meta=dict(model_name=self.name) _meta['restframe']= restframe self._SED_table = Table(_cols, names=_names,meta=_meta)
[docs] def sed_table(self, restframe='obs'): try: self.build_table(restframe=restframe) except: self.eval() self.build_table(restframe=restframe) return self._SED_table
[docs] class MultiplicativeModel(Model): def __init__(self, name='no-name', nu_size=100, model_type='multiplicative_model', scale='lin-lin'): super(MultiplicativeModel, self).__init__(name=name, nu_size=nu_size, model_type=model_type,scale=scale) delattr(self,'SED')