Source code for multineas.plot


import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from .util import Util
from .version import __version__

[docs] def multineas_watermark(ax,enlarge=1,alpha=0.5): """Add a water mark to a 2d or 3d plot. Parameters: ax: Class axes: Axe where the pryngles mark will be placed. """ #Get the height of axe axh=ax.get_window_extent().transformed(ax.get_figure().dpi_scale_trans.inverted()).height fig_factor=axh/4 #Options of the water mark args=dict( rotation=270,ha='left',va='top', transform=ax.transAxes,color='pink',fontsize=8*fig_factor*enlarge,zorder=100,alpha=alpha, ) #Text of the water mark mark=f"MultiNEAs {__version__}" #Choose the according to the fact it is a 2d or 3d plot try: ax.add_collection3d plt_text=ax.text2D except: plt_text=ax.text text=plt_text(1,1,mark,**args); return text
[docs] class CornerPlot(object): """ Create a grid of plots showing the projection of a N-dimensional data. Parameters ---------- properties : dict List of properties to be shown, dictionary of dictionaries (N entries). Keys are label of attribute, ex. "q". Dictionary values: * label: label used in axis, string * range: range for property, tuple (2) figsize : int, optional Base size for panels (the size of figure will be M x figsize), default 3. fontsize : int, optional Base fontsize, default 10. direction : str, optional Direction of ticks in panels, default 'out'. Attributes ---------- N : int Number of properties. M : int Size of grid matrix (M=N-1). fw : int Figsize. fs : int Fontsize. fig : matplotlib.figure.Figure Figure handle. axs : numpy.ndarray Matrix with subplots, axes handles (MxM). axp : dict Matrix with subplots, dictionary of dictionaries. properties : list List of properties labels, list of strings (N). Methods ------- tight_layout() Tight layout if no constrained_layout was used. set_labels(**args) Set labels parameters. set_ranges() Set ranges in panels according to ranges defined in dparameters. set_tick_params(**args) Set tick parameters. plot_hist(data, colorbar=False, **args) Create a 2d-histograms of data on all panels of the CornerPlot. scatter_plot(data, **args) Scatter plot on all panels of the CornerPlot. """ def __init__(self,properties,figsize=3,fontsize=10,direction='out'): #Basic attributes self.dproperties=properties self.properties=list(properties.keys()) #Secondary attributes self.N=len(properties) self.M=self.N-1 #Optional properties self.fw=figsize self.fs=fontsize #Create figure and axes: it works try: self.fig,self.axs=plt.subplots( self.M,self.M, constrained_layout=True, figsize=(self.M*self.fw,self.M*self.fw), sharex="col",sharey="row" ) self.constrained=True except: self.fig,self.axs=plt.subplots( self.M,self.M, figsize=(self.M*self.fw,self.M*self.fw), sharex="col",sharey="row" ) self.constrained=False if not isinstance(self.axs,np.ndarray): self.axs=np.array([[self.axs]]) self.single = True else: self.single = False #Create named axis self.axp=dict() for j in range(self.N): propj=self.properties[j] if propj not in self.axp.keys(): self.axp[propj]=dict() for i in range(self.N): propi=self.properties[i] if i==j: continue if propi not in self.axp.keys(): self.axp[propi]=dict() if i<j: self.axp[propj][propi]=self.axp[propi][propj] continue self.axp[propj][propi]=self.axs[i-1][j] #Deactivate unused panels for i in range(self.M): for j in range(i+1,self.M): self.axs[i][j].axis("off") #Place ticks for i in range(self.M): for j in range(i+1): if not self.single: self.axs[i,j].tick_params(axis='both',direction=direction) else: self.axs[i,i].tick_params(axis='both',direction=direction) for i in range(self.M): self.axs[i,0].tick_params(axis='y',direction="out") self.axs[self.M-1,i].tick_params(axis='x',direction="out") #Set properties of panels self.set_labels() self.set_ranges() self.set_tick_params() self.tight_layout()
[docs] def tight_layout(self): """ Tight layout if no constrained_layout was used. Attr. [HC] """ if self.constrained==False: self.fig.subplots_adjust(wspace=self.fw/100.,hspace=self.fw/100.) self.fig.tight_layout()
[docs] def set_tick_params(self,**args): """ Set tick parameters. Parameters ---------- **args : dict Same arguments as tick_params method. Attr. [HC] """ opts=dict(axis='both',which='major',labelsize=0.8*self.fs) opts.update(args) for i in range(self.M): for j in range(self.M): self.axs[i][j].tick_params(**opts)
[docs] def set_ranges(self): """ Set ranges in panels according to ranges defined in dparameters. Attr. [HC] """ for i,propi in enumerate(self.properties): for j,propj in enumerate(self.properties): if j<=i:continue if self.dproperties[propi]["range"] is not None: self.axp[propi][propj].set_xlim(self.dproperties[propi]["range"]) if self.dproperties[propj]["range"] is not None: self.axp[propi][propj].set_ylim(self.dproperties[propj]["range"])
[docs] def set_labels(self,**args): """ Set labels parameters. Parameters ---------- **args : dict Common arguments of set_xlabel, set_ylabel and text. Attr. [HC] """ opts=dict(fontsize=self.fs) opts.update(args) for i,prop in enumerate(self.properties[:-1]): label=self.dproperties[prop]["label"] self.axs[self.M-1][i].set_xlabel(label,**opts) for i,prop in enumerate(self.properties[1:]): label=self.dproperties[prop]["label"] self.axs[i][0].set_ylabel(label,rotation=90,labelpad=10,**opts) for i in range(1,self.M): label=self.dproperties[self.properties[i]]["label"] self.axs[i-1][i].text(0.5,0.0,label,ha='center', transform=self.axs[i-1][i].transAxes,**opts) #270 if you want rotation self.axs[i-1][i].text(0.0,0.5,label,rotation=270,va='center', transform=self.axs[i-1][i].transAxes,**opts) label=self.dproperties[self.properties[0]]["label"] if not self.single: self.axs[0][1].text(0.0,1.0,label,rotation=0,ha='left',va='top', transform=self.axs[0][1].transAxes,**opts) label=self.dproperties[self.properties[-1]]["label"] #270 if you want rotation self.axs[-1][-1].text(1.05,0.5,label,rotation=270,ha='left',va='center', transform=self.axs[-1][-1].transAxes,**opts) self.tight_layout()
[docs] def plot_hist(self,data,colorbar=False,**args): """ Create a 2d-histograms of data on all panels of the CornerPlot. Parameters ---------- data : numpy.ndarray Data to be histogramed (n=len(data)), numpy array (nxN). colorbar : bool, optional Include a colorbar? (default False). **args : dict All arguments of hist2d method. Returns ------- hist : list List of histogram instances. Examples -------- >>> properties = { ... 'Q': {'label': r"$Q$", 'range': None}, ... 'E': {'label': r"$C$", 'range': None}, ... 'I': {'label': r"$I$", 'range': None}, ... } >>> G = mm.CornerPlot(properties, figsize=3) >>> hargs = dict(bins=100, cmap='viridis') >>> hist = G.plot_hist(udata, **hargs) Attr. [HC] """ opts=dict() opts.update(args) hist=[] for i,propi in enumerate(self.properties): if self.dproperties[propi]["range"] is not None: xmin,xmax=self.dproperties[propi]["range"] else: xmin=data[:,i].min() xmax=data[:,i].max() for j,propj in enumerate(self.properties): if j<=i:continue if self.dproperties[propj]["range"] is not None: ymin,ymax=self.dproperties[propj]["range"] else: ymin=data[:,j].min() ymax=data[:,j].max() opts["range"]=[[xmin,xmax],[ymin,ymax]] h,xe,ye,im=self.axp[propi][propj].hist2d(data[:,i],data[:,j],**opts) hist+=[im] if colorbar: #Create color bar divider=make_axes_locatable(self.axp[propi][propj]) cax=divider.append_axes("top",size="9%",pad=0.1) self.fig.add_axes(cax) cticks=np.linspace(h.min(),h.max(),10)[2:-1] self.fig.colorbar(im, ax=self.axp[propi][propj], cax=cax, orientation="horizontal", ticks=cticks) cax.xaxis.set_tick_params(labelsize=0.5*self.fs,direction="in",pad=-0.8*self.fs) xt=cax.get_xticks() xm=xt.mean() m,e=Util.mantisa_exp(xm) xtl=[] for x in xt: xtl+=["%.1f"%(x/10**e)] cax.set_xticklabels(xtl) cax.text(0,0.5,r"$\times 10^{%d}$"%e,ha="left",va="center", transform=cax.transAxes,fontsize=6,color='w') self.set_labels() self.set_ranges() self.set_tick_params() self.tight_layout() multineas_watermark(self.axs[0][0]) return hist
[docs] def scatter_plot(self,data,**args): """ Scatter plot on all panels of the CornerPlot. Parameters ---------- data : numpy.ndarray Data to be histogramed (n=len(data)), numpy array (nxN). **args : dict All arguments of scatter method. Returns ------- scatter : list List of scatter instances. Examples -------- >>> sargs = dict(s=0.2, edgecolor='None', color='r') >>> hist = G.scatter_plot(udata, **sargs) Attr. [HC] """ scatter=[] for i,propi in enumerate(self.properties): for j,propj in enumerate(self.properties): if j<=i:continue scatter+=[self.axp[propi][propj].scatter(data[:,i],data[:,j],**args)] self.set_labels() self.set_ranges() self.set_tick_params() self.tight_layout() multineas_watermark(self.axs[0][0]) return scatter