This commit is contained in:
2024-11-29 18:15:30 +00:00
parent 40aade2d8e
commit bc9415586e
5298 changed files with 1938676 additions and 80 deletions

View File

@ -0,0 +1,10 @@
from . import axes_size as Size
from .axes_divider import Divider, SubplotDivider, make_axes_locatable
from .axes_grid import AxesGrid, Grid, ImageGrid
from .parasite_axes import host_subplot, host_axes
__all__ = ["Size",
"Divider", "SubplotDivider", "make_axes_locatable",
"AxesGrid", "Grid", "ImageGrid",
"host_subplot", "host_axes"]

View File

@ -0,0 +1,462 @@
from matplotlib import _api, transforms
from matplotlib.offsetbox import (AnchoredOffsetbox, AuxTransformBox,
DrawingArea, TextArea, VPacker)
from matplotlib.patches import (Rectangle, Ellipse, ArrowStyle,
FancyArrowPatch, PathPatch)
from matplotlib.text import TextPath
__all__ = ['AnchoredDrawingArea', 'AnchoredAuxTransformBox',
'AnchoredEllipse', 'AnchoredSizeBar', 'AnchoredDirectionArrows']
class AnchoredDrawingArea(AnchoredOffsetbox):
def __init__(self, width, height, xdescent, ydescent,
loc, pad=0.4, borderpad=0.5, prop=None, frameon=True,
**kwargs):
"""
An anchored container with a fixed size and fillable `.DrawingArea`.
Artists added to the *drawing_area* will have their coordinates
interpreted as pixels. Any transformations set on the artists will be
overridden.
Parameters
----------
width, height : float
Width and height of the container, in pixels.
xdescent, ydescent : float
Descent of the container in the x- and y- direction, in pixels.
loc : str
Location of this artist. Valid locations are
'upper left', 'upper center', 'upper right',
'center left', 'center', 'center right',
'lower left', 'lower center', 'lower right'.
For backward compatibility, numeric values are accepted as well.
See the parameter *loc* of `.Legend` for details.
pad : float, default: 0.4
Padding around the child objects, in fraction of the font size.
borderpad : float, default: 0.5
Border padding, in fraction of the font size.
prop : `~matplotlib.font_manager.FontProperties`, optional
Font property used as a reference for paddings.
frameon : bool, default: True
If True, draw a box around this artist.
**kwargs
Keyword arguments forwarded to `.AnchoredOffsetbox`.
Attributes
----------
drawing_area : `~matplotlib.offsetbox.DrawingArea`
A container for artists to display.
Examples
--------
To display blue and red circles of different sizes in the upper right
of an Axes *ax*:
>>> ada = AnchoredDrawingArea(20, 20, 0, 0,
... loc='upper right', frameon=False)
>>> ada.drawing_area.add_artist(Circle((10, 10), 10, fc="b"))
>>> ada.drawing_area.add_artist(Circle((30, 10), 5, fc="r"))
>>> ax.add_artist(ada)
"""
self.da = DrawingArea(width, height, xdescent, ydescent)
self.drawing_area = self.da
super().__init__(
loc, pad=pad, borderpad=borderpad, child=self.da, prop=None,
frameon=frameon, **kwargs
)
class AnchoredAuxTransformBox(AnchoredOffsetbox):
def __init__(self, transform, loc,
pad=0.4, borderpad=0.5, prop=None, frameon=True, **kwargs):
"""
An anchored container with transformed coordinates.
Artists added to the *drawing_area* are scaled according to the
coordinates of the transformation used. The dimensions of this artist
will scale to contain the artists added.
Parameters
----------
transform : `~matplotlib.transforms.Transform`
The transformation object for the coordinate system in use, i.e.,
:attr:`matplotlib.axes.Axes.transData`.
loc : str
Location of this artist. Valid locations are
'upper left', 'upper center', 'upper right',
'center left', 'center', 'center right',
'lower left', 'lower center', 'lower right'.
For backward compatibility, numeric values are accepted as well.
See the parameter *loc* of `.Legend` for details.
pad : float, default: 0.4
Padding around the child objects, in fraction of the font size.
borderpad : float, default: 0.5
Border padding, in fraction of the font size.
prop : `~matplotlib.font_manager.FontProperties`, optional
Font property used as a reference for paddings.
frameon : bool, default: True
If True, draw a box around this artist.
**kwargs
Keyword arguments forwarded to `.AnchoredOffsetbox`.
Attributes
----------
drawing_area : `~matplotlib.offsetbox.AuxTransformBox`
A container for artists to display.
Examples
--------
To display an ellipse in the upper left, with a width of 0.1 and
height of 0.4 in data coordinates:
>>> box = AnchoredAuxTransformBox(ax.transData, loc='upper left')
>>> el = Ellipse((0, 0), width=0.1, height=0.4, angle=30)
>>> box.drawing_area.add_artist(el)
>>> ax.add_artist(box)
"""
self.drawing_area = AuxTransformBox(transform)
super().__init__(loc, pad=pad, borderpad=borderpad,
child=self.drawing_area, prop=prop, frameon=frameon,
**kwargs)
@_api.deprecated("3.8")
class AnchoredEllipse(AnchoredOffsetbox):
def __init__(self, transform, width, height, angle, loc,
pad=0.1, borderpad=0.1, prop=None, frameon=True, **kwargs):
"""
Draw an anchored ellipse of a given size.
Parameters
----------
transform : `~matplotlib.transforms.Transform`
The transformation object for the coordinate system in use, i.e.,
:attr:`matplotlib.axes.Axes.transData`.
width, height : float
Width and height of the ellipse, given in coordinates of
*transform*.
angle : float
Rotation of the ellipse, in degrees, anti-clockwise.
loc : str
Location of the ellipse. Valid locations are
'upper left', 'upper center', 'upper right',
'center left', 'center', 'center right',
'lower left', 'lower center', 'lower right'.
For backward compatibility, numeric values are accepted as well.
See the parameter *loc* of `.Legend` for details.
pad : float, default: 0.1
Padding around the ellipse, in fraction of the font size.
borderpad : float, default: 0.1
Border padding, in fraction of the font size.
frameon : bool, default: True
If True, draw a box around the ellipse.
prop : `~matplotlib.font_manager.FontProperties`, optional
Font property used as a reference for paddings.
**kwargs
Keyword arguments forwarded to `.AnchoredOffsetbox`.
Attributes
----------
ellipse : `~matplotlib.patches.Ellipse`
Ellipse patch drawn.
"""
self._box = AuxTransformBox(transform)
self.ellipse = Ellipse((0, 0), width, height, angle=angle)
self._box.add_artist(self.ellipse)
super().__init__(loc, pad=pad, borderpad=borderpad, child=self._box,
prop=prop, frameon=frameon, **kwargs)
class AnchoredSizeBar(AnchoredOffsetbox):
def __init__(self, transform, size, label, loc,
pad=0.1, borderpad=0.1, sep=2,
frameon=True, size_vertical=0, color='black',
label_top=False, fontproperties=None, fill_bar=None,
**kwargs):
"""
Draw a horizontal scale bar with a center-aligned label underneath.
Parameters
----------
transform : `~matplotlib.transforms.Transform`
The transformation object for the coordinate system in use, i.e.,
:attr:`matplotlib.axes.Axes.transData`.
size : float
Horizontal length of the size bar, given in coordinates of
*transform*.
label : str
Label to display.
loc : str
Location of the size bar. Valid locations are
'upper left', 'upper center', 'upper right',
'center left', 'center', 'center right',
'lower left', 'lower center', 'lower right'.
For backward compatibility, numeric values are accepted as well.
See the parameter *loc* of `.Legend` for details.
pad : float, default: 0.1
Padding around the label and size bar, in fraction of the font
size.
borderpad : float, default: 0.1
Border padding, in fraction of the font size.
sep : float, default: 2
Separation between the label and the size bar, in points.
frameon : bool, default: True
If True, draw a box around the horizontal bar and label.
size_vertical : float, default: 0
Vertical length of the size bar, given in coordinates of
*transform*.
color : str, default: 'black'
Color for the size bar and label.
label_top : bool, default: False
If True, the label will be over the size bar.
fontproperties : `~matplotlib.font_manager.FontProperties`, optional
Font properties for the label text.
fill_bar : bool, optional
If True and if *size_vertical* is nonzero, the size bar will
be filled in with the color specified by the size bar.
Defaults to True if *size_vertical* is greater than
zero and False otherwise.
**kwargs
Keyword arguments forwarded to `.AnchoredOffsetbox`.
Attributes
----------
size_bar : `~matplotlib.offsetbox.AuxTransformBox`
Container for the size bar.
txt_label : `~matplotlib.offsetbox.TextArea`
Container for the label of the size bar.
Notes
-----
If *prop* is passed as a keyword argument, but *fontproperties* is
not, then *prop* is assumed to be the intended *fontproperties*.
Using both *prop* and *fontproperties* is not supported.
Examples
--------
>>> import matplotlib.pyplot as plt
>>> import numpy as np
>>> from mpl_toolkits.axes_grid1.anchored_artists import (
... AnchoredSizeBar)
>>> fig, ax = plt.subplots()
>>> ax.imshow(np.random.random((10, 10)))
>>> bar = AnchoredSizeBar(ax.transData, 3, '3 data units', 4)
>>> ax.add_artist(bar)
>>> fig.show()
Using all the optional parameters
>>> import matplotlib.font_manager as fm
>>> fontprops = fm.FontProperties(size=14, family='monospace')
>>> bar = AnchoredSizeBar(ax.transData, 3, '3 units', 4, pad=0.5,
... sep=5, borderpad=0.5, frameon=False,
... size_vertical=0.5, color='white',
... fontproperties=fontprops)
"""
if fill_bar is None:
fill_bar = size_vertical > 0
self.size_bar = AuxTransformBox(transform)
self.size_bar.add_artist(Rectangle((0, 0), size, size_vertical,
fill=fill_bar, facecolor=color,
edgecolor=color))
if fontproperties is None and 'prop' in kwargs:
fontproperties = kwargs.pop('prop')
if fontproperties is None:
textprops = {'color': color}
else:
textprops = {'color': color, 'fontproperties': fontproperties}
self.txt_label = TextArea(label, textprops=textprops)
if label_top:
_box_children = [self.txt_label, self.size_bar]
else:
_box_children = [self.size_bar, self.txt_label]
self._box = VPacker(children=_box_children,
align="center",
pad=0, sep=sep)
super().__init__(loc, pad=pad, borderpad=borderpad, child=self._box,
prop=fontproperties, frameon=frameon, **kwargs)
class AnchoredDirectionArrows(AnchoredOffsetbox):
def __init__(self, transform, label_x, label_y, length=0.15,
fontsize=0.08, loc='upper left', angle=0, aspect_ratio=1,
pad=0.4, borderpad=0.4, frameon=False, color='w', alpha=1,
sep_x=0.01, sep_y=0, fontproperties=None, back_length=0.15,
head_width=10, head_length=15, tail_width=2,
text_props=None, arrow_props=None,
**kwargs):
"""
Draw two perpendicular arrows to indicate directions.
Parameters
----------
transform : `~matplotlib.transforms.Transform`
The transformation object for the coordinate system in use, i.e.,
:attr:`matplotlib.axes.Axes.transAxes`.
label_x, label_y : str
Label text for the x and y arrows
length : float, default: 0.15
Length of the arrow, given in coordinates of *transform*.
fontsize : float, default: 0.08
Size of label strings, given in coordinates of *transform*.
loc : str, default: 'upper left'
Location of the arrow. Valid locations are
'upper left', 'upper center', 'upper right',
'center left', 'center', 'center right',
'lower left', 'lower center', 'lower right'.
For backward compatibility, numeric values are accepted as well.
See the parameter *loc* of `.Legend` for details.
angle : float, default: 0
The angle of the arrows in degrees.
aspect_ratio : float, default: 1
The ratio of the length of arrow_x and arrow_y.
Negative numbers can be used to change the direction.
pad : float, default: 0.4
Padding around the labels and arrows, in fraction of the font size.
borderpad : float, default: 0.4
Border padding, in fraction of the font size.
frameon : bool, default: False
If True, draw a box around the arrows and labels.
color : str, default: 'white'
Color for the arrows and labels.
alpha : float, default: 1
Alpha values of the arrows and labels
sep_x, sep_y : float, default: 0.01 and 0 respectively
Separation between the arrows and labels in coordinates of
*transform*.
fontproperties : `~matplotlib.font_manager.FontProperties`, optional
Font properties for the label text.
back_length : float, default: 0.15
Fraction of the arrow behind the arrow crossing.
head_width : float, default: 10
Width of arrow head, sent to `.ArrowStyle`.
head_length : float, default: 15
Length of arrow head, sent to `.ArrowStyle`.
tail_width : float, default: 2
Width of arrow tail, sent to `.ArrowStyle`.
text_props, arrow_props : dict
Properties of the text and arrows, passed to `.TextPath` and
`.FancyArrowPatch`.
**kwargs
Keyword arguments forwarded to `.AnchoredOffsetbox`.
Attributes
----------
arrow_x, arrow_y : `~matplotlib.patches.FancyArrowPatch`
Arrow x and y
text_path_x, text_path_y : `~matplotlib.text.TextPath`
Path for arrow labels
p_x, p_y : `~matplotlib.patches.PathPatch`
Patch for arrow labels
box : `~matplotlib.offsetbox.AuxTransformBox`
Container for the arrows and labels.
Notes
-----
If *prop* is passed as a keyword argument, but *fontproperties* is
not, then *prop* is assumed to be the intended *fontproperties*.
Using both *prop* and *fontproperties* is not supported.
Examples
--------
>>> import matplotlib.pyplot as plt
>>> import numpy as np
>>> from mpl_toolkits.axes_grid1.anchored_artists import (
... AnchoredDirectionArrows)
>>> fig, ax = plt.subplots()
>>> ax.imshow(np.random.random((10, 10)))
>>> arrows = AnchoredDirectionArrows(ax.transAxes, '111', '110')
>>> ax.add_artist(arrows)
>>> fig.show()
Using several of the optional parameters, creating downward pointing
arrow and high contrast text labels.
>>> import matplotlib.font_manager as fm
>>> fontprops = fm.FontProperties(family='monospace')
>>> arrows = AnchoredDirectionArrows(ax.transAxes, 'East', 'South',
... loc='lower left', color='k',
... aspect_ratio=-1, sep_x=0.02,
... sep_y=-0.01,
... text_props={'ec':'w', 'fc':'k'},
... fontproperties=fontprops)
"""
if arrow_props is None:
arrow_props = {}
if text_props is None:
text_props = {}
arrowstyle = ArrowStyle("Simple",
head_width=head_width,
head_length=head_length,
tail_width=tail_width)
if fontproperties is None and 'prop' in kwargs:
fontproperties = kwargs.pop('prop')
if 'color' not in arrow_props:
arrow_props['color'] = color
if 'alpha' not in arrow_props:
arrow_props['alpha'] = alpha
if 'color' not in text_props:
text_props['color'] = color
if 'alpha' not in text_props:
text_props['alpha'] = alpha
t_start = transform
t_end = t_start + transforms.Affine2D().rotate_deg(angle)
self.box = AuxTransformBox(t_end)
length_x = length
length_y = length*aspect_ratio
self.arrow_x = FancyArrowPatch(
(0, back_length*length_y),
(length_x, back_length*length_y),
arrowstyle=arrowstyle,
shrinkA=0.0,
shrinkB=0.0,
**arrow_props)
self.arrow_y = FancyArrowPatch(
(back_length*length_x, 0),
(back_length*length_x, length_y),
arrowstyle=arrowstyle,
shrinkA=0.0,
shrinkB=0.0,
**arrow_props)
self.box.add_artist(self.arrow_x)
self.box.add_artist(self.arrow_y)
text_path_x = TextPath((
length_x+sep_x, back_length*length_y+sep_y), label_x,
size=fontsize, prop=fontproperties)
self.p_x = PathPatch(text_path_x, transform=t_start, **text_props)
self.box.add_artist(self.p_x)
text_path_y = TextPath((
length_x*back_length+sep_x, length_y*(1-back_length)+sep_y),
label_y, size=fontsize, prop=fontproperties)
self.p_y = PathPatch(text_path_y, **text_props)
self.box.add_artist(self.p_y)
super().__init__(loc, pad=pad, borderpad=borderpad, child=self.box,
frameon=frameon, **kwargs)

View File

@ -0,0 +1,694 @@
"""
Helper classes to adjust the positions of multiple axes at drawing time.
"""
import functools
import numpy as np
import matplotlib as mpl
from matplotlib import _api
from matplotlib.gridspec import SubplotSpec
import matplotlib.transforms as mtransforms
from . import axes_size as Size
class Divider:
"""
An Axes positioning class.
The divider is initialized with lists of horizontal and vertical sizes
(:mod:`mpl_toolkits.axes_grid1.axes_size`) based on which a given
rectangular area will be divided.
The `new_locator` method then creates a callable object
that can be used as the *axes_locator* of the axes.
"""
def __init__(self, fig, pos, horizontal, vertical,
aspect=None, anchor="C"):
"""
Parameters
----------
fig : Figure
pos : tuple of 4 floats
Position of the rectangle that will be divided.
horizontal : list of :mod:`~mpl_toolkits.axes_grid1.axes_size`
Sizes for horizontal division.
vertical : list of :mod:`~mpl_toolkits.axes_grid1.axes_size`
Sizes for vertical division.
aspect : bool, optional
Whether overall rectangular area is reduced so that the relative
part of the horizontal and vertical scales have the same scale.
anchor : (float, float) or {'C', 'SW', 'S', 'SE', 'E', 'NE', 'N', \
'NW', 'W'}, default: 'C'
Placement of the reduced rectangle, when *aspect* is True.
"""
self._fig = fig
self._pos = pos
self._horizontal = horizontal
self._vertical = vertical
self._anchor = anchor
self.set_anchor(anchor)
self._aspect = aspect
self._xrefindex = 0
self._yrefindex = 0
self._locator = None
def get_horizontal_sizes(self, renderer):
return np.array([s.get_size(renderer) for s in self.get_horizontal()])
def get_vertical_sizes(self, renderer):
return np.array([s.get_size(renderer) for s in self.get_vertical()])
def set_position(self, pos):
"""
Set the position of the rectangle.
Parameters
----------
pos : tuple of 4 floats
position of the rectangle that will be divided
"""
self._pos = pos
def get_position(self):
"""Return the position of the rectangle."""
return self._pos
def set_anchor(self, anchor):
"""
Parameters
----------
anchor : (float, float) or {'C', 'SW', 'S', 'SE', 'E', 'NE', 'N', \
'NW', 'W'}
Either an (*x*, *y*) pair of relative coordinates (0 is left or
bottom, 1 is right or top), 'C' (center), or a cardinal direction
('SW', southwest, is bottom left, etc.).
See Also
--------
.Axes.set_anchor
"""
if isinstance(anchor, str):
_api.check_in_list(mtransforms.Bbox.coefs, anchor=anchor)
elif not isinstance(anchor, (tuple, list)) or len(anchor) != 2:
raise TypeError("anchor must be str or 2-tuple")
self._anchor = anchor
def get_anchor(self):
"""Return the anchor."""
return self._anchor
def get_subplotspec(self):
return None
def set_horizontal(self, h):
"""
Parameters
----------
h : list of :mod:`~mpl_toolkits.axes_grid1.axes_size`
sizes for horizontal division
"""
self._horizontal = h
def get_horizontal(self):
"""Return horizontal sizes."""
return self._horizontal
def set_vertical(self, v):
"""
Parameters
----------
v : list of :mod:`~mpl_toolkits.axes_grid1.axes_size`
sizes for vertical division
"""
self._vertical = v
def get_vertical(self):
"""Return vertical sizes."""
return self._vertical
def set_aspect(self, aspect=False):
"""
Parameters
----------
aspect : bool
"""
self._aspect = aspect
def get_aspect(self):
"""Return aspect."""
return self._aspect
def set_locator(self, _locator):
self._locator = _locator
def get_locator(self):
return self._locator
def get_position_runtime(self, ax, renderer):
if self._locator is None:
return self.get_position()
else:
return self._locator(ax, renderer).bounds
@staticmethod
def _calc_k(sizes, total):
# sizes is a (n, 2) array of (rel_size, abs_size); this method finds
# the k factor such that sum(rel_size * k + abs_size) == total.
rel_sum, abs_sum = sizes.sum(0)
return (total - abs_sum) / rel_sum if rel_sum else 0
@staticmethod
def _calc_offsets(sizes, k):
# Apply k factors to (n, 2) sizes array of (rel_size, abs_size); return
# the resulting cumulative offset positions.
return np.cumsum([0, *(sizes @ [k, 1])])
def new_locator(self, nx, ny, nx1=None, ny1=None):
"""
Return an axes locator callable for the specified cell.
Parameters
----------
nx, nx1 : int
Integers specifying the column-position of the
cell. When *nx1* is None, a single *nx*-th column is
specified. Otherwise, location of columns spanning between *nx*
to *nx1* (but excluding *nx1*-th column) is specified.
ny, ny1 : int
Same as *nx* and *nx1*, but for row positions.
"""
if nx1 is None:
nx1 = nx + 1
if ny1 is None:
ny1 = ny + 1
# append_size("left") adds a new size at the beginning of the
# horizontal size lists; this shift transforms e.g.
# new_locator(nx=2, ...) into effectively new_locator(nx=3, ...). To
# take that into account, instead of recording nx, we record
# nx-self._xrefindex, where _xrefindex is shifted by 1 by each
# append_size("left"), and re-add self._xrefindex back to nx in
# _locate, when the actual axes position is computed. Ditto for y.
xref = self._xrefindex
yref = self._yrefindex
locator = functools.partial(
self._locate, nx - xref, ny - yref, nx1 - xref, ny1 - yref)
locator.get_subplotspec = self.get_subplotspec
return locator
@_api.deprecated(
"3.8", alternative="divider.new_locator(...)(ax, renderer)")
def locate(self, nx, ny, nx1=None, ny1=None, axes=None, renderer=None):
"""
Implementation of ``divider.new_locator().__call__``.
Parameters
----------
nx, nx1 : int
Integers specifying the column-position of the cell. When *nx1* is
None, a single *nx*-th column is specified. Otherwise, the
location of columns spanning between *nx* to *nx1* (but excluding
*nx1*-th column) is specified.
ny, ny1 : int
Same as *nx* and *nx1*, but for row positions.
axes
renderer
"""
xref = self._xrefindex
yref = self._yrefindex
return self._locate(
nx - xref, (nx + 1 if nx1 is None else nx1) - xref,
ny - yref, (ny + 1 if ny1 is None else ny1) - yref,
axes, renderer)
def _locate(self, nx, ny, nx1, ny1, axes, renderer):
"""
Implementation of ``divider.new_locator().__call__``.
The axes locator callable returned by ``new_locator()`` is created as
a `functools.partial` of this method with *nx*, *ny*, *nx1*, and *ny1*
specifying the requested cell.
"""
nx += self._xrefindex
nx1 += self._xrefindex
ny += self._yrefindex
ny1 += self._yrefindex
fig_w, fig_h = self._fig.bbox.size / self._fig.dpi
x, y, w, h = self.get_position_runtime(axes, renderer)
hsizes = self.get_horizontal_sizes(renderer)
vsizes = self.get_vertical_sizes(renderer)
k_h = self._calc_k(hsizes, fig_w * w)
k_v = self._calc_k(vsizes, fig_h * h)
if self.get_aspect():
k = min(k_h, k_v)
ox = self._calc_offsets(hsizes, k)
oy = self._calc_offsets(vsizes, k)
ww = (ox[-1] - ox[0]) / fig_w
hh = (oy[-1] - oy[0]) / fig_h
pb = mtransforms.Bbox.from_bounds(x, y, w, h)
pb1 = mtransforms.Bbox.from_bounds(x, y, ww, hh)
x0, y0 = pb1.anchored(self.get_anchor(), pb).p0
else:
ox = self._calc_offsets(hsizes, k_h)
oy = self._calc_offsets(vsizes, k_v)
x0, y0 = x, y
if nx1 is None:
nx1 = -1
if ny1 is None:
ny1 = -1
x1, w1 = x0 + ox[nx] / fig_w, (ox[nx1] - ox[nx]) / fig_w
y1, h1 = y0 + oy[ny] / fig_h, (oy[ny1] - oy[ny]) / fig_h
return mtransforms.Bbox.from_bounds(x1, y1, w1, h1)
def append_size(self, position, size):
_api.check_in_list(["left", "right", "bottom", "top"],
position=position)
if position == "left":
self._horizontal.insert(0, size)
self._xrefindex += 1
elif position == "right":
self._horizontal.append(size)
elif position == "bottom":
self._vertical.insert(0, size)
self._yrefindex += 1
else: # 'top'
self._vertical.append(size)
def add_auto_adjustable_area(self, use_axes, pad=0.1, adjust_dirs=None):
"""
Add auto-adjustable padding around *use_axes* to take their decorations
(title, labels, ticks, ticklabels) into account during layout.
Parameters
----------
use_axes : `~matplotlib.axes.Axes` or list of `~matplotlib.axes.Axes`
The Axes whose decorations are taken into account.
pad : float, default: 0.1
Additional padding in inches.
adjust_dirs : list of {"left", "right", "bottom", "top"}, optional
The sides where padding is added; defaults to all four sides.
"""
if adjust_dirs is None:
adjust_dirs = ["left", "right", "bottom", "top"]
for d in adjust_dirs:
self.append_size(d, Size._AxesDecorationsSize(use_axes, d) + pad)
@_api.deprecated("3.8")
class AxesLocator:
"""
A callable object which returns the position and size of a given
`.AxesDivider` cell.
"""
def __init__(self, axes_divider, nx, ny, nx1=None, ny1=None):
"""
Parameters
----------
axes_divider : `~mpl_toolkits.axes_grid1.axes_divider.AxesDivider`
nx, nx1 : int
Integers specifying the column-position of the
cell. When *nx1* is None, a single *nx*-th column is
specified. Otherwise, location of columns spanning between *nx*
to *nx1* (but excluding *nx1*-th column) is specified.
ny, ny1 : int
Same as *nx* and *nx1*, but for row positions.
"""
self._axes_divider = axes_divider
_xrefindex = axes_divider._xrefindex
_yrefindex = axes_divider._yrefindex
self._nx, self._ny = nx - _xrefindex, ny - _yrefindex
if nx1 is None:
nx1 = len(self._axes_divider)
if ny1 is None:
ny1 = len(self._axes_divider[0])
self._nx1 = nx1 - _xrefindex
self._ny1 = ny1 - _yrefindex
def __call__(self, axes, renderer):
_xrefindex = self._axes_divider._xrefindex
_yrefindex = self._axes_divider._yrefindex
return self._axes_divider.locate(self._nx + _xrefindex,
self._ny + _yrefindex,
self._nx1 + _xrefindex,
self._ny1 + _yrefindex,
axes,
renderer)
def get_subplotspec(self):
return self._axes_divider.get_subplotspec()
class SubplotDivider(Divider):
"""
The Divider class whose rectangle area is specified as a subplot geometry.
"""
def __init__(self, fig, *args, horizontal=None, vertical=None,
aspect=None, anchor='C'):
"""
Parameters
----------
fig : `~matplotlib.figure.Figure`
*args : tuple (*nrows*, *ncols*, *index*) or int
The array of subplots in the figure has dimensions ``(nrows,
ncols)``, and *index* is the index of the subplot being created.
*index* starts at 1 in the upper left corner and increases to the
right.
If *nrows*, *ncols*, and *index* are all single digit numbers, then
*args* can be passed as a single 3-digit number (e.g. 234 for
(2, 3, 4)).
horizontal : list of :mod:`~mpl_toolkits.axes_grid1.axes_size`, optional
Sizes for horizontal division.
vertical : list of :mod:`~mpl_toolkits.axes_grid1.axes_size`, optional
Sizes for vertical division.
aspect : bool, optional
Whether overall rectangular area is reduced so that the relative
part of the horizontal and vertical scales have the same scale.
anchor : (float, float) or {'C', 'SW', 'S', 'SE', 'E', 'NE', 'N', \
'NW', 'W'}, default: 'C'
Placement of the reduced rectangle, when *aspect* is True.
"""
self.figure = fig
super().__init__(fig, [0, 0, 1, 1],
horizontal=horizontal or [], vertical=vertical or [],
aspect=aspect, anchor=anchor)
self.set_subplotspec(SubplotSpec._from_subplot_args(fig, args))
def get_position(self):
"""Return the bounds of the subplot box."""
return self.get_subplotspec().get_position(self.figure).bounds
def get_subplotspec(self):
"""Get the SubplotSpec instance."""
return self._subplotspec
def set_subplotspec(self, subplotspec):
"""Set the SubplotSpec instance."""
self._subplotspec = subplotspec
self.set_position(subplotspec.get_position(self.figure))
class AxesDivider(Divider):
"""
Divider based on the preexisting axes.
"""
def __init__(self, axes, xref=None, yref=None):
"""
Parameters
----------
axes : :class:`~matplotlib.axes.Axes`
xref
yref
"""
self._axes = axes
if xref is None:
self._xref = Size.AxesX(axes)
else:
self._xref = xref
if yref is None:
self._yref = Size.AxesY(axes)
else:
self._yref = yref
super().__init__(fig=axes.get_figure(), pos=None,
horizontal=[self._xref], vertical=[self._yref],
aspect=None, anchor="C")
def _get_new_axes(self, *, axes_class=None, **kwargs):
axes = self._axes
if axes_class is None:
axes_class = type(axes)
return axes_class(axes.get_figure(), axes.get_position(original=True),
**kwargs)
def new_horizontal(self, size, pad=None, pack_start=False, **kwargs):
"""
Helper method for ``append_axes("left")`` and ``append_axes("right")``.
See the documentation of `append_axes` for more details.
:meta private:
"""
if pad is None:
pad = mpl.rcParams["figure.subplot.wspace"] * self._xref
pos = "left" if pack_start else "right"
if pad:
if not isinstance(pad, Size._Base):
pad = Size.from_any(pad, fraction_ref=self._xref)
self.append_size(pos, pad)
if not isinstance(size, Size._Base):
size = Size.from_any(size, fraction_ref=self._xref)
self.append_size(pos, size)
locator = self.new_locator(
nx=0 if pack_start else len(self._horizontal) - 1,
ny=self._yrefindex)
ax = self._get_new_axes(**kwargs)
ax.set_axes_locator(locator)
return ax
def new_vertical(self, size, pad=None, pack_start=False, **kwargs):
"""
Helper method for ``append_axes("top")`` and ``append_axes("bottom")``.
See the documentation of `append_axes` for more details.
:meta private:
"""
if pad is None:
pad = mpl.rcParams["figure.subplot.hspace"] * self._yref
pos = "bottom" if pack_start else "top"
if pad:
if not isinstance(pad, Size._Base):
pad = Size.from_any(pad, fraction_ref=self._yref)
self.append_size(pos, pad)
if not isinstance(size, Size._Base):
size = Size.from_any(size, fraction_ref=self._yref)
self.append_size(pos, size)
locator = self.new_locator(
nx=self._xrefindex,
ny=0 if pack_start else len(self._vertical) - 1)
ax = self._get_new_axes(**kwargs)
ax.set_axes_locator(locator)
return ax
def append_axes(self, position, size, pad=None, *, axes_class=None,
**kwargs):
"""
Add a new axes on a given side of the main axes.
Parameters
----------
position : {"left", "right", "bottom", "top"}
Where the new axes is positioned relative to the main axes.
size : :mod:`~mpl_toolkits.axes_grid1.axes_size` or float or str
The axes width or height. float or str arguments are interpreted
as ``axes_size.from_any(size, AxesX(<main_axes>))`` for left or
right axes, and likewise with ``AxesY`` for bottom or top axes.
pad : :mod:`~mpl_toolkits.axes_grid1.axes_size` or float or str
Padding between the axes. float or str arguments are interpreted
as for *size*. Defaults to :rc:`figure.subplot.wspace` times the
main Axes width (left or right axes) or :rc:`figure.subplot.hspace`
times the main Axes height (bottom or top axes).
axes_class : subclass type of `~.axes.Axes`, optional
The type of the new axes. Defaults to the type of the main axes.
**kwargs
All extra keywords arguments are passed to the created axes.
"""
create_axes, pack_start = _api.check_getitem({
"left": (self.new_horizontal, True),
"right": (self.new_horizontal, False),
"bottom": (self.new_vertical, True),
"top": (self.new_vertical, False),
}, position=position)
ax = create_axes(
size, pad, pack_start=pack_start, axes_class=axes_class, **kwargs)
self._fig.add_axes(ax)
return ax
def get_aspect(self):
if self._aspect is None:
aspect = self._axes.get_aspect()
if aspect == "auto":
return False
else:
return True
else:
return self._aspect
def get_position(self):
if self._pos is None:
bbox = self._axes.get_position(original=True)
return bbox.bounds
else:
return self._pos
def get_anchor(self):
if self._anchor is None:
return self._axes.get_anchor()
else:
return self._anchor
def get_subplotspec(self):
return self._axes.get_subplotspec()
# Helper for HBoxDivider/VBoxDivider.
# The variable names are written for a horizontal layout, but the calculations
# work identically for vertical layouts.
def _locate(x, y, w, h, summed_widths, equal_heights, fig_w, fig_h, anchor):
total_width = fig_w * w
max_height = fig_h * h
# Determine the k factors.
n = len(equal_heights)
eq_rels, eq_abss = equal_heights.T
sm_rels, sm_abss = summed_widths.T
A = np.diag([*eq_rels, 0])
A[:n, -1] = -1
A[-1, :-1] = sm_rels
B = [*(-eq_abss), total_width - sm_abss.sum()]
# A @ K = B: This finds factors {k_0, ..., k_{N-1}, H} so that
# eq_rel_i * k_i + eq_abs_i = H for all i: all axes have the same height
# sum(sm_rel_i * k_i + sm_abs_i) = total_width: fixed total width
# (foo_rel_i * k_i + foo_abs_i will end up being the size of foo.)
*karray, height = np.linalg.solve(A, B)
if height > max_height: # Additionally, upper-bound the height.
karray = (max_height - eq_abss) / eq_rels
# Compute the offsets corresponding to these factors.
ox = np.cumsum([0, *(sm_rels * karray + sm_abss)])
ww = (ox[-1] - ox[0]) / fig_w
h0_rel, h0_abs = equal_heights[0]
hh = (karray[0]*h0_rel + h0_abs) / fig_h
pb = mtransforms.Bbox.from_bounds(x, y, w, h)
pb1 = mtransforms.Bbox.from_bounds(x, y, ww, hh)
x0, y0 = pb1.anchored(anchor, pb).p0
return x0, y0, ox, hh
class HBoxDivider(SubplotDivider):
"""
A `.SubplotDivider` for laying out axes horizontally, while ensuring that
they have equal heights.
Examples
--------
.. plot:: gallery/axes_grid1/demo_axes_hbox_divider.py
"""
def new_locator(self, nx, nx1=None):
"""
Create an axes locator callable for the specified cell.
Parameters
----------
nx, nx1 : int
Integers specifying the column-position of the
cell. When *nx1* is None, a single *nx*-th column is
specified. Otherwise, location of columns spanning between *nx*
to *nx1* (but excluding *nx1*-th column) is specified.
"""
return super().new_locator(nx, 0, nx1, 0)
def _locate(self, nx, ny, nx1, ny1, axes, renderer):
# docstring inherited
nx += self._xrefindex
nx1 += self._xrefindex
fig_w, fig_h = self._fig.bbox.size / self._fig.dpi
x, y, w, h = self.get_position_runtime(axes, renderer)
summed_ws = self.get_horizontal_sizes(renderer)
equal_hs = self.get_vertical_sizes(renderer)
x0, y0, ox, hh = _locate(
x, y, w, h, summed_ws, equal_hs, fig_w, fig_h, self.get_anchor())
if nx1 is None:
nx1 = -1
x1, w1 = x0 + ox[nx] / fig_w, (ox[nx1] - ox[nx]) / fig_w
y1, h1 = y0, hh
return mtransforms.Bbox.from_bounds(x1, y1, w1, h1)
class VBoxDivider(SubplotDivider):
"""
A `.SubplotDivider` for laying out axes vertically, while ensuring that
they have equal widths.
"""
def new_locator(self, ny, ny1=None):
"""
Create an axes locator callable for the specified cell.
Parameters
----------
ny, ny1 : int
Integers specifying the row-position of the
cell. When *ny1* is None, a single *ny*-th row is
specified. Otherwise, location of rows spanning between *ny*
to *ny1* (but excluding *ny1*-th row) is specified.
"""
return super().new_locator(0, ny, 0, ny1)
def _locate(self, nx, ny, nx1, ny1, axes, renderer):
# docstring inherited
ny += self._yrefindex
ny1 += self._yrefindex
fig_w, fig_h = self._fig.bbox.size / self._fig.dpi
x, y, w, h = self.get_position_runtime(axes, renderer)
summed_hs = self.get_vertical_sizes(renderer)
equal_ws = self.get_horizontal_sizes(renderer)
y0, x0, oy, ww = _locate(
y, x, h, w, summed_hs, equal_ws, fig_h, fig_w, self.get_anchor())
if ny1 is None:
ny1 = -1
x1, w1 = x0, ww
y1, h1 = y0 + oy[ny] / fig_h, (oy[ny1] - oy[ny]) / fig_h
return mtransforms.Bbox.from_bounds(x1, y1, w1, h1)
def make_axes_locatable(axes):
divider = AxesDivider(axes)
locator = divider.new_locator(nx=0, ny=0)
axes.set_axes_locator(locator)
return divider
def make_axes_area_auto_adjustable(
ax, use_axes=None, pad=0.1, adjust_dirs=None):
"""
Add auto-adjustable padding around *ax* to take its decorations (title,
labels, ticks, ticklabels) into account during layout, using
`.Divider.add_auto_adjustable_area`.
By default, padding is determined from the decorations of *ax*.
Pass *use_axes* to consider the decorations of other Axes instead.
"""
if adjust_dirs is None:
adjust_dirs = ["left", "right", "bottom", "top"]
divider = make_axes_locatable(ax)
if use_axes is None:
use_axes = ax
divider.add_auto_adjustable_area(use_axes=use_axes, pad=pad,
adjust_dirs=adjust_dirs)

