8th day of python challenges 111-117
This commit is contained in:
@@ -0,0 +1,85 @@
|
||||
from pandas._config import get_option
|
||||
|
||||
from pandas.plotting._matplotlib.boxplot import (
|
||||
BoxPlot,
|
||||
boxplot,
|
||||
boxplot_frame,
|
||||
boxplot_frame_groupby,
|
||||
)
|
||||
from pandas.plotting._matplotlib.converter import deregister, register
|
||||
from pandas.plotting._matplotlib.core import (
|
||||
AreaPlot,
|
||||
BarhPlot,
|
||||
BarPlot,
|
||||
HexBinPlot,
|
||||
LinePlot,
|
||||
PiePlot,
|
||||
ScatterPlot,
|
||||
)
|
||||
from pandas.plotting._matplotlib.hist import HistPlot, KdePlot, hist_frame, hist_series
|
||||
from pandas.plotting._matplotlib.misc import (
|
||||
andrews_curves,
|
||||
autocorrelation_plot,
|
||||
bootstrap_plot,
|
||||
lag_plot,
|
||||
parallel_coordinates,
|
||||
radviz,
|
||||
scatter_matrix,
|
||||
)
|
||||
from pandas.plotting._matplotlib.timeseries import tsplot
|
||||
from pandas.plotting._matplotlib.tools import table
|
||||
|
||||
PLOT_CLASSES = {
|
||||
"line": LinePlot,
|
||||
"bar": BarPlot,
|
||||
"barh": BarhPlot,
|
||||
"box": BoxPlot,
|
||||
"hist": HistPlot,
|
||||
"kde": KdePlot,
|
||||
"area": AreaPlot,
|
||||
"pie": PiePlot,
|
||||
"scatter": ScatterPlot,
|
||||
"hexbin": HexBinPlot,
|
||||
}
|
||||
|
||||
if get_option("plotting.matplotlib.register_converters"):
|
||||
register(explicit=False)
|
||||
|
||||
|
||||
def plot(data, kind, **kwargs):
|
||||
# Importing pyplot at the top of the file (before the converters are
|
||||
# registered) causes problems in matplotlib 2 (converters seem to not
|
||||
# work)
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if kwargs.pop("reuse_plot", False):
|
||||
ax = kwargs.get("ax")
|
||||
if ax is None and len(plt.get_fignums()) > 0:
|
||||
with plt.rc_context():
|
||||
ax = plt.gca()
|
||||
kwargs["ax"] = getattr(ax, "left_ax", ax)
|
||||
plot_obj = PLOT_CLASSES[kind](data, **kwargs)
|
||||
plot_obj.generate()
|
||||
plot_obj.draw()
|
||||
return plot_obj.result
|
||||
|
||||
|
||||
__all__ = [
|
||||
"plot",
|
||||
"hist_series",
|
||||
"hist_frame",
|
||||
"boxplot",
|
||||
"boxplot_frame",
|
||||
"boxplot_frame_groupby",
|
||||
"tsplot",
|
||||
"table",
|
||||
"andrews_curves",
|
||||
"autocorrelation_plot",
|
||||
"bootstrap_plot",
|
||||
"lag_plot",
|
||||
"parallel_coordinates",
|
||||
"radviz",
|
||||
"scatter_matrix",
|
||||
"register",
|
||||
"deregister",
|
||||
]
|
@@ -0,0 +1,416 @@
|
||||
from collections import namedtuple
|
||||
import warnings
|
||||
|
||||
from matplotlib.artist import setp
|
||||
import numpy as np
|
||||
|
||||
from pandas.core.dtypes.generic import ABCSeries
|
||||
from pandas.core.dtypes.missing import remove_na_arraylike
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from pandas.io.formats.printing import pprint_thing
|
||||
from pandas.plotting._matplotlib import converter
|
||||
from pandas.plotting._matplotlib.core import LinePlot, MPLPlot
|
||||
from pandas.plotting._matplotlib.style import _get_standard_colors
|
||||
from pandas.plotting._matplotlib.tools import _flatten, _subplots
|
||||
|
||||
|
||||
class BoxPlot(LinePlot):
|
||||
_kind = "box"
|
||||
_layout_type = "horizontal"
|
||||
|
||||
_valid_return_types = (None, "axes", "dict", "both")
|
||||
# namedtuple to hold results
|
||||
BP = namedtuple("Boxplot", ["ax", "lines"])
|
||||
|
||||
def __init__(self, data, return_type="axes", **kwargs):
|
||||
# Do not call LinePlot.__init__ which may fill nan
|
||||
if return_type not in self._valid_return_types:
|
||||
raise ValueError("return_type must be {None, 'axes', 'dict', 'both'}")
|
||||
|
||||
self.return_type = return_type
|
||||
MPLPlot.__init__(self, data, **kwargs)
|
||||
|
||||
def _args_adjust(self):
|
||||
if self.subplots:
|
||||
# Disable label ax sharing. Otherwise, all subplots shows last
|
||||
# column label
|
||||
if self.orientation == "vertical":
|
||||
self.sharex = False
|
||||
else:
|
||||
self.sharey = False
|
||||
|
||||
@classmethod
|
||||
def _plot(cls, ax, y, column_num=None, return_type="axes", **kwds):
|
||||
if y.ndim == 2:
|
||||
y = [remove_na_arraylike(v) for v in y]
|
||||
# Boxplot fails with empty arrays, so need to add a NaN
|
||||
# if any cols are empty
|
||||
# GH 8181
|
||||
y = [v if v.size > 0 else np.array([np.nan]) for v in y]
|
||||
else:
|
||||
y = remove_na_arraylike(y)
|
||||
bp = ax.boxplot(y, **kwds)
|
||||
|
||||
if return_type == "dict":
|
||||
return bp, bp
|
||||
elif return_type == "both":
|
||||
return cls.BP(ax=ax, lines=bp), bp
|
||||
else:
|
||||
return ax, bp
|
||||
|
||||
def _validate_color_args(self):
|
||||
if "color" in self.kwds:
|
||||
if self.colormap is not None:
|
||||
warnings.warn(
|
||||
"'color' and 'colormap' cannot be used "
|
||||
"simultaneously. Using 'color'"
|
||||
)
|
||||
self.color = self.kwds.pop("color")
|
||||
|
||||
if isinstance(self.color, dict):
|
||||
valid_keys = ["boxes", "whiskers", "medians", "caps"]
|
||||
for key, values in self.color.items():
|
||||
if key not in valid_keys:
|
||||
raise ValueError(
|
||||
"color dict contains invalid "
|
||||
"key '{0}' "
|
||||
"The key must be either {1}".format(key, valid_keys)
|
||||
)
|
||||
else:
|
||||
self.color = None
|
||||
|
||||
# get standard colors for default
|
||||
colors = _get_standard_colors(num_colors=3, colormap=self.colormap, color=None)
|
||||
# use 2 colors by default, for box/whisker and median
|
||||
# flier colors isn't needed here
|
||||
# because it can be specified by ``sym`` kw
|
||||
self._boxes_c = colors[0]
|
||||
self._whiskers_c = colors[0]
|
||||
self._medians_c = colors[2]
|
||||
self._caps_c = "k" # mpl default
|
||||
|
||||
def _get_colors(self, num_colors=None, color_kwds="color"):
|
||||
pass
|
||||
|
||||
def maybe_color_bp(self, bp):
|
||||
if isinstance(self.color, dict):
|
||||
boxes = self.color.get("boxes", self._boxes_c)
|
||||
whiskers = self.color.get("whiskers", self._whiskers_c)
|
||||
medians = self.color.get("medians", self._medians_c)
|
||||
caps = self.color.get("caps", self._caps_c)
|
||||
else:
|
||||
# Other types are forwarded to matplotlib
|
||||
# If None, use default colors
|
||||
boxes = self.color or self._boxes_c
|
||||
whiskers = self.color or self._whiskers_c
|
||||
medians = self.color or self._medians_c
|
||||
caps = self.color or self._caps_c
|
||||
|
||||
setp(bp["boxes"], color=boxes, alpha=1)
|
||||
setp(bp["whiskers"], color=whiskers, alpha=1)
|
||||
setp(bp["medians"], color=medians, alpha=1)
|
||||
setp(bp["caps"], color=caps, alpha=1)
|
||||
|
||||
def _make_plot(self):
|
||||
if self.subplots:
|
||||
self._return_obj = pd.Series()
|
||||
|
||||
for i, (label, y) in enumerate(self._iter_data()):
|
||||
ax = self._get_ax(i)
|
||||
kwds = self.kwds.copy()
|
||||
|
||||
ret, bp = self._plot(
|
||||
ax, y, column_num=i, return_type=self.return_type, **kwds
|
||||
)
|
||||
self.maybe_color_bp(bp)
|
||||
self._return_obj[label] = ret
|
||||
|
||||
label = [pprint_thing(label)]
|
||||
self._set_ticklabels(ax, label)
|
||||
else:
|
||||
y = self.data.values.T
|
||||
ax = self._get_ax(0)
|
||||
kwds = self.kwds.copy()
|
||||
|
||||
ret, bp = self._plot(
|
||||
ax, y, column_num=0, return_type=self.return_type, **kwds
|
||||
)
|
||||
self.maybe_color_bp(bp)
|
||||
self._return_obj = ret
|
||||
|
||||
labels = [l for l, _ in self._iter_data()]
|
||||
labels = [pprint_thing(l) for l in labels]
|
||||
if not self.use_index:
|
||||
labels = [pprint_thing(key) for key in range(len(labels))]
|
||||
self._set_ticklabels(ax, labels)
|
||||
|
||||
def _set_ticklabels(self, ax, labels):
|
||||
if self.orientation == "vertical":
|
||||
ax.set_xticklabels(labels)
|
||||
else:
|
||||
ax.set_yticklabels(labels)
|
||||
|
||||
def _make_legend(self):
|
||||
pass
|
||||
|
||||
def _post_plot_logic(self, ax, data):
|
||||
pass
|
||||
|
||||
@property
|
||||
def orientation(self):
|
||||
if self.kwds.get("vert", True):
|
||||
return "vertical"
|
||||
else:
|
||||
return "horizontal"
|
||||
|
||||
@property
|
||||
def result(self):
|
||||
if self.return_type is None:
|
||||
return super().result
|
||||
else:
|
||||
return self._return_obj
|
||||
|
||||
|
||||
def _grouped_plot_by_column(
|
||||
plotf,
|
||||
data,
|
||||
columns=None,
|
||||
by=None,
|
||||
numeric_only=True,
|
||||
grid=False,
|
||||
figsize=None,
|
||||
ax=None,
|
||||
layout=None,
|
||||
return_type=None,
|
||||
**kwargs
|
||||
):
|
||||
grouped = data.groupby(by)
|
||||
if columns is None:
|
||||
if not isinstance(by, (list, tuple)):
|
||||
by = [by]
|
||||
columns = data._get_numeric_data().columns.difference(by)
|
||||
naxes = len(columns)
|
||||
fig, axes = _subplots(
|
||||
naxes=naxes, sharex=True, sharey=True, figsize=figsize, ax=ax, layout=layout
|
||||
)
|
||||
|
||||
_axes = _flatten(axes)
|
||||
|
||||
ax_values = []
|
||||
|
||||
for i, col in enumerate(columns):
|
||||
ax = _axes[i]
|
||||
gp_col = grouped[col]
|
||||
keys, values = zip(*gp_col)
|
||||
re_plotf = plotf(keys, values, ax, **kwargs)
|
||||
ax.set_title(col)
|
||||
ax.set_xlabel(pprint_thing(by))
|
||||
ax_values.append(re_plotf)
|
||||
ax.grid(grid)
|
||||
|
||||
result = pd.Series(ax_values, index=columns)
|
||||
|
||||
# Return axes in multiplot case, maybe revisit later # 985
|
||||
if return_type is None:
|
||||
result = axes
|
||||
|
||||
byline = by[0] if len(by) == 1 else by
|
||||
fig.suptitle("Boxplot grouped by {byline}".format(byline=byline))
|
||||
fig.subplots_adjust(bottom=0.15, top=0.9, left=0.1, right=0.9, wspace=0.2)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def boxplot(
|
||||
data,
|
||||
column=None,
|
||||
by=None,
|
||||
ax=None,
|
||||
fontsize=None,
|
||||
rot=0,
|
||||
grid=True,
|
||||
figsize=None,
|
||||
layout=None,
|
||||
return_type=None,
|
||||
**kwds
|
||||
):
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# validate return_type:
|
||||
if return_type not in BoxPlot._valid_return_types:
|
||||
raise ValueError("return_type must be {'axes', 'dict', 'both'}")
|
||||
|
||||
if isinstance(data, ABCSeries):
|
||||
data = data.to_frame("x")
|
||||
column = "x"
|
||||
|
||||
def _get_colors():
|
||||
# num_colors=3 is required as method maybe_color_bp takes the colors
|
||||
# in positions 0 and 2.
|
||||
return _get_standard_colors(color=kwds.get("color"), num_colors=3)
|
||||
|
||||
def maybe_color_bp(bp):
|
||||
if "color" not in kwds:
|
||||
setp(bp["boxes"], color=colors[0], alpha=1)
|
||||
setp(bp["whiskers"], color=colors[0], alpha=1)
|
||||
setp(bp["medians"], color=colors[2], alpha=1)
|
||||
|
||||
def plot_group(keys, values, ax):
|
||||
keys = [pprint_thing(x) for x in keys]
|
||||
values = [np.asarray(remove_na_arraylike(v)) for v in values]
|
||||
bp = ax.boxplot(values, **kwds)
|
||||
if fontsize is not None:
|
||||
ax.tick_params(axis="both", labelsize=fontsize)
|
||||
if kwds.get("vert", 1):
|
||||
ax.set_xticklabels(keys, rotation=rot)
|
||||
else:
|
||||
ax.set_yticklabels(keys, rotation=rot)
|
||||
maybe_color_bp(bp)
|
||||
|
||||
# Return axes in multiplot case, maybe revisit later # 985
|
||||
if return_type == "dict":
|
||||
return bp
|
||||
elif return_type == "both":
|
||||
return BoxPlot.BP(ax=ax, lines=bp)
|
||||
else:
|
||||
return ax
|
||||
|
||||
colors = _get_colors()
|
||||
if column is None:
|
||||
columns = None
|
||||
else:
|
||||
if isinstance(column, (list, tuple)):
|
||||
columns = column
|
||||
else:
|
||||
columns = [column]
|
||||
|
||||
if by is not None:
|
||||
# Prefer array return type for 2-D plots to match the subplot layout
|
||||
# https://github.com/pandas-dev/pandas/pull/12216#issuecomment-241175580
|
||||
result = _grouped_plot_by_column(
|
||||
plot_group,
|
||||
data,
|
||||
columns=columns,
|
||||
by=by,
|
||||
grid=grid,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
layout=layout,
|
||||
return_type=return_type,
|
||||
)
|
||||
else:
|
||||
if return_type is None:
|
||||
return_type = "axes"
|
||||
if layout is not None:
|
||||
raise ValueError(
|
||||
"The 'layout' keyword is not supported when " "'by' is None"
|
||||
)
|
||||
|
||||
if ax is None:
|
||||
rc = {"figure.figsize": figsize} if figsize is not None else {}
|
||||
with plt.rc_context(rc):
|
||||
ax = plt.gca()
|
||||
data = data._get_numeric_data()
|
||||
if columns is None:
|
||||
columns = data.columns
|
||||
else:
|
||||
data = data[columns]
|
||||
|
||||
result = plot_group(columns, data.values.T, ax)
|
||||
ax.grid(grid)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def boxplot_frame(
|
||||
self,
|
||||
column=None,
|
||||
by=None,
|
||||
ax=None,
|
||||
fontsize=None,
|
||||
rot=0,
|
||||
grid=True,
|
||||
figsize=None,
|
||||
layout=None,
|
||||
return_type=None,
|
||||
**kwds
|
||||
):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
converter._WARN = False # no warning for pandas plots
|
||||
ax = boxplot(
|
||||
self,
|
||||
column=column,
|
||||
by=by,
|
||||
ax=ax,
|
||||
fontsize=fontsize,
|
||||
grid=grid,
|
||||
rot=rot,
|
||||
figsize=figsize,
|
||||
layout=layout,
|
||||
return_type=return_type,
|
||||
**kwds
|
||||
)
|
||||
plt.draw_if_interactive()
|
||||
return ax
|
||||
|
||||
|
||||
def boxplot_frame_groupby(
|
||||
grouped,
|
||||
subplots=True,
|
||||
column=None,
|
||||
fontsize=None,
|
||||
rot=0,
|
||||
grid=True,
|
||||
ax=None,
|
||||
figsize=None,
|
||||
layout=None,
|
||||
sharex=False,
|
||||
sharey=True,
|
||||
**kwds
|
||||
):
|
||||
converter._WARN = False # no warning for pandas plots
|
||||
if subplots is True:
|
||||
naxes = len(grouped)
|
||||
fig, axes = _subplots(
|
||||
naxes=naxes,
|
||||
squeeze=False,
|
||||
ax=ax,
|
||||
sharex=sharex,
|
||||
sharey=sharey,
|
||||
figsize=figsize,
|
||||
layout=layout,
|
||||
)
|
||||
axes = _flatten(axes)
|
||||
|
||||
ret = pd.Series()
|
||||
for (key, group), ax in zip(grouped, axes):
|
||||
d = group.boxplot(
|
||||
ax=ax, column=column, fontsize=fontsize, rot=rot, grid=grid, **kwds
|
||||
)
|
||||
ax.set_title(pprint_thing(key))
|
||||
ret.loc[key] = d
|
||||
fig.subplots_adjust(bottom=0.15, top=0.9, left=0.1, right=0.9, wspace=0.2)
|
||||
else:
|
||||
keys, frames = zip(*grouped)
|
||||
if grouped.axis == 0:
|
||||
df = pd.concat(frames, keys=keys, axis=1)
|
||||
else:
|
||||
if len(frames) > 1:
|
||||
df = frames[0].join(frames[1::])
|
||||
else:
|
||||
df = frames[0]
|
||||
ret = df.boxplot(
|
||||
column=column,
|
||||
fontsize=fontsize,
|
||||
rot=rot,
|
||||
grid=grid,
|
||||
ax=ax,
|
||||
figsize=figsize,
|
||||
layout=layout,
|
||||
**kwds
|
||||
)
|
||||
return ret
|
@@ -0,0 +1,22 @@
|
||||
# being a bit too dynamic
|
||||
from distutils.version import LooseVersion
|
||||
import operator
|
||||
|
||||
|
||||
def _mpl_version(version, op):
|
||||
def inner():
|
||||
try:
|
||||
import matplotlib as mpl
|
||||
except ImportError:
|
||||
return False
|
||||
return (
|
||||
op(LooseVersion(mpl.__version__), LooseVersion(version))
|
||||
and str(mpl.__version__)[0] != "0"
|
||||
)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
_mpl_ge_2_2_3 = _mpl_version("2.2.3", operator.ge)
|
||||
_mpl_ge_3_0_0 = _mpl_version("3.0.0", operator.ge)
|
||||
_mpl_ge_3_1_0 = _mpl_version("3.1.0", operator.ge)
|
File diff suppressed because it is too large
Load Diff
1502
venv/lib/python3.6/site-packages/pandas/plotting/_matplotlib/core.py
Normal file
1502
venv/lib/python3.6/site-packages/pandas/plotting/_matplotlib/core.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,421 @@
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from pandas.core.dtypes.common import is_integer, is_list_like
|
||||
from pandas.core.dtypes.generic import ABCDataFrame, ABCIndexClass
|
||||
from pandas.core.dtypes.missing import isna, remove_na_arraylike
|
||||
|
||||
import pandas.core.common as com
|
||||
|
||||
from pandas.io.formats.printing import pprint_thing
|
||||
from pandas.plotting._matplotlib import converter
|
||||
from pandas.plotting._matplotlib.core import LinePlot, MPLPlot
|
||||
from pandas.plotting._matplotlib.tools import _flatten, _set_ticks_props, _subplots
|
||||
|
||||
|
||||
class HistPlot(LinePlot):
|
||||
_kind = "hist"
|
||||
|
||||
def __init__(self, data, bins=10, bottom=0, **kwargs):
|
||||
self.bins = bins # use mpl default
|
||||
self.bottom = bottom
|
||||
# Do not call LinePlot.__init__ which may fill nan
|
||||
MPLPlot.__init__(self, data, **kwargs)
|
||||
|
||||
def _args_adjust(self):
|
||||
if is_integer(self.bins):
|
||||
# create common bin edge
|
||||
values = self.data._convert(datetime=True)._get_numeric_data()
|
||||
values = np.ravel(values)
|
||||
values = values[~isna(values)]
|
||||
|
||||
hist, self.bins = np.histogram(
|
||||
values,
|
||||
bins=self.bins,
|
||||
range=self.kwds.get("range", None),
|
||||
weights=self.kwds.get("weights", None),
|
||||
)
|
||||
|
||||
if is_list_like(self.bottom):
|
||||
self.bottom = np.array(self.bottom)
|
||||
|
||||
@classmethod
|
||||
def _plot(
|
||||
cls,
|
||||
ax,
|
||||
y,
|
||||
style=None,
|
||||
bins=None,
|
||||
bottom=0,
|
||||
column_num=0,
|
||||
stacking_id=None,
|
||||
**kwds
|
||||
):
|
||||
if column_num == 0:
|
||||
cls._initialize_stacker(ax, stacking_id, len(bins) - 1)
|
||||
y = y[~isna(y)]
|
||||
|
||||
base = np.zeros(len(bins) - 1)
|
||||
bottom = bottom + cls._get_stacked_values(ax, stacking_id, base, kwds["label"])
|
||||
# ignore style
|
||||
n, bins, patches = ax.hist(y, bins=bins, bottom=bottom, **kwds)
|
||||
cls._update_stacker(ax, stacking_id, n)
|
||||
return patches
|
||||
|
||||
def _make_plot(self):
|
||||
colors = self._get_colors()
|
||||
stacking_id = self._get_stacking_id()
|
||||
|
||||
for i, (label, y) in enumerate(self._iter_data()):
|
||||
ax = self._get_ax(i)
|
||||
|
||||
kwds = self.kwds.copy()
|
||||
|
||||
label = pprint_thing(label)
|
||||
kwds["label"] = label
|
||||
|
||||
style, kwds = self._apply_style_colors(colors, kwds, i, label)
|
||||
if style is not None:
|
||||
kwds["style"] = style
|
||||
|
||||
kwds = self._make_plot_keywords(kwds, y)
|
||||
artists = self._plot(ax, y, column_num=i, stacking_id=stacking_id, **kwds)
|
||||
self._add_legend_handle(artists[0], label, index=i)
|
||||
|
||||
def _make_plot_keywords(self, kwds, y):
|
||||
"""merge BoxPlot/KdePlot properties to passed kwds"""
|
||||
# y is required for KdePlot
|
||||
kwds["bottom"] = self.bottom
|
||||
kwds["bins"] = self.bins
|
||||
return kwds
|
||||
|
||||
def _post_plot_logic(self, ax, data):
|
||||
if self.orientation == "horizontal":
|
||||
ax.set_xlabel("Frequency")
|
||||
else:
|
||||
ax.set_ylabel("Frequency")
|
||||
|
||||
@property
|
||||
def orientation(self):
|
||||
if self.kwds.get("orientation", None) == "horizontal":
|
||||
return "horizontal"
|
||||
else:
|
||||
return "vertical"
|
||||
|
||||
|
||||
class KdePlot(HistPlot):
|
||||
_kind = "kde"
|
||||
orientation = "vertical"
|
||||
|
||||
def __init__(self, data, bw_method=None, ind=None, **kwargs):
|
||||
MPLPlot.__init__(self, data, **kwargs)
|
||||
self.bw_method = bw_method
|
||||
self.ind = ind
|
||||
|
||||
def _args_adjust(self):
|
||||
pass
|
||||
|
||||
def _get_ind(self, y):
|
||||
if self.ind is None:
|
||||
# np.nanmax() and np.nanmin() ignores the missing values
|
||||
sample_range = np.nanmax(y) - np.nanmin(y)
|
||||
ind = np.linspace(
|
||||
np.nanmin(y) - 0.5 * sample_range,
|
||||
np.nanmax(y) + 0.5 * sample_range,
|
||||
1000,
|
||||
)
|
||||
elif is_integer(self.ind):
|
||||
sample_range = np.nanmax(y) - np.nanmin(y)
|
||||
ind = np.linspace(
|
||||
np.nanmin(y) - 0.5 * sample_range,
|
||||
np.nanmax(y) + 0.5 * sample_range,
|
||||
self.ind,
|
||||
)
|
||||
else:
|
||||
ind = self.ind
|
||||
return ind
|
||||
|
||||
@classmethod
|
||||
def _plot(
|
||||
cls,
|
||||
ax,
|
||||
y,
|
||||
style=None,
|
||||
bw_method=None,
|
||||
ind=None,
|
||||
column_num=None,
|
||||
stacking_id=None,
|
||||
**kwds
|
||||
):
|
||||
from scipy.stats import gaussian_kde
|
||||
|
||||
y = remove_na_arraylike(y)
|
||||
gkde = gaussian_kde(y, bw_method=bw_method)
|
||||
|
||||
y = gkde.evaluate(ind)
|
||||
lines = MPLPlot._plot(ax, ind, y, style=style, **kwds)
|
||||
return lines
|
||||
|
||||
def _make_plot_keywords(self, kwds, y):
|
||||
kwds["bw_method"] = self.bw_method
|
||||
kwds["ind"] = self._get_ind(y)
|
||||
return kwds
|
||||
|
||||
def _post_plot_logic(self, ax, data):
|
||||
ax.set_ylabel("Density")
|
||||
|
||||
|
||||
def _grouped_plot(
|
||||
plotf,
|
||||
data,
|
||||
column=None,
|
||||
by=None,
|
||||
numeric_only=True,
|
||||
figsize=None,
|
||||
sharex=True,
|
||||
sharey=True,
|
||||
layout=None,
|
||||
rot=0,
|
||||
ax=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
if figsize == "default":
|
||||
# allowed to specify mpl default with 'default'
|
||||
warnings.warn(
|
||||
"figsize='default' is deprecated. Specify figure " "size by tuple instead",
|
||||
FutureWarning,
|
||||
stacklevel=5,
|
||||
)
|
||||
figsize = None
|
||||
|
||||
grouped = data.groupby(by)
|
||||
if column is not None:
|
||||
grouped = grouped[column]
|
||||
|
||||
naxes = len(grouped)
|
||||
fig, axes = _subplots(
|
||||
naxes=naxes, figsize=figsize, sharex=sharex, sharey=sharey, ax=ax, layout=layout
|
||||
)
|
||||
|
||||
_axes = _flatten(axes)
|
||||
|
||||
for i, (key, group) in enumerate(grouped):
|
||||
ax = _axes[i]
|
||||
if numeric_only and isinstance(group, ABCDataFrame):
|
||||
group = group._get_numeric_data()
|
||||
plotf(group, ax, **kwargs)
|
||||
ax.set_title(pprint_thing(key))
|
||||
|
||||
return fig, axes
|
||||
|
||||
|
||||
def _grouped_hist(
|
||||
data,
|
||||
column=None,
|
||||
by=None,
|
||||
ax=None,
|
||||
bins=50,
|
||||
figsize=None,
|
||||
layout=None,
|
||||
sharex=False,
|
||||
sharey=False,
|
||||
rot=90,
|
||||
grid=True,
|
||||
xlabelsize=None,
|
||||
xrot=None,
|
||||
ylabelsize=None,
|
||||
yrot=None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Grouped histogram
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : Series/DataFrame
|
||||
column : object, optional
|
||||
by : object, optional
|
||||
ax : axes, optional
|
||||
bins : int, default 50
|
||||
figsize : tuple, optional
|
||||
layout : optional
|
||||
sharex : bool, default False
|
||||
sharey : bool, default False
|
||||
rot : int, default 90
|
||||
grid : bool, default True
|
||||
kwargs : dict, keyword arguments passed to matplotlib.Axes.hist
|
||||
|
||||
Returns
|
||||
-------
|
||||
collection of Matplotlib Axes
|
||||
"""
|
||||
|
||||
def plot_group(group, ax):
|
||||
ax.hist(group.dropna().values, bins=bins, **kwargs)
|
||||
|
||||
converter._WARN = False # no warning for pandas plots
|
||||
xrot = xrot or rot
|
||||
|
||||
fig, axes = _grouped_plot(
|
||||
plot_group,
|
||||
data,
|
||||
column=column,
|
||||
by=by,
|
||||
sharex=sharex,
|
||||
sharey=sharey,
|
||||
ax=ax,
|
||||
figsize=figsize,
|
||||
layout=layout,
|
||||
rot=rot,
|
||||
)
|
||||
|
||||
_set_ticks_props(
|
||||
axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot
|
||||
)
|
||||
|
||||
fig.subplots_adjust(
|
||||
bottom=0.15, top=0.9, left=0.1, right=0.9, hspace=0.5, wspace=0.3
|
||||
)
|
||||
return axes
|
||||
|
||||
|
||||
def hist_series(
|
||||
self,
|
||||
by=None,
|
||||
ax=None,
|
||||
grid=True,
|
||||
xlabelsize=None,
|
||||
xrot=None,
|
||||
ylabelsize=None,
|
||||
yrot=None,
|
||||
figsize=None,
|
||||
bins=10,
|
||||
**kwds
|
||||
):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if by is None:
|
||||
if kwds.get("layout", None) is not None:
|
||||
raise ValueError(
|
||||
"The 'layout' keyword is not supported when " "'by' is None"
|
||||
)
|
||||
# hack until the plotting interface is a bit more unified
|
||||
fig = kwds.pop(
|
||||
"figure", plt.gcf() if plt.get_fignums() else plt.figure(figsize=figsize)
|
||||
)
|
||||
if figsize is not None and tuple(figsize) != tuple(fig.get_size_inches()):
|
||||
fig.set_size_inches(*figsize, forward=True)
|
||||
if ax is None:
|
||||
ax = fig.gca()
|
||||
elif ax.get_figure() != fig:
|
||||
raise AssertionError("passed axis not bound to passed figure")
|
||||
values = self.dropna().values
|
||||
|
||||
ax.hist(values, bins=bins, **kwds)
|
||||
ax.grid(grid)
|
||||
axes = np.array([ax])
|
||||
|
||||
_set_ticks_props(
|
||||
axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot
|
||||
)
|
||||
|
||||
else:
|
||||
if "figure" in kwds:
|
||||
raise ValueError(
|
||||
"Cannot pass 'figure' when using the "
|
||||
"'by' argument, since a new 'Figure' instance "
|
||||
"will be created"
|
||||
)
|
||||
axes = _grouped_hist(
|
||||
self,
|
||||
by=by,
|
||||
ax=ax,
|
||||
grid=grid,
|
||||
figsize=figsize,
|
||||
bins=bins,
|
||||
xlabelsize=xlabelsize,
|
||||
xrot=xrot,
|
||||
ylabelsize=ylabelsize,
|
||||
yrot=yrot,
|
||||
**kwds
|
||||
)
|
||||
|
||||
if hasattr(axes, "ndim"):
|
||||
if axes.ndim == 1 and len(axes) == 1:
|
||||
return axes[0]
|
||||
return axes
|
||||
|
||||
|
||||
def hist_frame(
|
||||
data,
|
||||
column=None,
|
||||
by=None,
|
||||
grid=True,
|
||||
xlabelsize=None,
|
||||
xrot=None,
|
||||
ylabelsize=None,
|
||||
yrot=None,
|
||||
ax=None,
|
||||
sharex=False,
|
||||
sharey=False,
|
||||
figsize=None,
|
||||
layout=None,
|
||||
bins=10,
|
||||
**kwds
|
||||
):
|
||||
converter._WARN = False # no warning for pandas plots
|
||||
if by is not None:
|
||||
axes = _grouped_hist(
|
||||
data,
|
||||
column=column,
|
||||
by=by,
|
||||
ax=ax,
|
||||
grid=grid,
|
||||
figsize=figsize,
|
||||
sharex=sharex,
|
||||
sharey=sharey,
|
||||
layout=layout,
|
||||
bins=bins,
|
||||
xlabelsize=xlabelsize,
|
||||
xrot=xrot,
|
||||
ylabelsize=ylabelsize,
|
||||
yrot=yrot,
|
||||
**kwds
|
||||
)
|
||||
return axes
|
||||
|
||||
if column is not None:
|
||||
if not isinstance(column, (list, np.ndarray, ABCIndexClass)):
|
||||
column = [column]
|
||||
data = data[column]
|
||||
data = data._get_numeric_data()
|
||||
naxes = len(data.columns)
|
||||
|
||||
if naxes == 0:
|
||||
raise ValueError("hist method requires numerical columns, " "nothing to plot.")
|
||||
|
||||
fig, axes = _subplots(
|
||||
naxes=naxes,
|
||||
ax=ax,
|
||||
squeeze=False,
|
||||
sharex=sharex,
|
||||
sharey=sharey,
|
||||
figsize=figsize,
|
||||
layout=layout,
|
||||
)
|
||||
_axes = _flatten(axes)
|
||||
|
||||
for i, col in enumerate(com.try_sort(data.columns)):
|
||||
ax = _axes[i]
|
||||
ax.hist(data[col].dropna().values, bins=bins, **kwds)
|
||||
ax.set_title(col)
|
||||
ax.grid(grid)
|
||||
|
||||
_set_ticks_props(
|
||||
axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot
|
||||
)
|
||||
fig.subplots_adjust(wspace=0.3, hspace=0.3)
|
||||
|
||||
return axes
|
@@ -0,0 +1,431 @@
|
||||
import random
|
||||
|
||||
import matplotlib.lines as mlines
|
||||
import matplotlib.patches as patches
|
||||
import numpy as np
|
||||
|
||||
from pandas.core.dtypes.missing import notna
|
||||
|
||||
from pandas.io.formats.printing import pprint_thing
|
||||
from pandas.plotting._matplotlib.style import _get_standard_colors
|
||||
from pandas.plotting._matplotlib.tools import _set_ticks_props, _subplots
|
||||
|
||||
|
||||
def scatter_matrix(
|
||||
frame,
|
||||
alpha=0.5,
|
||||
figsize=None,
|
||||
ax=None,
|
||||
grid=False,
|
||||
diagonal="hist",
|
||||
marker=".",
|
||||
density_kwds=None,
|
||||
hist_kwds=None,
|
||||
range_padding=0.05,
|
||||
**kwds
|
||||
):
|
||||
df = frame._get_numeric_data()
|
||||
n = df.columns.size
|
||||
naxes = n * n
|
||||
fig, axes = _subplots(naxes=naxes, figsize=figsize, ax=ax, squeeze=False)
|
||||
|
||||
# no gaps between subplots
|
||||
fig.subplots_adjust(wspace=0, hspace=0)
|
||||
|
||||
mask = notna(df)
|
||||
|
||||
marker = _get_marker_compat(marker)
|
||||
|
||||
hist_kwds = hist_kwds or {}
|
||||
density_kwds = density_kwds or {}
|
||||
|
||||
# GH 14855
|
||||
kwds.setdefault("edgecolors", "none")
|
||||
|
||||
boundaries_list = []
|
||||
for a in df.columns:
|
||||
values = df[a].values[mask[a].values]
|
||||
rmin_, rmax_ = np.min(values), np.max(values)
|
||||
rdelta_ext = (rmax_ - rmin_) * range_padding / 2.0
|
||||
boundaries_list.append((rmin_ - rdelta_ext, rmax_ + rdelta_ext))
|
||||
|
||||
for i, a in enumerate(df.columns):
|
||||
for j, b in enumerate(df.columns):
|
||||
ax = axes[i, j]
|
||||
|
||||
if i == j:
|
||||
values = df[a].values[mask[a].values]
|
||||
|
||||
# Deal with the diagonal by drawing a histogram there.
|
||||
if diagonal == "hist":
|
||||
ax.hist(values, **hist_kwds)
|
||||
|
||||
elif diagonal in ("kde", "density"):
|
||||
from scipy.stats import gaussian_kde
|
||||
|
||||
y = values
|
||||
gkde = gaussian_kde(y)
|
||||
ind = np.linspace(y.min(), y.max(), 1000)
|
||||
ax.plot(ind, gkde.evaluate(ind), **density_kwds)
|
||||
|
||||
ax.set_xlim(boundaries_list[i])
|
||||
|
||||
else:
|
||||
common = (mask[a] & mask[b]).values
|
||||
|
||||
ax.scatter(
|
||||
df[b][common], df[a][common], marker=marker, alpha=alpha, **kwds
|
||||
)
|
||||
|
||||
ax.set_xlim(boundaries_list[j])
|
||||
ax.set_ylim(boundaries_list[i])
|
||||
|
||||
ax.set_xlabel(b)
|
||||
ax.set_ylabel(a)
|
||||
|
||||
if j != 0:
|
||||
ax.yaxis.set_visible(False)
|
||||
if i != n - 1:
|
||||
ax.xaxis.set_visible(False)
|
||||
|
||||
if len(df.columns) > 1:
|
||||
lim1 = boundaries_list[0]
|
||||
locs = axes[0][1].yaxis.get_majorticklocs()
|
||||
locs = locs[(lim1[0] <= locs) & (locs <= lim1[1])]
|
||||
adj = (locs - lim1[0]) / (lim1[1] - lim1[0])
|
||||
|
||||
lim0 = axes[0][0].get_ylim()
|
||||
adj = adj * (lim0[1] - lim0[0]) + lim0[0]
|
||||
axes[0][0].yaxis.set_ticks(adj)
|
||||
|
||||
if np.all(locs == locs.astype(int)):
|
||||
# if all ticks are int
|
||||
locs = locs.astype(int)
|
||||
axes[0][0].yaxis.set_ticklabels(locs)
|
||||
|
||||
_set_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0)
|
||||
|
||||
return axes
|
||||
|
||||
|
||||
def _get_marker_compat(marker):
|
||||
if marker not in mlines.lineMarkers:
|
||||
return "o"
|
||||
return marker
|
||||
|
||||
|
||||
def radviz(frame, class_column, ax=None, color=None, colormap=None, **kwds):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def normalize(series):
|
||||
a = min(series)
|
||||
b = max(series)
|
||||
return (series - a) / (b - a)
|
||||
|
||||
n = len(frame)
|
||||
classes = frame[class_column].drop_duplicates()
|
||||
class_col = frame[class_column]
|
||||
df = frame.drop(class_column, axis=1).apply(normalize)
|
||||
|
||||
if ax is None:
|
||||
ax = plt.gca(xlim=[-1, 1], ylim=[-1, 1])
|
||||
|
||||
to_plot = {}
|
||||
colors = _get_standard_colors(
|
||||
num_colors=len(classes), colormap=colormap, color_type="random", color=color
|
||||
)
|
||||
|
||||
for kls in classes:
|
||||
to_plot[kls] = [[], []]
|
||||
|
||||
m = len(frame.columns) - 1
|
||||
s = np.array(
|
||||
[
|
||||
(np.cos(t), np.sin(t))
|
||||
for t in [2.0 * np.pi * (i / float(m)) for i in range(m)]
|
||||
]
|
||||
)
|
||||
|
||||
for i in range(n):
|
||||
row = df.iloc[i].values
|
||||
row_ = np.repeat(np.expand_dims(row, axis=1), 2, axis=1)
|
||||
y = (s * row_).sum(axis=0) / row.sum()
|
||||
kls = class_col.iat[i]
|
||||
to_plot[kls][0].append(y[0])
|
||||
to_plot[kls][1].append(y[1])
|
||||
|
||||
for i, kls in enumerate(classes):
|
||||
ax.scatter(
|
||||
to_plot[kls][0],
|
||||
to_plot[kls][1],
|
||||
color=colors[i],
|
||||
label=pprint_thing(kls),
|
||||
**kwds
|
||||
)
|
||||
ax.legend()
|
||||
|
||||
ax.add_patch(patches.Circle((0.0, 0.0), radius=1.0, facecolor="none"))
|
||||
|
||||
for xy, name in zip(s, df.columns):
|
||||
|
||||
ax.add_patch(patches.Circle(xy, radius=0.025, facecolor="gray"))
|
||||
|
||||
if xy[0] < 0.0 and xy[1] < 0.0:
|
||||
ax.text(
|
||||
xy[0] - 0.025, xy[1] - 0.025, name, ha="right", va="top", size="small"
|
||||
)
|
||||
elif xy[0] < 0.0 and xy[1] >= 0.0:
|
||||
ax.text(
|
||||
xy[0] - 0.025,
|
||||
xy[1] + 0.025,
|
||||
name,
|
||||
ha="right",
|
||||
va="bottom",
|
||||
size="small",
|
||||
)
|
||||
elif xy[0] >= 0.0 and xy[1] < 0.0:
|
||||
ax.text(
|
||||
xy[0] + 0.025, xy[1] - 0.025, name, ha="left", va="top", size="small"
|
||||
)
|
||||
elif xy[0] >= 0.0 and xy[1] >= 0.0:
|
||||
ax.text(
|
||||
xy[0] + 0.025, xy[1] + 0.025, name, ha="left", va="bottom", size="small"
|
||||
)
|
||||
|
||||
ax.axis("equal")
|
||||
return ax
|
||||
|
||||
|
||||
def andrews_curves(
|
||||
frame, class_column, ax=None, samples=200, color=None, colormap=None, **kwds
|
||||
):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def function(amplitudes):
|
||||
def f(t):
|
||||
x1 = amplitudes[0]
|
||||
result = x1 / np.sqrt(2.0)
|
||||
|
||||
# Take the rest of the coefficients and resize them
|
||||
# appropriately. Take a copy of amplitudes as otherwise numpy
|
||||
# deletes the element from amplitudes itself.
|
||||
coeffs = np.delete(np.copy(amplitudes), 0)
|
||||
coeffs.resize(int((coeffs.size + 1) / 2), 2)
|
||||
|
||||
# Generate the harmonics and arguments for the sin and cos
|
||||
# functions.
|
||||
harmonics = np.arange(0, coeffs.shape[0]) + 1
|
||||
trig_args = np.outer(harmonics, t)
|
||||
|
||||
result += np.sum(
|
||||
coeffs[:, 0, np.newaxis] * np.sin(trig_args)
|
||||
+ coeffs[:, 1, np.newaxis] * np.cos(trig_args),
|
||||
axis=0,
|
||||
)
|
||||
return result
|
||||
|
||||
return f
|
||||
|
||||
n = len(frame)
|
||||
class_col = frame[class_column]
|
||||
classes = frame[class_column].drop_duplicates()
|
||||
df = frame.drop(class_column, axis=1)
|
||||
t = np.linspace(-np.pi, np.pi, samples)
|
||||
used_legends = set()
|
||||
|
||||
color_values = _get_standard_colors(
|
||||
num_colors=len(classes), colormap=colormap, color_type="random", color=color
|
||||
)
|
||||
colors = dict(zip(classes, color_values))
|
||||
if ax is None:
|
||||
ax = plt.gca(xlim=(-np.pi, np.pi))
|
||||
for i in range(n):
|
||||
row = df.iloc[i].values
|
||||
f = function(row)
|
||||
y = f(t)
|
||||
kls = class_col.iat[i]
|
||||
label = pprint_thing(kls)
|
||||
if label not in used_legends:
|
||||
used_legends.add(label)
|
||||
ax.plot(t, y, color=colors[kls], label=label, **kwds)
|
||||
else:
|
||||
ax.plot(t, y, color=colors[kls], **kwds)
|
||||
|
||||
ax.legend(loc="upper right")
|
||||
ax.grid()
|
||||
return ax
|
||||
|
||||
|
||||
def bootstrap_plot(series, fig=None, size=50, samples=500, **kwds):
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# random.sample(ndarray, int) fails on python 3.3, sigh
|
||||
data = list(series.values)
|
||||
samplings = [random.sample(data, size) for _ in range(samples)]
|
||||
|
||||
means = np.array([np.mean(sampling) for sampling in samplings])
|
||||
medians = np.array([np.median(sampling) for sampling in samplings])
|
||||
midranges = np.array(
|
||||
[(min(sampling) + max(sampling)) * 0.5 for sampling in samplings]
|
||||
)
|
||||
if fig is None:
|
||||
fig = plt.figure()
|
||||
x = list(range(samples))
|
||||
axes = []
|
||||
ax1 = fig.add_subplot(2, 3, 1)
|
||||
ax1.set_xlabel("Sample")
|
||||
axes.append(ax1)
|
||||
ax1.plot(x, means, **kwds)
|
||||
ax2 = fig.add_subplot(2, 3, 2)
|
||||
ax2.set_xlabel("Sample")
|
||||
axes.append(ax2)
|
||||
ax2.plot(x, medians, **kwds)
|
||||
ax3 = fig.add_subplot(2, 3, 3)
|
||||
ax3.set_xlabel("Sample")
|
||||
axes.append(ax3)
|
||||
ax3.plot(x, midranges, **kwds)
|
||||
ax4 = fig.add_subplot(2, 3, 4)
|
||||
ax4.set_xlabel("Mean")
|
||||
axes.append(ax4)
|
||||
ax4.hist(means, **kwds)
|
||||
ax5 = fig.add_subplot(2, 3, 5)
|
||||
ax5.set_xlabel("Median")
|
||||
axes.append(ax5)
|
||||
ax5.hist(medians, **kwds)
|
||||
ax6 = fig.add_subplot(2, 3, 6)
|
||||
ax6.set_xlabel("Midrange")
|
||||
axes.append(ax6)
|
||||
ax6.hist(midranges, **kwds)
|
||||
for axis in axes:
|
||||
plt.setp(axis.get_xticklabels(), fontsize=8)
|
||||
plt.setp(axis.get_yticklabels(), fontsize=8)
|
||||
return fig
|
||||
|
||||
|
||||
def parallel_coordinates(
|
||||
frame,
|
||||
class_column,
|
||||
cols=None,
|
||||
ax=None,
|
||||
color=None,
|
||||
use_columns=False,
|
||||
xticks=None,
|
||||
colormap=None,
|
||||
axvlines=True,
|
||||
axvlines_kwds=None,
|
||||
sort_labels=False,
|
||||
**kwds
|
||||
):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if axvlines_kwds is None:
|
||||
axvlines_kwds = {"linewidth": 1, "color": "black"}
|
||||
|
||||
n = len(frame)
|
||||
classes = frame[class_column].drop_duplicates()
|
||||
class_col = frame[class_column]
|
||||
|
||||
if cols is None:
|
||||
df = frame.drop(class_column, axis=1)
|
||||
else:
|
||||
df = frame[cols]
|
||||
|
||||
used_legends = set()
|
||||
|
||||
ncols = len(df.columns)
|
||||
|
||||
# determine values to use for xticks
|
||||
if use_columns is True:
|
||||
if not np.all(np.isreal(list(df.columns))):
|
||||
raise ValueError("Columns must be numeric to be used as xticks")
|
||||
x = df.columns
|
||||
elif xticks is not None:
|
||||
if not np.all(np.isreal(xticks)):
|
||||
raise ValueError("xticks specified must be numeric")
|
||||
elif len(xticks) != ncols:
|
||||
raise ValueError("Length of xticks must match number of columns")
|
||||
x = xticks
|
||||
else:
|
||||
x = list(range(ncols))
|
||||
|
||||
if ax is None:
|
||||
ax = plt.gca()
|
||||
|
||||
color_values = _get_standard_colors(
|
||||
num_colors=len(classes), colormap=colormap, color_type="random", color=color
|
||||
)
|
||||
|
||||
if sort_labels:
|
||||
classes = sorted(classes)
|
||||
color_values = sorted(color_values)
|
||||
colors = dict(zip(classes, color_values))
|
||||
|
||||
for i in range(n):
|
||||
y = df.iloc[i].values
|
||||
kls = class_col.iat[i]
|
||||
label = pprint_thing(kls)
|
||||
if label not in used_legends:
|
||||
used_legends.add(label)
|
||||
ax.plot(x, y, color=colors[kls], label=label, **kwds)
|
||||
else:
|
||||
ax.plot(x, y, color=colors[kls], **kwds)
|
||||
|
||||
if axvlines:
|
||||
for i in x:
|
||||
ax.axvline(i, **axvlines_kwds)
|
||||
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(df.columns)
|
||||
ax.set_xlim(x[0], x[-1])
|
||||
ax.legend(loc="upper right")
|
||||
ax.grid()
|
||||
return ax
|
||||
|
||||
|
||||
def lag_plot(series, lag=1, ax=None, **kwds):
|
||||
# workaround because `c='b'` is hardcoded in matplotlibs scatter method
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
kwds.setdefault("c", plt.rcParams["patch.facecolor"])
|
||||
|
||||
data = series.values
|
||||
y1 = data[:-lag]
|
||||
y2 = data[lag:]
|
||||
if ax is None:
|
||||
ax = plt.gca()
|
||||
ax.set_xlabel("y(t)")
|
||||
ax.set_ylabel("y(t + {lag})".format(lag=lag))
|
||||
ax.scatter(y1, y2, **kwds)
|
||||
return ax
|
||||
|
||||
|
||||
def autocorrelation_plot(series, ax=None, **kwds):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
n = len(series)
|
||||
data = np.asarray(series)
|
||||
if ax is None:
|
||||
ax = plt.gca(xlim=(1, n), ylim=(-1.0, 1.0))
|
||||
mean = np.mean(data)
|
||||
c0 = np.sum((data - mean) ** 2) / float(n)
|
||||
|
||||
def r(h):
|
||||
return ((data[: n - h] - mean) * (data[h:] - mean)).sum() / float(n) / c0
|
||||
|
||||
x = np.arange(n) + 1
|
||||
y = [r(loc) for loc in x]
|
||||
z95 = 1.959963984540054
|
||||
z99 = 2.5758293035489004
|
||||
ax.axhline(y=z99 / np.sqrt(n), linestyle="--", color="grey")
|
||||
ax.axhline(y=z95 / np.sqrt(n), color="grey")
|
||||
ax.axhline(y=0.0, color="black")
|
||||
ax.axhline(y=-z95 / np.sqrt(n), color="grey")
|
||||
ax.axhline(y=-z99 / np.sqrt(n), linestyle="--", color="grey")
|
||||
ax.set_xlabel("Lag")
|
||||
ax.set_ylabel("Autocorrelation")
|
||||
ax.plot(x, y, **kwds)
|
||||
if "label" in kwds:
|
||||
ax.legend()
|
||||
ax.grid()
|
||||
return ax
|
@@ -0,0 +1,92 @@
|
||||
# being a bit too dynamic
|
||||
import warnings
|
||||
|
||||
import matplotlib.cm as cm
|
||||
import matplotlib.colors
|
||||
import numpy as np
|
||||
|
||||
from pandas.core.dtypes.common import is_list_like
|
||||
|
||||
import pandas.core.common as com
|
||||
|
||||
|
||||
def _get_standard_colors(
|
||||
num_colors=None, colormap=None, color_type="default", color=None
|
||||
):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if color is None and colormap is not None:
|
||||
if isinstance(colormap, str):
|
||||
cmap = colormap
|
||||
colormap = cm.get_cmap(colormap)
|
||||
if colormap is None:
|
||||
raise ValueError("Colormap {0} is not recognized".format(cmap))
|
||||
colors = [colormap(num) for num in np.linspace(0, 1, num=num_colors)]
|
||||
elif color is not None:
|
||||
if colormap is not None:
|
||||
warnings.warn(
|
||||
"'color' and 'colormap' cannot be used " "simultaneously. Using 'color'"
|
||||
)
|
||||
colors = list(color) if is_list_like(color) else color
|
||||
else:
|
||||
if color_type == "default":
|
||||
# need to call list() on the result to copy so we don't
|
||||
# modify the global rcParams below
|
||||
try:
|
||||
colors = [c["color"] for c in list(plt.rcParams["axes.prop_cycle"])]
|
||||
except KeyError:
|
||||
colors = list(plt.rcParams.get("axes.color_cycle", list("bgrcmyk")))
|
||||
if isinstance(colors, str):
|
||||
colors = list(colors)
|
||||
|
||||
colors = colors[0:num_colors]
|
||||
elif color_type == "random":
|
||||
|
||||
def random_color(column):
|
||||
""" Returns a random color represented as a list of length 3"""
|
||||
# GH17525 use common._random_state to avoid resetting the seed
|
||||
rs = com.random_state(column)
|
||||
return rs.rand(3).tolist()
|
||||
|
||||
colors = [random_color(num) for num in range(num_colors)]
|
||||
else:
|
||||
raise ValueError("color_type must be either 'default' or 'random'")
|
||||
|
||||
if isinstance(colors, str):
|
||||
conv = matplotlib.colors.ColorConverter()
|
||||
|
||||
def _maybe_valid_colors(colors):
|
||||
try:
|
||||
[conv.to_rgba(c) for c in colors]
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
# check whether the string can be convertible to single color
|
||||
maybe_single_color = _maybe_valid_colors([colors])
|
||||
# check whether each character can be convertible to colors
|
||||
maybe_color_cycle = _maybe_valid_colors(list(colors))
|
||||
if maybe_single_color and maybe_color_cycle and len(colors) > 1:
|
||||
hex_color = [c["color"] for c in list(plt.rcParams["axes.prop_cycle"])]
|
||||
colors = [hex_color[int(colors[1])]]
|
||||
elif maybe_single_color:
|
||||
colors = [colors]
|
||||
else:
|
||||
# ``colors`` is regarded as color cycle.
|
||||
# mpl will raise error any of them is invalid
|
||||
pass
|
||||
|
||||
# Append more colors by cycling if there is not enough color.
|
||||
# Extra colors will be ignored by matplotlib if there are more colors
|
||||
# than needed and nothing needs to be done here.
|
||||
if len(colors) < num_colors:
|
||||
try:
|
||||
multiple = num_colors // len(colors) - 1
|
||||
except ZeroDivisionError:
|
||||
raise ValueError("Invalid color argument: ''")
|
||||
mod = num_colors % len(colors)
|
||||
|
||||
colors += multiple * colors
|
||||
colors += colors[:mod]
|
||||
|
||||
return colors
|
@@ -0,0 +1,370 @@
|
||||
# TODO: Use the fact that axis can have units to simplify the process
|
||||
|
||||
import functools
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from pandas._libs.tslibs.frequencies import (
|
||||
FreqGroup,
|
||||
get_base_alias,
|
||||
get_freq,
|
||||
is_subperiod,
|
||||
is_superperiod,
|
||||
)
|
||||
from pandas._libs.tslibs.period import Period
|
||||
|
||||
from pandas.core.dtypes.generic import (
|
||||
ABCDatetimeIndex,
|
||||
ABCPeriodIndex,
|
||||
ABCTimedeltaIndex,
|
||||
)
|
||||
|
||||
from pandas.io.formats.printing import pprint_thing
|
||||
from pandas.plotting._matplotlib.converter import (
|
||||
TimeSeries_DateFormatter,
|
||||
TimeSeries_DateLocator,
|
||||
TimeSeries_TimedeltaFormatter,
|
||||
)
|
||||
import pandas.tseries.frequencies as frequencies
|
||||
from pandas.tseries.offsets import DateOffset
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Plotting functions and monkey patches
|
||||
|
||||
|
||||
def tsplot(series, plotf, ax=None, **kwargs):
|
||||
"""
|
||||
Plots a Series on the given Matplotlib axes or the current axes
|
||||
|
||||
Parameters
|
||||
----------
|
||||
axes : Axes
|
||||
series : Series
|
||||
|
||||
Notes
|
||||
_____
|
||||
Supports same kwargs as Axes.plot
|
||||
|
||||
|
||||
.. deprecated:: 0.23.0
|
||||
Use Series.plot() instead
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
warnings.warn(
|
||||
"'tsplot' is deprecated and will be removed in a "
|
||||
"future version. Please use Series.plot() instead.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# Used inferred freq is possible, need a test case for inferred
|
||||
if ax is None:
|
||||
ax = plt.gca()
|
||||
|
||||
freq, series = _maybe_resample(series, ax, kwargs)
|
||||
|
||||
# Set ax with freq info
|
||||
_decorate_axes(ax, freq, kwargs)
|
||||
ax._plot_data.append((series, plotf, kwargs))
|
||||
lines = plotf(ax, series.index._mpl_repr(), series.values, **kwargs)
|
||||
|
||||
# set date formatter, locators and rescale limits
|
||||
format_dateaxis(ax, ax.freq, series.index)
|
||||
return lines
|
||||
|
||||
|
||||
def _maybe_resample(series, ax, kwargs):
|
||||
# resample against axes freq if necessary
|
||||
freq, ax_freq = _get_freq(ax, series)
|
||||
|
||||
if freq is None: # pragma: no cover
|
||||
raise ValueError("Cannot use dynamic axis without frequency info")
|
||||
|
||||
# Convert DatetimeIndex to PeriodIndex
|
||||
if isinstance(series.index, ABCDatetimeIndex):
|
||||
series = series.to_period(freq=freq)
|
||||
|
||||
if ax_freq is not None and freq != ax_freq:
|
||||
if is_superperiod(freq, ax_freq): # upsample input
|
||||
series = series.copy()
|
||||
series.index = series.index.asfreq(ax_freq, how="s")
|
||||
freq = ax_freq
|
||||
elif _is_sup(freq, ax_freq): # one is weekly
|
||||
how = kwargs.pop("how", "last")
|
||||
series = getattr(series.resample("D"), how)().dropna()
|
||||
series = getattr(series.resample(ax_freq), how)().dropna()
|
||||
freq = ax_freq
|
||||
elif is_subperiod(freq, ax_freq) or _is_sub(freq, ax_freq):
|
||||
_upsample_others(ax, freq, kwargs)
|
||||
else: # pragma: no cover
|
||||
raise ValueError("Incompatible frequency conversion")
|
||||
return freq, series
|
||||
|
||||
|
||||
def _is_sub(f1, f2):
|
||||
return (f1.startswith("W") and is_subperiod("D", f2)) or (
|
||||
f2.startswith("W") and is_subperiod(f1, "D")
|
||||
)
|
||||
|
||||
|
||||
def _is_sup(f1, f2):
|
||||
return (f1.startswith("W") and is_superperiod("D", f2)) or (
|
||||
f2.startswith("W") and is_superperiod(f1, "D")
|
||||
)
|
||||
|
||||
|
||||
def _upsample_others(ax, freq, kwargs):
|
||||
legend = ax.get_legend()
|
||||
lines, labels = _replot_ax(ax, freq, kwargs)
|
||||
_replot_ax(ax, freq, kwargs)
|
||||
|
||||
other_ax = None
|
||||
if hasattr(ax, "left_ax"):
|
||||
other_ax = ax.left_ax
|
||||
if hasattr(ax, "right_ax"):
|
||||
other_ax = ax.right_ax
|
||||
|
||||
if other_ax is not None:
|
||||
rlines, rlabels = _replot_ax(other_ax, freq, kwargs)
|
||||
lines.extend(rlines)
|
||||
labels.extend(rlabels)
|
||||
|
||||
if legend is not None and kwargs.get("legend", True) and len(lines) > 0:
|
||||
title = legend.get_title().get_text()
|
||||
if title == "None":
|
||||
title = None
|
||||
ax.legend(lines, labels, loc="best", title=title)
|
||||
|
||||
|
||||
def _replot_ax(ax, freq, kwargs):
|
||||
data = getattr(ax, "_plot_data", None)
|
||||
|
||||
# clear current axes and data
|
||||
ax._plot_data = []
|
||||
ax.clear()
|
||||
|
||||
_decorate_axes(ax, freq, kwargs)
|
||||
|
||||
lines = []
|
||||
labels = []
|
||||
if data is not None:
|
||||
for series, plotf, kwds in data:
|
||||
series = series.copy()
|
||||
idx = series.index.asfreq(freq, how="S")
|
||||
series.index = idx
|
||||
ax._plot_data.append((series, plotf, kwds))
|
||||
|
||||
# for tsplot
|
||||
if isinstance(plotf, str):
|
||||
from pandas.plotting._matplotlib import PLOT_CLASSES
|
||||
|
||||
plotf = PLOT_CLASSES[plotf]._plot
|
||||
|
||||
lines.append(plotf(ax, series.index._mpl_repr(), series.values, **kwds)[0])
|
||||
labels.append(pprint_thing(series.name))
|
||||
|
||||
return lines, labels
|
||||
|
||||
|
||||
def _decorate_axes(ax, freq, kwargs):
|
||||
"""Initialize axes for time-series plotting"""
|
||||
if not hasattr(ax, "_plot_data"):
|
||||
ax._plot_data = []
|
||||
|
||||
ax.freq = freq
|
||||
xaxis = ax.get_xaxis()
|
||||
xaxis.freq = freq
|
||||
if not hasattr(ax, "legendlabels"):
|
||||
ax.legendlabels = [kwargs.get("label", None)]
|
||||
else:
|
||||
ax.legendlabels.append(kwargs.get("label", None))
|
||||
ax.view_interval = None
|
||||
ax.date_axis_info = None
|
||||
|
||||
|
||||
def _get_ax_freq(ax):
|
||||
"""
|
||||
Get the freq attribute of the ax object if set.
|
||||
Also checks shared axes (eg when using secondary yaxis, sharex=True
|
||||
or twinx)
|
||||
"""
|
||||
ax_freq = getattr(ax, "freq", None)
|
||||
if ax_freq is None:
|
||||
# check for left/right ax in case of secondary yaxis
|
||||
if hasattr(ax, "left_ax"):
|
||||
ax_freq = getattr(ax.left_ax, "freq", None)
|
||||
elif hasattr(ax, "right_ax"):
|
||||
ax_freq = getattr(ax.right_ax, "freq", None)
|
||||
if ax_freq is None:
|
||||
# check if a shared ax (sharex/twinx) has already freq set
|
||||
shared_axes = ax.get_shared_x_axes().get_siblings(ax)
|
||||
if len(shared_axes) > 1:
|
||||
for shared_ax in shared_axes:
|
||||
ax_freq = getattr(shared_ax, "freq", None)
|
||||
if ax_freq is not None:
|
||||
break
|
||||
return ax_freq
|
||||
|
||||
|
||||
def _get_freq(ax, series):
|
||||
# get frequency from data
|
||||
freq = getattr(series.index, "freq", None)
|
||||
if freq is None:
|
||||
freq = getattr(series.index, "inferred_freq", None)
|
||||
|
||||
ax_freq = _get_ax_freq(ax)
|
||||
|
||||
# use axes freq if no data freq
|
||||
if freq is None:
|
||||
freq = ax_freq
|
||||
|
||||
# get the period frequency
|
||||
if isinstance(freq, DateOffset):
|
||||
freq = freq.rule_code
|
||||
else:
|
||||
freq = get_base_alias(freq)
|
||||
|
||||
freq = frequencies.get_period_alias(freq)
|
||||
return freq, ax_freq
|
||||
|
||||
|
||||
def _use_dynamic_x(ax, data):
|
||||
freq = _get_index_freq(data)
|
||||
ax_freq = _get_ax_freq(ax)
|
||||
|
||||
if freq is None: # convert irregular if axes has freq info
|
||||
freq = ax_freq
|
||||
else: # do not use tsplot if irregular was plotted first
|
||||
if (ax_freq is None) and (len(ax.get_lines()) > 0):
|
||||
return False
|
||||
|
||||
if freq is None:
|
||||
return False
|
||||
|
||||
if isinstance(freq, DateOffset):
|
||||
freq = freq.rule_code
|
||||
else:
|
||||
freq = get_base_alias(freq)
|
||||
freq = frequencies.get_period_alias(freq)
|
||||
|
||||
if freq is None:
|
||||
return False
|
||||
|
||||
# hack this for 0.10.1, creating more technical debt...sigh
|
||||
if isinstance(data.index, ABCDatetimeIndex):
|
||||
base = get_freq(freq)
|
||||
x = data.index
|
||||
if base <= FreqGroup.FR_DAY:
|
||||
return x[:1].is_normalized
|
||||
return Period(x[0], freq).to_timestamp(tz=x.tz) == x[0]
|
||||
return True
|
||||
|
||||
|
||||
def _get_index_freq(data):
|
||||
freq = getattr(data.index, "freq", None)
|
||||
if freq is None:
|
||||
freq = getattr(data.index, "inferred_freq", None)
|
||||
if freq == "B":
|
||||
weekdays = np.unique(data.index.dayofweek)
|
||||
if (5 in weekdays) or (6 in weekdays):
|
||||
freq = None
|
||||
return freq
|
||||
|
||||
|
||||
def _maybe_convert_index(ax, data):
|
||||
# tsplot converts automatically, but don't want to convert index
|
||||
# over and over for DataFrames
|
||||
if isinstance(data.index, (ABCDatetimeIndex, ABCPeriodIndex)):
|
||||
freq = getattr(data.index, "freq", None)
|
||||
|
||||
if freq is None:
|
||||
freq = getattr(data.index, "inferred_freq", None)
|
||||
if isinstance(freq, DateOffset):
|
||||
freq = freq.rule_code
|
||||
|
||||
if freq is None:
|
||||
freq = _get_ax_freq(ax)
|
||||
|
||||
if freq is None:
|
||||
raise ValueError("Could not get frequency alias for plotting")
|
||||
|
||||
freq = get_base_alias(freq)
|
||||
freq = frequencies.get_period_alias(freq)
|
||||
|
||||
if isinstance(data.index, ABCDatetimeIndex):
|
||||
data = data.to_period(freq=freq)
|
||||
elif isinstance(data.index, ABCPeriodIndex):
|
||||
data.index = data.index.asfreq(freq=freq)
|
||||
return data
|
||||
|
||||
|
||||
# Patch methods for subplot. Only format_dateaxis is currently used.
|
||||
# Do we need the rest for convenience?
|
||||
|
||||
|
||||
def format_timedelta_ticks(x, pos, n_decimals):
|
||||
"""
|
||||
Convert seconds to 'D days HH:MM:SS.F'
|
||||
"""
|
||||
s, ns = divmod(x, 1e9)
|
||||
m, s = divmod(s, 60)
|
||||
h, m = divmod(m, 60)
|
||||
d, h = divmod(h, 24)
|
||||
decimals = int(ns * 10 ** (n_decimals - 9))
|
||||
s = r"{:02d}:{:02d}:{:02d}".format(int(h), int(m), int(s))
|
||||
if n_decimals > 0:
|
||||
s += ".{{:0{:0d}d}}".format(n_decimals).format(decimals)
|
||||
if d != 0:
|
||||
s = "{:d} days ".format(int(d)) + s
|
||||
return s
|
||||
|
||||
|
||||
def _format_coord(freq, t, y):
|
||||
return "t = {0} y = {1:8f}".format(Period(ordinal=int(t), freq=freq), y)
|
||||
|
||||
|
||||
def format_dateaxis(subplot, freq, index):
|
||||
"""
|
||||
Pretty-formats the date axis (x-axis).
|
||||
|
||||
Major and minor ticks are automatically set for the frequency of the
|
||||
current underlying series. As the dynamic mode is activated by
|
||||
default, changing the limits of the x axis will intelligently change
|
||||
the positions of the ticks.
|
||||
"""
|
||||
from matplotlib import pylab
|
||||
|
||||
# handle index specific formatting
|
||||
# Note: DatetimeIndex does not use this
|
||||
# interface. DatetimeIndex uses matplotlib.date directly
|
||||
if isinstance(index, ABCPeriodIndex):
|
||||
|
||||
majlocator = TimeSeries_DateLocator(
|
||||
freq, dynamic_mode=True, minor_locator=False, plot_obj=subplot
|
||||
)
|
||||
minlocator = TimeSeries_DateLocator(
|
||||
freq, dynamic_mode=True, minor_locator=True, plot_obj=subplot
|
||||
)
|
||||
subplot.xaxis.set_major_locator(majlocator)
|
||||
subplot.xaxis.set_minor_locator(minlocator)
|
||||
|
||||
majformatter = TimeSeries_DateFormatter(
|
||||
freq, dynamic_mode=True, minor_locator=False, plot_obj=subplot
|
||||
)
|
||||
minformatter = TimeSeries_DateFormatter(
|
||||
freq, dynamic_mode=True, minor_locator=True, plot_obj=subplot
|
||||
)
|
||||
subplot.xaxis.set_major_formatter(majformatter)
|
||||
subplot.xaxis.set_minor_formatter(minformatter)
|
||||
|
||||
# x and y coord info
|
||||
subplot.format_coord = functools.partial(_format_coord, freq)
|
||||
|
||||
elif isinstance(index, ABCTimedeltaIndex):
|
||||
subplot.xaxis.set_major_formatter(TimeSeries_TimedeltaFormatter())
|
||||
else:
|
||||
raise TypeError("index type not supported")
|
||||
|
||||
pylab.draw_if_interactive()
|
@@ -0,0 +1,379 @@
|
||||
# being a bit too dynamic
|
||||
from math import ceil
|
||||
import warnings
|
||||
|
||||
import matplotlib.table
|
||||
import matplotlib.ticker as ticker
|
||||
import numpy as np
|
||||
|
||||
from pandas.core.dtypes.common import is_list_like
|
||||
from pandas.core.dtypes.generic import ABCDataFrame, ABCIndexClass, ABCSeries
|
||||
|
||||
|
||||
def format_date_labels(ax, rot):
|
||||
# mini version of autofmt_xdate
|
||||
try:
|
||||
for label in ax.get_xticklabels():
|
||||
label.set_ha("right")
|
||||
label.set_rotation(rot)
|
||||
fig = ax.get_figure()
|
||||
fig.subplots_adjust(bottom=0.2)
|
||||
except Exception: # pragma: no cover
|
||||
pass
|
||||
|
||||
|
||||
def table(ax, data, rowLabels=None, colLabels=None, **kwargs):
|
||||
if isinstance(data, ABCSeries):
|
||||
data = data.to_frame()
|
||||
elif isinstance(data, ABCDataFrame):
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Input data must be DataFrame or Series")
|
||||
|
||||
if rowLabels is None:
|
||||
rowLabels = data.index
|
||||
|
||||
if colLabels is None:
|
||||
colLabels = data.columns
|
||||
|
||||
cellText = data.values
|
||||
|
||||
table = matplotlib.table.table(
|
||||
ax, cellText=cellText, rowLabels=rowLabels, colLabels=colLabels, **kwargs
|
||||
)
|
||||
return table
|
||||
|
||||
|
||||
def _get_layout(nplots, layout=None, layout_type="box"):
|
||||
if layout is not None:
|
||||
if not isinstance(layout, (tuple, list)) or len(layout) != 2:
|
||||
raise ValueError("Layout must be a tuple of (rows, columns)")
|
||||
|
||||
nrows, ncols = layout
|
||||
|
||||
# Python 2 compat
|
||||
ceil_ = lambda x: int(ceil(x))
|
||||
if nrows == -1 and ncols > 0:
|
||||
layout = nrows, ncols = (ceil_(float(nplots) / ncols), ncols)
|
||||
elif ncols == -1 and nrows > 0:
|
||||
layout = nrows, ncols = (nrows, ceil_(float(nplots) / nrows))
|
||||
elif ncols <= 0 and nrows <= 0:
|
||||
msg = "At least one dimension of layout must be positive"
|
||||
raise ValueError(msg)
|
||||
|
||||
if nrows * ncols < nplots:
|
||||
raise ValueError(
|
||||
"Layout of {nrows}x{ncols} must be larger "
|
||||
"than required size {nplots}".format(
|
||||
nrows=nrows, ncols=ncols, nplots=nplots
|
||||
)
|
||||
)
|
||||
|
||||
return layout
|
||||
|
||||
if layout_type == "single":
|
||||
return (1, 1)
|
||||
elif layout_type == "horizontal":
|
||||
return (1, nplots)
|
||||
elif layout_type == "vertical":
|
||||
return (nplots, 1)
|
||||
|
||||
layouts = {1: (1, 1), 2: (1, 2), 3: (2, 2), 4: (2, 2)}
|
||||
try:
|
||||
return layouts[nplots]
|
||||
except KeyError:
|
||||
k = 1
|
||||
while k ** 2 < nplots:
|
||||
k += 1
|
||||
|
||||
if (k - 1) * k >= nplots:
|
||||
return k, (k - 1)
|
||||
else:
|
||||
return k, k
|
||||
|
||||
|
||||
# copied from matplotlib/pyplot.py and modified for pandas.plotting
|
||||
|
||||
|
||||
def _subplots(
|
||||
naxes=None,
|
||||
sharex=False,
|
||||
sharey=False,
|
||||
squeeze=True,
|
||||
subplot_kw=None,
|
||||
ax=None,
|
||||
layout=None,
|
||||
layout_type="box",
|
||||
**fig_kw
|
||||
):
|
||||
"""Create a figure with a set of subplots already made.
|
||||
|
||||
This utility wrapper makes it convenient to create common layouts of
|
||||
subplots, including the enclosing figure object, in a single call.
|
||||
|
||||
Keyword arguments:
|
||||
|
||||
naxes : int
|
||||
Number of required axes. Exceeded axes are set invisible. Default is
|
||||
nrows * ncols.
|
||||
|
||||
sharex : bool
|
||||
If True, the X axis will be shared amongst all subplots.
|
||||
|
||||
sharey : bool
|
||||
If True, the Y axis will be shared amongst all subplots.
|
||||
|
||||
squeeze : bool
|
||||
|
||||
If True, extra dimensions are squeezed out from the returned axis object:
|
||||
- if only one subplot is constructed (nrows=ncols=1), the resulting
|
||||
single Axis object is returned as a scalar.
|
||||
- for Nx1 or 1xN subplots, the returned object is a 1-d numpy object
|
||||
array of Axis objects are returned as numpy 1-d arrays.
|
||||
- for NxM subplots with N>1 and M>1 are returned as a 2d array.
|
||||
|
||||
If False, no squeezing is done: the returned axis object is always
|
||||
a 2-d array containing Axis instances, even if it ends up being 1x1.
|
||||
|
||||
subplot_kw : dict
|
||||
Dict with keywords passed to the add_subplot() call used to create each
|
||||
subplots.
|
||||
|
||||
ax : Matplotlib axis object, optional
|
||||
|
||||
layout : tuple
|
||||
Number of rows and columns of the subplot grid.
|
||||
If not specified, calculated from naxes and layout_type
|
||||
|
||||
layout_type : {'box', 'horizontal', 'vertical'}, default 'box'
|
||||
Specify how to layout the subplot grid.
|
||||
|
||||
fig_kw : Other keyword arguments to be passed to the figure() call.
|
||||
Note that all keywords not recognized above will be
|
||||
automatically included here.
|
||||
|
||||
Returns:
|
||||
|
||||
fig, ax : tuple
|
||||
- fig is the Matplotlib Figure object
|
||||
- ax can be either a single axis object or an array of axis objects if
|
||||
more than one subplot was created. The dimensions of the resulting array
|
||||
can be controlled with the squeeze keyword, see above.
|
||||
|
||||
**Examples:**
|
||||
|
||||
x = np.linspace(0, 2*np.pi, 400)
|
||||
y = np.sin(x**2)
|
||||
|
||||
# Just a figure and one subplot
|
||||
f, ax = plt.subplots()
|
||||
ax.plot(x, y)
|
||||
ax.set_title('Simple plot')
|
||||
|
||||
# Two subplots, unpack the output array immediately
|
||||
f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
|
||||
ax1.plot(x, y)
|
||||
ax1.set_title('Sharing Y axis')
|
||||
ax2.scatter(x, y)
|
||||
|
||||
# Four polar axes
|
||||
plt.subplots(2, 2, subplot_kw=dict(polar=True))
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if subplot_kw is None:
|
||||
subplot_kw = {}
|
||||
|
||||
if ax is None:
|
||||
fig = plt.figure(**fig_kw)
|
||||
else:
|
||||
if is_list_like(ax):
|
||||
ax = _flatten(ax)
|
||||
if layout is not None:
|
||||
warnings.warn(
|
||||
"When passing multiple axes, layout keyword is " "ignored",
|
||||
UserWarning,
|
||||
)
|
||||
if sharex or sharey:
|
||||
warnings.warn(
|
||||
"When passing multiple axes, sharex and sharey "
|
||||
"are ignored. These settings must be specified "
|
||||
"when creating axes",
|
||||
UserWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
if len(ax) == naxes:
|
||||
fig = ax[0].get_figure()
|
||||
return fig, ax
|
||||
else:
|
||||
raise ValueError(
|
||||
"The number of passed axes must be {0}, the "
|
||||
"same as the output plot".format(naxes)
|
||||
)
|
||||
|
||||
fig = ax.get_figure()
|
||||
# if ax is passed and a number of subplots is 1, return ax as it is
|
||||
if naxes == 1:
|
||||
if squeeze:
|
||||
return fig, ax
|
||||
else:
|
||||
return fig, _flatten(ax)
|
||||
else:
|
||||
warnings.warn(
|
||||
"To output multiple subplots, the figure containing "
|
||||
"the passed axes is being cleared",
|
||||
UserWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
fig.clear()
|
||||
|
||||
nrows, ncols = _get_layout(naxes, layout=layout, layout_type=layout_type)
|
||||
nplots = nrows * ncols
|
||||
|
||||
# Create empty object array to hold all axes. It's easiest to make it 1-d
|
||||
# so we can just append subplots upon creation, and then
|
||||
axarr = np.empty(nplots, dtype=object)
|
||||
|
||||
# Create first subplot separately, so we can share it if requested
|
||||
ax0 = fig.add_subplot(nrows, ncols, 1, **subplot_kw)
|
||||
|
||||
if sharex:
|
||||
subplot_kw["sharex"] = ax0
|
||||
if sharey:
|
||||
subplot_kw["sharey"] = ax0
|
||||
axarr[0] = ax0
|
||||
|
||||
# Note off-by-one counting because add_subplot uses the MATLAB 1-based
|
||||
# convention.
|
||||
for i in range(1, nplots):
|
||||
kwds = subplot_kw.copy()
|
||||
# Set sharex and sharey to None for blank/dummy axes, these can
|
||||
# interfere with proper axis limits on the visible axes if
|
||||
# they share axes e.g. issue #7528
|
||||
if i >= naxes:
|
||||
kwds["sharex"] = None
|
||||
kwds["sharey"] = None
|
||||
ax = fig.add_subplot(nrows, ncols, i + 1, **kwds)
|
||||
axarr[i] = ax
|
||||
|
||||
if naxes != nplots:
|
||||
for ax in axarr[naxes:]:
|
||||
ax.set_visible(False)
|
||||
|
||||
_handle_shared_axes(axarr, nplots, naxes, nrows, ncols, sharex, sharey)
|
||||
|
||||
if squeeze:
|
||||
# Reshape the array to have the final desired dimension (nrow,ncol),
|
||||
# though discarding unneeded dimensions that equal 1. If we only have
|
||||
# one subplot, just return it instead of a 1-element array.
|
||||
if nplots == 1:
|
||||
axes = axarr[0]
|
||||
else:
|
||||
axes = axarr.reshape(nrows, ncols).squeeze()
|
||||
else:
|
||||
# returned axis array will be always 2-d, even if nrows=ncols=1
|
||||
axes = axarr.reshape(nrows, ncols)
|
||||
|
||||
return fig, axes
|
||||
|
||||
|
||||
def _remove_labels_from_axis(axis):
|
||||
for t in axis.get_majorticklabels():
|
||||
t.set_visible(False)
|
||||
|
||||
try:
|
||||
# set_visible will not be effective if
|
||||
# minor axis has NullLocator and NullFormattor (default)
|
||||
if isinstance(axis.get_minor_locator(), ticker.NullLocator):
|
||||
axis.set_minor_locator(ticker.AutoLocator())
|
||||
if isinstance(axis.get_minor_formatter(), ticker.NullFormatter):
|
||||
axis.set_minor_formatter(ticker.FormatStrFormatter(""))
|
||||
for t in axis.get_minorticklabels():
|
||||
t.set_visible(False)
|
||||
except Exception: # pragma no cover
|
||||
raise
|
||||
axis.get_label().set_visible(False)
|
||||
|
||||
|
||||
def _handle_shared_axes(axarr, nplots, naxes, nrows, ncols, sharex, sharey):
|
||||
if nplots > 1:
|
||||
|
||||
if nrows > 1:
|
||||
try:
|
||||
# first find out the ax layout,
|
||||
# so that we can correctly handle 'gaps"
|
||||
layout = np.zeros((nrows + 1, ncols + 1), dtype=np.bool)
|
||||
for ax in axarr:
|
||||
layout[ax.rowNum, ax.colNum] = ax.get_visible()
|
||||
|
||||
for ax in axarr:
|
||||
# only the last row of subplots should get x labels -> all
|
||||
# other off layout handles the case that the subplot is
|
||||
# the last in the column, because below is no subplot/gap.
|
||||
if not layout[ax.rowNum + 1, ax.colNum]:
|
||||
continue
|
||||
if sharex or len(ax.get_shared_x_axes().get_siblings(ax)) > 1:
|
||||
_remove_labels_from_axis(ax.xaxis)
|
||||
|
||||
except IndexError:
|
||||
# if gridspec is used, ax.rowNum and ax.colNum may different
|
||||
# from layout shape. in this case, use last_row logic
|
||||
for ax in axarr:
|
||||
if ax.is_last_row():
|
||||
continue
|
||||
if sharex or len(ax.get_shared_x_axes().get_siblings(ax)) > 1:
|
||||
_remove_labels_from_axis(ax.xaxis)
|
||||
|
||||
if ncols > 1:
|
||||
for ax in axarr:
|
||||
# only the first column should get y labels -> set all other to
|
||||
# off as we only have labels in the first column and we always
|
||||
# have a subplot there, we can skip the layout test
|
||||
if ax.is_first_col():
|
||||
continue
|
||||
if sharey or len(ax.get_shared_y_axes().get_siblings(ax)) > 1:
|
||||
_remove_labels_from_axis(ax.yaxis)
|
||||
|
||||
|
||||
def _flatten(axes):
|
||||
if not is_list_like(axes):
|
||||
return np.array([axes])
|
||||
elif isinstance(axes, (np.ndarray, ABCIndexClass)):
|
||||
return axes.ravel()
|
||||
return np.array(axes)
|
||||
|
||||
|
||||
def _get_all_lines(ax):
|
||||
lines = ax.get_lines()
|
||||
|
||||
if hasattr(ax, "right_ax"):
|
||||
lines += ax.right_ax.get_lines()
|
||||
|
||||
if hasattr(ax, "left_ax"):
|
||||
lines += ax.left_ax.get_lines()
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _get_xlim(lines):
|
||||
left, right = np.inf, -np.inf
|
||||
for l in lines:
|
||||
x = l.get_xdata(orig=False)
|
||||
left = min(np.nanmin(x), left)
|
||||
right = max(np.nanmax(x), right)
|
||||
return left, right
|
||||
|
||||
|
||||
def _set_ticks_props(axes, xlabelsize=None, xrot=None, ylabelsize=None, yrot=None):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
for ax in _flatten(axes):
|
||||
if xlabelsize is not None:
|
||||
plt.setp(ax.get_xticklabels(), fontsize=xlabelsize)
|
||||
if xrot is not None:
|
||||
plt.setp(ax.get_xticklabels(), rotation=xrot)
|
||||
if ylabelsize is not None:
|
||||
plt.setp(ax.get_yticklabels(), fontsize=ylabelsize)
|
||||
if yrot is not None:
|
||||
plt.setp(ax.get_yticklabels(), rotation=yrot)
|
||||
return axes
|
Reference in New Issue
Block a user