Source code for jetset.plot_sedfit



__author__ = "Andrea Tramacere"


import matplotlib as mpl


try:
    from matplotlib import  pyplot as plt
except:
    try:
        from matplotlib import pylab as plt

    except:
        try:
           import  pylab as plt
        except:
            raise RuntimeError('Unable to import pylab/pyplot from matplotlib')



from matplotlib import gridspec
import numpy as np
import  os
from astropy.constants import m_e,m_p,c
import matplotlib.ticker as ticker
import warnings

from collections import namedtuple

from .output import section_separator,WorkPlace

from .utils import *

__all__=['PlotSED','BasePlot','PlotPdistr','PlotSpecComp','PlotSeedPhotons','PlotSpectralMultipl','PlotTempEvDiagram','PlotTempEvEmitters']

def y_ev_transf(x):
    """Y ev transf.
    
    Parameters
    ----------
    x : object
        Parameter controlling x.
    
    Returns
    -------
    object
        Computed result.
    """
    return x / 2.417E14

def y_ev_transf_inv(x):
    """Y ev transf inv.
    
    Parameters
    ----------
    x : object
        Parameter controlling x.
    
    Returns
    -------
    object
        Computed result.
    """
    return x * 2.417E14



def set_mpl():
    """Set mpl."""
    mpl.rcParams['figure.figsize'] = [12.0, 8.0]
    mpl.rcParams['figure.dpi'] = 100
    mpl.rcParams['savefig.dpi'] = 100

    mpl.rcParams['font.size'] = '14'
    mpl.rcParams['legend.fontsize'] = 'medium'
    mpl.rcParams['figure.titlesize'] = 'medium'



def _rescale( x_min=None, x_max=None, y_min=None, y_max=None):
        warnings.warn('`The rescale method has been removed and has been replaced by the setlim method')
        print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
        print("!The rescale method as been replaced by the setlim method            !")
        print("!please notice that now jetset uses log axis rather than loglog plots!")
        print("!so, the correct way to use it is rescale(x_min=8)->setlim(x_min=1E8)!")
        print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")