View File

@ -0,0 +1,563 @@
from numbers import Number
import functools
from types import MethodType
import numpy as np
from matplotlib import _api, cbook
from matplotlib.gridspec import SubplotSpec
from .axes_divider import Size, SubplotDivider, Divider
from .mpl_axes import Axes, SimpleAxisArtist
class CbarAxesBase:
def __init__(self, *args, orientation, **kwargs):
self.orientation = orientation
super().__init__(*args, **kwargs)
def colorbar(self, mappable, **kwargs):
return self.figure.colorbar(
mappable, cax=self, location=self.orientation, **kwargs)
@_api.deprecated("3.8", alternative="ax.tick_params and colorbar.set_label")
def toggle_label(self, b):
axis = self.axis[self.orientation]
axis.toggle(ticklabels=b, label=b)
_cbaraxes_class_factory = cbook._make_class_factory(CbarAxesBase, "Cbar{}")
class Grid:
"""
A grid of Axes.
In Matplotlib, the Axes location (and size) is specified in normalized
figure coordinates. This may not be ideal for images that needs to be
displayed with a given aspect ratio; for example, it is difficult to
display multiple images of a same size with some fixed padding between
them. AxesGrid can be used in such case.
Attributes
----------
axes_all : list of Axes
A flat list of Axes. Note that you can also access this directly
from the grid. The following is equivalent ::
grid[i] == grid.axes_all[i]
len(grid) == len(grid.axes_all)
axes_column : list of list of Axes
A 2D list of Axes where the first index is the column. This results
in the usage pattern ``grid.axes_column[col][row]``.
axes_row : list of list of Axes
A 2D list of Axes where the first index is the row. This results
in the usage pattern ``grid.axes_row[row][col]``.
axes_llc : Axes
The Axes in the lower left corner.
ngrids : int
Number of Axes in the grid.
"""
_defaultAxesClass = Axes
def __init__(self, fig,
rect,
nrows_ncols,
ngrids=None,
direction="row",
axes_pad=0.02,
*,
share_all=False,
share_x=True,
share_y=True,
label_mode="L",
axes_class=None,
aspect=False,
):
"""
Parameters
----------
fig : `.Figure`
The parent figure.
rect : (float, float, float, float), (int, int, int), int, or \
`~.SubplotSpec`
The axes position, as a ``(left, bottom, width, height)`` tuple,
as a three-digit subplot position code (e.g., ``(1, 2, 1)`` or
``121``), or as a `~.SubplotSpec`.
nrows_ncols : (int, int)
Number of rows and columns in the grid.
ngrids : int or None, default: None
If not None, only the first *ngrids* axes in the grid are created.
direction : {"row", "column"}, default: "row"
Whether axes are created in row-major ("row by row") or
column-major order ("column by column"). This also affects the
order in which axes are accessed using indexing (``grid[index]``).
axes_pad : float or (float, float), default: 0.02
Padding or (horizontal padding, vertical padding) between axes, in
inches.
share_all : bool, default: False
Whether all axes share their x- and y-axis. Overrides *share_x*
and *share_y*.
share_x : bool, default: True
Whether all axes of a column share their x-axis.
share_y : bool, default: True
Whether all axes of a row share their y-axis.
label_mode : {"L", "1", "all", "keep"}, default: "L"
Determines which axes will get tick labels:
- "L": All axes on the left column get vertical tick labels;
all axes on the bottom row get horizontal tick labels.
- "1": Only the bottom left axes is labelled.
- "all": All axes are labelled.
- "keep": Do not do anything.
axes_class : subclass of `matplotlib.axes.Axes`, default: `.mpl_axes.Axes`
The type of Axes to create.
aspect : bool, default: False
Whether the axes aspect ratio follows the aspect ratio of the data
limits.
"""
self._nrows, self._ncols = nrows_ncols
if ngrids is None:
ngrids = self._nrows * self._ncols
else:
if not 0 < ngrids <= self._nrows * self._ncols:
raise ValueError(
"ngrids must be positive and not larger than nrows*ncols")
self.ngrids = ngrids
self._horiz_pad_size, self._vert_pad_size = map(
Size.Fixed, np.broadcast_to(axes_pad, 2))
_api.check_in_list(["column", "row"], direction=direction)
self._direction = direction
if axes_class is None:
axes_class = self._defaultAxesClass
elif isinstance(axes_class, (list, tuple)):
cls, kwargs = axes_class
axes_class = functools.partial(cls, **kwargs)
kw = dict(horizontal=[], vertical=[], aspect=aspect)
if isinstance(rect, (Number, SubplotSpec)):
self._divider = SubplotDivider(fig, rect, **kw)
elif len(rect) == 3:
self._divider = SubplotDivider(fig, *rect, **kw)
elif len(rect) == 4:
self._divider = Divider(fig, rect, **kw)
else:
raise TypeError("Incorrect rect format")
rect = self._divider.get_position()
axes_array = np.full((self._nrows, self._ncols), None, dtype=object)
for i in range(self.ngrids):
col, row = self._get_col_row(i)
if share_all:
sharex = sharey = axes_array[0, 0]
else:
sharex = axes_array[0, col] if share_x else None
sharey = axes_array[row, 0] if share_y else None
axes_array[row, col] = axes_class(
fig, rect, sharex=sharex, sharey=sharey)
self.axes_all = axes_array.ravel(
order="C" if self._direction == "row" else "F").tolist()
self.axes_column = axes_array.T.tolist()
self.axes_row = axes_array.tolist()
self.axes_llc = self.axes_column[0][-1]
self._init_locators()
for ax in self.axes_all:
fig.add_axes(ax)
self.set_label_mode(label_mode)
def _init_locators(self):
self._divider.set_horizontal(
[Size.Scaled(1), self._horiz_pad_size] * (self._ncols-1) + [Size.Scaled(1)])
self._divider.set_vertical(
[Size.Scaled(1), self._vert_pad_size] * (self._nrows-1) + [Size.Scaled(1)])
for i in range(self.ngrids):
col, row = self._get_col_row(i)
self.axes_all[i].set_axes_locator(
self._divider.new_locator(nx=2 * col, ny=2 * (self._nrows - 1 - row)))
def _get_col_row(self, n):
if self._direction == "column":
col, row = divmod(n, self._nrows)
else:
row, col = divmod(n, self._ncols)
return col, row
# Good to propagate __len__ if we have __getitem__
def __len__(self):
return len(self.axes_all)
def __getitem__(self, i):
return self.axes_all[i]
def get_geometry(self):
"""
Return the number of rows and columns of the grid as (nrows, ncols).
"""
return self._nrows, self._ncols
def set_axes_pad(self, axes_pad):
"""
Set the padding between the axes.
Parameters
----------
axes_pad : (float, float)
The padding (horizontal pad, vertical pad) in inches.
"""
self._horiz_pad_size.fixed_size = axes_pad[0]
self._vert_pad_size.fixed_size = axes_pad[1]
def get_axes_pad(self):
"""
Return the axes padding.
Returns
-------
hpad, vpad
Padding (horizontal pad, vertical pad) in inches.
"""
return (self._horiz_pad_size.fixed_size,
self._vert_pad_size.fixed_size)
def set_aspect(self, aspect):
"""Set the aspect of the SubplotDivider."""
self._divider.set_aspect(aspect)
def get_aspect(self):
"""Return the aspect of the SubplotDivider."""
return self._divider.get_aspect()
def set_label_mode(self, mode):
"""
Define which axes have tick labels.
Parameters
----------
mode : {"L", "1", "all", "keep"}
The label mode:
- "L": All axes on the left column get vertical tick labels;
all axes on the bottom row get horizontal tick labels.
- "1": Only the bottom left axes is labelled.
- "all": All axes are labelled.
- "keep": Do not do anything.
"""
_api.check_in_list(["all", "L", "1", "keep"], mode=mode)
is_last_row, is_first_col = (
np.mgrid[:self._nrows, :self._ncols] == [[[self._nrows - 1]], [[0]]])
if mode == "all":
bottom = left = np.full((self._nrows, self._ncols), True)
elif mode == "L":
bottom = is_last_row
left = is_first_col
elif mode == "1":
bottom = left = is_last_row & is_first_col
else:
return
for i in range(self._nrows):
for j in range(self._ncols):
ax = self.axes_row[i][j]
if isinstance(ax.axis, MethodType):
bottom_axis = SimpleAxisArtist(ax.xaxis, 1, ax.spines["bottom"])
left_axis = SimpleAxisArtist(ax.yaxis, 1, ax.spines["left"])
else:
bottom_axis = ax.axis["bottom"]
left_axis = ax.axis["left"]
bottom_axis.toggle(ticklabels=bottom[i, j], label=bottom[i, j])
left_axis.toggle(ticklabels=left[i, j], label=left[i, j])
def get_divider(self):
return self._divider
def set_axes_locator(self, locator):
self._divider.set_locator(locator)
def get_axes_locator(self):
return self._divider.get_locator()
class ImageGrid(Grid):
"""
A grid of Axes for Image display.
This class is a specialization of `~.axes_grid1.axes_grid.Grid` for displaying a
grid of images. In particular, it forces all axes in a column to share their x-axis
and all axes in a row to share their y-axis. It further provides helpers to add
colorbars to some or all axes.
"""
def __init__(self, fig,
rect,
nrows_ncols,
ngrids=None,
direction="row",
axes_pad=0.02,
*,
share_all=False,
aspect=True,
label_mode="L",
cbar_mode=None,
cbar_location="right",
cbar_pad=None,
cbar_size="5%",
cbar_set_cax=True,
axes_class=None,
):
"""
Parameters
----------
fig : `.Figure`
The parent figure.
rect : (float, float, float, float) or int
The axes position, as a ``(left, bottom, width, height)`` tuple or
as a three-digit subplot position code (e.g., "121").
nrows_ncols : (int, int)
Number of rows and columns in the grid.
ngrids : int or None, default: None
If not None, only the first *ngrids* axes in the grid are created.
direction : {"row", "column"}, default: "row"
Whether axes are created in row-major ("row by row") or
column-major order ("column by column"). This also affects the
order in which axes are accessed using indexing (``grid[index]``).
axes_pad : float or (float, float), default: 0.02in
Padding or (horizontal padding, vertical padding) between axes, in
inches.
share_all : bool, default: False
Whether all axes share their x- and y-axis. Note that in any case,
all axes in a column share their x-axis and all axes in a row share
their y-axis.
aspect : bool, default: True
Whether the axes aspect ratio follows the aspect ratio of the data
limits.
label_mode : {"L", "1", "all"}, default: "L"
Determines which axes will get tick labels:
- "L": All axes on the left column get vertical tick labels;
all axes on the bottom row get horizontal tick labels.
- "1": Only the bottom left axes is labelled.
- "all": all axes are labelled.
cbar_mode : {"each", "single", "edge", None}, default: None
Whether to create a colorbar for "each" axes, a "single" colorbar
for the entire grid, colorbars only for axes on the "edge"
determined by *cbar_location*, or no colorbars. The colorbars are
stored in the :attr:`cbar_axes` attribute.
cbar_location : {"left", "right", "bottom", "top"}, default: "right"
cbar_pad : float, default: None
Padding between the image axes and the colorbar axes.
cbar_size : size specification (see `.Size.from_any`), default: "5%"
Colorbar size.
cbar_set_cax : bool, default: True
If True, each axes in the grid has a *cax* attribute that is bound
to associated *cbar_axes*.
axes_class : subclass of `matplotlib.axes.Axes`, default: None
"""
_api.check_in_list(["each", "single", "edge", None],
cbar_mode=cbar_mode)
_api.check_in_list(["left", "right", "bottom", "top"],
cbar_location=cbar_location)
self._colorbar_mode = cbar_mode
self._colorbar_location = cbar_location
self._colorbar_pad = cbar_pad
self._colorbar_size = cbar_size
# The colorbar axes are created in _init_locators().
super().__init__(
fig, rect, nrows_ncols, ngrids,
direction=direction, axes_pad=axes_pad,
share_all=share_all, share_x=True, share_y=True, aspect=aspect,
label_mode=label_mode, axes_class=axes_class)
for ax in self.cbar_axes:
fig.add_axes(ax)
if cbar_set_cax:
if self._colorbar_mode == "single":
for ax in self.axes_all:
ax.cax = self.cbar_axes[0]
elif self._colorbar_mode == "edge":
for index, ax in enumerate(self.axes_all):
col, row = self._get_col_row(index)
if self._colorbar_location in ("left", "right"):
ax.cax = self.cbar_axes[row]
else:
ax.cax = self.cbar_axes[col]
else:
for ax, cax in zip(self.axes_all, self.cbar_axes):
ax.cax = cax
def _init_locators(self):
# Slightly abusing this method to inject colorbar creation into init.
if self._colorbar_pad is None:
# horizontal or vertical arrangement?
if self._colorbar_location in ("left", "right"):
self._colorbar_pad = self._horiz_pad_size.fixed_size
else:
self._colorbar_pad = self._vert_pad_size.fixed_size
self.cbar_axes = [
_cbaraxes_class_factory(self._defaultAxesClass)(
self.axes_all[0].figure, self._divider.get_position(),
orientation=self._colorbar_location)
for _ in range(self.ngrids)]
cb_mode = self._colorbar_mode
cb_location = self._colorbar_location
h = []
v = []
h_ax_pos = []
h_cb_pos = []
if cb_mode == "single" and cb_location in ("left", "bottom"):
if cb_location == "left":
sz = self._nrows * Size.AxesX(self.axes_llc)
h.append(Size.from_any(self._colorbar_size, sz))
h.append(Size.from_any(self._colorbar_pad, sz))
locator = self._divider.new_locator(nx=0, ny=0, ny1=-1)
elif cb_location == "bottom":
sz = self._ncols * Size.AxesY(self.axes_llc)
v.append(Size.from_any(self._colorbar_size, sz))
v.append(Size.from_any(self._colorbar_pad, sz))
locator = self._divider.new_locator(nx=0, nx1=-1, ny=0)
for i in range(self.ngrids):
self.cbar_axes[i].set_visible(False)
self.cbar_axes[0].set_axes_locator(locator)
self.cbar_axes[0].set_visible(True)
for col, ax in enumerate(self.axes_row[0]):
if h:
h.append(self._horiz_pad_size)
if ax:
sz = Size.AxesX(ax, aspect="axes", ref_ax=self.axes_all[0])
else:
sz = Size.AxesX(self.axes_all[0],
aspect="axes", ref_ax=self.axes_all[0])
if (cb_location == "left"
and (cb_mode == "each"
or (cb_mode == "edge" and col == 0))):
h_cb_pos.append(len(h))
h.append(Size.from_any(self._colorbar_size, sz))
h.append(Size.from_any(self._colorbar_pad, sz))
h_ax_pos.append(len(h))
h.append(sz)
if (cb_location == "right"
and (cb_mode == "each"
or (cb_mode == "edge" and col == self._ncols - 1))):
h.append(Size.from_any(self._colorbar_pad, sz))
h_cb_pos.append(len(h))
h.append(Size.from_any(self._colorbar_size, sz))
v_ax_pos = []
v_cb_pos = []
for row, ax in enumerate(self.axes_column[0][::-1]):
if v:
v.append(self._vert_pad_size)
if ax:
sz = Size.AxesY(ax, aspect="axes", ref_ax=self.axes_all[0])
else:
sz = Size.AxesY(self.axes_all[0],
aspect="axes", ref_ax=self.axes_all[0])
if (cb_location == "bottom"
and (cb_mode == "each"
or (cb_mode == "edge" and row == 0))):
v_cb_pos.append(len(v))
v.append(Size.from_any(self._colorbar_size, sz))
v.append(Size.from_any(self._colorbar_pad, sz))
v_ax_pos.append(len(v))
v.append(sz)
if (cb_location == "top"
and (cb_mode == "each"
or (cb_mode == "edge" and row == self._nrows - 1))):
v.append(Size.from_any(self._colorbar_pad, sz))
v_cb_pos.append(len(v))
v.append(Size.from_any(self._colorbar_size, sz))
for i in range(self.ngrids):
col, row = self._get_col_row(i)
locator = self._divider.new_locator(nx=h_ax_pos[col],
ny=v_ax_pos[self._nrows-1-row])
self.axes_all[i].set_axes_locator(locator)
if cb_mode == "each":
if cb_location in ("right", "left"):
locator = self._divider.new_locator(
nx=h_cb_pos[col], ny=v_ax_pos[self._nrows - 1 - row])
elif cb_location in ("top", "bottom"):
locator = self._divider.new_locator(
nx=h_ax_pos[col], ny=v_cb_pos[self._nrows - 1 - row])
self.cbar_axes[i].set_axes_locator(locator)
elif cb_mode == "edge":
if (cb_location == "left" and col == 0
or cb_location == "right" and col == self._ncols - 1):
locator = self._divider.new_locator(
nx=h_cb_pos[0], ny=v_ax_pos[self._nrows - 1 - row])
self.cbar_axes[row].set_axes_locator(locator)
elif (cb_location == "bottom" and row == self._nrows - 1
or cb_location == "top" and row == 0):
locator = self._divider.new_locator(nx=h_ax_pos[col],
ny=v_cb_pos[0])
self.cbar_axes[col].set_axes_locator(locator)
if cb_mode == "single":
if cb_location == "right":
sz = self._nrows * Size.AxesX(self.axes_llc)
h.append(Size.from_any(self._colorbar_pad, sz))
h.append(Size.from_any(self._colorbar_size, sz))
locator = self._divider.new_locator(nx=-2, ny=0, ny1=-1)
elif cb_location == "top":
sz = self._ncols * Size.AxesY(self.axes_llc)
v.append(Size.from_any(self._colorbar_pad, sz))
v.append(Size.from_any(self._colorbar_size, sz))
locator = self._divider.new_locator(nx=0, nx1=-1, ny=-2)
if cb_location in ("right", "top"):
for i in range(self.ngrids):
self.cbar_axes[i].set_visible(False)
self.cbar_axes[0].set_axes_locator(locator)
self.cbar_axes[0].set_visible(True)
elif cb_mode == "each":
for i in range(self.ngrids):
self.cbar_axes[i].set_visible(True)
elif cb_mode == "edge":
if cb_location in ("right", "left"):
count = self._nrows
else:
count = self._ncols
for i in range(count):
self.cbar_axes[i].set_visible(True)
for j in range(i + 1, self.ngrids):
self.cbar_axes[j].set_visible(False)
else:
for i in range(self.ngrids):
self.cbar_axes[i].set_visible(False)
self.cbar_axes[i].set_position([1., 1., 0.001, 0.001],
which="active")
self._divider.set_horizontal(h)
self._divider.set_vertical(v)
AxesGrid = ImageGrid

View File

@ -0,0 +1,157 @@
from types import MethodType
import numpy as np
from .axes_divider import make_axes_locatable, Size
from .mpl_axes import Axes, SimpleAxisArtist
def make_rgb_axes(ax, pad=0.01, axes_class=None, **kwargs):
"""
Parameters
----------
ax : `~matplotlib.axes.Axes`
Axes instance to create the RGB Axes in.
pad : float, optional
Fraction of the Axes height to pad.
axes_class : `matplotlib.axes.Axes` or None, optional
Axes class to use for the R, G, and B Axes. If None, use
the same class as *ax*.
**kwargs
Forwarded to *axes_class* init for the R, G, and B Axes.
"""
divider = make_axes_locatable(ax)
pad_size = pad * Size.AxesY(ax)
xsize = ((1-2*pad)/3) * Size.AxesX(ax)
ysize = ((1-2*pad)/3) * Size.AxesY(ax)
divider.set_horizontal([Size.AxesX(ax), pad_size, xsize])
divider.set_vertical([ysize, pad_size, ysize, pad_size, ysize])
ax.set_axes_locator(divider.new_locator(0, 0, ny1=-1))
ax_rgb = []
if axes_class is None:
axes_class = type(ax)
for ny in [4, 2, 0]:
ax1 = axes_class(ax.get_figure(), ax.get_position(original=True),
sharex=ax, sharey=ax, **kwargs)
locator = divider.new_locator(nx=2, ny=ny)
ax1.set_axes_locator(locator)
for t in ax1.yaxis.get_ticklabels() + ax1.xaxis.get_ticklabels():
t.set_visible(False)
try:
for axis in ax1.axis.values():
axis.major_ticklabels.set_visible(False)
except AttributeError:
pass
ax_rgb.append(ax1)
fig = ax.get_figure()
for ax1 in ax_rgb:
fig.add_axes(ax1)
return ax_rgb
class RGBAxes:
"""
4-panel `~.Axes.imshow` (RGB, R, G, B).
Layout::
┌───────────────┬─────┐
│ │ R │
│ ├─────┤
│ RGB │ G │
│ ├─────┤
│ │ B │
└───────────────┴─────┘
Subclasses can override the ``_defaultAxesClass`` attribute.
By default RGBAxes uses `.mpl_axes.Axes`.
Attributes
----------
RGB : ``_defaultAxesClass``
The Axes object for the three-channel `~.Axes.imshow`.
R : ``_defaultAxesClass``
The Axes object for the red channel `~.Axes.imshow`.
G : ``_defaultAxesClass``
The Axes object for the green channel `~.Axes.imshow`.
B : ``_defaultAxesClass``
The Axes object for the blue channel `~.Axes.imshow`.
"""
_defaultAxesClass = Axes
def __init__(self, *args, pad=0, **kwargs):
"""
Parameters
----------
pad : float, default: 0
Fraction of the Axes height to put as padding.
axes_class : `~matplotlib.axes.Axes`
Axes class to use. If not provided, ``_defaultAxesClass`` is used.
*args
Forwarded to *axes_class* init for the RGB Axes
**kwargs
Forwarded to *axes_class* init for the RGB, R, G, and B Axes
"""
axes_class = kwargs.pop("axes_class", self._defaultAxesClass)
self.RGB = ax = axes_class(*args, **kwargs)
ax.get_figure().add_axes(ax)
self.R, self.G, self.B = make_rgb_axes(
ax, pad=pad, axes_class=axes_class, **kwargs)
# Set the line color and ticks for the axes.
for ax1 in [self.RGB, self.R, self.G, self.B]:
if isinstance(ax1.axis, MethodType):
ad = Axes.AxisDict(self)
ad.update(
bottom=SimpleAxisArtist(ax1.xaxis, 1, ax1.spines["bottom"]),
top=SimpleAxisArtist(ax1.xaxis, 2, ax1.spines["top"]),
left=SimpleAxisArtist(ax1.yaxis, 1, ax1.spines["left"]),
right=SimpleAxisArtist(ax1.yaxis, 2, ax1.spines["right"]))
else:
ad = ax1.axis
ad[:].line.set_color("w")
ad[:].major_ticks.set_markeredgecolor("w")
def imshow_rgb(self, r, g, b, **kwargs):
"""
Create the four images {rgb, r, g, b}.
Parameters
----------
r, g, b : array-like
The red, green, and blue arrays.
**kwargs
Forwarded to `~.Axes.imshow` calls for the four images.
Returns
-------
rgb : `~matplotlib.image.AxesImage`
r : `~matplotlib.image.AxesImage`
g : `~matplotlib.image.AxesImage`
b : `~matplotlib.image.AxesImage`
"""
if not (r.shape == g.shape == b.shape):
raise ValueError(
f'Input shapes ({r.shape}, {g.shape}, {b.shape}) do not match')
RGB = np.dstack([r, g, b])
R = np.zeros_like(RGB)
R[:, :, 0] = r
G = np.zeros_like(RGB)
G[:, :, 1] = g
B = np.zeros_like(RGB)
B[:, :, 2] = b
im_rgb = self.RGB.imshow(RGB, **kwargs)
im_r = self.R.imshow(R, **kwargs)
im_g = self.G.imshow(G, **kwargs)
im_b = self.B.imshow(B, **kwargs)
return im_rgb, im_r, im_g, im_b

View File

@ -0,0 +1,248 @@
"""
Provides classes of simple units that will be used with `.AxesDivider`
class (or others) to determine the size of each Axes. The unit
classes define `get_size` method that returns a tuple of two floats,
meaning relative and absolute sizes, respectively.
Note that this class is nothing more than a simple tuple of two
floats. Take a look at the Divider class to see how these two
values are used.
"""
from numbers import Real
from matplotlib import _api
from matplotlib.axes import Axes
class _Base:
def __rmul__(self, other):
return Fraction(other, self)
def __add__(self, other):
if isinstance(other, _Base):
return Add(self, other)
else:
return Add(self, Fixed(other))
def get_size(self, renderer):
"""
Return two-float tuple with relative and absolute sizes.
"""
raise NotImplementedError("Subclasses must implement")
class Add(_Base):
"""
Sum of two sizes.
"""
def __init__(self, a, b):
self._a = a
self._b = b
def get_size(self, renderer):
a_rel_size, a_abs_size = self._a.get_size(renderer)
b_rel_size, b_abs_size = self._b.get_size(renderer)
return a_rel_size + b_rel_size, a_abs_size + b_abs_size
class Fixed(_Base):
"""
Simple fixed size with absolute part = *fixed_size* and relative part = 0.
"""
def __init__(self, fixed_size):
_api.check_isinstance(Real, fixed_size=fixed_size)
self.fixed_size = fixed_size
def get_size(self, renderer):
rel_size = 0.
abs_size = self.fixed_size
return rel_size, abs_size
class Scaled(_Base):
"""
Simple scaled(?) size with absolute part = 0 and
relative part = *scalable_size*.
"""
def __init__(self, scalable_size):
self._scalable_size = scalable_size
def get_size(self, renderer):
rel_size = self._scalable_size
abs_size = 0.
return rel_size, abs_size
Scalable = Scaled
def _get_axes_aspect(ax):
aspect = ax.get_aspect()
if aspect == "auto":
aspect = 1.
return aspect
class AxesX(_Base):
"""
Scaled size whose relative part corresponds to the data width
of the *axes* multiplied by the *aspect*.
"""
def __init__(self, axes, aspect=1., ref_ax=None):
self._axes = axes
self._aspect = aspect
if aspect == "axes" and ref_ax is None:
raise ValueError("ref_ax must be set when aspect='axes'")
self._ref_ax = ref_ax
def get_size(self, renderer):
l1, l2 = self._axes.get_xlim()
if self._aspect == "axes":
ref_aspect = _get_axes_aspect(self._ref_ax)
aspect = ref_aspect / _get_axes_aspect(self._axes)
else:
aspect = self._aspect
rel_size = abs(l2-l1)*aspect
abs_size = 0.
return rel_size, abs_size
class AxesY(_Base):
"""
Scaled size whose relative part corresponds to the data height
of the *axes* multiplied by the *aspect*.
"""
def __init__(self, axes, aspect=1., ref_ax=None):
self._axes = axes
self._aspect = aspect
if aspect == "axes" and ref_ax is None:
raise ValueError("ref_ax must be set when aspect='axes'")
self._ref_ax = ref_ax
def get_size(self, renderer):
l1, l2 = self._axes.get_ylim()
if self._aspect == "axes":
ref_aspect = _get_axes_aspect(self._ref_ax)
aspect = _get_axes_aspect(self._axes)
else:
aspect = self._aspect
rel_size = abs(l2-l1)*aspect
abs_size = 0.
return rel_size, abs_size
class MaxExtent(_Base):
"""
Size whose absolute part is either the largest width or the largest height
of the given *artist_list*.
"""
def __init__(self, artist_list, w_or_h):
self._artist_list = artist_list
_api.check_in_list(["width", "height"], w_or_h=w_or_h)
self._w_or_h = w_or_h
def add_artist(self, a):
self._artist_list.append(a)
def get_size(self, renderer):
rel_size = 0.
extent_list = [
getattr(a.get_window_extent(renderer), self._w_or_h) / a.figure.dpi
for a in self._artist_list]
abs_size = max(extent_list, default=0)
return rel_size, abs_size
class MaxWidth(MaxExtent):
"""
Size whose absolute part is the largest width of the given *artist_list*.
"""
def __init__(self, artist_list):
super().__init__(artist_list, "width")
class MaxHeight(MaxExtent):
"""
Size whose absolute part is the largest height of the given *artist_list*.
"""
def __init__(self, artist_list):
super().__init__(artist_list, "height")
class Fraction(_Base):
"""
An instance whose size is a *fraction* of the *ref_size*.
>>> s = Fraction(0.3, AxesX(ax))
"""
def __init__(self, fraction, ref_size):
_api.check_isinstance(Real, fraction=fraction)
self._fraction_ref = ref_size
self._fraction = fraction
def get_size(self, renderer):
if self._fraction_ref is None:
return self._fraction, 0.
else:
r, a = self._fraction_ref.get_size(renderer)
rel_size = r*self._fraction
abs_size = a*self._fraction
return rel_size, abs_size
def from_any(size, fraction_ref=None):
"""
Create a Fixed unit when the first argument is a float, or a
Fraction unit if that is a string that ends with %. The second
argument is only meaningful when Fraction unit is created.
>>> from mpl_toolkits.axes_grid1.axes_size import from_any
>>> a = from_any(1.2) # => Fixed(1.2)
>>> from_any("50%", a) # => Fraction(0.5, a)
"""
if isinstance(size, Real):
return Fixed(size)
elif isinstance(size, str):
if size[-1] == "%":
return Fraction(float(size[:-1]) / 100, fraction_ref)
raise ValueError("Unknown format")
class _AxesDecorationsSize(_Base):
"""
Fixed size, corresponding to the size of decorations on a given Axes side.
"""
_get_size_map = {
"left": lambda tight_bb, axes_bb: axes_bb.xmin - tight_bb.xmin,
"right": lambda tight_bb, axes_bb: tight_bb.xmax - axes_bb.xmax,
"bottom": lambda tight_bb, axes_bb: axes_bb.ymin - tight_bb.ymin,
"top": lambda tight_bb, axes_bb: tight_bb.ymax - axes_bb.ymax,
}
def __init__(self, ax, direction):
_api.check_in_list(self._get_size_map, direction=direction)
self._direction = direction
self._ax_list = [ax] if isinstance(ax, Axes) else ax
def get_size(self, renderer):
sz = max([
self._get_size_map[self._direction](
ax.get_tightbbox(renderer, call_axes_locator=False), ax.bbox)
for ax in self._ax_list])
dpi = renderer.points_to_pixels(72)
abs_size = sz / dpi
rel_size = 0
return rel_size, abs_size

