664 lines
		
	
	
		
			21 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			664 lines
		
	
	
		
			21 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""
 | 
						|
Module consolidating common testing functions for checking plotting.
 | 
						|
 | 
						|
Currently all plotting tests are marked as slow via
 | 
						|
``pytestmark = pytest.mark.slow`` at the module level.
 | 
						|
"""
 | 
						|
 | 
						|
from __future__ import annotations
 | 
						|
 | 
						|
import os
 | 
						|
from typing import (
 | 
						|
    TYPE_CHECKING,
 | 
						|
    Sequence,
 | 
						|
)
 | 
						|
import warnings
 | 
						|
 | 
						|
import numpy as np
 | 
						|
 | 
						|
from pandas.util._decorators import cache_readonly
 | 
						|
import pandas.util._test_decorators as td
 | 
						|
 | 
						|
from pandas.core.dtypes.api import is_list_like
 | 
						|
 | 
						|
import pandas as pd
 | 
						|
from pandas import (
 | 
						|
    DataFrame,
 | 
						|
    Series,
 | 
						|
    to_datetime,
 | 
						|
)
 | 
						|
import pandas._testing as tm
 | 
						|
 | 
						|
if TYPE_CHECKING:
 | 
						|
    from matplotlib.axes import Axes
 | 
						|
 | 
						|
 | 
						|
@td.skip_if_no_mpl
 | 
						|
class TestPlotBase:
 | 
						|
    """
 | 
						|
    This is a common base class used for various plotting tests
 | 
						|
    """
 | 
						|
 | 
						|
    def setup_method(self, method):
 | 
						|
 | 
						|
        import matplotlib as mpl
 | 
						|
 | 
						|
        from pandas.plotting._matplotlib import compat
 | 
						|
 | 
						|
        self.compat = compat
 | 
						|
 | 
						|
        mpl.rcdefaults()
 | 
						|
 | 
						|
        self.start_date_to_int64 = 812419200000000000
 | 
						|
        self.end_date_to_int64 = 819331200000000000
 | 
						|
 | 
						|
        self.mpl_ge_2_2_3 = compat.mpl_ge_2_2_3()
 | 
						|
        self.mpl_ge_3_0_0 = compat.mpl_ge_3_0_0()
 | 
						|
        self.mpl_ge_3_1_0 = compat.mpl_ge_3_1_0()
 | 
						|
        self.mpl_ge_3_2_0 = compat.mpl_ge_3_2_0()
 | 
						|
 | 
						|
        self.bp_n_objects = 7
 | 
						|
        self.polycollection_factor = 2
 | 
						|
        self.default_figsize = (6.4, 4.8)
 | 
						|
        self.default_tick_position = "left"
 | 
						|
 | 
						|
        n = 100
 | 
						|
        with tm.RNGContext(42):
 | 
						|
            gender = np.random.choice(["Male", "Female"], size=n)
 | 
						|
            classroom = np.random.choice(["A", "B", "C"], size=n)
 | 
						|
 | 
						|
            self.hist_df = DataFrame(
 | 
						|
                {
 | 
						|
                    "gender": gender,
 | 
						|
                    "classroom": classroom,
 | 
						|
                    "height": np.random.normal(66, 4, size=n),
 | 
						|
                    "weight": np.random.normal(161, 32, size=n),
 | 
						|
                    "category": np.random.randint(4, size=n),
 | 
						|
                    "datetime": to_datetime(
 | 
						|
                        np.random.randint(
 | 
						|
                            self.start_date_to_int64,
 | 
						|
                            self.end_date_to_int64,
 | 
						|
                            size=n,
 | 
						|
                            dtype=np.int64,
 | 
						|
                        )
 | 
						|
                    ),
 | 
						|
                }
 | 
						|
            )
 | 
						|
 | 
						|
        self.tdf = tm.makeTimeDataFrame()
 | 
						|
        self.hexbin_df = DataFrame(
 | 
						|
            {
 | 
						|
                "A": np.random.uniform(size=20),
 | 
						|
                "B": np.random.uniform(size=20),
 | 
						|
                "C": np.arange(20) + np.random.uniform(size=20),
 | 
						|
            }
 | 
						|
        )
 | 
						|
 | 
						|
    def teardown_method(self, method):
 | 
						|
        tm.close()
 | 
						|
 | 
						|
    @cache_readonly
 | 
						|
    def plt(self):
 | 
						|
        import matplotlib.pyplot as plt
 | 
						|
 | 
						|
        return plt
 | 
						|
 | 
						|
    @cache_readonly
 | 
						|
    def colorconverter(self):
 | 
						|
        import matplotlib.colors as colors
 | 
						|
 | 
						|
        return colors.colorConverter
 | 
						|
 | 
						|
    def _check_legend_labels(self, axes, labels=None, visible=True):
 | 
						|
        """
 | 
						|
        Check each axes has expected legend labels
 | 
						|
 | 
						|
        Parameters
 | 
						|
        ----------
 | 
						|
        axes : matplotlib Axes object, or its list-like
 | 
						|
        labels : list-like
 | 
						|
            expected legend labels
 | 
						|
        visible : bool
 | 
						|
            expected legend visibility. labels are checked only when visible is
 | 
						|
            True
 | 
						|
        """
 | 
						|
        if visible and (labels is None):
 | 
						|
            raise ValueError("labels must be specified when visible is True")
 | 
						|
        axes = self._flatten_visible(axes)
 | 
						|
        for ax in axes:
 | 
						|
            if visible:
 | 
						|
                assert ax.get_legend() is not None
 | 
						|
                self._check_text_labels(ax.get_legend().get_texts(), labels)
 | 
						|
            else:
 | 
						|
                assert ax.get_legend() is None
 | 
						|
 | 
						|
    def _check_legend_marker(self, ax, expected_markers=None, visible=True):
 | 
						|
        """
 | 
						|
        Check ax has expected legend markers
 | 
						|
 | 
						|
        Parameters
 | 
						|
        ----------
 | 
						|
        ax : matplotlib Axes object
 | 
						|
        expected_markers : list-like
 | 
						|
            expected legend markers
 | 
						|
        visible : bool
 | 
						|
            expected legend visibility. labels are checked only when visible is
 | 
						|
            True
 | 
						|
        """
 | 
						|
        if visible and (expected_markers is None):
 | 
						|
            raise ValueError("Markers must be specified when visible is True")
 | 
						|
        if visible:
 | 
						|
            handles, _ = ax.get_legend_handles_labels()
 | 
						|
            markers = [handle.get_marker() for handle in handles]
 | 
						|
            assert markers == expected_markers
 | 
						|
        else:
 | 
						|
            assert ax.get_legend() is None
 | 
						|
 | 
						|
    def _check_data(self, xp, rs):
 | 
						|
        """
 | 
						|
        Check each axes has identical lines
 | 
						|
 | 
						|
        Parameters
 | 
						|
        ----------
 | 
						|
        xp : matplotlib Axes object
 | 
						|
        rs : matplotlib Axes object
 | 
						|
        """
 | 
						|
        xp_lines = xp.get_lines()
 | 
						|
        rs_lines = rs.get_lines()
 | 
						|
 | 
						|
        def check_line(xpl, rsl):
 | 
						|
            xpdata = xpl.get_xydata()
 | 
						|
            rsdata = rsl.get_xydata()
 | 
						|
            tm.assert_almost_equal(xpdata, rsdata)
 | 
						|
 | 
						|
        assert len(xp_lines) == len(rs_lines)
 | 
						|
        [check_line(xpl, rsl) for xpl, rsl in zip(xp_lines, rs_lines)]
 | 
						|
        tm.close()
 | 
						|
 | 
						|
    def _check_visible(self, collections, visible=True):
 | 
						|
        """
 | 
						|
        Check each artist is visible or not
 | 
						|
 | 
						|
        Parameters
 | 
						|
        ----------
 | 
						|
        collections : matplotlib Artist or its list-like
 | 
						|
            target Artist or its list or collection
 | 
						|
        visible : bool
 | 
						|
            expected visibility
 | 
						|
        """
 | 
						|
        from matplotlib.collections import Collection
 | 
						|
 | 
						|
        if not isinstance(collections, Collection) and not is_list_like(collections):
 | 
						|
            collections = [collections]
 | 
						|
 | 
						|
        for patch in collections:
 | 
						|
            assert patch.get_visible() == visible
 | 
						|
 | 
						|
    def _check_patches_all_filled(
 | 
						|
        self, axes: Axes | Sequence[Axes], filled: bool = True
 | 
						|
    ) -> None:
 | 
						|
        """
 | 
						|
        Check for each artist whether it is filled or not
 | 
						|
 | 
						|
        Parameters
 | 
						|
        ----------
 | 
						|
        axes : matplotlib Axes object, or its list-like
 | 
						|
        filled : bool
 | 
						|
            expected filling
 | 
						|
        """
 | 
						|
 | 
						|
        axes = self._flatten_visible(axes)
 | 
						|
        for ax in axes:
 | 
						|
            for patch in ax.patches:
 | 
						|
                assert patch.fill == filled
 | 
						|
 | 
						|
    def _get_colors_mapped(self, series, colors):
 | 
						|
        unique = series.unique()
 | 
						|
        # unique and colors length can be differed
 | 
						|
        # depending on slice value
 | 
						|
        mapped = dict(zip(unique, colors))
 | 
						|
        return [mapped[v] for v in series.values]
 | 
						|
 | 
						|
    def _check_colors(
 | 
						|
        self, collections, linecolors=None, facecolors=None, mapping=None
 | 
						|
    ):
 | 
						|
        """
 | 
						|
        Check each artist has expected line colors and face colors
 | 
						|
 | 
						|
        Parameters
 | 
						|
        ----------
 | 
						|
        collections : list-like
 | 
						|
            list or collection of target artist
 | 
						|
        linecolors : list-like which has the same length as collections
 | 
						|
            list of expected line colors
 | 
						|
        facecolors : list-like which has the same length as collections
 | 
						|
            list of expected face colors
 | 
						|
        mapping : Series
 | 
						|
            Series used for color grouping key
 | 
						|
            used for andrew_curves, parallel_coordinates, radviz test
 | 
						|
        """
 | 
						|
        from matplotlib.collections import (
 | 
						|
            Collection,
 | 
						|
            LineCollection,
 | 
						|
            PolyCollection,
 | 
						|
        )
 | 
						|
        from matplotlib.lines import Line2D
 | 
						|
 | 
						|
        conv = self.colorconverter
 | 
						|
        if linecolors is not None:
 | 
						|
 | 
						|
            if mapping is not None:
 | 
						|
                linecolors = self._get_colors_mapped(mapping, linecolors)
 | 
						|
                linecolors = linecolors[: len(collections)]
 | 
						|
 | 
						|
            assert len(collections) == len(linecolors)
 | 
						|
            for patch, color in zip(collections, linecolors):
 | 
						|
                if isinstance(patch, Line2D):
 | 
						|
                    result = patch.get_color()
 | 
						|
                    # Line2D may contains string color expression
 | 
						|
                    result = conv.to_rgba(result)
 | 
						|
                elif isinstance(patch, (PolyCollection, LineCollection)):
 | 
						|
                    result = tuple(patch.get_edgecolor()[0])
 | 
						|
                else:
 | 
						|
                    result = patch.get_edgecolor()
 | 
						|
 | 
						|
                expected = conv.to_rgba(color)
 | 
						|
                assert result == expected
 | 
						|
 | 
						|
        if facecolors is not None:
 | 
						|
 | 
						|
            if mapping is not None:
 | 
						|
                facecolors = self._get_colors_mapped(mapping, facecolors)
 | 
						|
                facecolors = facecolors[: len(collections)]
 | 
						|
 | 
						|
            assert len(collections) == len(facecolors)
 | 
						|
            for patch, color in zip(collections, facecolors):
 | 
						|
                if isinstance(patch, Collection):
 | 
						|
                    # returned as list of np.array
 | 
						|
                    result = patch.get_facecolor()[0]
 | 
						|
                else:
 | 
						|
                    result = patch.get_facecolor()
 | 
						|
 | 
						|
                if isinstance(result, np.ndarray):
 | 
						|
                    result = tuple(result)
 | 
						|
 | 
						|
                expected = conv.to_rgba(color)
 | 
						|
                assert result == expected
 | 
						|
 | 
						|
    def _check_text_labels(self, texts, expected):
 | 
						|
        """
 | 
						|
        Check each text has expected labels
 | 
						|
 | 
						|
        Parameters
 | 
						|
        ----------
 | 
						|
        texts : matplotlib Text object, or its list-like
 | 
						|
            target text, or its list
 | 
						|
        expected : str or list-like which has the same length as texts
 | 
						|
            expected text label, or its list
 | 
						|
        """
 | 
						|
        if not is_list_like(texts):
 | 
						|
            assert texts.get_text() == expected
 | 
						|
        else:
 | 
						|
            labels = [t.get_text() for t in texts]
 | 
						|
            assert len(labels) == len(expected)
 | 
						|
            for label, e in zip(labels, expected):
 | 
						|
                assert label == e
 | 
						|
 | 
						|
    def _check_ticks_props(
 | 
						|
        self, axes, xlabelsize=None, xrot=None, ylabelsize=None, yrot=None
 | 
						|
    ):
 | 
						|
        """
 | 
						|
        Check each axes has expected tick properties
 | 
						|
 | 
						|
        Parameters
 | 
						|
        ----------
 | 
						|
        axes : matplotlib Axes object, or its list-like
 | 
						|
        xlabelsize : number
 | 
						|
            expected xticks font size
 | 
						|
        xrot : number
 | 
						|
            expected xticks rotation
 | 
						|
        ylabelsize : number
 | 
						|
            expected yticks font size
 | 
						|
        yrot : number
 | 
						|
            expected yticks rotation
 | 
						|
        """
 | 
						|
        from matplotlib.ticker import NullFormatter
 | 
						|
 | 
						|
        axes = self._flatten_visible(axes)
 | 
						|
        for ax in axes:
 | 
						|
            if xlabelsize is not None or xrot is not None:
 | 
						|
                if isinstance(ax.xaxis.get_minor_formatter(), NullFormatter):
 | 
						|
                    # If minor ticks has NullFormatter, rot / fontsize are not
 | 
						|
                    # retained
 | 
						|
                    labels = ax.get_xticklabels()
 | 
						|
                else:
 | 
						|
                    labels = ax.get_xticklabels() + ax.get_xticklabels(minor=True)
 | 
						|
 | 
						|
                for label in labels:
 | 
						|
                    if xlabelsize is not None:
 | 
						|
                        tm.assert_almost_equal(label.get_fontsize(), xlabelsize)
 | 
						|
                    if xrot is not None:
 | 
						|
                        tm.assert_almost_equal(label.get_rotation(), xrot)
 | 
						|
 | 
						|
            if ylabelsize is not None or yrot is not None:
 | 
						|
                if isinstance(ax.yaxis.get_minor_formatter(), NullFormatter):
 | 
						|
                    labels = ax.get_yticklabels()
 | 
						|
                else:
 | 
						|
                    labels = ax.get_yticklabels() + ax.get_yticklabels(minor=True)
 | 
						|
 | 
						|
                for label in labels:
 | 
						|
                    if ylabelsize is not None:
 | 
						|
                        tm.assert_almost_equal(label.get_fontsize(), ylabelsize)
 | 
						|
                    if yrot is not None:
 | 
						|
                        tm.assert_almost_equal(label.get_rotation(), yrot)
 | 
						|
 | 
						|
    def _check_ax_scales(self, axes, xaxis="linear", yaxis="linear"):
 | 
						|
        """
 | 
						|
        Check each axes has expected scales
 | 
						|
 | 
						|
        Parameters
 | 
						|
        ----------
 | 
						|
        axes : matplotlib Axes object, or its list-like
 | 
						|
        xaxis : {'linear', 'log'}
 | 
						|
            expected xaxis scale
 | 
						|
        yaxis : {'linear', 'log'}
 | 
						|
            expected yaxis scale
 | 
						|
        """
 | 
						|
        axes = self._flatten_visible(axes)
 | 
						|
        for ax in axes:
 | 
						|
            assert ax.xaxis.get_scale() == xaxis
 | 
						|
            assert ax.yaxis.get_scale() == yaxis
 | 
						|
 | 
						|
    def _check_axes_shape(self, axes, axes_num=None, layout=None, figsize=None):
 | 
						|
        """
 | 
						|
        Check expected number of axes is drawn in expected layout
 | 
						|
 | 
						|
        Parameters
 | 
						|
        ----------
 | 
						|
        axes : matplotlib Axes object, or its list-like
 | 
						|
        axes_num : number
 | 
						|
            expected number of axes. Unnecessary axes should be set to
 | 
						|
            invisible.
 | 
						|
        layout : tuple
 | 
						|
            expected layout, (expected number of rows , columns)
 | 
						|
        figsize : tuple
 | 
						|
            expected figsize. default is matplotlib default
 | 
						|
        """
 | 
						|
        from pandas.plotting._matplotlib.tools import flatten_axes
 | 
						|
 | 
						|
        if figsize is None:
 | 
						|
            figsize = self.default_figsize
 | 
						|
        visible_axes = self._flatten_visible(axes)
 | 
						|
 | 
						|
        if axes_num is not None:
 | 
						|
            assert len(visible_axes) == axes_num
 | 
						|
            for ax in visible_axes:
 | 
						|
                # check something drawn on visible axes
 | 
						|
                assert len(ax.get_children()) > 0
 | 
						|
 | 
						|
        if layout is not None:
 | 
						|
            result = self._get_axes_layout(flatten_axes(axes))
 | 
						|
            assert result == layout
 | 
						|
 | 
						|
        tm.assert_numpy_array_equal(
 | 
						|
            visible_axes[0].figure.get_size_inches(),
 | 
						|
            np.array(figsize, dtype=np.float64),
 | 
						|
        )
 | 
						|
 | 
						|
    def _get_axes_layout(self, axes):
 | 
						|
        x_set = set()
 | 
						|
        y_set = set()
 | 
						|
        for ax in axes:
 | 
						|
            # check axes coordinates to estimate layout
 | 
						|
            points = ax.get_position().get_points()
 | 
						|
            x_set.add(points[0][0])
 | 
						|
            y_set.add(points[0][1])
 | 
						|
        return (len(y_set), len(x_set))
 | 
						|
 | 
						|
    def _flatten_visible(self, axes):
 | 
						|
        """
 | 
						|
        Flatten axes, and filter only visible
 | 
						|
 | 
						|
        Parameters
 | 
						|
        ----------
 | 
						|
        axes : matplotlib Axes object, or its list-like
 | 
						|
 | 
						|
        """
 | 
						|
        from pandas.plotting._matplotlib.tools import flatten_axes
 | 
						|
 | 
						|
        axes = flatten_axes(axes)
 | 
						|
        axes = [ax for ax in axes if ax.get_visible()]
 | 
						|
        return axes
 | 
						|
 | 
						|
    def _check_has_errorbars(self, axes, xerr=0, yerr=0):
 | 
						|
        """
 | 
						|
        Check axes has expected number of errorbars
 | 
						|
 | 
						|
        Parameters
 | 
						|
        ----------
 | 
						|
        axes : matplotlib Axes object, or its list-like
 | 
						|
        xerr : number
 | 
						|
            expected number of x errorbar
 | 
						|
        yerr : number
 | 
						|
            expected number of y errorbar
 | 
						|
        """
 | 
						|
        axes = self._flatten_visible(axes)
 | 
						|
        for ax in axes:
 | 
						|
            containers = ax.containers
 | 
						|
            xerr_count = 0
 | 
						|
            yerr_count = 0
 | 
						|
            for c in containers:
 | 
						|
                has_xerr = getattr(c, "has_xerr", False)
 | 
						|
                has_yerr = getattr(c, "has_yerr", False)
 | 
						|
                if has_xerr:
 | 
						|
                    xerr_count += 1
 | 
						|
                if has_yerr:
 | 
						|
                    yerr_count += 1
 | 
						|
            assert xerr == xerr_count
 | 
						|
            assert yerr == yerr_count
 | 
						|
 | 
						|
    def _check_box_return_type(
 | 
						|
        self, returned, return_type, expected_keys=None, check_ax_title=True
 | 
						|
    ):
 | 
						|
        """
 | 
						|
        Check box returned type is correct
 | 
						|
 | 
						|
        Parameters
 | 
						|
        ----------
 | 
						|
        returned : object to be tested, returned from boxplot
 | 
						|
        return_type : str
 | 
						|
            return_type passed to boxplot
 | 
						|
        expected_keys : list-like, optional
 | 
						|
            group labels in subplot case. If not passed,
 | 
						|
            the function checks assuming boxplot uses single ax
 | 
						|
        check_ax_title : bool
 | 
						|
            Whether to check the ax.title is the same as expected_key
 | 
						|
            Intended to be checked by calling from ``boxplot``.
 | 
						|
            Normal ``plot`` doesn't attach ``ax.title``, it must be disabled.
 | 
						|
        """
 | 
						|
        from matplotlib.axes import Axes
 | 
						|
 | 
						|
        types = {"dict": dict, "axes": Axes, "both": tuple}
 | 
						|
        if expected_keys is None:
 | 
						|
            # should be fixed when the returning default is changed
 | 
						|
            if return_type is None:
 | 
						|
                return_type = "dict"
 | 
						|
 | 
						|
            assert isinstance(returned, types[return_type])
 | 
						|
            if return_type == "both":
 | 
						|
                assert isinstance(returned.ax, Axes)
 | 
						|
                assert isinstance(returned.lines, dict)
 | 
						|
        else:
 | 
						|
            # should be fixed when the returning default is changed
 | 
						|
            if return_type is None:
 | 
						|
                for r in self._flatten_visible(returned):
 | 
						|
                    assert isinstance(r, Axes)
 | 
						|
                return
 | 
						|
 | 
						|
            assert isinstance(returned, Series)
 | 
						|
 | 
						|
            assert sorted(returned.keys()) == sorted(expected_keys)
 | 
						|
            for key, value in returned.items():
 | 
						|
                assert isinstance(value, types[return_type])
 | 
						|
                # check returned dict has correct mapping
 | 
						|
                if return_type == "axes":
 | 
						|
                    if check_ax_title:
 | 
						|
                        assert value.get_title() == key
 | 
						|
                elif return_type == "both":
 | 
						|
                    if check_ax_title:
 | 
						|
                        assert value.ax.get_title() == key
 | 
						|
                    assert isinstance(value.ax, Axes)
 | 
						|
                    assert isinstance(value.lines, dict)
 | 
						|
                elif return_type == "dict":
 | 
						|
                    line = value["medians"][0]
 | 
						|
                    axes = line.axes
 | 
						|
                    if check_ax_title:
 | 
						|
                        assert axes.get_title() == key
 | 
						|
                else:
 | 
						|
                    raise AssertionError
 | 
						|
 | 
						|
    def _check_grid_settings(self, obj, kinds, kws={}):
 | 
						|
        # Make sure plot defaults to rcParams['axes.grid'] setting, GH 9792
 | 
						|
 | 
						|
        import matplotlib as mpl
 | 
						|
 | 
						|
        def is_grid_on():
 | 
						|
            xticks = self.plt.gca().xaxis.get_major_ticks()
 | 
						|
            yticks = self.plt.gca().yaxis.get_major_ticks()
 | 
						|
            # for mpl 2.2.2, gridOn and gridline.get_visible disagree.
 | 
						|
            # for new MPL, they are the same.
 | 
						|
 | 
						|
            if self.mpl_ge_3_1_0:
 | 
						|
                xoff = all(not g.gridline.get_visible() for g in xticks)
 | 
						|
                yoff = all(not g.gridline.get_visible() for g in yticks)
 | 
						|
            else:
 | 
						|
                xoff = all(not g.gridOn for g in xticks)
 | 
						|
                yoff = all(not g.gridOn for g in yticks)
 | 
						|
 | 
						|
            return not (xoff and yoff)
 | 
						|
 | 
						|
        spndx = 1
 | 
						|
        for kind in kinds:
 | 
						|
 | 
						|
            self.plt.subplot(1, 4 * len(kinds), spndx)
 | 
						|
            spndx += 1
 | 
						|
            mpl.rc("axes", grid=False)
 | 
						|
            obj.plot(kind=kind, **kws)
 | 
						|
            assert not is_grid_on()
 | 
						|
 | 
						|
            self.plt.subplot(1, 4 * len(kinds), spndx)
 | 
						|
            spndx += 1
 | 
						|
            mpl.rc("axes", grid=True)
 | 
						|
            obj.plot(kind=kind, grid=False, **kws)
 | 
						|
            assert not is_grid_on()
 | 
						|
 | 
						|
            if kind not in ["pie", "hexbin", "scatter"]:
 | 
						|
                self.plt.subplot(1, 4 * len(kinds), spndx)
 | 
						|
                spndx += 1
 | 
						|
                mpl.rc("axes", grid=True)
 | 
						|
                obj.plot(kind=kind, **kws)
 | 
						|
                assert is_grid_on()
 | 
						|
 | 
						|
                self.plt.subplot(1, 4 * len(kinds), spndx)
 | 
						|
                spndx += 1
 | 
						|
                mpl.rc("axes", grid=False)
 | 
						|
                obj.plot(kind=kind, grid=True, **kws)
 | 
						|
                assert is_grid_on()
 | 
						|
 | 
						|
    def _unpack_cycler(self, rcParams, field="color"):
 | 
						|
        """
 | 
						|
        Auxiliary function for correctly unpacking cycler after MPL >= 1.5
 | 
						|
        """
 | 
						|
        return [v[field] for v in rcParams["axes.prop_cycle"]]
 | 
						|
 | 
						|
    def get_x_axis(self, ax):
 | 
						|
        return ax._shared_axes["x"] if self.compat.mpl_ge_3_5_0() else ax._shared_x_axes
 | 
						|
 | 
						|
    def get_y_axis(self, ax):
 | 
						|
        return ax._shared_axes["y"] if self.compat.mpl_ge_3_5_0() else ax._shared_y_axes
 | 
						|
 | 
						|
 | 
						|
def _check_plot_works(f, filterwarnings="always", default_axes=False, **kwargs):
 | 
						|
    """
 | 
						|
    Create plot and ensure that plot return object is valid.
 | 
						|
 | 
						|
    Parameters
 | 
						|
    ----------
 | 
						|
    f : func
 | 
						|
        Plotting function.
 | 
						|
    filterwarnings : str
 | 
						|
        Warnings filter.
 | 
						|
        See https://docs.python.org/3/library/warnings.html#warning-filter
 | 
						|
    default_axes : bool, optional
 | 
						|
        If False (default):
 | 
						|
            - If `ax` not in `kwargs`, then create subplot(211) and plot there
 | 
						|
            - Create new subplot(212) and plot there as well
 | 
						|
            - Mind special corner case for bootstrap_plot (see `_gen_two_subplots`)
 | 
						|
        If True:
 | 
						|
            - Simply run plotting function with kwargs provided
 | 
						|
            - All required axes instances will be created automatically
 | 
						|
            - It is recommended to use it when the plotting function
 | 
						|
            creates multiple axes itself. It helps avoid warnings like
 | 
						|
            'UserWarning: To output multiple subplots,
 | 
						|
            the figure containing the passed axes is being cleared'
 | 
						|
    **kwargs
 | 
						|
        Keyword arguments passed to the plotting function.
 | 
						|
 | 
						|
    Returns
 | 
						|
    -------
 | 
						|
    Plot object returned by the last plotting.
 | 
						|
    """
 | 
						|
    import matplotlib.pyplot as plt
 | 
						|
 | 
						|
    if default_axes:
 | 
						|
        gen_plots = _gen_default_plot
 | 
						|
    else:
 | 
						|
        gen_plots = _gen_two_subplots
 | 
						|
 | 
						|
    ret = None
 | 
						|
    with warnings.catch_warnings():
 | 
						|
        warnings.simplefilter(filterwarnings)
 | 
						|
        try:
 | 
						|
            fig = kwargs.get("figure", plt.gcf())
 | 
						|
            plt.clf()
 | 
						|
 | 
						|
            for ret in gen_plots(f, fig, **kwargs):
 | 
						|
                tm.assert_is_valid_plot_return_object(ret)
 | 
						|
 | 
						|
            with tm.ensure_clean(return_filelike=True) as path:
 | 
						|
                plt.savefig(path)
 | 
						|
 | 
						|
        except Exception as err:
 | 
						|
            raise err
 | 
						|
        finally:
 | 
						|
            tm.close(fig)
 | 
						|
 | 
						|
        return ret
 | 
						|
 | 
						|
 | 
						|
def _gen_default_plot(f, fig, **kwargs):
 | 
						|
    """
 | 
						|
    Create plot in a default way.
 | 
						|
    """
 | 
						|
    yield f(**kwargs)
 | 
						|
 | 
						|
 | 
						|
def _gen_two_subplots(f, fig, **kwargs):
 | 
						|
    """
 | 
						|
    Create plot on two subplots forcefully created.
 | 
						|
    """
 | 
						|
    if "ax" not in kwargs:
 | 
						|
        fig.add_subplot(211)
 | 
						|
    yield f(**kwargs)
 | 
						|
 | 
						|
    if f is pd.plotting.bootstrap_plot:
 | 
						|
        assert "ax" not in kwargs
 | 
						|
    else:
 | 
						|
        kwargs["ax"] = fig.add_subplot(212)
 | 
						|
    yield f(**kwargs)
 | 
						|
 | 
						|
 | 
						|
def curpath():
 | 
						|
    pth, _ = os.path.split(os.path.abspath(__file__))
 | 
						|
    return pth
 |