__author__ = "Andrea Tramacere"
import os
from astropy.table import Table
try:
from sherpa.models.model import ArithmeticModel, modelCacher1d, RegriddableModel1D
from sherpa.models.parameter import Parameter
from sherpa import data
from sherpa.fit import Fit
from sherpa.stats import Chi2
from sherpa.optmethods import LevMar
from sherpa.fit import Fit
from sherpa.stats import Chi2
from sherpa import data as sherpa_data
sherpa_installed = True
except:
on_rtd = os.environ.get('READTHEDOCS', None) == 'True'
if on_rtd is True:
RegriddableModel1D=object
pass
else:
raise ImportError('to use sherpa plugin you need to install sherpa: https://sherpa.readthedocs.io/en/latest/install.html')
import numpy as np
from .plot_sedfit import PlotSED
from .minimizer import Minimizer
__all__=['JetsetSherpaModel','plot_sherpa_model','SherpaMinimizer']
[docs]
class JetsetSherpaModel(RegriddableModel1D):
"""
authomatic sherpa model generator
"""
def __init__(self, jetset_model,par_list=None,clone=False):
"""Create a new `JetsetSherpaModel` instance.
Parameters
----------
jetset_model : object
Parameter controlling jetset model.
par_list : object, optional
List of par.
clone : bool, optional
Parameter controlling clone.
"""
if clone is True:
self._jetset_model = jetset_model.clone()
else:
self._jetset_model=jetset_model
self._jp_list = []
self._jp_par_array = []
self._jp_list_names = []
setattr(self, '_jetset_ncalls', 0)
keep=True
#print('-->, ', par_list)
for p in self._jetset_model.parameters.par_array:
if par_list is None:
keep = True
else:
keep = p in par_list
#print('-->, ',p, p.name, keep)
if p is not None and keep is True:
if p.name.lower() in self._jp_list_names or p.name.upper() in self._jp_list_names:
name = p.name + '_sh'
print('jetset model name', p.name, 'renamed to ', name, 'due to sherpa internal naming convention')
else:
name = p.name
if p.fit_range_min is not None:
val_min=p.fit_range_min
else:
val_min = p.val_min
if p.fit_range_max is not None:
val_max = p.fit_range_max
else:
val_max = p.val_max
sh_p = Parameter(self._jetset_model.name, name, p.val, min=val_min, max=val_max, units=p.units, frozen=p.frozen)
setattr(self, sh_p.name, sh_p)
p._sherpa_ref = sh_p
if np.isnan(sh_p.max):
sh_p.max = sh_p.hard_max
if np.isnan(sh_p.min):
sh_p.min = sh_p.hard_min
self._jp_list.append(sh_p)
self._jp_par_array.append(p)
self._jp_list_names.append(p.name)
RegriddableModel1D.__init__(self, jetset_model.name,(p._sherpa_ref for p in self._jp_par_array))
[docs]
def calc(self, pars, x):
"""Calc.
Parameters
----------
pars : object
Parameter controlling pars.
x : object
Parameter controlling x.
Returns
-------
object
Computed result.
"""
for ID, p in enumerate(self._jp_list):
j_p = self._jp_par_array[ID]
j_p.val = p.val
self._jetset_ncalls +=1
return self._jetset_model.eval(get_model=True, nu=x)
[docs]
def plot_model(self, fit_range, model_range=[1E10, 1E30], nu_grid_size=200, plot_obj=None, sed_data=None):
"""Plot model.
Parameters
----------
fit_range : object
Range for fit.
model_range : list, optional
Range for model.
nu_grid_size : int, optional
Array/grid values for nu grid size.
plot_obj : object, optional
Existing plot object to update.
sed_data : object, optional
Observational SED data container.
"""
self._jetset_model.set_nu_grid(model_range[0], model_range[1], nu_grid_size)
self._jetset_model.eval()
plot_obj = self._jetset_model.plot_model(plot_obj=plot_obj, sed_data=sed_data)
plot_obj.add_model_residual_plot(data=sed_data, model=self._jetset_model,
fit_range=[fit_range[0], fit_range[1]])
[docs]
def plot_sherpa_model(sherpa_model, fit_range=None, model_range=[1E10, 1E30], nu_grid_size=200, sed_data=None,
add_res=False, plot_obj=None, label=None, line_style=None):
"""Plot sherpa model.
Parameters
----------
sherpa_model : object
Parameter controlling sherpa model.
fit_range : object, optional
Range for fit.
model_range : list, optional
Range for model.
nu_grid_size : int, optional
Array/grid values for nu grid size.
sed_data : object, optional
Observational SED data container.
add_res : bool, optional
Parameter controlling add res.
plot_obj : object, optional
Existing plot object to update.
label : object, optional
Label used in output or plots.
line_style : object, optional
Parameter controlling line style.
Returns
-------
object
Computed result.
"""
if fit_range is not None:
x = np.logspace(np.log10(fit_range[0]), np.log10(fit_range[1]), nu_grid_size)
else:
x = np.logspace(np.log10(model_range[0]), np.log10(model_range[1]), nu_grid_size)
y = sherpa_model(x)
if plot_obj is None:
plot_obj = PlotSED(frame='obs', density=False)
if sed_data is not None:
plot_obj.add_data_plot(sed_data=sed_data)
plot_obj.add_xy_plot(x, y, label=label, line_style=line_style)
if add_res is True and fit_range is not None:
nufnu_res = sherpa_model(sed_data.data['nu_data'])
y_res = (sed_data.data['nuFnu_data'] - nufnu_res) / sed_data.data['dnuFnu_data']
x_res = sed_data.data['nu_data']
plot_obj.add_xy_residual_plot(x=x_res, y=y_res, fit_range= [fit_range[0], fit_range[1]])
return plot_obj
[docs]
class SherpaMinimizer(Minimizer):
"""Sherpa-based minimizer backend for JetSeT fits.
Notes
-----
Wraps a JetSeT model into a Sherpa model/data pair, runs the selected
Sherpa optimizer/statistic, and maps fit outputs back to JetSeT structures.
"""
def __init__(self, model,method=None,stat=None):
"""Create a new `SherpaMinimizer` instance.
Parameters
----------
model : object
Model instance.
method : object, optional
Parameter controlling method.
stat : object, optional
Parameter controlling stat.
"""
if sherpa_installed is True:
pass
else:
raise ImportError('sherpa not installed, \n to use sherpa plugin you need to install sherpa: https://sherpa.readthedocs.io/en/latest/install.html')
if method is None:
method=LevMar()
if stat is None:
stat=Chi2()
super(SherpaMinimizer, self).__init__(model)
self._method=method
self._stat=stat
self._sherpa_model = None
self._sherpa_data = None
self.pbar = None
def _create_sherpa_model(self):
self._sherpa_model = JetsetSherpaModel(jetset_model = self.model.fit_model, par_list=self.model.fit_par_free)
def _create_sherpa_data(self):
self._sherpa_data = sherpa_data.Data1D("sed", self.model.data['x'], self.model.data['y'], staterror=self.model.data['dy'])
@property
def sherpa_fitter(self):
"""Sherpa fitter.
Returns
-------
object
Requested value.
"""
return self._sherpa_fitter
@property
def calls(self):
"""Calls.
Returns
-------
object
Requested value.
"""
if self._sherpa_model is not None:
return self._sherpa_model._jetset_ncalls
else:
return None
@calls.setter
def calls(self,n):
"""Calls.
Parameters
----------
n : object
Parameter controlling n.
"""
if self._sherpa_model is not None:
self._sherpa_model._jetset_ncalls = n
def _fit(self, max_ev,):
self._create_sherpa_model()
self._create_sherpa_data()
self._sherpa_model._jetset_ncalls = 0
self._sherpa_fitter=Fit(self._sherpa_data,self._sherpa_model, method=self._method,stat=self._stat)
self.mesg = self._sherpa_fitter.fit()
self.covar = self.mesg.covar
self.pout = [p for p in self.mesg.parvals]
self.p = [p for p in self.mesg.parvals]
def _set_fit_errors(self):
self.errors = [np.sqrt(np.fabs(self.covar[pi, pi])) for pi in range(len(self.model.fit_par_free))]
def sherpa_model_to_table(sherpa_model):
"""Sherpa model to table.
Parameters
----------
sherpa_model : object
Parameter controlling sherpa model.
Returns
-------
object
Computed result.
"""
rows=[]
for p in sherpa_model.pars:
if p.link is not None:
r=[p.modelname,p.name,p.val,p.min,p.max,p.frozen,p.units,True,p.link.name,p.link.modelname]
else:
r=[p.modelname,p.name,p.val,p.min,p.max,p.frozen,p.units,False,'','']
rows.append(r)
return Table(names=['model name' ,'name' , 'val' , 'min','max', 'frozen', 'units','linked','linked par','linked model'],rows=rows)