# -*- coding: utf-8 -*-
# License: BSD-3-Clause
# Author: LKouadio <etanoyau@gmail.com>
"""
:mod:`~watex.utils.plot` is a set of base plots for :term:`tensor`
visualization, data exploratory and analyses.
T-E-Q Plots encompass the tensors plots (:class:`~watex.view.TPlot`) dealing
with :term:`EM` methods, Exploratory plots ( :class:`~watex.view.ExPlot`) and
Quick analyses (:class:`~watex.view.QuickPlot`) visualization.
"""
from __future__ import annotations
import re
import copy
import warnings
import itertools
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from matplotlib.gridspec import GridSpec
import pandas as pd
from pandas.plotting import (
radviz ,
parallel_coordinates
)
import seaborn as sns
from .._docstring import (
DocstringComponents,
_core_docs,
_baseplot_params
)
from .._watexlog import watexlog
from ..decorators import temp2d
from ..cases.features import FeatureInspection
from ..exceptions import (
PlotError,
FeatureError,
NotFittedError,
EMError,
SiteError,
)
from ..property import BasePlot
from .._typing import (
Any ,
List,
Dict,
Optional,
ArrayLike,
DataFrame,
Series,
F,
EDIO
)
from ..utils._dependency import (
import_optional_dependency )
from ..utils.coreutils import _is_readable
from ..utils.exmath import (
moving_average , fittensor)
from ..utils.funcutils import (
_assert_all_types ,
_validate_name_in,
_isin,
repr_callable_obj,
smart_strobj_recognition,
remove_outliers,
smart_format,
reshape,
shrunkformat,
is_iterable,
station_id,
make_ids,
)
from ..utils.mlutils import (
existfeatures,
formatGenericObj,
selectfeatures ,
exporttarget
)
from ..utils.plotutils import(
make_mpl_properties,
plot_errorbar
)
from ..utils.validator import check_X_y
try:
import missingno as msno
except : pass
try :
from yellowbrick.features import (
JointPlotVisualizer,
Rank2D,
RadViz,
ParallelCoordinates,
)
except: pass
try :
from ..methods.em import (
EMAP,
MT
)
except: pass
_logger=watexlog.get_watex_logger(__name__)
#+++++++++++++++++++++++ add seaborn docs +++++++++++++++++++++++++++++++++++++
_sns_params = dict(
sns_orient="""
sns_orient: 'v' | 'h', optional
Orientation of the plot (vertical or horizontal). This is usually inferred
based on the type of the input variables, but it can be used to resolve
ambiguity when both x and y are numeric or when plotting wide-form data.
*default* is ``v`` which refer to 'vertical'
""",
sns_style="""
sns_style: dict, or one of {darkgrid, whitegrid, dark, white, ticks}
A dictionary of parameters or the name of a preconfigured style.
""",
sns_palette="""
sns_palette: seaborn color paltte | matplotlib colormap | hls | husl
Palette definition. Should be something color_palette() can process. the
palette generates the point with different colors
""",
sns_height="""
sns_height:float,
Proportion of axes extent covered by each rug element. Can be negative.
*default* is ``4.``
""",
sns_aspect="""
sns_aspect: scalar (float, int)
Aspect ratio of each facet, so that aspect * height gives the width of
each facet in inches. *default* is ``.7``
""",
)
_qkp_params = dict (
classes ="""
classes: list of int | float, [categorized classes]
list of the categorial values encoded to numerical. For instance, for
`flow` data analysis in the Bagoue dataset, the `classes` could be
``[0., 1., 3.]`` which means::
* 0 m3/h --> FR0
* > 0 to 1 m3/h --> FR1
* > 1 to 3 m3/h --> FR2
* > 3 m3/h --> FR3
""",
mapflow ="""
mapflow: bool,
Is refer to the flow rate prediction using DC-resistivity features and
work when the `tname` is set to ``flow``. If set to True, value
in the target columns should map to categorical values. Commonly the
flow rate values are given as a trend of numerical values. For a
classification purpose, flow rate must be converted to categorical
values which are mainly refered to the type of types of hydraulic.
Mostly the type of hydraulic system is in turn tided to the number of
the living population in a specific area. For instance, flow classes
can be ranged as follow:
* FR = 0 is for dry boreholes
* 0 < FR ≤ 3m3/h for village hydraulic (≤2000 inhabitants)
* 3 < FR ≤ 6m3/h for improved village hydraulic(>2000-20 000inhbts)
* 6 <FR ≤ 10m3/h for urban hydraulic (>200 000 inhabitants).
Note that the flow range from `mapflow` is not exhaustive and can be
modified according to the type of hydraulic required on the project.
"""
)
_param_docs = DocstringComponents.from_nested_components(
core=_core_docs["params"],
base=DocstringComponents(_baseplot_params),
sns = DocstringComponents(_sns_params),
qdoc= DocstringComponents(_qkp_params)
)
#++++++++++++++++++++++++++++++++++ end +++++++++++++++++++++++++++++++++++++++
[docs]
class TPlot (BasePlot):
_t= (
"survey_area",
"distance",
"prefix",
"window_size",
"component",
"mode",
"method",
"out",
"how",
"c"
)
def __init__ (
self,
survey_area =None ,
distance = 50.,
prefix ='S',
how= 'py',
window_size:int =5,
component:str ='xy',
mode: str ='same',
method:str ='slinear',
out:str ='srho',
c: int =2,
**kws
):
super().__init__(**kws)
self.survey_area=survey_area
self.distance=distance
self.prefix=prefix
self.window_size=window_size
self.component=component
self.mode=mode
self.method=method
self.out=out
self.how=how
self.c=c
[docs]
def fit (
self,
data: Optional [str|List[EDIO]]
):
"""
Fit data and populate attributes.
Parameters
-----------
data : str, or list or :class:`pycsamt.core.edi.Edi` object
Full path to EDI files or collection of EDI-objects
Returns
--------
``self``: :class:`watex.view.plot.TPlot` instanciated object
returns ``self`` for chaining methods.
"""
p = EMAP(
window_size = self.window_size ,
component= self.component,
mode= self.mode,
method= self.method,
out=self.out,
c=self.c
)
p.fit(data)
# set EM processing module
# as an attr
setattr (self, "p_", p )
# set component slices into a dict
self._c_= {
'xx': [slice (None, len(self.p_.freqs_)), 0 , 0] ,
'xy': [slice (None, len(self.p_.freqs_)), 0 , 1],
'yx': [slice (None, len(self.p_.freqs_)), 1 , 0],
'yy': [slice (None, len(self.p_.freqs_)), 1, 1]
}
return self
@property
def inspect (self):
""" Inspect object whether is fitted or not"""
msg = ( "{obj.__class__.__name__} instance is not fitted yet."
" Call 'fit' with appropriate arguments before using"
" this method"
)
if not hasattr (self, 'p_'):
raise NotFittedError(msg.format(
obj=self)
)
return 1
[docs]
def plot_multi_recovery(
self,
sites:str |List[str | int],
colors: List[str] = None,
**kws
):
"""
Plots mutiple site/stations with signal recovery.
Parameters
-----------
sites: list
list of sites to visualize. Can also be the index of the sites
colors: list of str
matplotlib colors to customize the raw signal and recovery signal
Returns
----------
ax: Matplotlib suplot axes
Examples
--------
>>> from watex.view.plot import TPlot
>>> from watex.datasets import load_edis
>>> # takes the 03 samples of EDIs
>>> edi_data = load_edis (return_data= True, samples =3 )
>>> TPlot(fig_size =(5, 3)).fit(edi_data).plot_multi_recovery (
sites =['S00'], colors =['o', 'ok--'])
<AxesSubplot:title={'center':'Recovered tensor $|Z_{xy}|$'},
xlabel='$Frequency [H_z]$', ylabel='$ App.resistivity \\quad xy \\quad [ \\Omega.m]$'>
"""
self.inspect
if isinstance (sites, str):
sites =[sites ]
if not is_iterable(sites):
sites =[sites]
site_index = station_id(sites)
for i, s in enumerate (site_index):
if s > len(self.p_.ediObjs_):
raise PlotError(f"Site {sites[i]!r} is out of the expected"
f" sites number: {len(self.p_.ediObjs_)}"
)
# read the component XY
res2d_r = self.p_.make2d (out=f'res{self.component}')
z_xy_rest = self.p_.zrestore() # no buffered data
# extracted the station at index 12, 27 for instance.
zs = [ z_xy_rest[s].resistivity[
tuple (self._c_.get(self.component))] for s in site_index ]
ma = [ moving_average ( res2d_r[:, s_ix ]) for s_ix in site_index ]
f= self.p_.getfullfrequency()
#>>> # ---> make a plot and color
# colors = make_mpl_properties(2*len(ma))
if colors is None:
colors =[]
if isinstance (colors, str):
colors =[colors]
colors += make_mpl_properties(2*len(ma))
fs = [f for i in range(len(ma))] # repeat frequency
z_norm_args = list( zip (fs, zs, colors[: len(ma)] ))
args = list(itertools.chain(*z_norm_args))
# >>> # make a fitting args
colors = ['m-'] + colors[len(ma):]
z_cor_objs = list( zip (fs, ma, ['m-'] + colors[len(ma):] ))
fit_args = list(itertools.chain(*z_cor_objs))
xlim = (f.min() -.5 * f.min(), f.max() +.5 * f.max())
return self._plot_recovery (
*args, fit_args= fit_args, xlim=xlim, sites = sites, **kws )
def _plot_recovery (
self,
*args,
fit_args = None,
leg= None,
xlim=None,
sites=None,
**kws
):
"""" Template to plot two stations with signal recovery
Isolated part of :meth:`~.TPlot.plot_multi_recovery`.
Parameters
-----------
*args : args : list
Matplotlib logs funtions plot arguments
fit_args : list or tuple
Matplotlib logs funtions plot arguments put on list. It used to
visualize the fitting curve after apply anay correction to the data.
X-coordinates. It should have the length M, the same of the ``arr2d``;
the columns of the 2D dimensional array. Note that if `x` is
given, the `distance is not needed.
leg: list
legend labels put into a list. It must fit the number of given
plots.
kws : dict
Additional keywords arguments of Matplotlib subsplots function
:func:`plt.loglog` or :func:`plt.semilog`
Returns
-------
ax: Matplotlib.pyplot <AxesSubplot>
"""
fig, ax = plt.subplots(
1,figsize = self.fig_size,
#num = self.fig_num,
)
p1= ax.loglog(*args,
markersize = self.ms ,
linewidth = self.lw ,
**kws
)
p2 =[]
if fit_args is not None:
fit_args = _assert_all_types(
fit_args , list, tuple, objname="Fit arguments")
p2 = ax.loglog(*fit_args,
markersize = self.ms ,
linewidth = self.lw
)
ax.set_xlabel (self.xlabel or '$Frequency [H_z]$',
fontsize =1.5 * self.font_size )
ax.set_ylabel(self.ylabel or '$ App.resistivity \quad xy \quad [ \Omega.m]$',
fontsize =1.5*self.font_size)
p1labels= [f'rec.tensor {i}' for i in sites ]
p2labels= [f'mov-aver. line {i}' for i in sites
] if fit_args is not None else []
ax.legend (handles = [*p1 ,*p2],
labels= [*p1labels, *p2labels] #['restored data' , 'recovery trend ']
if leg is None else leg,
loc ='best',
# ncol =len(args)//3 if fit_args is None else (
# (len(args)+len(fit_args )))//3 ,
fontsize =1.5 * self.font_size
)
if xlim is not None:
ax.set_xlim (xlim)
ax.tick_params (axis= 'both', labelsize = 1.5 * self.font_size )
plt.title (self.fig_title or 'Recovered tensor $|Z_{xy}|$',
fontsize =1.5*self.font_size)
if self.show_grid :
ax.grid (visible =True , alpha =self.galpha,
which =self.gwhich, color =self.gc)
if self.savefig is not None:
plt.savefig(self.savefig , dpi = self.fig_dpi)
plt.close (fig =fig )
return ax
[docs]
@temp2d("Base template for 2D recovery tensors plot.")
def plot_tensor2d (
self,
tensor='res',
sites =None,
to_log10=False,
):
""" Plot two dimensional tensor.
Parameters
-----------
freqs: array-like
y-coordinates. It should have the length N, the same of the ``arr2d``.
the rows of the ``arr2d``.Frequency array. It should be the
complete frequency used during the survey area.
tensor: str , ['res','phase', 'z'], default='res'
kind of tensor to plot. Can be resistivity or phase. If `phase`,
customize your plot to not fit the default 'res' behaviour.
to_log10: bool, defaut=False,
Convert the resistivity data and frequeny in log10.
sites: list of str, optional
List of stations/sites names. If given, it must have the same
length of the positions in of the EDI data. Must fit the number
of 'EDI' succesffully read.
Returns
-------
( arr2d , freqs, positions , sites , base_plot_kws):
- arr2d: 2D resistivity array from the tensor `component`
- freqs: array-like 1d of frequency in the survey.
- positions: Sites/stations positions. It is equals to the distance
between stations times the number of sites
- sites: list of the names of the station/sites
- base_plot_kws: plot keywords arguments inherits from
:class:`watex.property.BasePlot`. It composes the last
parameters for customizing plot as decorated return function.
Examples
--------
>>> from watex.view.plot import TPlot
>>> from watex.datasets import load_edis
>>> # get some 3 samples of EDI for demo
>>> edi_data = load_edis (return_data =True, samples =3 )
>>> # customize plot by adding plot_kws
>>> plot_kws = dict( ylabel = '$Log_{10}Frequency [Hz]$',
xlabel = '$Distance(m)$',
cb_label = '$Log_{10}Rhoa[\Omega.m$]',
fig_size =(6, 3),
font_size =7.
)
>>> t= TPlot(**plot_kws ).fit(edi_data)
>>> # plot recovery2d using the log10 resistivity
>>> t.plot_tensor2d (to_log10=True)
<AxesSubplot:xlabel='$Distance(m)$', ylabel='$Log_{10}Frequency [Hz]$'>
"""
self.inspect
assert str(tensor).lower() in {"res", 'phase'}, (
"Expect either a resistivity 'res' or 'phase'. Got {tensor!r}")
tensor =str(tensor).lower()
arr2d = self.p_.make2d (out = f'{tensor}{self.component}')
return self._make_tensor_utils (arr2d, sites , to_log10, tensor )
[docs]
@temp2d("Base template for 2D filtered tensors plot.")
def plot_ctensor2d (
self,
tensor ='res' ,
ffilter ='tma',
sites = None,
to_log10=False
):
""" Plot filtered tensors
Parameters
-----------
tensor: str , ['res','phase', 'z'], default='res'
kind of tensor to plot. Can be resistivity or phase. If `phase`,
customize your plot to not fit the default 'res' behaviour.
ffilter: str ['ama', 'flma', 'tma'], default='tma'
kind of appropriate filter to corrected tensor data.
to_log10: bool, defaut=False,
Convert the resistivity data and frequeny in log10.
sites: list of str, optional
List of stations/sites names. If given, it must have the same
length of the positions in of the EDI data. Must fit the number
of 'EDI' succesffully read.
Returns
-------
( arr2d , freqs, positions , sites , base_plot_kws):
- arr2d: 2D filtered tensor array from the `component`
- freqs: array-like 1d of frequency in the survey.
- positions: Sites/stations positions. It is equals to the distance
between stations times the number of sites
- sites: list of the names of the station/sites
- base_plot_kws: plot keywords arguments inherits from
:class:`watex.property.BasePlot`. It composes the last
parameters for customizing plot as decorated return function.
Examples
--------
>>> from watex.view.plot import TPlot
>>> from watex.datasets import load_edis
>>> # get some 3 samples of EDI for demo
>>> edi_data = load_edis (return_data =True, samples =3 )
>>> # customize plot by adding plot_kws
>>> plot_kws = dict( ylabel = '$Log_{10}Frequency [Hz]$',
xlabel = '$Distance(m)$',
cb_label = '$Log_{10}Rhoa[\Omega.m$]',
fig_size =(6, 3),
font_size =7.
)
>>> t= TPlot(**plot_kws ).fit(edi_data)
>>> # plot filtered tensor using the log10 resistivity
>>> t.plot_ctensor2d (to_log10=True)
<AxesSubplot:xlabel='$Distance(m)$', ylabel='$Log_{10}Frequency [Hz]$'>
"""
self.inspect
fd = {"tma": self.p_.tma , "flma":self.p_.flma, "ama":self.p_.ama }
assert str(ffilter).lower() in fd.keys(), (
"Supports only base filters {tuple (fd.keys())}. Got {ffilter!r}"
" To apply a simple filter like 'moving average' to a tensor, refer"
" to <watex.utils.exmath.moving_average>. For other filters like"
" 'Savitzky Golay1d/2d', 'remove distorsion' or 'remove outliers'"
" and else, use the package 'pycsamt' instead. "
)
ffilter= str (ffilter).lower().strip()
arr2d = fd.get(ffilter)()
return self._make_tensor_utils (arr2d, sites, to_log10 , tensor )
def _make_tensor_utils (
self, arr2d, sites, to_log10= False, tensor=None ):
""" Make utilities for plotting tensors
Parameters
------------
arr2d: arraylike of shape (n_freq, n_sites):
Array of the tensor composed of number of frequency and number
of sites that fit the number of EDI correctly read.
sites: list of str, optional
List of stations/sites names. If given, it must have the same
length of the positions in of the EDI data. Must fit the number
of 'EDI' succesffully read.
to_log10: bool, defaut=False,
Convert the resistivity data and frequeny in log10.
Returns
-------
( arr2d , freqs, positions , sites , base_plot_kws):
- arr2d: 2D filtered tensor array from the `component`
- freqs: array-like 1d of frequency in the survey.
- positions: Sites/stations positions. It is equals to the distance
between stations times the number of sites
- sites: list of the names of the station/sites
- base_plot_kws: plot keywords arguments inherits from
:class:`watex.property.BasePlot`. It composes the last
parameters for customizing plot as decorated return function.
"""
try :
distance = float(self.distance)
except :
raise TypeError (
f'Expect a float value not {type(self.distance).__name__!r}')
freqs = self.p_.freqs_
positions = np.arange(arr2d.shape[1]) * distance
sites = sites or make_ids (
positions , self.prefix , how = self.how)
if isinstance(sites, str):
sites =[sites]
if not is_iterable(sites):
raise TypeError("Sites collection must be an iterable"
f" object. Got {type(sites).__name__!r}"
)
if len(sites)!= len(positions):
raise TypeError (f"Sites={len(sites)} length must be consistent."
" Expects positions={len(positions)}.")
if tensor in {'phase', 'phs'}:
arr2d %=90
if to_log10:
arr2d = arr2d if tensor in ("phase", "phs") else np.log10 (arr2d)
freqs = np.log10 (freqs)
base_plot_kws = {
k: v for k, v in self.__dict__.items ()
if k not in list(self._t ) +['p_']
}
return arr2d, freqs, positions ,sites , base_plot_kws
[docs]
def plot_recovery(self, site = 'S00'):
""" visualize the restored tensor per site.
Parameters
------------
site: str, int, default ="S00"
Site/station name for
Returns
--------
``self``: :class:`watex.view.plot.TPlot` instanciated object
returns ``self`` for chaining methods.
Examples
--------
>>> from watex.view import TPlot
>>> from watex.datasets import load_edis
>>> edi_data = load_edis (return_data =True, samples =7)
>>> plot_kws = dict( ylabel = '$Log_{10}Frequency [Hz]$',
xlabel = '$Distance(m)$',
cb_label = '$Log_{10}Rhoa[\Omega.m$]',
fig_size =(7, 4),
font_size =7.
)
>>> t= TPlot(**plot_kws ).fit(edi_data)
>>> # plot recovery of site 'S01'
>>> t.plot_recovery ('S01')
"""
self.inspect
if isinstance(site, str):
site =[site]
site_index = station_id(site)
site_index = site_index [0] if isinstance (
site_index, tuple ) else site_index
if site_index > len(self.p_.ediObjs_):
raise PlotError(f"Site {site!r} is out of the expected"
f" sites number: {len(self.p_.ediObjs_)}"
)
ediObjs = self.p_.ediObjs_
# >>> zobjs_b = restoreZ(ediObjs, buffer = buffer) # with buffer
zobjs = self.p_.zrestore() # with no buffer
zxy_restored = np.abs (zobjs[site_index].z [
tuple (self._c_.get(self.component))])#[:, 0, 1])
# Export the first raw object with missing Z at
# the dead dand in ediObjs collection
z1 = np.abs(ediObjs[site_index].Z.z)
z1freq= ediObjs[site_index].Z._freq # the frequency of the first z obj
# get the frequency of the clean data knonw as reference frequency
indice_reffreq = np.argmax (list (map(lambda o: len(o.Z._freq), ediObjs)))
reffreq = ediObjs [indice_reffreq].Z._freq
# >>> # use the real part of component xy for the test
zxy = np.abs (z1[tuple (self._c_.get(self.component))]) #[:, 0, 1])
# fit zxy to get the missing data in the dead band
zfit = fittensor(refreq= reffreq, compfreq= z1freq, z=zxy)
# not necessary, one can corrected z to get a
# smooth resistivity distribution
zcorrected = moving_average (zxy_restored)
# plot the two figures
plt.figure(figsize =self.fig_size) #(10, 5)
plt.loglog(reffreq, zfit, '^r', reffreq, zxy_restored, 'ok--')
plt.loglog( reffreq, zcorrected, '1-.')
plt.legend (['raw data', 'tensor $res_{xy}$ restored',
'moving-average trend line' ],loc ='best')
plt.xlabel ('$Frequency [H_z]$')
plt.ylabel('$ Resistivity_{xy} [ \Omega.m]$')
plt.title ('Recovered tensor $|Z_{xy}|$' + f" at site {site[0].upper()}")
plt.grid (visible =True , alpha =0.8, which ='both', color ='k')
plt.xlim (reffreq.min() -.5* reffreq.min(),
reffreq.max() + .5 * reffreq.max())
plt.tight_layout()
return self
[docs]
def plot_phase_tensors(
self,
mode ='frequency',
stretch = (7000, 20 ),
linedir ='ns',
tensor='phimin',
ellipse_dict = None,
**kws
):
""" Plot phase tensor pseudosection and skew ellipsis
visualization.
Method plots the phase tensor ellipses in a pseudo section format.
It uses `mtpy` as dependency.
Parameters
-----------
mode: str, default ='frequency'
Tempoora scale in y-axis. Can be ['frequency' | 'period']
stretch : float or tuple (xstretch, ystretch), default=200
Is a factor that scales the distance from one station to the next
to make the plot readable. It determines (x,y) aspect ratio of plot.
linedir: str [ 'ns' | 'ew' ], default='ns'
The predominant direction of profile line. It can be ['ns' | 'ew']
where:
* 'ns' refer to North-South Line or line is closer to north-south)
* 'ew' refer to East-West line or line is closer to east-west
*Default* is 'ns'
tensor: str, default='phimin'
Is the tensor skew or ellipsis visualizations. The color for plot
style is referred accordingly. Tensor can be:
[ 'phimin' | 'phimax' | 'skew' |'skew_seg' | 'phidet' |'ellipticity' ]
where:
- 'phimin' -> colors by minimum phase
- 'phimax' -> colors by maximum phase
- 'skew' -> colors by skew
- 'skew_seg' -> colors by skew indiscrete segments defined
by the range
- 'normalized_skew' -> colors by skew see [Booker, 2014]
- 'normalized_skew_seg' -> colors by normalized skew in
discrete segments defined by the range
- 'phidet' -> colors by determinant of the phase tensor
- 'ellipticity' -> colors by ellipticity *default* is 'phimin'
ellipse_dict: dict, optional
Dictionary of parameters for the phase tensor ellipses with keys:
* 'size': float, default =2 , is the size of ellipse in points
* 'colorby' : str, default='phimin'
Is the color for plot style referring either to tensor,
skew or ellipsis visualizations. It can be all the `tensor`
parameter values. see `tensor` parameter values.
[ 'phimin' | 'phimax' | 'skew' |'skew_seg' | 'phidet' |'ellipticity' ]
* 'range' : tuple (min, max, step), default='colorby'
Need to input at least the min and max and if using
'skew_seg' to plot discrete values input step as well
* 'cmap' : [ 'mt_yl2rd' | 'mt_bl2yl2rd' |'mt_wh2bl' | 'mt_rd2bl' |
'mt_bl2wh2rd' | 'mt_seg_bl2wh2rd' |'mt_rd2gr2bl' ]
- 'mt_yl2rd' -> yellow to red
- 'mt_bl2yl2rd' -> blue to yellow to red
- 'mt_wh2bl' -> white to blue
- 'mt_rd2bl' -> red to blue
- 'mt_bl2wh2rd' -> blue to white to red
- 'mt_bl2gr2rd' -> blue to green to red
- 'mt_rd2gr2bl' -> red to green to blue
- 'mt_seg_bl2wh2rd' -> discrete blue to white to red
kws: dict
Additional keywords arguments passed from |MTpy| pseudosection
phase tensor class: :class:`~.PlotPhaseTensorPseudoSection`
See Also
----------
mtpy.imaging.phase_tensor_pseudosection.PlotPhaseTensorPseudoSection:
PlotPhase pseudo section tensor from |MTpy| package.
watex.utils.plot_skew:
Phase sensitive skew visualization.
Examples
---------
>>> import watex as wx
>>> edi_data = wx.fetch_data ('edis', key='edi', return_data =True , samples =17 )
>>> tplot = wx.methods.TPlot ().fit(edi_data )
>>> tplot.plot_phase_tensors (tensor ='skew')
"""
extra =("Phase tensor plots or skew ellipsis visualization"
" uses 'mtpy' as dependency. Alternatively, you may"
" also use the phase sensitive 'skew' visualization"
" of plot utilities if plot only refers to 'skew'."
)
import_optional_dependency ('mtpy' , extra = extra )
from mtpy.imaging.phase_tensor_pseudosection import (
PlotPhaseTensorPseudoSection )
self.inspect
zobjs = [edi_obj.Z for edi_obj in self.p_.ediObjs_]
elrange = [-7, 7] if 'skew' in str(tensor).lower() else [0, 90 ]
ellipse_dict = ellipse_dict or {
'ellipse_colorby':tensor,
'ellipse_range':elrange, # Color limits
'ellip_size': 2,
'ellipse_cmap':'mt_bl2wh2rd'
}
# skew_seg need to provide
# 3 numbers, the 3rd indicates
# interval, e.g. [-12,12,3]
#from contextlib import suppress
# suppress as possible the external
#lib resources
#with suppress (Exception):
ptsection = PlotPhaseTensorPseudoSection(
fn_list = self.p_.edifiles,
z_object_list = zobjs,
fig_size = self.fig_size,
tscale = mode,
plot_num = self.fig_num,
plot_title = self.fig_title,
xlimits = self.xlim,
ylimits = self.ylim,
linedir= linedir,
stretch= stretch,
station_id=(0, len(self.p_.ediObjs_)),
font_size=self.font_size ,
lw=self.lw,
**ellipse_dict,
**kws,
)
ptsection.save_figure(save_fn =self.savefig, fig_dpi=self.fig_dpi
) if self.savefig else ptsection.plot()
return self
[docs]
def plotSkew (
self ,
method ='Bahr',
sensitivity ='skew',
mode=None,
threshold_line=None,
show_average_sensistivity=True,
suppress_outliers =True,
**plot_kws
):
""" Plot phase sensistive skew visualization
'Skew' is also knwown as the conventional asymmetry parameter
based on the Z magnitude.
Mosly, the :term:`EM` signal is influenced by several factors such
as the dimensionality of the propagation medium and the physical
anomalies, which can distort theEM field both locally and regionally.
The distortion of Z was determined from the quantification of its
asymmetry and the deviation from the conditions that define its
dimensionality. The parameters used for this purpose are all rotational
invariant because the Z components involved in its definition are
independent of the orientation system used. The conventional asymmetry
parameter based on the Z magnitude is the skew defined by Swift (1967)
[1]_ and Bahr (1991) [2]_.
Parameters
-----------
method: str, default='Bahr':
Kind of correction. Can be:
- ``swift`` for the remove distorsion proposed by Swift in 1967.
The value close to 0. assume the 1D and 2D structures, and 3D
otherwise. However, In general case, the electrical structure
of :math:`\eta < 0.4` can be treated as a 2D medium.
- ``bahr`` for the remove distorsion proposed by Bahr in 1991.
The latter threshold is set to 0.3. Above this value the
structures is 3D.
sensitivity: str, default='skew'
phase sensistive visualization. Can be rotational invariant
``invariant``. In fact, setting to ``mu`` or ``invariant`` does
not change any interpretation when since the distortion of Z
are all rotational invariant whether using the ``Bahr`` or ``swift``
methods.
.. versionchanged::
Param `view` is deprecated and replaced with `sensistivity`.
mode:str, optional
X-axis coordinates for visualisation. plot either ``'frequency'`` or
``'periods'``. The default is ``'frequency'``
threshold_line: float, optional
Visualize th threshold line. Can be ['bahr', 'swift', 'both']:
- Note that when method is set to ``swift``, the value close
to close to :math:`0.` assume the 1D and 2D structures
(:math:`\eta <0.4`), and 3D otherwise( :math:`\eta >0.4`).
The threshold line for ``swift`` is set to :math:`0.4`.
- when method is set to ``Bahr``, :math:`\eta > 0.3`` is 3D
structures, between :math:`[0.1 - 0.3]` assumes modified 3D/2D
structures whereas :math:`<0.1` 1D, 2D or distorted 2D.
show_average_sensistivity: bool, default=True
Display the averaged value of skew data at all -frequencies.
Value can help a dimensionality interpretation purposes.
suppress_outliers: bool, default=True
Remove the outliers in the data if exists. It uses the
Inter Quartile Range (``IQR``) approach. See the documentation
of :func:`watex.utils.remove_outliers`. This is useful for clear
interpretation using the skew threshold value.
See Also
---------
watex.methods.EMAP.skew:
For mathematical skew `Bahr` and `Swift` concept formulations.
watex.utils.plot_skew:
For phase sensistive skew visualization - naive plot.
Examples
--------
>>> import watex
>>> test_data = watex.fetch_data ('edis', samples =37, return_data =True )
>>> watex.TPlot(fig_size =(10, 4), marker ='x').fit(
test_data).plotSkew(method ='swift', threshold_line=True)
References
-----------
.. [1] Swift, C., 1967. A magnetotelluric investigation of an
electrical conductivity anomaly in the southwestern United
States. Ph.D. Thesis, MIT Press. Cambridge.
.. [2] Bahr, K., 1991. Geological noise in magnetotelluric data: a
classification of distortion types. Physics of the Earth and
Planetary Interiors 66 (1–2), 24–38.
"""
self.inspect
sensitivity = str(sensitivity).lower()
for ix in ('inv', 'rot', 'mu'):
if sensitivity.find(ix)>=0:
sensitivity ='mu'
break
sensitivity='skew' if sensitivity=='none' else sensitivity
assert sensitivity in {"skew", 'mu'}, ("expect 'skew' or 'rotational'"
f" invariant plot, got {sensitivity!r}")
if 'period' in str(mode).lower():
mode ='period'
skew, mu =self.p_.skew(
method = method, suppress_outliers = suppress_outliers
)
freqs = 1/ self.p_.freqs_ if mode =='period' else self.p_.freqs_
ymat = skew if sensitivity =='skew' else mu
fig, ax = plt.subplots(figsize = self.fig_size )
#---manage threshold hline ------
thr_code = {"bahr": [1] , "swift":[ 2] , 'both':[1, 2] }
if str(threshold_line).lower()=='true':
threshold_line = str(method).lower()
if threshold_line is not None:
if str(threshold_line).lower() in ("*", "both" ):
threshold_line = 'both'
ct = thr_code.get(str(threshold_line).lower(), None )
for i in range (skew.shape[1]):
ax.scatter ( freqs, reshape (ymat[:, i]),
marker = plot_kws.get ('marker', None) or self.marker,
cmap = plot_kws.get('cmap', None) or self.cmap,
alpha = plot_kws.get('alpha', None) or self.alpha,
s = plot_kws.get('s', None) or self.s ,
**plot_kws
)
if ct:
for m in ct:
plt.axhline(y=0.4 if m==2 else 0.3 , color="k" if m==1 else "g",
linestyle="-",
label=f'threshold: $\mu={0.4 if m==2 else 0.3}$'
)
# plt.legend()
# see phase sensitive trend
if show_average_sensistivity:
plt.text(x= np.nanmin(freqs) , y= np.nanmax(ymat),
s="aver.-{}:{}={}".format(sensitivity, str(method).capitalize(),
np.around (np.average(ymat[ ~np.isnan(ymat)]), 3)),
fontdict= dict (style ='italic', bbox =dict(
boxstyle='round',facecolor ='#CED9EF'))
)
ax.set_xscale('log')
ax.set_xlabel('Period ($s$)' if mode=='period'
else 'Frequency ($H_z$)' or self.xlabel )
ax.set_ylabel(f"{'Skew' if sensitivity =='skew' else 'Rot.Invariant'}" + "($\mu$)"
or self.ylabel )
plt.xlim ([ freqs.min() , freqs.max()] or self.xlim )
plt.xlim()
if ct: ax.legend()
if self.savefig is not None:
plt.savefig (self.savefig, dpi = self.fig_dpi)
plt.close () if self.savefig is not None else plt.show()
return self
def _check_component_validity (self, tensor, components ):
"""Retrieve resistiviy, phase or impedance tensors from
EDI objets if component exists.
Parameters
-----------
tensor: str,
Name of tensor. Could be ['resistivity'| 'phase'|'z']
components: str, list,
Name of components. Could be ['xx', 'xy', 'yx', 'yy']
Returns
--------
rp: list of valid 2D dimensional tensors and ``None`` if
no valid tensors are found.
"""
rp =[]
tensor =str(tensor)
components = is_iterable(components, exclude_string =True,
transform =True, parse_string =True )
for c in components :
try:
mat2d = self.p_.make2d (f'{tensor+c}')
except :continue
else: rp.append(mat2d )
return rp if len(rp)!=0 else None
[docs]
def plot_rhoa(
self,
mode ='TE',
scale ='period',
site =None,
seed = None,
how ='py',
show_site=True,
survey= None,
style=None,
errorbar=True,
suppress_outliers=False,
**kws
):
""" Plot apparent resistivity and phase curves
Parameters
----------
mode: str, default='TE',
Electromagnetic mode. Can be ['TM' |'both']. If ``both``,
components `xy` and `yx` are expected in the data.
scale: str, default='period'
Visualization on axis labell. can be ``'frequency'``.
site: int,str, optional
index of name of the site to plot. `site` must be composed of
a position number. For instance ``'S13'``. If not provided,
a random station is selected instead.
seed : int, optional
If site is not provided, seed fetches randomly a site. To fetch
the same sime everytimes, it is better to set the seed value.
how: str, default='py'
The way the site is fetched for plot. For instance, in Python
indexing (default), the site is numbered from 0. For instance
'site05' will fetch the data at index 4. If this positioning
is not wished, set to 'None'.
show_site:bool, default=True,
Display the number of site.
survey: str, optional
Method used for the survey. e.g., 'AMT' for |AMT|.
style:str, default='default'
Matplotlib style.
errorbar: bool, default=True
display the error bar.
suppress_outliers: bool, default=False,
Remove outliers in the data before plotting
kws: dict,
Addfitional keywords arguments passed to
Matplotlib.Axes.Scatter plots.
Examples
---------
>>> import watex as wx
>>> edi_data = wx.fetch_data ('edis', return_data =True, samples =27)
>>> wx.methods.TPlot(show_grid=True).fit(edi_data).plot_rhoa (
seed =52, mode ='*')
"""
self.inspect
m=_validate_name_in(mode, ('*', 'both', 'tetm'), expect_name='*')
if m!='*':
m= _validate_name_in(mode, defaults = 'tm transverse-magnetic',
expect_name ='tm' )
if not m:
m='te'
scale = _validate_name_in(scale, deep =True, defaults='periods',
expect_name='period')
cpm = {'te': ["xy"] , 'tm': ["yx"], '*': ('xy', 'yx') }
components = cpm.get(m)
res, phs, site, *s= self._validate_correction (
components = components,
errorbar = errorbar ,
how = how,
seed = seed ,
sites = site,
style =style ,
n_sites = 1.
)
s, res_err, phs_err = s
# plot only single data
site = site [0] ; s = s[0]
# get the single site
fig = plt.figure(self.fig_num , figsize= self.fig_size,
dpi = self.fig_dpi , # layout='constrained'
)
gs = GridSpec(3, 1, figure = fig )
ax1 = fig.add_subplot (gs[:-1, 0 ])
ax2 = fig.add_subplot(gs [-1, 0 ], sharex = ax1 )
plt.setp(ax1.get_xticklabels(), visible=False)
survey= survey or self.p_.survey_name
if not survey: survey=''
colors = [ '#069AF3', '#DC143C']
#==plotlog10 --------
res= [ np.log10 (r) for r in res]
# the complete frequency
fp = self.p_.freqs_
fp = 1/ fp if scale =='period' else fp
fp = np.log10 ( fp)
if suppress_outliers:
res = remove_outliers(res, fill_value=np.nan)
phs = remove_outliers(phs, fill_value=np.nan)
if errorbar:
res_err = remove_outliers(
res_err, fill_value=np.nan)
phs_err = remove_outliers(
phs_err, fill_value=np.nan)
min_y = np.nanmin(res[0][:, site])
# add error bar data to main
data = [res, phs ]
data += [ res_err , phs_err ] if errorbar else []
for i, sloop in enumerate (zip (* data )) :
r, p, *sl = sloop
if len(sl) !=0 :
e, ep = sl # mean errorbar is set to True
y = reshape (r[:, site])
if errorbar:
plot_errorbar (ax1 ,
fp,
y,
y_err = reshape (e[:, site]),
)
ax1.scatter (fp , y,
marker =self.marker,
color =colors [i],
edgecolors='k',
label = fr'{survey}$\rho_a${components[i]}',
**kws
)
if errorbar:
plot_errorbar (ax2 ,
fp,
reshape (p[:, site]),
y_err = reshape (ep[:, site]),
)
ax2.scatter( fp,
reshape (p[:, site]),
marker =self.marker,
color =colors [i] ,
edgecolors='k',
label = f'{survey}$\phi${components[i]}',
**kws
)
min_y = np.nanmin (y) if np.nanmin (
y) < min_y else min_y
try:
ax1.legend(ncols = len(res))
ax2.legend(ncols = len(phs))
except:
# For consistency in the case matplotlib is < 3.3.
ax1.legend()
ax2.legend()
if show_site:
ax1.text (np.nanmin(fp),
min_y,
f'site {s}',
fontdict= dict (style ='italic', bbox =dict(
boxstyle='round',facecolor ='#CED9EF'),
alpha = 0.5 )
)
ax2.set_ylim ([0, 90 ])
xlabel = self.xlabel or ( 'Log$_{10}$Period($s$)' if scale=='period'
else 'Frequency ($H_z$)')
ax2.set_xlabel(xlabel )
ax1.set_ylabel(self.ylabel or r'Log$_{10}\rho_a$($\Omega$.m)')
ax2.set_ylabel('$\phi$($\degree$)')
if self.show_grid :
for ax in (ax1, ax2 ):
ax.grid (visible =True , alpha =self.galpha,
which =self.gwhich, color =self.gc)
if self.savefig is not None:
plt.savefig (self.savefig, dpi = self.fig_dpi)
plt.close () if self.savefig is not None else plt.show()
return self
[docs]
def plot_rhophi(
self,
sites =None,
mode ='TE',
scale ='period',
seed = None,
how ='py',
show_site=True,
survey= None,
style=None,
errorbar=True,
suppress_outliers=False,
kind='2',
n_sites= 1,
spad=.5,
**kws
):
""" Plot resistivities and phases from multiples stations.
Parameters
----------
mode: str, default='TE',
Electromagnetic mode. Can be ['TM' |'both']. If ``both``,
components `xy` and `yx` are expected in the data.
sites: int,str, or list, optional
A collection of index of name of the site . Each `site` must be
composed of a position number. For instance ``'S13'``. If not
provided, a random sites are selected instead using the `n_sites`
parameter.
scale: str, default='period'
Visualization on axis labell. can be ``'frequency'``.
seed : int, optional
If site is not provided, seed fetches randomly a site. To fetch
the same sime everytimes, it is better to set the seed value.
how: str, default='py'
The way the site is fetched for plot. For instance, in Python
indexing (default), the site is numbered from 0. For instance
'site05' will fetch the data at index 4. If this positioning
is not wished, set to 'None'.
show_site:bool, default=True,
Display the number of site.
survey: str, optional
Method used for the survey. e.g., 'AMT' for |AMT|.
style:str, default='default'
Matplotlib style.
errorbar: bool, default=True
display the error bar.
suppress_outliers: bool, default=False,
Remove outliers in the data before plotting
n_sites: int, default =1.
Number of random sites to select for visualizing. It cannot work
if the names of sites are given.
spad: float, default=.5,
pad to display the station in the top of each section plot.
.. versionadded:: 0.2.1
kws: dict,
Addfitional keywords arguments passed to
Matplotlib.Axes.Scatter plots.
Examples
---------
>>> import watex as wx
>>> edi_data = wx.fetch_data ('edis', return_data =True, samples =27)
>>> wx.methods.TPlot(show_grid=True).fit(edi_data).plot_rhophi (
seed =52, mode ='*', n_sites =3 )
"""
self.inspect
m=_validate_name_in(mode, ('*', 'both', 'tetm'),
expect_name='*')
if m!='*':
m= _validate_name_in(mode, defaults = 'tm transverse-magnetic',
expect_name ='tm' )
if not m:
m='te'
scale = _validate_name_in(scale, deep =True, defaults='periods',
expect_name='period')
cpm = {'te': ["xy"] , 'tm': ["yx"], '*': ('xy', 'yx') }
components = cpm.get(m)
res, phs, sites, *s= self._validate_correction (
components = components,
errorbar = errorbar ,
how = how,
seed = seed ,
sites = sites ,
style =style ,
n_sites = n_sites,
)
s, res_err, phs_err = s
survey= survey or self.p_.survey_name
if not survey: survey=''
#colors = [ '#069AF3', '#DC143C']
colors = [ '#0000FF', '#FF00FF']
#==plotlog10 --------
#xxxxxxxxxxxxxxxxxxxx
# res= [ np.log10 (r) for r in res]
# the complete frequency
fp = self.p_.freqs_
fp = 1/ fp if scale =='period' else fp
if suppress_outliers:
res = remove_outliers(res, fill_value=np.nan)
phs = remove_outliers(phs, fill_value=np.nan)
if errorbar:
res_err = remove_outliers(
res_err, fill_value=np.nan)
phs_err = remove_outliers(
phs_err, fill_value=np.nan)
# make sites coordinates to place sites
# assert whether the number of sites fit the row values
sy =[]
for ii in sites:
exp_sites = (len(res[0][0, :]) -1) if how=='py' else len(res[0][0, :])
if ii > exp_sites:
raise SiteError (
f"Expects {exp_sites} sites. Got {ii}. Note that"
f" for how={how!r}, the site numbering starts"
f" at {0 if how=='py' else 1}."
)
sy.append ( (np.nanmax(res[0][:, ii]) - np.nanmin(res[0][:, ii])) /2)
sy = np.average ( sy )
# sy= np.average ( [
# ( np.nanmax(res[0][:, ii]) - np.nanmin(res[0][:, ii])) /2
# for ii in sites ] )
sy += spad
sx = np.average (fp)
# add error bar data to main
data = [res, phs ]
data += [ res_err , phs_err ] if errorbar else []
# make thoa and phase labels
rlabels = [fr'{survey}$\rho_a${components[i]}'
for i in range (len(res))]
plabels = [f'{survey}$\phi${components[i]}'
for i in range(len(phs))]
self._plot_grid_spec (
data = data ,
x= fp,
sites =sites,
errorbar =errorbar,
colors = colors,
xysites= ( sx, sy ),
show_site =show_site,
scale =scale,
rlabels = rlabels,
plabels = plabels,
kind= kind,
**kws
)
if self.savefig is not None:
plt.savefig (self.savefig, dpi = self.fig_dpi)
plt.close () if self.savefig is not None else plt.show()
return self
def _plot_grid_spec (
self,
data,
x,
sites =None,
errorbar =False ,
colors = None,
show_site =False,
scale =None,
xysites = None,
color_mode='color',
kind='2',
**kws
):
""" Plot multiple stations using the SpecGrid
Parameters
-----------
data: list,
A collection of resistivity, errors and phases
x: arraylike
Arraylike one-dimensional for plotting data. It should be the
frequency array or periods
sites: int,str, optional
index of name of the site to plot. `site` must be composed of
a position number. For instance ``'S13'``. If not provided,
a random station is selected instead.
errorbar: bool, default=True
display the error bar.
colors: str, list
a collection of matplotlib colors
show_site: bool, default=False,
Display the name of the site in each section
style:str, default='classic'
Matplotlib style.
scale: str, default='period'
Visualization on axis labell. can be ``'frequency'``.
mode: str, {'1', '2'} , default='2'
choice of plot style. ``mode='2'`` plots only the errorbar and '1'
add scatter plots.
color_mode: str, {"color", "bw"}, default='color'
Plot tensor in different colors by default otherwise plots in
black-white. This parameter is triggered only if `mode` is set ``2``.
xysites: tuple , optional
The coordinates to locate the text of each station.
kws: dict,
Additional keywords passed to matplotlib.scatter plot.
Also to rename the labels of resistivy and phase, pass a list
of rho and phase labels in parameters `rlabels` and `plabels`
respectively.
Returns
--------
axr, axp : list of Matplotlib.Axes
A collection of Matplotlib axes of each stations
"""
ncols = len (sites) if sites is not None else 1
fig = plt.figure(figsize = self.fig_size, dpi=self.fig_dpi)
h_ratio = [1.5, 1, .5]
gs = GridSpec(2, ncols or 1,
wspace=0. if kind =='2' else .3, # .3,if
left=.08,
top=.85,
bottom=.1,
right=.98,
hspace=.0,
height_ratios=h_ratio[:2])
sharey = None
# make a list of axes
# to return axes
# for another plots
axr, axp =[], []
#++++++++++++++++++++++++++++++++++++++++++++++++++++++++
#if kind =='2':
# color mode
x /= 1 # inverse , take a periods
if str( color_mode) .lower() == 'color':
# color for data
cted = (0, 0, 1)
ctmd = (1, 0, 0)
mted = 's'
mtmd = 'o'
# black and white mode
elif color_mode == 'bw':
# color for data
cted = (0, 0, 0)
ctmd = (0, 0, 0)
mted = 's'
mtmd = 'o'
# --> make key word dictionaries for plotting
ms = 1.5
# ms_r = 3
lw = .5
# lw_r = 1.0
# ls = ':'
e_capthick = .5
e_capsize = 2
# kw_xx=dict(); kw_yy=dict()
res_limits =[]; phase_limits=[]
sharey2 =None
#np.savetxt ( 'x.txt', x )
#++++++++++++++++++++++++++++++++++++++++++++++++++++++++
for j, site in enumerate ( sites ):
ax1 = fig.add_subplot (gs [ 0, j] ,
sharey = sharey)
if j==0: sharey = ax1
if errorbar:
ax2 = fig.add_subplot (gs [1, j], sharey =sharey2 )
if j==0 and kind =='2': sharey2 = ax2
for i, sloop in enumerate (zip (* data )) :
r, p, *sl = sloop
if len(sl) !=0 :
e, ep = sl # mean errorbar is set to True
y = reshape (r[:, site])
colors = [cted,ctmd ]
markers = [mted, mtmd]
kw_xx = {'color': colors[i],
'marker': markers[i],
'ms': ms,
'ls': ':',
'lw': lw,
'e_capsize': e_capsize,
'e_capthick': e_capthick}
kw_yy = {'color': colors[i],
'marker': markers[i],
'ms': ms,
'ls': ':',
'lw': lw,
'e_capsize': e_capsize,
'e_capthick': e_capthick}
#if errorbar:
plot_errorbar (ax1 ,
x,
y, #if i ==0 else y ,
y_err = reshape (e[:, site]),
**kw_xx
)
plot_errorbar (ax2 ,
x,
reshape (p[:, site]),
y_err = reshape (ep[:, site]),
**kw_yy,
)
res_limits.append ((min(y), max(y)))
phase_limits.append( (min(reshape (p[:, site])),
max(reshape (p[:, site]))
)
)
if show_site:
ax1.set_title( f'site {site}',
fontdict={'size': 8 + 2,
'weight': 'bold'})
axr.append( ax1); axp.append (ax2)
# --> set default font size
self.font_size = 6
plt.rcParams['font.size'] = self.font_size
fontdict = {'size': self.font_size + 2,
'weight': 'bold'}
for ax0, site in zip(axr, sites):
ax0.set_title(f'S{site}', fontdict={'size': self.font_size + 2,
'weight': 'bold'})
# # set axis properties
# set ylimit
res_limit_max=np.array( list(
map ( lambda x: x[1], res_limits )) )
res_limit_min=np.array( list(
map ( lambda x: x[0], res_limits )))
res_limits_d= [10 **np.floor (np.log10(res_limit_min.min())),
10 **np.ceil (np.log10(res_limit_max.max())) ]
# phase limit
phase_limit_max=np.array( list(
map ( lambda x: x[1], phase_limits )) )
phase_limit_min=np.array( list(
map ( lambda x: x[0], phase_limits )))
phase_limits_d= [np.floor (phase_limit_min.min()),
np.ceil (phase_limit_max.max()) ]
phase_limits_d=None
ax_list = [*axr, *axp ]
for aa, ax in enumerate(ax_list):
ax.tick_params(axis='y', pad=1.5)
ax.set_xlabel('Period (s)',
fontdict=fontdict
)
if aa < len(ax_list)//2 : #4 :
ylabels = ax.get_yticklabels()
ylabels[0] = ''
ax.set_yticklabels(ylabels)
ax.set_yscale('log', #nonposy='clip'
)
try:
ax.set_ylim(res_limits_d)
except:
ax.set_ylim(None)
res_limits_d=None
if aa >= len(ax_list)//2 :
ax.yaxis.set_major_locator(mticker.MultipleLocator(10.0))
if phase_limits_d is not None:
ax.set_ylim(phase_limits_d)
# set axes labels
if aa == 0:
ax.set_ylabel('App. Res. ($\mathbf{\Omega \cdot m}$)',
fontdict=fontdict)
elif aa == 0 or aa == len(ax_list)//2:
ax.set_ylabel('Phase (deg)',
fontdict=fontdict)
ax.set_xscale('log', # nonposx='clip'
)
# set period limits
period_limits = (10 ** (np.floor(np.log10(x[0]))) * 1.01,
10 ** (np.ceil(np.log10(x[-1]))) * .99)
ax.set_xlim(xmin=period_limits[0],
xmax=period_limits[1])
ax.grid(True, alpha=.25)
if kind=='2':
if aa !=0 or aa != len(ax_list)//2:
ax.set_yticklabels('')
else:
ylabels = ax.get_yticks().tolist()
ylabels[-1] = ''
ylabels[0] = ''
ax.set_yticklabels(ylabels)
if aa < len(ax_list)//2:
plt.setp(ax.get_xticklabels(), visible=False)
return axr, axp
def _axesproperties1 (self, j, ax1, ax2, r, p, sites , scale ):
""" Set properties of plot kind number 1. """
if j > 0:
plt.setp(ax1.get_yticklabels(), visible=False)
plt.setp(ax2.get_yticklabels(), visible=False)
# Put the legend in the last image
if j == len(sites)-1:
try:
ax1.legend(ncols = len(r))
ax2.legend(ncols = len(p))
except:
# For consistency in the case matplotlib is < 3.3.
ax1.legend()
ax2.legend()
ax1.set_xscale ('log') ; ax1.set_yscale ('log')
ax2.set_xscale ('log')
ax2.set_ylim ([0, 90 ])
xlabel = self.xlabel or ( 'Period($s$)' if scale=='period'
else 'Frequency ($H_z$)')
ax2.set_xlabel(xlabel)
if j ==0 :
# avoid reapeting this
ax1.set_ylabel(self.ylabel or r'$\rho_a$($\Omega$.m)')
ax2.set_ylabel('$\phi$($\degree$)')
if self.show_grid :
for ax in (ax1, ax2 ):
ax.grid (visible =True , alpha =self.galpha,
which =self.gwhich, color =self.gc)
return ax1 , ax2
def _validate_correction (
self,
components ,
errorbar ,
seed ,
sites ,
how ,
style ,
n_sites,
):
"""Isolated part to validate the :meth:`plot_corrections` and
:meth:`plot_rhoa` arguments.
Parameters
----------
components: str ,
could be 'xx', 'xy', 'yx' or 'yy'
sites: int,str, optional
index of name of the site to plot. `site` must be composed of
a position number. For instance ``'S13'``. If not provided,
a random station is selected instead.
seed : int, optional
Get the same site if site is not provided. `seed` fetches
a random number of site.
how: str, default='py'
The way the site is fetched for plot. For instance, in Python
indexing (default), the site is numbered from 0. For instance
'site05' will fetch the data at index 4. If this positioning
is not wished, set to 'None'.
style:str, default='default'
Matplotlib style.
errorbar: bool, default=True
display the error bar.
n_sites: int,
Number of sites to randomly diplay when sites is not given.
Returns
--------
( fp, res, phs, site, s , res_err , phs_err) : Tuple
- fp: frequency array
- res: resistivity tensor collected at a specific components
- phs: phase tensor collected at a specific component
- site: The site number
- s : position of the site
- res_err: error in resistivity at a specific component
- phs_err: error in phase at a specific components.
"""
res = self._check_component_validity('res', components)
phs = self._check_component_validity('phase', components)
res_err , phs_err =[], []
if errorbar:
res_err = self._check_component_validity(
'res_err', components)
phs_err = self._check_component_validity(
'phase_err', components)
terror =("{0!r} does not contain component {}. Provide the"
" right component of the valid tensor.")
if res is None:
raise EMError(terror.format('resistivity', components))
if phs is None:
raise EMError(terror.format('phase', components))
# assert sites
sites, s = self._validate_sites(res, sites = sites, seed = seed ,
n_sites = n_sites , how = how )
try:
plt.style.use ( style or 'default')
except :
warnings.warn(
f"{style} is not available. Use `plt.style.available`"
" to get the list of available styles.")
plt.style.use ('default')
return res, phs, sites, s , res_err, phs_err
def _validate_sites (self,
data, /, sites = None, seed = None, n_sites = 1, how ='py'
):
""" validate sites or choose random sites from number of stations
in the survey data.
Parameters
-----------
data: List of resistivity-error and phases
A collection of resistivy , errors and phases from
EDI-objects
sites: str, list
A collection of sites to visualize.
seed : int,
`seed` is used to reproduce the same stations when sites are not
given.
n_sites: int, default=1
Number of sites to randomly selected for displaying. Note that it
only works if sites are ``None``.
how: str, default='py'
The way to fetch and display sites. By default used the Python
Indexing i.e the site starts with 0
Returns
-------
S, s: Tuple
Tuple of collection of sites and sites indexes.
"""
# assert sites
if seed:
seed = _assert_all_types(seed , int, float, objname ='Seed')
np.random.seed (seed )
if sites is None:
n_sites = int(n_sites ) if n_sites else n_sites
sites = np.random.permutation (range (data[0].shape[1])
)[:int (n_sites)]
# sites = [ np.random.choice (
# range (res[0].shape[1])) for i in range (nsites)]
# make site as an iterable object
sites = is_iterable(sites, exclude_string= True , transform =True )
s= copy.deepcopy(sites)
sites = [ re.search ('\d+', str(site), flags=re.IGNORECASE).group()
for site in sites ]
S = []
for ii, site in enumerate ( sites) :
try:
site= int(site)
except TypeError:
raise TypeError ("Missing position number. Station must prefix"
f" with position, e.g. 'S7', got {s[ii]!r}")
site = abs (site) + 1 if how !='py' else site
if site > data[0].shape [1] :
raise ValueError (
f"Site position {site} is out of the range. The total"
f" number of sites/stations ={data[0].shape [1]}")
S.append (site)
return S, s
[docs]
def plot_corrections(
self,
fltr='ama',
ss_fx =None,
ss_fy=None,
r=1000.,
nfreq=21,
skipfreq=5,
tol=.12,
rotate=0.,
distortion=None,
distortion_err =None,
mode ='TE',
scale ='period',
sites =None,
seed = None,
how ='py',
show_site=True,
survey= None,
style=None,
errorbar=True,
spad =.5,
n_sites = 1,
mcolors= None,
markers = None,
**kws
):
"""Plot apparent resistivity/phase curves and corrections.
.. versionchanged:: 0.2.1
Can henceforth display multiple sites by providing the
sites as a collection.
Parameters
----------
fltr: str , default='ama'
Type of filter to apply. ``ss`` is used to remove the static
shift using spatial median filter. Whereas ``dist`` is for
distorsion removal. Note that `distortion` might be provided
otherwise an error raises. Can also be ['tma'|'ama'|'flma'] for
:term:`EMAP` filters.
- ``tma`` for trimming moving-average
- ``ama`` for adaptative moving-average
- ``flma`` for fixed-length moving-average
.. versionadded: 0.2.1
Applied EMAP filters for the visualization.
distortion_tensor: np.ndarray(2, 2, dtype=real)
Real distortion tensor as a 2x2
error: np.ndarray(2, 2, dtype=real), Optional
Propagation of errors/uncertainties included
ss_fx: float, Optional
static shift factor to be applied to x components
(ie z[:, 0, :]). This is assumed to be in resistivity scale.
If None should be automatically computed using the
spatial median filter.
ss_fy: float, optional
static shift factor to be applied to y components
(ie z[:, 1, :]). This is assumed to be in resistivity scale. If
``None`` , should be computed using the spatial filter median.
r: float, default=1000.
radius to look for nearby stations, in meters.
nfreq: int, default=21
number of frequencies calculate the median static shift.
This is assuming the first frequency is the highest frequency.
Cause usually highest frequencies are sampling a 1D earth.
skipfreq: int, default=5
number of frequencies to skip from the highest frequency.
Sometimes the highest frequencies are not reliable due to noise
or low signal in the :term:`AMT` deadband. This allows you to
skip those frequencies.
tol: float, default=0.12
Tolerance on the median static shift correction. If the data is
noisy the correction factor can be biased away from 1. Therefore
the shift_tol is used to stop that bias. If
``1-tol < correction < 1+tol`` then the correction factor is set
to ``1``
rotate: float, default=0.
Rotate Z array by angle alpha in degrees. All angles are referenced
to geographic North, positive in clockwise direction.
(Mathematically negative!).
In non-rotated state, X refs to North and Y to East direction.
mode: str, default='TE',
Electromagnetic mode. Can be ['TM' |'both']. If ``both``,
components `xy` and `yx` are expected in the data.
scale: str, default='period'
Visualization on axis labell. can be ``'frequency'``.
sites: int,str, optional
index of name of the site to plot. `site` must be composed of
a position number. For instance ``'S13'``. If not provided,
a random station is selected instead.
seed : int, optional
Get the same site if site is not provided. `seed` fetches
a random number of site. T
how: str, default='py'
The way the site is fetched for plot. For instance, in Python
indexing (default), the site is numbered from 0. For instance
'site05' will fetch the data at index 4. If this positioning
is not wished, set to 'None'.
show_site:bool, default=True,
Display the number of site.
survey: str, optional
Method used for the survey. e.g., 'AMT' for |AMT|.
style:str, default='default'
Matplotlib style.
errorbar: bool, default=True
display the error bar.
spad: float, default=.5,
pad to display the station in the top of each section plot.
.. versionadded:: 0.2.1
n_sites: int, default =1.
Number of random sites to select for visualizing. It cannot work
if the names of sites are given.
mcolors: str, list, optional
The list of colors for resistivy and phase.
markers : str, list, optional
The list of marker for resistivy and phase.
markers = None,
kws: dict,
Addfitional keywords arguments passed to
Matplotlib.Axes.Scatter plots.
Examples
---------
>>> import numpy as np
>>> import watex as wx
>>> edi_data = wx.fetch_data ('edis', return_data =True, samples =27)
>>> wx.methods.TPlot(show_grid=True).fit(edi_data).plot_corrections (
seed =52, )
>>> distortion = np.array([[1.1 , 0.6 ],[0.23, 1.9 ]])
>>> wx.methods.TPlot(show_grid=True).fit(edi_data).plot_corrections (
seed =52, mode ='tm', fltr ='dist', distortion =distortion
)
"""
self.inspect
m=_validate_name_in(mode, 'tm transverse-magnetic', expect_name='tm')
if not m:
m='te'
scale = _validate_name_in(scale, deep =True, defaults='periods',
expect_name='period')
cpm = {'te': ["xy"] , 'tm': ["yx"]}
components = cpm.get(m)
res, phs, sites, *s= self._validate_correction (
components = components,
errorbar = errorbar ,
how = how,
seed = seed ,
sites = sites ,
style =style ,
n_sites = n_sites,
)
s, res_err, phs_err = s
# plot only single correction so
# Assert filters
mc = _validate_name_in(fltr, defaults =('static shift', 'ss', '1'),
expect_name= 'ss')
if mc!='ss':
mc = _validate_name_in(fltr, defaults=('distortion', 'dist', '2'),
expect_name ='dist')
if mc not in ('dist', 'ss') :
if str(fltr).lower() not in ( 'tma', 'ama', 'flma'):
ff = ('ss', 'dist', 'tma', 'ama', 'flma')
raise ValueError(f"Wrong filter {fltr!r}. Expect"
f"{smart_format(ff, 'or')} for corrections."
)
else: mc = str (fltr).lower()
if mc=='dist' and distortion is None:
raise TypeError("Distorsion cannot be None!")
# -> compute the corrected values
zo = MT().fit(self.p_.ediObjs_)
if mc =='ss':
zo.remove_static_shift (
ss_fx = ss_fx ,
ss_fy = ss_fx,
nfreq = nfreq ,
r=r,
skipfreq=skipfreq ,
tol=tol,
rotate = rotate,
)
elif mc =='dist':
zo.remove_distortion (
distortion ,
error = distortion_err
)
else:
zo.remove_ss_emap (fltr =mc )
# set zcorrected
zc = zo.new_Z_
zc_res = [ z.resistivity[tuple (self._c_.get(components[0])) ]
for z in zc ]
# zc_res = [ np.log10(r) for r in zc_res ] # convert to log10 res
# --> phase
zc_phase = [ z.phase[tuple (self._c_.get(components[0])) ]
for z in zc ]
# mofulo the phase to be 0 and 90 degree
zc_phase = [ np.abs (p)%90 for p in zc_phase ]
# ----------------------end ---------------------------------
survey= survey or self.p_.survey_name
if not survey: survey=''
# set defaults colors and markers
#colors = [ '#069AF3', '#DC143C']
colors = [] if mcolors is None else mcolors
c = is_iterable(colors , exclude_string =True , transform =True )
colors = list(c ) + [(0, .6, .3), (.9, 0, .8) ]
markers = [] if markers is None else markers
m = is_iterable( markers , exclude_string =True , transform =True )
markers = list(m) + [ 'o', 'D']
#==plotlog10 --------
# to use frequency for individual site rather than
# the complete frequency
fp = self.p_.ediObjs_[0].Z._freq
fp = 1/ fp if scale =='period' else fp
# min_y = np.nanmin(res[0][:, site])
# add error bar data to main
data = [res, phs ]
data += [ res_err , phs_err ] if errorbar else []
sy= np.average ( [
( np.nanmax(res[0][:, ii]) - np.nanmin(res[0][:, ii])) /2
for ii in sites ] )
sy += spad
sx = np.average (fp)
fig = plt.figure(figsize = self.fig_size, dpi=self.fig_dpi)
gs = GridSpec(2, len(sites),
wspace=0.,
left=.08,
top=.85,
bottom=.1,
right=.98,
hspace=.0,
height_ratios=[ 1.5, 1.]
)
sharey = None
#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++
for j , site in enumerate (sites ):
ax1 = fig.add_subplot (gs [ 0, j] , sharey = sharey)
if j==0: sharey = ax1
if errorbar:
ax2 = fig.add_subplot (gs [1, j] )
for i, sloop in enumerate (zip (* data )) :
r, p, *sl = sloop
if len(sl) !=0 :
e, ep = sl # mean errorbar is set to True
y = reshape (r[:, site])
if errorbar:
plot_errorbar (ax1 ,
fp,
y,
y_err = reshape (e[:, site]),
)
ax1.scatter (fp , y,
marker =markers [i],
color =colors [i],
edgecolors='k',
label = fr'{survey}$\rho_a${components[i]}',
**kws
)
# res_corr
ax1.scatter (fp, zc_res [site],
marker ='*',
color="#FF00FF",
edgecolors='k',
label = fr'{survey}$\rho_a${components[i]} {mc}',
**kws
)
if errorbar:
plot_errorbar (ax2 ,
fp, reshape (p[:, site]),
y_err = reshape (ep[:, site]),
)
ax2.scatter( fp,
reshape (p[:, site]),
marker =markers[i],
color =colors [i] ,
edgecolors='k',
label = f'{survey}$\phi${components[i]}',
**kws
)
# ----phase_cor
ax2.scatter( fp,
zc_phase [site],
marker ='*',
color="#FF00FF" ,
edgecolors='k',
label = f'{survey}$\phi${components[i]} {mc}',
**kws
)
# set ticks invisibale
if j > 0:
plt.setp(ax1.get_yticklabels(), visible=False)
plt.setp(ax2.get_yticklabels(), visible=False)
# Put the legend in the last images
if j == len(sites)-1:
try:
ax1.legend(ncols = len(r))
ax2.legend(ncols = len(p))
except:
# For consistency in the case matplotlib is < 3.3.
ax1.legend()
ax2.legend()
ax1.set_xscale ('log') ; ax1.set_yscale ('log')
ax2.set_xscale ('log')
if show_site:
ax1.text (sx,
sy,
f'S{site}',
fontdict= dict (style ='italic', bbox =dict(
boxstyle='round',facecolor ='#CED9EF'),
alpha = 0.5 )
)
ax2.set_ylim ([0, 90 ])
xlabel = self.xlabel or ( 'Log$_{10}$Period($s$)' if scale=='period'
else 'Frequency ($H_z$)')
# fixing yticks with matplotlib.ticker "FixedLocator"
# xticks_loc = ax2.get_xticks()
# ax2.xaxis.set_major_locator(mticker.FixedLocator(xticks_loc))
# ax2.set_xticklabels(['{:,.0f}'.format(np.log10(x))
# for x in xticks_loc])
ax2.set_xlabel(xlabel )
if j ==0 :
ax1.set_ylabel(self.ylabel or r'$\rho_a$($\Omega$.m)')
ax2.set_ylabel('$\phi$($\degree$)')
if self.show_grid :
for ax in (ax1, ax2 ):
ax.grid (visible =True , alpha =self.galpha,
linestyle = self.gls,
which =self.gwhich, color =self.gc
)
if self.savefig is not None:
plt.savefig (self.savefig, dpi = self.fig_dpi)
plt.close () if self.savefig is not None else plt.show()
return self
def __repr__(self):
""" Represents the output class format """
outm = ( '<{!r}:' + ', '.join(
[f"{k}={getattr(self, k)!r}" for k in self._t]) + '>'
)
return outm.format(self.__class__.__name__)
[docs]
class ExPlot (BasePlot):
msg = ("{expobj.__class__.__name__} instance is not"
" fitted yet. Call 'fit' with appropriate"
" arguments before using this method."
)
def __init__(
self,
tname:str = None,
inplace:bool = False,
**kws
):
super().__init__(**kws)
self.tname= tname
self.inplace= inplace
self.data= None
self.target_= None
self.y_= None
self.xname_=None
self.yname_=None
@property
def inspect(self):
""" Inspect data and trigger plot after checking the data entry.
Raises `NotFittedError` if `ExPlot` is not fitted yet."""
if self.data is None:
raise NotFittedError(self.msg.format(
expobj=self)
)
return 1
[docs]
def save (self, fig):
""" savefigure if figure properties are given. """
if self.savefig is not None:
fig.savefig (self.savefig, dpi = self.fig_dpi ,
bbox_inches = 'tight'
)
plt.show() if self.savefig is None else plt.close ()
[docs]
def fit(self, data: str |DataFrame, **fit_params )->'ExPlot':
""" Fit data and populate the arguments for plotting purposes.
There is no conventional procedure for checking if a method is fitted.
However, an class that is not fitted should raise
:class:`exceptions.NotFittedError` when a method is called.
Parameters
------------
data: Filepath or Dataframe or shape (M, N) from
:class:`pandas.DataFrame`. Dataframe containing samples M
and features N
fit_params: dict
Additional keywords arguments for reading the data is given as
a path-like object passed from
:func:watex.utils.coreutils._is_readable`
Return
-------
``self``: `Plot` instance
returns ``self`` for easy method chaining.
"""
if data is not None:
self.data = _is_readable(data, **fit_params)
if self.tname is not None:
self.target_, self.data = exporttarget(
self.data , self.tname, self.inplace )
self.y_ = reshape (self.target_.values ) # for consistency
return self
[docs]
def plotparallelcoords (
self,
classes: List [Any] = None,
pkg = 'pd',
rxlabel: int =45 ,
**kwd
)->'ExPlot':
""" Use parallel coordinates in multivariates for clustering
visualization
Parameters
------------
classes: list, default: None
a list of class names for the legend The class labels for each
class in y, ordered by sorted class index. These names act as a
label encoder for the legend, identifying integer classes or
renaming string labels. If omitted, the class labels will be taken
from the unique values in y.
Note that the length of this list must match the number of unique
values in y, otherwise an exception is raised.
pkg: str, Optional,
kind or library to use for visualization. can be ['sns'|'pd'] for
'yellowbrick' or 'pandas' respectively. *default* is ``pd``.
rxlabel: int, default is ``45``
rotate the xlabel when using pkg is set to ``pd``.
kws: dict,
Additional keywords arguments are passed down to
:class:`yellowbrick.ParallelCoordinates` and
:func:`pandas.plotting.parallel_coordinates`
Returns
--------
``self``: `ExPlot` instance and returns ``self`` for easy method chaining.
Examples
--------
>>> from watex.datasets import fetch_data
>>> from watex.view import ExPlot
>>> data =fetch_data('original data').get('data=dfy1')
>>> p = ExPlot (tname ='flow').fit(data)
>>> p.plotparallelcoords(pkg='yb')
... <'ExPlot':xname=None, yname=None , tname='flow'>
"""
self.inspect
if str(pkg) in ('yellowbrick', 'yb'):
pkg ='yb'
else: pkg ='pd'
fig, ax = plt.subplots(figsize = self.fig_size )
df = self.data .copy()
df = selectfeatures(df, include ='number')
if pkg =='yb':
import_optional_dependency('yellowbrick', (
"Cannot plot 'parallelcoordinates' with missing"
" 'yellowbrick'package.")
)
pc =ParallelCoordinates(ax =ax ,
features = df.columns,
classes =classes ,
**kwd
)
pc.fit(df, y = None or self.y_)
pc.transform (df)
label_format = '{:.0f}'
ticks_loc = list(ax.get_xticks())
ax.xaxis.set_major_locator(mticker.FixedLocator(ticks_loc))
ax.set_xticklabels([label_format.format(x) for x in ticks_loc],
rotation =rxlabel)
pc.show()
elif pkg =='pd':
if self.tname is not None:
if self.tname not in df.columns :
df[self.tname ]= self.y_
parallel_coordinates(df, class_column= self.tname,
ax= ax, **kwd
)
self.save (fig)
return self
[docs]
def plotradviz (self,
classes: List [Any] = None,
pkg:str = 'pd',
**kwd
)-> 'ExPlot':
""" plot each sample on circle or square, with features on the
circonference to vizualize separately between target.
Values are normalized and each figure has a spring that pulls samples
to it based on the value.
Parameters
------------
classes: list of int | float, [categorized classes]
must be a value in the target. Specified classes must match
the number of unique values in target. otherwise an error occurs.
the default behaviour i.e. ``None`` detect all classes in unique
value in the target.
pkg: str, Optional,
kind or library to use for visualization. can be ['sns'|'pd'] for
'yellowbrick' or 'pandas' respectively. default is ``pd``.
kws: dict,
Additional keywords arguments are passed down to
:class:`yellowbrick.RadViZ` and :func:`pandas.plotting.radviz`
Returns
-----------
``self``: `ExPlot` instance and returns ``self`` for easy method chaining.
Examples
---------
(1)-> using yellowbrick RadViz
>>> from watex.datasets import fetch_data
>>> from watex.view import ExPlot
>>> data0 = fetch_data('bagoue original').get('data=dfy1')
>>> p = ExPlot(tname ='flow').fit(data0)
>>> p.plotradviz(classes= [0, 1, 2, 3] ) # can set to None
(2) -> Using pandas radviz plot
>>> # use pandas with
>>> data2 = fetch_data('bagoue original').get('data=dfy2')
>>> p = ExPlot(tname ='flow').fit(data2)
>>> p.plotradviz(classes= None, pkg='pd' )
... <'ExPlot':xname=None, yname=None , tname='flow'>
"""
self.inspect
fig, ax = plt.subplots(figsize = self.fig_size )
df = self.data .copy()
if str(pkg) in ('yellowbrick', 'yb'):
pkg ='yb'
else: pkg ='pd'
if classes is None :
if self.tname is None:
raise TypeError (
"target name is missing. Can not fetch the target."
" Provide the target name instead."
)
classes = list(np.unique (self.y_))
df = selectfeatures(df, include ='number')
if pkg =='yb':
rv = RadViz( ax = ax ,
classes = classes ,
features = df.columns,
**kwd
)
rv.fit(df, y =None or self.y_ )
_ = rv.transform(df )
rv.show()
elif pkg =='pd':
if (self.tname is not None) and (self.y_ is not None) :
df [self.tname] = self.y_
radviz (df , class_column= self.tname , ax = ax,
**kwd
)
self.save (fig)
return self
[docs]
def plotpairwisecomparison (
self ,
corr:str = 'pearson',
pkg:str ='sns',
**kws
)-> 'ExPlot':
""" Create pairwise comparizons between features.
Plots shows a ['pearson'|'spearman'|'covariance'] correlation.
Parameters
-----------
corr: str, ['pearson'|'spearman'|'covariance']
Method of correlation to perform. Note that the 'person' and
'covariance' don't support string value. If such kind of data
is given, turn the `corr` to `spearman`.
*default* is ``pearson``
pkg: str, Optional,
kind or library to use for visualization. can be ['sns'|'yb'] for
'seaborn' or 'yellowbrick' respectively. default is ``sns``.
kws: dict,
Additional keywords arguments are passed down to
:class:`yellowbrick.Rand2D` and `seaborn.heatmap`
Returns
-----------
``self``: `ExPlot` instance and returns ``self`` for easy method chaining.
Example
---------
>>> from watex.datasets import fetch_data
>>> from watex.view import ExPlot
>>> data = fetch_data ('bagoue original').get('data=dfy1')
>>> p= ExPlot(tname='flow').fit(data)
>>> p.plotpairwisecomparison(fmt='.2f', corr='spearman', pkg ='yb',
annot=True,
cmap='RdBu_r',
vmin=-1,
vmax=1 )
... <'ExPlot':xname='sfi', yname='ohmS' , tname='flow'>
"""
self.inspect
if str(pkg) in ('yellowbrick', 'yb'):
pkg ='yb'
else: pkg ='sns'
fig, ax = plt.subplots(figsize = self.fig_size )
df = self.data .copy()
if pkg =='yb':
pcv = Rank2D( ax = ax,
features = df.columns,
algorithm=corr, **kws)
pcv.fit(df, y = None or self.y_ )
pcv.transform(df)
pcv.show()
elif pkg =='sns':
sns.set(rc={"figure.figsize":self.fig_size})
fig = sns.heatmap(data =df.corr() , **kws
)
self.save (fig)
return self
[docs]
def plotcutcomparison(
self,
xname: str =None,
yname:str =None,
q:int =10 ,
bins: int=3 ,
cmap:str = 'viridis',
duplicates:str ='drop',
**kws
)->'ExPlot':
"""Compare the cut or `q` quantiles values of ordinal categories.
It simulates that the the bining of 'xname' into a `q` quantiles, and
'yname'into `bins`. Plot is normalized so its fills all the vertical area.
which makes easy to see that in the `4*q %` quantiles.
Parameters
-------------
xname, yname : vectors or keys in data
Variables that specify positions on the x and y axes. Both are
the column names to consider. Shoud be items in the dataframe
columns. Raise an error if elements do not exist.
q: int or list-like of float
Number of quantiles. 10 for deciles, 4 for quartiles, etc.
Alternately array of quantiles, e.g. [0, .25, .5, .75, 1.] for
quartiles.
bins: int, sequence of scalars, or IntervalIndex
The criteria to bin by.
* int : Defines the number of equal-width bins in the range of x.
The range of x is extended by .1% on each side to include the
minimum and maximum values of x.
* sequence of scalars : Defines the bin edges allowing for non-uniform
width. No extension of the range of x is done.
* IntervalIndex : Defines the exact bins to be used. Note that
IntervalIndex for bins must be non-overlapping.
labels: array or False, default None
Used as labels for the resulting bins. Must be of the same length
as the resulting bins. If False, return only integer indicators of
the bins. If True, raises an error.
cmap: str, color or list of color, optional
The matplotlib colormap of the bar faces.
duplicates: {default 'raise', 'drop}, optional
If bin edges are not unique, raise ValueError or drop non-uniques.
*default* is 'drop'
kws: dict,
Other keyword arguments are passed down to `pandas.qcut` .
Returns
-------
``self``: `ExPlot` instance and returns ``self`` for easy method chaining.
Examples
---------
>>> from watex.datasets import fetch_data
>>> from watex.view import ExPlot
>>> data = fetch_data ('bagoue original').get('data=dfy1')
>>> p= ExPlot(tname='flow').fit(data)
>>> p.plotcutcomparison(xname ='sfi', yname='ohmS')
"""
self.inspect
self.xname_ = xname or self.xname_
self.yname_ = yname or self.yname_
fig, ax = plt.subplots(figsize = self.fig_size )
df = self.data .copy()
(df.assign(
xname_bin = pd.qcut(
df[self.xname_], q = q, labels =False,
duplicates = duplicates, **kws
),
yname_bin = pd.cut(
df[self.yname_], bins =bins, labels =False,
duplicates = duplicates,
),
)
.groupby (['xname_bin', 'yname_bin'])
.size ().unstack()
.pipe(lambda df: df.div(df.sum(1), axis=0))
.plot.bar(stacked=True,
width=1,
ax=ax,
cmap =cmap)
.legend(bbox_to_anchor=(1, 1))
)
self.save(fig)
return self
[docs]
def plotbv (
self,
xname: str =None,
yname:str =None,
kind:str ='box',
**kwd
)->'ExPlot':
"""Visualize distributions using the box, boxen or violin plots.
Parameters
-----------
xname, yname : vectors or keys in data
Variables that specify positions on the x and y axes. Both are
the column names to consider. Shoud be items in the dataframe
columns. Raise an error if elements do not exist.
kind: str
style of the plot. Can be ['box'|'boxen'|'violin'].
*default* is ``box``
kwd: dict,
Other keyword arguments are passed down to `seaborn.boxplot` .
Returns
-----------
``self``: `ExPlot` instance and returns ``self`` for easy
method chaining.
Example
--------
>>> from watex.datasets import fetch_data
>>> from watex.view import ExPlot
>>> data = fetch_data ('bagoue original').get('data=dfy1')
>>> p= ExPlot(tname='flow').fit(data)
>>> p.plotbv(xname='flow', yname='sfi', kind='violin')
"""
self.inspect
self.xname_ = xname or self.xname_
self.yname_ = yname or self.yname_
kind = str(kind).lower()
if kind.find('violin')>=0: kind = 'violin'
elif kind.find('boxen')>=0 : kind ='boxen'
else : kind ='box'
df = self.data.copy()
if (self.tname not in df.columns) and (self.y_ is not None):
df [self.tname] = self.y_
if kind =='box':
g= sns.boxplot(
data = df ,
x= self.xname_,
y=self.yname_ ,
**kwd
)
if kind =='boxen':
g= sns.boxenplot(
data = df ,
x= self.xname_,
y=self.yname_ ,
**kwd
)
if kind =='violin':
g = sns.violinplot(
data = df ,
x= self.xname_,
y=self.yname_ ,
**kwd
)
self.save(g)
return self
[docs]
def plotpairgrid (
self,
xname: str =None,
yname:str =None,
vars: List[str]= None,
**kwd
) -> 'ExPlot':
""" Create a pair grid.
Is a matrix of columns and kernel density estimations. To color by a
columns from a dataframe, use 'hue' parameter.
Parameters
-------------
xname, yname : vectors or keys in data
Variables that specify positions on the x and y axes. Both are
the column names to consider. Shoud be items in the dataframe
columns. Raise an error if elements do not exist.
vars: list, str
list of items in the dataframe columns. Raise an error if items
dont exist in the dataframe columns.
kws: dict,
Other keyword arguments are passed down to `seaborn.joinplot`_ .
Returns
-----------
``self``: `ExPlot` instance and returns ``self`` for easy method chaining.
Example
--------
>>> from watex.datasets import fetch_data
>>> from watex.view import ExPlot
>>> data = fetch_data ('bagoue original').get('data=dfy1')
>>> p= ExPlot(tname='flow').fit(data)
>>> p.plotpairgrid (vars = ['magnitude', 'power', 'ohmS'] )
... <'ExPlot':xname=(None,), yname=None , tname='flow'>
"""
self.inspect
self.xname_ = xname or self.xname_
self.yname_ = yname or self.yname_
df = self.data.copy()
if (self.tname not in df.columns) and (self.y_ is not None):
df [self.tname] = self.y_ # set new dataframe with a target
if vars is None :
vars = [self.xname_, self.y_name ]
sns.set(rc={"figure.figsize":self.fig_size})
g = sns.pairplot (df, vars= vars, hue = self.tname,
**kwd,
)
self.save(g)
return self
[docs]
def plotjoint (
self,
xname: str,
yname:str =None,
corr: str = 'pearson',
kind:str ='scatter',
pkg='sns',
yb_kws =None,
**kws
)->'ExPlot':
""" fancier scatterplot that includes histogram on the edge as well as
a regression line called a `joinplot`
Parameters
-------------
xname, yname : vectors or keys in data
Variables that specify positions on the x and y axes. Both are
the column names to consider. Shoud be items in the dataframe
columns. Raise an error if elements do not exist.
pkg: str, Optional,
kind or library to use for visualization. can be ['sns'|'yb'] for
'seaborn' or 'yellowbrick'. default is ``sns``.
kind : str in {'scatter', 'hex'}, default: 'scatter'
The type of plot to render in the joint axes. Note that when
kind='hex' the target cannot be plotted by color.
corr: str, default: 'pearson'
The algorithm used to compute the relationship between the
variables in the joint plot, one of: 'pearson', 'covariance',
'spearman', 'kendalltau'.
yb_kws: dict,
Additional keywords arguments from
:class:`yellowbrick.JointPlotVisualizer`
kws: dict,
Other keyword arguments are passed down to `seaborn.joinplot`_ .
Returns
-----------
``self``: `ExPlot` instance and returns ``self`` for easy method chaining.
Notes
-------
When using the `yellowbrick` library and array i.e a (x, y) variables
in the columns as well as the target arrays must not contain infs or NaNs
values. A value error raises if that is the case.
.. _seaborn.joinplot: https://seaborn.pydata.org/generated/seaborn.joinplot.html
"""
pkg = str(pkg).lower().strip()
if pkg in ('yb', 'yellowbrick'): pkg ='yb'
else: pkg ='sns'
self.inspect
self.xname_ = xname or self.xname_
self.yname_ = yname or self.yname_
# assert yb_kws arguments
yb_kws = yb_kws or dict()
yb_kws = _assert_all_types(yb_kws, dict)
if pkg =='yb':
fig, ax = plt.subplots(figsize = self.fig_size )
jpv = JointPlotVisualizer(
ax =ax ,
#columns =self.xname_, # self.data.columns,
correlation=corr,
# feature=self.xname_,
# target=self.tname,
kind= kind ,
fig = fig,
**yb_kws
)
jpv.fit(
self.data [self.xname_],
self.data [self.tname] if self.y_ is None else self.y_,
)
jpv.show()
elif pkg =='sns':
sns.set(rc={"figure.figsize":self.fig_size})
sns.set_style (self.sns_style)
df = self.data.copy()
if (self.tname not in df.columns) and (self.y_ is not None):
df [self.tname] = self.y_ # set new dataframe with a target
fig = sns.jointplot(
data= df,
x = self.xname_,
y= self.yname_,
**kws
)
self.save(fig )
return self
[docs]
def plotscatter (
self,
xname:str =None,
yname:str = None,
c:str |Series|ArrayLike =None,
s: int |ArrayLike =None,
**kwd
)->'ExPlot':
""" Shows the relationship between two numeric columns.
Parameters
------------
xname, yname : vectors or keys in data
Variables that specify positions on the x and y axes. Both are
the column names to consider. Shoud be items in the dataframe
columns. Raise an error if elements do not exist.
c: str, int or array_like, Optional
The color of each point. Possible values are:
* A single color string referred to by name, RGB or RGBA code,
for instance 'red' or '#a98d19'.
* A sequence of color strings referred to by name, RGB or RGBA
code, which will be used for each point’s color recursively.
For instance [‘green’,’yellow’] all points will be filled
in green or yellow, alternatively.
* A column name or position whose values will be used to color
the marker points according to a colormap.
s: scalar or array_like, Optional,
The size of each point. Possible values are:
* A single scalar so all points have the same size.
* A sequence of scalars, which will be used for each point’s
size recursively. For instance, when passing [2,14] all
points size will be either 2 or 14, alternatively.
kwd: dict,
Other keyword arguments are passed down to `seaborn.scatterplot`_ .
Returns
-----------
``self``: `ExPlot` instance
returns ``self`` for easy method chaining.
Example
---------
>>> from watex.view import ExPlot
>>> p = ExPlot(tname='flow').fit(data).plotscatter (
xname ='sfi', yname='ohmS')
>>> p
... <'ExPlot':xname='sfi', yname='ohmS' , tname='flow'>
References
------------
Scatterplot: https://seaborn.pydata.org/generated/seaborn.scatterplot.html
Pd.scatter plot: https://www.w3resource.com/pandas/dataframe/dataframe-plot-scatter.php
.. _seaborn.scatterplot: https://seaborn.pydata.org/generated/seaborn.scatterplot.html
"""
self.inspect
hue = kwd.pop('hue', None)
self.xname_ = xname or self.xname_
self.yname_ = yname or self.yname_
if hue is not None:
self.tname = hue
if xname is not None:
existfeatures( self.data, self.xname_ )
if yname is not None:
existfeatures( self.data, self.yname_ )
# state the fig plot and change the figure size
sns.set(rc={"figure.figsize":self.fig_size}) #width=3, #height=4
if self.sns_style is not None:
sns.set_style(self.sns_style)
# try :
fig= sns.scatterplot( data = self.data, x = self.xname_,
y=self.yname_, hue =self.tname,
# ax =ax , # call matplotlib.pyplot.gca() internally
**kwd)
# except :
# warnings.warn("The following variable cannot be assigned with "
# "wide-form data: `hue`; use the pandas scatterplot "
# "instead.")
# self.data.plot.scatter (x =xname , y=yname, c=c,
# s = s, ax =ax )
self.save(fig)
return self
[docs]
def plothistvstarget (
self,
xname: str,
c: Any =None, *,
posilabel: str = None,
neglabel: str= None,
kind='binarize',
**kws
)->'ExPlot':
"""
A histogram of continuous against the target of binary plot.
Parameters
----------
xname: str,
the column name to consider on x-axis. Shoud be an item in the
dataframe columns. Raise an error if element does not exist.
c: str or int
the class value in `y` to consider. Raise an error if not in `y`.
value `c` can be considered as the binary positive class
posilabel: str, Optional
the label of `c` considered as the positive class
neglabel: str, Optional
the label of other classes (categories) except `c` considered as
the negative class
kind: str, Optional, (default, 'binarize')
the kind of plot features against target. `binarize` considers
plotting the positive class ('c') vs negative class ('not c')
kws: dict,
Additional keyword arguments of `seaborn displot`_
Returns
-----------
``self``: `ExPlot` instance
returns ``self`` for easy method chaining.
Examples
--------
>>> from watex.utils import read_data
>>> from watex.view import ExPlot
>>> data = read_data ( 'data/geodata/main.bagciv.data.csv' )
>>> p = ExPlot(tname ='flow').fit(data)
>>> p.fig_size = (7, 5)
>>> p.savefig ='bbox.png'
>>> p.plothistvstarget (xname= 'sfi', c = 0, kind = 'binarize', kde=True,
posilabel='dried borehole (m3/h)',
neglabel = 'accept. boreholes'
)
Out[95]: <'ExPlot':xname='sfi', yname=None , tname='flow'>
"""
self.inspect
self.xname_ = xname or self.xname_
existfeatures(self.data, self.xname_) # assert the name in the columns
df = self.data.copy()
if str(kind).lower().strip().find('bin')>=0:
if c is None:
raise ValueError ("Need a categorical class value for binarizing")
_assert_all_types(c, float, int)
if self.y_ is None:
raise ValueError ("target name is missing. Specify the `tname`"
f" and refit {self.__class__.__name__!r} ")
if not _isin(self.y_, c ):
raise ValueError (f"c-value should be a class label, got '{c}'"
)
mask = self.y_ == c
# for consisteny use np.unique to get the classes
neglabel = neglabel or shrunkformat(
np.unique (self.y_[~(self.y_ == c)])
)
else:
if self.tname is None:
raise ValueError("Can't plot binary classes with missing"
" target name ")
df[self.tname] = df [self.tname].map(
lambda x : 1 if ( x == c if isinstance(c, str) else x <=c
) else 0 # mapping binary target
)
#--> now plot
# state the fig plot
sns.set(rc={"figure.figsize":self.fig_size})
if str(kind).lower().strip().find('bin')>=0:
g=sns.histplot (data = df[mask][self.xname_],
label= posilabel or str(c) ,
linewidth = self.lw,
**kws
)
g=sns.histplot( data = df[~mask][self.xname_],
label= neglabel ,
linewidth = self.lw,
**kws,
)
else :
g=sns.histplot (data =df ,
x = self.xname_,
hue= self.tname,
linewidth = self.lw,
**kws
)
if self.sns_style is not None:
sns.set_style(self.sns_style)
g.legend ()
# self.save(g)
return self
[docs]
def plothist(self,xname: str = None, *, kind:str = 'hist',
**kws
):
""" A histogram visualization of numerica data.
Parameters
----------
xname: str , xlabel
feature name in the dataframe and is the label on x-axis.
Raises an error , if it does not exist in the dataframe
kind: str
Mode of pandas series plotting. the *default* is ``hist``.
kws: dict,
additional keywords arguments from : func:`pandas.DataFrame.plot`
Return
-------
``self``: `ExPlot` instance
returns ``self`` for easy method chaining.
"""
self.inspect
self.xname_ = xname or self.xname_
xname = _assert_all_types(self.xname_,str )
# assert whether whether feature exists
existfeatures(self.data, self.xname_)
fig, ax = plt.subplots (figsize = self.fig_size or self.fig_size )
self.data [self.xname_].plot(kind = kind , ax= ax , **kws )
self.save(fig)
return self
[docs]
def plotmissing(self, *,
kind: str =None,
sample: float = None,
**kwd
):
"""
Vizualize patterns in the missing data.
Parameters
------------
data: Dataframe or shape (M, N) from :class:`pandas.DataFrame`
Dataframe containing samples M and features N
kind: str, Optional
kind of visualization. Can be ``dendrogramm``, ``mbar`` or ``bar`` plot
for dendrogram , :mod:`msno` bar and :mod:`plt` visualization
respectively:
* ``bar`` plot counts the nonmissing data using pandas
* ``mbar`` use the :mod:`msno` package to count the number
of nonmissing data.
* dendrogram`` show the clusterings of where the data is missing.
leaves that are the same level predict one onother presence
(empty of filled). The vertical arms are used to indicate how
different cluster are. short arms mean that branch are
similar.
* ``corr` creates a heat map showing if there are correlations
where the data is missing. In this case, it does look like
the locations where missing data are corollated.
* ``mpatterns`` is the default vizualisation. It is useful for viewing
contiguous area of the missing data which would indicate that
the missing data is not random. The :code:`matrix` function
includes a sparkline along the right side. Patterns here would
also indicate non-random missing data. It is recommended to limit
the number of sample to be able to see the patterns.
Any other value will raise an error
sample: int, Optional
Number of row to visualize. This is usefull when data is composed of
many rows. Skrunked the data to keep some sample for visualization is
recommended. ``None`` plot all the samples ( or examples) in the data
kws: dict
Additional keywords arguments of :mod:`msno.matrix` plot.
Return
-------
``self``: `ExPlot` instance
returns ``self`` for easy method chaining.
Example
--------
>>> import pandas as pd
>>> from watex.view import ExPlot
>>> data = pd.read_csv ('data/geodata/main.bagciv.data.csv' )
>>> p = ExPlot().fit(data)
>>> p.fig_size = (12, 4)
>>> p.plotmissing(kind ='corr')
"""
self.inspect
kstr =('dendrogram', 'bar', 'mbar', 'correlation', 'mpatterns')
kind = str(kind).lower().strip()
regex = re.compile (r'none|dendro|corr|base|default|mbar|bar|mpat',
flags= re.IGNORECASE)
kind = regex.search(kind)
if kind is None:
raise ValueError (f"Expect {smart_format(kstr, 'or')} not: {kind!r}")
kind = kind.group()
if kind in ('none', 'default', 'base', 'mpat'):
kind ='mpat'
if sample is not None:
sample = _assert_all_types(sample, int, float)
if kind =='bar':
fig, ax = plt.subplots (figsize = self.fig_size, **kwd )
(1- self.data.isnull().mean()).abs().plot.bar(ax=ax)
elif kind in ('mbar', 'dendro', 'corr', 'mpat'):
try :
msno
except :
raise ModuleNotFoundError(
f"Missing 'missingno' package. Can not plot {kind!r}")
if kind =='mbar':
ax = msno.bar(
self.data if sample is None else self.data.sample(sample),
figsize = self.fig_size
)
elif kind =='dendro':
ax = msno.dendrogram(self.data, figsize = self.fig_size , **kwd)
elif kind =='corr':
ax= msno.heatmap(self.data, figsize = self.fig_size)
else :
ax = msno.matrix(
self.data if sample is None else self.data.sample (sample),
figsize= self.fig_size , **kwd)
if self.savefig is not None:
fig.savefig(self.savefig, dpi =self.fig_dpi
) if kind =='bar' else ax.get_figure (
).savefig (self.savefig, dpi =self.fig_dpi)
return self
def __repr__(self):
""" Represent the output class format """
return "<{0!r}:xname={1!r}, yname={2!r} , tname={3!r}>".format(
self.__class__.__name__, self.xname_ , self.yname_ , self.tname
)
[docs]
class QuickPlot (BasePlot):
def __init__(
self,
classes = None,
tname= None,
mapflow=False,
**kws
):
super().__init__(**kws)
self._logging =watexlog().get_watex_logger(self.__class__.__name__)
self.classes=classes
self.tname=tname
self.mapflow=mapflow
@property
def data(self):
return self.data_
@data.setter
def data (self, data):
""" Read the data file
Can read the data file provided and set the data into pd.DataFrame by
calling :class:`watex.bases.features.FeatureInspection` to populate
convenient attributes especially when the target name is specified as
`flow`. Be sure to set other name if you dont want to consider flow
features inspection.
"""
if str(self.tname).lower() =='flow':
# default inspection for DC -flow rate prediction
fobj= FeatureInspection( set_index=True,
flow_classes = self.classes or [0., 1., 3] ,
target = self.tname,
mapflow= self.mapflow
).fit(data=data)
self.data_= fobj.data
self.data_ = _is_readable(
data , input_name="'data'")
if str(self.tname).lower() in self.data_.columns.str.lower():
ix = list(self.data.columns.str.lower()).index (
self.tname.lower() )
self.y = self.data_.iloc [:, ix ]
self.X_ = self.data_.drop(columns =self.data_.columns[ix] ,
)
[docs]
def fit(
self,
data: str | DataFrame,
y: Optional[Series| ArrayLike]=None
)-> "QuickPlot" :
"""
Fit data and populate the attributes for plotting purposes.
Parameters
----------
data: str or pd.core.DataFrame
Path -like object or Dataframe. Long-form (tidy) dataset for
plotting. Each column should correspond to a variable, and each
row should correspond to an observation. If data is given as
path-like object,`QuickPlot` reads and sanitizes data before
plotting. Be aware in this case to provide the target name and
possible the `classes` for data inspection. Both str or dataframe
need to provide the name of target.
y: array-like, optional
array of the target. Must be the same length as the data. If `y`
is provided and `data` is given as ``str`` or ``DataFrame``,
all the data should be considered as the X data for analysis.
Returns
-------
self: :class:`QuickPlot` instance
Returns ``self`` for easy method chaining.
Examples
--------
>>> from watex.datasets import load_bagoue
>>> data = load_bagoue ().frame
>>> from watex.view.plot import QuickPlot
>>> qplotObj= QuickPlot(xlabel = 'Flow classes in m3/h',
ylabel='Number of occurence (%)')
>>> qplotObj.tname= None # eith nameof target set to None
>>> qplotObj.fit(data)
>>> qplotObj.data.iloc[1:2, :]
... num name east north ... ohmS lwi geol flow
1 2.0 b2 791227.0 1159566.0 ... 1135.551531 21.406531 GRANITES 0.0
>>> qplotObj.tname= 'flow'
>>> qplotObj.mapflow= True # map the flow from num. values to categ. values
>>> qplotObj.fit(data)
>>> qplotObj.data.iloc[1:2, :]
... num name east north ... ohmS lwi geol flow
1 2.0 b2 791227.0 1159566.0 ... 1135.551531 21.406531 GRANITES FR0
"""
self.data = data
if y is not None:
_, y = check_X_y(
self.data, y,
force_all_finite="allow-nan",
dtype =object,
to_frame = True
)
y = _assert_all_types(y, np.ndarray, list, tuple, pd.Series)
if len(y)!= len(self.data) :
raise ValueError(
f"y and data must have the same length but {len(y)} and"
f" {len(self.data)} were given respectively.")
self.y = pd.Series (y , name = self.tname or 'none')
# for consistency get the name of target
self.tname = self.y.name
return self
[docs]
def histcatdist(
self,
stacked: bool = False,
**kws
):
"""
Histogram plot distribution.
Plots a distributions of categorized classes according to the
percentage of occurence.
Parameters
-----------
stacked: bool
Pill bins one to another as a cummulative values. *default* is
``False``.
bins : int, optional
contains the integer or sequence or string
range : list, optional
is the lower and upper range of the bins
density : bool, optional
contains the boolean values
weights : array-like, optional
is an array of weights, of the same shape as `data`
bottom : float, optional
is the location of the bottom baseline of each bin
histtype : str, optional
is used to draw type of histogram. {'bar', 'barstacked', step, 'stepfilled'}
align : str, optional
controls how the histogram is plotted. {'left', 'mid', 'right'}
rwidth : float, optional,
is a relative width of the bars as a fraction of the bin width
log : bool, optional
is used to set histogram axis to a log scale
color : str, optional
is a color spec or sequence of color specs, one per dataset
label : str , optional
is a string, or sequence of strings to match multiple datasets
normed : bool, optional
an optional parameter and it contains the boolean values. It uses
the density keyword argument instead.
data: str or pd.core.DataFrame
Path -like object or Dataframe. Long-form (tidy) dataset for
plotting. Each column should correspond to a variable, and each
row should correspond to an observation. If data is given as
path-like object,`QuickPlot` reads and sanitizes data before
plotting. Be aware in this case to provide the target name and
possible the `classes` for data inspection. Both str or dataframe
need to provide the name of target.
Returns
-------
:class:`QuickPlot` instance
Returns ``self`` for easy method chaining.
Notes
-------
The argument for `data` must be passed to `fit` method. `data`
parameter is not allowed in other `QuickPlot` method. The description
of the parameter `data` is to give a synopsis of the kind of data
the plot expected. An error will raise if force to pass `data`
argument as a keyword arguments.
Examples
---------
>>> from watex.view.plot import QuickPlot
>>> from watex.datasets import load_bagoue
>>> data = load_bagoue ().frame
>>> qplotObj= QuickPlot(xlabel = 'Flow classes',
ylabel='Number of occurence (%)',
lc='b', tname='flow')
>>> qplotObj.sns_style = 'darkgrid'
>>> qplotObj.fit(data)
>>> qplotObj. histcatdist()
"""
self._logging.info('Quick plot of categorized classes distributions.'
f' the target name: {self.tname!r}')
self.inspect
if self.tname is None and self.y is None:
raise FeatureError(
"Please specify 'tname' as the name of the target")
# reset index
df_= self.data_.copy() #make a copy for safety
df_.reset_index(inplace =True)
if kws.get('bins', None) is not None:
self.bins = kws.pop ('bins', None)
plt.figure(figsize =self.fig_size)
plt.hist(df_[self.tname], bins=self.bins ,
stacked = stacked , color= self.lc , **kws)
plt.xlabel(self.xlabel)
plt.ylabel(self.ylabel)
plt.title(self.fig_title)
if self.savefig is not None :
plt.savefig(self.savefig,dpi=self.fig_dpi,
orientation =self.fig_orientation
)
return self
[docs]
def barcatdist(
self,
basic_plot: bool = True,
groupby: List[str] | Dict [str, float] =None,
**kws):
"""
Bar plot distribution.
Plots a categorical distribution according to the occurence of the
`target` in the data.
Parameters
-----------
basic_pot: bool,
Plot only the occurence of targetted columns from
`matplotlib.pyplot.bar` function.
groupby: list or dict, optional
Group features for plotting. For instance it plot others features
located in the df columns. The plot features can be on ``list``
and use default plot properties. To customize plot provide, one may
provide, the features on ``dict`` with convenients properties
like::
* `groupby`= ['shape', 'type'] #{'type':{'color':'b',
'width':0.25 , 'sep': 0.}
'shape':{'color':'g', 'width':0.25,
'sep':0.25}}
kws: dict,
Additional keywords arguments from `seaborn.countplot`
data: str or pd.core.DataFrame
Path -like object or Dataframe. Long-form (tidy) dataset for
plotting. Each column should correspond to a variable, and each
row should correspond to an observation. If data is given as
path-like object,`QuickPlot` reads and sanitizes data before
plotting. Be aware in this case to provide the target name and
possible the `classes` for data inspection. Both str or dataframe
need to provide the name of target.
Returns
-------
:class:`QuickPlot` instance
Returns ``self`` for easy method chaining.
Notes
-------
The argument for `data` must be passed to `fit` method. `data`
parameter is not allowed in other `QuickPlot` method. The description
of the parameter `data` is to give a synopsis of the kind of data
the plot expected. An error will raise if force to pass `data`
argument as a keyword arguments.
Examples
----------
>>> from watex.view.plot import QuickPlot
>>> from watex.datasets import load_bagoue
>>> data = load_bagoue ().frame
>>> qplotObj= QuickPlot(xlabel = 'Anomaly type',
ylabel='Number of occurence (%)',
lc='b', tname='flow')
>>> qplotObj.sns_style = 'darkgrid'
>>> qplotObj.fit(data)
>>> qplotObj. barcatdist(basic_plot =False,
... groupby=['shape' ])
"""
self.inspect
fig, ax = plt.subplots(figsize = self.fig_size)
df_= self.data.copy(deep=True) #make a copy for safety
df_.reset_index(inplace =True)
if groupby is None:
mess= ''.join([
'Basic plot is turn to``False`` but no specific plot is',
" detected. Please provide a specific column's into "
" a `specific_plot` argument."])
self._logging.debug(mess)
warnings.warn(mess)
basic_plot =True
if basic_plot :
ax.bar(list(set(df_[self.tname])),
df_[self.tname].value_counts(normalize =True),
label= self.fig_title, color = self.lc, )
if groupby is not None :
if hasattr(self, 'sns_style'):
sns.set_style(self.sns_style)
if isinstance(groupby, str):
self.groupby =[groupby]
if isinstance(groupby , dict):
groupby =list(groupby.keys())
for sll in groupby :
ax= sns.countplot(x= sll, hue=self.tname,
data = df_, orient = self.sns_orient,
ax=ax ,**kws)
ax.set_xlabel(self. xlabel)
ax.set_ylabel (self.ylabel)
ax.set_title(self.fig_title)
ax.legend()
if groupby is not None:
self._logging.info(
'Multiple bar plot distribution grouped by {0}.'.format(
formatGenericObj(groupby)).format(*groupby))
if self.savefig is not None :
plt.savefig(self.savefig,dpi=self.fig_dpi,
orientation =self.fig_orientation)
plt.show() if self.savefig is None else plt.close ()
print('--> Bar distribution plot successfully done!'
)if self.verbose > 0 else print()
return self
[docs]
def multicatdist(
self,
*,
x =None,
col=None,
hue =None,
targets: List[str]=None,
x_features:List[str]=None ,
y_features: List[str]=None,
kind:str='count',
**kws):
"""
Figure-level interface for drawing multiple categorical distributions
plots onto a FacetGrid.
Multiple categorials plots from targetted pd.series.
Parameters
-----------
x, y, hue: list , Optional,
names of variables in data. Inputs for plotting long-form data.
See examples for interpretation. Here it can correspond to
`x_features` , `y_features` and `targets` from dataframe. Note that
each columns item could be correspond as element of `x`, `y` or `hue`.
For instance x_features could refer to x-axis features and must be
more than 0 and set into a list. the `y_features` might match the
columns name for `sns.catplot`. If number of feature is more than
one, create a list to hold all features is recommended.
the `y` should fit the `sns.catplot` argument ``hue``. Like other
it should be on list of features are greater than one.
row, colnames of variables in data, optional
Categorical variables that will determine the faceting of the grid.
col_wrapint
"Wrap" the column variable at this width, so that the column facets
span multiple rows. Incompatible with a row facet.
estimator: string or callable that maps vector -> scalar, optional
Statistical function to estimate within each categorical bin.
errorbar: string, (string, number) tuple, or callable
Name of errorbar method (either "ci", "pi", "se", or "sd"), or a
tuple with a method name and a level parameter, or a function that
maps from a vector to a (min, max) interval.
n_bootint, optional
Number of bootstrap samples used to compute confidence intervals.
units: name of variable in data or vector data, optional
Identifier of sampling units, which will be used to perform a
multilevel bootstrap and account for repeated measures design.
seed: int, numpy.random.Generator, or numpy.random.RandomState, optional
Seed or random number generator for reproducible bootstrapping.
order, hue_order: lists of strings, optional
Order to plot the categorical levels in; otherwise the levels are
inferred from the data objects.
row_order, col_order: lists of strings, optional
Order to organize the rows and/or columns of the grid in, otherwise
the orders are inferred from the data objects.
height: scalar
Height (in inches) of each facet. See also: aspect.
aspect:scalar
Aspect ratio of each facet, so that aspect * height gives the width
of each facet in inches.
kind: str, optional
`The kind of plot to draw, corresponds to the name of a categorical
axes-level plotting function. Options are: "strip", "swarm", "box",
"violin", "boxen", "point", "bar", or "count".
native_scale: bool, optional
When True, numeric or datetime values on the categorical axis
will maintain their original scaling rather than being converted
to fixed indices.
formatter: callable, optional
Function for converting categorical data into strings. Affects both
grouping and tick labels.
orient: "v" | "h", optional
Orientation of the plot (vertical or horizontal). This is usually
inferred based on the type of the input variables, but it can be
used to resolve ambiguity when both x and y are numeric or when
plotting wide-form data.
color: matplotlib color, optional
Single color for the elements in the plot.
palette: palette name, list, or dict
Colors to use for the different levels of the hue variable.
Should be something that can be interpreted by color_palette(),
or a dictionary mapping hue levels to matplotlib colors.
hue_norm: tuple or matplotlib.colors.Normalize object
Normalization in data units for colormap applied to the hue
variable when it is numeric. Not relevant if hue is categorical.
legend: str or bool, optional
Set to False to disable the legend. With strip or swarm plots,
this also accepts a string, as described in the axes-level
docstrings.
legend_out: bool
If True, the figure size will be extended, and the legend will be
drawn outside the plot on the center right.
share{x,y}: bool, 'col', or 'row' optional
If true, the facets will share y axes across columns and/or x axes
across rows.
margin_titles:bool
If True, the titles for the row variable are drawn to the right of
the last column. This option is experimental and may not work in
all cases.
facet_kws: dict, optional
Dictionary of other keyword arguments to pass to FacetGrid.
kwargs: key, value pairings
Other keyword arguments are passed through to the underlying
plotting function.
data: str or pd.core.DataFrame
Path -like object or Dataframe. Long-form (tidy) dataset for
plotting. Each column should correspond to a variable, and each
row should correspond to an observation. If data is given as
path-like object,`QuickPlot` reads and sanitizes data before
plotting. Be aware in this case to provide the target name and
possible the `classes` for data inspection. Both str or dataframe
need to provide the name of target.
Returns
-------
:class:`QuickPlot` instance
Returns ``self`` for easy method chaining.
Notes
-------
The argument for `data` must be passed to `fit` method. `data`
parameter is not allowed in other `QuickPlot` method. The description
of the parameter `data` is to give a synopsis of the kind of data
the plot expected. An error will raise if force to pass `data`
argument as a keyword arguments.
Examples
---------
>>> from watex.view.plot import QuickPlot
>>> from watex.datasets import load_bagoue
>>> data = load_bagoue ().frame
>>> qplotObj= QuickPlot(lc='b', tname='flow')
>>> qplotObj.sns_style = 'darkgrid'
>>> qplotObj.mapflow=True # to categorize the flow rate
>>> qplotObj.fit(data)
>>> fdict={
... 'x':['shape', 'type', 'type'],
... 'col':['type', 'geol', 'shape'],
... 'hue':['flow', 'flow', 'geol'],
... }
>>> qplotObj.multicatdist(**fdict)
"""
self.inspect
# set
if x is None :
x = [None]
if col is None:
col =[None]
if hue is None:
hue =[None]
# for consistency put the values in list
x, col, hue = list(x) , list(col), list(hue)
maxlen = max([len(i) for i in [x, col, hue]])
x.extend ( [None for n in range(maxlen - len(x))])
col.extend ([None for n in range(maxlen - len(col))] )
hue.extend ([None for n in range(maxlen - len(hue))])
df_= self.data.copy(deep=True)
df_.reset_index(inplace=True )
if not hasattr(self, 'ylabel'):
self.ylabel= 'Number of occurence (%)'
if hue is not None:
self._logging.info(
'Multiple categorical plots from targetted {0}.'.format(
formatGenericObj(hue)).format(*hue))
for ii in range(len(x)):
sns.catplot(data = df_,
kind= kind,
x= x[ii],
col=col[ii],
hue= hue[ii],
linewidth = self.lw,
height = self.sns_height,
aspect = self.sns_aspect,
**kws
).set_ylabels(self.ylabel)
plt.show()
if self.sns_style is not None:
sns.set_style(self.sns_style)
print('--> Multiple distribution plots sucessfully done!'
) if self.verbose > 0 else print()
return self
[docs]
def corrmatrix(
self,
cortype:str ='num',
features: Optional[List[str]] = None,
method: str ='pearson',
min_periods: int=1,
**sns_kws):
"""
Method to quick plot the numerical and categorical features.
Set `features` by providing the names of features for visualization.
Parameters
-----------
cortype: str,
The typle of parameters to cisualize their coreletions. Can be
``num`` for numerical features and ``cat`` for categorical features.
*Default* is ``num`` for quantitative values.
method: str,
the correlation method. can be 'spearman' or `person`. *Default is
``pearson``
features: List, optional
list of the name of features for correlation analysis. If given,
must be sure that the names belong to the dataframe columns,
otherwise an error will occur. If features are valid, dataframe
is shrunk to the number of features before the correlation plot.
min_periods:
Minimum number of observations required per pair of columns
to have a valid result. Currently only available for
``pearson`` and ``spearman`` correlation. For more details
refer to https://www.geeksforgeeks.org/python-pandas-dataframe-corr/
sns_kws: Other seabon heatmap arguments. Refer to
https://seaborn.pydata.org/generated/seaborn.heatmap.html
data: str or pd.core.DataFrame
Path -like object or Dataframe. Long-form (tidy) dataset for
plotting. Each column should correspond to a variable, and each
row should correspond to an observation. If data is given as
path-like object,`QuickPlot` reads and sanitizes data before
plotting. Be aware in this case to provide the target name and
possible the `classes` for data inspection. Both str or dataframe
need to provide the name of target.
Returns
-------
:class:`QuickPlot` instance
Returns ``self`` for easy method chaining.
Notes
-------
The argument for `data` must be passed to `fit` method. `data`
parameter is not allowed in other `QuickPlot` method. The description
of the parameter `data` is to give a synopsis of the kind of data
the plot expected. An error will raise if force to pass `data`
argument as a keyword arguments.
Example
---------
>>> from watex.view.plot import QuickPlot
>>> from watex.datasets import load_bagoue
>>> data = load_bagoue ().frame
>>> qplotObj = QuickPlot().fit(data)
>>> sns_kwargs ={'annot': False,
... 'linewidth': .5,
... 'center':0 ,
... # 'cmap':'jet_r',
... 'cbar':True}
>>> qplotObj.corrmatrix(cortype='cat', **sns_kwargs)
"""
self.inspect
corc = str(copy.deepcopy(cortype))
cortype= str(cortype).lower().strip()
if cortype.find('num')>=0 or cortype in (
'value', 'digit', 'quan', 'quantitative'):
cortype ='num'
elif cortype.find('cat')>=0 or cortype in (
'string', 'letter', 'qual', 'qualitative'):
cortype ='cat'
if cortype not in ('num', 'cat'):
return ValueError ("Expect 'num' or 'cat' for numerical and"
f" categorical features, not : {corc!r}")
df_= self.data.copy(deep=True)
# df_.reset_index(inplace=True )
df_ = selectfeatures(df_, features = features ,
include= 'number' if cortype =='num' else None,
exclude ='number' if cortype=='cat' else None,
)
features = list(df_.columns ) # for consistency
if cortype =='cat':
for ftn in features:
df_[ftn] = df_[ftn].astype('category').cat.codes
elif cortype =='num':
if 'id' in features:
features.remove('id')
df_= df_.drop('id', axis=1)
ax= sns.heatmap(data =df_[list(features)].corr(
method= method, min_periods=min_periods),
**sns_kws
)
ax.set_xlabel(self.xlabel)
ax.set_ylabel(self.ylabel)
ax.set_title(self.fig_title)
if self.savefig is not None :
plt.savefig(self.savefig,dpi=self.fig_dpi,
orientation =self.fig_orientation)
plt.show() if self.savefig is None else plt.close()
print(" --> Correlation matrix plot successfully done !"
) if self.verbose > 0 else print()
return self
[docs]
def numfeatures(
self,
features=None,
coerce: bool= False,
map_lower_kws=None,
**sns_kws):
"""
Plots qualitative features distribution using correlative aspect. Be
sure to provide numerical features as data arguments.
Parameters
-----------
features: list
List of numerical features to plot for correlating analyses.
will raise an error if features does not exist in the data
coerce: bool,
Constraint the data to read all features and keep only the numerical
values. An error occurs if ``False`` and the data contains some
non-numericalfeatures. *default* is ``False``.
map_lower_kws: dict, Optional
a way to customize plot. Is a dictionnary of sns.pairplot map_lower
kwargs arguments. If the diagram `kind` is ``kde``, plot is customized
with the provided `map_lower_kws` arguments. if ``None``,
will check whether the `diag_kind` argument on `sns_kws` is ``kde``
before triggering the plotting map.
sns_kws: dict,
Keywords word arguments of seabon pairplots. Refer to
http://seaborn.pydata.org/generated/seaborn.pairplot.html for
further details.
data: str or pd.core.DataFrame
Path -like object or Dataframe. Long-form (tidy) dataset for
plotting. Each column should correspond to a variable, and each
row should correspond to an observation. If data is given as
path-like object,`QuickPlot` reads and sanitizes data before
plotting. Be aware in this case to provide the target name and
possible the `classes` for data inspection. Both str or dataframe
need to provide the name of target.
Returns
-------
:class:`QuickPlot` instance
Returns ``self`` for easy method chaining.
Notes
-------
The argument for `data` must be passed to `fit` method. `data`
parameter is not allowed in other `QuickPlot` method. The description
of the parameter `data` is to give a synopsis of the kind of data
the plot expected. An error will raise if force to pass `data`
argument as a keyword arguments.
Examples
---------
>>> from watex.view.plot import QuickPlot
>>> from watex.datasets import load_bagoue
>>> data = load_bagoue ().frame
>>> qkObj = QuickPlot(mapflow =False, tname='flow'
).fit(data)
>>> qkObj.sns_style ='darkgrid',
>>> qkObj.fig_title='Quantitative features correlation'
>>> sns_pkws={'aspect':2 ,
... "height": 2,
# ... 'markers':['o', 'x', 'D', 'H', 's',
# '^', '+', 'S'],
... 'diag_kind':'kde',
... 'corner':False,
... }
>>> marklow = {'level':4,
... 'color':".2"}
>>> qkObj.numfeatures(coerce=True, map_lower_kws=marklow, **sns_pkws)
"""
self.inspect
df_= self.data.copy(deep=True)
try :
df_= df_.astype(float)
except:
if not coerce:
non_num = list(selectfeatures(df_, exclude='number').columns)
msg = f"non-numerical features detected: {smart_format(non_num)}"
warnings.warn(msg + "set 'coerce' to 'True' to only visualize"
" the numerical features.")
raise ValueError (msg + "; set 'coerce'to 'True' to keep the"
" the numerical insights")
df_= selectfeatures(df_, include ='number')
ax =sns.pairplot(data =df_, hue=self.tname,**sns_kws)
if map_lower_kws is not None :
try :
sns_kws['diag_kind']
except:
self._logging.info('Impossible to set `map_lower_kws`.')
warnings.warn(
'``kde|sns.kdeplot``is not found for seaborn pairplot.'
"Impossible to lowering the distribution map.")
else:
if sns_kws['diag_kind']=='kde' :
ax.map_lower(sns.kdeplot, **map_lower_kws)
if self.savefig is not None :
plt.savefig(self.savefig, dpi=self.fig_dpi,
orientation =self.fig_orientation)
plt.show() if self.savefig is None else plt.close ()
return self
[docs]
def joint2features(
self,
features: List [str], *,
join_kws=None, marginals_kws=None,
**sns_kws):
"""
Joint method allows to visualize correlation of two features.
Draw a plot of two features with bivariate and univariate graphs.
Parameters
-----------
features: list
List of numerical features to plot for correlating analyses.
will raise an error if features does not exist in the data
join_kws:dict, optional
Additional keyword arguments are passed to the function used
to draw the plot on the joint Axes, superseding items in the
`joint_kws` dictionary.
marginals_kws: dict, optional
Additional keyword arguments are passed to the function used
to draw the plot on the marginals Axes.
sns_kwargs: dict, optional
keywords arguments of seaborn joinplot methods. Refer to
:ref:`<http://seaborn.pydata.org/generated/seaborn.jointplot.html>`
for more details about usefull kwargs to customize plots.
data: str or pd.core.DataFrame
Path -like object or Dataframe. Long-form (tidy) dataset for
plotting. Each column should correspond to a variable, and each
row should correspond to an observation. If data is given as
path-like object,`QuickPlot` reads and sanitizes data before
plotting. Be aware in this case to provide the target name and
possible the `classes` for data inspection. Both str or dataframe
need to provide the name of target.
Returns
-------
:class:`QuickPlot` instance
Returns ``self`` for easy method chaining.
Notes
-------
The argument for `data` must be passed to `fit` method. `data`
parameter is not allowed in other `QuickPlot` method. The description
of the parameter `data` is to give a synopsis of the kind of data
the plot expected. An error will raise if force to pass `data`
argument as a keyword arguments.
Examples
----------
>>> from watex.view.plot import QuickPlot
>>> from watex.datasets import load_bagoue
>>> data = load_bagoue ().frame
>>> qkObj = QuickPlot( lc='b', sns_style ='darkgrid',
... fig_title='Quantitative features correlation'
... ).fit(data)
>>> sns_pkws={
... 'kind':'reg' , #'kde', 'hex'
... # "hue": 'flow',
... }
>>> joinpl_kws={"color": "r",
'zorder':0, 'levels':6}
>>> plmarg_kws={'color':"r", 'height':-.15, 'clip_on':False}
>>> qkObj.joint2features(features=['ohmS', 'lwi'],
... join_kws=joinpl_kws, marginals_kws=plmarg_kws,
... **sns_pkws,
... )
"""
self.inspect
df_= self.data.copy(deep=True)
if isinstance (features, str):
features =[features]
if features is None:
self._logging.error(f"Valid features are {smart_format(df_.columns)}")
raise PlotError("NoneType can not be a feature nor plotted.")
df_= selectfeatures(df_, features)
# checker whether features is quantitative features
df_ = selectfeatures(df_, include= 'number')
if len(df_.columns) != 2:
raise PlotError(f" Joinplot needs two features. {len(df_.columns)}"
f" {'was' if len(df_.columns)<=1 else 'were'} given")
ax= sns.jointplot(data=df_, x=features[0], y=features[1], **sns_kws)
if join_kws is not None:
join_kws = _assert_all_types(join_kws,dict)
ax.plot_joint(sns.kdeplot, **join_kws)
if marginals_kws is not None:
marginals_kws= _assert_all_types(marginals_kws,dict)
ax.plot_marginals(sns.rugplot, **marginals_kws)
plt.show() if self.savefig is None else plt.close ()
if self.savefig is not None :
plt.savefig(self.savefig,dpi=self.fig_dpi,
orientation =self.fig_orientation)
return self
[docs]
def scatteringfeatures(
self,features: List [str],
*,
relplot_kws= None,
**sns_kws
):
"""
Draw a scatter plot with possibility of several semantic features
groupings.
Indeed `scatteringfeatures` analysis is a process of understanding
how features in a dataset relate to each other and how those
relationships depend on other features. Visualization can be a core
component of this process because, when data are visualized properly,
the human visual system can see trends and patterns that indicate a
relationship.
Parameters
-----------
features: list
List of numerical features to plot for correlating analyses.
will raise an error if features does not exist in the data
relplot_kws: dict, optional
Extra keyword arguments to show the relationship between
two features with semantic mappings of subsets.
refer to :ref:`<http://seaborn.pydata.org/generated/seaborn.relplot.html#seaborn.relplot>`
for more details.
sns_kwargs:dict, optional
kwywords arguments to control what visual semantics are used
to identify the different subsets. For more details, please consult
:ref:`<http://seaborn.pydata.org/generated/seaborn.scatterplot.html>`.
data: str or pd.core.DataFrame
Path -like object or Dataframe. Long-form (tidy) dataset for
plotting. Each column should correspond to a variable, and each
row should correspond to an observation. If data is given as
path-like object,`QuickPlot` reads and sanitizes data before
plotting. Be aware in this case to provide the target name and
possible the `classes` for data inspection. Both str or dataframe
need to provide the name of target.
Returns
-------
:class:`QuickPlot` instance
Returns ``self`` for easy method chaining.
Notes
-------
The argument for `data` must be passed to `fit` method. `data`
parameter is not allowed in other `QuickPlot` method. The description
of the parameter `data` is to give a synopsis of the kind of data
the plot expected. An error will raise if force to pass `data`
argument as a keyword arguments.
Examples
----------
>>> from watex.view.plot import QuickPlot
>>> from watex.datasets import load_bagoue
>>> data = load_bagoue ().frame
>>> qkObj = QuickPlot(lc='b', sns_style ='darkgrid',
... fig_title='geol vs lewel of water inflow',
... xlabel='Level of water inflow (lwi)',
... ylabel='Flow rate in m3/h'
... )
>>>
>>> qkObj.tname='flow' # target the DC-flow rate prediction dataset
>>> qkObj.mapflow=True # to hold category FR0, FR1 etc..
>>> qkObj.fit(data)
>>> marker_list= ['o','s','P', 'H']
>>> markers_dict = {key:mv for key, mv in zip( list (
... dict(qkObj.data ['geol'].value_counts(
... normalize=True)).keys()),
... marker_list)}
>>> sns_pkws={'markers':markers_dict,
... 'sizes':(20, 200),
... "hue":'geol',
... 'style':'geol',
... "palette":'deep',
... 'legend':'full',
... # "hue_norm":(0,7)
... }
>>> regpl_kws = {'col':'flow',
... 'hue':'lwi',
... 'style':'geol',
... 'kind':'scatter'
... }
>>> qkObj.scatteringfeatures(features=['lwi', 'flow'],
... relplot_kws=regpl_kws,
... **sns_pkws,
... )
"""
self.inspect
df_= self.data.copy(deep=True)
# controller function
if isinstance (features, str):
features =[features]
if features is None:
self._logging.error(f"Valid features are {smart_format(df_.columns)}")
raise PlotError("NoneType can not be a feature nor plotted.")
if len(features) < 2:
raise PlotError(f" Scatterplot needs at least two features. {len(df_.columns)}"
f" {'was' if len(df_.columns)<=1 else 'were'} given")
# assert wether the feature exists
selectfeatures(df_, features)
ax= sns.scatterplot(data=df_, x=features[0], y=features[1],
**sns_kws)
ax.set_xlabel(self.xlabel)
ax.set_ylabel(self.ylabel)
ax.set_title(self.fig_title)
if relplot_kws is not None:
relplot_kws = _assert_all_types(relplot_kws, dict)
sns.relplot(data=df_, x= features[0], y=features[1],
**relplot_kws)
if self.savefig is not None :
plt.savefig(self.savefig,dpi=self.fig_dpi,
orientation =self.fig_orientation)
plt.show() if self.savefig is None else plt.close ()
return self
[docs]
def discussingfeatures(
self, features, *,
map_kws: Optional[dict]=None,
map_func: Optional[F] = None,
**sns_kws)-> None:
"""
Provides the features names at least 04 and discuss with
their distribution.
This method maps a dataset onto multiple axes arrayed in a grid of
rows and columns that correspond to levels of features in the dataset.
The plots produced are often called "lattice", "trellis", or
'small-multiple' graphics.
Parameters
-----------
features: list
List of features for discussing. The number of recommended
features for better analysis is four (04) classified as below:
features_disposal = ['x', 'y', 'col', 'target|hue']
where:
- `x` is the features hold to the x-axis, *default* is``ohmS``
- `y` is the feature located on y_xis, *default* is ``sfi``
- `col` is the feature on column subset, *default` is ``col``
- `target` or `hue` for targetted examples, *default* is ``flow``
If 03 `features` are given, the latter is considered as a `target`
map_kws:dict, optional
Extra keyword arguments for mapping plot.
func_map: callable, Optional
callable object, is a plot style function. Can be a 'matplotlib-pyplot'
function like ``plt.scatter`` or 'seaborn-scatterplot' like
``sns.scatterplot``. The *default* is ``sns.scatterplot``.
sns_kwargs: dict, optional
kwywords arguments to control what visual semantics are used
to identify the different subsets. For more details, please consult
:ref:`<http://seaborn.pydata.org/generated/seaborn.FacetGrid.html>`.
data: str or pd.core.DataFrame
Path -like object or Dataframe. Long-form (tidy) dataset for
plotting. Each column should correspond to a variable, and each
row should correspond to an observation. If data is given as
path-like object,`QuickPlot` reads and sanitizes data before
plotting. Be aware in this case to provide the target name and
possible the `classes` for data inspection. Both str or dataframe
need to provide the name of target.
Returns
-------
:class:`QuickPlot` instance
Returns ``self`` for easy method chaining.
Notes
-------
The argument for `data` must be passed to `fit` method. `data`
parameter is not allowed in other `QuickPlot` method. The description
of the parameter `data` is to give a synopsis of the kind of data
the plot expected. An error will raise if force to pass `data`
argument as a keyword arguments.
Examples
--------
>>> from watex.view.plot import QuickPlot
>>> from watex.datasets import load_bagoue
>>> data = load_bagoue ().frame
>>> qkObj = QuickPlot( leg_kws={'loc':'upper right'},
... fig_title = '`sfi` vs`ohmS|`geol`',
... )
>>> qkObj.tname='flow' # target the DC-flow rate prediction dataset
>>> qkObj.mapflow=True # to hold category FR0, FR1 etc..
>>> qkObj.fit(data)
>>> sns_pkws={'aspect':2 ,
... "height": 2,
... }
>>> map_kws={'edgecolor':"w"}
>>> qkObj.discussingfeatures(features =['ohmS', 'sfi','geol', 'flow'],
... map_kws=map_kws, **sns_pkws
... )
"""
self.inspect
df_= self.data.copy(deep=True)
if isinstance (features, str ):
features =[features]
if len(features)>4:
if self.verbose:
self._logging.debug(
'Features length provided is = {0:02}. The first four '
'features `{1}` are used for joinplot.'.format(
len(features), features[:4]))
features=list(features)[:4]
elif len(features)<=2:
if len(features)==2:verb, pl='are','s'
else:verb, pl='is',''
if self.verbose:
self._logging.error(
'Expect three features at least. {0} '
'{1} given.'.format(len(features), verb))
raise PlotError(
'{0:02} feature{1} {2} given. Expect at least 03 '
'features!'.format(len(features),pl, verb))
elif len(features)==3:
msg='03 Features are given. The last feature `{}` should be'\
' considered as the`targetted`feature or `hue` value.'.format(
features[-1])
if self.verbose:
self._logging.debug(msg)
warnings.warn(
'03 features are given, the last one `{}` is used as '
'target!'.format(features[-1]))
features.insert(2, None)
ax= sns.FacetGrid(data=df_, col=features[-2], hue= features[-1],
**sns_kws)
if map_func is None:
map_func = sns.scatterplot #plt.scatter
if map_func is not None :
if not hasattr(map_func, '__call__'):
raise TypeError(
f'map_func must be a callable object not {map_func.__name__!r}'
)
if map_kws is None :
map_kws = _assert_all_types(map_kws,dict)
map_kws={'edgecolor':"w"}
if (map_func and map_kws) is not None:
ax.map(map_func, features[0], features[1],
**map_kws).add_legend(**self.leg_kws)
if self.savefig is not None :
plt.savefig(self.savefig, dpi=self.fig_dpi,
orientation =self.fig_orientation)
plt.show() if self.savefig is None else plt.close ()
return self
[docs]
def naiveviz(
self,
x:str =None,
y:str =None,
kind:str ='scatter',
s_col ='lwi',
leg_kws:dict ={},
**pd_kws
):
""" Creates a plot to visualize the samples distributions
according to the geographical coordinates `x` and `y`.
Parameters
-----------
x: str ,
Column name to hold the x-axis values
y: str,
column na me to hold the y-axis values
s_col: column for scatter points. 'Default is ``fs`` time the features
column `lwi`.
pd_kws: dict, optional,
Pandas plot keywords arguments
leg_kws:dict, kws
Matplotlib legend keywords arguments
data: str or pd.core.DataFrame
Path -like object or Dataframe. Long-form (tidy) dataset for
plotting. Each column should correspond to a variable, and each
row should correspond to an observation. If data is given as
path-like object,`QuickPlot` reads and sanitizes data before
plotting. Be aware in this case to provide the target name and
possible the `classes` for data inspection. Both str or dataframe
need to provide the name of target.
Returns
-------
:class:`QuickPlot` instance
Returns ``self`` for easy method chaining.
Notes
-------
The argument for `data` must be passed to `fit` method. `data`
parameter is not allowed in other `QuickPlot` method. The description
of the parameter `data` is to give a synopsis of the kind of data
the plot expected. An error will raise if force to pass `data`
argument as a keyword arguments.
Examples
---------
>>> from watex.transformers import StratifiedWithCategoryAdder
>>> from watex.view.plot import QuickPlot
>>> from watex.datasets import load_bagoue
>>> df = load_bagoue ().frame
>>> stratifiedNumObj= StratifiedWithCategoryAdder('flow')
>>> strat_train_set , *_= \
... stratifiedNumObj.fit_transform(X=df)
>>> pd_kws ={'alpha': 0.4,
... 'label': 'flow m3/h',
... 'c':'flow',
... 'cmap':plt.get_cmap('jet'),
... 'colorbar':True}
>>> qkObj=QuickPlot(fs=25.)
>>> qkObj.fit(strat_train_set)
>>> qkObj.naiveviz( x= 'east', y='north', **pd_kws)
"""
self.inspect
df_= self.data.copy(deep=True)
# visualize the data and get insights
if 's' not in pd_kws.keys():
pd_kws['s'] = df_[s_col]* self.fs
df_.plot(kind=kind, x=x, y=y, **pd_kws)
self.leg_kws = self.leg_kws or dict ()
plt.legend(**leg_kws)
if self.savefig is not None :
plt.savefig(self.savefig,dpi=self.fig_dpi,
orientation =self.fig_orientation)
plt.show () if self.savefig is None else plt.close()
return self
def __repr__(self):
""" Pretty format for programmer guidance following the API... """
return repr_callable_obj (self, skip ='y')
def __getattr__(self, name):
if not name.endswith ('__') and name.endswith ('_'):
raise NotFittedError (
f"{self.__class__.__name__!r} instance is not fitted yet."
" Call 'fit' method with appropriate arguments before"
f" retreiving the attribute {name!r} value."
)
rv = smart_strobj_recognition(name, self.__dict__, deep =True)
appender = "" if rv is None else f'. Do you mean {rv!r}'
raise AttributeError (
f'{self.__class__.__name__!r} object has no attribute {name!r}'
f'{appender}{"" if rv is None else "?"}'
)
@property
def inspect (self):
""" Inspect object whether is fitted or not"""
msg = ( "{obj.__class__.__name__} instance is not fitted yet."
" Call 'fit' with appropriate arguments before using"
" this method"
)
if not hasattr (self, 'data_'):
raise NotFittedError(msg.format(
obj=self)
)
return 1
ExPlot .__doc__="""\
Exploratory plot for data analysis
`ExPlot` is a shadow class. Explore data is needed to create a model since
it gives a feel for the data and also at great excuses to meet and discuss
issues with business units that controls the data. `ExPlot` methods i.e.
return an instancied object that inherits from :class:`watex.property.Baseplots`
ABC (Abstract Base Class) for visualization.
Parameters
-------------
{params.base.savefig}
{params.base.fig_dpi}
{params.base.fig_num}
{params.base.fig_size}
{params.base.fig_orientation}
{params.base.fig_title}
{params.base.fs}
{params.base.ls}
{params.base.lc}
{params.base.lw}
{params.base.alpha}
{params.base.font_weight}
{params.base.font_style}
{params.base.font_size}
{params.base.ms}
{params.base.marker}
{params.base.marker_facecolor}
{params.base.marker_edgecolor}
{params.base.marker_edgewidth}
{params.base.xminorticks}
{params.base.yminorticks}
{params.base.bins}
{params.base.xlim}
{params.base.ylim}
{params.base.xlabel}
{params.base.ylabel}
{params.base.rotate_xlabel}
{params.base.rotate_ylabel}
{params.base.leg_kws}
{params.base.plt_kws}
{params.base.glc}
{params.base.glw}
{params.base.galpha}
{params.base.gaxis}
{params.base.gwhich}
{params.base.tp_axis}
{params.base.tp_labelsize}
{params.base.tp_bottom}
{params.base.tp_labelbottom}
{params.base.tp_labeltop}
{params.base.cb_orientation}
{params.base.cb_aspect}
{params.base.cb_shrink}
{params.base.cb_pad}
{params.base.cb_anchor}
{params.base.cb_panchor}
{params.base.cb_label}
{params.base.cb_spacing}
{params.base.cb_drawedges}
{params.sns.sns_orient}
{params.sns.sns_style}
{params.sns.sns_palette}
{params.sns.sns_height}
{params.sns.sns_aspect}
Returns
--------
{returns.self}
Examples
---------
>>> import pandas as pd
>>> from watex.view import ExPlot
>>> data = pd.read_csv ('data/geodata/main.bagciv.data.csv' )
>>> ExPlot(fig_size = (12, 4)).fit(data).missing(kind ='corr')
... <watex.view.plot.ExPlot at 0x21162a975e0>
""".format(
params=_param_docs,
returns= _core_docs["returns"],
)
QuickPlot.__doc__="""\
Special class dealing with analysis modules for quick diagrams,
histograms and bar visualizations.
Originally, it was designed for the flow rate prediction, however, it still
works with any other dataset by following the parameters details.
Parameters
-------------
{params.core.data}
{params.core.y}
{params.core.tname}
{params.qdoc.classes}
{params.qdoc.mapflow}
{params.base.savefig}
{params.base.fig_dpi}
{params.base.fig_num}
{params.base.fig_size}
{params.base.fig_orientation}
{params.base.fig_title}
{params.base.fs}
{params.base.ls}
{params.base.lc}
{params.base.lw}
{params.base.alpha}
{params.base.font_weight}
{params.base.font_style}
{params.base.font_size}
{params.base.ms}
{params.base.marker}
{params.base.marker_facecolor}
{params.base.marker_edgecolor}
{params.base.marker_edgewidth}
{params.base.xminorticks}
{params.base.yminorticks}
{params.base.bins}
{params.base.xlim}
{params.base.ylim}
{params.base.xlabel}
{params.base.ylabel}
{params.base.rotate_xlabel}
{params.base.rotate_ylabel}
{params.base.leg_kws}
{params.base.plt_kws}
{params.base.glc}
{params.base.glw}
{params.base.galpha}
{params.base.gaxis}
{params.base.gwhich}
{params.base.tp_axis}
{params.base.tp_labelsize}
{params.base.tp_bottom}
{params.base.tp_labelbottom}
{params.base.tp_labeltop}
{params.base.cb_orientation}
{params.base.cb_aspect}
{params.base.cb_shrink}
{params.base.cb_pad}
{params.base.cb_anchor}
{params.base.cb_panchor}
{params.base.cb_label}
{params.base.cb_spacing}
{params.base.cb_drawedges}
{params.sns.sns_orient}
{params.sns.sns_style}
{params.sns.sns_palette}
{params.sns.sns_height}
{params.sns.sns_aspect}
Returns
--------
{returns.self}
Examples
---------
>>> from watex.view.plot import QuickPlot
>>> data = 'data/geodata/main.bagciv.data.csv'
>>> qkObj = QuickPlot( leg_kws= dict( loc='upper right'),
... fig_title = '`sfi` vs`ohmS|`geol`',
... )
>>> qkObj.tname='flow' # target the DC-flow rate prediction dataset
>>> qkObj.mapflow=True # to hold category FR0, FR1 etc..
>>> qkObj.fit(data)
>>> sns_pkws= dict ( aspect = 2 ,
... height= 2,
... )
>>> map_kws= dict( edgecolor="w")
>>> qkObj.discussingfeatures(features =['ohmS', 'sfi','geol', 'flow'],
... map_kws=map_kws, **sns_pkws
... )
""".format(
params=_param_docs,
returns= _core_docs["returns"],
)
TPlot.__doc__="""\
Tensor plot from EMAP or AMT processing data.
`TPlot` is a :term:`Tensor` (Impedances , resistivity and phases ) plot class.
Explore SEG ( Society of Exploration Geophysicist ) class data. Plot recovery
tensors. `TPlot` methods returns an instancied object that inherits
from :class:`watex.property.Baseplots` ABC (Abstract Base Class) for
visualization.
Parameters
------------
window_size : int
the length of the window. Must be greater than 1 and preferably
an odd integer number. Default is ``5``
component: str
field tensors direction. It can be ``xx``, ``xy``,``yx``, ``yy``. If
`arr2d`` is provided, no need to give an argument. It become useful
when a collection of EDI-objects is provided. If don't specify, the
resistivity and phase value at component `xy` should be fetched for
correction by default. Change the component value to get the appropriate
data for correction. Default is ``xy``.
mode: str , ['valid', 'same'], default='same'
mode of the border trimming. Should be 'valid' or 'same'.'valid' is used
for regular trimimg whereas the 'same' is used for appending the first
and last value of resistivity. Any other argument except 'valid' should
be considered as 'same' argument. Default is ``same``.
method: str, default ``slinear``
Interpolation technique to use. Can be ``nearest``or ``pad``. Refer to
the documentation of :doc:`~.interpolate2d`.
out : str
Value to export. Can be ``sfactor``, ``tensor`` for corrections factor
and impedance tensor. Any other values will export the static corrected
resistivity ``srho``.
c : int,
A window-width expansion factor that must be input to the filter
adaptation process to control the roll-off characteristics
of the applied Hanning window. It is recommended to select `c` between
``1`` and ``4``. Default is ``2``.
distance: float
The step between two stations/sites. If given, it creates an array of
position for plotting purpose. Default value is ``50`` meters.
prefix: str
string value to add as prefix of given id. Prefix can be the site
name. Default is ``S``.
how: str
Mode to index the station. Default is 'Python indexing' i.e.
the counting of stations would starts by 0. Any other mode will
start the counting by 1.
{params.base.savefig}
{params.base.fig_dpi}
{params.base.fig_num}
{params.base.fig_size}
{params.base.fig_orientation}
{params.base.fig_title}
{params.base.fs}
{params.base.ls}
{params.base.lc}
{params.base.lw}
{params.base.alpha}
{params.base.font_weight}
{params.base.font_style}
{params.base.font_size}
{params.base.ms}
{params.base.marker}
{params.base.marker_facecolor}
{params.base.marker_edgecolor}
{params.base.marker_edgewidth}
{params.base.xminorticks}
{params.base.yminorticks}
{params.base.bins}
{params.base.xlim}
{params.base.ylim}
{params.base.xlabel}
{params.base.ylabel}
{params.base.rotate_xlabel}
{params.base.rotate_ylabel}
{params.base.leg_kws}
{params.base.plt_kws}
{params.base.glc}
{params.base.glw}
{params.base.galpha}
{params.base.gaxis}
{params.base.gwhich}
{params.base.tp_axis}
{params.base.tp_labelsize}
{params.base.tp_bottom}
{params.base.tp_labelbottom}
{params.base.tp_labeltop}
{params.base.cb_orientation}
{params.base.cb_aspect}
{params.base.cb_shrink}
{params.base.cb_pad}
{params.base.cb_anchor}
{params.base.cb_panchor}
{params.base.cb_label}
{params.base.cb_spacing}
{params.base.cb_drawedges}
{params.sns.sns_orient}
{params.sns.sns_style}
{params.sns.sns_palette}
{params.sns.sns_height}
{params.sns.sns_aspect}
Returns
--------
{returns.self}
Examples
---------
>>> from watex.view.plot import TPlot
>>> from watex.datasets import load_edis
>>> plot_kws = dict( ylabel = '$Log_{{10}}Frequency [Hz]$',
xlabel = '$Distance(m)$',
cb_label = '$Log_{{10}}Rhoa[\Omega.m$]',
fig_size =(6, 3),
font_size =7.,
rotate_xlabel=45,
imshow_interp='bicubic',
)
>>> edi_data =load_edis (return_data= True, samples=7 )
>>> t= TPlot(**plot_kws ).fit(edi_data)
>>> t.fit(edi_data ).plot_tensor2d (to_log10=True )
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|Data collected = 7 |EDI success. read= 7 |Rate = 100.0 %|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Out[150]: <AxesSubplot:xlabel='$Distance(m)$', ylabel='$Log_{{10}}Frequency [Hz]$'>
""".format(
params=_param_docs,
returns= _core_docs["returns"],
)
[docs]
def viewtemplate (y, /, xlabel=None, ylabel =None, **kws):
"""
Quick view template
Parameters
-----------
y: Arraylike , shape (N, )
xlabel: str, Optional
Label for naming the x-abscissia
ylabel: str, Optional,
Label for naming the y-coordinates.
kws: dict,
keywords argument passed to :func:`matplotlib.pyplot.plot`
"""
label =kws.pop('label', None)
# create figure obj
obj = ExPlot()
fig = plt.figure(figsize = obj.fig_size)
ax = fig.add_subplot(1,1,1)
ax.plot(y,
color= obj.lc,
linewidth = obj.lw,
linestyle = obj.ls ,
label =label,
**kws
)
if obj.xlabel is None:
obj.xlabel =xlabel or ''
if obj.ylabel is None:
obj.ylabel =ylabel or ''
ax.set_xlabel( obj.xlabel,
fontsize= .5 * obj.font_size * obj.fs
)
ax.set_ylabel (obj.ylabel,
fontsize= .5 * obj.font_size * obj.fs
)
ax.tick_params(axis='both',
labelsize=.5 * obj.font_size * obj.fs
)
if obj.show_grid is True :
if obj.gwhich =='minor':
ax.minorticks_on()
ax.grid(obj.show_grid,
axis=obj.gaxis,
which = obj.gwhich,
color = obj.gc,
linestyle=obj.gls,
linewidth=obj.glw,
alpha = obj.galpha
)
if len(obj.leg_kws) ==0 or 'loc' not in obj.leg_kws.keys():
obj.leg_kws['loc']='upper left'
ax.legend(**obj.leg_kws)
plt.show()
if obj.savefig is not None :
plt.savefig(obj.savefig,
dpi=obj.fig_dpi,
orientation =obj.fig_orientation
)
# import matplotlib.cm as cm
# import matplotlib.colorbar as mplcb
# from mpl_toolkits.axes_grid1 import make_axes_locatable
# from matplotlib.ticker import MultipleLocator, NullLocator
# import matplotlib.gridspec as gspec