[docs] class PlotSED (object): """Main SED plotting utility for data, models, and residuals. Notes ----- Manages a two-panel Matplotlib figure (spectrum + residuals) and provides helpers to overlay observational data, model curves, and time-dependent snapshots. """ def __init__(self, sed_data=None, model=None, interactive=False, plot_workplace=None, title='Plot', frame='obs', density=False, dpi=100, figsize=(12,8), use_grid=True): """Create a new `PlotSED` instance. Parameters ---------- sed_data : object, optional Observational SED data container. model : object, optional Model instance. interactive : bool, optional Parameter controlling interactive. plot_workplace : object, optional If ``True``, plot workplace. title : str, optional Parameter controlling title. frame : str, optional Reference frame for data/model values. density : bool, optional Parameter controlling density. dpi : int, optional Parameter controlling dpi. figsize : tuple, optional Parameter controlling figsize. use_grid : bool, optional If ``True``, enable grid. """ check_frame(frame) self.frame=frame self._sed_data=None self.density=density self.axis_kw=['x_min','x_max','y_min','y_max'] self.interactive=interactive plot_workplace=plot_workplace self.lines_data_list=[] self.lines_model_list=[] self.lines_res_list = [] if self.interactive is True: plt.ion() print ('running PyLab in interactive mode') if plot_workplace is None: plot_workplace=WorkPlace() self.out_dir=plot_workplace.out_dir self.flag=plot_workplace.flag else: self.out_dir=plot_workplace.out_dir self.flag=plot_workplace.flag self.title="%s_%s"%(title,self.flag) if figsize is None: figsize=(10,8) self.fig=plt.figure(figsize=figsize,dpi=dpi) self.gs = gridspec.GridSpec(2, 1, height_ratios=[4, 1]) self.sedplot= self.fig.add_subplot(self.gs[0]) self._add_res_plot() self.set_plot_axis_labels(density=self.density) #if autoscale==True: self.sedplot.set_autoscalex_on(True) self.sedplot.set_autoscaley_on(True) self.sedplot.set_autoscale_on(True) self.counter=0 if use_grid is True: self.sedplot.grid(use_grid,alpha=0.5) self.sedplot.set_xlim(1E6, 1E30) if frame == 'obs': self.sedplot.set_ylim(1E-20, 1E-8) elif frame == 'src': self.sedplot.set_ylim(1E38, 1E55) elif frame == 'blob': self.sedplot.set_ylim(1E34, 1E51) else: unexpected_behaviour() self.sedplot.set_xscale("log", nonpositive='clip') self.sedplot.set_yscale("log", nonpositive='clip') self.secaxy = self.sedplot.secondary_xaxis('top', functions=(y_ev_transf, y_ev_transf_inv)) self.secaxy.set_xlabel('E (eV)') self.resplot.set_ybound(-2,2) try: if hasattr(self.fig.canvas.manager,'toolbar'): self.fig.canvas.manager.toolbar.update() except: pass if sed_data is not None : self.add_data_plot(sed_data) if model is not None: self.add_model_plot(model) self.counter_res=0 self.add_residual_plot = self.add_model_residual_plot def _check_frame(self,frame): if frame is None: frame=self.frame elif frame != self.frame: raise RuntimeError('you have to use the same restframe of the PlotSED class:',self.frame ) return frame def _add_res_plot(self): self.resplot = self.fig.add_subplot(self.gs[1], sharex=self.sedplot) self.lx_res = '$ \\nu $ (Hz)' self.ly_res = 'res' self.resplot.set_ylabel(self.ly_res) self.resplot.set_xlabel(self.lx_res) #self,resplot.set_xscale("log", nonpositive='clip') self.add_res_zeroline()
[docs] def clean_residuals_lines(self): """Clean residuals lines.""" for i in range(len(self.lines_res_list)): self.del_residuals_line(0)
[docs] def clean_data_lines(self): """Clean data lines.""" for i in range(len(self.lines_data_list)): self.del_data_line(0)
[docs] def clean_model_lines(self): """Clean model lines.""" for i in range(len(self.lines_model_list)): self.del_model_line(0)
[docs] def list_lines(self): """List lines.""" if self.lines_data_list==[] and self.lines_model_list==[]: pass else: for ID,plot_line in enumerate(self.lines_data_list): print('data',ID, plot_line.get_label()) for ID,plot_line in enumerate(self.lines_model_list): print ('model',ID, plot_line.get_label())
[docs] def del_data_line(self,line_ID): """Del data line. Parameters ---------- line_ID : object Index/identifier for line id. """ if self.lines_data_list==[]: print ("no lines to delete ") else: print ("removing line: ",self.lines_data_list[line_ID]) line = self.lines_data_list[line_ID] for item in line: # This removes lines if np.shape(item) == (): item.remove() else: # This removes containers for data with errorbars for item1 in item: item1.remove() del self.lines_data_list[line_ID] #self.update_legend() self.update_plot()
[docs] def del_model_line(self,line_ID): """Del model line. Parameters ---------- line_ID : object Index/identifier for line id. """ if self.lines_model_list==[]: #print "no lines to delete " pass else: line=self.lines_model_list[line_ID] line.remove() del self.lines_model_list[line_ID] self.update_plot()
#self.update_legend()
[docs] def del_residuals_line(self, line_ID): """Del residuals line. Parameters ---------- line_ID : object Index/identifier for line id. """ if self.lines_res_list == []: # print "no lines to delete " pass else: line = self.lines_res_list[line_ID] line.remove() del self.lines_res_list[line_ID] self.update_plot()
#self.update_legend()
[docs] def set_plot_axis_labels(self, density=False): """Set plot axis labels. Parameters ---------- density : bool, optional Parameter controlling density. """ self.lx = '$ \\nu $ (Hz)' if self.frame == 'src' or self.frame == 'blob': if density is False: self.ly = '$ \\nu L_{\\nu} $ (erg s$^{-1})$' else: self.ly = '$ L_{\\nu} $ (erg s$^{-1}$ Hz$^{-1})$' elif self.frame == 'obs': if density is False: self.ly = '$ \\nu F_{\\nu} $ (erg cm$^{-2}$ s$^{-1})$' else: self.ly = '$ F{\\nu} $ (erg cm$^{-2}$ s$^{-1}$ Hz$^{-1})$' else: unexpected_behaviour() self.sedplot.set_ylabel(self.ly) self.sedplot.set_xlabel(self.lx)
[docs] def add_res_zeroline(self): #y0 = np.zeros(2) #x0 = [0,30] """Add res zeroline.""" self.resplot.axhline(0, ls='--', color='black') self.update_plot()
[docs] def rescale(self, x_min=None, x_max=None, y_min=None, y_max=None): """Rescale. Parameters ---------- x_min : object, optional Minimum value for x. x_max : object, optional Maximum value for x. y_min : object, optional Minimum value for y. y_max : object, optional Maximum value for y. """ _rescale(x_min=x_min,x_max=x_max,y_min=y_min,y_max=y_max)
[docs] def setlim(self, x_min=None, x_max=None, y_min=None, y_max=None): """Setlim. Parameters ---------- x_min : object, optional Minimum value for x. x_max : object, optional Maximum value for x. y_min : object, optional Minimum value for y. y_max : object, optional Maximum value for y. """ self.sedplot.set_xlim(x_min, x_max) self.sedplot.set_ylim(y_min, y_max)
[docs] def setlim_res(self,x_min=None,x_max=None,y_min=None,y_max=None): """Setlim res. Parameters ---------- x_min : object, optional Minimum value for x. x_max : object, optional Maximum value for x. y_min : object, optional Minimum value for y. y_max : object, optional Maximum value for y. """ self.resplot.set_xlim(x_min,x_max) self.resplot.set_ylim(y_min,y_max) self.update_plot()
[docs] def update_plot(self): """Update plot.""" self.fig.canvas.draw() y_s = [] x_min = [] x_max = [] y_min = None y_max = None if len(self.sedplot.lines)>0: for l in self.sedplot.lines: if len(l.get_ydata())>0: y_s.append(np.max(l.get_ydata())) if len(y_s) > 0: y_min = min(y_s)/1000 y_max = max(y_s)*10 else: self.sedplot.autoscale(axis='y') if y_min is not None and y_max is not None: self.sedplot.set_ylim(y_min, y_max) for l in self.sedplot.lines: x=np.array(l.get_xdata())[np.array(l.get_ydata()) >= y_min] if len(x)>0: x_min.append(np.min(x)) x_max.append(np.max(x)) if len(x_min)>0 and len(x_max)>0: self.sedplot.set_xlim(min(x_min)/10, max(x_max)*10) else: self.sedplot.relim() self.sedplot.autoscale(axis='y') self.sedplot.autoscale(axis='x') self.update_legend() self.fig.tight_layout()
[docs] def update_legend(self,label=None): """Update legend. Parameters ---------- label : object, optional Label used in output or plots. """ _handles=[] if self.lines_data_list!=[] and self.lines_data_list is not None: _handles.extend(self.lines_data_list) if self.lines_model_list!=[] and self.lines_model_list is not None: _handles.extend(self.lines_model_list) for h in _handles[:]: if h._label is None: _handles.remove(h) elif h._label.startswith('_'): _handles.remove(h) else: pass self.sedplot.legend(handles=_handles,loc='center left', bbox_to_anchor=(1.0, 0.5), ncol=1, prop={'size':10})
[docs] def add_model_plot(self, model, label=None, color=None, line_style=None, flim=None,auto_label=True,fit_range=None, update=True, lw=1.0 ,frame=None): """Add model plot. Parameters ---------- model : object Model instance. label : object, optional Label used in output or plots. color : object, optional Matplotlib color specification. line_style : object, optional Parameter controlling line style. flim : object, optional Parameter controlling flim. auto_label : bool, optional Parameter controlling auto label. fit_range : object, optional Range for fit. update : bool, optional Parameter controlling update. lw : float, optional Parameter controlling lw. frame : object, optional Reference frame for data/model values. """ frame=self._check_frame(frame=frame) if hasattr(model,'get_model_points'): try: x, y = model.get_model_points(log_log=False, frame = self.frame) except Exception as e: raise RuntimeError('for model',model.name, "problem with get_model_points()",e) else: try: x, y = model.SED.get_model_points(log_log=False, frame = self.frame) except Exception as e: raise RuntimeError('for model',model.name, "problem with SED.get_model_points()",e) if self.density is True: y=y/x if line_style is None: line_style = '-' if label is None and auto_label is True: if model.name is not None: label = model.name else: label = 'line %d' % self.counter if flim is not None: msk=y>flim x=x[msk] y=y[msk] else: pass if fit_range is not None: msk1 = x < fit_range[1] msk2 = x > fit_range[0] x = x[msk1 * msk2] y = y[msk1 * msk2] line, = self.sedplot.plot(x, y, line_style, label=label,color=color,lw=lw) self.lines_model_list.append(line) if update is True: #self.update_legend() self.update_plot() self.counter += 1
[docs] def plot_tempev_model(self, temp_ev, region, comp='Sum', frame=None, t1=None, t2=None, time_slice=None, time_slice_bin=None, time=None, time_bin=None, use_cached=False, sed_data=None, density=False, average=False): """Plot tempev model. Parameters ---------- temp_ev : object Parameter controlling temp ev. region : object Parameter controlling region. comp : str, optional Parameter controlling comp. frame : object, optional Reference frame for data/model values. t1 : object, optional Parameter controlling t1. t2 : object, optional Parameter controlling t2. time_slice : object, optional Time-related value for time slice. time_slice_bin : object, optional Time-related value for time slice bin. time : object, optional Time-related value for time. time_bin : object, optional Time-related value for time bin. use_cached : bool, optional If ``True``, enable cached. sed_data : object, optional Observational SED data container. density : bool, optional Parameter controlling density. average : bool, optional Parameter controlling average. """ frame=self._check_frame(frame) if (time_slice is not None and time is not None): raise RuntimeError('you can to pass either the N-th time slice "time_slice", or the blob time in seconds "time" ') if t1 is None or t1 < region.time_sampled_emitters.time_blob[0]: t1 = region.time_sampled_emitters.time_blob[0] if t2 is None or t2 > region.time_sampled_emitters.time_blob[-1]: t2 = region.time_sampled_emitters.time_blob[-1] if time_slice is None: _time_slice = 0 else: _time_slice = time_slice _time_slice_bin = time_slice_bin if time_slice_bin is None and time_slice is None: _time_slice_bin = 1 #if time_slice is not None or time_bin is None: if time is not None and time_bin is not None: t_array = np.arange(t1, t2, time_bin) time_id_array=None elif time is not None and time_bin is None: t_array = np.array([time]) time_id_array = None else: t_array, time_id_array = region.time_sampled_emitters._get_time_samples(time_slice=_time_slice, time_slice_bin=_time_slice_bin) time_id_array = time_id_array[t_array <= t2] time_id_array = time_id_array[t_array >= t1] t_array = t_array[t_array <= t2] t_array = t_array[t_array >= t1] g = plt.cm.Greens(np.linspace(0.5, 1, t_array.size)) r = plt.cm.Reds(np.linspace(0.5, 1, t_array.size)) b = plt.cm.Blues(np.linspace(0.5, 1, t_array.size)) for ID, t in enumerate(t_array): if time is not None: s = region.get_SED(comp, frame=frame, time=t, use_cached=use_cached, time_bin=time_bin,average=average) else: s = region.get_SED(comp, frame=frame, time_slice=time_id_array[ID], use_cached=use_cached, time_slice_bin=time_slice_bin,average=average) label = None ls = '-' color = r[ID] if temp_ev.custom_q_inj_profile[temp_ev._get_time_slice_T_array(t)] > 0: color = g[ID] ls = '-' lw = 0.2 if temp_ev.custom_acc_profile[temp_ev._get_time_slice_T_array(t)] > 0: color = b[ID] ls = '-' lw = 0.2 if ID == 0: lw = 2 ls = '--' label = 'start, t=%2.2e (s)' % t color = 'green' if ID == t_array.size - 1: lw = 2 ls = '--' color = 'purple' label = 'stop, t=%2.2e (s)' % t self.add_model_plot(model=s, label=label, line_style=ls, color=color, update=False, lw=lw, auto_label=False) if sed_data is not None: self.add_data_plot(sed_data) self.update_plot() return
[docs] def add_data_plot(self,sed_data,label=None,color=None,frame=None,fmt='o',ms=4,mew=0.5,fit_range=None): """Add data plot. Parameters ---------- sed_data : object Observational SED data container. label : object, optional Label used in output or plots. color : object, optional Matplotlib color specification. frame : object, optional Reference frame for data/model values. fmt : str, optional Parameter controlling fmt. ms : int, optional Parameter controlling ms. mew : float, optional Parameter controlling mew. fit_range : object, optional Range for fit. """ self._sed_data=sed_data frame = self._check_frame(frame) try: x,y,dx,dy,=sed_data.get_data_points(log_log=False,frame=self.frame,density=self.density) except Exception as e: raise RuntimeError("!!! ERROR failed to get data points from", sed_data,e) if dx is None: dx=np.zeros(len(sed_data.data['nu_data'])) if dy is None: dy=np.zeros(len(sed_data.data['nu_data'])) UL = sed_data.data['UL'] if label is None: if sed_data.obj_name is not None : label=sed_data.obj_name else: label='line %d'%self.counter if fit_range is not None: msk1 = x < fit_range[1] msk2 = x > fit_range[0] x = x[msk1 * msk2] y = y[msk1 * msk2] dx= dx[msk1 * msk2] dy = dy[msk1 * msk2] UL=UL[msk1 * msk2] line = self.sedplot.errorbar(x, y, xerr=dx, yerr=dy, fmt=fmt , uplims=UL,label=label,ms=ms,mew=mew,color=color) self.lines_data_list.append(line) self.counter+=1 #self.update_legend() self.update_plot()
[docs] def add_xy_plot(self,x,y,label=None,color=None,line_style=None,autoscale=False): """Add xy plot. Parameters ---------- x : object Parameter controlling x. y : object Parameter controlling y. label : object, optional Label used in output or plots. color : object, optional Matplotlib color specification. line_style : object, optional Parameter controlling line style. autoscale : bool, optional Parameter controlling autoscale. """ if line_style is None: line_style='-' if label is None: label='line %d'%self.counter line, = self.sedplot.plot(x, y, line_style,label=label,color=color) self.lines_model_list.append(line) self.counter+=1 #self.update_legend() self.update_plot()
[docs] def add_model_residual_plot(self, model, data, label=None, color=None, filter_UL=True, fit_range=None): """Add model residual plot. Parameters ---------- model : object Model instance. data : object Input data table/array. label : object, optional Label used in output or plots. color : object, optional Matplotlib color specification. filter_UL : bool, optional Parameter controlling filter ul. fit_range : object, optional Range for fit. """ if data is not None: x,y = model.get_residuals(log_log=False,data=data,filter_UL=filter_UL) self.add_xy_residual_plot(x=x, y=y, fit_range=fit_range, color=color) else: pass
[docs] def add_xy_residual_plot(self, x, y, fit_range=None, color=None): """Add xy residual plot. Parameters ---------- x : object Parameter controlling x. y : object Parameter controlling y. fit_range : object, optional Range for fit. color : object, optional Matplotlib color specification. """ if self.counter_res == 0: self.add_res_zeroline() if fit_range is not None: msk1 = x < fit_range[1] msk2 = x > fit_range[0] x = x[msk1 * msk2] y = y[msk1 * msk2] line = self.resplot.errorbar(x, y, yerr=np.ones(x.size), fmt='+', color=color) self.lines_res_list.append(line) self.counter_res += 1 self.update_plot()
[docs] def add_text(self,lines): """Add text. Parameters ---------- lines : object Parameter controlling lines. """ self.PLT.focus(0,0) x_min, x_max = self.sedplot.get_xlim() y_min, y_max = self.sedplot.get_ylim() t='' for line in lines: t+='%s \\n'%line.strip() self.PLT.text(t,font=10,charsize=0.6,x=x_min-1.5,y=y_min-2.85) self.PLT.redraw()
[docs] def save(self,filename=None): """Save object state to disk. Parameters ---------- filename : object, optional Filesystem path for filename. """ if filename is None: wd=self.out_dir filename = 'jetset_fig.png' else: wd='' outname = os.path.join(wd,filename) self.fig.savefig(outname)
[docs] def show(self): """Show.""" self.fig.show()
[docs] class BasePlot(object): """Lightweight base wrapper around a single Matplotlib axis. Notes ----- Provides common axis limit helpers and redraw/autoscale behavior reused by specialized JetSeT plotting classes. """ def __init__(self,figsize=(8,6),dpi=100): """Create a new `BasePlot` instance. Parameters ---------- figsize : tuple, optional Parameter controlling figsize. dpi : int, optional Parameter controlling dpi. """ self.fig, self.ax = plt.subplots(figsize=figsize,dpi=dpi)
[docs] def rescale(self, x_min=None, x_max=None, y_min=None, y_max=None): """Rescale. Parameters ---------- x_min : object, optional Minimum value for x. x_max : object, optional Maximum value for x. y_min : object, optional Minimum value for y. y_max : object, optional Maximum value for y. """ _rescale(x_min=x_min,x_max=x_max,y_min=y_min,y_max=y_max)
[docs] def setlim(self, x_min=None, x_max=None, y_min=None, y_max=None): """Setlim. Parameters ---------- x_min : object, optional Minimum value for x. x_max : object, optional Maximum value for x. y_min : object, optional Minimum value for y. y_max : object, optional Maximum value for y. """ self.ax.set_xlim(x_min, x_max) self.ax.set_ylim(y_min, y_max)
[docs] def update_plot(self): """Update plot.""" self.fig.canvas.draw() self.ax.relim() self.ax.autoscale(axis='y') self.ax.legend() self.fig.tight_layout()
[docs] class PlotSpectralMultipl(BasePlot): """Plot helper for spectral multiplicative terms in log-log space.""" def __init__(self): """Create a new `PlotSpectralMultipl` instance.""" super(PlotSpectralMultipl, self).__init__() secax = self.ax.secondary_xaxis('top', functions=(y_ev_transf, y_ev_transf_inv)) secax.set_xlabel('E (eV)')
[docs] def plot(self,nu,y,y_label,y_min=None,y_max=None,label=None,line_style=None,color=None): """Plot. Parameters ---------- nu : object Frequency values in Hz. y : object Parameter controlling y. y_label : object Parameter controlling y label. y_min : object, optional Minimum value for y. y_max : object, optional Maximum value for y. label : object, optional Label used in output or plots. line_style : object, optional Parameter controlling line style. color : object, optional Matplotlib color specification. """ self.ax.plot(np.log10(nu), np.log10(y),label=label,ls=line_style,color=color) self.ax.set_xlabel(r'$ \nu $ (Hz)') self.ax.set_ylabel(y_label) self.ax.set_ylim(y_min, y_max) self.ax.legend() self.update_plot()
[docs] class PlotPdistr (BasePlot): """Plotter for particle/injection energy distributions. Notes ----- Supports multiple energy units, optional powers of energy (e.g. ``E^2 n``), and linear/log views for electrons and protons. """ def __init__(self,figsize=(8,6),dpi=100,injection=False,loglog=True): """Create a new `PlotPdistr` instance. Parameters ---------- figsize : tuple, optional Parameter controlling figsize. dpi : int, optional Parameter controlling dpi. injection : bool, optional Parameter controlling injection. loglog : bool, optional If ``True``, operate in log10 space. """ super(PlotPdistr, self).__init__(figsize=figsize,dpi=dpi) self.loglog=loglog self.injection = injection def _set_variable(self,gamma,n_gamma,particle,energy_unit,pow=None): energy_plot=False if energy_unit == 'gamma': energy_name = '\gamma' energy_units='' else: energy_name='E' energy_units= '%s'%energy_unit energy_plot=True if energy_plot is False: x=gamma y=n_gamma else: if particle=='electrons': x = gamma*(m_e*c*c).to(energy_unit).value y = n_gamma * 1.0/(m_e*c*c).to(energy_unit).value elif particle=='protons': x = gamma * (m_p * c * c).to(energy_unit).value y = n_gamma * 1.0 / (m_p * c * c).to(energy_unit).value else: raise RuntimeError('particle ',particle, 'not implemented') m = y > 0 x=np.copy(x) y=np.copy(y) if pow is not None: y[m] = y[m]* np.power( x[m], pow) if self.loglog is True: x[m] = np.log10( x[m]) y[m] = np.log10(y[m]) return x[m], y[m], energy_name,energy_units def _set_xy_label(self,energy_name,energy_units,pow): if energy_units != '': _e = '(%s)' % energy_units else: _e = '' if self.loglog is True: self.ax.set_xlabel(r'log($%s$) %s' % (energy_name, _e)) else: self.ax.set_xlabel(r'$%s$ %s' % (energy_name, _e)) if energy_units != '': _e = '%s^{-1}' % energy_units else: _e = '' n_str = 'n($%s$)'%energy_name if pow is not None and pow!=0: n_str = 'n($%s$) $%s^{%d}$' % (energy_name,energy_name,pow) if energy_units != '': i=-1+pow if i==1: _e='%s' %(energy_units) else: _e='%s^{%d}' %(energy_units,i) if self.injection is False: if self.loglog is True: self.ax.set_ylabel(r'log(%s) ($cm^{-3} %s$) ' % (n_str, _e)) else: self.ax.set_ylabel(r'%s ($cm^{-3} %s$) ' % (n_str, _e)) else: if self.loglog is True: self.ax.set_ylabel(r'log(Q$_{inj}$($%s$)) ($cm^{-3} s^{-1} %s$)' % (energy_name,_e)) else: self.ax.set_ylabel(r'Q$_{inj}$($%s$) ($cm^{-3} s^{-1} %s$)' % (energy_name,_e)) def _plot(self,x,y,c=None,lw=None,ls=None,label=None): if self.loglog is True: self.ax.plot(x, y, c=c, lw=lw, label=label,ls=ls) else: self.ax.loglog(x, y,c=c, lw=lw, label=label,ls=ls)
[docs] def plot_distr(self,gamma,n_gamma,y_min=None,y_max=None,x_min=None,x_max=None,particle='electrons',energy_unit='gamma',label=None): """Plot distr. Parameters ---------- gamma : object Frequency/energy control value for gamma. n_gamma : object Frequency/energy control value for n gamma. y_min : object, optional Minimum value for y. y_max : object, optional Maximum value for y. x_min : object, optional Minimum value for x. x_max : object, optional Maximum value for x. particle : str, optional Parameter controlling particle. energy_unit : str, optional Frequency/energy control value for energy unit. label : object, optional Label used in output or plots. """ x,y,energy_name,energy_units=self._set_variable(gamma,n_gamma,particle,energy_unit) if label is None: label=particle self._plot(x,y,label=label) self._set_xy_label(energy_name,energy_units,pow=None) self.update_plot() self.ax.set_ylim(y_min, y_max) self.ax.set_xlim(x_min, x_max)
[docs] def plot_distr2p(self, gamma, n_gamma, y_min=None, y_max=None, x_min=None, x_max=None,particle='electrons',energy_unit='gamma',label=None): """Plot distr2p. Parameters ---------- gamma : object Frequency/energy control value for gamma. n_gamma : object Frequency/energy control value for n gamma. y_min : object, optional Minimum value for y. y_max : object, optional Maximum value for y. x_min : object, optional Minimum value for x. x_max : object, optional Maximum value for x. particle : str, optional Parameter controlling particle. energy_unit : str, optional Frequency/energy control value for energy unit. label : object, optional Label used in output or plots. """ if label is None: label=particle x, y, energy_name, energy_units = self._set_variable(gamma, n_gamma, particle, energy_unit,pow=2) self._plot(x,y,label=label) self._set_xy_label(energy_name, energy_units,pow=2) self.update_plot() self.ax.set_ylim(y_min, y_max) self.ax.set_xlim(x_min, x_max)
[docs] def plot_distr3p(self,gamma,n_gamma,y_min=None,y_max=None,x_min=None,x_max=None,particle='electrons',energy_unit='gamma', label=None): """Plot distr3p. Parameters ---------- gamma : object Frequency/energy control value for gamma. n_gamma : object Frequency/energy control value for n gamma. y_min : object, optional Minimum value for y. y_max : object, optional Maximum value for y. x_min : object, optional Minimum value for x. x_max : object, optional Maximum value for x. particle : str, optional Parameter controlling particle. energy_unit : str, optional Frequency/energy control value for energy unit. label : object, optional Label used in output or plots. """ if label is None: label = particle x, y, energy_name, energy_units = self._set_variable(gamma, n_gamma, particle, energy_unit, pow=3) self._plot(x,y,label=label) self._set_xy_label(energy_name, energy_units,pow=3) self.update_plot() self.ax.set_ylim(y_min, y_max) self.ax.set_xlim(x_min, x_max)
[docs] def update_plot(self): """Update plot.""" self.fig.canvas.draw() self.ax.relim() self.ax.autoscale(axis='y') self.ax.autoscale(axis='x') self.ax.legend() self.fig.tight_layout()
[docs] class PlotTempEvEmitters (PlotPdistr): """Emitter-distribution plotter specialized for time-evolution outputs.""" def __init__(self,figsize=(8,6),dpi=100,loglog=True): """Create a new `PlotTempEvEmitters` instance. Parameters ---------- figsize : tuple, optional Parameter controlling figsize. dpi : int, optional Parameter controlling dpi. loglog : bool, optional If ``True``, operate in log10 space. """ super(PlotTempEvEmitters, self).__init__(figsize=figsize,dpi=dpi,loglog=loglog,) def _plot_distr(self,temp_ev, region,particle='electrons', energy_unit='gamma', pow=None, plot_Q_inj=True, t1=None, t2=None, #time_slice_bin=None, #time_slice=None): ): if t1 is None: t1=region.time_sampled_emitters.time_blob[0] if t2 is None: t2=region.time_sampled_emitters.time_blob[-1] # if time_slice_bin is None and time_slice is None: # _time_slice_bin = 1 t_array = np.linspace(t1, t2, region.time_sampled_emitters.time_blob.size,) ls = '-' lw = 0.2 g = plt.cm.Greens(np.linspace(0.5, 1, t_array.size)) r = plt.cm.Reds(np.linspace(0.5, 1, t_array.size)) b = plt.cm.Blues(np.linspace(0.5, 1, t_array.size)) n = region.time_sampled_emitters.n_gamma for ID, t in enumerate(t_array): label = None ls = '-' color = r[ID] if temp_ev.custom_q_inj_profile[temp_ev._get_time_slice_T_array(t)] > 0: color = g[ID] ls = '-' lw = 0.2 if temp_ev.custom_acc_profile[temp_ev._get_time_slice_T_array(t)] > 0: color = b[ID] ls = '-' lw = 0.2 if ID == 0: lw = 2 ls = '--' label = 'start, t=%2.2e (s)' % t color = 'green' if ID == t_array.size - 1: lw = 2 ls = '--' color = 'purple' label = 'stop, t=%2.2e (s)' % t x, y, energy_name, energy_units = self._set_variable(region.time_sampled_emitters.gamma, n[ID], particle, energy_unit, pow=pow) self._plot(x,y,c=color,lw=lw,label=label,ls=ls) #x, y, energy_name, energy_units = self._set_variable(region.time_sampled_emitters.gamma, n[0], particle, #energy_unit, pow=pow) #self._plot(x, y, c='black', lw=2,label='Start sample') #x, y, energy_name, energy_units = self._set_variable(region.time_sampled_emitters.gamma, n[-1], particle, #energy_unit, pow=pow) #self._plot(x, y, c='blue', lw=2,label='Stop sample') self._set_xy_label(energy_name, energy_units,pow=pow) #TODO move to plot inj if iny is used in region if temp_ev.Q_inj is not None and (region._region_type == 'acc' or temp_ev._only_radiation is True): y = temp_ev.Q_inj.n_gamma_e * temp_ev.delta_t x = temp_ev.Q_inj.gamma_e if plot_Q_inj is True: if pow is not None: y=y*np.power(x,pow) self._plot(x,y, c='red', lw=1, label='$Q_{inj}$ delta t') #print('==> d') self.ax.legend()
[docs] def plot_distr(self, temp_ev, region='acc', energy_unit='gamma',plot_Q_inj=True,pow=None): """Plot distr. Parameters ---------- temp_ev : object Parameter controlling temp ev. region : str, optional Parameter controlling region. energy_unit : str, optional Frequency/energy control value for energy unit. plot_Q_inj : bool, optional If ``True``, plot q inj. pow : object, optional Parameter controlling pow. """ self._plot_distr(temp_ev,region=region, particle='electrons',energy_unit=energy_unit,pow=pow,plot_Q_inj=plot_Q_inj)
[docs] def plot_distr2p(self, temp_ev, region='acc', energy_unit='gamma',plot_Q_inj=True): """Plot distr2p. Parameters ---------- temp_ev : object Parameter controlling temp ev. region : str, optional Parameter controlling region. energy_unit : str, optional Frequency/energy control value for energy unit. plot_Q_inj : bool, optional If ``True``, plot q inj. """ self._plot_distr(temp_ev, region=region, particle='electrons',energy_unit=energy_unit,pow=2,plot_Q_inj=plot_Q_inj)
[docs] def plot_distr3p(self, temp_ev, region='acc', energy_unit='gamma',plot_Q_inj=True): """Plot distr3p. Parameters ---------- temp_ev : object Parameter controlling temp ev. region : str, optional Parameter controlling region. energy_unit : str, optional Frequency/energy control value for energy unit. plot_Q_inj : bool, optional If ``True``, plot q inj. """ self._plot_distr(temp_ev, region=region, particle='electrons',energy_unit=energy_unit,pow=3,plot_Q_inj=plot_Q_inj)
[docs] class PlotTempEvDiagram (object): """Diagnostic multi-panel summary for time-evolution control profiles.""" def __init__(self,figsize=(8,6),dpi=100,expanding_region=False): """Create a new `PlotTempEvDiagram` instance. Parameters ---------- figsize : tuple, optional Parameter controlling figsize. dpi : int, optional Parameter controlling dpi. expanding_region : bool, optional Parameter controlling expanding region. """ if expanding_region is True: n_rows =4 else: n_rows = 4 self.fig, self.axs = plt.subplots(n_rows, 1, figsize=figsize, dpi=dpi, sharex=False)
[docs] def plot(self, T_array, inj_profile, acc_profile, R_exp, B_exp, R_H_exp): """Plot. Parameters ---------- T_array : object Array/grid values for t array. inj_profile : object Filesystem path for inj profile. acc_profile : object Filesystem path for acc profile. R_exp : object Parameter controlling r exp. B_exp : object Parameter controlling b exp. R_H_exp : object Parameter controlling r h exp. """ self.axs[0].plot(T_array, acc_profile, label='Acc. start/stop', c='g') self.axs[0].set_ylim(0, 1.5) self.axs[1].plot(T_array, inj_profile, label='Inj. profile', c='b') self.axs[1].set_ylim(0,None) self.axs[1].set_xlabel('Time in blob frame (s)') self.axs[1].sharex = self.axs[0] #if expanding_region is True: self.axs[2].plot(np.log10(R_H_exp), np.log10(R_exp), label='Rad. region size', c='orange') self.axs[2].set_ylabel('log(R_rad) (cm)') self.axs[3].plot(np.log10(R_H_exp), np.log10(B_exp), label='B Rad. region', c='green') self.axs[3].set_xlabel('log(BH distance) in observer frame (cm)') self.axs[3].set_ylabel('log(B) (G)') self.axs[3].sharex = self.axs[2] for ax in self.axs: ax.legend() #self.fig.subplots_adjust(hspace=0) self.fig.tight_layout()
[docs] def rescale(self, x_min=None, x_max=None, y_min=None, y_max=None): """Rescale. Parameters ---------- x_min : object, optional Minimum value for x. x_max : object, optional Maximum value for x. y_min : object, optional Minimum value for y. y_max : object, optional Maximum value for y. """ _rescale(x_min=x_min,x_max=x_max,y_min=y_min,y_max=y_max)
[docs] def setlim(self, x_min=None, x_max=None, y_min=None, y_max=None): """Setlim. Parameters ---------- x_min : object, optional Minimum value for x. x_max : object, optional Maximum value for x. y_min : object, optional Minimum value for y. y_max : object, optional Maximum value for y. """ self.ax.set_xlim(x_min, x_max) self.ax.set_ylim(y_min, y_max)
[docs] class PlotSpecComp (BasePlot): """Quick-look plotter for individual spectral components.""" def __init__(self): """Create a new `PlotSpecComp` instance.""" super(PlotSpecComp, self).__init__()
[docs] def plot(self,nu,nuFnu,y_min=None,y_max=None): """Plot. Parameters ---------- nu : object Frequency values in Hz. nuFnu : object Frequency/energy control value for nu fnu. y_min : object, optional Minimum value for y. y_max : object, optional Maximum value for y. """ self.ax.plot(np.log10(nu), np.log10(nuFnu)) self.ax.set_xlabel(r'log($ \nu $) (Hz)') self.ax.set_ylabel(r'log($ \nu F_{\nu} $ ) (erg cm$^{-2}$ s$^{-1}$)') self.ax.set_ylim(y_min, y_max) self.update_plot()
[docs] class PlotSeedPhotons (BasePlot): """Quick-look plotter for seed-photon number density spectra.""" def __init__(self): """Create a new `PlotSeedPhotons` instance.""" super(PlotSeedPhotons, self).__init__()
[docs] def plot(self,nu,nuFnu,y_min=None,y_max=None): """Plot. Parameters ---------- nu : object Frequency values in Hz. nuFnu : object Frequency/energy control value for nu fnu. y_min : object, optional Minimum value for y. y_max : object, optional Maximum value for y. """ self.ax.plot(np.log10(nu), np.log10(nuFnu)) self.ax.set_xlabel(r'log($ \nu $) (Hz)') self.ax.set_ylabel(r'log(n ) (photons cm$^{-3}$ Hz$^{-1}$ ster$^{-1}$)') self.ax.set_ylim(y_min, y_max) self.update_plot()
def heatmap(data, row_labels, col_labels, ax=None, cbar_kw=None, cbarlabel="", **kwargs): """ Create a heatmap from a numpy array and two lists of labels. Parameters ---------- data A 2D numpy array of shape (M, N). row_labels A list or array of length M with the labels for the rows. col_labels A list or array of length N with the labels for the columns. ax A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If not provided, use current Axes or create a new one. Optional. cbar_kw A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional. cbarlabel The label for the colorbar. Optional. **kwargs All other arguments are forwarded to `imshow`. """ if ax is None: ax = plt.gca() if cbar_kw is None: cbar_kw = {} # Plot the heatmap im = ax.imshow(data, **kwargs) # Create colorbar cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw) cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom") # Show all ticks and label them with the respective list entries. ax.set_xticks(np.arange(data.shape[1]), labels=col_labels) ax.set_yticks(np.arange(data.shape[0]), labels=row_labels) # Let the horizontal axes labeling appear on top. ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) # Rotate the tick labels and set their alignment. plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor") # Turn spines off and create white grid. ax.spines[:].set_visible(False) ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True) ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True) ax.grid(which="minor", color="w", linestyle='-', linewidth=3) ax.tick_params(which="minor", bottom=False, left=False) return im, cbar def annotate_heatmap(im, data=None, valfmt="{x:.2f}", textcolors=("black", "white"), threshold=None, **textkw): """ A function to annotate a heatmap. Parameters ---------- im The AxesImage to be labeled. data Data used to annotate. If None, the image's data is used. Optional. valfmt The format of the annotations inside the heatmap. This should either use the string format method, e.g. "$ {x:.2f}", or be a `matplotlib.ticker.Formatter`. Optional. textcolors A pair of colors. The first is used for values below a threshold, the second for those above. Optional. threshold Value in data units according to which the colors from textcolors are applied. If None (the default) uses the middle of the colormap as separation. Optional. **kwargs All other arguments are forwarded to each call to `text` used to create the text labels. """ if not isinstance(data, (list, np.ndarray)): data = im.get_array() # Normalize the threshold to the images color range. if threshold is not None: threshold = im.norm(threshold) else: threshold = im.norm(data.max())/2. # Set default alignment to center, but allow it to be # overwritten by textkw. kw = dict(horizontalalignment="center", verticalalignment="center") kw.update(textkw) # Get the formatter in case a string is supplied if isinstance(valfmt, str): valfmt = ticker.StrMethodFormatter(valfmt) # Loop over the data and create a `Text` for each "pixel". # Change the text's color depending on the data. texts = [] for i in range(data.shape[0]): for j in range(data.shape[1]): kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)]) text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) texts.append(text) return texts