View File

@ -0,0 +1,561 @@
"""
A collection of functions and objects for creating or placing inset axes.
"""
from matplotlib import _api, _docstring
from matplotlib.offsetbox import AnchoredOffsetbox
from matplotlib.patches import Patch, Rectangle
from matplotlib.path import Path
from matplotlib.transforms import Bbox, BboxTransformTo
from matplotlib.transforms import IdentityTransform, TransformedBbox
from . import axes_size as Size
from .parasite_axes import HostAxes
@_api.deprecated("3.8", alternative="Axes.inset_axes")
class InsetPosition:
@_docstring.dedent_interpd
def __init__(self, parent, lbwh):
"""
An object for positioning an inset axes.
This is created by specifying the normalized coordinates in the axes,
instead of the figure.
Parameters
----------
parent : `~matplotlib.axes.Axes`
Axes to use for normalizing coordinates.
lbwh : iterable of four floats
The left edge, bottom edge, width, and height of the inset axes, in
units of the normalized coordinate of the *parent* axes.
See Also
--------
:meth:`matplotlib.axes.Axes.set_axes_locator`
Examples
--------
The following bounds the inset axes to a box with 20%% of the parent
axes height and 40%% of the width. The size of the axes specified
([0, 0, 1, 1]) ensures that the axes completely fills the bounding box:
>>> parent_axes = plt.gca()
>>> ax_ins = plt.axes([0, 0, 1, 1])
>>> ip = InsetPosition(parent_axes, [0.5, 0.1, 0.4, 0.2])
>>> ax_ins.set_axes_locator(ip)
"""
self.parent = parent
self.lbwh = lbwh
def __call__(self, ax, renderer):
bbox_parent = self.parent.get_position(original=False)
trans = BboxTransformTo(bbox_parent)
bbox_inset = Bbox.from_bounds(*self.lbwh)
bb = TransformedBbox(bbox_inset, trans)
return bb
class AnchoredLocatorBase(AnchoredOffsetbox):
def __init__(self, bbox_to_anchor, offsetbox, loc,
borderpad=0.5, bbox_transform=None):
super().__init__(
loc, pad=0., child=None, borderpad=borderpad,
bbox_to_anchor=bbox_to_anchor, bbox_transform=bbox_transform
)
def draw(self, renderer):
raise RuntimeError("No draw method should be called")
def __call__(self, ax, renderer):
if renderer is None:
renderer = ax.figure._get_renderer()
self.axes = ax
bbox = self.get_window_extent(renderer)
px, py = self.get_offset(bbox.width, bbox.height, 0, 0, renderer)
bbox_canvas = Bbox.from_bounds(px, py, bbox.width, bbox.height)
tr = ax.figure.transSubfigure.inverted()
return TransformedBbox(bbox_canvas, tr)
class AnchoredSizeLocator(AnchoredLocatorBase):
def __init__(self, bbox_to_anchor, x_size, y_size, loc,
borderpad=0.5, bbox_transform=None):
super().__init__(
bbox_to_anchor, None, loc,
borderpad=borderpad, bbox_transform=bbox_transform
)
self.x_size = Size.from_any(x_size)
self.y_size = Size.from_any(y_size)
def get_bbox(self, renderer):
bbox = self.get_bbox_to_anchor()
dpi = renderer.points_to_pixels(72.)
r, a = self.x_size.get_size(renderer)
width = bbox.width * r + a * dpi
r, a = self.y_size.get_size(renderer)
height = bbox.height * r + a * dpi
fontsize = renderer.points_to_pixels(self.prop.get_size_in_points())
pad = self.pad * fontsize
return Bbox.from_bounds(0, 0, width, height).padded(pad)
class AnchoredZoomLocator(AnchoredLocatorBase):
def __init__(self, parent_axes, zoom, loc,
borderpad=0.5,
bbox_to_anchor=None,
bbox_transform=None):
self.parent_axes = parent_axes
self.zoom = zoom
if bbox_to_anchor is None:
bbox_to_anchor = parent_axes.bbox
super().__init__(
bbox_to_anchor, None, loc, borderpad=borderpad,
bbox_transform=bbox_transform)
def get_bbox(self, renderer):
bb = self.parent_axes.transData.transform_bbox(self.axes.viewLim)
fontsize = renderer.points_to_pixels(self.prop.get_size_in_points())
pad = self.pad * fontsize
return (
Bbox.from_bounds(
0, 0, abs(bb.width * self.zoom), abs(bb.height * self.zoom))
.padded(pad))
class BboxPatch(Patch):
@_docstring.dedent_interpd
def __init__(self, bbox, **kwargs):
"""
Patch showing the shape bounded by a Bbox.
Parameters
----------
bbox : `~matplotlib.transforms.Bbox`
Bbox to use for the extents of this patch.
**kwargs
Patch properties. Valid arguments include:
%(Patch:kwdoc)s
"""
if "transform" in kwargs:
raise ValueError("transform should not be set")
kwargs["transform"] = IdentityTransform()
super().__init__(**kwargs)
self.bbox = bbox
def get_path(self):
# docstring inherited
x0, y0, x1, y1 = self.bbox.extents
return Path._create_closed([(x0, y0), (x1, y0), (x1, y1), (x0, y1)])
class BboxConnector(Patch):
@staticmethod
def get_bbox_edge_pos(bbox, loc):
"""
Return the ``(x, y)`` coordinates of corner *loc* of *bbox*; parameters
behave as documented for the `.BboxConnector` constructor.
"""
x0, y0, x1, y1 = bbox.extents
if loc == 1:
return x1, y1
elif loc == 2:
return x0, y1
elif loc == 3:
return x0, y0
elif loc == 4:
return x1, y0
@staticmethod
def connect_bbox(bbox1, bbox2, loc1, loc2=None):
"""
Construct a `.Path` connecting corner *loc1* of *bbox1* to corner
*loc2* of *bbox2*, where parameters behave as documented as for the
`.BboxConnector` constructor.
"""
if isinstance(bbox1, Rectangle):
bbox1 = TransformedBbox(Bbox.unit(), bbox1.get_transform())
if isinstance(bbox2, Rectangle):
bbox2 = TransformedBbox(Bbox.unit(), bbox2.get_transform())
if loc2 is None:
loc2 = loc1
x1, y1 = BboxConnector.get_bbox_edge_pos(bbox1, loc1)
x2, y2 = BboxConnector.get_bbox_edge_pos(bbox2, loc2)
return Path([[x1, y1], [x2, y2]])
@_docstring.dedent_interpd
def __init__(self, bbox1, bbox2, loc1, loc2=None, **kwargs):
"""
Connect two bboxes with a straight line.
Parameters
----------
bbox1, bbox2 : `~matplotlib.transforms.Bbox`
Bounding boxes to connect.
loc1, loc2 : {1, 2, 3, 4}
Corner of *bbox1* and *bbox2* to draw the line. Valid values are::
'upper right' : 1,
'upper left' : 2,
'lower left' : 3,
'lower right' : 4
*loc2* is optional and defaults to *loc1*.
**kwargs
Patch properties for the line drawn. Valid arguments include:
%(Patch:kwdoc)s
"""
if "transform" in kwargs:
raise ValueError("transform should not be set")
kwargs["transform"] = IdentityTransform()
kwargs.setdefault(
"fill", bool({'fc', 'facecolor', 'color'}.intersection(kwargs)))
super().__init__(**kwargs)
self.bbox1 = bbox1
self.bbox2 = bbox2
self.loc1 = loc1
self.loc2 = loc2
def get_path(self):
# docstring inherited
return self.connect_bbox(self.bbox1, self.bbox2,
self.loc1, self.loc2)
class BboxConnectorPatch(BboxConnector):
@_docstring.dedent_interpd
def __init__(self, bbox1, bbox2, loc1a, loc2a, loc1b, loc2b, **kwargs):
"""
Connect two bboxes with a quadrilateral.
The quadrilateral is specified by two lines that start and end at
corners of the bboxes. The four sides of the quadrilateral are defined
by the two lines given, the line between the two corners specified in
*bbox1* and the line between the two corners specified in *bbox2*.
Parameters
----------
bbox1, bbox2 : `~matplotlib.transforms.Bbox`
Bounding boxes to connect.
loc1a, loc2a, loc1b, loc2b : {1, 2, 3, 4}
The first line connects corners *loc1a* of *bbox1* and *loc2a* of
*bbox2*; the second line connects corners *loc1b* of *bbox1* and
*loc2b* of *bbox2*. Valid values are::
'upper right' : 1,
'upper left' : 2,
'lower left' : 3,
'lower right' : 4
**kwargs
Patch properties for the line drawn:
%(Patch:kwdoc)s
"""
if "transform" in kwargs:
raise ValueError("transform should not be set")
super().__init__(bbox1, bbox2, loc1a, loc2a, **kwargs)
self.loc1b = loc1b
self.loc2b = loc2b
def get_path(self):
# docstring inherited
path1 = self.connect_bbox(self.bbox1, self.bbox2, self.loc1, self.loc2)
path2 = self.connect_bbox(self.bbox2, self.bbox1,
self.loc2b, self.loc1b)
path_merged = [*path1.vertices, *path2.vertices, path1.vertices[0]]
return Path(path_merged)
def _add_inset_axes(parent_axes, axes_class, axes_kwargs, axes_locator):
"""Helper function to add an inset axes and disable navigation in it."""
if axes_class is None:
axes_class = HostAxes
if axes_kwargs is None:
axes_kwargs = {}
inset_axes = axes_class(
parent_axes.figure, parent_axes.get_position(),
**{"navigate": False, **axes_kwargs, "axes_locator": axes_locator})
return parent_axes.figure.add_axes(inset_axes)
@_docstring.dedent_interpd
def inset_axes(parent_axes, width, height, loc='upper right',
bbox_to_anchor=None, bbox_transform=None,
axes_class=None, axes_kwargs=None,
borderpad=0.5):
"""
Create an inset axes with a given width and height.
Both sizes used can be specified either in inches or percentage.
For example,::
inset_axes(parent_axes, width='40%%', height='30%%', loc='lower left')
creates in inset axes in the lower left corner of *parent_axes* which spans
over 30%% in height and 40%% in width of the *parent_axes*. Since the usage
of `.inset_axes` may become slightly tricky when exceeding such standard
cases, it is recommended to read :doc:`the examples
</gallery/axes_grid1/inset_locator_demo>`.
Notes
-----
The meaning of *bbox_to_anchor* and *bbox_to_transform* is interpreted
differently from that of legend. The value of bbox_to_anchor
(or the return value of its get_points method; the default is
*parent_axes.bbox*) is transformed by the bbox_transform (the default
is Identity transform) and then interpreted as points in the pixel
coordinate (which is dpi dependent).
Thus, following three calls are identical and creates an inset axes
with respect to the *parent_axes*::
axins = inset_axes(parent_axes, "30%%", "40%%")
axins = inset_axes(parent_axes, "30%%", "40%%",
bbox_to_anchor=parent_axes.bbox)
axins = inset_axes(parent_axes, "30%%", "40%%",
bbox_to_anchor=(0, 0, 1, 1),
bbox_transform=parent_axes.transAxes)
Parameters
----------
parent_axes : `matplotlib.axes.Axes`
Axes to place the inset axes.
width, height : float or str
Size of the inset axes to create. If a float is provided, it is
the size in inches, e.g. *width=1.3*. If a string is provided, it is
the size in relative units, e.g. *width='40%%'*. By default, i.e. if
neither *bbox_to_anchor* nor *bbox_transform* are specified, those
are relative to the parent_axes. Otherwise, they are to be understood
relative to the bounding box provided via *bbox_to_anchor*.
loc : str, default: 'upper right'
Location to place the inset axes. Valid locations are
'upper left', 'upper center', 'upper right',
'center left', 'center', 'center right',
'lower left', 'lower center', 'lower right'.
For backward compatibility, numeric values are accepted as well.
See the parameter *loc* of `.Legend` for details.
bbox_to_anchor : tuple or `~matplotlib.transforms.BboxBase`, optional
Bbox that the inset axes will be anchored to. If None,
a tuple of (0, 0, 1, 1) is used if *bbox_transform* is set
to *parent_axes.transAxes* or *parent_axes.figure.transFigure*.
Otherwise, *parent_axes.bbox* is used. If a tuple, can be either
[left, bottom, width, height], or [left, bottom].
If the kwargs *width* and/or *height* are specified in relative units,
the 2-tuple [left, bottom] cannot be used. Note that,
unless *bbox_transform* is set, the units of the bounding box
are interpreted in the pixel coordinate. When using *bbox_to_anchor*
with tuple, it almost always makes sense to also specify
a *bbox_transform*. This might often be the axes transform
*parent_axes.transAxes*.
bbox_transform : `~matplotlib.transforms.Transform`, optional
Transformation for the bbox that contains the inset axes.
If None, a `.transforms.IdentityTransform` is used. The value
of *bbox_to_anchor* (or the return value of its get_points method)
is transformed by the *bbox_transform* and then interpreted
as points in the pixel coordinate (which is dpi dependent).
You may provide *bbox_to_anchor* in some normalized coordinate,
and give an appropriate transform (e.g., *parent_axes.transAxes*).
axes_class : `~matplotlib.axes.Axes` type, default: `.HostAxes`
The type of the newly created inset axes.
axes_kwargs : dict, optional
Keyword arguments to pass to the constructor of the inset axes.
Valid arguments include:
%(Axes:kwdoc)s
borderpad : float, default: 0.5
Padding between inset axes and the bbox_to_anchor.
The units are axes font size, i.e. for a default font size of 10 points
*borderpad = 0.5* is equivalent to a padding of 5 points.
Returns
-------
inset_axes : *axes_class*
Inset axes object created.
"""
if (bbox_transform in [parent_axes.transAxes, parent_axes.figure.transFigure]
and bbox_to_anchor is None):
_api.warn_external("Using the axes or figure transform requires a "
"bounding box in the respective coordinates. "
"Using bbox_to_anchor=(0, 0, 1, 1) now.")
bbox_to_anchor = (0, 0, 1, 1)
if bbox_to_anchor is None:
bbox_to_anchor = parent_axes.bbox
if (isinstance(bbox_to_anchor, tuple) and
(isinstance(width, str) or isinstance(height, str))):
if len(bbox_to_anchor) != 4:
raise ValueError("Using relative units for width or height "
"requires to provide a 4-tuple or a "
"`Bbox` instance to `bbox_to_anchor.")
return _add_inset_axes(
parent_axes, axes_class, axes_kwargs,
AnchoredSizeLocator(
bbox_to_anchor, width, height, loc=loc,
bbox_transform=bbox_transform, borderpad=borderpad))
@_docstring.dedent_interpd
def zoomed_inset_axes(parent_axes, zoom, loc='upper right',
bbox_to_anchor=None, bbox_transform=None,
axes_class=None, axes_kwargs=None,
borderpad=0.5):
"""
Create an anchored inset axes by scaling a parent axes. For usage, also see
:doc:`the examples </gallery/axes_grid1/inset_locator_demo2>`.
Parameters
----------
parent_axes : `~matplotlib.axes.Axes`
Axes to place the inset axes.
zoom : float
Scaling factor of the data axes. *zoom* > 1 will enlarge the
coordinates (i.e., "zoomed in"), while *zoom* < 1 will shrink the
coordinates (i.e., "zoomed out").
loc : str, default: 'upper right'
Location to place the inset axes. Valid locations are
'upper left', 'upper center', 'upper right',
'center left', 'center', 'center right',
'lower left', 'lower center', 'lower right'.
For backward compatibility, numeric values are accepted as well.
See the parameter *loc* of `.Legend` for details.
bbox_to_anchor : tuple or `~matplotlib.transforms.BboxBase`, optional
Bbox that the inset axes will be anchored to. If None,
*parent_axes.bbox* is used. If a tuple, can be either
[left, bottom, width, height], or [left, bottom].
If the kwargs *width* and/or *height* are specified in relative units,
the 2-tuple [left, bottom] cannot be used. Note that
the units of the bounding box are determined through the transform
in use. When using *bbox_to_anchor* it almost always makes sense to
also specify a *bbox_transform*. This might often be the axes transform
*parent_axes.transAxes*.
bbox_transform : `~matplotlib.transforms.Transform`, optional
Transformation for the bbox that contains the inset axes.
If None, a `.transforms.IdentityTransform` is used (i.e. pixel
coordinates). This is useful when not providing any argument to
*bbox_to_anchor*. When using *bbox_to_anchor* it almost always makes
sense to also specify a *bbox_transform*. This might often be the
axes transform *parent_axes.transAxes*. Inversely, when specifying
the axes- or figure-transform here, be aware that not specifying
*bbox_to_anchor* will use *parent_axes.bbox*, the units of which are
in display (pixel) coordinates.
axes_class : `~matplotlib.axes.Axes` type, default: `.HostAxes`
The type of the newly created inset axes.
axes_kwargs : dict, optional
Keyword arguments to pass to the constructor of the inset axes.
Valid arguments include:
%(Axes:kwdoc)s
borderpad : float, default: 0.5
Padding between inset axes and the bbox_to_anchor.
The units are axes font size, i.e. for a default font size of 10 points
*borderpad = 0.5* is equivalent to a padding of 5 points.
Returns
-------
inset_axes : *axes_class*
Inset axes object created.
"""
return _add_inset_axes(
parent_axes, axes_class, axes_kwargs,
AnchoredZoomLocator(
parent_axes, zoom=zoom, loc=loc,
bbox_to_anchor=bbox_to_anchor, bbox_transform=bbox_transform,
borderpad=borderpad))
class _TransformedBboxWithCallback(TransformedBbox):
"""
Variant of `.TransformBbox` which calls *callback* before returning points.
Used by `.mark_inset` to unstale the parent axes' viewlim as needed.
"""
def __init__(self, *args, callback, **kwargs):
super().__init__(*args, **kwargs)
self._callback = callback
def get_points(self):
self._callback()
return super().get_points()
@_docstring.dedent_interpd
def mark_inset(parent_axes, inset_axes, loc1, loc2, **kwargs):
"""
Draw a box to mark the location of an area represented by an inset axes.
This function draws a box in *parent_axes* at the bounding box of
*inset_axes*, and shows a connection with the inset axes by drawing lines
at the corners, giving a "zoomed in" effect.
Parameters
----------
parent_axes : `~matplotlib.axes.Axes`
Axes which contains the area of the inset axes.
inset_axes : `~matplotlib.axes.Axes`
The inset axes.
loc1, loc2 : {1, 2, 3, 4}
Corners to use for connecting the inset axes and the area in the
parent axes.
**kwargs
Patch properties for the lines and box drawn:
%(Patch:kwdoc)s
Returns
-------
pp : `~matplotlib.patches.Patch`
The patch drawn to represent the area of the inset axes.
p1, p2 : `~matplotlib.patches.Patch`
The patches connecting two corners of the inset axes and its area.
"""
rect = _TransformedBboxWithCallback(
inset_axes.viewLim, parent_axes.transData,
callback=parent_axes._unstale_viewLim)
kwargs.setdefault("fill", bool({'fc', 'facecolor', 'color'}.intersection(kwargs)))
pp = BboxPatch(rect, **kwargs)
parent_axes.add_patch(pp)
p1 = BboxConnector(inset_axes.bbox, rect, loc1=loc1, **kwargs)
inset_axes.add_patch(p1)
p1.set_clip_on(False)
p2 = BboxConnector(inset_axes.bbox, rect, loc1=loc2, **kwargs)
inset_axes.add_patch(p2)
p2.set_clip_on(False)
return pp, p1, p2

View File

@ -0,0 +1,128 @@
import matplotlib.axes as maxes
from matplotlib.artist import Artist
from matplotlib.axis import XAxis, YAxis
class SimpleChainedObjects:
def __init__(self, objects):
self._objects = objects
def __getattr__(self, k):
_a = SimpleChainedObjects([getattr(a, k) for a in self._objects])
return _a
def __call__(self, *args, **kwargs):
for m in self._objects:
m(*args, **kwargs)
class Axes(maxes.Axes):
class AxisDict(dict):
def __init__(self, axes):
self.axes = axes
super().__init__()
def __getitem__(self, k):
if isinstance(k, tuple):
r = SimpleChainedObjects(
# super() within a list comprehension needs explicit args.
[super(Axes.AxisDict, self).__getitem__(k1) for k1 in k])
return r
elif isinstance(k, slice):
if k.start is None and k.stop is None and k.step is None:
return SimpleChainedObjects(list(self.values()))
else:
raise ValueError("Unsupported slice")
else:
return dict.__getitem__(self, k)
def __call__(self, *v, **kwargs):
return maxes.Axes.axis(self.axes, *v, **kwargs)
@property
def axis(self):
return self._axislines
def clear(self):
# docstring inherited
super().clear()
# Init axis artists.
self._axislines = self.AxisDict(self)
self._axislines.update(
bottom=SimpleAxisArtist(self.xaxis, 1, self.spines["bottom"]),
top=SimpleAxisArtist(self.xaxis, 2, self.spines["top"]),
left=SimpleAxisArtist(self.yaxis, 1, self.spines["left"]),
right=SimpleAxisArtist(self.yaxis, 2, self.spines["right"]))
class SimpleAxisArtist(Artist):
def __init__(self, axis, axisnum, spine):
self._axis = axis
self._axisnum = axisnum
self.line = spine
if isinstance(axis, XAxis):
self._axis_direction = ["bottom", "top"][axisnum-1]
elif isinstance(axis, YAxis):
self._axis_direction = ["left", "right"][axisnum-1]
else:
raise ValueError(
f"axis must be instance of XAxis or YAxis, but got {axis}")
super().__init__()
@property
def major_ticks(self):
tickline = "tick%dline" % self._axisnum
return SimpleChainedObjects([getattr(tick, tickline)
for tick in self._axis.get_major_ticks()])
@property
def major_ticklabels(self):
label = "label%d" % self._axisnum
return SimpleChainedObjects([getattr(tick, label)
for tick in self._axis.get_major_ticks()])
@property
def label(self):
return self._axis.label
def set_visible(self, b):
self.toggle(all=b)
self.line.set_visible(b)
self._axis.set_visible(True)
super().set_visible(b)
def set_label(self, txt):
self._axis.set_label_text(txt)
def toggle(self, all=None, ticks=None, ticklabels=None, label=None):
if all:
_ticks, _ticklabels, _label = True, True, True
elif all is not None:
_ticks, _ticklabels, _label = False, False, False
else:
_ticks, _ticklabels, _label = None, None, None
if ticks is not None:
_ticks = ticks
if ticklabels is not None:
_ticklabels = ticklabels
if label is not None:
_label = label
if _ticks is not None:
tickparam = {f"tick{self._axisnum}On": _ticks}
self._axis.set_tick_params(**tickparam)
if _ticklabels is not None:
tickparam = {f"label{self._axisnum}On": _ticklabels}
self._axis.set_tick_params(**tickparam)
if _label is not None:
pos = self._axis.get_label_position()
if (pos == self._axis_direction) and not _label:
self._axis.label.set_visible(False)
elif _label:
self._axis.label.set_visible(True)
self._axis.set_label_position(self._axis_direction)

View File

@ -0,0 +1,257 @@
from matplotlib import _api, cbook
import matplotlib.artist as martist
import matplotlib.transforms as mtransforms
from matplotlib.transforms import Bbox
from .mpl_axes import Axes
class ParasiteAxesBase:
def __init__(self, parent_axes, aux_transform=None,
*, viewlim_mode=None, **kwargs):
self._parent_axes = parent_axes
self.transAux = aux_transform
self.set_viewlim_mode(viewlim_mode)
kwargs["frameon"] = False
super().__init__(parent_axes.figure, parent_axes._position, **kwargs)
def clear(self):
super().clear()
martist.setp(self.get_children(), visible=False)
self._get_lines = self._parent_axes._get_lines
self._parent_axes.callbacks._connect_picklable(
"xlim_changed", self._sync_lims)
self._parent_axes.callbacks._connect_picklable(
"ylim_changed", self._sync_lims)
def pick(self, mouseevent):
# This most likely goes to Artist.pick (depending on axes_class given
# to the factory), which only handles pick events registered on the
# axes associated with each child:
super().pick(mouseevent)
# But parasite axes are additionally given pick events from their host
# axes (cf. HostAxesBase.pick), which we handle here:
for a in self.get_children():
if (hasattr(mouseevent.inaxes, "parasites")
and self in mouseevent.inaxes.parasites):
a.pick(mouseevent)
# aux_transform support
def _set_lim_and_transforms(self):
if self.transAux is not None:
self.transAxes = self._parent_axes.transAxes
self.transData = self.transAux + self._parent_axes.transData
self._xaxis_transform = mtransforms.blended_transform_factory(
self.transData, self.transAxes)
self._yaxis_transform = mtransforms.blended_transform_factory(
self.transAxes, self.transData)
else:
super()._set_lim_and_transforms()
def set_viewlim_mode(self, mode):
_api.check_in_list([None, "equal", "transform"], mode=mode)
self._viewlim_mode = mode
def get_viewlim_mode(self):
return self._viewlim_mode
def _sync_lims(self, parent):
viewlim = parent.viewLim.frozen()
mode = self.get_viewlim_mode()
if mode is None:
pass
elif mode == "equal":
self.viewLim.set(viewlim)
elif mode == "transform":
self.viewLim.set(viewlim.transformed(self.transAux.inverted()))
else:
_api.check_in_list([None, "equal", "transform"], mode=mode)
# end of aux_transform support
parasite_axes_class_factory = cbook._make_class_factory(
ParasiteAxesBase, "{}Parasite")
ParasiteAxes = parasite_axes_class_factory(Axes)
class HostAxesBase:
def __init__(self, *args, **kwargs):
self.parasites = []
super().__init__(*args, **kwargs)
def get_aux_axes(
self, tr=None, viewlim_mode="equal", axes_class=None, **kwargs):
"""
Add a parasite axes to this host.
Despite this method's name, this should actually be thought of as an
``add_parasite_axes`` method.
.. versionchanged:: 3.7
Defaults to same base axes class as host axes.
Parameters
----------
tr : `~matplotlib.transforms.Transform` or None, default: None
If a `.Transform`, the following relation will hold:
``parasite.transData = tr + host.transData``.
If None, the parasite's and the host's ``transData`` are unrelated.
viewlim_mode : {"equal", "transform", None}, default: "equal"
How the parasite's view limits are set: directly equal to the
parent axes ("equal"), equal after application of *tr*
("transform"), or independently (None).
axes_class : subclass type of `~matplotlib.axes.Axes`, optional
The `~.axes.Axes` subclass that is instantiated. If None, the base
class of the host axes is used.
**kwargs
Other parameters are forwarded to the parasite axes constructor.
"""
if axes_class is None:
axes_class = self._base_axes_class
parasite_axes_class = parasite_axes_class_factory(axes_class)
ax2 = parasite_axes_class(
self, tr, viewlim_mode=viewlim_mode, **kwargs)
# note that ax2.transData == tr + ax1.transData
# Anything you draw in ax2 will match the ticks and grids of ax1.
self.parasites.append(ax2)
ax2._remove_method = self.parasites.remove
return ax2
def draw(self, renderer):
orig_children_len = len(self._children)
locator = self.get_axes_locator()
if locator:
pos = locator(self, renderer)
self.set_position(pos, which="active")
self.apply_aspect(pos)
else:
self.apply_aspect()
rect = self.get_position()
for ax in self.parasites:
ax.apply_aspect(rect)
self._children.extend(ax.get_children())
super().draw(renderer)
del self._children[orig_children_len:]
def clear(self):
super().clear()
for ax in self.parasites:
ax.clear()
def pick(self, mouseevent):
super().pick(mouseevent)
# Also pass pick events on to parasite axes and, in turn, their
# children (cf. ParasiteAxesBase.pick)
for a in self.parasites:
a.pick(mouseevent)
def twinx(self, axes_class=None):
"""
Create a twin of Axes with a shared x-axis but independent y-axis.
The y-axis of self will have ticks on the left and the returned axes
will have ticks on the right.
"""
ax = self._add_twin_axes(axes_class, sharex=self)
self.axis["right"].set_visible(False)
ax.axis["right"].set_visible(True)
ax.axis["left", "top", "bottom"].set_visible(False)
return ax
def twiny(self, axes_class=None):
"""
Create a twin of Axes with a shared y-axis but independent x-axis.
The x-axis of self will have ticks on the bottom and the returned axes
will have ticks on the top.
"""
ax = self._add_twin_axes(axes_class, sharey=self)
self.axis["top"].set_visible(False)
ax.axis["top"].set_visible(True)
ax.axis["left", "right", "bottom"].set_visible(False)
return ax
def twin(self, aux_trans=None, axes_class=None):
"""
Create a twin of Axes with no shared axis.
While self will have ticks on the left and bottom axis, the returned
axes will have ticks on the top and right axis.
"""
if aux_trans is None:
aux_trans = mtransforms.IdentityTransform()
ax = self._add_twin_axes(
axes_class, aux_transform=aux_trans, viewlim_mode="transform")
self.axis["top", "right"].set_visible(False)
ax.axis["top", "right"].set_visible(True)
ax.axis["left", "bottom"].set_visible(False)
return ax
def _add_twin_axes(self, axes_class, **kwargs):
"""
Helper for `.twinx`/`.twiny`/`.twin`.
*kwargs* are forwarded to the parasite axes constructor.
"""
if axes_class is None:
axes_class = self._base_axes_class
ax = parasite_axes_class_factory(axes_class)(self, **kwargs)
self.parasites.append(ax)
ax._remove_method = self._remove_any_twin
return ax
def _remove_any_twin(self, ax):
self.parasites.remove(ax)
restore = ["top", "right"]
if ax._sharex:
restore.remove("top")
if ax._sharey:
restore.remove("right")
self.axis[tuple(restore)].set_visible(True)
self.axis[tuple(restore)].toggle(ticklabels=False, label=False)
@_api.make_keyword_only("3.8", "call_axes_locator")
def get_tightbbox(self, renderer=None, call_axes_locator=True,
bbox_extra_artists=None):
bbs = [
*[ax.get_tightbbox(renderer, call_axes_locator=call_axes_locator)
for ax in self.parasites],
super().get_tightbbox(renderer,
call_axes_locator=call_axes_locator,
bbox_extra_artists=bbox_extra_artists)]
return Bbox.union([b for b in bbs if b.width != 0 or b.height != 0])
host_axes_class_factory = host_subplot_class_factory = \
cbook._make_class_factory(HostAxesBase, "{}HostAxes", "_base_axes_class")
HostAxes = SubplotHost = host_axes_class_factory(Axes)
def host_axes(*args, axes_class=Axes, figure=None, **kwargs):
"""
Create axes that can act as a hosts to parasitic axes.
Parameters
----------
figure : `~matplotlib.figure.Figure`
Figure to which the axes will be added. Defaults to the current figure
`.pyplot.gcf()`.
*args, **kwargs
Will be passed on to the underlying `~.axes.Axes` object creation.
"""
import matplotlib.pyplot as plt
host_axes_class = host_axes_class_factory(axes_class)
if figure is None:
figure = plt.gcf()
ax = host_axes_class(figure, *args, **kwargs)
figure.add_axes(ax)
return ax
host_subplot = host_axes

View File

@ -0,0 +1,10 @@
from pathlib import Path
# Check that the test directories exist
if not (Path(__file__).parent / "baseline_images").exists():
raise OSError(
'The baseline image directory does not exist. '
'This is most likely because the test data is not installed. '
'You may need to install matplotlib from source to get the '
'test data.')

View File

@ -0,0 +1,2 @@
from matplotlib.testing.conftest import (mpl_test_settings, # noqa
pytest_configure, pytest_unconfigure)

View File

