asd
This commit is contained in:
234
venv/lib/python3.12/site-packages/matplotlib/testing/__init__.py
Normal file
234
venv/lib/python3.12/site-packages/matplotlib/testing/__init__.py
Normal file
@ -0,0 +1,234 @@
|
||||
"""
|
||||
Helper functions for testing.
|
||||
"""
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
import locale
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import matplotlib as mpl
|
||||
from matplotlib import _api
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def set_font_settings_for_testing():
|
||||
mpl.rcParams['font.family'] = 'DejaVu Sans'
|
||||
mpl.rcParams['text.hinting'] = 'none'
|
||||
mpl.rcParams['text.hinting_factor'] = 8
|
||||
|
||||
|
||||
def set_reproducibility_for_testing():
|
||||
mpl.rcParams['svg.hashsalt'] = 'matplotlib'
|
||||
|
||||
|
||||
def setup():
|
||||
# The baseline images are created in this locale, so we should use
|
||||
# it during all of the tests.
|
||||
|
||||
try:
|
||||
locale.setlocale(locale.LC_ALL, 'en_US.UTF-8')
|
||||
except locale.Error:
|
||||
try:
|
||||
locale.setlocale(locale.LC_ALL, 'English_United States.1252')
|
||||
except locale.Error:
|
||||
_log.warning(
|
||||
"Could not set locale to English/United States. "
|
||||
"Some date-related tests may fail.")
|
||||
|
||||
mpl.use('Agg')
|
||||
|
||||
with _api.suppress_matplotlib_deprecation_warning():
|
||||
mpl.rcdefaults() # Start with all defaults
|
||||
|
||||
# These settings *must* be hardcoded for running the comparison tests and
|
||||
# are not necessarily the default values as specified in rcsetup.py.
|
||||
set_font_settings_for_testing()
|
||||
set_reproducibility_for_testing()
|
||||
|
||||
|
||||
def subprocess_run_for_testing(command, env=None, timeout=60, stdout=None,
|
||||
stderr=None, check=False, text=True,
|
||||
capture_output=False):
|
||||
"""
|
||||
Create and run a subprocess.
|
||||
|
||||
Thin wrapper around `subprocess.run`, intended for testing. Will
|
||||
mark fork() failures on Cygwin as expected failures: not a
|
||||
success, but not indicating a problem with the code either.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
args : list of str
|
||||
env : dict[str, str]
|
||||
timeout : float
|
||||
stdout, stderr
|
||||
check : bool
|
||||
text : bool
|
||||
Also called ``universal_newlines`` in subprocess. I chose this
|
||||
name since the main effect is returning bytes (`False`) vs. str
|
||||
(`True`), though it also tries to normalize newlines across
|
||||
platforms.
|
||||
capture_output : bool
|
||||
Set stdout and stderr to subprocess.PIPE
|
||||
|
||||
Returns
|
||||
-------
|
||||
proc : subprocess.Popen
|
||||
|
||||
See Also
|
||||
--------
|
||||
subprocess.run
|
||||
|
||||
Raises
|
||||
------
|
||||
pytest.xfail
|
||||
If platform is Cygwin and subprocess reports a fork() failure.
|
||||
"""
|
||||
if capture_output:
|
||||
stdout = stderr = subprocess.PIPE
|
||||
try:
|
||||
proc = subprocess.run(
|
||||
command, env=env,
|
||||
timeout=timeout, check=check,
|
||||
stdout=stdout, stderr=stderr,
|
||||
text=text
|
||||
)
|
||||
except BlockingIOError:
|
||||
if sys.platform == "cygwin":
|
||||
# Might want to make this more specific
|
||||
import pytest
|
||||
pytest.xfail("Fork failure")
|
||||
raise
|
||||
return proc
|
||||
|
||||
|
||||
def subprocess_run_helper(func, *args, timeout, extra_env=None):
|
||||
"""
|
||||
Run a function in a sub-process.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : function
|
||||
The function to be run. It must be in a module that is importable.
|
||||
*args : str
|
||||
Any additional command line arguments to be passed in
|
||||
the first argument to ``subprocess.run``.
|
||||
extra_env : dict[str, str]
|
||||
Any additional environment variables to be set for the subprocess.
|
||||
"""
|
||||
target = func.__name__
|
||||
module = func.__module__
|
||||
file = func.__code__.co_filename
|
||||
proc = subprocess_run_for_testing(
|
||||
[
|
||||
sys.executable,
|
||||
"-c",
|
||||
f"import importlib.util;"
|
||||
f"_spec = importlib.util.spec_from_file_location({module!r}, {file!r});"
|
||||
f"_module = importlib.util.module_from_spec(_spec);"
|
||||
f"_spec.loader.exec_module(_module);"
|
||||
f"_module.{target}()",
|
||||
*args
|
||||
],
|
||||
env={**os.environ, "SOURCE_DATE_EPOCH": "0", **(extra_env or {})},
|
||||
timeout=timeout, check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True
|
||||
)
|
||||
return proc
|
||||
|
||||
|
||||
def _check_for_pgf(texsystem):
|
||||
"""
|
||||
Check if a given TeX system + pgf is available
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texsystem : str
|
||||
The executable name to check
|
||||
"""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
tex_path = Path(tmpdir, "test.tex")
|
||||
tex_path.write_text(r"""
|
||||
\documentclass{article}
|
||||
\usepackage{pgf}
|
||||
\begin{document}
|
||||
\typeout{pgfversion=\pgfversion}
|
||||
\makeatletter
|
||||
\@@end
|
||||
""", encoding="utf-8")
|
||||
try:
|
||||
subprocess.check_call(
|
||||
[texsystem, "-halt-on-error", str(tex_path)], cwd=tmpdir,
|
||||
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
except (OSError, subprocess.CalledProcessError):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _has_tex_package(package):
|
||||
try:
|
||||
mpl.dviread.find_tex_file(f"{package}.sty")
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
def ipython_in_subprocess(requested_backend_or_gui_framework, all_expected_backends):
|
||||
import pytest
|
||||
IPython = pytest.importorskip("IPython")
|
||||
|
||||
if sys.platform == "win32":
|
||||
pytest.skip("Cannot change backend running IPython in subprocess on Windows")
|
||||
|
||||
if (IPython.version_info[:3] == (8, 24, 0) and
|
||||
requested_backend_or_gui_framework == "osx"):
|
||||
pytest.skip("Bug using macosx backend in IPython 8.24.0 fixed in 8.24.1")
|
||||
|
||||
# This code can be removed when Python 3.12, the latest version supported
|
||||
# by IPython < 8.24, reaches end-of-life in late 2028.
|
||||
for min_version, backend in all_expected_backends.items():
|
||||
if IPython.version_info[:2] >= min_version:
|
||||
expected_backend = backend
|
||||
break
|
||||
|
||||
code = ("import matplotlib as mpl, matplotlib.pyplot as plt;"
|
||||
"fig, ax=plt.subplots(); ax.plot([1, 3, 2]); mpl.get_backend()")
|
||||
proc = subprocess_run_for_testing(
|
||||
[
|
||||
"ipython",
|
||||
"--no-simple-prompt",
|
||||
f"--matplotlib={requested_backend_or_gui_framework}",
|
||||
"-c", code,
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
assert proc.stdout.strip().endswith(f"'{expected_backend}'")
|
||||
|
||||
|
||||
def is_ci_environment():
|
||||
# Common CI variables
|
||||
ci_environment_variables = [
|
||||
'CI', # Generic CI environment variable
|
||||
'CONTINUOUS_INTEGRATION', # Generic CI environment variable
|
||||
'TRAVIS', # Travis CI
|
||||
'CIRCLECI', # CircleCI
|
||||
'JENKINS', # Jenkins
|
||||
'GITLAB_CI', # GitLab CI
|
||||
'GITHUB_ACTIONS', # GitHub Actions
|
||||
'TEAMCITY_VERSION' # TeamCity
|
||||
# Add other CI environment variables as needed
|
||||
]
|
||||
|
||||
for env_var in ci_environment_variables:
|
||||
if os.getenv(env_var):
|
||||
return True
|
||||
|
||||
return False
|
||||
@ -0,0 +1,54 @@
|
||||
from collections.abc import Callable
|
||||
import subprocess
|
||||
from typing import Any, IO, Literal, overload
|
||||
|
||||
def set_font_settings_for_testing() -> None: ...
|
||||
def set_reproducibility_for_testing() -> None: ...
|
||||
def setup() -> None: ...
|
||||
@overload
|
||||
def subprocess_run_for_testing(
|
||||
command: list[str],
|
||||
env: dict[str, str] | None = ...,
|
||||
timeout: float | None = ...,
|
||||
stdout: int | IO[Any] | None = ...,
|
||||
stderr: int | IO[Any] | None = ...,
|
||||
check: bool = ...,
|
||||
*,
|
||||
text: Literal[True],
|
||||
capture_output: bool = ...,
|
||||
) -> subprocess.CompletedProcess[str]: ...
|
||||
@overload
|
||||
def subprocess_run_for_testing(
|
||||
command: list[str],
|
||||
env: dict[str, str] | None = ...,
|
||||
timeout: float | None = ...,
|
||||
stdout: int | IO[Any] | None = ...,
|
||||
stderr: int | IO[Any] | None = ...,
|
||||
check: bool = ...,
|
||||
text: Literal[False] = ...,
|
||||
capture_output: bool = ...,
|
||||
) -> subprocess.CompletedProcess[bytes]: ...
|
||||
@overload
|
||||
def subprocess_run_for_testing(
|
||||
command: list[str],
|
||||
env: dict[str, str] | None = ...,
|
||||
timeout: float | None = ...,
|
||||
stdout: int | IO[Any] | None = ...,
|
||||
stderr: int | IO[Any] | None = ...,
|
||||
check: bool = ...,
|
||||
text: bool = ...,
|
||||
capture_output: bool = ...,
|
||||
) -> subprocess.CompletedProcess[bytes] | subprocess.CompletedProcess[str]: ...
|
||||
def subprocess_run_helper(
|
||||
func: Callable[[], None],
|
||||
*args: Any,
|
||||
timeout: float,
|
||||
extra_env: dict[str, str] | None = ...,
|
||||
) -> subprocess.CompletedProcess[str]: ...
|
||||
def _check_for_pgf(texsystem: str) -> bool: ...
|
||||
def _has_tex_package(package: str) -> bool: ...
|
||||
def ipython_in_subprocess(
|
||||
requested_backend_or_gui_framework: str,
|
||||
all_expected_backends: dict[tuple[int, int], str],
|
||||
) -> None: ...
|
||||
def is_ci_environment() -> bool: ...
|
||||
@ -0,0 +1,49 @@
|
||||
"""
|
||||
pytest markers for the internal Matplotlib test suite.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
import matplotlib.testing
|
||||
import matplotlib.testing.compare
|
||||
from matplotlib import _get_executable_info, ExecutableNotFoundError
|
||||
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _checkdep_usetex() -> bool:
|
||||
if not shutil.which("tex"):
|
||||
_log.warning("usetex mode requires TeX.")
|
||||
return False
|
||||
try:
|
||||
_get_executable_info("dvipng")
|
||||
except ExecutableNotFoundError:
|
||||
_log.warning("usetex mode requires dvipng.")
|
||||
return False
|
||||
try:
|
||||
_get_executable_info("gs")
|
||||
except ExecutableNotFoundError:
|
||||
_log.warning("usetex mode requires ghostscript.")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
needs_ghostscript = pytest.mark.skipif(
|
||||
"eps" not in matplotlib.testing.compare.converter,
|
||||
reason="This test needs a ghostscript installation")
|
||||
needs_pgf_lualatex = pytest.mark.skipif(
|
||||
not matplotlib.testing._check_for_pgf('lualatex'),
|
||||
reason='lualatex + pgf is required')
|
||||
needs_pgf_pdflatex = pytest.mark.skipif(
|
||||
not matplotlib.testing._check_for_pgf('pdflatex'),
|
||||
reason='pdflatex + pgf is required')
|
||||
needs_pgf_xelatex = pytest.mark.skipif(
|
||||
not matplotlib.testing._check_for_pgf('xelatex'),
|
||||
reason='xelatex + pgf is required')
|
||||
needs_usetex = pytest.mark.skipif(
|
||||
not _checkdep_usetex(),
|
||||
reason="This test needs a TeX installation")
|
||||
520
venv/lib/python3.12/site-packages/matplotlib/testing/compare.py
Normal file
520
venv/lib/python3.12/site-packages/matplotlib/testing/compare.py
Normal file
@ -0,0 +1,520 @@
|
||||
"""
|
||||
Utilities for comparing image results.
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import functools
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from tempfile import TemporaryDirectory, TemporaryFile
|
||||
import weakref
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import matplotlib as mpl
|
||||
from matplotlib import cbook
|
||||
from matplotlib.testing.exceptions import ImageComparisonFailure
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ['calculate_rms', 'comparable_formats', 'compare_images']
|
||||
|
||||
|
||||
def make_test_filename(fname, purpose):
|
||||
"""
|
||||
Make a new filename by inserting *purpose* before the file's extension.
|
||||
"""
|
||||
base, ext = os.path.splitext(fname)
|
||||
return f'{base}-{purpose}{ext}'
|
||||
|
||||
|
||||
def _get_cache_path():
|
||||
cache_dir = Path(mpl.get_cachedir(), 'test_cache')
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
return cache_dir
|
||||
|
||||
|
||||
def get_cache_dir():
|
||||
return str(_get_cache_path())
|
||||
|
||||
|
||||
def get_file_hash(path, block_size=2 ** 20):
|
||||
md5 = hashlib.md5()
|
||||
with open(path, 'rb') as fd:
|
||||
while True:
|
||||
data = fd.read(block_size)
|
||||
if not data:
|
||||
break
|
||||
md5.update(data)
|
||||
|
||||
if Path(path).suffix == '.pdf':
|
||||
md5.update(str(mpl._get_executable_info("gs").version)
|
||||
.encode('utf-8'))
|
||||
elif Path(path).suffix == '.svg':
|
||||
md5.update(str(mpl._get_executable_info("inkscape").version)
|
||||
.encode('utf-8'))
|
||||
|
||||
return md5.hexdigest()
|
||||
|
||||
|
||||
class _ConverterError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class _Converter:
|
||||
def __init__(self):
|
||||
self._proc = None
|
||||
# Explicitly register deletion from an atexit handler because if we
|
||||
# wait until the object is GC'd (which occurs later), then some module
|
||||
# globals (e.g. signal.SIGKILL) has already been set to None, and
|
||||
# kill() doesn't work anymore...
|
||||
atexit.register(self.__del__)
|
||||
|
||||
def __del__(self):
|
||||
if self._proc:
|
||||
self._proc.kill()
|
||||
self._proc.wait()
|
||||
for stream in filter(None, [self._proc.stdin,
|
||||
self._proc.stdout,
|
||||
self._proc.stderr]):
|
||||
stream.close()
|
||||
self._proc = None
|
||||
|
||||
def _read_until(self, terminator):
|
||||
"""Read until the prompt is reached."""
|
||||
buf = bytearray()
|
||||
while True:
|
||||
c = self._proc.stdout.read(1)
|
||||
if not c:
|
||||
raise _ConverterError(os.fsdecode(bytes(buf)))
|
||||
buf.extend(c)
|
||||
if buf.endswith(terminator):
|
||||
return bytes(buf)
|
||||
|
||||
|
||||
class _GSConverter(_Converter):
|
||||
def __call__(self, orig, dest):
|
||||
if not self._proc:
|
||||
self._proc = subprocess.Popen(
|
||||
[mpl._get_executable_info("gs").executable,
|
||||
"-dNOSAFER", "-dNOPAUSE", "-dEPSCrop", "-sDEVICE=png16m"],
|
||||
# As far as I can see, ghostscript never outputs to stderr.
|
||||
stdin=subprocess.PIPE, stdout=subprocess.PIPE)
|
||||
try:
|
||||
self._read_until(b"\nGS")
|
||||
except _ConverterError as e:
|
||||
raise OSError(f"Failed to start Ghostscript:\n\n{e.args[0]}") from None
|
||||
|
||||
def encode_and_escape(name):
|
||||
return (os.fsencode(name)
|
||||
.replace(b"\\", b"\\\\")
|
||||
.replace(b"(", br"\(")
|
||||
.replace(b")", br"\)"))
|
||||
|
||||
self._proc.stdin.write(
|
||||
b"<< /OutputFile ("
|
||||
+ encode_and_escape(dest)
|
||||
+ b") >> setpagedevice ("
|
||||
+ encode_and_escape(orig)
|
||||
+ b") run flush\n")
|
||||
self._proc.stdin.flush()
|
||||
# GS> if nothing left on the stack; GS<n> if n items left on the stack.
|
||||
err = self._read_until((b"GS<", b"GS>"))
|
||||
stack = self._read_until(b">") if err.endswith(b"GS<") else b""
|
||||
if stack or not os.path.exists(dest):
|
||||
stack_size = int(stack[:-1]) if stack else 0
|
||||
self._proc.stdin.write(b"pop\n" * stack_size)
|
||||
# Using the systemencoding should at least get the filenames right.
|
||||
raise ImageComparisonFailure(
|
||||
(err + stack).decode(sys.getfilesystemencoding(), "replace"))
|
||||
|
||||
|
||||
class _SVGConverter(_Converter):
|
||||
def __call__(self, orig, dest):
|
||||
old_inkscape = mpl._get_executable_info("inkscape").version.major < 1
|
||||
terminator = b"\n>" if old_inkscape else b"> "
|
||||
if not hasattr(self, "_tmpdir"):
|
||||
self._tmpdir = TemporaryDirectory()
|
||||
# On Windows, we must make sure that self._proc has terminated
|
||||
# (which __del__ does) before clearing _tmpdir.
|
||||
weakref.finalize(self._tmpdir, self.__del__)
|
||||
if (not self._proc # First run.
|
||||
or self._proc.poll() is not None): # Inkscape terminated.
|
||||
if self._proc is not None and self._proc.poll() is not None:
|
||||
for stream in filter(None, [self._proc.stdin,
|
||||
self._proc.stdout,
|
||||
self._proc.stderr]):
|
||||
stream.close()
|
||||
env = {
|
||||
**os.environ,
|
||||
# If one passes e.g. a png file to Inkscape, it will try to
|
||||
# query the user for conversion options via a GUI (even with
|
||||
# `--without-gui`). Unsetting `DISPLAY` prevents this (and
|
||||
# causes GTK to crash and Inkscape to terminate, but that'll
|
||||
# just be reported as a regular exception below).
|
||||
"DISPLAY": "",
|
||||
# Do not load any user options.
|
||||
"INKSCAPE_PROFILE_DIR": self._tmpdir.name,
|
||||
}
|
||||
# Old versions of Inkscape (e.g. 0.48.3.1) seem to sometimes
|
||||
# deadlock when stderr is redirected to a pipe, so we redirect it
|
||||
# to a temporary file instead. This is not necessary anymore as of
|
||||
# Inkscape 0.92.1.
|
||||
stderr = TemporaryFile()
|
||||
self._proc = subprocess.Popen(
|
||||
["inkscape", "--without-gui", "--shell"] if old_inkscape else
|
||||
["inkscape", "--shell"],
|
||||
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=stderr,
|
||||
env=env, cwd=self._tmpdir.name)
|
||||
# Slight abuse, but makes shutdown handling easier.
|
||||
self._proc.stderr = stderr
|
||||
try:
|
||||
self._read_until(terminator)
|
||||
except _ConverterError as err:
|
||||
raise OSError(
|
||||
"Failed to start Inkscape in interactive mode:\n\n"
|
||||
+ err.args[0]) from err
|
||||
|
||||
# Inkscape's shell mode does not support escaping metacharacters in the
|
||||
# filename ("\n", and ":;" for inkscape>=1). Avoid any problems by
|
||||
# running from a temporary directory and using fixed filenames.
|
||||
inkscape_orig = Path(self._tmpdir.name, os.fsdecode(b"f.svg"))
|
||||
inkscape_dest = Path(self._tmpdir.name, os.fsdecode(b"f.png"))
|
||||
try:
|
||||
inkscape_orig.symlink_to(Path(orig).resolve())
|
||||
except OSError:
|
||||
shutil.copyfile(orig, inkscape_orig)
|
||||
self._proc.stdin.write(
|
||||
b"f.svg --export-png=f.png\n" if old_inkscape else
|
||||
b"file-open:f.svg;export-filename:f.png;export-do;file-close\n")
|
||||
self._proc.stdin.flush()
|
||||
try:
|
||||
self._read_until(terminator)
|
||||
except _ConverterError as err:
|
||||
# Inkscape's output is not localized but gtk's is, so the output
|
||||
# stream probably has a mixed encoding. Using the filesystem
|
||||
# encoding should at least get the filenames right...
|
||||
self._proc.stderr.seek(0)
|
||||
raise ImageComparisonFailure(
|
||||
self._proc.stderr.read().decode(
|
||||
sys.getfilesystemencoding(), "replace")) from err
|
||||
os.remove(inkscape_orig)
|
||||
shutil.move(inkscape_dest, dest)
|
||||
|
||||
def __del__(self):
|
||||
super().__del__()
|
||||
if hasattr(self, "_tmpdir"):
|
||||
self._tmpdir.cleanup()
|
||||
|
||||
|
||||
class _SVGWithMatplotlibFontsConverter(_SVGConverter):
|
||||
"""
|
||||
A SVG converter which explicitly adds the fonts shipped by Matplotlib to
|
||||
Inkspace's font search path, to better support `svg.fonttype = "none"`
|
||||
(which is in particular used by certain mathtext tests).
|
||||
"""
|
||||
|
||||
def __call__(self, orig, dest):
|
||||
if not hasattr(self, "_tmpdir"):
|
||||
self._tmpdir = TemporaryDirectory()
|
||||
shutil.copytree(cbook._get_data_path("fonts/ttf"),
|
||||
Path(self._tmpdir.name, "fonts"))
|
||||
return super().__call__(orig, dest)
|
||||
|
||||
|
||||
def _update_converter():
|
||||
try:
|
||||
mpl._get_executable_info("gs")
|
||||
except mpl.ExecutableNotFoundError:
|
||||
pass
|
||||
else:
|
||||
converter['pdf'] = converter['eps'] = _GSConverter()
|
||||
try:
|
||||
mpl._get_executable_info("inkscape")
|
||||
except mpl.ExecutableNotFoundError:
|
||||
pass
|
||||
else:
|
||||
converter['svg'] = _SVGConverter()
|
||||
|
||||
|
||||
#: A dictionary that maps filename extensions to functions which themselves
|
||||
#: convert between arguments `old` and `new` (filenames).
|
||||
converter = {}
|
||||
_update_converter()
|
||||
_svg_with_matplotlib_fonts_converter = _SVGWithMatplotlibFontsConverter()
|
||||
|
||||
|
||||
def comparable_formats():
|
||||
"""
|
||||
Return the list of file formats that `.compare_images` can compare
|
||||
on this system.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of str
|
||||
E.g. ``['png', 'pdf', 'svg', 'eps']``.
|
||||
|
||||
"""
|
||||
return ['png', *converter]
|
||||
|
||||
|
||||
def convert(filename, cache):
|
||||
"""
|
||||
Convert the named file to png; return the name of the created file.
|
||||
|
||||
If *cache* is True, the result of the conversion is cached in
|
||||
`matplotlib.get_cachedir() + '/test_cache/'`. The caching is based on a
|
||||
hash of the exact contents of the input file. Old cache entries are
|
||||
automatically deleted as needed to keep the size of the cache capped to
|
||||
twice the size of all baseline images.
|
||||
"""
|
||||
path = Path(filename)
|
||||
if not path.exists():
|
||||
raise OSError(f"{path} does not exist")
|
||||
if path.suffix[1:] not in converter:
|
||||
import pytest
|
||||
pytest.skip(f"Don't know how to convert {path.suffix} files to png")
|
||||
newpath = path.parent / f"{path.stem}_{path.suffix[1:]}.png"
|
||||
|
||||
# Only convert the file if the destination doesn't already exist or
|
||||
# is out of date.
|
||||
if not newpath.exists() or newpath.stat().st_mtime < path.stat().st_mtime:
|
||||
cache_dir = _get_cache_path() if cache else None
|
||||
|
||||
if cache_dir is not None:
|
||||
_register_conversion_cache_cleaner_once()
|
||||
hash_value = get_file_hash(path)
|
||||
cached_path = cache_dir / (hash_value + newpath.suffix)
|
||||
if cached_path.exists():
|
||||
_log.debug("For %s: reusing cached conversion.", filename)
|
||||
shutil.copyfile(cached_path, newpath)
|
||||
return str(newpath)
|
||||
|
||||
_log.debug("For %s: converting to png.", filename)
|
||||
convert = converter[path.suffix[1:]]
|
||||
if path.suffix == ".svg":
|
||||
contents = path.read_text()
|
||||
# NOTE: This check should be kept in sync with font styling in
|
||||
# `lib/matplotlib/backends/backend_svg.py`. If it changes, then be sure to
|
||||
# re-generate any SVG test files using this mode, or else such tests will
|
||||
# fail to use the converter for the expected images (but will for the
|
||||
# results), and the tests will fail strangely.
|
||||
if 'style="font:' in contents:
|
||||
# for svg.fonttype = none, we explicitly patch the font search
|
||||
# path so that fonts shipped by Matplotlib are found.
|
||||
convert = _svg_with_matplotlib_fonts_converter
|
||||
convert(path, newpath)
|
||||
|
||||
if cache_dir is not None:
|
||||
_log.debug("For %s: caching conversion result.", filename)
|
||||
shutil.copyfile(newpath, cached_path)
|
||||
|
||||
return str(newpath)
|
||||
|
||||
|
||||
def _clean_conversion_cache():
|
||||
# This will actually ignore mpl_toolkits baseline images, but they're
|
||||
# relatively small.
|
||||
baseline_images_size = sum(
|
||||
path.stat().st_size
|
||||
for path in Path(mpl.__file__).parent.glob("**/baseline_images/**/*"))
|
||||
# 2x: one full copy of baselines, and one full copy of test results
|
||||
# (actually an overestimate: we don't convert png baselines and results).
|
||||
max_cache_size = 2 * baseline_images_size
|
||||
# Reduce cache until it fits.
|
||||
with cbook._lock_path(_get_cache_path()):
|
||||
cache_stat = {
|
||||
path: path.stat() for path in _get_cache_path().glob("*")}
|
||||
cache_size = sum(stat.st_size for stat in cache_stat.values())
|
||||
paths_by_atime = sorted( # Oldest at the end.
|
||||
cache_stat, key=lambda path: cache_stat[path].st_atime,
|
||||
reverse=True)
|
||||
while cache_size > max_cache_size:
|
||||
path = paths_by_atime.pop()
|
||||
cache_size -= cache_stat[path].st_size
|
||||
path.unlink()
|
||||
|
||||
|
||||
@functools.cache # Ensure this is only registered once.
|
||||
def _register_conversion_cache_cleaner_once():
|
||||
atexit.register(_clean_conversion_cache)
|
||||
|
||||
|
||||
def crop_to_same(actual_path, actual_image, expected_path, expected_image):
|
||||
# clip the images to the same size -- this is useful only when
|
||||
# comparing eps to pdf
|
||||
if actual_path[-7:-4] == 'eps' and expected_path[-7:-4] == 'pdf':
|
||||
aw, ah, ad = actual_image.shape
|
||||
ew, eh, ed = expected_image.shape
|
||||
actual_image = actual_image[int(aw / 2 - ew / 2):int(
|
||||
aw / 2 + ew / 2), int(ah / 2 - eh / 2):int(ah / 2 + eh / 2)]
|
||||
return actual_image, expected_image
|
||||
|
||||
|
||||
def calculate_rms(expected_image, actual_image):
|
||||
"""
|
||||
Calculate the per-pixel errors, then compute the root mean square error.
|
||||
"""
|
||||
if expected_image.shape != actual_image.shape:
|
||||
raise ImageComparisonFailure(
|
||||
f"Image sizes do not match expected size: {expected_image.shape} "
|
||||
f"actual size {actual_image.shape}")
|
||||
# Convert to float to avoid overflowing finite integer types.
|
||||
return np.sqrt(((expected_image - actual_image).astype(float) ** 2).mean())
|
||||
|
||||
|
||||
# NOTE: compare_image and save_diff_image assume that the image does not have
|
||||
# 16-bit depth, as Pillow converts these to RGB incorrectly.
|
||||
|
||||
|
||||
def _load_image(path):
|
||||
img = Image.open(path)
|
||||
# In an RGBA image, if the smallest value in the alpha channel is 255, all
|
||||
# values in it must be 255, meaning that the image is opaque. If so,
|
||||
# discard the alpha channel so that it may compare equal to an RGB image.
|
||||
if img.mode != "RGBA" or img.getextrema()[3][0] == 255:
|
||||
img = img.convert("RGB")
|
||||
return np.asarray(img)
|
||||
|
||||
|
||||
def compare_images(expected, actual, tol, in_decorator=False):
|
||||
"""
|
||||
Compare two "image" files checking differences within a tolerance.
|
||||
|
||||
The two given filenames may point to files which are convertible to
|
||||
PNG via the `.converter` dictionary. The underlying RMS is calculated
|
||||
with the `.calculate_rms` function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expected : str
|
||||
The filename of the expected image.
|
||||
actual : str
|
||||
The filename of the actual image.
|
||||
tol : float
|
||||
The tolerance (a color value difference, where 255 is the
|
||||
maximal difference). The test fails if the average pixel
|
||||
difference is greater than this value.
|
||||
in_decorator : bool
|
||||
Determines the output format. If called from image_comparison
|
||||
decorator, this should be True. (default=False)
|
||||
|
||||
Returns
|
||||
-------
|
||||
None or dict or str
|
||||
Return *None* if the images are equal within the given tolerance.
|
||||
|
||||
If the images differ, the return value depends on *in_decorator*.
|
||||
If *in_decorator* is true, a dict with the following entries is
|
||||
returned:
|
||||
|
||||
- *rms*: The RMS of the image difference.
|
||||
- *expected*: The filename of the expected image.
|
||||
- *actual*: The filename of the actual image.
|
||||
- *diff_image*: The filename of the difference image.
|
||||
- *tol*: The comparison tolerance.
|
||||
|
||||
Otherwise, a human-readable multi-line string representation of this
|
||||
information is returned.
|
||||
|
||||
Examples
|
||||
--------
|
||||
::
|
||||
|
||||
img1 = "./baseline/plot.png"
|
||||
img2 = "./output/plot.png"
|
||||
compare_images(img1, img2, 0.001)
|
||||
|
||||
"""
|
||||
actual = os.fspath(actual)
|
||||
if not os.path.exists(actual):
|
||||
raise Exception(f"Output image {actual} does not exist.")
|
||||
if os.stat(actual).st_size == 0:
|
||||
raise Exception(f"Output image file {actual} is empty.")
|
||||
|
||||
# Convert the image to png
|
||||
expected = os.fspath(expected)
|
||||
if not os.path.exists(expected):
|
||||
raise OSError(f'Baseline image {expected!r} does not exist.')
|
||||
extension = expected.split('.')[-1]
|
||||
if extension != 'png':
|
||||
actual = convert(actual, cache=True)
|
||||
expected = convert(expected, cache=True)
|
||||
|
||||
# open the image files
|
||||
expected_image = _load_image(expected)
|
||||
actual_image = _load_image(actual)
|
||||
|
||||
actual_image, expected_image = crop_to_same(
|
||||
actual, actual_image, expected, expected_image)
|
||||
|
||||
diff_image = make_test_filename(actual, 'failed-diff')
|
||||
|
||||
if tol <= 0:
|
||||
if np.array_equal(expected_image, actual_image):
|
||||
return None
|
||||
|
||||
# convert to signed integers, so that the images can be subtracted without
|
||||
# overflow
|
||||
expected_image = expected_image.astype(np.int16)
|
||||
actual_image = actual_image.astype(np.int16)
|
||||
|
||||
rms = calculate_rms(expected_image, actual_image)
|
||||
|
||||
if rms <= tol:
|
||||
return None
|
||||
|
||||
save_diff_image(expected, actual, diff_image)
|
||||
|
||||
results = dict(rms=rms, expected=str(expected),
|
||||
actual=str(actual), diff=str(diff_image), tol=tol)
|
||||
|
||||
if not in_decorator:
|
||||
# Then the results should be a string suitable for stdout.
|
||||
template = ['Error: Image files did not match.',
|
||||
'RMS Value: {rms}',
|
||||
'Expected: \n {expected}',
|
||||
'Actual: \n {actual}',
|
||||
'Difference:\n {diff}',
|
||||
'Tolerance: \n {tol}', ]
|
||||
results = '\n '.join([line.format(**results) for line in template])
|
||||
return results
|
||||
|
||||
|
||||
def save_diff_image(expected, actual, output):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
expected : str
|
||||
File path of expected image.
|
||||
actual : str
|
||||
File path of actual image.
|
||||
output : str
|
||||
File path to save difference image to.
|
||||
"""
|
||||
expected_image = _load_image(expected)
|
||||
actual_image = _load_image(actual)
|
||||
actual_image, expected_image = crop_to_same(
|
||||
actual, actual_image, expected, expected_image)
|
||||
expected_image = np.array(expected_image, float)
|
||||
actual_image = np.array(actual_image, float)
|
||||
if expected_image.shape != actual_image.shape:
|
||||
raise ImageComparisonFailure(
|
||||
f"Image sizes do not match expected size: {expected_image.shape} "
|
||||
f"actual size {actual_image.shape}")
|
||||
abs_diff = np.abs(expected_image - actual_image)
|
||||
|
||||
# expand differences in luminance domain
|
||||
abs_diff *= 10
|
||||
abs_diff = np.clip(abs_diff, 0, 255).astype(np.uint8)
|
||||
|
||||
if abs_diff.shape[2] == 4: # Hard-code the alpha channel to fully solid
|
||||
abs_diff[:, :, 3] = 255
|
||||
|
||||
Image.fromarray(abs_diff).save(output, format="png")
|
||||
@ -0,0 +1,32 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Literal, overload
|
||||
|
||||
from numpy.typing import NDArray
|
||||
|
||||
__all__ = ["calculate_rms", "comparable_formats", "compare_images"]
|
||||
|
||||
def make_test_filename(fname: str, purpose: str) -> str: ...
|
||||
def get_cache_dir() -> str: ...
|
||||
def get_file_hash(path: str, block_size: int = ...) -> str: ...
|
||||
|
||||
converter: dict[str, Callable[[str, str], None]] = {}
|
||||
|
||||
def comparable_formats() -> list[str]: ...
|
||||
def convert(filename: str, cache: bool) -> str: ...
|
||||
def crop_to_same(
|
||||
actual_path: str, actual_image: NDArray, expected_path: str, expected_image: NDArray
|
||||
) -> tuple[NDArray, NDArray]: ...
|
||||
def calculate_rms(expected_image: NDArray, actual_image: NDArray) -> float: ...
|
||||
@overload
|
||||
def compare_images(
|
||||
expected: str, actual: str, tol: float, in_decorator: Literal[True]
|
||||
) -> None | dict[str, float | str]: ...
|
||||
@overload
|
||||
def compare_images(
|
||||
expected: str, actual: str, tol: float, in_decorator: Literal[False]
|
||||
) -> None | str: ...
|
||||
@overload
|
||||
def compare_images(
|
||||
expected: str, actual: str, tol: float, in_decorator: bool = ...
|
||||
) -> None | str | dict[str, float | str]: ...
|
||||
def save_diff_image(expected: str, actual: str, output: str) -> None: ...
|
||||
100
venv/lib/python3.12/site-packages/matplotlib/testing/conftest.py
Normal file
100
venv/lib/python3.12/site-packages/matplotlib/testing/conftest.py
Normal file
@ -0,0 +1,100 @@
|
||||
import pytest
|
||||
import sys
|
||||
import matplotlib
|
||||
from matplotlib import _api
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
# config is initialized here rather than in pytest.ini so that `pytest
|
||||
# --pyargs matplotlib` (which would not find pytest.ini) works. The only
|
||||
# entries in pytest.ini set minversion (which is checked earlier),
|
||||
# testpaths/python_files, as they are required to properly find the tests
|
||||
for key, value in [
|
||||
("markers", "flaky: (Provided by pytest-rerunfailures.)"),
|
||||
("markers", "timeout: (Provided by pytest-timeout.)"),
|
||||
("markers", "backend: Set alternate Matplotlib backend temporarily."),
|
||||
("markers", "baseline_images: Compare output against references."),
|
||||
("markers", "pytz: Tests that require pytz to be installed."),
|
||||
("filterwarnings", "error"),
|
||||
("filterwarnings",
|
||||
"ignore:.*The py23 module has been deprecated:DeprecationWarning"),
|
||||
("filterwarnings",
|
||||
r"ignore:DynamicImporter.find_spec\(\) not found; "
|
||||
r"falling back to find_module\(\):ImportWarning"),
|
||||
]:
|
||||
config.addinivalue_line(key, value)
|
||||
|
||||
matplotlib.use('agg', force=True)
|
||||
matplotlib._called_from_pytest = True
|
||||
matplotlib._init_tests()
|
||||
|
||||
|
||||
def pytest_unconfigure(config):
|
||||
matplotlib._called_from_pytest = False
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mpl_test_settings(request):
|
||||
from matplotlib.testing.decorators import _cleanup_cm
|
||||
|
||||
with _cleanup_cm():
|
||||
|
||||
backend = None
|
||||
backend_marker = request.node.get_closest_marker('backend')
|
||||
prev_backend = matplotlib.get_backend()
|
||||
if backend_marker is not None:
|
||||
assert len(backend_marker.args) == 1, \
|
||||
"Marker 'backend' must specify 1 backend."
|
||||
backend, = backend_marker.args
|
||||
skip_on_importerror = backend_marker.kwargs.get(
|
||||
'skip_on_importerror', False)
|
||||
|
||||
# special case Qt backend importing to avoid conflicts
|
||||
if backend.lower().startswith('qt5'):
|
||||
if any(sys.modules.get(k) for k in ('PyQt4', 'PySide')):
|
||||
pytest.skip('Qt4 binding already imported')
|
||||
|
||||
matplotlib.testing.setup()
|
||||
with _api.suppress_matplotlib_deprecation_warning():
|
||||
if backend is not None:
|
||||
# This import must come after setup() so it doesn't load the
|
||||
# default backend prematurely.
|
||||
import matplotlib.pyplot as plt
|
||||
try:
|
||||
plt.switch_backend(backend)
|
||||
except ImportError as exc:
|
||||
# Should only occur for the cairo backend tests, if neither
|
||||
# pycairo nor cairocffi are installed.
|
||||
if 'cairo' in backend.lower() or skip_on_importerror:
|
||||
pytest.skip("Failed to switch to backend "
|
||||
f"{backend} ({exc}).")
|
||||
else:
|
||||
raise
|
||||
# Default of cleanup and image_comparison too.
|
||||
matplotlib.style.use(["classic", "_classic_test_patch"])
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if backend is not None:
|
||||
plt.close("all")
|
||||
matplotlib.use(prev_backend)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pd():
|
||||
"""Fixture to import and configure pandas."""
|
||||
pd = pytest.importorskip('pandas')
|
||||
try:
|
||||
from pandas.plotting import (
|
||||
deregister_matplotlib_converters as deregister)
|
||||
deregister()
|
||||
except ImportError:
|
||||
pass
|
||||
return pd
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def xr():
|
||||
"""Fixture to import xarray."""
|
||||
xr = pytest.importorskip('xarray')
|
||||
return xr
|
||||
@ -0,0 +1,12 @@
|
||||
from types import ModuleType
|
||||
|
||||
import pytest
|
||||
|
||||
def pytest_configure(config: pytest.Config) -> None: ...
|
||||
def pytest_unconfigure(config: pytest.Config) -> None: ...
|
||||
@pytest.fixture
|
||||
def mpl_test_settings(request: pytest.FixtureRequest) -> None: ...
|
||||
@pytest.fixture
|
||||
def pd() -> ModuleType: ...
|
||||
@pytest.fixture
|
||||
def xr() -> ModuleType: ...
|
||||
@ -0,0 +1,464 @@
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import os
|
||||
from platform import uname
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import string
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
from packaging.version import parse as parse_version
|
||||
|
||||
import matplotlib.style
|
||||
import matplotlib.units
|
||||
import matplotlib.testing
|
||||
from matplotlib import _pylab_helpers, cbook, ft2font, pyplot as plt, ticker
|
||||
from .compare import comparable_formats, compare_images, make_test_filename
|
||||
from .exceptions import ImageComparisonFailure
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _cleanup_cm():
|
||||
orig_units_registry = matplotlib.units.registry.copy()
|
||||
try:
|
||||
with warnings.catch_warnings(), matplotlib.rc_context():
|
||||
yield
|
||||
finally:
|
||||
matplotlib.units.registry.clear()
|
||||
matplotlib.units.registry.update(orig_units_registry)
|
||||
plt.close("all")
|
||||
|
||||
|
||||
def _check_freetype_version(ver):
|
||||
if ver is None:
|
||||
return True
|
||||
|
||||
if isinstance(ver, str):
|
||||
ver = (ver, ver)
|
||||
ver = [parse_version(x) for x in ver]
|
||||
found = parse_version(ft2font.__freetype_version__)
|
||||
|
||||
return ver[0] <= found <= ver[1]
|
||||
|
||||
|
||||
def _checked_on_freetype_version(required_freetype_version):
|
||||
import pytest
|
||||
return pytest.mark.xfail(
|
||||
not _check_freetype_version(required_freetype_version),
|
||||
reason=f"Mismatched version of freetype. "
|
||||
f"Test requires '{required_freetype_version}', "
|
||||
f"you have '{ft2font.__freetype_version__}'",
|
||||
raises=ImageComparisonFailure, strict=False)
|
||||
|
||||
|
||||
def remove_ticks_and_titles(figure):
|
||||
figure.suptitle("")
|
||||
null_formatter = ticker.NullFormatter()
|
||||
def remove_ticks(ax):
|
||||
"""Remove ticks in *ax* and all its child Axes."""
|
||||
ax.set_title("")
|
||||
ax.xaxis.set_major_formatter(null_formatter)
|
||||
ax.xaxis.set_minor_formatter(null_formatter)
|
||||
ax.yaxis.set_major_formatter(null_formatter)
|
||||
ax.yaxis.set_minor_formatter(null_formatter)
|
||||
try:
|
||||
ax.zaxis.set_major_formatter(null_formatter)
|
||||
ax.zaxis.set_minor_formatter(null_formatter)
|
||||
except AttributeError:
|
||||
pass
|
||||
for child in ax.child_axes:
|
||||
remove_ticks(child)
|
||||
for ax in figure.get_axes():
|
||||
remove_ticks(ax)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _collect_new_figures():
|
||||
"""
|
||||
After::
|
||||
|
||||
with _collect_new_figures() as figs:
|
||||
some_code()
|
||||
|
||||
the list *figs* contains the figures that have been created during the
|
||||
execution of ``some_code``, sorted by figure number.
|
||||
"""
|
||||
managers = _pylab_helpers.Gcf.figs
|
||||
preexisting = [manager for manager in managers.values()]
|
||||
new_figs = []
|
||||
try:
|
||||
yield new_figs
|
||||
finally:
|
||||
new_managers = sorted([manager for manager in managers.values()
|
||||
if manager not in preexisting],
|
||||
key=lambda manager: manager.num)
|
||||
new_figs[:] = [manager.canvas.figure for manager in new_managers]
|
||||
|
||||
|
||||
def _raise_on_image_difference(expected, actual, tol):
|
||||
__tracebackhide__ = True
|
||||
|
||||
err = compare_images(expected, actual, tol, in_decorator=True)
|
||||
if err:
|
||||
for key in ["actual", "expected", "diff"]:
|
||||
err[key] = os.path.relpath(err[key])
|
||||
raise ImageComparisonFailure(
|
||||
('images not close (RMS %(rms).3f):'
|
||||
'\n\t%(actual)s\n\t%(expected)s\n\t%(diff)s') % err)
|
||||
|
||||
|
||||
class _ImageComparisonBase:
|
||||
"""
|
||||
Image comparison base class
|
||||
|
||||
This class provides *just* the comparison-related functionality and avoids
|
||||
any code that would be specific to any testing framework.
|
||||
"""
|
||||
|
||||
def __init__(self, func, tol, remove_text, savefig_kwargs):
|
||||
self.func = func
|
||||
self.baseline_dir, self.result_dir = _image_directories(func)
|
||||
self.tol = tol
|
||||
self.remove_text = remove_text
|
||||
self.savefig_kwargs = savefig_kwargs
|
||||
|
||||
def copy_baseline(self, baseline, extension):
|
||||
baseline_path = self.baseline_dir / baseline
|
||||
orig_expected_path = baseline_path.with_suffix(f'.{extension}')
|
||||
if extension == 'eps' and not orig_expected_path.exists():
|
||||
orig_expected_path = orig_expected_path.with_suffix('.pdf')
|
||||
expected_fname = make_test_filename(
|
||||
self.result_dir / orig_expected_path.name, 'expected')
|
||||
try:
|
||||
# os.symlink errors if the target already exists.
|
||||
with contextlib.suppress(OSError):
|
||||
os.remove(expected_fname)
|
||||
try:
|
||||
if 'microsoft' in uname().release.lower():
|
||||
raise OSError # On WSL, symlink breaks silently
|
||||
os.symlink(orig_expected_path, expected_fname)
|
||||
except OSError: # On Windows, symlink *may* be unavailable.
|
||||
shutil.copyfile(orig_expected_path, expected_fname)
|
||||
except OSError as err:
|
||||
raise ImageComparisonFailure(
|
||||
f"Missing baseline image {expected_fname} because the "
|
||||
f"following file cannot be accessed: "
|
||||
f"{orig_expected_path}") from err
|
||||
return expected_fname
|
||||
|
||||
def compare(self, fig, baseline, extension, *, _lock=False):
|
||||
__tracebackhide__ = True
|
||||
|
||||
if self.remove_text:
|
||||
remove_ticks_and_titles(fig)
|
||||
|
||||
actual_path = (self.result_dir / baseline).with_suffix(f'.{extension}')
|
||||
kwargs = self.savefig_kwargs.copy()
|
||||
if extension == 'pdf':
|
||||
kwargs.setdefault('metadata',
|
||||
{'Creator': None, 'Producer': None,
|
||||
'CreationDate': None})
|
||||
|
||||
lock = (cbook._lock_path(actual_path)
|
||||
if _lock else contextlib.nullcontext())
|
||||
with lock:
|
||||
try:
|
||||
fig.savefig(actual_path, **kwargs)
|
||||
finally:
|
||||
# Matplotlib has an autouse fixture to close figures, but this
|
||||
# makes things more convenient for third-party users.
|
||||
plt.close(fig)
|
||||
expected_path = self.copy_baseline(baseline, extension)
|
||||
_raise_on_image_difference(expected_path, actual_path, self.tol)
|
||||
|
||||
|
||||
def _pytest_image_comparison(baseline_images, extensions, tol,
|
||||
freetype_version, remove_text, savefig_kwargs,
|
||||
style):
|
||||
"""
|
||||
Decorate function with image comparison for pytest.
|
||||
|
||||
This function creates a decorator that wraps a figure-generating function
|
||||
with image comparison code.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
KEYWORD_ONLY = inspect.Parameter.KEYWORD_ONLY
|
||||
|
||||
def decorator(func):
|
||||
old_sig = inspect.signature(func)
|
||||
|
||||
@functools.wraps(func)
|
||||
@pytest.mark.parametrize('extension', extensions)
|
||||
@matplotlib.style.context(style)
|
||||
@_checked_on_freetype_version(freetype_version)
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, extension, request, **kwargs):
|
||||
__tracebackhide__ = True
|
||||
if 'extension' in old_sig.parameters:
|
||||
kwargs['extension'] = extension
|
||||
if 'request' in old_sig.parameters:
|
||||
kwargs['request'] = request
|
||||
|
||||
if extension not in comparable_formats():
|
||||
reason = {
|
||||
'pdf': 'because Ghostscript is not installed',
|
||||
'eps': 'because Ghostscript is not installed',
|
||||
'svg': 'because Inkscape is not installed',
|
||||
}.get(extension, 'on this system')
|
||||
pytest.skip(f"Cannot compare {extension} files {reason}")
|
||||
|
||||
img = _ImageComparisonBase(func, tol=tol, remove_text=remove_text,
|
||||
savefig_kwargs=savefig_kwargs)
|
||||
matplotlib.testing.set_font_settings_for_testing()
|
||||
|
||||
with _collect_new_figures() as figs:
|
||||
func(*args, **kwargs)
|
||||
|
||||
# If the test is parametrized in any way other than applied via
|
||||
# this decorator, then we need to use a lock to prevent two
|
||||
# processes from touching the same output file.
|
||||
needs_lock = any(
|
||||
marker.args[0] != 'extension'
|
||||
for marker in request.node.iter_markers('parametrize'))
|
||||
|
||||
if baseline_images is not None:
|
||||
our_baseline_images = baseline_images
|
||||
else:
|
||||
# Allow baseline image list to be produced on the fly based on
|
||||
# current parametrization.
|
||||
our_baseline_images = request.getfixturevalue(
|
||||
'baseline_images')
|
||||
|
||||
assert len(figs) == len(our_baseline_images), (
|
||||
f"Test generated {len(figs)} images but there are "
|
||||
f"{len(our_baseline_images)} baseline images")
|
||||
for fig, baseline in zip(figs, our_baseline_images):
|
||||
img.compare(fig, baseline, extension, _lock=needs_lock)
|
||||
|
||||
parameters = list(old_sig.parameters.values())
|
||||
if 'extension' not in old_sig.parameters:
|
||||
parameters += [inspect.Parameter('extension', KEYWORD_ONLY)]
|
||||
if 'request' not in old_sig.parameters:
|
||||
parameters += [inspect.Parameter("request", KEYWORD_ONLY)]
|
||||
new_sig = old_sig.replace(parameters=parameters)
|
||||
wrapper.__signature__ = new_sig
|
||||
|
||||
# Reach a bit into pytest internals to hoist the marks from our wrapped
|
||||
# function.
|
||||
new_marks = getattr(func, 'pytestmark', []) + wrapper.pytestmark
|
||||
wrapper.pytestmark = new_marks
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def image_comparison(baseline_images, extensions=None, tol=0,
|
||||
freetype_version=None, remove_text=False,
|
||||
savefig_kwarg=None,
|
||||
# Default of mpl_test_settings fixture and cleanup too.
|
||||
style=("classic", "_classic_test_patch")):
|
||||
"""
|
||||
Compare images generated by the test with those specified in
|
||||
*baseline_images*, which must correspond, else an `ImageComparisonFailure`
|
||||
exception will be raised.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
baseline_images : list or None
|
||||
A list of strings specifying the names of the images generated by
|
||||
calls to `.Figure.savefig`.
|
||||
|
||||
If *None*, the test function must use the ``baseline_images`` fixture,
|
||||
either as a parameter or with `pytest.mark.usefixtures`. This value is
|
||||
only allowed when using pytest.
|
||||
|
||||
extensions : None or list of str
|
||||
The list of extensions to test, e.g. ``['png', 'pdf']``.
|
||||
|
||||
If *None*, defaults to all supported extensions: png, pdf, and svg.
|
||||
|
||||
When testing a single extension, it can be directly included in the
|
||||
names passed to *baseline_images*. In that case, *extensions* must not
|
||||
be set.
|
||||
|
||||
In order to keep the size of the test suite from ballooning, we only
|
||||
include the ``svg`` or ``pdf`` outputs if the test is explicitly
|
||||
exercising a feature dependent on that backend (see also the
|
||||
`check_figures_equal` decorator for that purpose).
|
||||
|
||||
tol : float, default: 0
|
||||
The RMS threshold above which the test is considered failed.
|
||||
|
||||
Due to expected small differences in floating-point calculations, on
|
||||
32-bit systems an additional 0.06 is added to this threshold.
|
||||
|
||||
freetype_version : str or tuple
|
||||
The expected freetype version or range of versions for this test to
|
||||
pass.
|
||||
|
||||
remove_text : bool
|
||||
Remove the title and tick text from the figure before comparison. This
|
||||
is useful to make the baseline images independent of variations in text
|
||||
rendering between different versions of FreeType.
|
||||
|
||||
This does not remove other, more deliberate, text, such as legends and
|
||||
annotations.
|
||||
|
||||
savefig_kwarg : dict
|
||||
Optional arguments that are passed to the savefig method.
|
||||
|
||||
style : str, dict, or list
|
||||
The optional style(s) to apply to the image test. The test itself
|
||||
can also apply additional styles if desired. Defaults to ``["classic",
|
||||
"_classic_test_patch"]``.
|
||||
"""
|
||||
|
||||
if baseline_images is not None:
|
||||
# List of non-empty filename extensions.
|
||||
baseline_exts = [*filter(None, {Path(baseline).suffix[1:]
|
||||
for baseline in baseline_images})]
|
||||
if baseline_exts:
|
||||
if extensions is not None:
|
||||
raise ValueError(
|
||||
"When including extensions directly in 'baseline_images', "
|
||||
"'extensions' cannot be set as well")
|
||||
if len(baseline_exts) > 1:
|
||||
raise ValueError(
|
||||
"When including extensions directly in 'baseline_images', "
|
||||
"all baselines must share the same suffix")
|
||||
extensions = baseline_exts
|
||||
baseline_images = [ # Chop suffix out from baseline_images.
|
||||
Path(baseline).stem for baseline in baseline_images]
|
||||
if extensions is None:
|
||||
# Default extensions to test, if not set via baseline_images.
|
||||
extensions = ['png', 'pdf', 'svg']
|
||||
if savefig_kwarg is None:
|
||||
savefig_kwarg = dict() # default no kwargs to savefig
|
||||
if sys.maxsize <= 2**32:
|
||||
tol += 0.06
|
||||
return _pytest_image_comparison(
|
||||
baseline_images=baseline_images, extensions=extensions, tol=tol,
|
||||
freetype_version=freetype_version, remove_text=remove_text,
|
||||
savefig_kwargs=savefig_kwarg, style=style)
|
||||
|
||||
|
||||
def check_figures_equal(*, extensions=("png", "pdf", "svg"), tol=0):
|
||||
"""
|
||||
Decorator for test cases that generate and compare two figures.
|
||||
|
||||
The decorated function must take two keyword arguments, *fig_test*
|
||||
and *fig_ref*, and draw the test and reference images on them.
|
||||
After the function returns, the figures are saved and compared.
|
||||
|
||||
This decorator should be preferred over `image_comparison` when possible in
|
||||
order to keep the size of the test suite from ballooning.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
extensions : list, default: ["png", "pdf", "svg"]
|
||||
The extensions to test.
|
||||
tol : float
|
||||
The RMS threshold above which the test is considered failed.
|
||||
|
||||
Raises
|
||||
------
|
||||
RuntimeError
|
||||
If any new figures are created (and not subsequently closed) inside
|
||||
the test function.
|
||||
|
||||
Examples
|
||||
--------
|
||||
Check that calling `.Axes.plot` with a single argument plots it against
|
||||
``[0, 1, 2, ...]``::
|
||||
|
||||
@check_figures_equal()
|
||||
def test_plot(fig_test, fig_ref):
|
||||
fig_test.subplots().plot([1, 3, 5])
|
||||
fig_ref.subplots().plot([0, 1, 2], [1, 3, 5])
|
||||
|
||||
"""
|
||||
ALLOWED_CHARS = set(string.digits + string.ascii_letters + '_-[]()')
|
||||
KEYWORD_ONLY = inspect.Parameter.KEYWORD_ONLY
|
||||
|
||||
def decorator(func):
|
||||
import pytest
|
||||
|
||||
_, result_dir = _image_directories(func)
|
||||
old_sig = inspect.signature(func)
|
||||
|
||||
if not {"fig_test", "fig_ref"}.issubset(old_sig.parameters):
|
||||
raise ValueError("The decorated function must have at least the "
|
||||
"parameters 'fig_test' and 'fig_ref', but your "
|
||||
f"function has the signature {old_sig}")
|
||||
|
||||
@pytest.mark.parametrize("ext", extensions)
|
||||
def wrapper(*args, ext, request, **kwargs):
|
||||
if 'ext' in old_sig.parameters:
|
||||
kwargs['ext'] = ext
|
||||
if 'request' in old_sig.parameters:
|
||||
kwargs['request'] = request
|
||||
|
||||
file_name = "".join(c for c in request.node.name
|
||||
if c in ALLOWED_CHARS)
|
||||
try:
|
||||
fig_test = plt.figure("test")
|
||||
fig_ref = plt.figure("reference")
|
||||
with _collect_new_figures() as figs:
|
||||
func(*args, fig_test=fig_test, fig_ref=fig_ref, **kwargs)
|
||||
if figs:
|
||||
raise RuntimeError('Number of open figures changed during '
|
||||
'test. Make sure you are plotting to '
|
||||
'fig_test or fig_ref, or if this is '
|
||||
'deliberate explicitly close the '
|
||||
'new figure(s) inside the test.')
|
||||
test_image_path = result_dir / (file_name + "." + ext)
|
||||
ref_image_path = result_dir / (file_name + "-expected." + ext)
|
||||
fig_test.savefig(test_image_path)
|
||||
fig_ref.savefig(ref_image_path)
|
||||
_raise_on_image_difference(
|
||||
ref_image_path, test_image_path, tol=tol
|
||||
)
|
||||
finally:
|
||||
plt.close(fig_test)
|
||||
plt.close(fig_ref)
|
||||
|
||||
parameters = [
|
||||
param
|
||||
for param in old_sig.parameters.values()
|
||||
if param.name not in {"fig_test", "fig_ref"}
|
||||
]
|
||||
if 'ext' not in old_sig.parameters:
|
||||
parameters += [inspect.Parameter("ext", KEYWORD_ONLY)]
|
||||
if 'request' not in old_sig.parameters:
|
||||
parameters += [inspect.Parameter("request", KEYWORD_ONLY)]
|
||||
new_sig = old_sig.replace(parameters=parameters)
|
||||
wrapper.__signature__ = new_sig
|
||||
|
||||
# reach a bit into pytest internals to hoist the marks from
|
||||
# our wrapped function
|
||||
new_marks = getattr(func, "pytestmark", []) + wrapper.pytestmark
|
||||
wrapper.pytestmark = new_marks
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _image_directories(func):
|
||||
"""
|
||||
Compute the baseline and result image directories for testing *func*.
|
||||
|
||||
For test module ``foo.bar.test_baz``, the baseline directory is at
|
||||
``foo/bar/baseline_images/test_baz`` and the result directory at
|
||||
``$(pwd)/result_images/test_baz``. The result directory is created if it
|
||||
doesn't exist.
|
||||
"""
|
||||
module_path = Path(inspect.getfile(func))
|
||||
baseline_dir = module_path.parent / "baseline_images" / module_path.stem
|
||||
result_dir = Path().resolve() / "result_images" / module_path.stem
|
||||
result_dir.mkdir(parents=True, exist_ok=True)
|
||||
return baseline_dir, result_dir
|
||||
@ -0,0 +1,25 @@
|
||||
from collections.abc import Callable, Sequence
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.typing import RcStyleType
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
def remove_ticks_and_titles(figure: Figure) -> None: ...
|
||||
def image_comparison(
|
||||
baseline_images: list[str] | None,
|
||||
extensions: list[str] | None = ...,
|
||||
tol: float = ...,
|
||||
freetype_version: tuple[str, str] | str | None = ...,
|
||||
remove_text: bool = ...,
|
||||
savefig_kwarg: dict[str, Any] | None = ...,
|
||||
style: RcStyleType = ...,
|
||||
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: ...
|
||||
def check_figures_equal(
|
||||
*, extensions: Sequence[str] = ..., tol: float = ...
|
||||
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: ...
|
||||
def _image_directories(func: Callable) -> tuple[Path, Path]: ...
|
||||
@ -0,0 +1,4 @@
|
||||
class ImageComparisonFailure(AssertionError):
|
||||
"""
|
||||
Raise this exception to mark a test as a comparison between two images.
|
||||
"""
|
||||
@ -0,0 +1,138 @@
|
||||
"""Duration module."""
|
||||
|
||||
import functools
|
||||
import operator
|
||||
|
||||
from matplotlib import _api
|
||||
|
||||
|
||||
class Duration:
|
||||
"""Class Duration in development."""
|
||||
|
||||
allowed = ["ET", "UTC"]
|
||||
|
||||
def __init__(self, frame, seconds):
|
||||
"""
|
||||
Create a new Duration object.
|
||||
|
||||
= ERROR CONDITIONS
|
||||
- If the input frame is not in the allowed list, an error is thrown.
|
||||
|
||||
= INPUT VARIABLES
|
||||
- frame The frame of the duration. Must be 'ET' or 'UTC'
|
||||
- seconds The number of seconds in the Duration.
|
||||
"""
|
||||
_api.check_in_list(self.allowed, frame=frame)
|
||||
self._frame = frame
|
||||
self._seconds = seconds
|
||||
|
||||
def frame(self):
|
||||
"""Return the frame the duration is in."""
|
||||
return self._frame
|
||||
|
||||
def __abs__(self):
|
||||
"""Return the absolute value of the duration."""
|
||||
return Duration(self._frame, abs(self._seconds))
|
||||
|
||||
def __neg__(self):
|
||||
"""Return the negative value of this Duration."""
|
||||
return Duration(self._frame, -self._seconds)
|
||||
|
||||
def seconds(self):
|
||||
"""Return the number of seconds in the Duration."""
|
||||
return self._seconds
|
||||
|
||||
def __bool__(self):
|
||||
return self._seconds != 0
|
||||
|
||||
def _cmp(self, op, rhs):
|
||||
"""
|
||||
Check that *self* and *rhs* share frames; compare them using *op*.
|
||||
"""
|
||||
self.checkSameFrame(rhs, "compare")
|
||||
return op(self._seconds, rhs._seconds)
|
||||
|
||||
__eq__ = functools.partialmethod(_cmp, operator.eq)
|
||||
__ne__ = functools.partialmethod(_cmp, operator.ne)
|
||||
__lt__ = functools.partialmethod(_cmp, operator.lt)
|
||||
__le__ = functools.partialmethod(_cmp, operator.le)
|
||||
__gt__ = functools.partialmethod(_cmp, operator.gt)
|
||||
__ge__ = functools.partialmethod(_cmp, operator.ge)
|
||||
|
||||
def __add__(self, rhs):
|
||||
"""
|
||||
Add two Durations.
|
||||
|
||||
= ERROR CONDITIONS
|
||||
- If the input rhs is not in the same frame, an error is thrown.
|
||||
|
||||
= INPUT VARIABLES
|
||||
- rhs The Duration to add.
|
||||
|
||||
= RETURN VALUE
|
||||
- Returns the sum of ourselves and the input Duration.
|
||||
"""
|
||||
# Delay-load due to circular dependencies.
|
||||
import matplotlib.testing.jpl_units as U
|
||||
|
||||
if isinstance(rhs, U.Epoch):
|
||||
return rhs + self
|
||||
|
||||
self.checkSameFrame(rhs, "add")
|
||||
return Duration(self._frame, self._seconds + rhs._seconds)
|
||||
|
||||
def __sub__(self, rhs):
|
||||
"""
|
||||
Subtract two Durations.
|
||||
|
||||
= ERROR CONDITIONS
|
||||
- If the input rhs is not in the same frame, an error is thrown.
|
||||
|
||||
= INPUT VARIABLES
|
||||
- rhs The Duration to subtract.
|
||||
|
||||
= RETURN VALUE
|
||||
- Returns the difference of ourselves and the input Duration.
|
||||
"""
|
||||
self.checkSameFrame(rhs, "sub")
|
||||
return Duration(self._frame, self._seconds - rhs._seconds)
|
||||
|
||||
def __mul__(self, rhs):
|
||||
"""
|
||||
Scale a UnitDbl by a value.
|
||||
|
||||
= INPUT VARIABLES
|
||||
- rhs The scalar to multiply by.
|
||||
|
||||
= RETURN VALUE
|
||||
- Returns the scaled Duration.
|
||||
"""
|
||||
return Duration(self._frame, self._seconds * float(rhs))
|
||||
|
||||
__rmul__ = __mul__
|
||||
|
||||
def __str__(self):
|
||||
"""Print the Duration."""
|
||||
return f"{self._seconds:g} {self._frame}"
|
||||
|
||||
def __repr__(self):
|
||||
"""Print the Duration."""
|
||||
return f"Duration('{self._frame}', {self._seconds:g})"
|
||||
|
||||
def checkSameFrame(self, rhs, func):
|
||||
"""
|
||||
Check to see if frames are the same.
|
||||
|
||||
= ERROR CONDITIONS
|
||||
- If the frame of the rhs Duration is not the same as our frame,
|
||||
an error is thrown.
|
||||
|
||||
= INPUT VARIABLES
|
||||
- rhs The Duration to check for the same frame
|
||||
- func The name of the function doing the check.
|
||||
"""
|
||||
if self._frame != rhs._frame:
|
||||
raise ValueError(
|
||||
f"Cannot {func} Durations with different frames.\n"
|
||||
f"LHS: {self._frame}\n"
|
||||
f"RHS: {rhs._frame}")
|
||||
@ -0,0 +1,211 @@
|
||||
"""Epoch module."""
|
||||
|
||||
import functools
|
||||
import operator
|
||||
import math
|
||||
import datetime as DT
|
||||
|
||||
from matplotlib import _api
|
||||
from matplotlib.dates import date2num
|
||||
|
||||
|
||||
class Epoch:
|
||||
# Frame conversion offsets in seconds
|
||||
# t(TO) = t(FROM) + allowed[ FROM ][ TO ]
|
||||
allowed = {
|
||||
"ET": {
|
||||
"UTC": +64.1839,
|
||||
},
|
||||
"UTC": {
|
||||
"ET": -64.1839,
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, frame, sec=None, jd=None, daynum=None, dt=None):
|
||||
"""
|
||||
Create a new Epoch object.
|
||||
|
||||
Build an epoch 1 of 2 ways:
|
||||
|
||||
Using seconds past a Julian date:
|
||||
# Epoch('ET', sec=1e8, jd=2451545)
|
||||
|
||||
or using a matplotlib day number
|
||||
# Epoch('ET', daynum=730119.5)
|
||||
|
||||
= ERROR CONDITIONS
|
||||
- If the input units are not in the allowed list, an error is thrown.
|
||||
|
||||
= INPUT VARIABLES
|
||||
- frame The frame of the epoch. Must be 'ET' or 'UTC'
|
||||
- sec The number of seconds past the input JD.
|
||||
- jd The Julian date of the epoch.
|
||||
- daynum The matplotlib day number of the epoch.
|
||||
- dt A python datetime instance.
|
||||
"""
|
||||
if ((sec is None and jd is not None) or
|
||||
(sec is not None and jd is None) or
|
||||
(daynum is not None and
|
||||
(sec is not None or jd is not None)) or
|
||||
(daynum is None and dt is None and
|
||||
(sec is None or jd is None)) or
|
||||
(daynum is not None and dt is not None) or
|
||||
(dt is not None and (sec is not None or jd is not None)) or
|
||||
((dt is not None) and not isinstance(dt, DT.datetime))):
|
||||
raise ValueError(
|
||||
"Invalid inputs. Must enter sec and jd together, "
|
||||
"daynum by itself, or dt (must be a python datetime).\n"
|
||||
"Sec = %s\n"
|
||||
"JD = %s\n"
|
||||
"dnum= %s\n"
|
||||
"dt = %s" % (sec, jd, daynum, dt))
|
||||
|
||||
_api.check_in_list(self.allowed, frame=frame)
|
||||
self._frame = frame
|
||||
|
||||
if dt is not None:
|
||||
daynum = date2num(dt)
|
||||
|
||||
if daynum is not None:
|
||||
# 1-JAN-0001 in JD = 1721425.5
|
||||
jd = float(daynum) + 1721425.5
|
||||
self._jd = math.floor(jd)
|
||||
self._seconds = (jd - self._jd) * 86400.0
|
||||
|
||||
else:
|
||||
self._seconds = float(sec)
|
||||
self._jd = float(jd)
|
||||
|
||||
# Resolve seconds down to [ 0, 86400)
|
||||
deltaDays = math.floor(self._seconds / 86400)
|
||||
self._jd += deltaDays
|
||||
self._seconds -= deltaDays * 86400.0
|
||||
|
||||
def convert(self, frame):
|
||||
if self._frame == frame:
|
||||
return self
|
||||
|
||||
offset = self.allowed[self._frame][frame]
|
||||
|
||||
return Epoch(frame, self._seconds + offset, self._jd)
|
||||
|
||||
def frame(self):
|
||||
return self._frame
|
||||
|
||||
def julianDate(self, frame):
|
||||
t = self
|
||||
if frame != self._frame:
|
||||
t = self.convert(frame)
|
||||
|
||||
return t._jd + t._seconds / 86400.0
|
||||
|
||||
def secondsPast(self, frame, jd):
|
||||
t = self
|
||||
if frame != self._frame:
|
||||
t = self.convert(frame)
|
||||
|
||||
delta = t._jd - jd
|
||||
return t._seconds + delta * 86400
|
||||
|
||||
def _cmp(self, op, rhs):
|
||||
"""Compare Epochs *self* and *rhs* using operator *op*."""
|
||||
t = self
|
||||
if self._frame != rhs._frame:
|
||||
t = self.convert(rhs._frame)
|
||||
if t._jd != rhs._jd:
|
||||
return op(t._jd, rhs._jd)
|
||||
return op(t._seconds, rhs._seconds)
|
||||
|
||||
__eq__ = functools.partialmethod(_cmp, operator.eq)
|
||||
__ne__ = functools.partialmethod(_cmp, operator.ne)
|
||||
__lt__ = functools.partialmethod(_cmp, operator.lt)
|
||||
__le__ = functools.partialmethod(_cmp, operator.le)
|
||||
__gt__ = functools.partialmethod(_cmp, operator.gt)
|
||||
__ge__ = functools.partialmethod(_cmp, operator.ge)
|
||||
|
||||
def __add__(self, rhs):
|
||||
"""
|
||||
Add a duration to an Epoch.
|
||||
|
||||
= INPUT VARIABLES
|
||||
- rhs The Epoch to subtract.
|
||||
|
||||
= RETURN VALUE
|
||||
- Returns the difference of ourselves and the input Epoch.
|
||||
"""
|
||||
t = self
|
||||
if self._frame != rhs.frame():
|
||||
t = self.convert(rhs._frame)
|
||||
|
||||
sec = t._seconds + rhs.seconds()
|
||||
|
||||
return Epoch(t._frame, sec, t._jd)
|
||||
|
||||
def __sub__(self, rhs):
|
||||
"""
|
||||
Subtract two Epoch's or a Duration from an Epoch.
|
||||
|
||||
Valid:
|
||||
Duration = Epoch - Epoch
|
||||
Epoch = Epoch - Duration
|
||||
|
||||
= INPUT VARIABLES
|
||||
- rhs The Epoch to subtract.
|
||||
|
||||
= RETURN VALUE
|
||||
- Returns either the duration between to Epoch's or the a new
|
||||
Epoch that is the result of subtracting a duration from an epoch.
|
||||
"""
|
||||
# Delay-load due to circular dependencies.
|
||||
import matplotlib.testing.jpl_units as U
|
||||
|
||||
# Handle Epoch - Duration
|
||||
if isinstance(rhs, U.Duration):
|
||||
return self + -rhs
|
||||
|
||||
t = self
|
||||
if self._frame != rhs._frame:
|
||||
t = self.convert(rhs._frame)
|
||||
|
||||
days = t._jd - rhs._jd
|
||||
sec = t._seconds - rhs._seconds
|
||||
|
||||
return U.Duration(rhs._frame, days*86400 + sec)
|
||||
|
||||
def __str__(self):
|
||||
"""Print the Epoch."""
|
||||
return f"{self.julianDate(self._frame):22.15e} {self._frame}"
|
||||
|
||||
def __repr__(self):
|
||||
"""Print the Epoch."""
|
||||
return str(self)
|
||||
|
||||
@staticmethod
|
||||
def range(start, stop, step):
|
||||
"""
|
||||
Generate a range of Epoch objects.
|
||||
|
||||
Similar to the Python range() method. Returns the range [
|
||||
start, stop) at the requested step. Each element will be a
|
||||
Epoch object.
|
||||
|
||||
= INPUT VARIABLES
|
||||
- start The starting value of the range.
|
||||
- stop The stop value of the range.
|
||||
- step Step to use.
|
||||
|
||||
= RETURN VALUE
|
||||
- Returns a list containing the requested Epoch values.
|
||||
"""
|
||||
elems = []
|
||||
|
||||
i = 0
|
||||
while True:
|
||||
d = start + i * step
|
||||
if d >= stop:
|
||||
break
|
||||
|
||||
elems.append(d)
|
||||
i += 1
|
||||
|
||||
return elems
|
||||
@ -0,0 +1,94 @@
|
||||
"""EpochConverter module containing class EpochConverter."""
|
||||
|
||||
from matplotlib import cbook, units
|
||||
import matplotlib.dates as date_ticker
|
||||
|
||||
__all__ = ['EpochConverter']
|
||||
|
||||
|
||||
class EpochConverter(units.ConversionInterface):
|
||||
"""
|
||||
Provides Matplotlib conversion functionality for Monte Epoch and Duration
|
||||
classes.
|
||||
"""
|
||||
|
||||
jdRef = 1721425.5
|
||||
|
||||
@staticmethod
|
||||
def axisinfo(unit, axis):
|
||||
# docstring inherited
|
||||
majloc = date_ticker.AutoDateLocator()
|
||||
majfmt = date_ticker.AutoDateFormatter(majloc)
|
||||
return units.AxisInfo(majloc=majloc, majfmt=majfmt, label=unit)
|
||||
|
||||
@staticmethod
|
||||
def float2epoch(value, unit):
|
||||
"""
|
||||
Convert a Matplotlib floating-point date into an Epoch of the specified
|
||||
units.
|
||||
|
||||
= INPUT VARIABLES
|
||||
- value The Matplotlib floating-point date.
|
||||
- unit The unit system to use for the Epoch.
|
||||
|
||||
= RETURN VALUE
|
||||
- Returns the value converted to an Epoch in the specified time system.
|
||||
"""
|
||||
# Delay-load due to circular dependencies.
|
||||
import matplotlib.testing.jpl_units as U
|
||||
|
||||
secPastRef = value * 86400.0 * U.UnitDbl(1.0, 'sec')
|
||||
return U.Epoch(unit, secPastRef, EpochConverter.jdRef)
|
||||
|
||||
@staticmethod
|
||||
def epoch2float(value, unit):
|
||||
"""
|
||||
Convert an Epoch value to a float suitable for plotting as a python
|
||||
datetime object.
|
||||
|
||||
= INPUT VARIABLES
|
||||
- value An Epoch or list of Epochs that need to be converted.
|
||||
- unit The units to use for an axis with Epoch data.
|
||||
|
||||
= RETURN VALUE
|
||||
- Returns the value parameter converted to floats.
|
||||
"""
|
||||
return value.julianDate(unit) - EpochConverter.jdRef
|
||||
|
||||
@staticmethod
|
||||
def duration2float(value):
|
||||
"""
|
||||
Convert a Duration value to a float suitable for plotting as a python
|
||||
datetime object.
|
||||
|
||||
= INPUT VARIABLES
|
||||
- value A Duration or list of Durations that need to be converted.
|
||||
|
||||
= RETURN VALUE
|
||||
- Returns the value parameter converted to floats.
|
||||
"""
|
||||
return value.seconds() / 86400.0
|
||||
|
||||
@staticmethod
|
||||
def convert(value, unit, axis):
|
||||
# docstring inherited
|
||||
|
||||
# Delay-load due to circular dependencies.
|
||||
import matplotlib.testing.jpl_units as U
|
||||
|
||||
if not cbook.is_scalar_or_string(value):
|
||||
return [EpochConverter.convert(x, unit, axis) for x in value]
|
||||
if unit is None:
|
||||
unit = EpochConverter.default_units(value, axis)
|
||||
if isinstance(value, U.Duration):
|
||||
return EpochConverter.duration2float(value)
|
||||
else:
|
||||
return EpochConverter.epoch2float(value, unit)
|
||||
|
||||
@staticmethod
|
||||
def default_units(value, axis):
|
||||
# docstring inherited
|
||||
if cbook.is_scalar_or_string(value):
|
||||
return value.frame()
|
||||
else:
|
||||
return EpochConverter.default_units(value[0], axis)
|
||||
@ -0,0 +1,97 @@
|
||||
"""StrConverter module containing class StrConverter."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import matplotlib.units as units
|
||||
|
||||
__all__ = ['StrConverter']
|
||||
|
||||
|
||||
class StrConverter(units.ConversionInterface):
|
||||
"""
|
||||
A Matplotlib converter class for string data values.
|
||||
|
||||
Valid units for string are:
|
||||
- 'indexed' : Values are indexed as they are specified for plotting.
|
||||
- 'sorted' : Values are sorted alphanumerically.
|
||||
- 'inverted' : Values are inverted so that the first value is on top.
|
||||
- 'sorted-inverted' : A combination of 'sorted' and 'inverted'
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def axisinfo(unit, axis):
|
||||
# docstring inherited
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def convert(value, unit, axis):
|
||||
# docstring inherited
|
||||
|
||||
if value == []:
|
||||
return []
|
||||
|
||||
# we delay loading to make matplotlib happy
|
||||
ax = axis.axes
|
||||
if axis is ax.xaxis:
|
||||
isXAxis = True
|
||||
else:
|
||||
isXAxis = False
|
||||
|
||||
axis.get_major_ticks()
|
||||
ticks = axis.get_ticklocs()
|
||||
labels = axis.get_ticklabels()
|
||||
|
||||
labels = [l.get_text() for l in labels if l.get_text()]
|
||||
|
||||
if not labels:
|
||||
ticks = []
|
||||
labels = []
|
||||
|
||||
if not np.iterable(value):
|
||||
value = [value]
|
||||
|
||||
newValues = []
|
||||
for v in value:
|
||||
if v not in labels and v not in newValues:
|
||||
newValues.append(v)
|
||||
|
||||
labels.extend(newValues)
|
||||
|
||||
# DISABLED: This is disabled because matplotlib bar plots do not
|
||||
# DISABLED: recalculate the unit conversion of the data values
|
||||
# DISABLED: this is due to design and is not really a bug.
|
||||
# DISABLED: If this gets changed, then we can activate the following
|
||||
# DISABLED: block of code. Note that this works for line plots.
|
||||
# DISABLED if unit:
|
||||
# DISABLED if unit.find("sorted") > -1:
|
||||
# DISABLED labels.sort()
|
||||
# DISABLED if unit.find("inverted") > -1:
|
||||
# DISABLED labels = labels[::-1]
|
||||
|
||||
# add padding (so they do not appear on the axes themselves)
|
||||
labels = [''] + labels + ['']
|
||||
ticks = list(range(len(labels)))
|
||||
ticks[0] = 0.5
|
||||
ticks[-1] = ticks[-1] - 0.5
|
||||
|
||||
axis.set_ticks(ticks)
|
||||
axis.set_ticklabels(labels)
|
||||
# we have to do the following lines to make ax.autoscale_view work
|
||||
loc = axis.get_major_locator()
|
||||
loc.set_bounds(ticks[0], ticks[-1])
|
||||
|
||||
if isXAxis:
|
||||
ax.set_xlim(ticks[0], ticks[-1])
|
||||
else:
|
||||
ax.set_ylim(ticks[0], ticks[-1])
|
||||
|
||||
result = [ticks[labels.index(v)] for v in value]
|
||||
|
||||
ax.viewLim.ignore(-1)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def default_units(value, axis):
|
||||
# docstring inherited
|
||||
# The default behavior for string indexing.
|
||||
return "indexed"
|
||||
@ -0,0 +1,180 @@
|
||||
"""UnitDbl module."""
|
||||
|
||||
import functools
|
||||
import operator
|
||||
|
||||
from matplotlib import _api
|
||||
|
||||
|
||||
class UnitDbl:
|
||||
"""Class UnitDbl in development."""
|
||||
|
||||
# Unit conversion table. Small subset of the full one but enough
|
||||
# to test the required functions. First field is a scale factor to
|
||||
# convert the input units to the units of the second field. Only
|
||||
# units in this table are allowed.
|
||||
allowed = {
|
||||
"m": (0.001, "km"),
|
||||
"km": (1, "km"),
|
||||
"mile": (1.609344, "km"),
|
||||
|
||||
"rad": (1, "rad"),
|
||||
"deg": (1.745329251994330e-02, "rad"),
|
||||
|
||||
"sec": (1, "sec"),
|
||||
"min": (60.0, "sec"),
|
||||
"hour": (3600, "sec"),
|
||||
}
|
||||
|
||||
_types = {
|
||||
"km": "distance",
|
||||
"rad": "angle",
|
||||
"sec": "time",
|
||||
}
|
||||
|
||||
def __init__(self, value, units):
|
||||
"""
|
||||
Create a new UnitDbl object.
|
||||
|
||||
Units are internally converted to km, rad, and sec. The only
|
||||
valid inputs for units are [m, km, mile, rad, deg, sec, min, hour].
|
||||
|
||||
The field UnitDbl.value will contain the converted value. Use
|
||||
the convert() method to get a specific type of units back.
|
||||
|
||||
= ERROR CONDITIONS
|
||||
- If the input units are not in the allowed list, an error is thrown.
|
||||
|
||||
= INPUT VARIABLES
|
||||
- value The numeric value of the UnitDbl.
|
||||
- units The string name of the units the value is in.
|
||||
"""
|
||||
data = _api.check_getitem(self.allowed, units=units)
|
||||
self._value = float(value * data[0])
|
||||
self._units = data[1]
|
||||
|
||||
def convert(self, units):
|
||||
"""
|
||||
Convert the UnitDbl to a specific set of units.
|
||||
|
||||
= ERROR CONDITIONS
|
||||
- If the input units are not in the allowed list, an error is thrown.
|
||||
|
||||
= INPUT VARIABLES
|
||||
- units The string name of the units to convert to.
|
||||
|
||||
= RETURN VALUE
|
||||
- Returns the value of the UnitDbl in the requested units as a floating
|
||||
point number.
|
||||
"""
|
||||
if self._units == units:
|
||||
return self._value
|
||||
data = _api.check_getitem(self.allowed, units=units)
|
||||
if self._units != data[1]:
|
||||
raise ValueError(f"Error trying to convert to different units.\n"
|
||||
f" Invalid conversion requested.\n"
|
||||
f" UnitDbl: {self}\n"
|
||||
f" Units: {units}\n")
|
||||
return self._value / data[0]
|
||||
|
||||
def __abs__(self):
|
||||
"""Return the absolute value of this UnitDbl."""
|
||||
return UnitDbl(abs(self._value), self._units)
|
||||
|
||||
def __neg__(self):
|
||||
"""Return the negative value of this UnitDbl."""
|
||||
return UnitDbl(-self._value, self._units)
|
||||
|
||||
def __bool__(self):
|
||||
"""Return the truth value of a UnitDbl."""
|
||||
return bool(self._value)
|
||||
|
||||
def _cmp(self, op, rhs):
|
||||
"""Check that *self* and *rhs* share units; compare them using *op*."""
|
||||
self.checkSameUnits(rhs, "compare")
|
||||
return op(self._value, rhs._value)
|
||||
|
||||
__eq__ = functools.partialmethod(_cmp, operator.eq)
|
||||
__ne__ = functools.partialmethod(_cmp, operator.ne)
|
||||
__lt__ = functools.partialmethod(_cmp, operator.lt)
|
||||
__le__ = functools.partialmethod(_cmp, operator.le)
|
||||
__gt__ = functools.partialmethod(_cmp, operator.gt)
|
||||
__ge__ = functools.partialmethod(_cmp, operator.ge)
|
||||
|
||||
def _binop_unit_unit(self, op, rhs):
|
||||
"""Check that *self* and *rhs* share units; combine them using *op*."""
|
||||
self.checkSameUnits(rhs, op.__name__)
|
||||
return UnitDbl(op(self._value, rhs._value), self._units)
|
||||
|
||||
__add__ = functools.partialmethod(_binop_unit_unit, operator.add)
|
||||
__sub__ = functools.partialmethod(_binop_unit_unit, operator.sub)
|
||||
|
||||
def _binop_unit_scalar(self, op, scalar):
|
||||
"""Combine *self* and *scalar* using *op*."""
|
||||
return UnitDbl(op(self._value, scalar), self._units)
|
||||
|
||||
__mul__ = functools.partialmethod(_binop_unit_scalar, operator.mul)
|
||||
__rmul__ = functools.partialmethod(_binop_unit_scalar, operator.mul)
|
||||
|
||||
def __str__(self):
|
||||
"""Print the UnitDbl."""
|
||||
return f"{self._value:g} *{self._units}"
|
||||
|
||||
def __repr__(self):
|
||||
"""Print the UnitDbl."""
|
||||
return f"UnitDbl({self._value:g}, '{self._units}')"
|
||||
|
||||
def type(self):
|
||||
"""Return the type of UnitDbl data."""
|
||||
return self._types[self._units]
|
||||
|
||||
@staticmethod
|
||||
def range(start, stop, step=None):
|
||||
"""
|
||||
Generate a range of UnitDbl objects.
|
||||
|
||||
Similar to the Python range() method. Returns the range [
|
||||
start, stop) at the requested step. Each element will be a
|
||||
UnitDbl object.
|
||||
|
||||
= INPUT VARIABLES
|
||||
- start The starting value of the range.
|
||||
- stop The stop value of the range.
|
||||
- step Optional step to use. If set to None, then a UnitDbl of
|
||||
value 1 w/ the units of the start is used.
|
||||
|
||||
= RETURN VALUE
|
||||
- Returns a list containing the requested UnitDbl values.
|
||||
"""
|
||||
if step is None:
|
||||
step = UnitDbl(1, start._units)
|
||||
|
||||
elems = []
|
||||
|
||||
i = 0
|
||||
while True:
|
||||
d = start + i * step
|
||||
if d >= stop:
|
||||
break
|
||||
|
||||
elems.append(d)
|
||||
i += 1
|
||||
|
||||
return elems
|
||||
|
||||
def checkSameUnits(self, rhs, func):
|
||||
"""
|
||||
Check to see if units are the same.
|
||||
|
||||
= ERROR CONDITIONS
|
||||
- If the units of the rhs UnitDbl are not the same as our units,
|
||||
an error is thrown.
|
||||
|
||||
= INPUT VARIABLES
|
||||
- rhs The UnitDbl to check for the same units
|
||||
- func The name of the function doing the check.
|
||||
"""
|
||||
if self._units != rhs._units:
|
||||
raise ValueError(f"Cannot {func} units of different types.\n"
|
||||
f"LHS: {self._units}\n"
|
||||
f"RHS: {rhs._units}")
|
||||
@ -0,0 +1,85 @@
|
||||
"""UnitDblConverter module containing class UnitDblConverter."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from matplotlib import cbook, units
|
||||
import matplotlib.projections.polar as polar
|
||||
|
||||
__all__ = ['UnitDblConverter']
|
||||
|
||||
|
||||
# A special function for use with the matplotlib FuncFormatter class
|
||||
# for formatting axes with radian units.
|
||||
# This was copied from matplotlib example code.
|
||||
def rad_fn(x, pos=None):
|
||||
"""Radian function formatter."""
|
||||
n = int((x / np.pi) * 2.0 + 0.25)
|
||||
if n == 0:
|
||||
return str(x)
|
||||
elif n == 1:
|
||||
return r'$\pi/2$'
|
||||
elif n == 2:
|
||||
return r'$\pi$'
|
||||
elif n % 2 == 0:
|
||||
return fr'${n//2}\pi$'
|
||||
else:
|
||||
return fr'${n}\pi/2$'
|
||||
|
||||
|
||||
class UnitDblConverter(units.ConversionInterface):
|
||||
"""
|
||||
Provides Matplotlib conversion functionality for the Monte UnitDbl class.
|
||||
"""
|
||||
# default for plotting
|
||||
defaults = {
|
||||
"distance": 'km',
|
||||
"angle": 'deg',
|
||||
"time": 'sec',
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def axisinfo(unit, axis):
|
||||
# docstring inherited
|
||||
|
||||
# Delay-load due to circular dependencies.
|
||||
import matplotlib.testing.jpl_units as U
|
||||
|
||||
# Check to see if the value used for units is a string unit value
|
||||
# or an actual instance of a UnitDbl so that we can use the unit
|
||||
# value for the default axis label value.
|
||||
if unit:
|
||||
label = unit if isinstance(unit, str) else unit.label()
|
||||
else:
|
||||
label = None
|
||||
|
||||
if label == "deg" and isinstance(axis.axes, polar.PolarAxes):
|
||||
# If we want degrees for a polar plot, use the PolarPlotFormatter
|
||||
majfmt = polar.PolarAxes.ThetaFormatter()
|
||||
else:
|
||||
majfmt = U.UnitDblFormatter(useOffset=False)
|
||||
|
||||
return units.AxisInfo(majfmt=majfmt, label=label)
|
||||
|
||||
@staticmethod
|
||||
def convert(value, unit, axis):
|
||||
# docstring inherited
|
||||
if not cbook.is_scalar_or_string(value):
|
||||
return [UnitDblConverter.convert(x, unit, axis) for x in value]
|
||||
# If no units were specified, then get the default units to use.
|
||||
if unit is None:
|
||||
unit = UnitDblConverter.default_units(value, axis)
|
||||
# Convert the incoming UnitDbl value/values to float/floats
|
||||
if isinstance(axis.axes, polar.PolarAxes) and value.type() == "angle":
|
||||
# Guarantee that units are radians for polar plots.
|
||||
return value.convert("rad")
|
||||
return value.convert(unit)
|
||||
|
||||
@staticmethod
|
||||
def default_units(value, axis):
|
||||
# docstring inherited
|
||||
# Determine the default units based on the user preferences set for
|
||||
# default units when printing a UnitDbl.
|
||||
if cbook.is_scalar_or_string(value):
|
||||
return UnitDblConverter.defaults[value.type()]
|
||||
else:
|
||||
return UnitDblConverter.default_units(value[0], axis)
|
||||
@ -0,0 +1,28 @@
|
||||
"""UnitDblFormatter module containing class UnitDblFormatter."""
|
||||
|
||||
import matplotlib.ticker as ticker
|
||||
|
||||
__all__ = ['UnitDblFormatter']
|
||||
|
||||
|
||||
class UnitDblFormatter(ticker.ScalarFormatter):
|
||||
"""
|
||||
The formatter for UnitDbl data types.
|
||||
|
||||
This allows for formatting with the unit string.
|
||||
"""
|
||||
|
||||
def __call__(self, x, pos=None):
|
||||
# docstring inherited
|
||||
if len(self.locs) == 0:
|
||||
return ''
|
||||
else:
|
||||
return f'{x:.12}'
|
||||
|
||||
def format_data_short(self, value):
|
||||
# docstring inherited
|
||||
return f'{value:.12}'
|
||||
|
||||
def format_data(self, value):
|
||||
# docstring inherited
|
||||
return f'{value:.12}'
|
||||
@ -0,0 +1,76 @@
|
||||
"""
|
||||
A sample set of units for use with testing unit conversion
|
||||
of Matplotlib routines. These are used because they use very strict
|
||||
enforcement of unitized data which will test the entire spectrum of how
|
||||
unitized data might be used (it is not always meaningful to convert to
|
||||
a float without specific units given).
|
||||
|
||||
UnitDbl is essentially a unitized floating point number. It has a
|
||||
minimal set of supported units (enough for testing purposes). All
|
||||
of the mathematical operation are provided to fully test any behaviour
|
||||
that might occur with unitized data. Remember that unitized data has
|
||||
rules as to how it can be applied to one another (a value of distance
|
||||
cannot be added to a value of time). Thus we need to guard against any
|
||||
accidental "default" conversion that will strip away the meaning of the
|
||||
data and render it neutered.
|
||||
|
||||
Epoch is different than a UnitDbl of time. Time is something that can be
|
||||
measured where an Epoch is a specific moment in time. Epochs are typically
|
||||
referenced as an offset from some predetermined epoch.
|
||||
|
||||
A difference of two epochs is a Duration. The distinction between a Duration
|
||||
and a UnitDbl of time is made because an Epoch can have different frames (or
|
||||
units). In the case of our test Epoch class the two allowed frames are 'UTC'
|
||||
and 'ET' (Note that these are rough estimates provided for testing purposes
|
||||
and should not be used in production code where accuracy of time frames is
|
||||
desired). As such a Duration also has a frame of reference and therefore needs
|
||||
to be called out as different that a simple measurement of time since a delta-t
|
||||
in one frame may not be the same in another.
|
||||
"""
|
||||
|
||||
from .Duration import Duration
|
||||
from .Epoch import Epoch
|
||||
from .UnitDbl import UnitDbl
|
||||
|
||||
from .StrConverter import StrConverter
|
||||
from .EpochConverter import EpochConverter
|
||||
from .UnitDblConverter import UnitDblConverter
|
||||
|
||||
from .UnitDblFormatter import UnitDblFormatter
|
||||
|
||||
|
||||
__version__ = "1.0"
|
||||
|
||||
__all__ = [
|
||||
'register',
|
||||
'Duration',
|
||||
'Epoch',
|
||||
'UnitDbl',
|
||||
'UnitDblFormatter',
|
||||
]
|
||||
|
||||
|
||||
def register():
|
||||
"""Register the unit conversion classes with matplotlib."""
|
||||
import matplotlib.units as mplU
|
||||
|
||||
mplU.registry[str] = StrConverter()
|
||||
mplU.registry[Epoch] = EpochConverter()
|
||||
mplU.registry[Duration] = EpochConverter()
|
||||
mplU.registry[UnitDbl] = UnitDblConverter()
|
||||
|
||||
|
||||
# Some default unit instances
|
||||
# Distances
|
||||
m = UnitDbl(1.0, "m")
|
||||
km = UnitDbl(1.0, "km")
|
||||
mile = UnitDbl(1.0, "mile")
|
||||
# Angles
|
||||
deg = UnitDbl(1.0, "deg")
|
||||
rad = UnitDbl(1.0, "rad")
|
||||
# Time
|
||||
sec = UnitDbl(1.0, "sec")
|
||||
min = UnitDbl(1.0, "min")
|
||||
hr = UnitDbl(1.0, "hour")
|
||||
day = UnitDbl(24.0, "hour")
|
||||
sec = UnitDbl(1.0, "sec")
|
||||
119
venv/lib/python3.12/site-packages/matplotlib/testing/widgets.py
Normal file
119
venv/lib/python3.12/site-packages/matplotlib/testing/widgets.py
Normal file
@ -0,0 +1,119 @@
|
||||
"""
|
||||
========================
|
||||
Widget testing utilities
|
||||
========================
|
||||
|
||||
See also :mod:`matplotlib.tests.test_widgets`.
|
||||
"""
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def get_ax():
|
||||
"""Create a plot and return its Axes."""
|
||||
fig, ax = plt.subplots(1, 1)
|
||||
ax.plot([0, 200], [0, 200])
|
||||
ax.set_aspect(1.0)
|
||||
ax.figure.canvas.draw()
|
||||
return ax
|
||||
|
||||
|
||||
def noop(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def mock_event(ax, button=1, xdata=0, ydata=0, key=None, step=1):
|
||||
r"""
|
||||
Create a mock event that can stand in for `.Event` and its subclasses.
|
||||
|
||||
This event is intended to be used in tests where it can be passed into
|
||||
event handling functions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ax : `~matplotlib.axes.Axes`
|
||||
The Axes the event will be in.
|
||||
xdata : float
|
||||
x coord of mouse in data coords.
|
||||
ydata : float
|
||||
y coord of mouse in data coords.
|
||||
button : None or `MouseButton` or {'up', 'down'}
|
||||
The mouse button pressed in this event (see also `.MouseEvent`).
|
||||
key : None or str
|
||||
The key pressed when the mouse event triggered (see also `.KeyEvent`).
|
||||
step : int
|
||||
Number of scroll steps (positive for 'up', negative for 'down').
|
||||
|
||||
Returns
|
||||
-------
|
||||
event
|
||||
A `.Event`\-like Mock instance.
|
||||
"""
|
||||
event = mock.Mock()
|
||||
event.button = button
|
||||
event.x, event.y = ax.transData.transform([(xdata, ydata),
|
||||
(xdata, ydata)])[0]
|
||||
event.xdata, event.ydata = xdata, ydata
|
||||
event.inaxes = ax
|
||||
event.canvas = ax.figure.canvas
|
||||
event.key = key
|
||||
event.step = step
|
||||
event.guiEvent = None
|
||||
event.name = 'Custom'
|
||||
return event
|
||||
|
||||
|
||||
def do_event(tool, etype, button=1, xdata=0, ydata=0, key=None, step=1):
|
||||
"""
|
||||
Trigger an event on the given tool.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tool : matplotlib.widgets.AxesWidget
|
||||
etype : str
|
||||
The event to trigger.
|
||||
xdata : float
|
||||
x coord of mouse in data coords.
|
||||
ydata : float
|
||||
y coord of mouse in data coords.
|
||||
button : None or `MouseButton` or {'up', 'down'}
|
||||
The mouse button pressed in this event (see also `.MouseEvent`).
|
||||
key : None or str
|
||||
The key pressed when the mouse event triggered (see also `.KeyEvent`).
|
||||
step : int
|
||||
Number of scroll steps (positive for 'up', negative for 'down').
|
||||
"""
|
||||
event = mock_event(tool.ax, button, xdata, ydata, key, step)
|
||||
func = getattr(tool, etype)
|
||||
func(event)
|
||||
|
||||
|
||||
def click_and_drag(tool, start, end, key=None):
|
||||
"""
|
||||
Helper to simulate a mouse drag operation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tool : `~matplotlib.widgets.Widget`
|
||||
start : [float, float]
|
||||
Starting point in data coordinates.
|
||||
end : [float, float]
|
||||
End point in data coordinates.
|
||||
key : None or str
|
||||
An optional key that is pressed during the whole operation
|
||||
(see also `.KeyEvent`).
|
||||
"""
|
||||
if key is not None:
|
||||
# Press key
|
||||
do_event(tool, 'on_key_press', xdata=start[0], ydata=start[1],
|
||||
button=1, key=key)
|
||||
# Click, move, and release mouse
|
||||
do_event(tool, 'press', xdata=start[0], ydata=start[1], button=1)
|
||||
do_event(tool, 'onmove', xdata=end[0], ydata=end[1], button=1)
|
||||
do_event(tool, 'release', xdata=end[0], ydata=end[1], button=1)
|
||||
if key is not None:
|
||||
# Release key
|
||||
do_event(tool, 'on_key_release', xdata=end[0], ydata=end[1],
|
||||
button=1, key=key)
|
||||
@ -0,0 +1,31 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from matplotlib.axes import Axes
|
||||
from matplotlib.backend_bases import Event, MouseButton
|
||||
from matplotlib.widgets import AxesWidget, Widget
|
||||
|
||||
def get_ax() -> Axes: ...
|
||||
def noop(*args: Any, **kwargs: Any) -> None: ...
|
||||
def mock_event(
|
||||
ax: Axes,
|
||||
button: MouseButton | int | Literal["up", "down"] | None = ...,
|
||||
xdata: float = ...,
|
||||
ydata: float = ...,
|
||||
key: str | None = ...,
|
||||
step: int = ...,
|
||||
) -> Event: ...
|
||||
def do_event(
|
||||
tool: AxesWidget,
|
||||
etype: str,
|
||||
button: MouseButton | int | Literal["up", "down"] | None = ...,
|
||||
xdata: float = ...,
|
||||
ydata: float = ...,
|
||||
key: str | None = ...,
|
||||
step: int = ...,
|
||||
) -> None: ...
|
||||
def click_and_drag(
|
||||
tool: Widget,
|
||||
start: tuple[float, float],
|
||||
end: tuple[float, float],
|
||||
key: str | None = ...,
|
||||
) -> None: ...
|
||||
Reference in New Issue
Block a user