import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from .util import Util
from .version import __version__
[docs]
def multineas_watermark(ax,enlarge=1,alpha=0.5):
"""Add a water mark to a 2d or 3d plot.
Parameters:
ax: Class axes:
Axe where the pryngles mark will be placed.
"""
#Get the height of axe
axh=ax.get_window_extent().transformed(ax.get_figure().dpi_scale_trans.inverted()).height
fig_factor=axh/4
#Options of the water mark
args=dict(
rotation=270,ha='left',va='top',
transform=ax.transAxes,color='pink',fontsize=8*fig_factor*enlarge,zorder=100,alpha=alpha,
)
#Text of the water mark
mark=f"MultiNEAs {__version__}"
#Choose the according to the fact it is a 2d or 3d plot
try:
ax.add_collection3d
plt_text=ax.text2D
except:
plt_text=ax.text
text=plt_text(1,1,mark,**args);
return text
[docs]
class CornerPlot(object):
"""
Create a grid of plots showing the projection of a N-dimensional data.
Parameters
----------
properties : dict
List of properties to be shown, dictionary of dictionaries (N entries).
Keys are label of attribute, ex. "q".
Dictionary values:
* label: label used in axis, string
* range: range for property, tuple (2)
figsize : int, optional
Base size for panels (the size of figure will be M x figsize), default 3.
fontsize : int, optional
Base fontsize, default 10.
direction : str, optional
Direction of ticks in panels, default 'out'.
Attributes
----------
N : int
Number of properties.
M : int
Size of grid matrix (M=N-1).
fw : int
Figsize.
fs : int
Fontsize.
fig : matplotlib.figure.Figure
Figure handle.
axs : numpy.ndarray
Matrix with subplots, axes handles (MxM).
axp : dict
Matrix with subplots, dictionary of dictionaries.
properties : list
List of properties labels, list of strings (N).
Methods
-------
tight_layout()
Tight layout if no constrained_layout was used.
set_labels(**args)
Set labels parameters.
set_ranges()
Set ranges in panels according to ranges defined in dparameters.
set_tick_params(**args)
Set tick parameters.
plot_hist(data, colorbar=False, **args)
Create a 2d-histograms of data on all panels of the CornerPlot.
scatter_plot(data, **args)
Scatter plot on all panels of the CornerPlot.
"""
def __init__(self,properties,figsize=3,fontsize=10,direction='out'):
#Basic attributes
self.dproperties=properties
self.properties=list(properties.keys())
#Secondary attributes
self.N=len(properties)
self.M=self.N-1
#Optional properties
self.fw=figsize
self.fs=fontsize
#Create figure and axes: it works
try:
self.fig,self.axs=plt.subplots(
self.M,self.M,
constrained_layout=True,
figsize=(self.M*self.fw,self.M*self.fw),
sharex="col",sharey="row"
)
self.constrained=True
except:
self.fig,self.axs=plt.subplots(
self.M,self.M,
figsize=(self.M*self.fw,self.M*self.fw),
sharex="col",sharey="row"
)
self.constrained=False
if not isinstance(self.axs,np.ndarray):
self.axs=np.array([[self.axs]])
self.single = True
else:
self.single = False
#Create named axis
self.axp=dict()
for j in range(self.N):
propj=self.properties[j]
if propj not in self.axp.keys():
self.axp[propj]=dict()
for i in range(self.N):
propi=self.properties[i]
if i==j:
continue
if propi not in self.axp.keys():
self.axp[propi]=dict()
if i<j:
self.axp[propj][propi]=self.axp[propi][propj]
continue
self.axp[propj][propi]=self.axs[i-1][j]
#Deactivate unused panels
for i in range(self.M):
for j in range(i+1,self.M):
self.axs[i][j].axis("off")
#Place ticks
for i in range(self.M):
for j in range(i+1):
if not self.single:
self.axs[i,j].tick_params(axis='both',direction=direction)
else:
self.axs[i,i].tick_params(axis='both',direction=direction)
for i in range(self.M):
self.axs[i,0].tick_params(axis='y',direction="out")
self.axs[self.M-1,i].tick_params(axis='x',direction="out")
#Set properties of panels
self.set_labels()
self.set_ranges()
self.set_tick_params()
self.tight_layout()
[docs]
def tight_layout(self):
"""
Tight layout if no constrained_layout was used.
Attr. [HC]
"""
if self.constrained==False:
self.fig.subplots_adjust(wspace=self.fw/100.,hspace=self.fw/100.)
self.fig.tight_layout()
[docs]
def set_tick_params(self,**args):
"""
Set tick parameters.
Parameters
----------
**args : dict
Same arguments as tick_params method.
Attr. [HC]
"""
opts=dict(axis='both',which='major',labelsize=0.8*self.fs)
opts.update(args)
for i in range(self.M):
for j in range(self.M):
self.axs[i][j].tick_params(**opts)
[docs]
def set_ranges(self):
"""
Set ranges in panels according to ranges defined in dparameters.
Attr. [HC]
"""
for i,propi in enumerate(self.properties):
for j,propj in enumerate(self.properties):
if j<=i:continue
if self.dproperties[propi]["range"] is not None:
self.axp[propi][propj].set_xlim(self.dproperties[propi]["range"])
if self.dproperties[propj]["range"] is not None:
self.axp[propi][propj].set_ylim(self.dproperties[propj]["range"])
[docs]
def set_labels(self,**args):
"""
Set labels parameters.
Parameters
----------
**args : dict
Common arguments of set_xlabel, set_ylabel and text.
Attr. [HC]
"""
opts=dict(fontsize=self.fs)
opts.update(args)
for i,prop in enumerate(self.properties[:-1]):
label=self.dproperties[prop]["label"]
self.axs[self.M-1][i].set_xlabel(label,**opts)
for i,prop in enumerate(self.properties[1:]):
label=self.dproperties[prop]["label"]
self.axs[i][0].set_ylabel(label,rotation=90,labelpad=10,**opts)
for i in range(1,self.M):
label=self.dproperties[self.properties[i]]["label"]
self.axs[i-1][i].text(0.5,0.0,label,ha='center',
transform=self.axs[i-1][i].transAxes,**opts)
#270 if you want rotation
self.axs[i-1][i].text(0.0,0.5,label,rotation=270,va='center',
transform=self.axs[i-1][i].transAxes,**opts)
label=self.dproperties[self.properties[0]]["label"]
if not self.single:
self.axs[0][1].text(0.0,1.0,label,rotation=0,ha='left',va='top',
transform=self.axs[0][1].transAxes,**opts)
label=self.dproperties[self.properties[-1]]["label"]
#270 if you want rotation
self.axs[-1][-1].text(1.05,0.5,label,rotation=270,ha='left',va='center',
transform=self.axs[-1][-1].transAxes,**opts)
self.tight_layout()
[docs]
def plot_hist(self,data,colorbar=False,**args):
"""
Create a 2d-histograms of data on all panels of the CornerPlot.
Parameters
----------
data : numpy.ndarray
Data to be histogramed (n=len(data)), numpy array (nxN).
colorbar : bool, optional
Include a colorbar? (default False).
**args : dict
All arguments of hist2d method.
Returns
-------
hist : list
List of histogram instances.
Examples
--------
>>> properties = {
... 'Q': {'label': r"$Q$", 'range': None},
... 'E': {'label': r"$C$", 'range': None},
... 'I': {'label': r"$I$", 'range': None},
... }
>>> G = mm.CornerPlot(properties, figsize=3)
>>> hargs = dict(bins=100, cmap='viridis')
>>> hist = G.plot_hist(udata, **hargs)
Attr. [HC]
"""
opts=dict()
opts.update(args)
hist=[]
for i,propi in enumerate(self.properties):
if self.dproperties[propi]["range"] is not None:
xmin,xmax=self.dproperties[propi]["range"]
else:
xmin=data[:,i].min()
xmax=data[:,i].max()
for j,propj in enumerate(self.properties):
if j<=i:continue
if self.dproperties[propj]["range"] is not None:
ymin,ymax=self.dproperties[propj]["range"]
else:
ymin=data[:,j].min()
ymax=data[:,j].max()
opts["range"]=[[xmin,xmax],[ymin,ymax]]
h,xe,ye,im=self.axp[propi][propj].hist2d(data[:,i],data[:,j],**opts)
hist+=[im]
if colorbar:
#Create color bar
divider=make_axes_locatable(self.axp[propi][propj])
cax=divider.append_axes("top",size="9%",pad=0.1)
self.fig.add_axes(cax)
cticks=np.linspace(h.min(),h.max(),10)[2:-1]
self.fig.colorbar(im,
ax=self.axp[propi][propj],
cax=cax,
orientation="horizontal",
ticks=cticks)
cax.xaxis.set_tick_params(labelsize=0.5*self.fs,direction="in",pad=-0.8*self.fs)
xt=cax.get_xticks()
xm=xt.mean()
m,e=Util.mantisa_exp(xm)
xtl=[]
for x in xt:
xtl+=["%.1f"%(x/10**e)]
cax.set_xticklabels(xtl)
cax.text(0,0.5,r"$\times 10^{%d}$"%e,ha="left",va="center",
transform=cax.transAxes,fontsize=6,color='w')
self.set_labels()
self.set_ranges()
self.set_tick_params()
self.tight_layout()
multineas_watermark(self.axs[0][0])
return hist
[docs]
def scatter_plot(self,data,**args):
"""
Scatter plot on all panels of the CornerPlot.
Parameters
----------
data : numpy.ndarray
Data to be histogramed (n=len(data)), numpy array (nxN).
**args : dict
All arguments of scatter method.
Returns
-------
scatter : list
List of scatter instances.
Examples
--------
>>> sargs = dict(s=0.2, edgecolor='None', color='r')
>>> hist = G.scatter_plot(udata, **sargs)
Attr. [HC]
"""
scatter=[]
for i,propi in enumerate(self.properties):
for j,propj in enumerate(self.properties):
if j<=i:continue
scatter+=[self.axp[propi][propj].scatter(data[:,i],data[:,j],**args)]
self.set_labels()
self.set_ranges()
self.set_tick_params()
self.tight_layout()
multineas_watermark(self.axs[0][0])
return scatter