@ -0,0 +1,792 @@
from itertools import product
import io
import platform
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from matplotlib import cbook
from matplotlib.backend_bases import MouseEvent
from matplotlib.colors import LogNorm
from matplotlib.patches import Circle, Ellipse
from matplotlib.transforms import Bbox, TransformedBbox
from matplotlib.testing.decorators import (
check_figures_equal, image_comparison, remove_ticks_and_titles)
from mpl_toolkits.axes_grid1 import (
axes_size as Size,
host_subplot, make_axes_locatable,
Grid, AxesGrid, ImageGrid)
from mpl_toolkits.axes_grid1.anchored_artists import (
AnchoredAuxTransformBox, AnchoredDrawingArea, AnchoredEllipse,
AnchoredDirectionArrows, AnchoredSizeBar)
from mpl_toolkits.axes_grid1.axes_divider import (
Divider, HBoxDivider, make_axes_area_auto_adjustable, SubplotDivider,
VBoxDivider)
from mpl_toolkits.axes_grid1.axes_rgb import RGBAxes
from mpl_toolkits.axes_grid1.inset_locator import (
zoomed_inset_axes, mark_inset, inset_axes, BboxConnectorPatch,
InsetPosition)
import mpl_toolkits.axes_grid1.mpl_axes
import pytest
import numpy as np
from numpy.testing import assert_array_equal, assert_array_almost_equal
def test_divider_append_axes():
fig, ax = plt.subplots()
divider = make_axes_locatable(ax)
axs = {
"main": ax,
"top": divider.append_axes("top", 1.2, pad=0.1, sharex=ax),
"bottom": divider.append_axes("bottom", 1.2, pad=0.1, sharex=ax),
"left": divider.append_axes("left", 1.2, pad=0.1, sharey=ax),
"right": divider.append_axes("right", 1.2, pad=0.1, sharey=ax),
}
fig.canvas.draw()
bboxes = {k: axs[k].get_window_extent() for k in axs}
dpi = fig.dpi
assert bboxes["top"].height == pytest.approx(1.2 * dpi)
assert bboxes["bottom"].height == pytest.approx(1.2 * dpi)
assert bboxes["left"].width == pytest.approx(1.2 * dpi)
assert bboxes["right"].width == pytest.approx(1.2 * dpi)
assert bboxes["top"].y0 - bboxes["main"].y1 == pytest.approx(0.1 * dpi)
assert bboxes["main"].y0 - bboxes["bottom"].y1 == pytest.approx(0.1 * dpi)
assert bboxes["main"].x0 - bboxes["left"].x1 == pytest.approx(0.1 * dpi)
assert bboxes["right"].x0 - bboxes["main"].x1 == pytest.approx(0.1 * dpi)
assert bboxes["left"].y0 == bboxes["main"].y0 == bboxes["right"].y0
assert bboxes["left"].y1 == bboxes["main"].y1 == bboxes["right"].y1
assert bboxes["top"].x0 == bboxes["main"].x0 == bboxes["bottom"].x0
assert bboxes["top"].x1 == bboxes["main"].x1 == bboxes["bottom"].x1
# Update style when regenerating the test image
@image_comparison(['twin_axes_empty_and_removed'], extensions=["png"], tol=1,
style=('classic', '_classic_test_patch'))
def test_twin_axes_empty_and_removed():
# Purely cosmetic font changes (avoid overlap)
mpl.rcParams.update(
{"font.size": 8, "xtick.labelsize": 8, "ytick.labelsize": 8})
generators = ["twinx", "twiny", "twin"]
modifiers = ["", "host invisible", "twin removed", "twin invisible",
"twin removed\nhost invisible"]
# Unmodified host subplot at the beginning for reference
h = host_subplot(len(modifiers)+1, len(generators), 2)
h.text(0.5, 0.5, "host_subplot",
horizontalalignment="center", verticalalignment="center")
# Host subplots with various modifications (twin*, visibility) applied
for i, (mod, gen) in enumerate(product(modifiers, generators),
len(generators) + 1):
h = host_subplot(len(modifiers)+1, len(generators), i)
t = getattr(h, gen)()
if "twin invisible" in mod:
t.axis[:].set_visible(False)
if "twin removed" in mod:
t.remove()
if "host invisible" in mod:
h.axis[:].set_visible(False)
h.text(0.5, 0.5, gen + ("\n" + mod if mod else ""),
horizontalalignment="center", verticalalignment="center")
plt.subplots_adjust(wspace=0.5, hspace=1)
def test_twin_axes_both_with_units():
host = host_subplot(111)
with pytest.warns(mpl.MatplotlibDeprecationWarning):
host.plot_date([0, 1, 2], [0, 1, 2], xdate=False, ydate=True)
twin = host.twinx()
twin.plot(["a", "b", "c"])
assert host.get_yticklabels()[0].get_text() == "00:00:00"
assert twin.get_yticklabels()[0].get_text() == "a"
def test_axesgrid_colorbar_log_smoketest():
fig = plt.figure()
grid = AxesGrid(fig, 111, # modified to be only subplot
nrows_ncols=(1, 1),
ngrids=1,
label_mode="L",
cbar_location="top",
cbar_mode="single",
)
Z = 10000 * np.random.rand(10, 10)
im = grid[0].imshow(Z, interpolation="nearest", norm=LogNorm())
grid.cbar_axes[0].colorbar(im)
def test_inset_colorbar_tight_layout_smoketest():
fig, ax = plt.subplots(1, 1)
pts = ax.scatter([0, 1], [0, 1], c=[1, 5])
cax = inset_axes(ax, width="3%", height="70%")
plt.colorbar(pts, cax=cax)
with pytest.warns(UserWarning, match="This figure includes Axes"):
# Will warn, but not raise an error
plt.tight_layout()
@image_comparison(['inset_locator.png'], style='default', remove_text=True)
def test_inset_locator():
fig, ax = plt.subplots(figsize=[5, 4])
# prepare the demo image
# Z is a 15x15 array
Z = cbook.get_sample_data("axes_grid/bivariate_normal.npy")
extent = (-3, 4, -4, 3)
Z2 = np.zeros((150, 150))
ny, nx = Z.shape
Z2[30:30+ny, 30:30+nx] = Z
ax.imshow(Z2, extent=extent, interpolation="nearest",
origin="lower")
axins = zoomed_inset_axes(ax, zoom=6, loc='upper right')
axins.imshow(Z2, extent=extent, interpolation="nearest",
origin="lower")
axins.yaxis.get_major_locator().set_params(nbins=7)
axins.xaxis.get_major_locator().set_params(nbins=7)
# sub region of the original image
x1, x2, y1, y2 = -1.5, -0.9, -2.5, -1.9
axins.set_xlim(x1, x2)
axins.set_ylim(y1, y2)
plt.xticks(visible=False)
plt.yticks(visible=False)
# draw a bbox of the region of the inset axes in the parent axes and
# connecting lines between the bbox and the inset axes area
mark_inset(ax, axins, loc1=2, loc2=4, fc="none", ec="0.5")
asb = AnchoredSizeBar(ax.transData,
0.5,
'0.5',
loc='lower center',
pad=0.1, borderpad=0.5, sep=5,
frameon=False)
ax.add_artist(asb)
@image_comparison(['inset_axes.png'], style='default', remove_text=True)
def test_inset_axes():
fig, ax = plt.subplots(figsize=[5, 4])
# prepare the demo image
# Z is a 15x15 array
Z = cbook.get_sample_data("axes_grid/bivariate_normal.npy")
extent = (-3, 4, -4, 3)
Z2 = np.zeros((150, 150))
ny, nx = Z.shape
Z2[30:30+ny, 30:30+nx] = Z
ax.imshow(Z2, extent=extent, interpolation="nearest",
origin="lower")
# creating our inset axes with a bbox_transform parameter
axins = inset_axes(ax, width=1., height=1., bbox_to_anchor=(1, 1),
bbox_transform=ax.transAxes)
axins.imshow(Z2, extent=extent, interpolation="nearest",
origin="lower")
axins.yaxis.get_major_locator().set_params(nbins=7)
axins.xaxis.get_major_locator().set_params(nbins=7)
# sub region of the original image
x1, x2, y1, y2 = -1.5, -0.9, -2.5, -1.9
axins.set_xlim(x1, x2)
axins.set_ylim(y1, y2)
plt.xticks(visible=False)
plt.yticks(visible=False)
# draw a bbox of the region of the inset axes in the parent axes and
# connecting lines between the bbox and the inset axes area
mark_inset(ax, axins, loc1=2, loc2=4, fc="none", ec="0.5")
asb = AnchoredSizeBar(ax.transData,
0.5,
'0.5',
loc='lower center',
pad=0.1, borderpad=0.5, sep=5,
frameon=False)
ax.add_artist(asb)
def test_inset_axes_complete():
dpi = 100
figsize = (6, 5)
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
fig.subplots_adjust(.1, .1, .9, .9)
ins = inset_axes(ax, width=2., height=2., borderpad=0)
fig.canvas.draw()
assert_array_almost_equal(
ins.get_position().extents,
[(0.9*figsize[0]-2.)/figsize[0], (0.9*figsize[1]-2.)/figsize[1],
0.9, 0.9])
ins = inset_axes(ax, width="40%", height="30%", borderpad=0)
fig.canvas.draw()
assert_array_almost_equal(
ins.get_position().extents, [.9-.8*.4, .9-.8*.3, 0.9, 0.9])
ins = inset_axes(ax, width=1., height=1.2, bbox_to_anchor=(200, 100),
loc=3, borderpad=0)
fig.canvas.draw()
assert_array_almost_equal(
ins.get_position().extents,
[200/dpi/figsize[0], 100/dpi/figsize[1],
(200/dpi+1)/figsize[0], (100/dpi+1.2)/figsize[1]])
ins1 = inset_axes(ax, width="35%", height="60%", loc=3, borderpad=1)
ins2 = inset_axes(ax, width="100%", height="100%",
bbox_to_anchor=(0, 0, .35, .60),
bbox_transform=ax.transAxes, loc=3, borderpad=1)
fig.canvas.draw()
assert_array_equal(ins1.get_position().extents,
ins2.get_position().extents)
with pytest.raises(ValueError):
ins = inset_axes(ax, width="40%", height="30%",
bbox_to_anchor=(0.4, 0.5))
with pytest.warns(UserWarning):
ins = inset_axes(ax, width="40%", height="30%",
bbox_transform=ax.transAxes)
def test_inset_axes_tight():
# gh-26287 found that inset_axes raised with bbox_inches=tight
fig, ax = plt.subplots()
inset_axes(ax, width=1.3, height=0.9)
f = io.BytesIO()
fig.savefig(f, bbox_inches="tight")
@image_comparison(['fill_facecolor.png'], remove_text=True, style='mpl20')
def test_fill_facecolor():
fig, ax = plt.subplots(1, 5)
fig.set_size_inches(5, 5)
for i in range(1, 4):
ax[i].yaxis.set_visible(False)
ax[4].yaxis.tick_right()
bbox = Bbox.from_extents(0, 0.4, 1, 0.6)
# fill with blue by setting 'fc' field
bbox1 = TransformedBbox(bbox, ax[0].transData)
bbox2 = TransformedBbox(bbox, ax[1].transData)
# set color to BboxConnectorPatch
p = BboxConnectorPatch(
bbox1, bbox2, loc1a=1, loc2a=2, loc1b=4, loc2b=3,
ec="r", fc="b")
p.set_clip_on(False)
ax[0].add_patch(p)
# set color to marked area
axins = zoomed_inset_axes(ax[0], 1, loc='upper right')
axins.set_xlim(0, 0.2)
axins.set_ylim(0, 0.2)
plt.gca().axes.xaxis.set_ticks([])
plt.gca().axes.yaxis.set_ticks([])
mark_inset(ax[0], axins, loc1=2, loc2=4, fc="b", ec="0.5")
# fill with yellow by setting 'facecolor' field
bbox3 = TransformedBbox(bbox, ax[1].transData)
bbox4 = TransformedBbox(bbox, ax[2].transData)
# set color to BboxConnectorPatch
p = BboxConnectorPatch(
bbox3, bbox4, loc1a=1, loc2a=2, loc1b=4, loc2b=3,
ec="r", facecolor="y")
p.set_clip_on(False)
ax[1].add_patch(p)
# set color to marked area
axins = zoomed_inset_axes(ax[1], 1, loc='upper right')
axins.set_xlim(0, 0.2)
axins.set_ylim(0, 0.2)
plt.gca().axes.xaxis.set_ticks([])
plt.gca().axes.yaxis.set_ticks([])
mark_inset(ax[1], axins, loc1=2, loc2=4, facecolor="y", ec="0.5")
# fill with green by setting 'color' field
bbox5 = TransformedBbox(bbox, ax[2].transData)
bbox6 = TransformedBbox(bbox, ax[3].transData)
# set color to BboxConnectorPatch
p = BboxConnectorPatch(
bbox5, bbox6, loc1a=1, loc2a=2, loc1b=4, loc2b=3,
ec="r", color="g")
p.set_clip_on(False)
ax[2].add_patch(p)
# set color to marked area
axins = zoomed_inset_axes(ax[2], 1, loc='upper right')
axins.set_xlim(0, 0.2)
axins.set_ylim(0, 0.2)
plt.gca().axes.xaxis.set_ticks([])
plt.gca().axes.yaxis.set_ticks([])
mark_inset(ax[2], axins, loc1=2, loc2=4, color="g", ec="0.5")
# fill with green but color won't show if set fill to False
bbox7 = TransformedBbox(bbox, ax[3].transData)
bbox8 = TransformedBbox(bbox, ax[4].transData)
# BboxConnectorPatch won't show green
p = BboxConnectorPatch(
bbox7, bbox8, loc1a=1, loc2a=2, loc1b=4, loc2b=3,
ec="r", fc="g", fill=False)
p.set_clip_on(False)
ax[3].add_patch(p)
# marked area won't show green
axins = zoomed_inset_axes(ax[3], 1, loc='upper right')
axins.set_xlim(0, 0.2)
axins.set_ylim(0, 0.2)
axins.xaxis.set_ticks([])
axins.yaxis.set_ticks([])
mark_inset(ax[3], axins, loc1=2, loc2=4, fc="g", ec="0.5", fill=False)
# Update style when regenerating the test image
@image_comparison(['zoomed_axes.png', 'inverted_zoomed_axes.png'],
style=('classic', '_classic_test_patch'),
tol=0.02 if platform.machine() == 'arm64' else 0)
def test_zooming_with_inverted_axes():
fig, ax = plt.subplots()
ax.plot([1, 2, 3], [1, 2, 3])
ax.axis([1, 3, 1, 3])
inset_ax = zoomed_inset_axes(ax, zoom=2.5, loc='lower right')
inset_ax.axis([1.1, 1.4, 1.1, 1.4])
fig, ax = plt.subplots()
ax.plot([1, 2, 3], [1, 2, 3])
ax.axis([3, 1, 3, 1])
inset_ax = zoomed_inset_axes(ax, zoom=2.5, loc='lower right')
inset_ax.axis([1.4, 1.1, 1.4, 1.1])
# Update style when regenerating the test image
@image_comparison(['anchored_direction_arrows.png'],
tol=0 if platform.machine() == 'x86_64' else 0.01,
style=('classic', '_classic_test_patch'))
def test_anchored_direction_arrows():
fig, ax = plt.subplots()
ax.imshow(np.zeros((10, 10)), interpolation='nearest')
simple_arrow = AnchoredDirectionArrows(ax.transAxes, 'X', 'Y')
ax.add_artist(simple_arrow)
# Update style when regenerating the test image
@image_comparison(['anchored_direction_arrows_many_args.png'],
style=('classic', '_classic_test_patch'))
def test_anchored_direction_arrows_many_args():
fig, ax = plt.subplots()
ax.imshow(np.ones((10, 10)))
direction_arrows = AnchoredDirectionArrows(
ax.transAxes, 'A', 'B', loc='upper right', color='red',
aspect_ratio=-0.5, pad=0.6, borderpad=2, frameon=True, alpha=0.7,
sep_x=-0.06, sep_y=-0.08, back_length=0.1, head_width=9,
head_length=10, tail_width=5)
ax.add_artist(direction_arrows)
def test_axes_locatable_position():
fig, ax = plt.subplots()
divider = make_axes_locatable(ax)
with mpl.rc_context({"figure.subplot.wspace": 0.02}):
cax = divider.append_axes('right', size='5%')
fig.canvas.draw()
assert np.isclose(cax.get_position(original=False).width,
0.03621495327102808)
@image_comparison(['image_grid_each_left_label_mode_all.png'], style='mpl20',
savefig_kwarg={'bbox_inches': 'tight'})
def test_image_grid_each_left_label_mode_all():
imdata = np.arange(100).reshape((10, 10))
fig = plt.figure(1, (3, 3))
grid = ImageGrid(fig, (1, 1, 1), nrows_ncols=(3, 2), axes_pad=(0.5, 0.3),
cbar_mode="each", cbar_location="left", cbar_size="15%",
label_mode="all")
# 3-tuple rect => SubplotDivider
assert isinstance(grid.get_divider(), SubplotDivider)
assert grid.get_axes_pad() == (0.5, 0.3)
assert grid.get_aspect() # True by default for ImageGrid
for ax, cax in zip(grid, grid.cbar_axes):
im = ax.imshow(imdata, interpolation='none')
cax.colorbar(im)
@image_comparison(['image_grid_single_bottom_label_mode_1.png'], style='mpl20',
savefig_kwarg={'bbox_inches': 'tight'})
def test_image_grid_single_bottom():
imdata = np.arange(100).reshape((10, 10))
fig = plt.figure(1, (2.5, 1.5))
grid = ImageGrid(fig, (0, 0, 1, 1), nrows_ncols=(1, 3),
axes_pad=(0.2, 0.15), cbar_mode="single",
cbar_location="bottom", cbar_size="10%", label_mode="1")
# 4-tuple rect => Divider, isinstance will give True for SubplotDivider
assert type(grid.get_divider()) is Divider
for i in range(3):
im = grid[i].imshow(imdata, interpolation='none')
grid.cbar_axes[0].colorbar(im)
def test_image_grid_label_mode_invalid():
fig = plt.figure()
with pytest.raises(ValueError, match="'foo' is not a valid value for mode"):
ImageGrid(fig, (0, 0, 1, 1), (2, 1), label_mode="foo")
@image_comparison(['image_grid.png'],
remove_text=True, style='mpl20',
savefig_kwarg={'bbox_inches': 'tight'})
def test_image_grid():
# test that image grid works with bbox_inches=tight.
im = np.arange(100).reshape((10, 10))
fig = plt.figure(1, (4, 4))
grid = ImageGrid(fig, 111, nrows_ncols=(2, 2), axes_pad=0.1)
assert grid.get_axes_pad() == (0.1, 0.1)
for i in range(4):
grid[i].imshow(im, interpolation='nearest')
def test_gettightbbox():
fig, ax = plt.subplots(figsize=(8, 6))
l, = ax.plot([1, 2, 3], [0, 1, 0])
ax_zoom = zoomed_inset_axes(ax, 4)
ax_zoom.plot([1, 2, 3], [0, 1, 0])
mark_inset(ax, ax_zoom, loc1=1, loc2=3, fc="none", ec='0.3')
remove_ticks_and_titles(fig)
bbox = fig.get_tightbbox(fig.canvas.get_renderer())
np.testing.assert_array_almost_equal(bbox.extents,
[-17.7, -13.9, 7.2, 5.4])
@pytest.mark.parametrize("click_on", ["big", "small"])
@pytest.mark.parametrize("big_on_axes,small_on_axes", [
("gca", "gca"),
("host", "host"),
("host", "parasite"),
("parasite", "host"),
("parasite", "parasite")
])
def test_picking_callbacks_overlap(big_on_axes, small_on_axes, click_on):
"""Test pick events on normal, host or parasite axes."""
# Two rectangles are drawn and "clicked on", a small one and a big one
# enclosing the small one. The axis on which they are drawn as well as the
# rectangle that is clicked on are varied.
# In each case we expect that both rectangles are picked if we click on the
# small one and only the big one is picked if we click on the big one.
# Also tests picking on normal axes ("gca") as a control.
big = plt.Rectangle((0.25, 0.25), 0.5, 0.5, picker=5)
small = plt.Rectangle((0.4, 0.4), 0.2, 0.2, facecolor="r", picker=5)
# Machinery for "receiving" events
received_events = []
def on_pick(event):
received_events.append(event)
plt.gcf().canvas.mpl_connect('pick_event', on_pick)
# Shortcut
rectangles_on_axes = (big_on_axes, small_on_axes)
# Axes setup
axes = {"gca": None, "host": None, "parasite": None}
if "gca" in rectangles_on_axes:
axes["gca"] = plt.gca()
if "host" in rectangles_on_axes or "parasite" in rectangles_on_axes:
axes["host"] = host_subplot(111)
axes["parasite"] = axes["host"].twin()
# Add rectangles to axes
axes[big_on_axes].add_patch(big)
axes[small_on_axes].add_patch(small)
# Simulate picking with click mouse event
if click_on == "big":
click_axes = axes[big_on_axes]
axes_coords = (0.3, 0.3)
else:
click_axes = axes[small_on_axes]
axes_coords = (0.5, 0.5)
# In reality mouse events never happen on parasite axes, only host axes
if click_axes is axes["parasite"]:
click_axes = axes["host"]
(x, y) = click_axes.transAxes.transform(axes_coords)
m = MouseEvent("button_press_event", click_axes.figure.canvas, x, y,
button=1)
click_axes.pick(m)
# Checks
expected_n_events = 2 if click_on == "small" else 1
assert len(received_events) == expected_n_events
event_rects = [event.artist for event in received_events]
assert big in event_rects
if click_on == "small":
assert small in event_rects
@image_comparison(['anchored_artists.png'], remove_text=True, style='mpl20')
def test_anchored_artists():
fig, ax = plt.subplots(figsize=(3, 3))
ada = AnchoredDrawingArea(40, 20, 0, 0, loc='upper right', pad=0.,
frameon=False)
p1 = Circle((10, 10), 10)
ada.drawing_area.add_artist(p1)
p2 = Circle((30, 10), 5, fc="r")
ada.drawing_area.add_artist(p2)
ax.add_artist(ada)
box = AnchoredAuxTransformBox(ax.transData, loc='upper left')
el = Ellipse((0, 0), width=0.1, height=0.4, angle=30, color='cyan')
box.drawing_area.add_artist(el)
ax.add_artist(box)
# Manually construct the ellipse instead, once the deprecation elapses.
with pytest.warns(mpl.MatplotlibDeprecationWarning):
ae = AnchoredEllipse(ax.transData, width=0.1, height=0.25, angle=-60,
loc='lower left', pad=0.5, borderpad=0.4,
frameon=True)
ax.add_artist(ae)
asb = AnchoredSizeBar(ax.transData, 0.2, r"0.2 units", loc='lower right',
pad=0.3, borderpad=0.4, sep=4, fill_bar=True,
frameon=False, label_top=True, prop={'size': 20},
size_vertical=0.05, color='green')
ax.add_artist(asb)
def test_hbox_divider():
arr1 = np.arange(20).reshape((4, 5))
arr2 = np.arange(20).reshape((5, 4))
fig, (ax1, ax2) = plt.subplots(1, 2)
ax1.imshow(arr1)
ax2.imshow(arr2)
pad = 0.5 # inches.
divider = HBoxDivider(
fig, 111, # Position of combined axes.
horizontal=[Size.AxesX(ax1), Size.Fixed(pad), Size.AxesX(ax2)],
vertical=[Size.AxesY(ax1), Size.Scaled(1), Size.AxesY(ax2)])
ax1.set_axes_locator(divider.new_locator(0))
ax2.set_axes_locator(divider.new_locator(2))
fig.canvas.draw()
p1 = ax1.get_position()
p2 = ax2.get_position()
assert p1.height == p2.height
assert p2.width / p1.width == pytest.approx((4 / 5) ** 2)
def test_vbox_divider():
arr1 = np.arange(20).reshape((4, 5))
arr2 = np.arange(20).reshape((5, 4))
fig, (ax1, ax2) = plt.subplots(1, 2)
ax1.imshow(arr1)
ax2.imshow(arr2)
pad = 0.5 # inches.
divider = VBoxDivider(
fig, 111, # Position of combined axes.
horizontal=[Size.AxesX(ax1), Size.Scaled(1), Size.AxesX(ax2)],
vertical=[Size.AxesY(ax1), Size.Fixed(pad), Size.AxesY(ax2)])
ax1.set_axes_locator(divider.new_locator(0))
ax2.set_axes_locator(divider.new_locator(2))
fig.canvas.draw()
p1 = ax1.get_position()
p2 = ax2.get_position()
assert p1.width == p2.width
assert p1.height / p2.height == pytest.approx((4 / 5) ** 2)
def test_axes_class_tuple():
fig = plt.figure()
axes_class = (mpl_toolkits.axes_grid1.mpl_axes.Axes, {})
gr = AxesGrid(fig, 111, nrows_ncols=(1, 1), axes_class=axes_class)
def test_grid_axes_lists():
"""Test Grid axes_all, axes_row and axes_column relationship."""
fig = plt.figure()
grid = Grid(fig, 111, (2, 3), direction="row")
assert_array_equal(grid, grid.axes_all)
assert_array_equal(grid.axes_row, np.transpose(grid.axes_column))
assert_array_equal(grid, np.ravel(grid.axes_row), "row")
assert grid.get_geometry() == (2, 3)
grid = Grid(fig, 111, (2, 3), direction="column")
assert_array_equal(grid, np.ravel(grid.axes_column), "column")
@pytest.mark.parametrize('direction', ('row', 'column'))
def test_grid_axes_position(direction):
"""Test positioning of the axes in Grid."""
fig = plt.figure()
grid = Grid(fig, 111, (2, 2), direction=direction)
loc = [ax.get_axes_locator() for ax in np.ravel(grid.axes_row)]
# Test nx.
assert loc[1].args[0] > loc[0].args[0]
assert loc[0].args[0] == loc[2].args[0]
assert loc[3].args[0] == loc[1].args[0]
# Test ny.
assert loc[2].args[1] < loc[0].args[1]
assert loc[0].args[1] == loc[1].args[1]
assert loc[3].args[1] == loc[2].args[1]
@pytest.mark.parametrize('rect, ngrids, error, message', (
((1, 1), None, TypeError, "Incorrect rect format"),
(111, -1, ValueError, "ngrids must be positive"),
(111, 7, ValueError, "ngrids must be positive"),
))
def test_grid_errors(rect, ngrids, error, message):
fig = plt.figure()
with pytest.raises(error, match=message):
Grid(fig, rect, (2, 3), ngrids=ngrids)
@pytest.mark.parametrize('anchor, error, message', (
(None, TypeError, "anchor must be str"),
("CC", ValueError, "'CC' is not a valid value for anchor"),
((1, 1, 1), TypeError, "anchor must be str"),
))
def test_divider_errors(anchor, error, message):
fig = plt.figure()
with pytest.raises(error, match=message):
Divider(fig, [0, 0, 1, 1], [Size.Fixed(1)], [Size.Fixed(1)],
anchor=anchor)
@check_figures_equal(extensions=["png"])
def test_mark_inset_unstales_viewlim(fig_test, fig_ref):
inset, full = fig_test.subplots(1, 2)
full.plot([0, 5], [0, 5])
inset.set(xlim=(1, 2), ylim=(1, 2))
# Check that mark_inset unstales full's viewLim before drawing the marks.
mark_inset(full, inset, 1, 4)
inset, full = fig_ref.subplots(1, 2)
full.plot([0, 5], [0, 5])
inset.set(xlim=(1, 2), ylim=(1, 2))
mark_inset(full, inset, 1, 4)
# Manually unstale the full's viewLim.
fig_ref.canvas.draw()
def test_auto_adjustable():
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1])
pad = 0.1
make_axes_area_auto_adjustable(ax, pad=pad)
fig.canvas.draw()
tbb = ax.get_tightbbox()
assert tbb.x0 == pytest.approx(pad * fig.dpi)
assert tbb.x1 == pytest.approx(fig.bbox.width - pad * fig.dpi)
assert tbb.y0 == pytest.approx(pad * fig.dpi)
assert tbb.y1 == pytest.approx(fig.bbox.height - pad * fig.dpi)
# Update style when regenerating the test image
@image_comparison(['rgb_axes.png'], remove_text=True,
style=('classic', '_classic_test_patch'))
def test_rgb_axes():
fig = plt.figure()
ax = RGBAxes(fig, (0.1, 0.1, 0.8, 0.8), pad=0.1)
rng = np.random.default_rng(19680801)
r = rng.random((5, 5))
g = rng.random((5, 5))
b = rng.random((5, 5))
ax.imshow_rgb(r, g, b, interpolation='none')
# Update style when regenerating the test image
@image_comparison(['insetposition.png'], remove_text=True,
style=('classic', '_classic_test_patch'))
def test_insetposition():
fig, ax = plt.subplots(figsize=(2, 2))
ax_ins = plt.axes([0, 0, 1, 1])
with pytest.warns(mpl.MatplotlibDeprecationWarning):
ip = InsetPosition(ax, [0.2, 0.25, 0.5, 0.4])
ax_ins.set_axes_locator(ip)
# The original version of this test relied on mpl_toolkits's slightly different
# colorbar implementation; moving to matplotlib's own colorbar implementation
# caused the small image comparison error.
@image_comparison(['imagegrid_cbar_mode.png'],
remove_text=True, style='mpl20', tol=0.3)
def test_imagegrid_cbar_mode_edge():
arr = np.arange(16).reshape((4, 4))
fig = plt.figure(figsize=(18, 9))
positions = (241, 242, 243, 244, 245, 246, 247, 248)
directions = ['row']*4 + ['column']*4
cbar_locations = ['left', 'right', 'top', 'bottom']*2
for position, direction, location in zip(
positions, directions, cbar_locations):
grid = ImageGrid(fig, position,
nrows_ncols=(2, 2),
direction=direction,
cbar_location=location,
cbar_size='20%',
cbar_mode='edge')
ax1, ax2, ax3, ax4 = grid
ax1.imshow(arr, cmap='nipy_spectral')
ax2.imshow(arr.T, cmap='hot')
ax3.imshow(np.hypot(arr, arr.T), cmap='jet')
ax4.imshow(np.arctan2(arr, arr.T), cmap='hsv')
# In each row/column, the "first" colorbars must be overwritten by the
# "second" ones. To achieve this, clear out the axes first.
for ax in grid:
ax.cax.cla()
cb = ax.cax.colorbar(ax.images[0])
def test_imagegrid():
fig = plt.figure()
grid = ImageGrid(fig, 111, nrows_ncols=(1, 1))
ax = grid[0]
im = ax.imshow([[1, 2]], norm=mpl.colors.LogNorm())
cb = ax.cax.colorbar(im)
assert isinstance(cb.locator, mticker.LogLocator)
def test_removal():
import matplotlib.pyplot as plt
import mpl_toolkits.axisartist as AA
fig = plt.figure()
ax = host_subplot(111, axes_class=AA.Axes, figure=fig)
col = ax.fill_between(range(5), 0, range(5))
fig.canvas.draw()
col.remove()
fig.canvas.draw()
@image_comparison(['anchored_locator_base_call.png'], style="mpl20")
def test_anchored_locator_base_call():
fig = plt.figure(figsize=(3, 3))
fig1, fig2 = fig.subfigures(nrows=2, ncols=1)
ax = fig1.subplots()
ax.set(aspect=1, xlim=(-15, 15), ylim=(-20, 5))
ax.set(xticks=[], yticks=[])
Z = cbook.get_sample_data("axes_grid/bivariate_normal.npy")
extent = (-3, 4, -4, 3)
axins = zoomed_inset_axes(ax, zoom=2, loc="upper left")
axins.set(xticks=[], yticks=[])
axins.imshow(Z, extent=extent, origin="lower")
def test_grid_with_axes_class_not_overriding_axis():
Grid(plt.figure(), 111, (2, 2), axes_class=mpl.axes.Axes)
RGBAxes(plt.figure(), 111, axes_class=mpl.axes.Axes)

View File

@ -0,0 +1,14 @@
from .axislines import Axes
from .axislines import ( # noqa: F401
AxesZero, AxisArtistHelper, AxisArtistHelperRectlinear,
GridHelperBase, GridHelperRectlinear, Subplot, SubplotZero)
from .axis_artist import AxisArtist, GridlinesCollection # noqa: F401
from .grid_helper_curvelinear import GridHelperCurveLinear # noqa: F401
from .floating_axes import FloatingAxes, FloatingSubplot # noqa: F401
from mpl_toolkits.axes_grid1.parasite_axes import (
host_axes_class_factory, parasite_axes_class_factory)
ParasiteAxes = parasite_axes_class_factory(Axes)
HostAxes = host_axes_class_factory(Axes)
SubplotHost = HostAxes

View File

