"""MCMC sampling utilities built on emcee for JetSeT fit models."""
__author__ = "Andrea Tramacere"
from .minimizer import _eval_res
import emcee
from itertools import cycle
import types
import time
import numpy as np
import corner
import dill as pickle
import warnings
import math
from .plot_sedfit import plt, set_mpl
from .mcmc_parameters import(
McmcCompositeModelParameterArray,
get_mcmc_bound_max,
get_mcmc_bound_min,
set_mcmc_bound_max,
set_mcmc_bound_min,
_check_par_mcmc_bounds
)
__all__=['McmcSampler']
class Counter(object):
"""Progress counter used during MCMC sampling loops.
Notes
-----
Stores total calls, accepted calls, and a spinner iterator used by
text-based progress feedback.
"""
def __init__(self,count_tot):
"""Create a new `Counter` instance.
Parameters
----------
count_tot : object
Total number of expected iterations/samples.
"""
self.count = 0
self.count_OK = 0
self.count_tot = count_tot
self._progress_iter = cycle(['|', '/', '-', '\\'])
def sample_ball(p0, std, size=1):
"""Sample ball.
Parameters
----------
p0 : object
Initial parameter vector.
std : object
Per-parameter standard deviations for sampling.
size : int, optional
Number of samples or sample size.
Returns
-------
object
Computed value.
"""
assert len(p0) == len(std)
return np.vstack(
[p0 + std * np.random.normal(size=len(p0)) for i in range(size)]
)
[docs]
class McmcSampler(object):
"""Run and manage ``emcee`` sampling for a fitted JetSeT model.
Notes
-----
Wraps an input :class:`~jetset.minimizer.ModelMinimizer`, clones its model
state, stores sampler outputs, and provides serialization-safe state
handling for chain analysis and plotting.
"""
def __init__(self,model_minimizer,build_mcmc_parameters=True):
"""Create a new `McmcSampler` instance.
Parameters
----------
model_minimizer : object
Initialized model-minimizer object.
"""
if emcee.__version__ < "3":
raise RuntimeError('Please update to emcee v>=3.0.0')
self.model = model_minimizer.fit_model.clone()
self.data = model_minimizer.data
self._bounds_sampler=[]
self.model.parameters.__class__=McmcCompositeModelParameterArray
self._progress_iter = cycle(['|', '/', '-', '\\'])
if build_mcmc_parameters:
self._bulild_mcmc_paramters()
@property
def _par_array_sampler(self):
return [p for p in self.model.parameters.par_array if p.frozen is False]
@staticmethod
def _new_progress_iter():
return cycle(['|', '/', '-', '\\'])
def __getstate__(self):
# Keep serialized state free of runtime-only objects not stable across Python versions.
state = self.__dict__.copy()
state['_progress_iter'] = None
state['sampler'] = None
state['sate_of_mcmc_parameters']={}
for model in self.model.components.components_list:
state['sate_of_mcmc_parameters'][model.name]={}
for p in self.model.parameters.par_array:
state['sate_of_mcmc_parameters'][model.name][p.name]={}
state['sate_of_mcmc_parameters'][model.name][p.name]['best_fit_mcmc_val']=p.best_fit_mcmc_val
state['sate_of_mcmc_parameters'][model.name][p.name]['q_16']=p.q_16
state['sate_of_mcmc_parameters'][model.name][p.name]['q_50']=p.q_50
state['sate_of_mcmc_parameters'][model.name][p.name]['q_84']=p.q_84
state['sate_of_mcmc_parameters'][model.name][p.name]['plot_label']=p.plot_label
state['sate_of_mcmc_parameters'][model.name][p.name]['mcmc_bound_min']=p.mcmc_bound_min
state['sate_of_mcmc_parameters'][model.name][p.name]['mcmc_bound_max']=p.mcmc_bound_max
return state
def __setstate__(self, state):
self.__dict__.update(state)
self._progress_iter = self._new_progress_iter()
if hasattr(self, 'chain') and self.chain is not None:
self.chain = self._as_walker_first_chain(self.chain)
if hasattr(self, 'log_prob_chain') and self.log_prob_chain is not None:
self.log_prob_chain = self._as_walker_first_log_prob(self.log_prob_chain)
def _as_walker_first_chain(self, chain):
_chain = np.asarray(chain)
if _chain.ndim != 3:
return _chain
if hasattr(self, 'nwalkers'):
if _chain.shape[0] == self.nwalkers:
return _chain
if _chain.shape[1] == self.nwalkers:
return np.swapaxes(_chain, 0, 1)
return _chain
def _as_walker_first_log_prob(self, log_prob_chain):
_logp = np.asarray(log_prob_chain)
if _logp.ndim != 2:
return _logp
if hasattr(self, 'nwalkers'):
if _logp.shape[0] == self.nwalkers:
return _logp
if _logp.shape[1] == self.nwalkers:
return np.swapaxes(_logp, 0, 1)
return _logp
def _cache_sampler_chains(self):
# Persist full (pre-burnin) traces in walker-first layout for plotting after reload.
_chain = None
try:
_chain = self.sampler.get_chain()
except Exception:
_chain = self.sampler.get_chain(flat=False)
self.chain = self._as_walker_first_chain(_chain)
try:
_logp = self.sampler.get_log_prob()
except Exception:
_logp = self.sampler.get_log_prob(flat=False)
self.log_prob_chain = self._as_walker_first_log_prob(_logp)
def _bulild_mcmc_paramters(self):
"""dynamically create McmcParameter with proper inheritance"""
for p in self.model.parameters.par_array:
Base = p.__class__
if not Base.__name__.startswith("McmcParameter_"):
McmcParameter = types.new_class(f"McmcParameter_{Base.__name__}", (Base,))
McmcParameter.x = property(set_mcmc_bound_max,get_mcmc_bound_max)
McmcParameter.x = property(set_mcmc_bound_min,get_mcmc_bound_min)
p.__class__ = McmcParameter
p._check_par_mcmc_bounds = types.MethodType(_check_par_mcmc_bounds, p)
p.best_fit_mcmc_val=p.val
p.q_16=None
p.q_50=None
p.q_84=None
if not p.frozen:
if p.fit_range_min is not None:
p.mcmc_bound_min=p.fit_range_min
else:
p.mcmc_bound_min=None
if p.fit_range_max is not None:
p.mcmc_bound_max= p.fit_range_max
else:
p.mcmc_bound_max=None
else:
p.mcmc_bound_min=p.val_min
p.mcmc_bound_max=p.val_max
p.plot_label=p.name
@property
def best_fit_par_table(self):
return self.model.parameters.best_fit_par_table
@property
def sampler_parameters(self):
"""sampler table.
Returns
-------
object
Requested value.
"""
return self.model.parameters._build_sampler_par_table()
[docs]
def set_bounds(self,par_name=None,comp_name=None,bound=0.2,bound_rel=False,preserve_fit_range=True,par_bounds=None,zero_abs_tol=1E-200):
"""Set bounds.
Parameters
----------
bound : float, optional
Absolute parameter-bound span.
bound_rel : bool, optional
Relative parameter-bound span.
preserve_fit_range : bool, optional
Range for preserve fit.
"""
#all pars
if comp_name is None and par_name is None:
if par_bounds is not None:
raise RuntimeError('if you pass par_bounds, please provide both model comp_name and par_name')
for par in self._par_array_sampler:
self._set_bounds(par=par,bound=bound,bound_rel=bound_rel,preserve_fit_range=preserve_fit_range,par_bounds=par_bounds,zero_abs_tol=zero_abs_tol)
return
elif comp_name is not None and par_name is not None:
if np.shape(par_bounds)!=(2,):
raise RuntimeError('please provide par_bounds as [min_bound, max_bound], with min_bound<max_bound')
if par_bounds[0]>=par_bounds[1]:
raise RuntimeError('please provide par_bounds as [min_bound, max_bound], with min_bound<max_bound')
par=self.get_par(par_name,comp_name=comp_name)
self._set_bounds(par=par,bound=bound,bound_rel=bound_rel,preserve_fit_range=preserve_fit_range,par_bounds=par_bounds,zero_abs_tol=zero_abs_tol)
return
else:
raise RuntimeError('please, provide both par_name and comp_name')
def _set_bounds(self, par, bound=0.2,bound_rel=True,preserve_fit_range=True,par_bounds=None,zero_abs_tol=1E-200):
if par.best_fit_val is None:
ref_val=par.val
else:
ref_val=par.best_fit_val
if par.islog:
ref_val=10**ref_val
if par_bounds == [] or par_bounds is None:
if np.shape(bound) == ():
bound=[bound,bound]
elif np.shape(bound) == (2,):
pass
else:
raise RuntimeError('bound shape', np.shape(bound), 'it is wrong, has to be a scalar or (2,)')
if not bound_rel:
delta_p = ref_val * bound[1]
delta_m = ref_val * bound[0]
else:
if math.isclose(np.fabs(ref_val),0,abs_tol=zero_abs_tol):
raise RuntimeError(f"You can't set relative bounds for reference value equal to 0, par={par.name}, of model comp: {par.model.name}, val={ref_val}")
else:
delta_p = np.fabs(ref_val)*bound[1]
delta_m = np.fabs(ref_val)*bound[0]
_min = ref_val - delta_m
_max = ref_val + delta_p
else:
if par.islog:
_min,_max=[10**par_bounds[0],10**par_bounds[1]]
else:
_min,_max=par_bounds
if par.fit_range_min is not None and preserve_fit_range is True:
if par.islog:
fit_range_min=10**par.fit_range_min
else:
fit_range_min=par.fit_range_min
_min= max(_min, fit_range_min )
elif par.val_min is not None:
if par.islog:
val_min=10**par.val_min
else:
val_min=par.val_min
_min= max(_min,val_min)
if par.fit_range_max is not None and preserve_fit_range is True:
if par.islog:
fit_range_max=10**par.fit_range_max
else:
fit_range_max=par.fit_range_max
_max= min(_max,fit_range_max )
elif par.val_max is not None:
if par.islog:
val_max=10**par.val_max
else:
val_max=par.val_max
_max= min(_max,val_max)
if ref_val>_max or ref_val<_min:
if par.islog:
_max=np.log10(_max)
_min=np.log10(_min)
ref_val=np.log10(ref_val)
raise RuntimeError(f'please set bounds for par: {par.name} of model comp: {par.model.name} such that {_min}<={ref_val}<={_max}')
if par.islog:
_max=np.log10(_max)
_min=np.log10(_min)
ref_val=np.log10(ref_val)
print('par:',par.name,' ref value: ',ref_val,' mcmc bounds:',[_min, _max])
par.mcmc_bound_max=_max
par.mcmc_bound_min=_min
def _build_sampler_bounds(self):
self._bounds_sampler=[]
for par in self._par_array_sampler:
self._bounds_sampler.append([par.mcmc_bound_min,par.mcmc_bound_max])
[docs]
def get_par(self, par_name_or_idx, comp_name=None, get_index=False):
"""Return par.
Parameters
----------
par_name_or_idx : object
Parameter identifier by name or index.
comp_name : object, optional
Model-component name.
get_index : bool, optional
If ``True``, also return the index of the selected item.
Returns
-------
object
Requested value.
"""
if type(par_name_or_idx) == int:
par_idx=par_name_or_idx
else:
par_name=par_name_or_idx
try:
if comp_name is None:
par_idx = [par.name for par in self._par_array_sampler].index(par_name)
else:
par_idx = [par.name+par.model.name for par in self._par_array_sampler].index(par_name+comp_name)
except:
raise RuntimeError('parameter p', par_name, 'not found')
if par_idx > len(self._par_array_sampler):
raise RuntimeError('label id larger then labels size')
if get_index is True:
return self._par_array_sampler[par_idx], par_idx
else:
return self._par_array_sampler[par_idx]
[docs]
def set_plot_label(self,par_name,plot_label,comp_name):
"""Set the display label used for a sampled parameter in plots.
Parameters
----------
par_name : str
Name of the parameter whose label must be updated.
plot_label : str
New text label used in corner and chain plots.
comp_name : str
Name of the model component that owns ``par_name``.
Raises
------
RuntimeError
If the requested parameter/component pair is not found.
"""
p=self.get_par(par_name,comp_name=comp_name)
p.plot_label=plot_label
[docs]
def reset_to_minimizer_best_fit(self):
"""Reset sampled parameter values to minimizer best-fit values.
Notes
-----
This updates only the internal parameter dictionary used by the
sampler helper; it does not run a new minimization.
"""
for par in self._par_array_sampler:
if par.best_fit_val is not None:
par.val = par.best_fit_val
[docs]
def reset_to_mcmc_best_fit(self,verbose=True):
"""Reset to mcmc best fit.
Parameters
----------
verbose : bool, optional
If ``True``, print additional information.
"""
_prob_max = np.argmax(self.samples_log_prob)
_id_prob_max = np.unravel_index(_prob_max, self.samples_log_prob.shape)
if verbose:
print("----------------------------")
print("MCMC best fit solution")
for ID,par in enumerate(self._par_array_sampler):
par.val= self.get_sample(ID)[_id_prob_max]
par.best_fit_mcmc_val= par.val
quantiles=self.get_par_quantiles(par_name=par.name,comp_name=par.model.name,quantiles=(0.16,0.5,0.84))
par.q_16=quantiles[0]
par.q_50=quantiles[1]
par.q_84=quantiles[2]
if verbose:
print(f"comp: {par.model.name} par: {par.name} mcmc best fit val: {par.val} quantiles(0.16,0.5,0.84): {quantiles} ")
if verbose:
print("----------------------------")
[docs]
def run_sampler(self,
nwalkers=None,
steps=100,
pos=None,
burnin=50,
use_UL=False,
threads=None,
walker_start_bound=0.005,
loglog = False,
progress='notebook'):
"""Run sampler.
Parameters
----------
nwalkers : int, optional
Number of MCMC walkers.
steps : int, optional
Number of MCMC steps.
pos : object, optional
Initial walker positions.
burnin : int, optional
Number of burn-in steps to discard.
use_UL : bool, optional
If ``True``, enable ul.
threads : object, optional
Number of worker threads/processes.
walker_start_bound : float, optional
Initial spread factor for walker starting points.
loglog : bool, optional
If ``True``, operate in log10 space.
progress : str, optional
If ``True``, display sampling progress.
"""
to_fix=False
err_str='\n'
for par in self._par_array_sampler:
if par.mcmc_bound_min is None or par.mcmc_bound_max is None:
err_str+=f"please set mcmc_bound_min and mcmc_bound_max for par: {par.name} in model component: {par.model.name}\n"
to_fix=True
if to_fix:
raise RuntimeError(f"can not run sampler if you do not set bounds for these parameters: {err_str}")
self._build_sampler_bounds()
self.calls = 0
self.calls_OK = 0
self.use_UL = use_UL
self.ndim = len(self._par_array_sampler)
self.pos = pos
self.burnin=burnin
if nwalkers is None:
self.nwalkers = 4*len(self._par_array_sampler)
print('setting nwalkers to:', self.nwalkers)
else:
self.nwalkers = nwalkers
if self.nwalkers < 2*len(self._par_array_sampler):
raise RuntimeError("numbers of walkers has to be at least two times the number of sampling pars")
counter=Counter( self.nwalkers*steps)
self.steps = steps
self.calls_tot = self.nwalkers * steps
if pos is None:
pos = sample_ball(np.array([p.val for p in self._par_array_sampler]),
np.array([p.val * walker_start_bound for p in self._par_array_sampler]),
self.nwalkers)
print('mcmc run starting')
print('')
start = time.time()
if threads is not None and threads>1:
warnings.warn('python multithreading is not effective, JetSeT uses C threads to speedup computation')
threads=1
self.sampler = emcee.EnsembleSampler(self.nwalkers, self.ndim, log_prob, args=(self.model, self.data, use_UL, counter, self._bounds_sampler, self._par_array_sampler, loglog))
self.sampler.run_mcmc(pos, steps, progress=progress)
else:
threads=1
self.sampler = emcee.EnsembleSampler(self.nwalkers, self.ndim, log_prob, args=(self.model, self.data, use_UL, counter, self._bounds_sampler, self._par_array_sampler, loglog))
self.sampler.run_mcmc(pos, steps, progress=progress)
end = time.time()
comp_time = end - start
print("mcmc run done, with %d threads took %2.2f seconds"%(threads,comp_time))
self._set_samples_post_run()
[docs]
def tune_burnin(self, tau_coeff=None):
"""Tune burn-in length from the sampler autocorrelation time.
Parameters
----------
tau_coeff : float, optional
Multiplicative factor applied to the maximum integrated
autocorrelation time. The burn-in is set to
``int(tau_coeff * max(tau))``. If ``None``, an internal value is
computed to keep burn-in below the total chain length.
Notes
-----
This method uses ``emcee`` integrated autocorrelation times computed
through ``sampler.get_autocorr_time(tol=0)`` and then updates
``self.burnin`` and all post-run cached products (samples, log-prob,
best-fit and quantiles) via :meth:`_set_samples_post_run`.
See emcee documentation for details:
- https://emcee.readthedocs.io/en/stable/tutorials/autocorr/
- https://emcee.readthedocs.io/en/stable/user/sampler/#emcee.EnsembleSampler.get_autocorr_time
"""
tau = np.asarray(self.sampler.get_autocorr_time(tol=0), dtype=float)
tau_max = float(np.max(tau))
if tau_coeff is None:
tau_coeff = min(1.0, (self.steps - 10) / max(tau_max, 1.0))
self.burnin = int(tau_coeff * np.max(tau)) # or a few times max(tau)
self._set_samples_post_run()
def _set_samples_post_run(self):
self.samples = self.sampler.get_chain(flat=True,discard=self.burnin)
self.samples_log_prob = self.sampler.get_log_prob(flat=True,discard=self.burnin)
self.posterior_weights=None
self._cache_sampler_chains()
self.acceptance_fraction=np.mean(self.sampler.acceptance_fraction)
self.reset_to_mcmc_best_fit()
[docs]
def get_par_quantiles(self,par_name,comp_name=None,quantiles=(0.16,0.5,0.84)):
"""Return par quantiles.
Parameters
----------
par_name : str
Parameter name.
comp_name : object, optional
Model-component name.
quantiles : tuple, optional
Quantiles to evaluate/report.
Returns
-------
object
Requested value.
"""
return np.array(np.quantile(self.get_sample(par_name,comp_name=comp_name),quantiles))
[docs]
def corner_plot(self, comp_name=None,quantiles = (0.16, 0.5, 0.84), levels = None, title_kwargs = {}, per_component=True, **kwargs):
"""Corner plot.
Parameters
----------
comp_name : str, optional
Model-component name.
quantiles : tuple, optional
Quantiles to evaluate/report.
levels : object, optional
Contour levels for corner plots.
title_kwargs : dict, optional
Keyword arguments forwarded to plot-title rendering.
**kwargs : dict
Additional keyword arguments.
Returns
-------
object
Computed value.
"""
if per_component:
if comp_name is None:
components=np.unique([p.model.name for p in self._par_array_sampler])
else:
components=[comp_name]
f_list=[]
for c in components:
_idxs = []
truths = []
msk=np.array([p.model.name==c for p in self._par_array_sampler])
names=[p.name for p in np.array(self._par_array_sampler)[msk]]
plot_labels=[]
if msk.sum()>0:
for name in names:
_idx=self.get_par(name,comp_name=c,get_index=True)[1]
_idxs.append(_idx)
p=self.get_par(_idx)
truths.append(p.best_fit_mcmc_val)
if hasattr(p,'plot_label'):
plot_labels.append(p.plot_label)
else:
plot_labels.append(p.name)
f=self._do_corner_plot(c,
self.samples,
_idxs,
quantiles,
plot_labels,
truths,
title_kwargs,
levels,
weights=self.posterior_weights,
**kwargs)
f_list.append(f)
return f_list
else:
f_list=[]
_idxs = []
truths = []
plot_labels=[]
components=np.unique([p.model.name for p in self._par_array_sampler])
for c in components:
msk=np.array([p.model.name==c for p in self._par_array_sampler])
if msk.sum()>0:
_names=([p.name for p in np.array(self._par_array_sampler)[msk]])
for name in _names:
_idx=self.get_par(name,comp_name=c,get_index=True)[1]
_idxs.append(_idx)
p=self.get_par(_idx)
truths.append(p.best_fit_mcmc_val)
if hasattr(p,'plot_label'):
plot_labels.append(f'{p.model.name}\n{p.plot_label}')
else:
plot_labels.append(f'{p.model.name}\n{p.name}')
f=self._do_corner_plot(self.model.name,
self.samples,
_idxs,
quantiles,
plot_labels,
truths,
title_kwargs,
levels,
weights=self.posterior_weights,
**kwargs)
f_list.append(f)
return f_list
def _do_corner_plot(self, c,samples,_idxs,quantiles,plot_labels,truths,title_kwargs,levels,**kwargs):
f = corner.corner(samples[:, _idxs],
quantiles=quantiles,
labels=plot_labels,
truths=truths,
title_kwargs=title_kwargs,
show_titles = True,
levels = levels,
**kwargs)
title = c + ' quantiles ='+str(quantiles)
f.suptitle(title,y=1.0)
return f
[docs]
def plot_chain(self,par_name=None, comp_name=None,log_plot=False):
"""Plot chain.
Parameters
----------
par_name : str, optional
Parameter name.
comp_name : object, optional
Model-component name.
log_plot : bool, optional
If ``True``, use logarithmic plot scaling where applicable.
Returns
-------
object
Plot object or generated visualization.
"""
f_list=[]
if comp_name is None:
components=np.unique([p.model.name for p in self._par_array_sampler])
else:
components=[comp_name]
for c in components:
msk=np.array([p.model.name==c for p in self._par_array_sampler])
if par_name is None:
par_names=[par.name for par in np.array(self._par_array_sampler)[msk]]
else:
par_names=np.atleast_1d(par_name)
f, axes = plt.subplots(len(par_names),figsize=(10, 3*len(par_names)), sharex=True)
axes=np.atleast_1d(axes)
for ID,_p_name in enumerate(par_names):
self._plot_chain(_p_name,axes[ID],comp_name=comp_name,log_plot=log_plot)
axes[-1].set_xlabel('steps')
f_list.append(f)
return f_list
def _plot_chain(self, par_name,ax,comp_name=None, log_plot=False):
par = self.get_par(par_name,comp_name=comp_name)
n = par.plot_label
traces=self.get_trace(par_name,comp_name=comp_name)
if par.units is not None:
n += ' (%s)' % par.units
_s=self.get_sample(par_name,comp_name=comp_name)
alpha_true = np.median(_s)
if log_plot == True:
n = 'log10(%s)'%n
if np.any(_s <= 0):
raise RuntimeWarning('negative values in p')
else:
traces = np.log10(traces)
alpha_true = np.log10(alpha_true)
for t in traces:
ax.plot(t, '-', color='k', alpha=0.5)
ax.axhline(alpha_true, color='blue')
if hasattr(self,'burnin'):
ax.axvline(self.burnin, ls='--',color='orange',alpha=0.5)
ax.set_ylabel(n)
[docs]
def get_trace(self, par_name,comp_name=None):
"""Return trace.
Parameters
----------
par_name : str
Parameter name.
comp_name : object, optional
Model-component name.
Returns
-------
object
Requested value.
"""
_p,p_idx=self.get_par(par_name,comp_name=comp_name,get_index=True)
if hasattr(self, 'chain') and self.chain is not None:
return self.chain[:, :, p_idx]
if hasattr(self, 'sampler') and self.sampler is not None:
try:
_chain = self._as_walker_first_chain(self.sampler.get_chain(flat=False))
return _chain[:, :, p_idx]
except Exception:
_chain = self._as_walker_first_chain(self.sampler.get_chain())
return _chain[:, :, p_idx]
raise RuntimeError('MCMC traces are not available in this sampler')
[docs]
def plot_par(self, par_name, comp_name=None,nbins=20, log_plot=False,quantiles=(0.16,0.5,0.84),figsize=None):
"""Plot par.
Parameters
----------
par_name : str
Parameter name.
comp_name : object, optional
Model-component name.
nbins : int, optional
Number of bins for histogram estimates.
log_plot : bool, optional
If ``True``, use logarithmic plot scaling where applicable.
quantiles : tuple, optional
Quantiles to evaluate/report.
figsize : object, optional
Matplotlib figure size.
Returns
-------
object
Plot object or generated visualization.
"""
set_mpl()
par = self.get_par(par_name,comp_name=comp_name)
par_name = par.name
x_name = par_name
if par.units is not None:
x_name += ' (%s)' % par.units
_d=self.get_sample(par_name,comp_name=comp_name)
f = plt.figure(figsize=figsize)
ax = f.add_subplot(111)
if log_plot == True:
x_name = 'log10(%s)' % x_name
if np.any(_d <= 0):
raise RuntimeWarning('negative values in p')
else:
_d = np.log10(_d)
q_vals=self.get_par_quantiles(par_name,comp_name=comp_name,quantiles=quantiles)
q_diff = np.diff(q_vals)
ax.hist(_d,
bins=nbins,
density=True,
alpha=0.5,
label=r'%s = $%.3e^{+%.3e}_{-%.3e} $'%(par_name,q_vals[0],q_diff[1],q_diff[0]))
for q in q_vals:
if log_plot is True:
q=np.log10(q)
ax.axvline(q,c='black',ls='--',lw=0.5)
ax.set_xlabel(x_name)
f.suptitle('quantiles = %s'%str(quantiles))
ax.legend(loc='center left', bbox_to_anchor=(1.0, 0.5), ncol=1)
return f
[docs]
def get_sample(self, par_name,comp_name=None):
"""Return sample.
Parameters
----------
par_name : str
Parameter name.
comp_name : object, optional
Model-component name.
Returns
-------
object
Requested value.
"""
_p,p_idx=self.get_par(par_name,comp_name=comp_name,get_index=True)
return self.samples[:,p_idx]
[docs]
def plot_model(self, sed_data=None, fit_range=None, size=100, frame='obs', density=False,quantiles=None, get_model=False, plot_mcmc_best_fit_model=True,rnd_seed=0,plot_components=False):
"""Plot posterior model envelope and a best-fit reference curve.
Parameters
----------
sed_data : ObsData, optional
Observational SED data used for plotting and residuals. If ``None``,
``self.sed_data`` is used.
fit_range : sequence of float, optional
Two-element fit interval ``[nu_min, nu_max]`` used for model and
residual overlays. If ``None``, ``[self.model.nu_min_fit,
self.model.nu_max_fit]`` is used.
size : int, optional
Number of posterior samples used to build the shaded model envelope.
frame : {'obs', 'src'}, optional
Frame used to evaluate and display model SED values.
density : bool, optional
If ``True``, convert sampled ``nuFnu`` curves to ``Fnu`` by dividing
by frequency before computing the envelope.
quantiles : tuple of float, optional
Lower/upper quantiles for the shaded envelope (for example
``(0.16, 0.84)``). If ``None``, use the full min/max range.
get_model : bool, optional
If ``True``, also return the sampled envelope arrays
``[x, y_min, y_max]`` after flux-limit masking.
plot_mcmc_best_fit_model : bool, optional
If ``True``, overlay the MCMC best-fit model. If ``False``, overlay
the minimizer best-fit model.
rnd_seed : int, optional
Seed used to draw posterior samples reproducibly.
plot_components : bool, optional
If ``True``, plot component/sub-component curves before the final
reference curve.
Returns
-------
PlotSED or tuple
Plot object. If ``get_model`` is ``True``, returns
``(plot_obj, [x, y_min, y_max])``.
"""
if sed_data is None:
sed_data=self.sed_data
if fit_range is None:
fit_range = [self.model.nu_min_fit, self.model.nu_max_fit]
p = self.model._set_up_plot(None, sed_data, frame, density)
x,y=self._get_model_samples(size=size,rnd_seed=rnd_seed,frame=frame)
if density is True:
y=y/x
if quantiles is None:
y_min=np.amin(y, axis=0)
y_max=np.amax(y, axis=0)
_l='mcmc model range'
#msk = y_min > self.model.flux_plot_lim
#l=p.sedplot.fill_between(x[msk],y_max[msk],y_min[msk],color='gray',alpha=0.3,label='mcmc model range')
else:
_l='mcmc model conf. %s'%quantiles
y_min,y_max=np.quantile(y, quantiles, axis=0)
msk = y_min > self.model.flux_plot_lim
l=p.sedplot.fill_between(x[msk],y_max[msk],y_min[msk],color='gray',alpha=0.3,label=_l)
p.lines_model_list.append(l)
if not plot_mcmc_best_fit_model :
self.reset_to_minimizer_best_fit()
label=None
else:
label='mcmc best fit'
self.reset_to_mcmc_best_fit(verbose=False)
self.model.eval()
if plot_components:
self.model.plot_model(sed_data=sed_data,plot_obj=p,only_components=True)
p.add_model_plot(self.model, color='red',fit_range = fit_range,flim=self.model.flux_plot_lim,label=label)
p.add_model_residual_plot(model = self.model, data = sed_data, fit_range = fit_range, color='red')
self.reset_to_mcmc_best_fit(verbose=False)
if get_model is True:
return p, [x[msk],y_min[msk],y_max[msk]]
else:
return p
def _get_model_samples(self,size,rnd_seed,frame):
self.model.eval()
x, _y = self.model.SED.get_model_points(log_log=False, frame=frame)
if size is None:
size = len(self.samples)
else:
size = min(len(self.samples), int(size))
rng = np.random.default_rng(rnd_seed)
ID_mcmc = rng.integers(0,len(self.samples),size=size)
y = np.zeros((size,x.size))
for ID,ID_rand in enumerate(ID_mcmc):
for id_p,par in enumerate(self._par_array_sampler):
par.val=self.get_sample(id_p)[ID_rand]
self.model.parameters.set_par(model_name= par.model.name,par_name=par.name,val=par.val)
self.model.eval(fill_SED=True)
x, y[ID] = self.model.SED.get_model_points(log_log=False, frame=frame)
return x,y
def _progess_bar(self,):
if np.mod(self.calls, 10) == 0 and self.calls != 0:
print("\r%s progress=%3.3f%% calls=%d accepted=%d" % (next(self._progress_iter),float(100*self.calls)/(self.calls_tot),self.calls,self.calls_OK), end="")
[docs]
def save(self, name):
"""Save object state to disk.
Parameters
----------
name : object
Name identifier.
"""
with open(name, 'wb') as output:
pickle.dump(self, output, pickle.HIGHEST_PROTOCOL)
[docs]
@classmethod
#def load(self, name):
# with open(name, 'rb') as input:
# return pickle.load(input)
def load(cls, file_name):
"""Load object state from disk.
Parameters
----------
file_name : object
Input/output file path.
Returns
-------
object
Loaded object.
"""
try:
c = pickle.load(open(file_name, "rb"))
if isinstance(c, McmcSampler):
c.model.parameters.__class__=McmcCompositeModelParameterArray
if hasattr(c,'model'):
c.model=c.model._build_model(c.model)
for p in c.model.parameters.par_array:
Base = p.__class__
if not Base.__name__.startswith("McmcParameter_"):
McmcParameter = types.new_class(f"McmcParameter_{Base.__name__}", (Base,))
McmcParameter.x = property(set_mcmc_bound_max,get_mcmc_bound_max)
McmcParameter.x = property(set_mcmc_bound_min,get_mcmc_bound_min)
p.__class__ = McmcParameter
p._check_par_mcmc_bounds = types.MethodType(_check_par_mcmc_bounds, p)
p.best_fit_mcmc_val=c.sate_of_mcmc_parameters[p.model.name][p.name]['best_fit_mcmc_val']
p.q_16=c.sate_of_mcmc_parameters[p.model.name][p.name]['q_16']
p.q_50=c.sate_of_mcmc_parameters[p.model.name][p.name]['q_50']
p.q_84=c.sate_of_mcmc_parameters[p.model.name][p.name]['q_84']
p.plot_label=c.sate_of_mcmc_parameters[p.model.name][p.name]['plot_label']
p.mcmc_bound_min=c.sate_of_mcmc_parameters[p.model.name][p.name]['mcmc_bound_min']
p.mcmc_bound_max=c.sate_of_mcmc_parameters[p.model.name][p.name]['mcmc_bound_max']
delattr(c,'sate_of_mcmc_parameters')
return c
else:
raise RuntimeError('The model you loaded is not valid please check the file name')
except Exception as e:
raise RuntimeError(e)
def emcee_log_like(theta,fit_model,data,use_UL,par_array,loglog):
"""Emcee log like.
Parameters
----------
theta : object
Parameter vector sampled by MCMC.
fit_model : object
Model instance used for fitting.
data : object
Input data table or array.
use_UL : object
If ``True``, enable ul.
par_array : object
Array/list of model parameters.
loglog : object
If ``True``, operate in log10 space.
Returns
-------
object
Computed value.
"""
_warn = False
theta = np.asarray(theta, dtype=float)
# Fast reject before mutating model
if not np.all(np.isfinite(theta)):
return -1e100
try:
for pi in range(len(theta)):
par_array[pi].val=theta[pi]
fit_model.parameters.set_par(model_name= par_array[pi].model.name,par_name=par_array[pi].name,val=par_array[pi].val)
if np.isnan(theta[pi]):
_warn=True
except Exception:
# Parameter-setting failure
return -1e100
try:
_model = fit_model.eval(nu=data['x'], fill_SED=False, get_model=True, loglog=loglog)
except Exception:
# Parameter-setting failure
return -1e100
_model = np.asarray(_model, dtype=float)
if not np.all(np.isfinite(_model)):
return -1e100
try:
_res_sum, _res, _res_UL = _eval_res(data['y'],
_model,
data['dy'],
data['UL'],
use_UL=use_UL)
except Exception:
return -1e100
if not np.isfinite(_res_sum):
return -1e100
return _res_sum *-0.5
#def _progess_bar(counter):
#
# if np.mod(counter.count, 10) == 0 and counter.count != 0:
# print("\r%s progress=%3.3f%% calls=%d accepted=%d" % (next( counter._progress_iter),float(100* counter.count)/( counter.count_tot),counter.count,counter.count_OK), end="")
def log_prob(theta,fit_model,data,use_UL,counter,bounds,par_array,loglog):
"""Log prob.
Parameters
----------
theta : object
Parameter vector sampled by MCMC.
fit_model : object
Model instance used for fitting.
data : object
Input data table or array.
use_UL : object
If ``True``, enable ul.
counter : object
Sampling progress counter object.
bounds : object
Bounds for sampled/fitted parameters.
par_array : object
Array/list of model parameters.
loglog : object
If ``True``, operate in log10 space.
Returns
-------
object
Computed value.
"""
lp = log_prior(theta,bounds)
counter.count += 1
if not np.isfinite(lp):
res = -np.inf
else:
ll= emcee_log_like(theta,fit_model,data,use_UL,par_array,loglog)
res= lp+ll
counter.count_OK += 1
return res
def log_prior(theta,bounds):
"""Log prior.
Parameters
----------
theta : object
Parameter vector sampled by MCMC.
bounds : object
Bounds for sampled/fitted parameters.
Returns
-------
object
Computed value.
"""
_r=0.
for pi in range(len(theta)):
if bounds[pi][1] is not None:
if theta[pi]>bounds[pi][1]:
_r=-np.inf
if bounds[pi][0] is not None:
if theta[pi]<bounds[pi][0]:
_r=-np.inf
return _r