__author__ = "Andrea Tramacere"
import os
from astropy.table import Table
try:
from sherpa.models.model import ArithmeticModel, modelCacher1d, RegriddableModel1D
from sherpa.models.parameter import Parameter
from sherpa import data
from sherpa.fit import Fit
from sherpa.stats import Chi2
from sherpa.optmethods import LevMar
from sherpa.fit import Fit
from sherpa.stats import Chi2
from sherpa import data as sherpa_data
sherpa_installed = True
except:
on_rtd = os.environ.get('READTHEDOCS', None) == 'True'
if on_rtd is True:
RegriddableModel1D=object
pass
else:
raise ImportError('to use sherpa plugin you need to install sherpa: https://sherpa.readthedocs.io/en/latest/install.html')
import numpy as np
from .plot_sedfit import PlotSED
from .minimizer import Minimizer
__all__=['JetsetSherpaModel','plot_sherpa_model']
[docs]
class JetsetSherpaModel(RegriddableModel1D):
"""
authomatic sherpa model generator
"""
def __init__(self, jetset_model,par_list=None,clone=False):
if clone is True:
self._jetset_model = jetset_model.clone()
else:
self._jetset_model=jetset_model
self._jp_list = []
self._jp_par_array = []
self._jp_list_names = []
setattr(self, '_jetset_ncalls', 0)
keep=True
#print('-->, ', par_list)
for p in self._jetset_model.parameters.par_array:
if par_list is None:
keep = True
else:
keep = p in par_list
#print('-->, ',p, p.name, keep)
if p is not None and keep is True:
if p.name.lower() in self._jp_list_names or p.name.upper() in self._jp_list_names:
name = p.name + '_sh'
print('jetset model name', p.name, 'renamed to ', name, 'due to sherpa internal naming convention')
else:
name = p.name
if p.fit_range_min is not None:
val_min=p.fit_range_min
else:
val_min = p.val_min
if p.fit_range_max is not None:
val_max = p.fit_range_max
else:
val_max = p.val_max
sh_p = Parameter(self._jetset_model.name, name, p.val, min=val_min, max=val_max, units=p.units, frozen=p.frozen)
setattr(self, sh_p.name, sh_p)
p._sherpa_ref = sh_p
if np.isnan(sh_p.max):
sh_p.max = sh_p.hard_max
if np.isnan(sh_p.min):
sh_p.min = sh_p.hard_min
self._jp_list.append(sh_p)
self._jp_par_array.append(p)
self._jp_list_names.append(p.name)
RegriddableModel1D.__init__(self, jetset_model.name,(p._sherpa_ref for p in self._jp_par_array))
[docs]
def calc(self, pars, x):
for ID, p in enumerate(self._jp_list):
j_p = self._jp_par_array[ID]
j_p.val = p.val
self._jetset_ncalls +=1
return self._jetset_model.eval(get_model=True, nu=x)
[docs]
def plot_model(self, fit_range, model_range=[1E10, 1E30], nu_grid_size=200, plot_obj=None, sed_data=None):
self._jetset_model.set_nu_grid(model_range[0], model_range[1], nu_grid_size)
self._jetset_model.eval()
plot_obj = self._jetset_model.plot_model(plot_obj=plot_obj, sed_data=sed_data)
plot_obj.add_model_residual_plot(data=sed_data, model=self._jetset_model,
fit_range=[fit_range[0], fit_range[1]])
[docs]
def plot_sherpa_model(sherpa_model, fit_range=None, model_range=[1E10, 1E30], nu_grid_size=200, sed_data=None,
add_res=False, plot_obj=None, label=None, line_style=None):
if fit_range is not None:
x = np.logspace(np.log10(fit_range[0]), np.log10(fit_range[1]), nu_grid_size)
else:
x = np.logspace(np.log10(model_range[0]), np.log10(model_range[1]), nu_grid_size)
y = sherpa_model(x)
if plot_obj is None:
plot_obj = PlotSED(frame='obs', density=False)
if sed_data is not None:
plot_obj.add_data_plot(sed_data=sed_data)
plot_obj.add_xy_plot(x, y, label=label, line_style=line_style)
if add_res is True and fit_range is not None:
nufnu_res = sherpa_model(sed_data.data['nu_data'])
y_res = (sed_data.data['nuFnu_data'] - nufnu_res) / sed_data.data['dnuFnu_data']
x_res = sed_data.data['nu_data']
plot_obj.add_xy_residual_plot(x=x_res, y=y_res, fit_range= [fit_range[0], fit_range[1]])
return plot_obj
class SherpaMinimizer(Minimizer):
def __init__(self, model,method=None,stat=None):
if sherpa_installed is True:
pass
else:
raise ImportError('sherpa not installed, \n to use sherpa plugin you need to install sherpa: https://sherpa.readthedocs.io/en/latest/install.html')
if method is None:
method=LevMar()
if stat is None:
stat=Chi2()
super(SherpaMinimizer, self).__init__(model)
self._method=method
self._stat=stat
self._sherpa_model = None
self._sherpa_data = None
self.pbar = None
def _create_sherpa_model(self):
self._sherpa_model = JetsetSherpaModel(jetset_model = self.model.fit_model, par_list=self.model.fit_par_free)
def _create_sherpa_data(self):
self._sherpa_data = sherpa_data.Data1D("sed", self.model.data['x'], self.model.data['y'], staterror=self.model.data['dy'])
@property
def sherpa_fitter(self):
return self._sherpa_fitter
@property
def calls(self):
if self._sherpa_model is not None:
return self._sherpa_model._jetset_ncalls
else:
return None
@calls.setter
def calls(self,n):
if self._sherpa_model is not None:
self._sherpa_model._jetset_ncalls = n
def _fit(self, max_ev,):
self._create_sherpa_model()
self._create_sherpa_data()
self._sherpa_model._jetset_ncalls = 0
self._sherpa_fitter=Fit(self._sherpa_data,self._sherpa_model, method=self._method,stat=self._stat)
self.mesg = self._sherpa_fitter.fit()
self.covar = self.mesg.covar
self.pout = [p for p in self.mesg.parvals]
self.p = [p for p in self.mesg.parvals]
def _set_fit_errors(self):
self.errors = [np.sqrt(np.fabs(self.covar[pi, pi])) for pi in range(len(self.model.fit_par_free))]
def sherpa_model_to_table(sherpa_model):
rows=[]
for p in sherpa_model.pars:
if p.link is not None:
r=[p.modelname,p.name,p.val,p.min,p.max,p.frozen,p.units,True,p.link.name,p.link.modelname]
else:
r=[p.modelname,p.name,p.val,p.min,p.max,p.frozen,p.units,False,'','']
rows.append(r)
return Table(names=['model name' ,'name' , 'val' , 'min','max', 'frozen', 'units','linked','linked par','linked model'],rows=rows)