@ -0,0 +1,394 @@
import numpy as np
import math
from mpl_toolkits.axisartist.grid_finder import ExtremeFinderSimple
def select_step_degree(dv):
degree_limits_ = [1.5, 3, 7, 13, 20, 40, 70, 120, 270, 520]
degree_steps_ = [1, 2, 5, 10, 15, 30, 45, 90, 180, 360]
degree_factors = [1.] * len(degree_steps_)
minsec_limits_ = [1.5, 2.5, 3.5, 8, 11, 18, 25, 45]
minsec_steps_ = [1, 2, 3, 5, 10, 15, 20, 30]
minute_limits_ = np.array(minsec_limits_) / 60
minute_factors = [60.] * len(minute_limits_)
second_limits_ = np.array(minsec_limits_) / 3600
second_factors = [3600.] * len(second_limits_)
degree_limits = [*second_limits_, *minute_limits_, *degree_limits_]
degree_steps = [*minsec_steps_, *minsec_steps_, *degree_steps_]
degree_factors = [*second_factors, *minute_factors, *degree_factors]
n = np.searchsorted(degree_limits, dv)
step = degree_steps[n]
factor = degree_factors[n]
return step, factor
def select_step_hour(dv):
hour_limits_ = [1.5, 2.5, 3.5, 5, 7, 10, 15, 21, 36]
hour_steps_ = [1, 2, 3, 4, 6, 8, 12, 18, 24]
hour_factors = [1.] * len(hour_steps_)
minsec_limits_ = [1.5, 2.5, 3.5, 4.5, 5.5, 8, 11, 14, 18, 25, 45]
minsec_steps_ = [1, 2, 3, 4, 5, 6, 10, 12, 15, 20, 30]
minute_limits_ = np.array(minsec_limits_) / 60
minute_factors = [60.] * len(minute_limits_)
second_limits_ = np.array(minsec_limits_) / 3600
second_factors = [3600.] * len(second_limits_)
hour_limits = [*second_limits_, *minute_limits_, *hour_limits_]
hour_steps = [*minsec_steps_, *minsec_steps_, *hour_steps_]
hour_factors = [*second_factors, *minute_factors, *hour_factors]
n = np.searchsorted(hour_limits, dv)
step = hour_steps[n]
factor = hour_factors[n]
return step, factor
def select_step_sub(dv):
# subarcsec or degree
tmp = 10.**(int(math.log10(dv))-1.)
factor = 1./tmp
if 1.5*tmp >= dv:
step = 1
elif 3.*tmp >= dv:
step = 2
elif 7.*tmp >= dv:
step = 5
else:
step = 1
factor = 0.1*factor
return step, factor
def select_step(v1, v2, nv, hour=False, include_last=True,
threshold_factor=3600.):
if v1 > v2:
v1, v2 = v2, v1
dv = (v2 - v1) / nv
if hour:
_select_step = select_step_hour
cycle = 24.
else:
_select_step = select_step_degree
cycle = 360.
# for degree
if dv > 1 / threshold_factor:
step, factor = _select_step(dv)
else:
step, factor = select_step_sub(dv*threshold_factor)
factor = factor * threshold_factor
levs = np.arange(np.floor(v1 * factor / step),
np.ceil(v2 * factor / step) + 0.5,
dtype=int) * step
# n : number of valid levels. If there is a cycle, e.g., [0, 90, 180,
# 270, 360], the grid line needs to be extended from 0 to 360, so
# we need to return the whole array. However, the last level (360)
# needs to be ignored often. In this case, so we return n=4.
n = len(levs)
# we need to check the range of values
# for example, -90 to 90, 0 to 360,
if factor == 1. and levs[-1] >= levs[0] + cycle: # check for cycle
nv = int(cycle / step)
if include_last:
levs = levs[0] + np.arange(0, nv+1, 1) * step
else:
levs = levs[0] + np.arange(0, nv, 1) * step
n = len(levs)
return np.array(levs), n, factor
def select_step24(v1, v2, nv, include_last=True, threshold_factor=3600):
v1, v2 = v1 / 15, v2 / 15
levs, n, factor = select_step(v1, v2, nv, hour=True,
include_last=include_last,
threshold_factor=threshold_factor)
return levs * 15, n, factor
def select_step360(v1, v2, nv, include_last=True, threshold_factor=3600):
return select_step(v1, v2, nv, hour=False,
include_last=include_last,
threshold_factor=threshold_factor)
class LocatorBase:
def __init__(self, nbins, include_last=True):
self.nbins = nbins
self._include_last = include_last
def set_params(self, nbins=None):
if nbins is not None:
self.nbins = int(nbins)
class LocatorHMS(LocatorBase):
def __call__(self, v1, v2):
return select_step24(v1, v2, self.nbins, self._include_last)
class LocatorHM(LocatorBase):
def __call__(self, v1, v2):
return select_step24(v1, v2, self.nbins, self._include_last,
threshold_factor=60)
class LocatorH(LocatorBase):
def __call__(self, v1, v2):
return select_step24(v1, v2, self.nbins, self._include_last,
threshold_factor=1)
class LocatorDMS(LocatorBase):
def __call__(self, v1, v2):
return select_step360(v1, v2, self.nbins, self._include_last)
class LocatorDM(LocatorBase):
def __call__(self, v1, v2):
return select_step360(v1, v2, self.nbins, self._include_last,
threshold_factor=60)
class LocatorD(LocatorBase):
def __call__(self, v1, v2):
return select_step360(v1, v2, self.nbins, self._include_last,
threshold_factor=1)
class FormatterDMS:
deg_mark = r"^{\circ}"
min_mark = r"^{\prime}"
sec_mark = r"^{\prime\prime}"
fmt_d = "$%d" + deg_mark + "$"
fmt_ds = r"$%d.%s" + deg_mark + "$"
# %s for sign
fmt_d_m = r"$%s%d" + deg_mark + r"\,%02d" + min_mark + "$"
fmt_d_ms = r"$%s%d" + deg_mark + r"\,%02d.%s" + min_mark + "$"
fmt_d_m_partial = "$%s%d" + deg_mark + r"\,%02d" + min_mark + r"\,"
fmt_s_partial = "%02d" + sec_mark + "$"
fmt_ss_partial = "%02d.%s" + sec_mark + "$"
def _get_number_fraction(self, factor):
## check for fractional numbers
number_fraction = None
# check for 60
for threshold in [1, 60, 3600]:
if factor <= threshold:
break
d = factor // threshold
int_log_d = int(np.floor(np.log10(d)))
if 10**int_log_d == d and d != 1:
number_fraction = int_log_d
factor = factor // 10**int_log_d
return factor, number_fraction
return factor, number_fraction
def __call__(self, direction, factor, values):
if len(values) == 0:
return []
ss = np.sign(values)
signs = ["-" if v < 0 else "" for v in values]
factor, number_fraction = self._get_number_fraction(factor)
values = np.abs(values)
if number_fraction is not None:
values, frac_part = divmod(values, 10 ** number_fraction)
frac_fmt = "%%0%dd" % (number_fraction,)
frac_str = [frac_fmt % (f1,) for f1 in frac_part]
if factor == 1:
if number_fraction is None:
return [self.fmt_d % (s * int(v),) for s, v in zip(ss, values)]
else:
return [self.fmt_ds % (s * int(v), f1)
for s, v, f1 in zip(ss, values, frac_str)]
elif factor == 60:
deg_part, min_part = divmod(values, 60)
if number_fraction is None:
return [self.fmt_d_m % (s1, d1, m1)
for s1, d1, m1 in zip(signs, deg_part, min_part)]
else:
return [self.fmt_d_ms % (s, d1, m1, f1)
for s, d1, m1, f1
in zip(signs, deg_part, min_part, frac_str)]
elif factor == 3600:
if ss[-1] == -1:
inverse_order = True
values = values[::-1]
signs = signs[::-1]
else:
inverse_order = False
l_hm_old = ""
r = []
deg_part, min_part_ = divmod(values, 3600)
min_part, sec_part = divmod(min_part_, 60)
if number_fraction is None:
sec_str = [self.fmt_s_partial % (s1,) for s1 in sec_part]
else:
sec_str = [self.fmt_ss_partial % (s1, f1)
for s1, f1 in zip(sec_part, frac_str)]
for s, d1, m1, s1 in zip(signs, deg_part, min_part, sec_str):
l_hm = self.fmt_d_m_partial % (s, d1, m1)
if l_hm != l_hm_old:
l_hm_old = l_hm
l = l_hm + s1
else:
l = "$" + s + s1
r.append(l)
if inverse_order:
return r[::-1]
else:
return r
else: # factor > 3600.
return [r"$%s^{\circ}$" % v for v in ss*values]
class FormatterHMS(FormatterDMS):
deg_mark = r"^\mathrm{h}"
min_mark = r"^\mathrm{m}"
sec_mark = r"^\mathrm{s}"
fmt_d = "$%d" + deg_mark + "$"
fmt_ds = r"$%d.%s" + deg_mark + "$"
# %s for sign
fmt_d_m = r"$%s%d" + deg_mark + r"\,%02d" + min_mark+"$"
fmt_d_ms = r"$%s%d" + deg_mark + r"\,%02d.%s" + min_mark+"$"
fmt_d_m_partial = "$%s%d" + deg_mark + r"\,%02d" + min_mark + r"\,"
fmt_s_partial = "%02d" + sec_mark + "$"
fmt_ss_partial = "%02d.%s" + sec_mark + "$"
def __call__(self, direction, factor, values): # hour
return super().__call__(direction, factor, np.asarray(values) / 15)
class ExtremeFinderCycle(ExtremeFinderSimple):
# docstring inherited
def __init__(self, nx, ny,
lon_cycle=360., lat_cycle=None,
lon_minmax=None, lat_minmax=(-90, 90)):
"""
This subclass handles the case where one or both coordinates should be
taken modulo 360, or be restricted to not exceed a specific range.
Parameters
----------
nx, ny : int
The number of samples in each direction.
lon_cycle, lat_cycle : 360 or None
If not None, values in the corresponding direction are taken modulo
*lon_cycle* or *lat_cycle*; in theory this can be any number but
the implementation actually assumes that it is 360 (if not None);
other values give nonsensical results.
This is done by "unwrapping" the transformed grid coordinates so
that jumps are less than a half-cycle; then normalizing the span to
no more than a full cycle.
For example, if values are in the union of the [0, 2] and
[358, 360] intervals (typically, angles measured modulo 360), the
values in the second interval are normalized to [-2, 0] instead so
that the values now cover [-2, 2]. If values are in a range of
[5, 1000], this gets normalized to [5, 365].
lon_minmax, lat_minmax : (float, float) or None
If not None, the computed bounding box is clipped to the given
range in the corresponding direction.
"""
self.nx, self.ny = nx, ny
self.lon_cycle, self.lat_cycle = lon_cycle, lat_cycle
self.lon_minmax = lon_minmax
self.lat_minmax = lat_minmax
def __call__(self, transform_xy, x1, y1, x2, y2):
# docstring inherited
x, y = np.meshgrid(
np.linspace(x1, x2, self.nx), np.linspace(y1, y2, self.ny))
lon, lat = transform_xy(np.ravel(x), np.ravel(y))
# iron out jumps, but algorithm should be improved.
# This is just naive way of doing and my fail for some cases.
# Consider replacing this with numpy.unwrap
# We are ignoring invalid warnings. They are triggered when
# comparing arrays with NaNs using > We are already handling
# that correctly using np.nanmin and np.nanmax
with np.errstate(invalid='ignore'):
if self.lon_cycle is not None:
lon0 = np.nanmin(lon)
lon -= 360. * ((lon - lon0) > 180.)
if self.lat_cycle is not None:
lat0 = np.nanmin(lat)
lat -= 360. * ((lat - lat0) > 180.)
lon_min, lon_max = np.nanmin(lon), np.nanmax(lon)
lat_min, lat_max = np.nanmin(lat), np.nanmax(lat)
lon_min, lon_max, lat_min, lat_max = \
self._add_pad(lon_min, lon_max, lat_min, lat_max)
# check cycle
if self.lon_cycle:
lon_max = min(lon_max, lon_min + self.lon_cycle)
if self.lat_cycle:
lat_max = min(lat_max, lat_min + self.lat_cycle)
if self.lon_minmax is not None:
min0 = self.lon_minmax[0]
lon_min = max(min0, lon_min)
max0 = self.lon_minmax[1]
lon_max = min(max0, lon_max)
if self.lat_minmax is not None:
min0 = self.lat_minmax[0]
lat_min = max(min0, lat_min)
max0 = self.lat_minmax[1]
lat_max = min(max0, lat_max)
return lon_min, lon_max, lat_min, lat_max

View File

@ -0,0 +1,2 @@
from mpl_toolkits.axes_grid1.axes_divider import ( # noqa
Divider, AxesLocator, SubplotDivider, AxesDivider, make_axes_locatable)

View File

@ -0,0 +1,23 @@
from matplotlib import _api
import mpl_toolkits.axes_grid1.axes_grid as axes_grid_orig
from .axislines import Axes
_api.warn_deprecated(
"3.8", name=__name__, obj_type="module", alternative="axes_grid1.axes_grid")
@_api.deprecated("3.8", alternative=(
"axes_grid1.axes_grid.Grid(..., axes_class=axislines.Axes"))
class Grid(axes_grid_orig.Grid):
_defaultAxesClass = Axes
@_api.deprecated("3.8", alternative=(
"axes_grid1.axes_grid.ImageGrid(..., axes_class=axislines.Axes"))
class ImageGrid(axes_grid_orig.ImageGrid):
_defaultAxesClass = Axes
AxesGrid = ImageGrid

View File

@ -0,0 +1,18 @@
from matplotlib import _api
from mpl_toolkits.axes_grid1.axes_rgb import ( # noqa
make_rgb_axes, RGBAxes as _RGBAxes)
from .axislines import Axes
_api.warn_deprecated(
"3.8", name=__name__, obj_type="module", alternative="axes_grid1.axes_rgb")
@_api.deprecated("3.8", alternative=(
"axes_grid1.axes_rgb.RGBAxes(..., axes_class=axislines.Axes"))
class RGBAxes(_RGBAxes):
"""
Subclass of `~.axes_grid1.axes_rgb.RGBAxes` with
``_defaultAxesClass`` = `.axislines.Axes`.
"""
_defaultAxesClass = Axes

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,193 @@
"""
Provides classes to style the axis lines.
"""
import math
import numpy as np
import matplotlib as mpl
from matplotlib.patches import _Style, FancyArrowPatch
from matplotlib.path import Path
from matplotlib.transforms import IdentityTransform
class _FancyAxislineStyle:
class SimpleArrow(FancyArrowPatch):
"""The artist class that will be returned for SimpleArrow style."""
_ARROW_STYLE = "->"
def __init__(self, axis_artist, line_path, transform,
line_mutation_scale):
self._axis_artist = axis_artist
self._line_transform = transform
self._line_path = line_path
self._line_mutation_scale = line_mutation_scale
FancyArrowPatch.__init__(self,
path=self._line_path,
arrowstyle=self._ARROW_STYLE,
patchA=None,
patchB=None,
shrinkA=0.,
shrinkB=0.,
mutation_scale=line_mutation_scale,
mutation_aspect=None,
transform=IdentityTransform(),
)
def set_line_mutation_scale(self, scale):
self.set_mutation_scale(scale*self._line_mutation_scale)
def _extend_path(self, path, mutation_size=10):
"""
Extend the path to make a room for drawing arrow.
"""
(x0, y0), (x1, y1) = path.vertices[-2:]
theta = math.atan2(y1 - y0, x1 - x0)
x2 = x1 + math.cos(theta) * mutation_size
y2 = y1 + math.sin(theta) * mutation_size
if path.codes is None:
return Path(np.concatenate([path.vertices, [[x2, y2]]]))
else:
return Path(np.concatenate([path.vertices, [[x2, y2]]]),
np.concatenate([path.codes, [Path.LINETO]]))
def set_path(self, path):
self._line_path = path
def draw(self, renderer):
"""
Draw the axis line.
1) Transform the path to the display coordinate.
2) Extend the path to make a room for arrow.
3) Update the path of the FancyArrowPatch.
4) Draw.
"""
path_in_disp = self._line_transform.transform_path(self._line_path)
mutation_size = self.get_mutation_scale() # line_mutation_scale()
extended_path = self._extend_path(path_in_disp,
mutation_size=mutation_size)
self._path_original = extended_path
FancyArrowPatch.draw(self, renderer)
def get_window_extent(self, renderer=None):
path_in_disp = self._line_transform.transform_path(self._line_path)
mutation_size = self.get_mutation_scale() # line_mutation_scale()
extended_path = self._extend_path(path_in_disp,
mutation_size=mutation_size)
self._path_original = extended_path
return FancyArrowPatch.get_window_extent(self, renderer)
class FilledArrow(SimpleArrow):
"""The artist class that will be returned for FilledArrow style."""
_ARROW_STYLE = "-|>"
def __init__(self, axis_artist, line_path, transform,
line_mutation_scale, facecolor):
super().__init__(axis_artist, line_path, transform,
line_mutation_scale)
self.set_facecolor(facecolor)
class AxislineStyle(_Style):
"""
A container class which defines style classes for AxisArtists.
An instance of any axisline style class is a callable object,
whose call signature is ::
__call__(self, axis_artist, path, transform)
When called, this should return an `.Artist` with the following methods::
def set_path(self, path):
# set the path for axisline.
def set_line_mutation_scale(self, scale):
# set the scale
def draw(self, renderer):
# draw
"""
_style_list = {}
class _Base:
# The derived classes are required to be able to be initialized
# w/o arguments, i.e., all its argument (except self) must have
# the default values.
def __init__(self):
"""
initialization.
"""
super().__init__()
def __call__(self, axis_artist, transform):
"""
Given the AxisArtist instance, and transform for the path (set_path
method), return the Matplotlib artist for drawing the axis line.
"""
return self.new_line(axis_artist, transform)
class SimpleArrow(_Base):
"""
A simple arrow.
"""
ArrowAxisClass = _FancyAxislineStyle.SimpleArrow
def __init__(self, size=1):
"""
Parameters
----------
size : float
Size of the arrow as a fraction of the ticklabel size.
"""
self.size = size
super().__init__()
def new_line(self, axis_artist, transform):
linepath = Path([(0, 0), (0, 1)])
axisline = self.ArrowAxisClass(axis_artist, linepath, transform,
line_mutation_scale=self.size)
return axisline
_style_list["->"] = SimpleArrow
class FilledArrow(SimpleArrow):
"""
An arrow with a filled head.
"""
ArrowAxisClass = _FancyAxislineStyle.FilledArrow
def __init__(self, size=1, facecolor=None):
"""
Parameters
----------
size : float
Size of the arrow as a fraction of the ticklabel size.
facecolor : :mpltype:`color`, default: :rc:`axes.edgecolor`
Fill color.
.. versionadded:: 3.7
"""
if facecolor is None:
facecolor = mpl.rcParams['axes.edgecolor']
self.size = size
self._facecolor = facecolor
super().__init__(size=size)
def new_line(self, axis_artist, transform):
linepath = Path([(0, 0), (0, 1)])
axisline = self.ArrowAxisClass(axis_artist, linepath, transform,
line_mutation_scale=self.size,
facecolor=self._facecolor)
return axisline
_style_list["-|>"] = FilledArrow

View File

@ -0,0 +1,483 @@
"""
Axislines includes modified implementation of the Axes class. The
biggest difference is that the artists responsible for drawing the axis spine,
ticks, ticklabels and axis labels are separated out from Matplotlib's Axis
class. Originally, this change was motivated to support curvilinear
grid. Here are a few reasons that I came up with a new axes class:
* "top" and "bottom" x-axis (or "left" and "right" y-axis) can have
different ticks (tick locations and labels). This is not possible
with the current Matplotlib, although some twin axes trick can help.
* Curvilinear grid.
* angled ticks.
In the new axes class, xaxis and yaxis is set to not visible by
default, and new set of artist (AxisArtist) are defined to draw axis
line, ticks, ticklabels and axis label. Axes.axis attribute serves as
a dictionary of these artists, i.e., ax.axis["left"] is a AxisArtist
instance responsible to draw left y-axis. The default Axes.axis contains
"bottom", "left", "top" and "right".
AxisArtist can be considered as a container artist and has the following
children artists which will draw ticks, labels, etc.
* line
* major_ticks, major_ticklabels
* minor_ticks, minor_ticklabels
* offsetText
* label
Note that these are separate artists from `matplotlib.axis.Axis`, thus most
tick-related functions in Matplotlib won't work. For example, color and
markerwidth of the ``ax.axis["bottom"].major_ticks`` will follow those of
Axes.xaxis unless explicitly specified.
In addition to AxisArtist, the Axes will have *gridlines* attribute,
which obviously draws grid lines. The gridlines needs to be separated
from the axis as some gridlines can never pass any axis.
"""
import numpy as np
import matplotlib as mpl
from matplotlib import _api
import matplotlib.axes as maxes
from matplotlib.path import Path
from mpl_toolkits.axes_grid1 import mpl_axes
from .axisline_style import AxislineStyle # noqa
from .axis_artist import AxisArtist, GridlinesCollection
class _AxisArtistHelperBase:
"""
Base class for axis helper.
Subclasses should define the methods listed below. The *axes*
argument will be the ``.axes`` attribute of the caller artist. ::
# Construct the spine.
def get_line_transform(self, axes):
return transform
def get_line(self, axes):
return path
# Construct the label.
def get_axislabel_transform(self, axes):
return transform
def get_axislabel_pos_angle(self, axes):
return (x, y), angle
# Construct the ticks.
def get_tick_transform(self, axes):
return transform
def get_tick_iterators(self, axes):
# A pair of iterables (one for major ticks, one for minor ticks)
# that yield (tick_position, tick_angle, tick_label).
return iter_major, iter_minor
"""
def __init__(self, nth_coord):
self.nth_coord = nth_coord
def update_lim(self, axes):
pass
def get_nth_coord(self):
return self.nth_coord
def _to_xy(self, values, const):
"""
Create a (*values.shape, 2)-shape array representing (x, y) pairs.
The other coordinate is filled with the constant *const*.
Example::
>>> self.nth_coord = 0
>>> self._to_xy([1, 2, 3], const=0)
array([[1, 0],
[2, 0],
[3, 0]])
"""
if self.nth_coord == 0:
return np.stack(np.broadcast_arrays(values, const), axis=-1)
elif self.nth_coord == 1:
return np.stack(np.broadcast_arrays(const, values), axis=-1)
else:
raise ValueError("Unexpected nth_coord")
class _FixedAxisArtistHelperBase(_AxisArtistHelperBase):
"""Helper class for a fixed (in the axes coordinate) axis."""
@_api.delete_parameter("3.9", "nth_coord")
def __init__(self, loc, nth_coord=None):
"""``nth_coord = 0``: x-axis; ``nth_coord = 1``: y-axis."""
super().__init__(_api.check_getitem(
{"bottom": 0, "top": 0, "left": 1, "right": 1}, loc=loc))
self._loc = loc
self._pos = {"bottom": 0, "top": 1, "left": 0, "right": 1}[loc]
# axis line in transAxes
self._path = Path(self._to_xy((0, 1), const=self._pos))
# LINE
def get_line(self, axes):
return self._path
def get_line_transform(self, axes):
return axes.transAxes
# LABEL
def get_axislabel_transform(self, axes):
return axes.transAxes
def get_axislabel_pos_angle(self, axes):
"""
Return the label reference position in transAxes.
get_label_transform() returns a transform of (transAxes+offset)
"""
return dict(left=((0., 0.5), 90), # (position, angle_tangent)
right=((1., 0.5), 90),
bottom=((0.5, 0.), 0),
top=((0.5, 1.), 0))[self._loc]
# TICK
def get_tick_transform(self, axes):
return [axes.get_xaxis_transform(), axes.get_yaxis_transform()][self.nth_coord]
class _FloatingAxisArtistHelperBase(_AxisArtistHelperBase):
def __init__(self, nth_coord, value):
self._value = value
super().__init__(nth_coord)
def get_line(self, axes):
raise RuntimeError("get_line method should be defined by the derived class")
class FixedAxisArtistHelperRectilinear(_FixedAxisArtistHelperBase):
@_api.delete_parameter("3.9", "nth_coord")
def __init__(self, axes, loc, nth_coord=None):
"""
nth_coord = along which coordinate value varies
in 2D, nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
"""
super().__init__(loc)
self.axis = [axes.xaxis, axes.yaxis][self.nth_coord]
# TICK
def get_tick_iterators(self, axes):
"""tick_loc, tick_angle, tick_label"""
angle_normal, angle_tangent = {0: (90, 0), 1: (0, 90)}[self.nth_coord]
major = self.axis.major
major_locs = major.locator()
major_labels = major.formatter.format_ticks(major_locs)
minor = self.axis.minor
minor_locs = minor.locator()
minor_labels = minor.formatter.format_ticks(minor_locs)
tick_to_axes = self.get_tick_transform(axes) - axes.transAxes
def _f(locs, labels):
for loc, label in zip(locs, labels):
c = self._to_xy(loc, const=self._pos)
# check if the tick point is inside axes
c2 = tick_to_axes.transform(c)
if mpl.transforms._interval_contains_close((0, 1), c2[self.nth_coord]):
yield c, angle_normal, angle_tangent, label
return _f(major_locs, major_labels), _f(minor_locs, minor_labels)
class FloatingAxisArtistHelperRectilinear(_FloatingAxisArtistHelperBase):
def __init__(self, axes, nth_coord,
passingthrough_point, axis_direction="bottom"):
super().__init__(nth_coord, passingthrough_point)
self._axis_direction = axis_direction
self.axis = [axes.xaxis, axes.yaxis][self.nth_coord]
def get_line(self, axes):
fixed_coord = 1 - self.nth_coord
data_to_axes = axes.transData - axes.transAxes
p = data_to_axes.transform([self._value, self._value])
return Path(self._to_xy((0, 1), const=p[fixed_coord]))
def get_line_transform(self, axes):
return axes.transAxes
def get_axislabel_transform(self, axes):
return axes.transAxes
def get_axislabel_pos_angle(self, axes):
"""
Return the label reference position in transAxes.
get_label_transform() returns a transform of (transAxes+offset)
"""
angle = [0, 90][self.nth_coord]
fixed_coord = 1 - self.nth_coord
data_to_axes = axes.transData - axes.transAxes
p = data_to_axes.transform([self._value, self._value])
verts = self._to_xy(0.5, const=p[fixed_coord])
return (verts, angle) if 0 <= verts[fixed_coord] <= 1 else (None, None)
def get_tick_transform(self, axes):
return axes.transData
def get_tick_iterators(self, axes):
"""tick_loc, tick_angle, tick_label"""
angle_normal, angle_tangent = {0: (90, 0), 1: (0, 90)}[self.nth_coord]
major = self.axis.major
major_locs = major.locator()
major_labels = major.formatter.format_ticks(major_locs)
minor = self.axis.minor
minor_locs = minor.locator()
minor_labels = minor.formatter.format_ticks(minor_locs)
data_to_axes = axes.transData - axes.transAxes
def _f(locs, labels):
for loc, label in zip(locs, labels):
c = self._to_xy(loc, const=self._value)
c1, c2 = data_to_axes.transform(c)
if 0 <= c1 <= 1 and 0 <= c2 <= 1:
yield c, angle_normal, angle_tangent, label
return _f(major_locs, major_labels), _f(minor_locs, minor_labels)
class AxisArtistHelper: # Backcompat.
Fixed = _FixedAxisArtistHelperBase
Floating = _FloatingAxisArtistHelperBase
class AxisArtistHelperRectlinear: # Backcompat.
Fixed = FixedAxisArtistHelperRectilinear
Floating = FloatingAxisArtistHelperRectilinear
class GridHelperBase:
def __init__(self):
self._old_limits = None
super().__init__()
def update_lim(self, axes):
x1, x2 = axes.get_xlim()
y1, y2 = axes.get_ylim()
if self._old_limits != (x1, x2, y1, y2):
self._update_grid(x1, y1, x2, y2)
self._old_limits = (x1, x2, y1, y2)
def _update_grid(self, x1, y1, x2, y2):
"""Cache relevant computations when the axes limits have changed."""
def get_gridlines(self, which, axis):
"""
Return list of grid lines as a list of paths (list of points).
Parameters
----------
which : {"both", "major", "minor"}
axis : {"both", "x", "y"}
"""
return []
class GridHelperRectlinear(GridHelperBase):
def __init__(self, axes):
super().__init__()
self.axes = axes
@_api.delete_parameter(
"3.9", "nth_coord", addendum="'nth_coord' is now inferred from 'loc'.")
def new_fixed_axis(
self, loc, nth_coord=None, axis_direction=None, offset=None, axes=None):
if axes is None:
_api.warn_external(
"'new_fixed_axis' explicitly requires the axes keyword.")
axes = self.axes
if axis_direction is None:
axis_direction = loc
return AxisArtist(axes, FixedAxisArtistHelperRectilinear(axes, loc),
offset=offset, axis_direction=axis_direction)
def new_floating_axis(self, nth_coord, value, axis_direction="bottom", axes=None):
if axes is None:
_api.warn_external(
"'new_floating_axis' explicitly requires the axes keyword.")
axes = self.axes
helper = FloatingAxisArtistHelperRectilinear(
axes, nth_coord, value, axis_direction)
axisline = AxisArtist(axes, helper, axis_direction=axis_direction)
axisline.line.set_clip_on(True)
axisline.line.set_clip_box(axisline.axes.bbox)
return axisline
def get_gridlines(self, which="major", axis="both"):
"""
Return list of gridline coordinates in data coordinates.
Parameters
----------
which : {"both", "major", "minor"}
axis : {"both", "x", "y"}
"""
_api.check_in_list(["both", "major", "minor"], which=which)
_api.check_in_list(["both", "x", "y"], axis=axis)
gridlines = []
if axis in ("both", "x"):
locs = []
y1, y2 = self.axes.get_ylim()
if which in ("both", "major"):
locs.extend(self.axes.xaxis.major.locator())
if which in ("both", "minor"):
locs.extend(self.axes.xaxis.minor.locator())
gridlines.extend([[x, x], [y1, y2]] for x in locs)
if axis in ("both", "y"):
x1, x2 = self.axes.get_xlim()
locs = []
if self.axes.yaxis._major_tick_kw["gridOn"]:
locs.extend(self.axes.yaxis.major.locator())
if self.axes.yaxis._minor_tick_kw["gridOn"]:
locs.extend(self.axes.yaxis.minor.locator())
gridlines.extend([[x1, x2], [y, y]] for y in locs)
return gridlines
class Axes(maxes.Axes):
@_api.deprecated("3.8", alternative="ax.axis")
def __call__(self, *args, **kwargs):
return maxes.Axes.axis(self.axes, *args, **kwargs)
def __init__(self, *args, grid_helper=None, **kwargs):
self._axisline_on = True
self._grid_helper = grid_helper if grid_helper else GridHelperRectlinear(self)
super().__init__(*args, **kwargs)
self.toggle_axisline(True)
def toggle_axisline(self, b=None):
if b is None:
b = not self._axisline_on
if b:
self._axisline_on = True
self.spines[:].set_visible(False)
self.xaxis.set_visible(False)
self.yaxis.set_visible(False)
else:
self._axisline_on = False
self.spines[:].set_visible(True)
self.xaxis.set_visible(True)
self.yaxis.set_visible(True)
@property
def axis(self):
return self._axislines
def clear(self):
# docstring inherited
# Init gridlines before clear() as clear() calls grid().
self.gridlines = gridlines = GridlinesCollection(
[],
colors=mpl.rcParams['grid.color'],
linestyles=mpl.rcParams['grid.linestyle'],
linewidths=mpl.rcParams['grid.linewidth'])
self._set_artist_props(gridlines)
gridlines.set_grid_helper(self.get_grid_helper())
super().clear()
# clip_path is set after Axes.clear(): that's when a patch is created.
gridlines.set_clip_path(self.axes.patch)
# Init axis artists.
self._axislines = mpl_axes.Axes.AxisDict(self)
new_fixed_axis = self.get_grid_helper().new_fixed_axis
self._axislines.update({
loc: new_fixed_axis(loc=loc, axes=self, axis_direction=loc)
for loc in ["bottom", "top", "left", "right"]})
for axisline in [self._axislines["top"], self._axislines["right"]]:
axisline.label.set_visible(False)
axisline.major_ticklabels.set_visible(False)
axisline.minor_ticklabels.set_visible(False)
def get_grid_helper(self):
return self._grid_helper
def grid(self, visible=None, which='major', axis="both", **kwargs):
"""
Toggle the gridlines, and optionally set the properties of the lines.
"""
# There are some discrepancies in the behavior of grid() between
# axes_grid and Matplotlib, because axes_grid explicitly sets the
# visibility of the gridlines.
super().grid(visible, which=which, axis=axis, **kwargs)
if not self._axisline_on:
return
if visible is None:
visible = (self.axes.xaxis._minor_tick_kw["gridOn"]
or self.axes.xaxis._major_tick_kw["gridOn"]
or self.axes.yaxis._minor_tick_kw["gridOn"]
or self.axes.yaxis._major_tick_kw["gridOn"])
self.gridlines.set(which=which, axis=axis, visible=visible)
self.gridlines.set(**kwargs)
def get_children(self):
if self._axisline_on:
children = [*self._axislines.values(), self.gridlines]
else:
children = []
children.extend(super().get_children())
return children
def new_fixed_axis(self, loc, offset=None):
return self.get_grid_helper().new_fixed_axis(loc, offset=offset, axes=self)
def new_floating_axis(self, nth_coord, value, axis_direction="bottom"):
return self.get_grid_helper().new_floating_axis(
nth_coord, value, axis_direction=axis_direction, axes=self)
class AxesZero(Axes):
def clear(self):
super().clear()
new_floating_axis = self.get_grid_helper().new_floating_axis
self._axislines.update(
xzero=new_floating_axis(
nth_coord=0, value=0., axis_direction="bottom", axes=self),
yzero=new_floating_axis(
nth_coord=1, value=0., axis_direction="left", axes=self),
)
for k in ["xzero", "yzero"]:
self._axislines[k].line.set_clip_path(self.patch)
self._axislines[k].set_visible(False)
Subplot = Axes
SubplotZero = AxesZero

