Source code for jetset.mcmc

__author__ = "Andrea Tramacere"

from .minimizer import  _eval_res

import emcee
from itertools import cycle

import numpy as np
import scipy as sp
from scipy import stats
import corner
import dill as pickle
from multiprocessing import cpu_count, Pool
import multiprocessing as mp
import warnings
import  time
import copy
import threading
from .plot_sedfit import  plt, PlotSED, set_mpl


class Counter(object):

    def __init__(self,count_tot):
        self.count = 0
        self.count_OK = 0
        self.count_tot = count_tot
        self._progress_iter = cycle(['|', '/', '-', '\\'])

#class RunThread(object):
#    def __init__(self, target_class):
#        self.target_class = target_class

#    def run(self):
#        self.target_class.model=self.target_class.model.clone()
#        self.target_class.sampler.run_mcmc(self.target_class._pos, self.target_class._npernode, rstate0=np.random.get_state(), progress=True,store = True)

#to prevent from deprecation to error in emcee 
def sample_ball(p0, std, size=1):
    Produce a ball of walkers around an initial parameter value.

    :param p0: The initial parameter value.
    :param std: The axis-aligned standard deviation.
    :param size: The number of samples to produce.

    assert len(p0) == len(std)
    return np.vstack(
        [p0 + std * np.random.normal(size=len(p0)) for i in range(size)]

[docs] class McmcSampler(object): def __init__(self,model_minimizer): if emcee.__version__ < "3": raise RuntimeError('Please update to emcee v>=3.0.0') self.model = model_minimizer.fit_model = self.fit_par_free = model_minimizer.fit_par_free self._progress_iter = cycle(['|', '/', '-', '\\']) self.labels=None self.par_array=None self._bounds=None
[docs] def set_labels(self,use_labels_dict=None): """_summary_ Parameters ---------- use_labels_dict : _type_, optional _description_, by default None """ if use_labels_dict is None: self.par_array = self.fit_par_free self.labels = [ for par in self.par_array] self.labels_units = [par.units for par in self.par_array] self.labels_start_val = [p.best_fit_val for p in self.par_array] else: self.labels=[] self.par_array=[] self.labels_units=[] self.labels_start_val=[] for model_name in use_labels_dict.keys(): for par_name in use_labels_dict[model_name]: p= self.model.parameters.get_par_by_name(model_name,par_name) if p is not None: self.par_array.append(p) self.labels.append( self.labels_units.append(p.units) self.labels_start_val.append(p.best_fit_val) else: warnings.warn('par %s'%par_name+' not present in model, will be skipped')
[docs] def set_bounds(self,bound=0.2,bound_rel=False,): self._build_bounds(bound=bound,bound_rel=bound_rel)
[docs] def run_sampler(self, nwalkers=None, steps=100, pos=None, burnin=50, use_UL=False, threads=None, walker_start_bound=0.005, loglog = False, progress='notebook', use_labels_dict=None, bound=None, bound_rel=False): self.calls = 0 self.calls_OK = 0 self.use_UL = use_UL if use_labels_dict is not None or bound is not None: warnings.warn('use_labels_dict, and bounds in run_sampler are deprecated and will result in an error in the next version, please use the .set_labels and .set_bounds methods as explained in the documentation') self.set_labels(use_labels_dict=use_labels_dict) self.set_bounds(bound=bound,bound_rel=bound_rel) self.plot_labels=copy.deepcopy(self.labels) self.par_array_best_fit=copy.deepcopy(self.par_array) self.ndim = len(self.labels) self.pos = pos self.burnin=burnin if nwalkers is None: self.nwalkers = 4*len(self.labels) print('setting nwalkers to:', self.nwalkers) else: self.nwalkers = nwalkers if self.nwalkers < 2*len(self.labels): raise RuntimeError("numbers of walkers has to be at least two times the number of sampling pars") counter=Counter( self.nwalkers*steps) self.steps = steps self.calls_tot = self.nwalkers * steps if pos is None: pos = sample_ball(np.array([p.best_fit_val for p in self.par_array]), np.array([p.best_fit_val * walker_start_bound for p in self.par_array]), self.nwalkers) print('mcmc run starting') print('') start = time.time() if threads is not None and threads>1: warnings.warn('python multithreading is not effective, JetSeT uses C threads to speedup computation') threads=1 self.sampler = emcee.EnsembleSampler(self.nwalkers, self.ndim, log_prob, args=(self.model,, use_UL, counter, self._bounds, self.par_array, loglog)) self.sampler.run_mcmc(pos, steps, progress=progress) else: threads=1 self.sampler = emcee.EnsembleSampler(self.nwalkers, self.ndim, log_prob, args=(self.model,, use_UL, counter, self._bounds, self.par_array, loglog)) self.sampler.run_mcmc(pos, steps, progress=progress) end = time.time() comp_time = end - start print("mcmc run done, with %d threads took %2.2f seconds"%(threads,comp_time)) #self.samples = self.sampler.chain[:, burnin:, :].reshape((-1, self.ndim)) self.samples = self.sampler.get_chain(flat=True,discard=burnin) self.samples_log_prob = self.sampler.get_log_prob(flat=True,discard=burnin) self.acceptance_fraction=np.mean(self.sampler.acceptance_fraction) self.reset_to_best_fit()
[docs] def set_plot_label(self,label,plot_label): self.plot_labels[self.labels.index(label)]=plot_label
[docs] def get_par_quantiles(self,p,quantiles=(0.16,0.5,0.84)): _d, idx=self.get_par(p) return np.array(np.quantile(_d,quantiles))
[docs] def reset_to_best_fit(self): for ID,par in enumerate(self.par_array): par.val = self.par_array_best_fit[ID].val
[docs] def reset_to_mcmc(self,quantile=0.5): for ID,par in enumerate(self.par_array): q_vals=self.get_par_quantiles(ID,quantiles=quantile) par.val = q_vals
def _build_bounds(self, bound=0.2,bound_rel=True): self._bounds=[] if np.shape(bound) == (): bound=[bound,bound] elif np.shape(bound) == (2,): pass else: raise RuntimeError('bound shape', np.shape(bound), 'it is wrong, has to be a scalar or (2,)') for par in self.par_array: if bound_rel is False and par.best_fit_err is not None: #_min = par.best_fit_val - par.best_fit_err * bound[0] #_max = par.best_fit_val + par.best_fit_err * bound[1] delta_p = par.best_fit_err * bound[1] delta_m = par.best_fit_err * bound[0] else: delta_p = np.fabs(par.best_fit_val)*bound[1] delta_m = np.fabs(par.best_fit_val)*bound[0] #_min = par.best_fit_val * (1.0 - bound[0]) #_max = par.best_fit_val * (1.0 + bound[1]) _min = par.best_fit_val - delta_m _max = par.best_fit_val + delta_p if par.fit_range_min is not None: _min= max(_min, par.fit_range_min) if par.fit_range_max is not None: _max= min(_max, par.fit_range_max) print('par:',,' best fit value: ',par.best_fit_val,' mcmc bounds:',[_min, _max]) self._bounds.append([_min, _max])
[docs] def corner_plot(self, labels = None, quantiles = (0.16, 0.5, 0.84), levels = None, title_kwargs = {}, **kwargs): _id = [] if labels is None: labels=self.labels if type(labels) == list: pass else: labels = [labels] for l in labels: _id.append(self.labels.index(l)) f = corner.corner(self.samples[:, _id], quantiles=quantiles, labels=self.plot_labels, truths=[self.labels_start_val[i] for i in _id], title_kwargs=title_kwargs,show_titles = True, levels = levels,**kwargs) title = 'quantiles ='+str(quantiles) f.suptitle(title,y=1.0) return f
[docs] def get_par(self, p,): if type(p) == int: pass else: try: p = self.labels.index(p) except: raise RuntimeError('parameter p', p, 'not found') if p > len(self.labels): raise RuntimeError('label id larger then labels size') return self.samples[:, p].flatten(), p
[docs] def plot_chain(self,p=None,log_plot=False): if p is None: p=self.labels else: p=np.atleast_1d(p) f, axes = plt.subplots(len(p), sharex=True) axes=np.atleast_1d(axes) for ID,_p in enumerate(p): self._plot_chain(_p,axes[ID],log_plot=log_plot) axes[-1].set_xlabel('steps') return f
def _plot_chain(self, p,ax, log_plot=False): _d, idx = self.get_par(p) n = self.plot_labels[idx] traces=self.sampler.chain[:, :, idx] if self.labels_units is not None: if self.labels_units[idx] is not None: n += ' (%s)' % self.labels_units[idx] alpha_true = np.median(_d) if log_plot == True: n = 'log10(%s)'%n if np.any(_d <= 0): raise RuntimeWarning('negative values in p') else: traces = np.log10(traces) alpha_true = np.log10(alpha_true) for t in traces: ax.plot(t, '-', color='k', alpha=0.5) ax.axhline(alpha_true, color='blue') if hasattr(self,'burnin'): ax.axvline(self.burnin, ls='--',color='orange',alpha=0.5) ax.set_ylabel(n)
[docs] def plot_par(self, p, nbins=20, log_plot=False,quantiles=(0.16,0.5,0.84),figsize=None): set_mpl() _d, idx = self.get_par(p) par_name = self.plot_labels[idx] x_name = par_name if self.labels_units is not None: if self.labels_units[idx] is not None and str(self.labels_units[idx]).strip() != '' : x_name += ' (%s)' % self.labels_units[idx] f = plt.figure(figsize=figsize) ax = f.add_subplot(111) if log_plot == True: x_name = 'log10(%s)' % x_name if np.any(_d <= 0): raise RuntimeWarning('negative values in p') else: _d = np.log10(_d) q_vals=self.get_par_quantiles(p,quantiles=quantiles) q_diff = np.diff(q_vals) ax.hist(_d, bins=nbins, density=True, alpha=0.5, label=r'%s = $%.3e^{+%.3e}_{-%.3e} $'%(par_name,q_vals[0],q_diff[1],q_diff[0])) for q in q_vals: if log_plot is True: q=np.log10(q) ax.axvline(q,c='black',ls='--',lw=0.5) ax.set_xlabel(x_name) f.suptitle('quantiles = %s'%str(quantiles)) ax.legend(loc='center left', bbox_to_anchor=(1.0, 0.5), ncol=1) return f
[docs] def plot_model(self, sed_data=None, fit_range=None, size=100, frame='obs', density=False,quantiles=None, get_model=False, plot_mcmc_best_fit_model=False): if sed_data is None: sed_data=self.sed_data if fit_range is None: fit_range = [self.model.nu_min_fit, self.model.nu_max_fit] p = self.model._set_up_plot(None, sed_data, frame, density) self.reset_to_best_fit() self.model.eval(fill_SED=True) x, y = self.model.SED.get_model_points(log_log=False, frame=frame) #if density is True: # y=y-x if size is None: size = len(self.samples) ID_mcmc = np.arange(size) else: size = min(len(self.samples), int(size)) ID_mcmc = np.random.randint(len(self.samples), size=size) y = np.zeros((size,x.size)) for ID,ID_rand in enumerate(ID_mcmc): for ID_par, pi in enumerate(self.par_array): pi.set(val=self.get_par(ID_par)[0][ID_rand]) self.model.eval(fill_SED=True) x, y[ID] = self.model.SED.get_model_points(log_log=False, frame=frame) if density is True: y=y/x if quantiles is None: y_min=np.amin(y, axis=0) y_max=np.amax(y, axis=0) msk = y_min > self.model.flux_plot_lim l=p.sedplot.fill_between(x[msk],y_max[msk],y_min[msk],color='gray',alpha=0.3,label='mcmc model range') else: y_min,y_max=np.quantile(y, quantiles, axis=0) msk = y_min > self.model.flux_plot_lim l=p.sedplot.fill_between(x[msk],y_max[msk],y_min[msk],color='gray',alpha=0.3,label='mcmc model conf %s'%quantiles) p.lines_model_list.append(l) self.reset_to_best_fit() self.model.eval(fill_SED=True) if plot_mcmc_best_fit_model is False: p.add_model_plot(self.model, color='red',fit_range = fit_range,density=density,flim=self.model.flux_plot_lim) p.add_model_residual_plot(model = self.model, data = sed_data, fit_range = fit_range, color='red') else: self.reset_to_mcmc(quantile=0.5) self.model.eval() p.add_model_plot(self.model, color='red',fit_range = fit_range,density=density,flim=self.model.flux_plot_lim,label='mcmc 0.5 quantile') p.add_model_residual_plot(model = self.model, data = sed_data, fit_range = fit_range, color='red') #self.reset_to_best_fit() if get_model is True: return p, [x[msk],y_min[msk],y_max[msk]] else: return p
def _progess_bar(self,): if np.mod(self.calls, 10) == 0 and self.calls != 0: print("\r%s progress=%3.3f%% calls=%d accepted=%d" % (next(self._progress_iter),float(100*self.calls)/(self.calls_tot),self.calls,self.calls_OK), end="")
[docs] def save(self, name): with open(name, 'wb') as output: pickle.dump(self, output, pickle.HIGHEST_PROTOCOL)
[docs] @classmethod def load(self, name): with open(name, 'rb') as input: return pickle.load(input)
def emcee_log_like(theta,fit_model,data,use_UL,par_array,loglog): _warn = False for pi in range(len(theta)): par_array[pi].set(val=theta[pi]) if np.isnan(theta[pi]): _warn=True _m = fit_model.eval(nu=data['x'], fill_SED=False, get_model=True, loglog=loglog) _res_sum, _res, _res_UL = _eval_res(data['y'], _m, data['dy'], data['UL'], use_UL=use_UL) #_progess_bar(counter) return _res_sum *-0.5 def _progess_bar(counter): if np.mod(counter.count, 10) == 0 and counter.count != 0: print("\r%s progress=%3.3f%% calls=%d accepted=%d" % (next( counter._progress_iter),float(100* counter.count)/( counter.count_tot),counter.count,counter.count_OK), end="") def log_prob(theta,fit_model,data,use_UL,counter,bounds,par_array,loglog): lp = log_prior(theta,bounds) counter.count += 1 if not np.isfinite(lp): lp = -np.inf else: lp += emcee_log_like(theta,fit_model,data,use_UL,par_array,loglog) counter.count_OK += 1 return lp def log_prior(theta,bounds): _r=0. #bounds = [(par.fit_range_min, par.fit_range_max) for par in model_minimizer.fit_par_free] #skip=False for pi in range(len(theta)): if bounds[pi][1] is not None: if theta[pi]>bounds[pi][1]: _r=-np.inf if bounds[pi][0] is not None: if theta[pi]<bounds[pi][0]: _r=-np.inf #if skip is True: #_r=-np.inf #print(theta[pi],bounds) return _r