View File

@ -0,0 +1,286 @@
"""
An experimental support for curvilinear grid.
"""
# TODO :
# see if tick_iterator method can be simplified by reusing the parent method.
import functools
import numpy as np
import matplotlib as mpl
from matplotlib import _api, cbook
import matplotlib.patches as mpatches
from matplotlib.path import Path
from mpl_toolkits.axes_grid1.parasite_axes import host_axes_class_factory
from . import axislines, grid_helper_curvelinear
from .axis_artist import AxisArtist
from .grid_finder import ExtremeFinderSimple
class FloatingAxisArtistHelper(
grid_helper_curvelinear.FloatingAxisArtistHelper):
pass
class FixedAxisArtistHelper(grid_helper_curvelinear.FloatingAxisArtistHelper):
def __init__(self, grid_helper, side, nth_coord_ticks=None):
"""
nth_coord = along which coordinate value varies.
nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
"""
lon1, lon2, lat1, lat2 = grid_helper.grid_finder.extreme_finder(*[None] * 5)
value, nth_coord = _api.check_getitem(
dict(left=(lon1, 0), right=(lon2, 0), bottom=(lat1, 1), top=(lat2, 1)),
side=side)
super().__init__(grid_helper, nth_coord, value, axis_direction=side)
if nth_coord_ticks is None:
nth_coord_ticks = nth_coord
self.nth_coord_ticks = nth_coord_ticks
self.value = value
self.grid_helper = grid_helper
self._side = side
def update_lim(self, axes):
self.grid_helper.update_lim(axes)
self._grid_info = self.grid_helper._grid_info
def get_tick_iterators(self, axes):
"""tick_loc, tick_angle, tick_label, (optionally) tick_label"""
grid_finder = self.grid_helper.grid_finder
lat_levs, lat_n, lat_factor = self._grid_info["lat_info"]
yy0 = lat_levs / lat_factor
lon_levs, lon_n, lon_factor = self._grid_info["lon_info"]
xx0 = lon_levs / lon_factor
extremes = self.grid_helper.grid_finder.extreme_finder(*[None] * 5)
xmin, xmax = sorted(extremes[:2])
ymin, ymax = sorted(extremes[2:])
def trf_xy(x, y):
trf = grid_finder.get_transform() + axes.transData
return trf.transform(np.column_stack(np.broadcast_arrays(x, y))).T
if self.nth_coord == 0:
mask = (ymin <= yy0) & (yy0 <= ymax)
(xx1, yy1), (dxx1, dyy1), (dxx2, dyy2) = \
grid_helper_curvelinear._value_and_jacobian(
trf_xy, self.value, yy0[mask], (xmin, xmax), (ymin, ymax))
labels = self._grid_info["lat_labels"]
elif self.nth_coord == 1:
mask = (xmin <= xx0) & (xx0 <= xmax)
(xx1, yy1), (dxx2, dyy2), (dxx1, dyy1) = \
grid_helper_curvelinear._value_and_jacobian(
trf_xy, xx0[mask], self.value, (xmin, xmax), (ymin, ymax))
labels = self._grid_info["lon_labels"]
labels = [l for l, m in zip(labels, mask) if m]
angle_normal = np.arctan2(dyy1, dxx1)
angle_tangent = np.arctan2(dyy2, dxx2)
mm = (dyy1 == 0) & (dxx1 == 0) # points with degenerate normal
angle_normal[mm] = angle_tangent[mm] + np.pi / 2
tick_to_axes = self.get_tick_transform(axes) - axes.transAxes
in_01 = functools.partial(
mpl.transforms._interval_contains_close, (0, 1))
def f1():
for x, y, normal, tangent, lab \
in zip(xx1, yy1, angle_normal, angle_tangent, labels):
c2 = tick_to_axes.transform((x, y))
if in_01(c2[0]) and in_01(c2[1]):
yield [x, y], *np.rad2deg([normal, tangent]), lab
return f1(), iter([])
def get_line(self, axes):
self.update_lim(axes)
k, v = dict(left=("lon_lines0", 0),
right=("lon_lines0", 1),
bottom=("lat_lines0", 0),
top=("lat_lines0", 1))[self._side]
xx, yy = self._grid_info[k][v]
return Path(np.column_stack([xx, yy]))
class ExtremeFinderFixed(ExtremeFinderSimple):
# docstring inherited
def __init__(self, extremes):
"""
This subclass always returns the same bounding box.
Parameters
----------
extremes : (float, float, float, float)
The bounding box that this helper always returns.
"""
self._extremes = extremes
def __call__(self, transform_xy, x1, y1, x2, y2):
# docstring inherited
return self._extremes
class GridHelperCurveLinear(grid_helper_curvelinear.GridHelperCurveLinear):
def __init__(self, aux_trans, extremes,
grid_locator1=None,
grid_locator2=None,
tick_formatter1=None,
tick_formatter2=None):
# docstring inherited
super().__init__(aux_trans,
extreme_finder=ExtremeFinderFixed(extremes),
grid_locator1=grid_locator1,
grid_locator2=grid_locator2,
tick_formatter1=tick_formatter1,
tick_formatter2=tick_formatter2)
@_api.deprecated("3.8")
def get_data_boundary(self, side):
"""
Return v=0, nth=1.
"""
lon1, lon2, lat1, lat2 = self.grid_finder.extreme_finder(*[None] * 5)
return dict(left=(lon1, 0),
right=(lon2, 0),
bottom=(lat1, 1),
top=(lat2, 1))[side]
def new_fixed_axis(
self, loc, nth_coord=None, axis_direction=None, offset=None, axes=None):
if axes is None:
axes = self.axes
if axis_direction is None:
axis_direction = loc
# This is not the same as the FixedAxisArtistHelper class used by
# grid_helper_curvelinear.GridHelperCurveLinear.new_fixed_axis!
helper = FixedAxisArtistHelper(
self, loc, nth_coord_ticks=nth_coord)
axisline = AxisArtist(axes, helper, axis_direction=axis_direction)
# Perhaps should be moved to the base class?
axisline.line.set_clip_on(True)
axisline.line.set_clip_box(axisline.axes.bbox)
return axisline
# new_floating_axis will inherit the grid_helper's extremes.
# def new_floating_axis(self, nth_coord, value, axes=None, axis_direction="bottom"):
# axis = super(GridHelperCurveLinear,
# self).new_floating_axis(nth_coord,
# value, axes=axes,
# axis_direction=axis_direction)
# # set extreme values of the axis helper
# if nth_coord == 1:
# axis.get_helper().set_extremes(*self._extremes[:2])
# elif nth_coord == 0:
# axis.get_helper().set_extremes(*self._extremes[2:])
# return axis
def _update_grid(self, x1, y1, x2, y2):
if self._grid_info is None:
self._grid_info = dict()
grid_info = self._grid_info
grid_finder = self.grid_finder
extremes = grid_finder.extreme_finder(grid_finder.inv_transform_xy,
x1, y1, x2, y2)
lon_min, lon_max = sorted(extremes[:2])
lat_min, lat_max = sorted(extremes[2:])
grid_info["extremes"] = lon_min, lon_max, lat_min, lat_max # extremes
lon_levs, lon_n, lon_factor = \
grid_finder.grid_locator1(lon_min, lon_max)
lon_levs = np.asarray(lon_levs)
lat_levs, lat_n, lat_factor = \
grid_finder.grid_locator2(lat_min, lat_max)
lat_levs = np.asarray(lat_levs)
grid_info["lon_info"] = lon_levs, lon_n, lon_factor
grid_info["lat_info"] = lat_levs, lat_n, lat_factor
grid_info["lon_labels"] = grid_finder._format_ticks(
1, "bottom", lon_factor, lon_levs)
grid_info["lat_labels"] = grid_finder._format_ticks(
2, "bottom", lat_factor, lat_levs)
lon_values = lon_levs[:lon_n] / lon_factor
lat_values = lat_levs[:lat_n] / lat_factor
lon_lines, lat_lines = grid_finder._get_raw_grid_lines(
lon_values[(lon_min < lon_values) & (lon_values < lon_max)],
lat_values[(lat_min < lat_values) & (lat_values < lat_max)],
lon_min, lon_max, lat_min, lat_max)
grid_info["lon_lines"] = lon_lines
grid_info["lat_lines"] = lat_lines
lon_lines, lat_lines = grid_finder._get_raw_grid_lines(
# lon_min, lon_max, lat_min, lat_max)
extremes[:2], extremes[2:], *extremes)
grid_info["lon_lines0"] = lon_lines
grid_info["lat_lines0"] = lat_lines
def get_gridlines(self, which="major", axis="both"):
grid_lines = []
if axis in ["both", "x"]:
grid_lines.extend(self._grid_info["lon_lines"])
if axis in ["both", "y"]:
grid_lines.extend(self._grid_info["lat_lines"])
return grid_lines
class FloatingAxesBase:
def __init__(self, *args, grid_helper, **kwargs):
_api.check_isinstance(GridHelperCurveLinear, grid_helper=grid_helper)
super().__init__(*args, grid_helper=grid_helper, **kwargs)
self.set_aspect(1.)
def _gen_axes_patch(self):
# docstring inherited
x0, x1, y0, y1 = self.get_grid_helper().grid_finder.extreme_finder(*[None] * 5)
patch = mpatches.Polygon([(x0, y0), (x1, y0), (x1, y1), (x0, y1)])
patch.get_path()._interpolation_steps = 100
return patch
def clear(self):
super().clear()
self.patch.set_transform(
self.get_grid_helper().grid_finder.get_transform()
+ self.transData)
# The original patch is not in the draw tree; it is only used for
# clipping purposes.
orig_patch = super()._gen_axes_patch()
orig_patch.set_figure(self.figure)
orig_patch.set_transform(self.transAxes)
self.patch.set_clip_path(orig_patch)
self.gridlines.set_clip_path(orig_patch)
self.adjust_axes_lim()
def adjust_axes_lim(self):
bbox = self.patch.get_path().get_extents(
# First transform to pixel coords, then to parent data coords.
self.patch.get_transform() - self.transData)
bbox = bbox.expanded(1.02, 1.02)
self.set_xlim(bbox.xmin, bbox.xmax)
self.set_ylim(bbox.ymin, bbox.ymax)
floatingaxes_class_factory = cbook._make_class_factory(FloatingAxesBase, "Floating{}")
FloatingAxes = floatingaxes_class_factory(host_axes_class_factory(axislines.Axes))
FloatingSubplot = FloatingAxes

View File

@ -0,0 +1,326 @@
import numpy as np
from matplotlib import ticker as mticker, _api
from matplotlib.transforms import Bbox, Transform
def _find_line_box_crossings(xys, bbox):
"""
Find the points where a polyline crosses a bbox, and the crossing angles.
Parameters
----------
xys : (N, 2) array
The polyline coordinates.
bbox : `.Bbox`
The bounding box.
Returns
-------
list of ((float, float), float)
Four separate lists of crossings, for the left, right, bottom, and top
sides of the bbox, respectively. For each list, the entries are the
``((x, y), ccw_angle_in_degrees)`` of the crossing, where an angle of 0
means that the polyline is moving to the right at the crossing point.
The entries are computed by linearly interpolating at each crossing
between the nearest points on either side of the bbox edges.
"""
crossings = []
dxys = xys[1:] - xys[:-1]
for sl in [slice(None), slice(None, None, -1)]:
us, vs = xys.T[sl] # "this" coord, "other" coord
dus, dvs = dxys.T[sl]
umin, vmin = bbox.min[sl]
umax, vmax = bbox.max[sl]
for u0, inside in [(umin, us > umin), (umax, us < umax)]:
cross = []
idxs, = (inside[:-1] ^ inside[1:]).nonzero()
for idx in idxs:
v = vs[idx] + (u0 - us[idx]) * dvs[idx] / dus[idx]
if not vmin <= v <= vmax:
continue
crossing = (u0, v)[sl]
theta = np.degrees(np.arctan2(*dxys[idx][::-1]))
cross.append((crossing, theta))
crossings.append(cross)
return crossings
class ExtremeFinderSimple:
"""
A helper class to figure out the range of grid lines that need to be drawn.
"""
def __init__(self, nx, ny):
"""
Parameters
----------
nx, ny : int
The number of samples in each direction.
"""
self.nx = nx
self.ny = ny
def __call__(self, transform_xy, x1, y1, x2, y2):
"""
Compute an approximation of the bounding box obtained by applying
*transform_xy* to the box delimited by ``(x1, y1, x2, y2)``.
The intended use is to have ``(x1, y1, x2, y2)`` in axes coordinates,
and have *transform_xy* be the transform from axes coordinates to data
coordinates; this method then returns the range of data coordinates
that span the actual axes.
The computation is done by sampling ``nx * ny`` equispaced points in
the ``(x1, y1, x2, y2)`` box and finding the resulting points with
extremal coordinates; then adding some padding to take into account the
finite sampling.
As each sampling step covers a relative range of *1/nx* or *1/ny*,
the padding is computed by expanding the span covered by the extremal
coordinates by these fractions.
"""
x, y = np.meshgrid(
np.linspace(x1, x2, self.nx), np.linspace(y1, y2, self.ny))
xt, yt = transform_xy(np.ravel(x), np.ravel(y))
return self._add_pad(xt.min(), xt.max(), yt.min(), yt.max())
def _add_pad(self, x_min, x_max, y_min, y_max):
"""Perform the padding mentioned in `__call__`."""
dx = (x_max - x_min) / self.nx
dy = (y_max - y_min) / self.ny
return x_min - dx, x_max + dx, y_min - dy, y_max + dy
class _User2DTransform(Transform):
"""A transform defined by two user-set functions."""
input_dims = output_dims = 2
def __init__(self, forward, backward):
"""
Parameters
----------
forward, backward : callable
The forward and backward transforms, taking ``x`` and ``y`` as
separate arguments and returning ``(tr_x, tr_y)``.
"""
# The normal Matplotlib convention would be to take and return an
# (N, 2) array but axisartist uses the transposed version.
super().__init__()
self._forward = forward
self._backward = backward
def transform_non_affine(self, values):
# docstring inherited
return np.transpose(self._forward(*np.transpose(values)))
def inverted(self):
# docstring inherited
return type(self)(self._backward, self._forward)
class GridFinder:
"""
Internal helper for `~.grid_helper_curvelinear.GridHelperCurveLinear`, with
the same constructor parameters; should not be directly instantiated.
"""
def __init__(self,
transform,
extreme_finder=None,
grid_locator1=None,
grid_locator2=None,
tick_formatter1=None,
tick_formatter2=None):
if extreme_finder is None:
extreme_finder = ExtremeFinderSimple(20, 20)
if grid_locator1 is None:
grid_locator1 = MaxNLocator()
if grid_locator2 is None:
grid_locator2 = MaxNLocator()
if tick_formatter1 is None:
tick_formatter1 = FormatterPrettyPrint()
if tick_formatter2 is None:
tick_formatter2 = FormatterPrettyPrint()
self.extreme_finder = extreme_finder
self.grid_locator1 = grid_locator1
self.grid_locator2 = grid_locator2
self.tick_formatter1 = tick_formatter1
self.tick_formatter2 = tick_formatter2
self.set_transform(transform)
def _format_ticks(self, idx, direction, factor, levels):
"""
Helper to support both standard formatters (inheriting from
`.mticker.Formatter`) and axisartist-specific ones; should be called instead of
directly calling ``self.tick_formatter1`` and ``self.tick_formatter2``. This
method should be considered as a temporary workaround which will be removed in
the future at the same time as axisartist-specific formatters.
"""
fmt = _api.check_getitem(
{1: self.tick_formatter1, 2: self.tick_formatter2}, idx=idx)
return (fmt.format_ticks(levels) if isinstance(fmt, mticker.Formatter)
else fmt(direction, factor, levels))
def get_grid_info(self, x1, y1, x2, y2):
"""
lon_values, lat_values : list of grid values. if integer is given,
rough number of grids in each direction.
"""
extremes = self.extreme_finder(self.inv_transform_xy, x1, y1, x2, y2)
# min & max rage of lat (or lon) for each grid line will be drawn.
# i.e., gridline of lon=0 will be drawn from lat_min to lat_max.
lon_min, lon_max, lat_min, lat_max = extremes
lon_levs, lon_n, lon_factor = self.grid_locator1(lon_min, lon_max)
lon_levs = np.asarray(lon_levs)
lat_levs, lat_n, lat_factor = self.grid_locator2(lat_min, lat_max)
lat_levs = np.asarray(lat_levs)
lon_values = lon_levs[:lon_n] / lon_factor
lat_values = lat_levs[:lat_n] / lat_factor
lon_lines, lat_lines = self._get_raw_grid_lines(lon_values,
lat_values,
lon_min, lon_max,
lat_min, lat_max)
bb = Bbox.from_extents(x1, y1, x2, y2).expanded(1 + 2e-10, 1 + 2e-10)
grid_info = {
"extremes": extremes,
# "lon", "lat", filled below.
}
for idx, lon_or_lat, levs, factor, values, lines in [
(1, "lon", lon_levs, lon_factor, lon_values, lon_lines),
(2, "lat", lat_levs, lat_factor, lat_values, lat_lines),
]:
grid_info[lon_or_lat] = gi = {
"lines": [[l] for l in lines],
"ticks": {"left": [], "right": [], "bottom": [], "top": []},
}
for (lx, ly), v, level in zip(lines, values, levs):
all_crossings = _find_line_box_crossings(np.column_stack([lx, ly]), bb)
for side, crossings in zip(
["left", "right", "bottom", "top"], all_crossings):
for crossing in crossings:
gi["ticks"][side].append({"level": level, "loc": crossing})
for side in gi["ticks"]:
levs = [tick["level"] for tick in gi["ticks"][side]]
labels = self._format_ticks(idx, side, factor, levs)
for tick, label in zip(gi["ticks"][side], labels):
tick["label"] = label
return grid_info
def _get_raw_grid_lines(self,
lon_values, lat_values,
lon_min, lon_max, lat_min, lat_max):
lons_i = np.linspace(lon_min, lon_max, 100) # for interpolation
lats_i = np.linspace(lat_min, lat_max, 100)
lon_lines = [self.transform_xy(np.full_like(lats_i, lon), lats_i)
for lon in lon_values]
lat_lines = [self.transform_xy(lons_i, np.full_like(lons_i, lat))
for lat in lat_values]
return lon_lines, lat_lines
def set_transform(self, aux_trans):
if isinstance(aux_trans, Transform):
self._aux_transform = aux_trans
elif len(aux_trans) == 2 and all(map(callable, aux_trans)):
self._aux_transform = _User2DTransform(*aux_trans)
else:
raise TypeError("'aux_trans' must be either a Transform "
"instance or a pair of callables")
def get_transform(self):
return self._aux_transform
update_transform = set_transform # backcompat alias.
def transform_xy(self, x, y):
return self._aux_transform.transform(np.column_stack([x, y])).T
def inv_transform_xy(self, x, y):
return self._aux_transform.inverted().transform(
np.column_stack([x, y])).T
def update(self, **kwargs):
for k, v in kwargs.items():
if k in ["extreme_finder",
"grid_locator1",
"grid_locator2",
"tick_formatter1",
"tick_formatter2"]:
setattr(self, k, v)
else:
raise ValueError(f"Unknown update property {k!r}")
class MaxNLocator(mticker.MaxNLocator):
def __init__(self, nbins=10, steps=None,
trim=True,
integer=False,
symmetric=False,
prune=None):
# trim argument has no effect. It has been left for API compatibility
super().__init__(nbins, steps=steps, integer=integer,
symmetric=symmetric, prune=prune)
self.create_dummy_axis()
def __call__(self, v1, v2):
locs = super().tick_values(v1, v2)
return np.array(locs), len(locs), 1 # 1: factor (see angle_helper)
class FixedLocator:
def __init__(self, locs):
self._locs = locs
def __call__(self, v1, v2):
v1, v2 = sorted([v1, v2])
locs = np.array([l for l in self._locs if v1 <= l <= v2])
return locs, len(locs), 1 # 1: factor (see angle_helper)
# Tick Formatter
class FormatterPrettyPrint:
def __init__(self, useMathText=True):
self._fmt = mticker.ScalarFormatter(
useMathText=useMathText, useOffset=False)
self._fmt.create_dummy_axis()
def __call__(self, direction, factor, values):
return self._fmt.format_ticks(values)
class DictFormatter:
def __init__(self, format_dict, formatter=None):
"""
format_dict : dictionary for format strings to be used.
formatter : fall-back formatter
"""
super().__init__()
self._format_dict = format_dict
self._fallback_formatter = formatter
def __call__(self, direction, factor, values):
"""
factor is ignored if value is found in the dictionary
"""
if self._fallback_formatter:
fallback_strings = self._fallback_formatter(
direction, factor, values)
else:
fallback_strings = [""] * len(values)
return [self._format_dict.get(k, v)
for k, v in zip(values, fallback_strings)]

View File

@ -0,0 +1,328 @@
"""
An experimental support for curvilinear grid.
"""
import functools
import numpy as np
import matplotlib as mpl
from matplotlib import _api
from matplotlib.path import Path
from matplotlib.transforms import Affine2D, IdentityTransform
from .axislines import (
_FixedAxisArtistHelperBase, _FloatingAxisArtistHelperBase, GridHelperBase)
from .axis_artist import AxisArtist
from .grid_finder import GridFinder
def _value_and_jacobian(func, xs, ys, xlims, ylims):
"""
Compute *func* and its derivatives along x and y at positions *xs*, *ys*,
while ensuring that finite difference calculations don't try to evaluate
values outside of *xlims*, *ylims*.
"""
eps = np.finfo(float).eps ** (1/2) # see e.g. scipy.optimize.approx_fprime
val = func(xs, ys)
# Take the finite difference step in the direction where the bound is the
# furthest; the step size is min of epsilon and distance to that bound.
xlo, xhi = sorted(xlims)
dxlo = xs - xlo
dxhi = xhi - xs
xeps = (np.take([-1, 1], dxhi >= dxlo)
* np.minimum(eps, np.maximum(dxlo, dxhi)))
val_dx = func(xs + xeps, ys)
ylo, yhi = sorted(ylims)
dylo = ys - ylo
dyhi = yhi - ys
yeps = (np.take([-1, 1], dyhi >= dylo)
* np.minimum(eps, np.maximum(dylo, dyhi)))
val_dy = func(xs, ys + yeps)
return (val, (val_dx - val) / xeps, (val_dy - val) / yeps)
class FixedAxisArtistHelper(_FixedAxisArtistHelperBase):
"""
Helper class for a fixed axis.
"""
def __init__(self, grid_helper, side, nth_coord_ticks=None):
"""
nth_coord = along which coordinate value varies.
nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
"""
super().__init__(loc=side)
self.grid_helper = grid_helper
if nth_coord_ticks is None:
nth_coord_ticks = self.nth_coord
self.nth_coord_ticks = nth_coord_ticks
self.side = side
def update_lim(self, axes):
self.grid_helper.update_lim(axes)
def get_tick_transform(self, axes):
return axes.transData
def get_tick_iterators(self, axes):
"""tick_loc, tick_angle, tick_label"""
v1, v2 = axes.get_ylim() if self.nth_coord == 0 else axes.get_xlim()
if v1 > v2: # Inverted limits.
side = {"left": "right", "right": "left",
"top": "bottom", "bottom": "top"}[self.side]
else:
side = self.side
angle_tangent = dict(left=90, right=90, bottom=0, top=0)[side]
def iter_major():
for nth_coord, show_labels in [
(self.nth_coord_ticks, True), (1 - self.nth_coord_ticks, False)]:
gi = self.grid_helper._grid_info[["lon", "lat"][nth_coord]]
for tick in gi["ticks"][side]:
yield (*tick["loc"], angle_tangent,
(tick["label"] if show_labels else ""))
return iter_major(), iter([])
class FloatingAxisArtistHelper(_FloatingAxisArtistHelperBase):
def __init__(self, grid_helper, nth_coord, value, axis_direction=None):
"""
nth_coord = along which coordinate value varies.
nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
"""
super().__init__(nth_coord, value)
self.value = value
self.grid_helper = grid_helper
self._extremes = -np.inf, np.inf
self._line_num_points = 100 # number of points to create a line
def set_extremes(self, e1, e2):
if e1 is None:
e1 = -np.inf
if e2 is None:
e2 = np.inf
self._extremes = e1, e2
def update_lim(self, axes):
self.grid_helper.update_lim(axes)
x1, x2 = axes.get_xlim()
y1, y2 = axes.get_ylim()
grid_finder = self.grid_helper.grid_finder
extremes = grid_finder.extreme_finder(grid_finder.inv_transform_xy,
x1, y1, x2, y2)
lon_min, lon_max, lat_min, lat_max = extremes
e_min, e_max = self._extremes # ranges of other coordinates
if self.nth_coord == 0:
lat_min = max(e_min, lat_min)
lat_max = min(e_max, lat_max)
elif self.nth_coord == 1:
lon_min = max(e_min, lon_min)
lon_max = min(e_max, lon_max)
lon_levs, lon_n, lon_factor = \
grid_finder.grid_locator1(lon_min, lon_max)
lat_levs, lat_n, lat_factor = \
grid_finder.grid_locator2(lat_min, lat_max)
if self.nth_coord == 0:
xx0 = np.full(self._line_num_points, self.value)
yy0 = np.linspace(lat_min, lat_max, self._line_num_points)
xx, yy = grid_finder.transform_xy(xx0, yy0)
elif self.nth_coord == 1:
xx0 = np.linspace(lon_min, lon_max, self._line_num_points)
yy0 = np.full(self._line_num_points, self.value)
xx, yy = grid_finder.transform_xy(xx0, yy0)
self._grid_info = {
"extremes": (lon_min, lon_max, lat_min, lat_max),
"lon_info": (lon_levs, lon_n, np.asarray(lon_factor)),
"lat_info": (lat_levs, lat_n, np.asarray(lat_factor)),
"lon_labels": grid_finder._format_ticks(
1, "bottom", lon_factor, lon_levs),
"lat_labels": grid_finder._format_ticks(
2, "bottom", lat_factor, lat_levs),
"line_xy": (xx, yy),
}
def get_axislabel_transform(self, axes):
return Affine2D() # axes.transData
def get_axislabel_pos_angle(self, axes):
def trf_xy(x, y):
trf = self.grid_helper.grid_finder.get_transform() + axes.transData
return trf.transform([x, y]).T
xmin, xmax, ymin, ymax = self._grid_info["extremes"]
if self.nth_coord == 0:
xx0 = self.value
yy0 = (ymin + ymax) / 2
elif self.nth_coord == 1:
xx0 = (xmin + xmax) / 2
yy0 = self.value
xy1, dxy1_dx, dxy1_dy = _value_and_jacobian(
trf_xy, xx0, yy0, (xmin, xmax), (ymin, ymax))
p = axes.transAxes.inverted().transform(xy1)
if 0 <= p[0] <= 1 and 0 <= p[1] <= 1:
d = [dxy1_dy, dxy1_dx][self.nth_coord]
return xy1, np.rad2deg(np.arctan2(*d[::-1]))
else:
return None, None
def get_tick_transform(self, axes):
return IdentityTransform() # axes.transData
def get_tick_iterators(self, axes):
"""tick_loc, tick_angle, tick_label, (optionally) tick_label"""
lat_levs, lat_n, lat_factor = self._grid_info["lat_info"]
yy0 = lat_levs / lat_factor
lon_levs, lon_n, lon_factor = self._grid_info["lon_info"]
xx0 = lon_levs / lon_factor
e0, e1 = self._extremes
def trf_xy(x, y):
trf = self.grid_helper.grid_finder.get_transform() + axes.transData
return trf.transform(np.column_stack(np.broadcast_arrays(x, y))).T
# find angles
if self.nth_coord == 0:
mask = (e0 <= yy0) & (yy0 <= e1)
(xx1, yy1), (dxx1, dyy1), (dxx2, dyy2) = _value_and_jacobian(
trf_xy, self.value, yy0[mask], (-np.inf, np.inf), (e0, e1))
labels = self._grid_info["lat_labels"]
elif self.nth_coord == 1:
mask = (e0 <= xx0) & (xx0 <= e1)
(xx1, yy1), (dxx2, dyy2), (dxx1, dyy1) = _value_and_jacobian(
trf_xy, xx0[mask], self.value, (-np.inf, np.inf), (e0, e1))
labels = self._grid_info["lon_labels"]
labels = [l for l, m in zip(labels, mask) if m]
angle_normal = np.arctan2(dyy1, dxx1)
angle_tangent = np.arctan2(dyy2, dxx2)
mm = (dyy1 == 0) & (dxx1 == 0) # points with degenerate normal
angle_normal[mm] = angle_tangent[mm] + np.pi / 2
tick_to_axes = self.get_tick_transform(axes) - axes.transAxes
in_01 = functools.partial(
mpl.transforms._interval_contains_close, (0, 1))
def iter_major():
for x, y, normal, tangent, lab \
in zip(xx1, yy1, angle_normal, angle_tangent, labels):
c2 = tick_to_axes.transform((x, y))
if in_01(c2[0]) and in_01(c2[1]):
yield [x, y], *np.rad2deg([normal, tangent]), lab
return iter_major(), iter([])
def get_line_transform(self, axes):
return axes.transData
def get_line(self, axes):
self.update_lim(axes)
x, y = self._grid_info["line_xy"]
return Path(np.column_stack([x, y]))
class GridHelperCurveLinear(GridHelperBase):
def __init__(self, aux_trans,
extreme_finder=None,
grid_locator1=None,
grid_locator2=None,
tick_formatter1=None,
tick_formatter2=None):
"""
Parameters
----------
aux_trans : `.Transform` or tuple[Callable, Callable]
The transform from curved coordinates to rectilinear coordinate:
either a `.Transform` instance (which provides also its inverse),
or a pair of callables ``(trans, inv_trans)`` that define the
transform and its inverse. The callables should have signature::
x_rect, y_rect = trans(x_curved, y_curved)
x_curved, y_curved = inv_trans(x_rect, y_rect)
extreme_finder
grid_locator1, grid_locator2
Grid locators for each axis.
tick_formatter1, tick_formatter2
Tick formatters for each axis.
"""
super().__init__()
self._grid_info = None
self.grid_finder = GridFinder(aux_trans,
extreme_finder,
grid_locator1,
grid_locator2,
tick_formatter1,
tick_formatter2)
def update_grid_finder(self, aux_trans=None, **kwargs):
if aux_trans is not None:
self.grid_finder.update_transform(aux_trans)
self.grid_finder.update(**kwargs)
self._old_limits = None # Force revalidation.
@_api.make_keyword_only("3.9", "nth_coord")
def new_fixed_axis(
self, loc, nth_coord=None, axis_direction=None, offset=None, axes=None):
if axes is None:
axes = self.axes
if axis_direction is None:
axis_direction = loc
helper = FixedAxisArtistHelper(self, loc, nth_coord_ticks=nth_coord)
axisline = AxisArtist(axes, helper, axis_direction=axis_direction)
# Why is clip not set on axisline, unlike in new_floating_axis or in
# the floating_axig.GridHelperCurveLinear subclass?
return axisline
def new_floating_axis(self, nth_coord, value, axes=None, axis_direction="bottom"):
if axes is None:
axes = self.axes
helper = FloatingAxisArtistHelper(
self, nth_coord, value, axis_direction)
axisline = AxisArtist(axes, helper)
axisline.line.set_clip_on(True)
axisline.line.set_clip_box(axisline.axes.bbox)
# axisline.major_ticklabels.set_visible(True)
# axisline.minor_ticklabels.set_visible(False)
return axisline
def _update_grid(self, x1, y1, x2, y2):
self._grid_info = self.grid_finder.get_grid_info(x1, y1, x2, y2)
def get_gridlines(self, which="major", axis="both"):
grid_lines = []
if axis in ["both", "x"]:
for gl in self._grid_info["lon"]["lines"]:
grid_lines.extend(gl)
if axis in ["both", "y"]:
for gl in self._grid_info["lat"]["lines"]:
grid_lines.extend(gl)
return grid_lines
@_api.deprecated("3.9")
def get_tick_iterator(self, nth_coord, axis_side, minor=False):
angle_tangent = dict(left=90, right=90, bottom=0, top=0)[axis_side]
lon_or_lat = ["lon", "lat"][nth_coord]
if not minor: # major ticks
for tick in self._grid_info[lon_or_lat]["ticks"][axis_side]:
yield *tick["loc"], angle_tangent, tick["label"]
else:
for tick in self._grid_info[lon_or_lat]["ticks"][axis_side]:
yield *tick["loc"], angle_tangent, ""

View File

@ -0,0 +1,7 @@
from mpl_toolkits.axes_grid1.parasite_axes import (
host_axes_class_factory, parasite_axes_class_factory)
from .axislines import Axes
ParasiteAxes = parasite_axes_class_factory(Axes)
HostAxes = SubplotHost = host_axes_class_factory(Axes)

View File

@ -0,0 +1,10 @@
from pathlib import Path
# Check that the test directories exist
if not (Path(__file__).parent / "baseline_images").exists():
raise OSError(
'The baseline image directory does not exist. '
'This is most likely because the test data is not installed. '
'You may need to install matplotlib from source to get the '
'test data.')

View File

@ -0,0 +1,2 @@
from matplotlib.testing.conftest import (mpl_test_settings, # noqa
pytest_configure, pytest_unconfigure)

View File

@ -0,0 +1,141 @@
import re
import numpy as np
import pytest
from mpl_toolkits.axisartist.angle_helper import (
FormatterDMS, FormatterHMS, select_step, select_step24, select_step360)
_MS_RE = (
r'''\$ # Mathtext
(
# The sign sometimes appears on a 0 when a fraction is shown.
# Check later that there's only one.
(?P<degree_sign>-)?
(?P<degree>[0-9.]+) # Degrees value
{degree} # Degree symbol (to be replaced by format.)
)?
(
(?(degree)\\,) # Separator if degrees are also visible.
(?P<minute_sign>-)?
(?P<minute>[0-9.]+) # Minutes value
{minute} # Minute symbol (to be replaced by format.)
)?
(
(?(minute)\\,) # Separator if minutes are also visible.
(?P<second_sign>-)?
(?P<second>[0-9.]+) # Seconds value
{second} # Second symbol (to be replaced by format.)
)?
\$ # Mathtext
'''
)
DMS_RE = re.compile(_MS_RE.format(degree=re.escape(FormatterDMS.deg_mark),
minute=re.escape(FormatterDMS.min_mark),
second=re.escape(FormatterDMS.sec_mark)),
re.VERBOSE)
HMS_RE = re.compile(_MS_RE.format(degree=re.escape(FormatterHMS.deg_mark),
minute=re.escape(FormatterHMS.min_mark),
second=re.escape(FormatterHMS.sec_mark)),
re.VERBOSE)
def dms2float(degrees, minutes=0, seconds=0):
return degrees + minutes / 60.0 + seconds / 3600.0
@pytest.mark.parametrize('args, kwargs, expected_levels, expected_factor', [
((-180, 180, 10), {'hour': False}, np.arange(-180, 181, 30), 1.0),
((-12, 12, 10), {'hour': True}, np.arange(-12, 13, 2), 1.0)
])
def test_select_step(args, kwargs, expected_levels, expected_factor):
levels, n, factor = select_step(*args, **kwargs)
assert n == len(levels)
np.testing.assert_array_equal(levels, expected_levels)
assert factor == expected_factor
@pytest.mark.parametrize('args, kwargs, expected_levels, expected_factor', [
((-180, 180, 10), {}, np.arange(-180, 181, 30), 1.0),
((-12, 12, 10), {}, np.arange(-750, 751, 150), 60.0)
])
def test_select_step24(args, kwargs, expected_levels, expected_factor):
levels, n, factor = select_step24(*args, **kwargs)
assert n == len(levels)
np.testing.assert_array_equal(levels, expected_levels)
assert factor == expected_factor
@pytest.mark.parametrize('args, kwargs, expected_levels, expected_factor', [
((dms2float(20, 21.2), dms2float(21, 33.3), 5), {},
np.arange(1215, 1306, 15), 60.0),
((dms2float(20.5, seconds=21.2), dms2float(20.5, seconds=33.3), 5), {},
np.arange(73820, 73835, 2), 3600.0),
((dms2float(20, 21.2), dms2float(20, 53.3), 5), {},
np.arange(1220, 1256, 5), 60.0),
((21.2, 33.3, 5), {},
np.arange(20, 35, 2), 1.0),
((dms2float(20, 21.2), dms2float(21, 33.3), 5), {},
np.arange(1215, 1306, 15), 60.0),
((dms2float(20.5, seconds=21.2), dms2float(20.5, seconds=33.3), 5), {},
np.arange(73820, 73835, 2), 3600.0),
((dms2float(20.5, seconds=21.2), dms2float(20.5, seconds=21.4), 5), {},
np.arange(7382120, 7382141, 5), 360000.0),
# test threshold factor
((dms2float(20.5, seconds=11.2), dms2float(20.5, seconds=53.3), 5),
{'threshold_factor': 60}, np.arange(12301, 12310), 600.0),
((dms2float(20.5, seconds=11.2), dms2float(20.5, seconds=53.3), 5),
{'threshold_factor': 1}, np.arange(20502, 20517, 2), 1000.0),
])
def test_select_step360(args, kwargs, expected_levels, expected_factor):
levels, n, factor = select_step360(*args, **kwargs)
assert n == len(levels)
np.testing.assert_array_equal(levels, expected_levels)
assert factor == expected_factor
@pytest.mark.parametrize('Formatter, regex',
[(FormatterDMS, DMS_RE),
(FormatterHMS, HMS_RE)],
ids=['Degree/Minute/Second', 'Hour/Minute/Second'])
@pytest.mark.parametrize('direction, factor, values', [
("left", 60, [0, -30, -60]),
("left", 600, [12301, 12302, 12303]),
("left", 3600, [0, -30, -60]),
("left", 36000, [738210, 738215, 738220]),
("left", 360000, [7382120, 7382125, 7382130]),
("left", 1., [45, 46, 47]),
("left", 10., [452, 453, 454]),
])
def test_formatters(Formatter, regex, direction, factor, values):
fmt = Formatter()
result = fmt(direction, factor, values)
prev_degree = prev_minute = prev_second = None
for tick, value in zip(result, values):
m = regex.match(tick)
assert m is not None, f'{tick!r} is not an expected tick format.'
sign = sum(m.group(sign + '_sign') is not None
for sign in ('degree', 'minute', 'second'))
assert sign <= 1, f'Only one element of tick {tick!r} may have a sign.'
sign = 1 if sign == 0 else -1
degree = float(m.group('degree') or prev_degree or 0)
minute = float(m.group('minute') or prev_minute or 0)
second = float(m.group('second') or prev_second or 0)
if Formatter == FormatterHMS:
# 360 degrees as plot range -> 24 hours as labelled range
expected_value = pytest.approx((value // 15) / factor)
else:
expected_value = pytest.approx(value / factor)
assert sign * dms2float(degree, minute, second) == expected_value, \
f'{tick!r} does not match expected tick value.'
prev_degree = degree
prev_minute = minute
prev_second = second

View File

@ -0,0 +1,99 @@
import matplotlib.pyplot as plt
from matplotlib.testing.decorators import image_comparison
from mpl_toolkits.axisartist import AxisArtistHelperRectlinear
from mpl_toolkits.axisartist.axis_artist import (AxisArtist, AxisLabel,
LabelBase, Ticks, TickLabels)
@image_comparison(['axis_artist_ticks.png'], style='default')
def test_ticks():
fig, ax = plt.subplots()
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
locs_angles = [((i / 10, 0.0), i * 30) for i in range(-1, 12)]
ticks_in = Ticks(ticksize=10, axis=ax.xaxis)
ticks_in.set_locs_angles(locs_angles)
ax.add_artist(ticks_in)
ticks_out = Ticks(ticksize=10, tick_out=True, color='C3', axis=ax.xaxis)
ticks_out.set_locs_angles(locs_angles)
ax.add_artist(ticks_out)
@image_comparison(['axis_artist_labelbase.png'], style='default')
def test_labelbase():
# Remove this line when this test image is regenerated.
plt.rcParams['text.kerning_factor'] = 6
fig, ax = plt.subplots()
ax.plot([0.5], [0.5], "o")
label = LabelBase(0.5, 0.5, "Test")
label._ref_angle = -90
label._offset_radius = 50
label.set_rotation(-90)
label.set(ha="center", va="top")
ax.add_artist(label)
@image_comparison(['axis_artist_ticklabels.png'], style='default')
def test_ticklabels():
# Remove this line when this test image is regenerated.
plt.rcParams['text.kerning_factor'] = 6
fig, ax = plt.subplots()
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
ax.plot([0.2, 0.4], [0.5, 0.5], "o")
ticks = Ticks(ticksize=10, axis=ax.xaxis)
ax.add_artist(ticks)
locs_angles_labels = [((0.2, 0.5), -90, "0.2"),
((0.4, 0.5), -120, "0.4")]
tick_locs_angles = [(xy, a + 180) for xy, a, l in locs_angles_labels]
ticks.set_locs_angles(tick_locs_angles)
ticklabels = TickLabels(axis_direction="left")
ticklabels._locs_angles_labels = locs_angles_labels
ticklabels.set_pad(10)
ax.add_artist(ticklabels)
ax.plot([0.5], [0.5], "s")
axislabel = AxisLabel(0.5, 0.5, "Test")
axislabel._offset_radius = 20
axislabel._ref_angle = 0
axislabel.set_axis_direction("bottom")
ax.add_artist(axislabel)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
@image_comparison(['axis_artist.png'], style='default')
def test_axis_artist():
# Remove this line when this test image is regenerated.
plt.rcParams['text.kerning_factor'] = 6
fig, ax = plt.subplots()
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
for loc in ('left', 'right', 'bottom'):
helper = AxisArtistHelperRectlinear.Fixed(ax, loc=loc)
axisline = AxisArtist(ax, helper, offset=None, axis_direction=loc)
ax.add_artist(axisline)
# Settings for bottom AxisArtist.
axisline.set_label("TTT")
axisline.major_ticks.set_tick_out(False)
axisline.label.set_pad(5)
ax.set_ylabel("Test")

View File

@ -0,0 +1,147 @@
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.testing.decorators import image_comparison
from matplotlib.transforms import IdentityTransform
from mpl_toolkits.axisartist.axislines import AxesZero, SubplotZero, Subplot
from mpl_toolkits.axisartist import Axes, SubplotHost
@image_comparison(['SubplotZero.png'], style='default')
def test_SubplotZero():
# Remove this line when this test image is regenerated.
plt.rcParams['text.kerning_factor'] = 6
fig = plt.figure()
ax = SubplotZero(fig, 1, 1, 1)
fig.add_subplot(ax)
ax.axis["xzero"].set_visible(True)
ax.axis["xzero"].label.set_text("Axis Zero")
for n in ["top", "right"]:
ax.axis[n].set_visible(False)
xx = np.arange(0, 2 * np.pi, 0.01)
ax.plot(xx, np.sin(xx))
ax.set_ylabel("Test")
@image_comparison(['Subplot.png'], style='default')
def test_Subplot():
# Remove this line when this test image is regenerated.
plt.rcParams['text.kerning_factor'] = 6
fig = plt.figure()
ax = Subplot(fig, 1, 1, 1)
fig.add_subplot(ax)
xx = np.arange(0, 2 * np.pi, 0.01)
ax.plot(xx, np.sin(xx))
ax.set_ylabel("Test")
ax.axis["top"].major_ticks.set_tick_out(True)
ax.axis["bottom"].major_ticks.set_tick_out(True)
ax.axis["bottom"].set_label("Tk0")
def test_Axes():
fig = plt.figure()
ax = Axes(fig, [0.15, 0.1, 0.65, 0.8])
fig.add_axes(ax)
ax.plot([1, 2, 3], [0, 1, 2])
ax.set_xscale('log')
fig.canvas.draw()
@image_comparison(['ParasiteAxesAuxTrans_meshplot.png'],
remove_text=True, style='default', tol=0.075)
def test_ParasiteAxesAuxTrans():
data = np.ones((6, 6))
data[2, 2] = 2
data[0, :] = 0
data[-2, :] = 0
data[:, 0] = 0
data[:, -2] = 0
x = np.arange(6)
y = np.arange(6)
xx, yy = np.meshgrid(x, y)
funcnames = ['pcolor', 'pcolormesh', 'contourf']
fig = plt.figure()
for i, name in enumerate(funcnames):
ax1 = SubplotHost(fig, 1, 3, i+1)
fig.add_subplot(ax1)
ax2 = ax1.get_aux_axes(IdentityTransform(), viewlim_mode=None)
if name.startswith('pcolor'):
getattr(ax2, name)(xx, yy, data[:-1, :-1])
else:
getattr(ax2, name)(xx, yy, data)
ax1.set_xlim((0, 5))
ax1.set_ylim((0, 5))
ax2.contour(xx, yy, data, colors='k')
@image_comparison(['axisline_style.png'], remove_text=True, style='mpl20')
def test_axisline_style():
fig = plt.figure(figsize=(2, 2))
ax = fig.add_subplot(axes_class=AxesZero)
ax.axis["xzero"].set_axisline_style("-|>")
ax.axis["xzero"].set_visible(True)
ax.axis["yzero"].set_axisline_style("->")
ax.axis["yzero"].set_visible(True)
for direction in ("left", "right", "bottom", "top"):
ax.axis[direction].set_visible(False)
@image_comparison(['axisline_style_size_color.png'], remove_text=True,
style='mpl20')
def test_axisline_style_size_color():
fig = plt.figure(figsize=(2, 2))
ax = fig.add_subplot(axes_class=AxesZero)
ax.axis["xzero"].set_axisline_style("-|>", size=2.0, facecolor='r')
ax.axis["xzero"].set_visible(True)
ax.axis["yzero"].set_axisline_style("->, size=1.5")
ax.axis["yzero"].set_visible(True)
for direction in ("left", "right", "bottom", "top"):
ax.axis[direction].set_visible(False)
@image_comparison(['axisline_style_tight.png'], remove_text=True,
style='mpl20')
def test_axisline_style_tight():
fig = plt.figure(figsize=(2, 2))
ax = fig.add_subplot(axes_class=AxesZero)
ax.axis["xzero"].set_axisline_style("-|>", size=5, facecolor='g')
ax.axis["xzero"].set_visible(True)
ax.axis["yzero"].set_axisline_style("->, size=8")
ax.axis["yzero"].set_visible(True)
for direction in ("left", "right", "bottom", "top"):
ax.axis[direction].set_visible(False)
fig.tight_layout()
@image_comparison(['subplotzero_ylabel.png'], style='mpl20')
def test_subplotzero_ylabel():
fig = plt.figure()
ax = fig.add_subplot(111, axes_class=SubplotZero)
ax.set(xlim=(-3, 7), ylim=(-3, 7), xlabel="x", ylabel="y")
zero_axis = ax.axis["xzero", "yzero"]
zero_axis.set_visible(True) # they are hidden by default
ax.axis["left", "right", "bottom", "top"].set_visible(False)
zero_axis.set_axisline_style("->")

View File

@ -0,0 +1,115 @@
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.projections as mprojections
import matplotlib.transforms as mtransforms
from matplotlib.testing.decorators import image_comparison
from mpl_toolkits.axisartist.axislines import Subplot
from mpl_toolkits.axisartist.floating_axes import (
FloatingAxes, GridHelperCurveLinear)
from mpl_toolkits.axisartist.grid_finder import FixedLocator
from mpl_toolkits.axisartist import angle_helper
def test_subplot():
fig = plt.figure(figsize=(5, 5))
ax = Subplot(fig, 111)
fig.add_subplot(ax)
# Rather high tolerance to allow ongoing work with floating axes internals;
# remove when image is regenerated.
@image_comparison(['curvelinear3.png'], style='default', tol=5)
def test_curvelinear3():
fig = plt.figure(figsize=(5, 5))
tr = (mtransforms.Affine2D().scale(np.pi / 180, 1) +
mprojections.PolarAxes.PolarTransform(apply_theta_transforms=False))
grid_helper = GridHelperCurveLinear(
tr,
extremes=(0, 360, 10, 3),
grid_locator1=angle_helper.LocatorDMS(15),
grid_locator2=FixedLocator([2, 4, 6, 8, 10]),
tick_formatter1=angle_helper.FormatterDMS(),
tick_formatter2=None)
ax1 = fig.add_subplot(axes_class=FloatingAxes, grid_helper=grid_helper)
r_scale = 10
tr2 = mtransforms.Affine2D().scale(1, 1 / r_scale) + tr
grid_helper2 = GridHelperCurveLinear(
tr2,
extremes=(0, 360, 10 * r_scale, 3 * r_scale),
grid_locator2=FixedLocator([30, 60, 90]))
ax1.axis["right"] = axis = grid_helper2.new_fixed_axis("right", axes=ax1)
ax1.axis["left"].label.set_text("Test 1")
ax1.axis["right"].label.set_text("Test 2")
ax1.axis["left", "right"].set_visible(False)
axis = grid_helper.new_floating_axis(1, 7, axes=ax1,
axis_direction="bottom")
ax1.axis["z"] = axis
axis.toggle(all=True, label=True)
axis.label.set_text("z = ?")
axis.label.set_visible(True)
axis.line.set_color("0.5")
ax2 = ax1.get_aux_axes(tr)
xx, yy = [67, 90, 75, 30], [2, 5, 8, 4]
ax2.scatter(xx, yy)
l, = ax2.plot(xx, yy, "k-")
l.set_clip_path(ax1.patch)
# Rather high tolerance to allow ongoing work with floating axes internals;
# remove when image is regenerated.
@image_comparison(['curvelinear4.png'], style='default', tol=0.9)
def test_curvelinear4():
# Remove this line when this test image is regenerated.
plt.rcParams['text.kerning_factor'] = 6
fig = plt.figure(figsize=(5, 5))
tr = (mtransforms.Affine2D().scale(np.pi / 180, 1) +
mprojections.PolarAxes.PolarTransform(apply_theta_transforms=False))
grid_helper = GridHelperCurveLinear(
tr,
extremes=(120, 30, 10, 0),
grid_locator1=angle_helper.LocatorDMS(5),
grid_locator2=FixedLocator([2, 4, 6, 8, 10]),
tick_formatter1=angle_helper.FormatterDMS(),
tick_formatter2=None)
ax1 = fig.add_subplot(axes_class=FloatingAxes, grid_helper=grid_helper)
ax1.clear() # Check that clear() also restores the correct limits on ax1.
ax1.axis["left"].label.set_text("Test 1")
ax1.axis["right"].label.set_text("Test 2")
ax1.axis["top"].set_visible(False)
axis = grid_helper.new_floating_axis(1, 70, axes=ax1,
axis_direction="bottom")
ax1.axis["z"] = axis
axis.toggle(all=True, label=True)
axis.label.set_axis_direction("top")
axis.label.set_text("z = ?")
axis.label.set_visible(True)
axis.line.set_color("0.5")
ax2 = ax1.get_aux_axes(tr)
xx, yy = [67, 90, 75, 30], [2, 5, 8, 4]
ax2.scatter(xx, yy)
l, = ax2.plot(xx, yy, "k-")
l.set_clip_path(ax1.patch)
def test_axis_direction():
# Check that axis direction is propagated on a floating axis
fig = plt.figure()
ax = Subplot(fig, 111)
fig.add_subplot(ax)
ax.axis['y'] = ax.new_floating_axis(nth_coord=1, value=0,
axis_direction='left')
assert ax.axis['y']._axis_direction == 'left'

View File

@ -0,0 +1,34 @@
import numpy as np
import pytest
from matplotlib.transforms import Bbox
from mpl_toolkits.axisartist.grid_finder import (
_find_line_box_crossings, FormatterPrettyPrint, MaxNLocator)
def test_find_line_box_crossings():
x = np.array([-3, -2, -1, 0., 1, 2, 3, 2, 1, 0, -1, -2, -3, 5])
y = np.arange(len(x))
bbox = Bbox.from_extents(-2, 3, 2, 12.5)
left, right, bottom, top = _find_line_box_crossings(
np.column_stack([x, y]), bbox)
((lx0, ly0), la0), ((lx1, ly1), la1), = left
((rx0, ry0), ra0), ((rx1, ry1), ra1), = right
((bx0, by0), ba0), = bottom
((tx0, ty0), ta0), = top
assert (lx0, ly0, la0) == (-2, 11, 135)
assert (lx1, ly1, la1) == pytest.approx((-2., 12.125, 7.125016))
assert (rx0, ry0, ra0) == (2, 5, 45)
assert (rx1, ry1, ra1) == (2, 7, 135)
assert (bx0, by0, ba0) == (0, 3, 45)
assert (tx0, ty0, ta0) == pytest.approx((1., 12.5, 7.125016))
def test_pretty_print_format():
locator = MaxNLocator()
locs, nloc, factor = locator(0, 100)
fmt = FormatterPrettyPrint()
assert fmt("left", None, locs) == \
[r'$\mathdefault{%d}$' % (l, ) for l in locs]

View File

@ -0,0 +1,207 @@
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.path import Path
from matplotlib.projections import PolarAxes
from matplotlib.ticker import FuncFormatter
from matplotlib.transforms import Affine2D, Transform
from matplotlib.testing.decorators import image_comparison
from mpl_toolkits.axisartist import SubplotHost
from mpl_toolkits.axes_grid1.parasite_axes import host_axes_class_factory
from mpl_toolkits.axisartist import angle_helper
from mpl_toolkits.axisartist.axislines import Axes
from mpl_toolkits.axisartist.grid_helper_curvelinear import \
GridHelperCurveLinear
@image_comparison(['custom_transform.png'], style='default', tol=0.2)
def test_custom_transform():
class MyTransform(Transform):
input_dims = output_dims = 2
def __init__(self, resolution):
"""
Resolution is the number of steps to interpolate between each input
line segment to approximate its path in transformed space.
"""
Transform.__init__(self)
self._resolution = resolution
def transform(self, ll):
x, y = ll.T
return np.column_stack([x, y - x])
transform_non_affine = transform
def transform_path(self, path):
ipath = path.interpolated(self._resolution)
return Path(self.transform(ipath.vertices), ipath.codes)
transform_path_non_affine = transform_path
def inverted(self):
return MyTransformInv(self._resolution)
class MyTransformInv(Transform):
input_dims = output_dims = 2
def __init__(self, resolution):
Transform.__init__(self)
self._resolution = resolution
def transform(self, ll):
x, y = ll.T
return np.column_stack([x, y + x])
def inverted(self):
return MyTransform(self._resolution)
fig = plt.figure()
SubplotHost = host_axes_class_factory(Axes)
tr = MyTransform(1)
grid_helper = GridHelperCurveLinear(tr)
ax1 = SubplotHost(fig, 1, 1, 1, grid_helper=grid_helper)
fig.add_subplot(ax1)
ax2 = ax1.get_aux_axes(tr, viewlim_mode="equal")
ax2.plot([3, 6], [5.0, 10.])
ax1.set_aspect(1.)
ax1.set_xlim(0, 10)
ax1.set_ylim(0, 10)
ax1.grid(True)
@image_comparison(['polar_box.png'], style='default', tol=0.04)
def test_polar_box():
fig = plt.figure(figsize=(5, 5))
# PolarAxes.PolarTransform takes radian. However, we want our coordinate
# system in degree
tr = (Affine2D().scale(np.pi / 180., 1.) +
PolarAxes.PolarTransform(apply_theta_transforms=False))
# polar projection, which involves cycle, and also has limits in
# its coordinates, needs a special method to find the extremes
# (min, max of the coordinate within the view).
extreme_finder = angle_helper.ExtremeFinderCycle(20, 20,
lon_cycle=360,
lat_cycle=None,
lon_minmax=None,
lat_minmax=(0, np.inf))
grid_helper = GridHelperCurveLinear(
tr,
extreme_finder=extreme_finder,
grid_locator1=angle_helper.LocatorDMS(12),
tick_formatter1=angle_helper.FormatterDMS(),
tick_formatter2=FuncFormatter(lambda x, p: "eight" if x == 8 else f"{int(x)}"),
)
ax1 = SubplotHost(fig, 1, 1, 1, grid_helper=grid_helper)
ax1.axis["right"].major_ticklabels.set_visible(True)
ax1.axis["top"].major_ticklabels.set_visible(True)
# let right axis shows ticklabels for 1st coordinate (angle)
ax1.axis["right"].get_helper().nth_coord_ticks = 0
# let bottom axis shows ticklabels for 2nd coordinate (radius)
ax1.axis["bottom"].get_helper().nth_coord_ticks = 1
fig.add_subplot(ax1)
ax1.axis["lat"] = axis = grid_helper.new_floating_axis(0, 45, axes=ax1)
axis.label.set_text("Test")
axis.label.set_visible(True)
axis.get_helper().set_extremes(2, 12)
ax1.axis["lon"] = axis = grid_helper.new_floating_axis(1, 6, axes=ax1)
axis.label.set_text("Test 2")
axis.get_helper().set_extremes(-180, 90)
# A parasite axes with given transform
ax2 = ax1.get_aux_axes(tr, viewlim_mode="equal")
assert ax2.transData == tr + ax1.transData
# Anything you draw in ax2 will match the ticks and grids of ax1.
ax2.plot(np.linspace(0, 30, 50), np.linspace(10, 10, 50))
ax1.set_aspect(1.)
ax1.set_xlim(-5, 12)
ax1.set_ylim(-5, 10)
ax1.grid(True)
# Remove tol & kerning_factor when this test image is regenerated.
@image_comparison(['axis_direction.png'], style='default', tol=0.13)
def test_axis_direction():
plt.rcParams['text.kerning_factor'] = 6
fig = plt.figure(figsize=(5, 5))
# PolarAxes.PolarTransform takes radian. However, we want our coordinate
# system in degree
tr = (Affine2D().scale(np.pi / 180., 1.) +
PolarAxes.PolarTransform(apply_theta_transforms=False))
# polar projection, which involves cycle, and also has limits in
# its coordinates, needs a special method to find the extremes
# (min, max of the coordinate within the view).
# 20, 20 : number of sampling points along x, y direction
extreme_finder = angle_helper.ExtremeFinderCycle(20, 20,
lon_cycle=360,
lat_cycle=None,
lon_minmax=None,
lat_minmax=(0, np.inf),
)
grid_locator1 = angle_helper.LocatorDMS(12)
tick_formatter1 = angle_helper.FormatterDMS()
grid_helper = GridHelperCurveLinear(tr,
extreme_finder=extreme_finder,
grid_locator1=grid_locator1,
tick_formatter1=tick_formatter1)
ax1 = SubplotHost(fig, 1, 1, 1, grid_helper=grid_helper)
for axis in ax1.axis.values():
axis.set_visible(False)
fig.add_subplot(ax1)
ax1.axis["lat1"] = axis = grid_helper.new_floating_axis(
0, 130,
axes=ax1, axis_direction="left")
axis.label.set_text("Test")
axis.label.set_visible(True)
axis.get_helper().set_extremes(0.001, 10)
ax1.axis["lat2"] = axis = grid_helper.new_floating_axis(
0, 50,
axes=ax1, axis_direction="right")
axis.label.set_text("Test")
axis.label.set_visible(True)
axis.get_helper().set_extremes(0.001, 10)
ax1.axis["lon"] = axis = grid_helper.new_floating_axis(
1, 10,
axes=ax1, axis_direction="bottom")
axis.label.set_text("Test 2")
axis.get_helper().set_extremes(50, 130)
axis.major_ticklabels.set_axis_direction("top")
axis.label.set_axis_direction("top")
grid_helper.grid_finder.grid_locator1.set_params(nbins=5)
grid_helper.grid_finder.grid_locator2.set_params(nbins=5)
ax1.set_aspect(1.)
ax1.set_xlim(-8, 8)
ax1.set_ylim(-4, 12)
ax1.grid(True)

View File

@ -0,0 +1,3 @@
from .axes3d import Axes3D
__all__ = ['Axes3D']

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,760 @@
# axis3d.py, original mplot3d version by John Porter
# Created: 23 Sep 2005
# Parts rewritten by Reinier Heeres <reinier@heeres.eu>
import inspect
import numpy as np
import matplotlib as mpl
from matplotlib import (
_api, artist, lines as mlines, axis as maxis, patches as mpatches,
transforms as mtransforms, colors as mcolors)
from . import art3d, proj3d
def _move_from_center(coord, centers, deltas, axmask=(True, True, True)):
"""
For each coordinate where *axmask* is True, move *coord* away from
*centers* by *deltas*.
"""
coord = np.asarray(coord)
return coord + axmask * np.copysign(1, coord - centers) * deltas
def _tick_update_position(tick, tickxs, tickys, labelpos):
"""Update tick line and label position and style."""
tick.label1.set_position(labelpos)
tick.label2.set_position(labelpos)
tick.tick1line.set_visible(True)
tick.tick2line.set_visible(False)
tick.tick1line.set_linestyle('-')
tick.tick1line.set_marker('')
tick.tick1line.set_data(tickxs, tickys)
tick.gridline.set_data([0], [0])
class Axis(maxis.XAxis):
"""An Axis class for the 3D plots."""
# These points from the unit cube make up the x, y and z-planes
_PLANES = (
(0, 3, 7, 4), (1, 2, 6, 5), # yz planes
(0, 1, 5, 4), (3, 2, 6, 7), # xz planes
(0, 1, 2, 3), (4, 5, 6, 7), # xy planes
)
# Some properties for the axes
_AXINFO = {
'x': {'i': 0, 'tickdir': 1, 'juggled': (1, 0, 2)},
'y': {'i': 1, 'tickdir': 0, 'juggled': (0, 1, 2)},
'z': {'i': 2, 'tickdir': 0, 'juggled': (0, 2, 1)},
}
def _old_init(self, adir, v_intervalx, d_intervalx, axes, *args,
rotate_label=None, **kwargs):
return locals()
def _new_init(self, axes, *, rotate_label=None, **kwargs):
return locals()
def __init__(self, *args, **kwargs):
params = _api.select_matching_signature(
[self._old_init, self._new_init], *args, **kwargs)
if "adir" in params:
_api.warn_deprecated(
"3.6", message=f"The signature of 3D Axis constructors has "
f"changed in %(since)s; the new signature is "
f"{inspect.signature(type(self).__init__)}", pending=True)
if params["adir"] != self.axis_name:
raise ValueError(f"Cannot instantiate {type(self).__name__} "
f"with adir={params['adir']!r}")
axes = params["axes"]
rotate_label = params["rotate_label"]
args = params.get("args", ())
kwargs = params["kwargs"]
name = self.axis_name
self._label_position = 'default'
self._tick_position = 'default'
# This is a temporary member variable.
# Do not depend on this existing in future releases!
self._axinfo = self._AXINFO[name].copy()
# Common parts
self._axinfo.update({
'label': {'va': 'center', 'ha': 'center',
'rotation_mode': 'anchor'},
'color': mpl.rcParams[f'axes3d.{name}axis.panecolor'],
'tick': {
'inward_factor': 0.2,
'outward_factor': 0.1,
},
})
if mpl.rcParams['_internal.classic_mode']:
self._axinfo.update({
'axisline': {'linewidth': 0.75, 'color': (0, 0, 0, 1)},
'grid': {
'color': (0.9, 0.9, 0.9, 1),
'linewidth': 1.0,
'linestyle': '-',
},
})
self._axinfo['tick'].update({
'linewidth': {
True: mpl.rcParams['lines.linewidth'], # major
False: mpl.rcParams['lines.linewidth'], # minor
}
})
else:
self._axinfo.update({
'axisline': {
'linewidth': mpl.rcParams['axes.linewidth'],
'color': mpl.rcParams['axes.edgecolor'],
},
'grid': {
'color': mpl.rcParams['grid.color'],
'linewidth': mpl.rcParams['grid.linewidth'],
'linestyle': mpl.rcParams['grid.linestyle'],
},
})
self._axinfo['tick'].update({
'linewidth': {
True: ( # major
mpl.rcParams['xtick.major.width'] if name in 'xz'
else mpl.rcParams['ytick.major.width']),
False: ( # minor
mpl.rcParams['xtick.minor.width'] if name in 'xz'
else mpl.rcParams['ytick.minor.width']),
}
})
super().__init__(axes, *args, **kwargs)
# data and viewing intervals for this direction
if "d_intervalx" in params:
self.set_data_interval(*params["d_intervalx"])
if "v_intervalx" in params:
self.set_view_interval(*params["v_intervalx"])
self.set_rotate_label(rotate_label)
self._init3d() # Inline after init3d deprecation elapses.
__init__.__signature__ = inspect.signature(_new_init)
adir = _api.deprecated("3.6", pending=True)(
property(lambda self: self.axis_name))
def _init3d(self):
self.line = mlines.Line2D(
xdata=(0, 0), ydata=(0, 0),
linewidth=self._axinfo['axisline']['linewidth'],
color=self._axinfo['axisline']['color'],
antialiased=True)
# Store dummy data in Polygon object
self.pane = mpatches.Polygon([[0, 0], [0, 1]], closed=False)
self.set_pane_color(self._axinfo['color'])
self.axes._set_artist_props(self.line)
self.axes._set_artist_props(self.pane)
self.gridlines = art3d.Line3DCollection([])
self.axes._set_artist_props(self.gridlines)
self.axes._set_artist_props(self.label)
self.axes._set_artist_props(self.offsetText)
# Need to be able to place the label at the correct location
self.label._transform = self.axes.transData
self.offsetText._transform = self.axes.transData
@_api.deprecated("3.6", pending=True)
def init3d(self): # After deprecation elapses, inline _init3d to __init__.
self._init3d()
def get_major_ticks(self, numticks=None):
ticks = super().get_major_ticks(numticks)
for t in ticks:
for obj in [
t.tick1line, t.tick2line, t.gridline, t.label1, t.label2]:
obj.set_transform(self.axes.transData)
return ticks
def get_minor_ticks(self, numticks=None):
ticks = super().get_minor_ticks(numticks)
for t in ticks:
for obj in [
t.tick1line, t.tick2line, t.gridline, t.label1, t.label2]:
obj.set_transform(self.axes.transData)
return ticks
def set_ticks_position(self, position):
"""
Set the ticks position.
Parameters
----------
position : {'lower', 'upper', 'both', 'default', 'none'}
The position of the bolded axis lines, ticks, and tick labels.
"""
if position in ['top', 'bottom']:
_api.warn_deprecated('3.8', name=f'{position=}',
obj_type='argument value',
alternative="'upper' or 'lower'")
return
_api.check_in_list(['lower', 'upper', 'both', 'default', 'none'],
position=position)
self._tick_position = position
def get_ticks_position(self):
"""
Get the ticks position.
Returns
-------
str : {'lower', 'upper', 'both', 'default', 'none'}
The position of the bolded axis lines, ticks, and tick labels.
"""
return self._tick_position
def set_label_position(self, position):
"""
Set the label position.
Parameters
----------
position : {'lower', 'upper', 'both', 'default', 'none'}
The position of the axis label.
"""
if position in ['top', 'bottom']:
_api.warn_deprecated('3.8', name=f'{position=}',
obj_type='argument value',
alternative="'upper' or 'lower'")
return
_api.check_in_list(['lower', 'upper', 'both', 'default', 'none'],
position=position)
self._label_position = position
def get_label_position(self):
"""
Get the label position.
Returns
-------
str : {'lower', 'upper', 'both', 'default', 'none'}
The position of the axis label.
"""
return self._label_position
def set_pane_color(self, color, alpha=None):
"""
Set pane color.
Parameters
----------
color : :mpltype:`color`
Color for axis pane.
alpha : float, optional
Alpha value for axis pane. If None, base it on *color*.
"""
color = mcolors.to_rgba(color, alpha)
self._axinfo['color'] = color
self.pane.set_edgecolor(color)
self.pane.set_facecolor(color)
self.pane.set_alpha(color[-1])
self.stale = True
def set_rotate_label(self, val):
"""
Whether to rotate the axis label: True, False or None.
If set to None the label will be rotated if longer than 4 chars.
"""
self._rotate_label = val
self.stale = True
def get_rotate_label(self, text):
if self._rotate_label is not None:
return self._rotate_label
else:
return len(text) > 4
def _get_coord_info(self):
mins, maxs = np.array([
self.axes.get_xbound(),
self.axes.get_ybound(),
self.axes.get_zbound(),
]).T
# Project the bounds along the current position of the cube:
bounds = mins[0], maxs[0], mins[1], maxs[1], mins[2], maxs[2]
bounds_proj = self.axes._transformed_cube(bounds)
# Determine which one of the parallel planes are higher up:
means_z0 = np.zeros(3)
means_z1 = np.zeros(3)
for i in range(3):
means_z0[i] = np.mean(bounds_proj[self._PLANES[2 * i], 2])
means_z1[i] = np.mean(bounds_proj[self._PLANES[2 * i + 1], 2])
highs = means_z0 < means_z1
# Special handling for edge-on views
equals = np.abs(means_z0 - means_z1) <= np.finfo(float).eps
if np.sum(equals) == 2:
vertical = np.where(~equals)[0][0]
if vertical == 2: # looking at XY plane
highs = np.array([True, True, highs[2]])
elif vertical == 1: # looking at XZ plane
highs = np.array([True, highs[1], False])
elif vertical == 0: # looking at YZ plane
highs = np.array([highs[0], False, False])
return mins, maxs, bounds_proj, highs
def _calc_centers_deltas(self, maxs, mins):
centers = 0.5 * (maxs + mins)
# In mpl3.8, the scale factor was 1/12. mpl3.9 changes this to
# 1/12 * 24/25 = 0.08 to compensate for the change in automargin
# behavior and keep appearance the same. The 24/25 factor is from the
# 1/48 padding added to each side of the axis in mpl3.8.
scale = 0.08
deltas = (maxs - mins) * scale
return centers, deltas
def _get_axis_line_edge_points(self, minmax, maxmin, position=None):
"""Get the edge points for the black bolded axis line."""
# When changing vertical axis some of the axes has to be
# moved to the other plane so it looks the same as if the z-axis
# was the vertical axis.
mb = [minmax, maxmin] # line from origin to nearest corner to camera
mb_rev = mb[::-1]
mm = [[mb, mb_rev, mb_rev], [mb_rev, mb_rev, mb], [mb, mb, mb]]
mm = mm[self.axes._vertical_axis][self._axinfo["i"]]
juggled = self._axinfo["juggled"]
edge_point_0 = mm[0].copy() # origin point
if ((position == 'lower' and mm[1][juggled[-1]] < mm[0][juggled[-1]]) or
(position == 'upper' and mm[1][juggled[-1]] > mm[0][juggled[-1]])):
edge_point_0[juggled[-1]] = mm[1][juggled[-1]]
else:
edge_point_0[juggled[0]] = mm[1][juggled[0]]
edge_point_1 = edge_point_0.copy()
edge_point_1[juggled[1]] = mm[1][juggled[1]]
return edge_point_0, edge_point_1
def _get_all_axis_line_edge_points(self, minmax, maxmin, axis_position=None):
# Determine edge points for the axis lines
edgep1s = []
edgep2s = []
position = []
if axis_position in (None, 'default'):
edgep1, edgep2 = self._get_axis_line_edge_points(minmax, maxmin)
edgep1s = [edgep1]
edgep2s = [edgep2]
position = ['default']
else:
edgep1_l, edgep2_l = self._get_axis_line_edge_points(minmax, maxmin,
position='lower')
edgep1_u, edgep2_u = self._get_axis_line_edge_points(minmax, maxmin,
position='upper')
if axis_position in ('lower', 'both'):
edgep1s.append(edgep1_l)
edgep2s.append(edgep2_l)
position.append('lower')
if axis_position in ('upper', 'both'):
edgep1s.append(edgep1_u)
edgep2s.append(edgep2_u)
position.append('upper')
return edgep1s, edgep2s, position
def _get_tickdir(self, position):
"""
Get the direction of the tick.
Parameters
----------
position : str, optional : {'upper', 'lower', 'default'}
The position of the axis.
Returns
-------
tickdir : int
Index which indicates which coordinate the tick line will
align with.
"""
_api.check_in_list(('upper', 'lower', 'default'), position=position)
# TODO: Move somewhere else where it's triggered less:
tickdirs_base = [v["tickdir"] for v in self._AXINFO.values()] # default
elev_mod = np.mod(self.axes.elev + 180, 360) - 180
azim_mod = np.mod(self.axes.azim, 360)
if position == 'upper':
if elev_mod >= 0:
tickdirs_base = [2, 2, 0]
else:
tickdirs_base = [1, 0, 0]
if 0 <= azim_mod < 180:
tickdirs_base[2] = 1
elif position == 'lower':
if elev_mod >= 0:
tickdirs_base = [1, 0, 1]
else:
tickdirs_base = [2, 2, 1]
if 0 <= azim_mod < 180:
tickdirs_base[2] = 0
info_i = [v["i"] for v in self._AXINFO.values()]
i = self._axinfo["i"]
vert_ax = self.axes._vertical_axis
j = vert_ax - 2
# default: tickdir = [[1, 2, 1], [2, 2, 0], [1, 0, 0]][vert_ax][i]
tickdir = np.roll(info_i, -j)[np.roll(tickdirs_base, j)][i]
return tickdir
def active_pane(self):
mins, maxs, tc, highs = self._get_coord_info()
info = self._axinfo
index = info['i']
if not highs[index]:
loc = mins[index]
plane = self._PLANES[2 * index]
else:
loc = maxs[index]
plane = self._PLANES[2 * index + 1]
xys = np.array([tc[p] for p in plane])
return xys, loc
def draw_pane(self, renderer):
"""
Draw pane.
Parameters
----------
renderer : `~matplotlib.backend_bases.RendererBase` subclass
"""
renderer.open_group('pane3d', gid=self.get_gid())
xys, loc = self.active_pane()
self.pane.xy = xys[:, :2]
self.pane.draw(renderer)
renderer.close_group('pane3d')
def _axmask(self):
axmask = [True, True, True]
axmask[self._axinfo["i"]] = False
return axmask
def _draw_ticks(self, renderer, edgep1, centers, deltas, highs,
deltas_per_point, pos):
ticks = self._update_ticks()
info = self._axinfo
index = info["i"]
juggled = info["juggled"]
mins, maxs, tc, highs = self._get_coord_info()
centers, deltas = self._calc_centers_deltas(maxs, mins)
# Draw ticks:
tickdir = self._get_tickdir(pos)
tickdelta = deltas[tickdir] if highs[tickdir] else -deltas[tickdir]
tick_info = info['tick']
tick_out = tick_info['outward_factor'] * tickdelta
tick_in = tick_info['inward_factor'] * tickdelta
tick_lw = tick_info['linewidth']
edgep1_tickdir = edgep1[tickdir]
out_tickdir = edgep1_tickdir + tick_out
in_tickdir = edgep1_tickdir - tick_in
default_label_offset = 8. # A rough estimate
points = deltas_per_point * deltas
for tick in ticks:
# Get tick line positions
pos = edgep1.copy()
pos[index] = tick.get_loc()
pos[tickdir] = out_tickdir
x1, y1, z1 = proj3d.proj_transform(*pos, self.axes.M)
pos[tickdir] = in_tickdir
x2, y2, z2 = proj3d.proj_transform(*pos, self.axes.M)
# Get position of label
labeldeltas = (tick.get_pad() + default_label_offset) * points
pos[tickdir] = edgep1_tickdir
pos = _move_from_center(pos, centers, labeldeltas, self._axmask())
lx, ly, lz = proj3d.proj_transform(*pos, self.axes.M)
_tick_update_position(tick, (x1, x2), (y1, y2), (lx, ly))
tick.tick1line.set_linewidth(tick_lw[tick._major])
tick.draw(renderer)
def _draw_offset_text(self, renderer, edgep1, edgep2, labeldeltas, centers,
highs, pep, dx, dy):
# Get general axis information:
info = self._axinfo
index = info["i"]
juggled = info["juggled"]
tickdir = info["tickdir"]
# Which of the two edge points do we want to
# use for locating the offset text?
if juggled[2] == 2:
outeredgep = edgep1
outerindex = 0
else:
outeredgep = edgep2
outerindex = 1
pos = _move_from_center(outeredgep, centers, labeldeltas,
self._axmask())
olx, oly, olz = proj3d.proj_transform(*pos, self.axes.M)
self.offsetText.set_text(self.major.formatter.get_offset())
self.offsetText.set_position((olx, oly))
angle = art3d._norm_text_angle(np.rad2deg(np.arctan2(dy, dx)))
self.offsetText.set_rotation(angle)
# Must set rotation mode to "anchor" so that
# the alignment point is used as the "fulcrum" for rotation.
self.offsetText.set_rotation_mode('anchor')
# ----------------------------------------------------------------------
# Note: the following statement for determining the proper alignment of
# the offset text. This was determined entirely by trial-and-error
# and should not be in any way considered as "the way". There are
# still some edge cases where alignment is not quite right, but this
# seems to be more of a geometry issue (in other words, I might be
# using the wrong reference points).
#
# (TT, FF, TF, FT) are the shorthand for the tuple of
# (centpt[tickdir] <= pep[tickdir, outerindex],
# centpt[index] <= pep[index, outerindex])
#
# Three-letters (e.g., TFT, FTT) are short-hand for the array of bools
# from the variable 'highs'.
# ---------------------------------------------------------------------
centpt = proj3d.proj_transform(*centers, self.axes.M)
if centpt[tickdir] > pep[tickdir, outerindex]:
# if FT and if highs has an even number of Trues
if (centpt[index] <= pep[index, outerindex]
and np.count_nonzero(highs) % 2 == 0):
# Usually, this means align right, except for the FTT case,
# in which offset for axis 1 and 2 are aligned left.
if highs.tolist() == [False, True, True] and index in (1, 2):
align = 'left'
else:
align = 'right'
else:
# The FF case
align = 'left'
else:
# if TF and if highs has an even number of Trues
if (centpt[index] > pep[index, outerindex]
and np.count_nonzero(highs) % 2 == 0):
# Usually mean align left, except if it is axis 2
align = 'right' if index == 2 else 'left'
else:
# The TT case
align = 'right'
self.offsetText.set_va('center')
self.offsetText.set_ha(align)
self.offsetText.draw(renderer)
def _draw_labels(self, renderer, edgep1, edgep2, labeldeltas, centers, dx, dy):
label = self._axinfo["label"]
# Draw labels
lxyz = 0.5 * (edgep1 + edgep2)
lxyz = _move_from_center(lxyz, centers, labeldeltas, self._axmask())
tlx, tly, tlz = proj3d.proj_transform(*lxyz, self.axes.M)
self.label.set_position((tlx, tly))
if self.get_rotate_label(self.label.get_text()):
angle = art3d._norm_text_angle(np.rad2deg(np.arctan2(dy, dx)))
self.label.set_rotation(angle)
self.label.set_va(label['va'])
self.label.set_ha(label['ha'])
self.label.set_rotation_mode(label['rotation_mode'])
self.label.draw(renderer)
@artist.allow_rasterization
def draw(self, renderer):
self.label._transform = self.axes.transData
self.offsetText._transform = self.axes.transData
renderer.open_group("axis3d", gid=self.get_gid())
# Get general axis information:
mins, maxs, tc, highs = self._get_coord_info()
centers, deltas = self._calc_centers_deltas(maxs, mins)
# Calculate offset distances
# A rough estimate; points are ambiguous since 3D plots rotate
reltoinches = self.figure.dpi_scale_trans.inverted()
ax_inches = reltoinches.transform(self.axes.bbox.size)
ax_points_estimate = sum(72. * ax_inches)
deltas_per_point = 48 / ax_points_estimate
default_offset = 21.
labeldeltas = (self.labelpad + default_offset) * deltas_per_point * deltas
# Determine edge points for the axis lines
minmax = np.where(highs, maxs, mins) # "origin" point
maxmin = np.where(~highs, maxs, mins) # "opposite" corner near camera
for edgep1, edgep2, pos in zip(*self._get_all_axis_line_edge_points(
minmax, maxmin, self._tick_position)):
# Project the edge points along the current position
pep = proj3d._proj_trans_points([edgep1, edgep2], self.axes.M)
pep = np.asarray(pep)
# The transAxes transform is used because the Text object
# rotates the text relative to the display coordinate system.
# Therefore, if we want the labels to remain parallel to the
# axis regardless of the aspect ratio, we need to convert the
# edge points of the plane to display coordinates and calculate
# an angle from that.
# TODO: Maybe Text objects should handle this themselves?
dx, dy = (self.axes.transAxes.transform([pep[0:2, 1]]) -
self.axes.transAxes.transform([pep[0:2, 0]]))[0]
# Draw the lines
self.line.set_data(pep[0], pep[1])
self.line.draw(renderer)
# Draw ticks
self._draw_ticks(renderer, edgep1, centers, deltas, highs,
deltas_per_point, pos)
# Draw Offset text
self._draw_offset_text(renderer, edgep1, edgep2, labeldeltas,
centers, highs, pep, dx, dy)
for edgep1, edgep2, pos in zip(*self._get_all_axis_line_edge_points(
minmax, maxmin, self._label_position)):
# See comments above
pep = proj3d._proj_trans_points([edgep1, edgep2], self.axes.M)
pep = np.asarray(pep)
dx, dy = (self.axes.transAxes.transform([pep[0:2, 1]]) -
self.axes.transAxes.transform([pep[0:2, 0]]))[0]
# Draw labels
self._draw_labels(renderer, edgep1, edgep2, labeldeltas, centers, dx, dy)
renderer.close_group('axis3d')
self.stale = False
@artist.allow_rasterization
def draw_grid(self, renderer):
if not self.axes._draw_grid:
return
renderer.open_group("grid3d", gid=self.get_gid())
ticks = self._update_ticks()
if len(ticks):
# Get general axis information:
info = self._axinfo
index = info["i"]
mins, maxs, tc, highs = self._get_coord_info()
minmax = np.where(highs, maxs, mins)
maxmin = np.where(~highs, maxs, mins)
# Grid points where the planes meet
xyz0 = np.tile(minmax, (len(ticks), 1))
xyz0[:, index] = [tick.get_loc() for tick in ticks]
# Grid lines go from the end of one plane through the plane
# intersection (at xyz0) to the end of the other plane. The first
# point (0) differs along dimension index-2 and the last (2) along
# dimension index-1.
lines = np.stack([xyz0, xyz0, xyz0], axis=1)
lines[:, 0, index - 2] = maxmin[index - 2]
lines[:, 2, index - 1] = maxmin[index - 1]
self.gridlines.set_segments(lines)
gridinfo = info['grid']
self.gridlines.set_color(gridinfo['color'])
self.gridlines.set_linewidth(gridinfo['linewidth'])
self.gridlines.set_linestyle(gridinfo['linestyle'])
self.gridlines.do_3d_projection()
self.gridlines.draw(renderer)
renderer.close_group('grid3d')
# TODO: Get this to work (more) properly when mplot3d supports the
# transforms framework.
def get_tightbbox(self, renderer=None, *, for_layout_only=False):
# docstring inherited
if not self.get_visible():
return
# We have to directly access the internal data structures
# (and hope they are up to date) because at draw time we
# shift the ticks and their labels around in (x, y) space
# based on the projection, the current view port, and their
# position in 3D space. If we extend the transforms framework
# into 3D we would not need to do this different book keeping
# than we do in the normal axis
major_locs = self.get_majorticklocs()
minor_locs = self.get_minorticklocs()
ticks = [*self.get_minor_ticks(len(minor_locs)),
*self.get_major_ticks(len(major_locs))]
view_low, view_high = self.get_view_interval()
if view_low > view_high:
view_low, view_high = view_high, view_low
interval_t = self.get_transform().transform([view_low, view_high])
ticks_to_draw = []
for tick in ticks:
try:
loc_t = self.get_transform().transform(tick.get_loc())
except AssertionError:
# Transform.transform doesn't allow masked values but
# some scales might make them, so we need this try/except.
pass
else:
if mtransforms._interval_contains_close(interval_t, loc_t):
ticks_to_draw.append(tick)
ticks = ticks_to_draw
bb_1, bb_2 = self._get_ticklabel_bboxes(ticks, renderer)
other = []
if self.line.get_visible():
other.append(self.line.get_window_extent(renderer))
if (self.label.get_visible() and not for_layout_only and
self.label.get_text()):
other.append(self.label.get_window_extent(renderer))
return mtransforms.Bbox.union([*bb_1, *bb_2, *other])
d_interval = _api.deprecated(
"3.6", alternative="get_data_interval", pending=True)(
property(lambda self: self.get_data_interval(),
lambda self, minmax: self.set_data_interval(*minmax)))
v_interval = _api.deprecated(
"3.6", alternative="get_view_interval", pending=True)(
property(lambda self: self.get_view_interval(),
lambda self, minmax: self.set_view_interval(*minmax)))
class XAxis(Axis):
axis_name = "x"
get_view_interval, set_view_interval = maxis._make_getset_interval(
"view", "xy_viewLim", "intervalx")
get_data_interval, set_data_interval = maxis._make_getset_interval(
"data", "xy_dataLim", "intervalx")
class YAxis(Axis):
axis_name = "y"
get_view_interval, set_view_interval = maxis._make_getset_interval(
"view", "xy_viewLim", "intervaly")
get_data_interval, set_data_interval = maxis._make_getset_interval(
"data", "xy_dataLim", "intervaly")
class ZAxis(Axis):
axis_name = "z"
get_view_interval, set_view_interval = maxis._make_getset_interval(
"view", "zz_viewLim", "intervalx")
get_data_interval, set_data_interval = maxis._make_getset_interval(
"data", "zz_dataLim", "intervalx")

View File

@ -0,0 +1,259 @@
"""
Various transforms used for by the 3D code
"""
import numpy as np
from matplotlib import _api
def world_transformation(xmin, xmax,
ymin, ymax,
zmin, zmax, pb_aspect=None):
"""
Produce a matrix that scales homogeneous coords in the specified ranges
to [0, 1], or [0, pb_aspect[i]] if the plotbox aspect ratio is specified.
"""
dx = xmax - xmin
dy = ymax - ymin
dz = zmax - zmin
if pb_aspect is not None:
ax, ay, az = pb_aspect
dx /= ax
dy /= ay
dz /= az
return np.array([[1/dx, 0, 0, -xmin/dx],
[0, 1/dy, 0, -ymin/dy],
[0, 0, 1/dz, -zmin/dz],
[0, 0, 0, 1]])
@_api.deprecated("3.8")
def rotation_about_vector(v, angle):
"""
Produce a rotation matrix for an angle in radians about a vector.
"""
return _rotation_about_vector(v, angle)
def _rotation_about_vector(v, angle):
"""
Produce a rotation matrix for an angle in radians about a vector.
"""
vx, vy, vz = v / np.linalg.norm(v)
s = np.sin(angle)
c = np.cos(angle)
t = 2*np.sin(angle/2)**2 # more numerically stable than t = 1-c
R = np.array([
[t*vx*vx + c, t*vx*vy - vz*s, t*vx*vz + vy*s],
[t*vy*vx + vz*s, t*vy*vy + c, t*vy*vz - vx*s],
[t*vz*vx - vy*s, t*vz*vy + vx*s, t*vz*vz + c]])
return R
def _view_axes(E, R, V, roll):
"""
Get the unit viewing axes in data coordinates.
Parameters
----------
E : 3-element numpy array
The coordinates of the eye/camera.
R : 3-element numpy array
The coordinates of the center of the view box.
V : 3-element numpy array
Unit vector in the direction of the vertical axis.
roll : float
The roll angle in radians.
Returns
-------
u : 3-element numpy array
Unit vector pointing towards the right of the screen.
v : 3-element numpy array
Unit vector pointing towards the top of the screen.
w : 3-element numpy array
Unit vector pointing out of the screen.
"""
w = (E - R)
w = w/np.linalg.norm(w)
u = np.cross(V, w)
u = u/np.linalg.norm(u)
v = np.cross(w, u) # Will be a unit vector
# Save some computation for the default roll=0
if roll != 0:
# A positive rotation of the camera is a negative rotation of the world
Rroll = _rotation_about_vector(w, -roll)
u = np.dot(Rroll, u)
v = np.dot(Rroll, v)
return u, v, w
def _view_transformation_uvw(u, v, w, E):
"""
Return the view transformation matrix.
Parameters
----------
u : 3-element numpy array
Unit vector pointing towards the right of the screen.
v : 3-element numpy array
Unit vector pointing towards the top of the screen.
w : 3-element numpy array
Unit vector pointing out of the screen.
E : 3-element numpy array
The coordinates of the eye/camera.
"""
Mr = np.eye(4)
Mt = np.eye(4)
Mr[:3, :3] = [u, v, w]
Mt[:3, -1] = -E
M = np.dot(Mr, Mt)
return M
@_api.deprecated("3.8")
def view_transformation(E, R, V, roll):
"""
Return the view transformation matrix.
Parameters
----------
E : 3-element numpy array
The coordinates of the eye/camera.
R : 3-element numpy array
The coordinates of the center of the view box.
V : 3-element numpy array
Unit vector in the direction of the vertical axis.
roll : float
The roll angle in radians.
"""
u, v, w = _view_axes(E, R, V, roll)
M = _view_transformation_uvw(u, v, w, E)
return M
@_api.deprecated("3.8")
def persp_transformation(zfront, zback, focal_length):
return _persp_transformation(zfront, zback, focal_length)
def _persp_transformation(zfront, zback, focal_length):
e = focal_length
a = 1 # aspect ratio
b = (zfront+zback)/(zfront-zback)
c = -2*(zfront*zback)/(zfront-zback)
proj_matrix = np.array([[e, 0, 0, 0],
[0, e/a, 0, 0],
[0, 0, b, c],
[0, 0, -1, 0]])
return proj_matrix
@_api.deprecated("3.8")
def ortho_transformation(zfront, zback):
return _ortho_transformation(zfront, zback)
def _ortho_transformation(zfront, zback):
# note: w component in the resulting vector will be (zback-zfront), not 1
a = -(zfront + zback)
b = -(zfront - zback)
proj_matrix = np.array([[2, 0, 0, 0],
[0, 2, 0, 0],
[0, 0, -2, 0],
[0, 0, a, b]])
return proj_matrix
def _proj_transform_vec(vec, M):
vecw = np.dot(M, vec)
w = vecw[3]
# clip here..
txs, tys, tzs = vecw[0]/w, vecw[1]/w, vecw[2]/w
return txs, tys, tzs
def _proj_transform_vec_clip(vec, M):
vecw = np.dot(M, vec)
w = vecw[3]
# clip here.
txs, tys, tzs = vecw[0] / w, vecw[1] / w, vecw[2] / w
tis = (0 <= vecw[0]) & (vecw[0] <= 1) & (0 <= vecw[1]) & (vecw[1] <= 1)
if np.any(tis):
tis = vecw[1] < 1
return txs, tys, tzs, tis
def inv_transform(xs, ys, zs, invM):
"""
Transform the points by the inverse of the projection matrix, *invM*.
"""
vec = _vec_pad_ones(xs, ys, zs)
vecr = np.dot(invM, vec)
if vecr.shape == (4,):
vecr = vecr.reshape((4, 1))
for i in range(vecr.shape[1]):
if vecr[3][i] != 0:
vecr[:, i] = vecr[:, i] / vecr[3][i]
return vecr[0], vecr[1], vecr[2]
def _vec_pad_ones(xs, ys, zs):
return np.array([xs, ys, zs, np.ones_like(xs)])
def proj_transform(xs, ys, zs, M):
"""
Transform the points by the projection matrix *M*.
"""
vec = _vec_pad_ones(xs, ys, zs)
return _proj_transform_vec(vec, M)
transform = _api.deprecated(
"3.8", obj_type="function", name="transform",
alternative="proj_transform")(proj_transform)
def proj_transform_clip(xs, ys, zs, M):
"""
Transform the points by the projection matrix
and return the clipping result
returns txs, tys, tzs, tis
"""
vec = _vec_pad_ones(xs, ys, zs)
return _proj_transform_vec_clip(vec, M)
@_api.deprecated("3.8")
def proj_points(points, M):
return _proj_points(points, M)
def _proj_points(points, M):
return np.column_stack(_proj_trans_points(points, M))
@_api.deprecated("3.8")
def proj_trans_points(points, M):
return _proj_trans_points(points, M)
def _proj_trans_points(points, M):
xs, ys, zs = zip(*points)
return proj_transform(xs, ys, zs, M)
@_api.deprecated("3.8")
def rot_x(V, alpha):
cosa, sina = np.cos(alpha), np.sin(alpha)
M1 = np.array([[1, 0, 0, 0],
[0, cosa, -sina, 0],
[0, sina, cosa, 0],
[0, 0, 0, 1]])
return np.dot(M1, V)

View File

@ -0,0 +1,10 @@
from pathlib import Path
# Check that the test directories exist
if not (Path(__file__).parent / "baseline_images").exists():
raise OSError(
'The baseline image directory does not exist. '
'This is most likely because the test data is not installed. '
'You may need to install matplotlib from source to get the '
'test data.')

View File

@ -0,0 +1,2 @@
from matplotlib.testing.conftest import (mpl_test_settings, # noqa
pytest_configure, pytest_unconfigure)

View File

@ -0,0 +1,56 @@
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backend_bases import MouseEvent
from mpl_toolkits.mplot3d.art3d import Line3DCollection
def test_scatter_3d_projection_conservation():
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
# fix axes3d projection
ax.roll = 0
ax.elev = 0
ax.azim = -45
ax.stale = True
x = [0, 1, 2, 3, 4]
scatter_collection = ax.scatter(x, x, x)
fig.canvas.draw_idle()
# Get scatter location on canvas and freeze the data
scatter_offset = scatter_collection.get_offsets()
scatter_location = ax.transData.transform(scatter_offset)
# Yaw -44 and -46 are enough to produce two set of scatter
# with opposite z-order without moving points too far
for azim in (-44, -46):
ax.azim = azim
ax.stale = True
fig.canvas.draw_idle()
for i in range(5):
# Create a mouse event used to locate and to get index
# from each dots
event = MouseEvent("button_press_event", fig.canvas,
*scatter_location[i, :])
contains, ind = scatter_collection.contains(event)
assert contains is True
assert len(ind["ind"]) == 1
assert ind["ind"][0] == i
def test_zordered_error():
# Smoke test for https://github.com/matplotlib/matplotlib/issues/26497
lc = [(np.fromiter([0.0, 0.0, 0.0], dtype="float"),
np.fromiter([1.0, 1.0, 1.0], dtype="float"))]
pc = [np.fromiter([0.0, 0.0], dtype="float"),
np.fromiter([0.0, 1.0], dtype="float"),
np.fromiter([1.0, 1.0], dtype="float")]
fig = plt.figure()
ax = fig.add_subplot(projection="3d")
ax.add_collection(Line3DCollection(lc))
ax.scatter(*pc, visible=False)
plt.draw()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,117 @@
import platform
import numpy as np
import matplotlib as mpl
from matplotlib.colors import same_color
from matplotlib.testing.decorators import image_comparison
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import art3d
@image_comparison(['legend_plot.png'], remove_text=True, style='mpl20')
def test_legend_plot():
fig, ax = plt.subplots(subplot_kw=dict(projection='3d'))
x = np.arange(10)
ax.plot(x, 5 - x, 'o', zdir='y', label='z=1')
ax.plot(x, x - 5, 'o', zdir='y', label='z=-1')
ax.legend()
@image_comparison(['legend_bar.png'], remove_text=True, style='mpl20')
def test_legend_bar():
fig, ax = plt.subplots(subplot_kw=dict(projection='3d'))
x = np.arange(10)
b1 = ax.bar(x, x, zdir='y', align='edge', color='m')
b2 = ax.bar(x, x[::-1], zdir='x', align='edge', color='g')
ax.legend([b1[0], b2[0]], ['up', 'down'])
@image_comparison(['fancy.png'], remove_text=True, style='mpl20',
tol=0.011 if platform.machine() == 'arm64' else 0)
def test_fancy():
fig, ax = plt.subplots(subplot_kw=dict(projection='3d'))
ax.plot(np.arange(10), np.full(10, 5), np.full(10, 5), 'o--', label='line')
ax.scatter(np.arange(10), np.arange(10, 0, -1), label='scatter')
ax.errorbar(np.full(10, 5), np.arange(10), np.full(10, 10),
xerr=0.5, zerr=0.5, label='errorbar')
ax.legend(loc='lower left', ncols=2, title='My legend', numpoints=1)
def test_linecollection_scaled_dashes():
lines1 = [[(0, .5), (.5, 1)], [(.3, .6), (.2, .2)]]
lines2 = [[[0.7, .2], [.8, .4]], [[.5, .7], [.6, .1]]]
lines3 = [[[0.6, .2], [.8, .4]], [[.5, .7], [.1, .1]]]
lc1 = art3d.Line3DCollection(lines1, linestyles="--", lw=3)
lc2 = art3d.Line3DCollection(lines2, linestyles="-.")
lc3 = art3d.Line3DCollection(lines3, linestyles=":", lw=.5)
fig, ax = plt.subplots(subplot_kw=dict(projection='3d'))
ax.add_collection(lc1)
ax.add_collection(lc2)
ax.add_collection(lc3)
leg = ax.legend([lc1, lc2, lc3], ['line1', 'line2', 'line 3'])
h1, h2, h3 = leg.legend_handles
for oh, lh in zip((lc1, lc2, lc3), (h1, h2, h3)):
assert oh.get_linestyles()[0] == lh._dash_pattern
def test_handlerline3d():
# Test marker consistency for monolithic Line3D legend handler.
fig, ax = plt.subplots(subplot_kw=dict(projection='3d'))
ax.scatter([0, 1], [0, 1], marker="v")
handles = [art3d.Line3D([0], [0], [0], marker="v")]
leg = ax.legend(handles, ["Aardvark"], numpoints=1)
assert handles[0].get_marker() == leg.legend_handles[0].get_marker()
def test_contour_legend_elements():
x, y = np.mgrid[1:10, 1:10]
h = x * y
colors = ['blue', '#00FF00', 'red']
fig, ax = plt.subplots(subplot_kw=dict(projection='3d'))
cs = ax.contour(x, y, h, levels=[10, 30, 50], colors=colors, extend='both')
artists, labels = cs.legend_elements()
assert labels == ['$x = 10.0$', '$x = 30.0$', '$x = 50.0$']
assert all(isinstance(a, mpl.lines.Line2D) for a in artists)
assert all(same_color(a.get_color(), c)
for a, c in zip(artists, colors))
def test_contourf_legend_elements():
x, y = np.mgrid[1:10, 1:10]
h = x * y
fig, ax = plt.subplots(subplot_kw=dict(projection='3d'))
cs = ax.contourf(x, y, h, levels=[10, 30, 50],
colors=['#FFFF00', '#FF00FF', '#00FFFF'],
extend='both')
cs.cmap.set_over('red')
cs.cmap.set_under('blue')
cs.changed()
artists, labels = cs.legend_elements()
assert labels == ['$x \\leq -1e+250s$',
'$10.0 < x \\leq 30.0$',
'$30.0 < x \\leq 50.0$',
'$x > 1e+250s$']
expected_colors = ('blue', '#FFFF00', '#FF00FF', 'red')
assert all(isinstance(a, mpl.patches.Rectangle) for a in artists)
assert all(same_color(a.get_facecolor(), c)
for a, c in zip(artists, expected_colors))
def test_legend_Poly3dCollection():
verts = np.asarray([[0, 0, 0], [0, 1, 1], [1, 0, 1]])
mesh = art3d.Poly3DCollection([verts], label="surface")
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
mesh.set_edgecolor('k')
handle = ax.add_collection3d(mesh)
leg = ax.legend()
assert (leg.legend_handles[0].get_facecolor()
== handle.get_facecolor()).all()