This commit is contained in:
2024-11-29 18:15:30 +00:00
parent 40aade2d8e
commit bc9415586e
5298 changed files with 1938676 additions and 80 deletions

View File

@ -0,0 +1,161 @@
# This file is generated by SciPy's build process
# It contains system_info results at the time of building this package.
from enum import Enum
__all__ = ["show"]
_built_with_meson = True
class DisplayModes(Enum):
stdout = "stdout"
dicts = "dicts"
def _cleanup(d):
"""
Removes empty values in a `dict` recursively
This ensures we remove values that Meson could not provide to CONFIG
"""
if isinstance(d, dict):
return { k: _cleanup(v) for k, v in d.items() if v != '' and _cleanup(v) != '' }
else:
return d
CONFIG = _cleanup(
{
"Compilers": {
"c": {
"name": "gcc",
"linker": r"ld.bfd",
"version": "10.2.1",
"commands": r"cc",
"args": r"",
"linker args": r"",
},
"cython": {
"name": r"cython",
"linker": r"cython",
"version": r"3.0.11",
"commands": r"cython",
"args": r"",
"linker args": r"",
},
"c++": {
"name": "gcc",
"linker": r"ld.bfd",
"version": "10.2.1",
"commands": r"c++",
"args": r"",
"linker args": r"",
},
"fortran": {
"name": "gcc",
"linker": r"ld.bfd",
"version": "10.2.1",
"commands": r"gfortran",
"args": r"",
"linker args": r"",
},
"pythran": {
"version": r"0.16.1",
"include directory": r"../../tmp/pip-build-env-znxdftlb/overlay/lib/python3.12/site-packages/pythran"
},
},
"Machine Information": {
"host": {
"cpu": r"x86_64",
"family": r"x86_64",
"endian": r"little",
"system": r"linux",
},
"build": {
"cpu": r"x86_64",
"family": r"x86_64",
"endian": r"little",
"system": r"linux",
},
"cross-compiled": bool("False".lower().replace('false', '')),
},
"Build Dependencies": {
"blas": {
"name": "scipy-openblas",
"found": bool("True".lower().replace('false', '')),
"version": "0.3.27.dev",
"detection method": "pkgconfig",
"include directory": r"/opt/_internal/cpython-3.12.4/lib/python3.12/site-packages/scipy_openblas32/include",
"lib directory": r"/opt/_internal/cpython-3.12.4/lib/python3.12/site-packages/scipy_openblas32/lib",
"openblas configuration": r"OpenBLAS 0.3.27.dev DYNAMIC_ARCH NO_AFFINITY Zen MAX_THREADS=64",
"pc file directory": r"/project",
},
"lapack": {
"name": "scipy-openblas",
"found": bool("True".lower().replace('false', '')),
"version": "0.3.27.dev",
"detection method": "pkgconfig",
"include directory": r"/opt/_internal/cpython-3.12.4/lib/python3.12/site-packages/scipy_openblas32/include",
"lib directory": r"/opt/_internal/cpython-3.12.4/lib/python3.12/site-packages/scipy_openblas32/lib",
"openblas configuration": r"OpenBLAS 0.3.27.dev DYNAMIC_ARCH NO_AFFINITY Zen MAX_THREADS=64",
"pc file directory": r"/project",
},
"pybind11": {
"name": "pybind11",
"version": "2.12.0",
"detection method": "config-tool",
"include directory": r"unknown",
},
},
"Python Information": {
"path": r"/opt/python/cp312-cp312/bin/python",
"version": "3.12",
},
}
)
def _check_pyyaml():
import yaml
return yaml
def show(mode=DisplayModes.stdout.value):
"""
Show libraries and system information on which SciPy was built
and is being used
Parameters
----------
mode : {`'stdout'`, `'dicts'`}, optional.
Indicates how to display the config information.
`'stdout'` prints to console, `'dicts'` returns a dictionary
of the configuration.
Returns
-------
out : {`dict`, `None`}
If mode is `'dicts'`, a dict is returned, else None
Notes
-----
1. The `'stdout'` mode will give more readable
output if ``pyyaml`` is installed
"""
if mode == DisplayModes.stdout.value:
try: # Non-standard library, check import
yaml = _check_pyyaml()
print(yaml.dump(CONFIG))
except ModuleNotFoundError:
import warnings
import json
warnings.warn("Install `pyyaml` for better output", stacklevel=1)
print(json.dumps(CONFIG, indent=2))
elif mode == DisplayModes.dicts.value:
return CONFIG
else:
raise AttributeError(
f"Invalid `mode`, use one of: {', '.join([e.value for e in DisplayModes])}"
)

View File

@ -0,0 +1,141 @@
"""
SciPy: A scientific computing package for Python
================================================
Documentation is available in the docstrings and
online at https://docs.scipy.org.
Subpackages
-----------
Using any of these subpackages requires an explicit import. For example,
``import scipy.cluster``.
::
cluster --- Vector Quantization / Kmeans
constants --- Physical and mathematical constants and units
datasets --- Dataset methods
fft --- Discrete Fourier transforms
fftpack --- Legacy discrete Fourier transforms
integrate --- Integration routines
interpolate --- Interpolation Tools
io --- Data input and output
linalg --- Linear algebra routines
misc --- Utilities that don't have another home.
ndimage --- N-D image package
odr --- Orthogonal Distance Regression
optimize --- Optimization Tools
signal --- Signal Processing Tools
sparse --- Sparse Matrices
spatial --- Spatial data structures and algorithms
special --- Special functions
stats --- Statistical Functions
Public API in the main SciPy namespace
--------------------------------------
::
__version__ --- SciPy version string
LowLevelCallable --- Low-level callback function
show_config --- Show scipy build configuration
test --- Run scipy unittests
"""
import importlib as _importlib
from numpy import __version__ as __numpy_version__
try:
from scipy.__config__ import show as show_config
except ImportError as e:
msg = """Error importing SciPy: you cannot import SciPy while
being in scipy source directory; please exit the SciPy source
tree first and relaunch your Python interpreter."""
raise ImportError(msg) from e
from scipy.version import version as __version__
# Allow distributors to run custom init code
from . import _distributor_init
del _distributor_init
from scipy._lib import _pep440
# In maintenance branch, change to np_maxversion N+3 if numpy is at N
np_minversion = '1.23.5'
np_maxversion = '2.3.0'
if (_pep440.parse(__numpy_version__) < _pep440.Version(np_minversion) or
_pep440.parse(__numpy_version__) >= _pep440.Version(np_maxversion)):
import warnings
warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
f" is required for this version of SciPy (detected "
f"version {__numpy_version__})",
UserWarning, stacklevel=2)
del _pep440
# This is the first import of an extension module within SciPy. If there's
# a general issue with the install, such that extension modules are missing
# or cannot be imported, this is where we'll get a failure - so give an
# informative error message.
try:
from scipy._lib._ccallback import LowLevelCallable
except ImportError as e:
msg = "The `scipy` install you are using seems to be broken, " + \
"(extension modules cannot be imported), " + \
"please try reinstalling."
raise ImportError(msg) from e
from scipy._lib._testutils import PytestTester
test = PytestTester(__name__)
del PytestTester
submodules = [
'cluster',
'constants',
'datasets',
'fft',
'fftpack',
'integrate',
'interpolate',
'io',
'linalg',
'misc',
'ndimage',
'odr',
'optimize',
'signal',
'sparse',
'spatial',
'special',
'stats'
]
__all__ = submodules + [
'LowLevelCallable',
'test',
'show_config',
'__version__',
]
def __dir__():
return __all__
def __getattr__(name):
if name in submodules:
return _importlib.import_module(f'scipy.{name}')
else:
try:
return globals()[name]
except KeyError:
raise AttributeError(
f"Module 'scipy' has no attribute '{name}'"
)

View File

@ -0,0 +1,18 @@
""" Distributor init file
Distributors: you can replace the contents of this file with your own custom
code to support particular distributions of SciPy.
For example, this is a good place to put any checks for hardware requirements
or BLAS/LAPACK library initialization.
The SciPy standard source distribution will not put code in this file beyond
the try-except import of `_distributor_init_local` (which is not part of a
standard source distribution), so you can safely replace this file with your
own version.
"""
try:
from . import _distributor_init_local # noqa: F401
except ImportError:
pass

View File

@ -0,0 +1,14 @@
"""
Module containing private utility functions
===========================================
The ``scipy._lib`` namespace is empty (for now). Tests for all
utilities in submodules of ``_lib`` can be run with::
from scipy import _lib
_lib.test()
"""
from scipy._lib._testutils import PytestTester
test = PytestTester(__name__)
del PytestTester

View File

@ -0,0 +1,524 @@
"""Utility functions to use Python Array API compatible libraries.
For the context about the Array API see:
https://data-apis.org/array-api/latest/purpose_and_scope.html
The SciPy use case of the Array API is described on the following page:
https://data-apis.org/array-api/latest/use_cases.html#use-case-scipy
"""
from __future__ import annotations
import os
import warnings
from types import ModuleType
from typing import Any, Literal, TYPE_CHECKING
import numpy as np
import numpy.typing as npt
from scipy._lib import array_api_compat
from scipy._lib.array_api_compat import (
is_array_api_obj,
size,
numpy as np_compat,
device
)
__all__ = ['array_namespace', '_asarray', 'size', 'device']
# To enable array API and strict array-like input validation
SCIPY_ARRAY_API: str | bool = os.environ.get("SCIPY_ARRAY_API", False)
# To control the default device - for use in the test suite only
SCIPY_DEVICE = os.environ.get("SCIPY_DEVICE", "cpu")
_GLOBAL_CONFIG = {
"SCIPY_ARRAY_API": SCIPY_ARRAY_API,
"SCIPY_DEVICE": SCIPY_DEVICE,
}
if TYPE_CHECKING:
Array = Any # To be changed to a Protocol later (see array-api#589)
ArrayLike = Array | npt.ArrayLike
def compliance_scipy(arrays: list[ArrayLike]) -> list[Array]:
"""Raise exceptions on known-bad subclasses.
The following subclasses are not supported and raise and error:
- `numpy.ma.MaskedArray`
- `numpy.matrix`
- NumPy arrays which do not have a boolean or numerical dtype
- Any array-like which is neither array API compatible nor coercible by NumPy
- Any array-like which is coerced by NumPy to an unsupported dtype
"""
for i in range(len(arrays)):
array = arrays[i]
if isinstance(array, np.ma.MaskedArray):
raise TypeError("Inputs of type `numpy.ma.MaskedArray` are not supported.")
elif isinstance(array, np.matrix):
raise TypeError("Inputs of type `numpy.matrix` are not supported.")
if isinstance(array, (np.ndarray, np.generic)):
dtype = array.dtype
if not (np.issubdtype(dtype, np.number) or np.issubdtype(dtype, np.bool_)):
raise TypeError(f"An argument has dtype `{dtype!r}`; "
f"only boolean and numerical dtypes are supported.")
elif not is_array_api_obj(array):
try:
array = np.asanyarray(array)
except TypeError:
raise TypeError("An argument is neither array API compatible nor "
"coercible by NumPy.")
dtype = array.dtype
if not (np.issubdtype(dtype, np.number) or np.issubdtype(dtype, np.bool_)):
message = (
f"An argument was coerced to an unsupported dtype `{dtype!r}`; "
f"only boolean and numerical dtypes are supported."
)
raise TypeError(message)
arrays[i] = array
return arrays
def _check_finite(array: Array, xp: ModuleType) -> None:
"""Check for NaNs or Infs."""
msg = "array must not contain infs or NaNs"
try:
if not xp.all(xp.isfinite(array)):
raise ValueError(msg)
except TypeError:
raise ValueError(msg)
def array_namespace(*arrays: Array) -> ModuleType:
"""Get the array API compatible namespace for the arrays xs.
Parameters
----------
*arrays : sequence of array_like
Arrays used to infer the common namespace.
Returns
-------
namespace : module
Common namespace.
Notes
-----
Thin wrapper around `array_api_compat.array_namespace`.
1. Check for the global switch: SCIPY_ARRAY_API. This can also be accessed
dynamically through ``_GLOBAL_CONFIG['SCIPY_ARRAY_API']``.
2. `compliance_scipy` raise exceptions on known-bad subclasses. See
its definition for more details.
When the global switch is False, it defaults to the `numpy` namespace.
In that case, there is no compliance check. This is a convenience to
ease the adoption. Otherwise, arrays must comply with the new rules.
"""
if not _GLOBAL_CONFIG["SCIPY_ARRAY_API"]:
# here we could wrap the namespace if needed
return np_compat
_arrays = [array for array in arrays if array is not None]
_arrays = compliance_scipy(_arrays)
return array_api_compat.array_namespace(*_arrays)
def _asarray(
array: ArrayLike,
dtype: Any = None,
order: Literal['K', 'A', 'C', 'F'] | None = None,
copy: bool | None = None,
*,
xp: ModuleType | None = None,
check_finite: bool = False,
subok: bool = False,
) -> Array:
"""SciPy-specific replacement for `np.asarray` with `order`, `check_finite`, and
`subok`.
Memory layout parameter `order` is not exposed in the Array API standard.
`order` is only enforced if the input array implementation
is NumPy based, otherwise `order` is just silently ignored.
`check_finite` is also not a keyword in the array API standard; included
here for convenience rather than that having to be a separate function
call inside SciPy functions.
`subok` is included to allow this function to preserve the behaviour of
`np.asanyarray` for NumPy based inputs.
"""
if xp is None:
xp = array_namespace(array)
if xp.__name__ in {"numpy", "scipy._lib.array_api_compat.numpy"}:
# Use NumPy API to support order
if copy is True:
array = np.array(array, order=order, dtype=dtype, subok=subok)
elif subok:
array = np.asanyarray(array, order=order, dtype=dtype)
else:
array = np.asarray(array, order=order, dtype=dtype)
# At this point array is a NumPy ndarray. We convert it to an array
# container that is consistent with the input's namespace.
array = xp.asarray(array)
else:
try:
array = xp.asarray(array, dtype=dtype, copy=copy)
except TypeError:
coerced_xp = array_namespace(xp.asarray(3))
array = coerced_xp.asarray(array, dtype=dtype, copy=copy)
if check_finite:
_check_finite(array, xp)
return array
def atleast_nd(x: Array, *, ndim: int, xp: ModuleType | None = None) -> Array:
"""Recursively expand the dimension to have at least `ndim`."""
if xp is None:
xp = array_namespace(x)
x = xp.asarray(x)
if x.ndim < ndim:
x = xp.expand_dims(x, axis=0)
x = atleast_nd(x, ndim=ndim, xp=xp)
return x
def copy(x: Array, *, xp: ModuleType | None = None) -> Array:
"""
Copies an array.
Parameters
----------
x : array
xp : array_namespace
Returns
-------
copy : array
Copied array
Notes
-----
This copy function does not offer all the semantics of `np.copy`, i.e. the
`subok` and `order` keywords are not used.
"""
# Note: xp.asarray fails if xp is numpy.
if xp is None:
xp = array_namespace(x)
return _asarray(x, copy=True, xp=xp)
def is_numpy(xp: ModuleType) -> bool:
return xp.__name__ in ('numpy', 'scipy._lib.array_api_compat.numpy')
def is_cupy(xp: ModuleType) -> bool:
return xp.__name__ in ('cupy', 'scipy._lib.array_api_compat.cupy')
def is_torch(xp: ModuleType) -> bool:
return xp.__name__ in ('torch', 'scipy._lib.array_api_compat.torch')
def is_jax(xp):
return xp.__name__ in ('jax.numpy', 'jax.experimental.array_api')
def _strict_check(actual, desired, xp,
check_namespace=True, check_dtype=True, check_shape=True):
__tracebackhide__ = True # Hide traceback for py.test
if check_namespace:
_assert_matching_namespace(actual, desired)
desired = xp.asarray(desired)
if check_dtype:
_msg = f"dtypes do not match.\nActual: {actual.dtype}\nDesired: {desired.dtype}"
assert actual.dtype == desired.dtype, _msg
if check_shape:
_msg = f"Shapes do not match.\nActual: {actual.shape}\nDesired: {desired.shape}"
assert actual.shape == desired.shape, _msg
_check_scalar(actual, desired, xp)
desired = xp.broadcast_to(desired, actual.shape)
return desired
def _assert_matching_namespace(actual, desired):
__tracebackhide__ = True # Hide traceback for py.test
actual = actual if isinstance(actual, tuple) else (actual,)
desired_space = array_namespace(desired)
for arr in actual:
arr_space = array_namespace(arr)
_msg = (f"Namespaces do not match.\n"
f"Actual: {arr_space.__name__}\n"
f"Desired: {desired_space.__name__}")
assert arr_space == desired_space, _msg
def _check_scalar(actual, desired, xp):
__tracebackhide__ = True # Hide traceback for py.test
# Shape check alone is sufficient unless desired.shape == (). Also,
# only NumPy distinguishes between scalars and arrays.
if desired.shape != () or not is_numpy(xp):
return
# We want to follow the conventions of the `xp` library. Libraries like
# NumPy, for which `np.asarray(0)[()]` returns a scalar, tend to return
# a scalar even when a 0D array might be more appropriate:
# import numpy as np
# np.mean([1, 2, 3]) # scalar, not 0d array
# np.asarray(0)*2 # scalar, not 0d array
# np.sin(np.asarray(0)) # scalar, not 0d array
# Libraries like CuPy, for which `cp.asarray(0)[()]` returns a 0D array,
# tend to return a 0D array in scenarios like those above.
# Therefore, regardless of whether the developer provides a scalar or 0D
# array for `desired`, we would typically want the type of `actual` to be
# the type of `desired[()]`. If the developer wants to override this
# behavior, they can set `check_shape=False`.
desired = desired[()]
_msg = f"Types do not match:\n Actual: {type(actual)}\n Desired: {type(desired)}"
assert (xp.isscalar(actual) and xp.isscalar(desired)
or (not xp.isscalar(actual) and not xp.isscalar(desired))), _msg
def xp_assert_equal(actual, desired, check_namespace=True, check_dtype=True,
check_shape=True, err_msg='', xp=None):
__tracebackhide__ = True # Hide traceback for py.test
if xp is None:
xp = array_namespace(actual)
desired = _strict_check(actual, desired, xp, check_namespace=check_namespace,
check_dtype=check_dtype, check_shape=check_shape)
if is_cupy(xp):
return xp.testing.assert_array_equal(actual, desired, err_msg=err_msg)
elif is_torch(xp):
# PyTorch recommends using `rtol=0, atol=0` like this
# to test for exact equality
err_msg = None if err_msg == '' else err_msg
return xp.testing.assert_close(actual, desired, rtol=0, atol=0, equal_nan=True,
check_dtype=False, msg=err_msg)
# JAX uses `np.testing`
return np.testing.assert_array_equal(actual, desired, err_msg=err_msg)
def xp_assert_close(actual, desired, rtol=None, atol=0, check_namespace=True,
check_dtype=True, check_shape=True, err_msg='', xp=None):
__tracebackhide__ = True # Hide traceback for py.test
if xp is None:
xp = array_namespace(actual)
desired = _strict_check(actual, desired, xp, check_namespace=check_namespace,
check_dtype=check_dtype, check_shape=check_shape)
floating = xp.isdtype(actual.dtype, ('real floating', 'complex floating'))
if rtol is None and floating:
# multiplier of 4 is used as for `np.float64` this puts the default `rtol`
# roughly half way between sqrt(eps) and the default for
# `numpy.testing.assert_allclose`, 1e-7
rtol = xp.finfo(actual.dtype).eps**0.5 * 4
elif rtol is None:
rtol = 1e-7
if is_cupy(xp):
return xp.testing.assert_allclose(actual, desired, rtol=rtol,
atol=atol, err_msg=err_msg)
elif is_torch(xp):
err_msg = None if err_msg == '' else err_msg
return xp.testing.assert_close(actual, desired, rtol=rtol, atol=atol,
equal_nan=True, check_dtype=False, msg=err_msg)
# JAX uses `np.testing`
return np.testing.assert_allclose(actual, desired, rtol=rtol,
atol=atol, err_msg=err_msg)
def xp_assert_less(actual, desired, check_namespace=True, check_dtype=True,
check_shape=True, err_msg='', verbose=True, xp=None):
__tracebackhide__ = True # Hide traceback for py.test
if xp is None:
xp = array_namespace(actual)
desired = _strict_check(actual, desired, xp, check_namespace=check_namespace,
check_dtype=check_dtype, check_shape=check_shape)
if is_cupy(xp):
return xp.testing.assert_array_less(actual, desired,
err_msg=err_msg, verbose=verbose)
elif is_torch(xp):
if actual.device.type != 'cpu':
actual = actual.cpu()
if desired.device.type != 'cpu':
desired = desired.cpu()
# JAX uses `np.testing`
return np.testing.assert_array_less(actual, desired,
err_msg=err_msg, verbose=verbose)
def cov(x: Array, *, xp: ModuleType | None = None) -> Array:
if xp is None:
xp = array_namespace(x)
X = copy(x, xp=xp)
dtype = xp.result_type(X, xp.float64)
X = atleast_nd(X, ndim=2, xp=xp)
X = xp.asarray(X, dtype=dtype)
avg = xp.mean(X, axis=1)
fact = X.shape[1] - 1
if fact <= 0:
warnings.warn("Degrees of freedom <= 0 for slice",
RuntimeWarning, stacklevel=2)
fact = 0.0
X -= avg[:, None]
X_T = X.T
if xp.isdtype(X_T.dtype, 'complex floating'):
X_T = xp.conj(X_T)
c = X @ X_T
c /= fact
axes = tuple(axis for axis, length in enumerate(c.shape) if length == 1)
return xp.squeeze(c, axis=axes)
def xp_unsupported_param_msg(param: Any) -> str:
return f'Providing {param!r} is only supported for numpy arrays.'
def is_complex(x: Array, xp: ModuleType) -> bool:
return xp.isdtype(x.dtype, 'complex floating')
def get_xp_devices(xp: ModuleType) -> list[str] | list[None]:
"""Returns a list of available devices for the given namespace."""
devices: list[str] = []
if is_torch(xp):
devices += ['cpu']
import torch # type: ignore[import]
num_cuda = torch.cuda.device_count()
for i in range(0, num_cuda):
devices += [f'cuda:{i}']
if torch.backends.mps.is_available():
devices += ['mps']
return devices
elif is_cupy(xp):
import cupy # type: ignore[import]
num_cuda = cupy.cuda.runtime.getDeviceCount()
for i in range(0, num_cuda):
devices += [f'cuda:{i}']
return devices
elif is_jax(xp):
import jax # type: ignore[import]
num_cpu = jax.device_count(backend='cpu')
for i in range(0, num_cpu):
devices += [f'cpu:{i}']
num_gpu = jax.device_count(backend='gpu')
for i in range(0, num_gpu):
devices += [f'gpu:{i}']
num_tpu = jax.device_count(backend='tpu')
for i in range(0, num_tpu):
devices += [f'tpu:{i}']
return devices
# given namespace is not known to have a list of available devices;
# return `[None]` so that one can use this in tests for `device=None`.
return [None]
def scipy_namespace_for(xp: ModuleType) -> ModuleType:
"""
Return the `scipy` namespace for alternative backends, where it exists,
such as `cupyx.scipy` and `jax.scipy`. Useful for ad hoc dispatching.
Default: return `scipy` (this package).
"""
if is_cupy(xp):
import cupyx # type: ignore[import-not-found,import-untyped]
return cupyx.scipy
if is_jax(xp):
import jax # type: ignore[import-not-found]
return jax.scipy
import scipy
return scipy
# temporary substitute for xp.minimum, which is not yet in all backends
# or covered by array_api_compat.
def xp_minimum(x1: Array, x2: Array, /) -> Array:
# xp won't be passed in because it doesn't need to be passed in to xp.minimum
xp = array_namespace(x1, x2)
if hasattr(xp, 'minimum'):
return xp.minimum(x1, x2)
x1, x2 = xp.broadcast_arrays(x1, x2)
i = (x2 < x1) | xp.isnan(x2)
res = xp.where(i, x2, x1)
return res[()] if res.ndim == 0 else res
# temporary substitute for xp.clip, which is not yet in all backends
# or covered by array_api_compat.
def xp_clip(
x: Array,
/,
min: int | float | Array | None = None,
max: int | float | Array | None = None,
*,
xp: ModuleType | None = None) -> Array:
xp = array_namespace(x) if xp is None else xp
a, b = xp.asarray(min, dtype=x.dtype), xp.asarray(max, dtype=x.dtype)
if hasattr(xp, 'clip'):
return xp.clip(x, a, b)
x, a, b = xp.broadcast_arrays(x, a, b)
y = xp.asarray(x, copy=True)
ia = y < a
y[ia] = a[ia]
ib = y > b
y[ib] = b[ib]
return y[()] if y.ndim == 0 else y
# temporary substitute for xp.moveaxis, which is not yet in all backends
# or covered by array_api_compat.
def xp_moveaxis_to_end(
x: Array,
source: int,
/, *,
xp: ModuleType | None = None) -> Array:
xp = array_namespace(xp) if xp is None else xp
axes = list(range(x.ndim))
temp = axes.pop(source)
axes = axes + [temp]
return xp.permute_dims(x, axes)
# temporary substitute for xp.copysign, which is not yet in all backends
# or covered by array_api_compat.
def xp_copysign(x1: Array, x2: Array, /, *, xp: ModuleType | None = None) -> Array:
# no attempt to account for special cases
xp = array_namespace(x1, x2) if xp is None else xp
abs_x1 = xp.abs(x1)
return xp.where(x2 >= 0, abs_x1, -abs_x1)
# partial substitute for xp.sign, which does not cover the NaN special case
# that I need. (https://github.com/data-apis/array-api-compat/issues/136)
def xp_sign(x: Array, /, *, xp: ModuleType | None = None) -> Array:
xp = array_namespace(x) if xp is None else xp
if is_numpy(xp): # only NumPy implements the special cases correctly
return xp.sign(x)
sign = xp.full_like(x, xp.nan)
one = xp.asarray(1, dtype=x.dtype)
sign = xp.where(x > 0, one, sign)
sign = xp.where(x < 0, -one, sign)
sign = xp.where(x == 0, 0*one, sign)
return sign

View File

@ -0,0 +1,225 @@
import sys as _sys
from keyword import iskeyword as _iskeyword
def _validate_names(typename, field_names, extra_field_names):
"""
Ensure that all the given names are valid Python identifiers that
do not start with '_'. Also check that there are no duplicates
among field_names + extra_field_names.
"""
for name in [typename] + field_names + extra_field_names:
if not isinstance(name, str):
raise TypeError('typename and all field names must be strings')
if not name.isidentifier():
raise ValueError('typename and all field names must be valid '
f'identifiers: {name!r}')
if _iskeyword(name):
raise ValueError('typename and all field names cannot be a '
f'keyword: {name!r}')
seen = set()
for name in field_names + extra_field_names:
if name.startswith('_'):
raise ValueError('Field names cannot start with an underscore: '
f'{name!r}')
if name in seen:
raise ValueError(f'Duplicate field name: {name!r}')
seen.add(name)
# Note: This code is adapted from CPython:Lib/collections/__init__.py
def _make_tuple_bunch(typename, field_names, extra_field_names=None,
module=None):
"""
Create a namedtuple-like class with additional attributes.
This function creates a subclass of tuple that acts like a namedtuple
and that has additional attributes.
The additional attributes are listed in `extra_field_names`. The
values assigned to these attributes are not part of the tuple.
The reason this function exists is to allow functions in SciPy
that currently return a tuple or a namedtuple to returned objects
that have additional attributes, while maintaining backwards
compatibility.
This should only be used to enhance *existing* functions in SciPy.
New functions are free to create objects as return values without
having to maintain backwards compatibility with an old tuple or
namedtuple return value.
Parameters
----------
typename : str
The name of the type.
field_names : list of str
List of names of the values to be stored in the tuple. These names
will also be attributes of instances, so the values in the tuple
can be accessed by indexing or as attributes. At least one name
is required. See the Notes for additional restrictions.
extra_field_names : list of str, optional
List of names of values that will be stored as attributes of the
object. See the notes for additional restrictions.
Returns
-------
cls : type
The new class.
Notes
-----
There are restrictions on the names that may be used in `field_names`
and `extra_field_names`:
* The names must be unique--no duplicates allowed.
* The names must be valid Python identifiers, and must not begin with
an underscore.
* The names must not be Python keywords (e.g. 'def', 'and', etc., are
not allowed).
Examples
--------
>>> from scipy._lib._bunch import _make_tuple_bunch
Create a class that acts like a namedtuple with length 2 (with field
names `x` and `y`) that will also have the attributes `w` and `beta`:
>>> Result = _make_tuple_bunch('Result', ['x', 'y'], ['w', 'beta'])
`Result` is the new class. We call it with keyword arguments to create
a new instance with given values.
>>> result1 = Result(x=1, y=2, w=99, beta=0.5)
>>> result1
Result(x=1, y=2, w=99, beta=0.5)
`result1` acts like a tuple of length 2:
>>> len(result1)
2
>>> result1[:]
(1, 2)
The values assigned when the instance was created are available as
attributes:
>>> result1.y
2
>>> result1.beta
0.5
"""
if len(field_names) == 0:
raise ValueError('field_names must contain at least one name')
if extra_field_names is None:
extra_field_names = []
_validate_names(typename, field_names, extra_field_names)
typename = _sys.intern(str(typename))
field_names = tuple(map(_sys.intern, field_names))
extra_field_names = tuple(map(_sys.intern, extra_field_names))
all_names = field_names + extra_field_names
arg_list = ', '.join(field_names)
full_list = ', '.join(all_names)
repr_fmt = ''.join(('(',
', '.join(f'{name}=%({name})r' for name in all_names),
')'))
tuple_new = tuple.__new__
_dict, _tuple, _zip = dict, tuple, zip
# Create all the named tuple methods to be added to the class namespace
s = f"""\
def __new__(_cls, {arg_list}, **extra_fields):
return _tuple_new(_cls, ({arg_list},))
def __init__(self, {arg_list}, **extra_fields):
for key in self._extra_fields:
if key not in extra_fields:
raise TypeError("missing keyword argument '%s'" % (key,))
for key, val in extra_fields.items():
if key not in self._extra_fields:
raise TypeError("unexpected keyword argument '%s'" % (key,))
self.__dict__[key] = val
def __setattr__(self, key, val):
if key in {repr(field_names)}:
raise AttributeError("can't set attribute %r of class %r"
% (key, self.__class__.__name__))
else:
self.__dict__[key] = val
"""
del arg_list
namespace = {'_tuple_new': tuple_new,
'__builtins__': dict(TypeError=TypeError,
AttributeError=AttributeError),
'__name__': f'namedtuple_{typename}'}
exec(s, namespace)
__new__ = namespace['__new__']
__new__.__doc__ = f'Create new instance of {typename}({full_list})'
__init__ = namespace['__init__']
__init__.__doc__ = f'Instantiate instance of {typename}({full_list})'
__setattr__ = namespace['__setattr__']
def __repr__(self):
'Return a nicely formatted representation string'
return self.__class__.__name__ + repr_fmt % self._asdict()
def _asdict(self):
'Return a new dict which maps field names to their values.'
out = _dict(_zip(self._fields, self))
out.update(self.__dict__)
return out
def __getnewargs_ex__(self):
'Return self as a plain tuple. Used by copy and pickle.'
return _tuple(self), self.__dict__
# Modify function metadata to help with introspection and debugging
for method in (__new__, __repr__, _asdict, __getnewargs_ex__):
method.__qualname__ = f'{typename}.{method.__name__}'
# Build-up the class namespace dictionary
# and use type() to build the result class
class_namespace = {
'__doc__': f'{typename}({full_list})',
'_fields': field_names,
'__new__': __new__,
'__init__': __init__,
'__repr__': __repr__,
'__setattr__': __setattr__,
'_asdict': _asdict,
'_extra_fields': extra_field_names,
'__getnewargs_ex__': __getnewargs_ex__,
}
for index, name in enumerate(field_names):
def _get(self, index=index):
return self[index]
class_namespace[name] = property(_get)
for name in extra_field_names:
def _get(self, name=name):
return self.__dict__[name]
class_namespace[name] = property(_get)
result = type(typename, (tuple,), class_namespace)
# For pickling to work, the __module__ variable needs to be set to the
# frame where the named tuple is created. Bypass this step in environments
# where sys._getframe is not defined (Jython for example) or sys._getframe
# is not defined for arguments greater than 0 (IronPython), or where the
# user has specified a particular module.
if module is None:
try:
module = _sys._getframe(1).f_globals.get('__name__', '__main__')
except (AttributeError, ValueError):
pass
if module is not None:
result.__module__ = module
__new__.__module__ = module
return result

View File

@ -0,0 +1,251 @@
from . import _ccallback_c
import ctypes
PyCFuncPtr = ctypes.CFUNCTYPE(ctypes.c_void_p).__bases__[0]
ffi = None
class CData:
pass
def _import_cffi():
global ffi, CData
if ffi is not None:
return
try:
import cffi
ffi = cffi.FFI()
CData = ffi.CData
except ImportError:
ffi = False
class LowLevelCallable(tuple):
"""
Low-level callback function.
Some functions in SciPy take as arguments callback functions, which
can either be python callables or low-level compiled functions. Using
compiled callback functions can improve performance somewhat by
avoiding wrapping data in Python objects.
Such low-level functions in SciPy are wrapped in `LowLevelCallable`
objects, which can be constructed from function pointers obtained from
ctypes, cffi, Cython, or contained in Python `PyCapsule` objects.
.. seealso::
Functions accepting low-level callables:
`scipy.integrate.quad`, `scipy.ndimage.generic_filter`,
`scipy.ndimage.generic_filter1d`, `scipy.ndimage.geometric_transform`
Usage examples:
:ref:`ndimage-ccallbacks`, :ref:`quad-callbacks`
Parameters
----------
function : {PyCapsule, ctypes function pointer, cffi function pointer}
Low-level callback function.
user_data : {PyCapsule, ctypes void pointer, cffi void pointer}
User data to pass on to the callback function.
signature : str, optional
Signature of the function. If omitted, determined from *function*,
if possible.
Attributes
----------
function
Callback function given.
user_data
User data given.
signature
Signature of the function.
Methods
-------
from_cython
Class method for constructing callables from Cython C-exported
functions.
Notes
-----
The argument ``function`` can be one of:
- PyCapsule, whose name contains the C function signature
- ctypes function pointer
- cffi function pointer
The signature of the low-level callback must match one of those expected
by the routine it is passed to.
If constructing low-level functions from a PyCapsule, the name of the
capsule must be the corresponding signature, in the format::
return_type (arg1_type, arg2_type, ...)
For example::
"void (double)"
"double (double, int *, void *)"
The context of a PyCapsule passed in as ``function`` is used as ``user_data``,
if an explicit value for ``user_data`` was not given.
"""
# Make the class immutable
__slots__ = ()
def __new__(cls, function, user_data=None, signature=None):
# We need to hold a reference to the function & user data,
# to prevent them going out of scope
item = cls._parse_callback(function, user_data, signature)
return tuple.__new__(cls, (item, function, user_data))
def __repr__(self):
return f"LowLevelCallable({self.function!r}, {self.user_data!r})"
@property
def function(self):
return tuple.__getitem__(self, 1)
@property
def user_data(self):
return tuple.__getitem__(self, 2)
@property
def signature(self):
return _ccallback_c.get_capsule_signature(tuple.__getitem__(self, 0))
def __getitem__(self, idx):
raise ValueError()
@classmethod
def from_cython(cls, module, name, user_data=None, signature=None):
"""
Create a low-level callback function from an exported Cython function.
Parameters
----------
module : module
Cython module where the exported function resides
name : str
Name of the exported function
user_data : {PyCapsule, ctypes void pointer, cffi void pointer}, optional
User data to pass on to the callback function.
signature : str, optional
Signature of the function. If omitted, determined from *function*.
"""
try:
function = module.__pyx_capi__[name]
except AttributeError as e:
message = "Given module is not a Cython module with __pyx_capi__ attribute"
raise ValueError(message) from e
except KeyError as e:
message = f"No function {name!r} found in __pyx_capi__ of the module"
raise ValueError(message) from e
return cls(function, user_data, signature)
@classmethod
def _parse_callback(cls, obj, user_data=None, signature=None):
_import_cffi()
if isinstance(obj, LowLevelCallable):
func = tuple.__getitem__(obj, 0)
elif isinstance(obj, PyCFuncPtr):
func, signature = _get_ctypes_func(obj, signature)
elif isinstance(obj, CData):
func, signature = _get_cffi_func(obj, signature)
elif _ccallback_c.check_capsule(obj):
func = obj
else:
raise ValueError("Given input is not a callable or a "
"low-level callable (pycapsule/ctypes/cffi)")
if isinstance(user_data, ctypes.c_void_p):
context = _get_ctypes_data(user_data)
elif isinstance(user_data, CData):
context = _get_cffi_data(user_data)
elif user_data is None:
context = 0
elif _ccallback_c.check_capsule(user_data):
context = user_data
else:
raise ValueError("Given user data is not a valid "
"low-level void* pointer (pycapsule/ctypes/cffi)")
return _ccallback_c.get_raw_capsule(func, signature, context)
#
# ctypes helpers
#
def _get_ctypes_func(func, signature=None):
# Get function pointer
func_ptr = ctypes.cast(func, ctypes.c_void_p).value
# Construct function signature
if signature is None:
signature = _typename_from_ctypes(func.restype) + " ("
for j, arg in enumerate(func.argtypes):
if j == 0:
signature += _typename_from_ctypes(arg)
else:
signature += ", " + _typename_from_ctypes(arg)
signature += ")"
return func_ptr, signature
def _typename_from_ctypes(item):
if item is None:
return "void"
elif item is ctypes.c_void_p:
return "void *"
name = item.__name__
pointer_level = 0
while name.startswith("LP_"):
pointer_level += 1
name = name[3:]
if name.startswith('c_'):
name = name[2:]
if pointer_level > 0:
name += " " + "*"*pointer_level
return name
def _get_ctypes_data(data):
# Get voidp pointer
return ctypes.cast(data, ctypes.c_void_p).value
#
# CFFI helpers
#
def _get_cffi_func(func, signature=None):
# Get function pointer
func_ptr = ffi.cast('uintptr_t', func)
# Get signature
if signature is None:
signature = ffi.getctype(ffi.typeof(func)).replace('(*)', ' ')
return func_ptr, signature
def _get_cffi_data(data):
# Get pointer
return ffi.cast('uintptr_t', data)

View File

@ -0,0 +1,254 @@
"""
Disjoint set data structure
"""
class DisjointSet:
""" Disjoint set data structure for incremental connectivity queries.
.. versionadded:: 1.6.0
Attributes
----------
n_subsets : int
The number of subsets.
Methods
-------
add
merge
connected
subset
subset_size
subsets
__getitem__
Notes
-----
This class implements the disjoint set [1]_, also known as the *union-find*
or *merge-find* data structure. The *find* operation (implemented in
`__getitem__`) implements the *path halving* variant. The *merge* method
implements the *merge by size* variant.
References
----------
.. [1] https://en.wikipedia.org/wiki/Disjoint-set_data_structure
Examples
--------
>>> from scipy.cluster.hierarchy import DisjointSet
Initialize a disjoint set:
>>> disjoint_set = DisjointSet([1, 2, 3, 'a', 'b'])
Merge some subsets:
>>> disjoint_set.merge(1, 2)
True
>>> disjoint_set.merge(3, 'a')
True
>>> disjoint_set.merge('a', 'b')
True
>>> disjoint_set.merge('b', 'b')
False
Find root elements:
>>> disjoint_set[2]
1
>>> disjoint_set['b']
3
Test connectivity:
>>> disjoint_set.connected(1, 2)
True
>>> disjoint_set.connected(1, 'b')
False
List elements in disjoint set:
>>> list(disjoint_set)
[1, 2, 3, 'a', 'b']
Get the subset containing 'a':
>>> disjoint_set.subset('a')
{'a', 3, 'b'}
Get the size of the subset containing 'a' (without actually instantiating
the subset):
>>> disjoint_set.subset_size('a')
3
Get all subsets in the disjoint set:
>>> disjoint_set.subsets()
[{1, 2}, {'a', 3, 'b'}]
"""
def __init__(self, elements=None):
self.n_subsets = 0
self._sizes = {}
self._parents = {}
# _nbrs is a circular linked list which links connected elements.
self._nbrs = {}
# _indices tracks the element insertion order in `__iter__`.
self._indices = {}
if elements is not None:
for x in elements:
self.add(x)
def __iter__(self):
"""Returns an iterator of the elements in the disjoint set.
Elements are ordered by insertion order.
"""
return iter(self._indices)
def __len__(self):
return len(self._indices)
def __contains__(self, x):
return x in self._indices
def __getitem__(self, x):
"""Find the root element of `x`.
Parameters
----------
x : hashable object
Input element.
Returns
-------
root : hashable object
Root element of `x`.
"""
if x not in self._indices:
raise KeyError(x)
# find by "path halving"
parents = self._parents
while self._indices[x] != self._indices[parents[x]]:
parents[x] = parents[parents[x]]
x = parents[x]
return x
def add(self, x):
"""Add element `x` to disjoint set
"""
if x in self._indices:
return
self._sizes[x] = 1
self._parents[x] = x
self._nbrs[x] = x
self._indices[x] = len(self._indices)
self.n_subsets += 1
def merge(self, x, y):
"""Merge the subsets of `x` and `y`.
The smaller subset (the child) is merged into the larger subset (the
parent). If the subsets are of equal size, the root element which was
first inserted into the disjoint set is selected as the parent.
Parameters
----------
x, y : hashable object
Elements to merge.
Returns
-------
merged : bool
True if `x` and `y` were in disjoint sets, False otherwise.
"""
xr = self[x]
yr = self[y]
if self._indices[xr] == self._indices[yr]:
return False
sizes = self._sizes
if (sizes[xr], self._indices[yr]) < (sizes[yr], self._indices[xr]):
xr, yr = yr, xr
self._parents[yr] = xr
self._sizes[xr] += self._sizes[yr]
self._nbrs[xr], self._nbrs[yr] = self._nbrs[yr], self._nbrs[xr]
self.n_subsets -= 1
return True
def connected(self, x, y):
"""Test whether `x` and `y` are in the same subset.
Parameters
----------
x, y : hashable object
Elements to test.
Returns
-------
result : bool
True if `x` and `y` are in the same set, False otherwise.
"""
return self._indices[self[x]] == self._indices[self[y]]
def subset(self, x):
"""Get the subset containing `x`.
Parameters
----------
x : hashable object
Input element.
Returns
-------
result : set
Subset containing `x`.
"""
if x not in self._indices:
raise KeyError(x)
result = [x]
nxt = self._nbrs[x]
while self._indices[nxt] != self._indices[x]:
result.append(nxt)
nxt = self._nbrs[nxt]
return set(result)
def subset_size(self, x):
"""Get the size of the subset containing `x`.
Note that this method is faster than ``len(self.subset(x))`` because
the size is directly read off an internal field, without the need to
instantiate the full subset.
Parameters
----------
x : hashable object
Input element.
Returns
-------
result : int
Size of the subset containing `x`.
"""
return self._sizes[self[x]]
def subsets(self):
"""Get all the subsets in the disjoint set.
Returns
-------
result : list
Subsets in the disjoint set.
"""
result = []
visited = set()
for x in self:
if x not in visited:
xset = self.subset(x)
visited.update(xset)
result.append(xset)
return result

View File

@ -0,0 +1,679 @@
"""Extract reference documentation from the NumPy source tree.
"""
# copied from numpydoc/docscrape.py
import inspect
import textwrap
import re
import pydoc
from warnings import warn
from collections import namedtuple
from collections.abc import Callable, Mapping
import copy
import sys
def strip_blank_lines(l):
"Remove leading and trailing blank lines from a list of lines"
while l and not l[0].strip():
del l[0]
while l and not l[-1].strip():
del l[-1]
return l
class Reader:
"""A line-based string reader.
"""
def __init__(self, data):
"""
Parameters
----------
data : str
String with lines separated by '\\n'.
"""
if isinstance(data, list):
self._str = data
else:
self._str = data.split('\n') # store string as list of lines
self.reset()
def __getitem__(self, n):
return self._str[n]
def reset(self):
self._l = 0 # current line nr
def read(self):
if not self.eof():
out = self[self._l]
self._l += 1
return out
else:
return ''
def seek_next_non_empty_line(self):
for l in self[self._l:]:
if l.strip():
break
else:
self._l += 1
def eof(self):
return self._l >= len(self._str)
def read_to_condition(self, condition_func):
start = self._l
for line in self[start:]:
if condition_func(line):
return self[start:self._l]
self._l += 1
if self.eof():
return self[start:self._l+1]
return []
def read_to_next_empty_line(self):
self.seek_next_non_empty_line()
def is_empty(line):
return not line.strip()
return self.read_to_condition(is_empty)
def read_to_next_unindented_line(self):
def is_unindented(line):
return (line.strip() and (len(line.lstrip()) == len(line)))
return self.read_to_condition(is_unindented)
def peek(self, n=0):
if self._l + n < len(self._str):
return self[self._l + n]
else:
return ''
def is_empty(self):
return not ''.join(self._str).strip()
class ParseError(Exception):
def __str__(self):
message = self.args[0]
if hasattr(self, 'docstring'):
message = f"{message} in {self.docstring!r}"
return message
Parameter = namedtuple('Parameter', ['name', 'type', 'desc'])
class NumpyDocString(Mapping):
"""Parses a numpydoc string to an abstract representation
Instances define a mapping from section title to structured data.
"""
sections = {
'Signature': '',
'Summary': [''],
'Extended Summary': [],
'Parameters': [],
'Returns': [],
'Yields': [],
'Receives': [],
'Raises': [],
'Warns': [],
'Other Parameters': [],
'Attributes': [],
'Methods': [],
'See Also': [],
'Notes': [],
'Warnings': [],
'References': '',
'Examples': '',
'index': {}
}
def __init__(self, docstring, config={}):
orig_docstring = docstring
docstring = textwrap.dedent(docstring).split('\n')
self._doc = Reader(docstring)
self._parsed_data = copy.deepcopy(self.sections)
try:
self._parse()
except ParseError as e:
e.docstring = orig_docstring
raise
def __getitem__(self, key):
return self._parsed_data[key]
def __setitem__(self, key, val):
if key not in self._parsed_data:
self._error_location("Unknown section %s" % key, error=False)
else:
self._parsed_data[key] = val
def __iter__(self):
return iter(self._parsed_data)
def __len__(self):
return len(self._parsed_data)
def _is_at_section(self):
self._doc.seek_next_non_empty_line()
if self._doc.eof():
return False
l1 = self._doc.peek().strip() # e.g. Parameters
if l1.startswith('.. index::'):
return True
l2 = self._doc.peek(1).strip() # ---------- or ==========
return l2.startswith('-'*len(l1)) or l2.startswith('='*len(l1))
def _strip(self, doc):
i = 0
j = 0
for i, line in enumerate(doc):
if line.strip():
break
for j, line in enumerate(doc[::-1]):
if line.strip():
break
return doc[i:len(doc)-j]
def _read_to_next_section(self):
section = self._doc.read_to_next_empty_line()
while not self._is_at_section() and not self._doc.eof():
if not self._doc.peek(-1).strip(): # previous line was empty
section += ['']
section += self._doc.read_to_next_empty_line()
return section
def _read_sections(self):
while not self._doc.eof():
data = self._read_to_next_section()
name = data[0].strip()
if name.startswith('..'): # index section
yield name, data[1:]
elif len(data) < 2:
yield StopIteration
else:
yield name, self._strip(data[2:])
def _parse_param_list(self, content, single_element_is_type=False):
r = Reader(content)
params = []
while not r.eof():
header = r.read().strip()
if ' : ' in header:
arg_name, arg_type = header.split(' : ')[:2]
else:
if single_element_is_type:
arg_name, arg_type = '', header
else:
arg_name, arg_type = header, ''
desc = r.read_to_next_unindented_line()
desc = dedent_lines(desc)
desc = strip_blank_lines(desc)
params.append(Parameter(arg_name, arg_type, desc))
return params
# See also supports the following formats.
#
# <FUNCNAME>
# <FUNCNAME> SPACE* COLON SPACE+ <DESC> SPACE*
# <FUNCNAME> ( COMMA SPACE+ <FUNCNAME>)+ (COMMA | PERIOD)? SPACE*
# <FUNCNAME> ( COMMA SPACE+ <FUNCNAME>)* SPACE* COLON SPACE+ <DESC> SPACE*
# <FUNCNAME> is one of
# <PLAIN_FUNCNAME>
# COLON <ROLE> COLON BACKTICK <PLAIN_FUNCNAME> BACKTICK
# where
# <PLAIN_FUNCNAME> is a legal function name, and
# <ROLE> is any nonempty sequence of word characters.
# Examples: func_f1 :meth:`func_h1` :obj:`~baz.obj_r` :class:`class_j`
# <DESC> is a string describing the function.
_role = r":(?P<role>\w+):"
_funcbacktick = r"`(?P<name>(?:~\w+\.)?[a-zA-Z0-9_\.-]+)`"
_funcplain = r"(?P<name2>[a-zA-Z0-9_\.-]+)"
_funcname = r"(" + _role + _funcbacktick + r"|" + _funcplain + r")"
_funcnamenext = _funcname.replace('role', 'rolenext')
_funcnamenext = _funcnamenext.replace('name', 'namenext')
_description = r"(?P<description>\s*:(\s+(?P<desc>\S+.*))?)?\s*$"
_func_rgx = re.compile(r"^\s*" + _funcname + r"\s*")
_line_rgx = re.compile(
r"^\s*" +
r"(?P<allfuncs>" + # group for all function names
_funcname +
r"(?P<morefuncs>([,]\s+" + _funcnamenext + r")*)" +
r")" + # end of "allfuncs"
# Some function lists have a trailing comma (or period) '\s*'
r"(?P<trailing>[,\.])?" +
_description)
# Empty <DESC> elements are replaced with '..'
empty_description = '..'
def _parse_see_also(self, content):
"""
func_name : Descriptive text
continued text
another_func_name : Descriptive text
func_name1, func_name2, :meth:`func_name`, func_name3
"""
items = []
def parse_item_name(text):
"""Match ':role:`name`' or 'name'."""
m = self._func_rgx.match(text)
if not m:
raise ParseError("%s is not a item name" % text)
role = m.group('role')
name = m.group('name') if role else m.group('name2')
return name, role, m.end()
rest = []
for line in content:
if not line.strip():
continue
line_match = self._line_rgx.match(line)
description = None
if line_match:
description = line_match.group('desc')
if line_match.group('trailing') and description:
self._error_location(
'Unexpected comma or period after function list at '
'index %d of line "%s"' % (line_match.end('trailing'),
line),
error=False)
if not description and line.startswith(' '):
rest.append(line.strip())
elif line_match:
funcs = []
text = line_match.group('allfuncs')
while True:
if not text.strip():
break
name, role, match_end = parse_item_name(text)
funcs.append((name, role))
text = text[match_end:].strip()
if text and text[0] == ',':
text = text[1:].strip()
rest = list(filter(None, [description]))
items.append((funcs, rest))
else:
raise ParseError("%s is not a item name" % line)
return items
def _parse_index(self, section, content):
"""
.. index:: default
:refguide: something, else, and more
"""
def strip_each_in(lst):
return [s.strip() for s in lst]
out = {}
section = section.split('::')
if len(section) > 1:
out['default'] = strip_each_in(section[1].split(','))[0]
for line in content:
line = line.split(':')
if len(line) > 2:
out[line[1]] = strip_each_in(line[2].split(','))
return out
def _parse_summary(self):
"""Grab signature (if given) and summary"""
if self._is_at_section():
return
# If several signatures present, take the last one
while True:
summary = self._doc.read_to_next_empty_line()
summary_str = " ".join([s.strip() for s in summary]).strip()
compiled = re.compile(r'^([\w., ]+=)?\s*[\w\.]+\(.*\)$')
if compiled.match(summary_str):
self['Signature'] = summary_str
if not self._is_at_section():
continue
break
if summary is not None:
self['Summary'] = summary
if not self._is_at_section():
self['Extended Summary'] = self._read_to_next_section()
def _parse(self):
self._doc.reset()
self._parse_summary()
sections = list(self._read_sections())
section_names = {section for section, content in sections}
has_returns = 'Returns' in section_names
has_yields = 'Yields' in section_names
# We could do more tests, but we are not. Arbitrarily.
if has_returns and has_yields:
msg = 'Docstring contains both a Returns and Yields section.'
raise ValueError(msg)
if not has_yields and 'Receives' in section_names:
msg = 'Docstring contains a Receives section but not Yields.'
raise ValueError(msg)
for (section, content) in sections:
if not section.startswith('..'):
section = (s.capitalize() for s in section.split(' '))
section = ' '.join(section)
if self.get(section):
self._error_location("The section %s appears twice"
% section)
if section in ('Parameters', 'Other Parameters', 'Attributes',
'Methods'):
self[section] = self._parse_param_list(content)
elif section in ('Returns', 'Yields', 'Raises', 'Warns',
'Receives'):
self[section] = self._parse_param_list(
content, single_element_is_type=True)
elif section.startswith('.. index::'):
self['index'] = self._parse_index(section, content)
elif section == 'See Also':
self['See Also'] = self._parse_see_also(content)
else:
self[section] = content
def _error_location(self, msg, error=True):
if hasattr(self, '_obj'):
# we know where the docs came from:
try:
filename = inspect.getsourcefile(self._obj)
except TypeError:
filename = None
msg = msg + (f" in the docstring of {self._obj} in {filename}.")
if error:
raise ValueError(msg)
else:
warn(msg, stacklevel=3)
# string conversion routines
def _str_header(self, name, symbol='-'):
return [name, len(name)*symbol]
def _str_indent(self, doc, indent=4):
out = []
for line in doc:
out += [' '*indent + line]
return out
def _str_signature(self):
if self['Signature']:
return [self['Signature'].replace('*', r'\*')] + ['']
else:
return ['']
def _str_summary(self):
if self['Summary']:
return self['Summary'] + ['']
else:
return []
def _str_extended_summary(self):
if self['Extended Summary']:
return self['Extended Summary'] + ['']
else:
return []
def _str_param_list(self, name):
out = []
if self[name]:
out += self._str_header(name)
for param in self[name]:
parts = []
if param.name:
parts.append(param.name)
if param.type:
parts.append(param.type)
out += [' : '.join(parts)]
if param.desc and ''.join(param.desc).strip():
out += self._str_indent(param.desc)
out += ['']
return out
def _str_section(self, name):
out = []
if self[name]:
out += self._str_header(name)
out += self[name]
out += ['']
return out
def _str_see_also(self, func_role):
if not self['See Also']:
return []
out = []
out += self._str_header("See Also")
out += ['']
last_had_desc = True
for funcs, desc in self['See Also']:
assert isinstance(funcs, list)
links = []
for func, role in funcs:
if role:
link = f':{role}:`{func}`'
elif func_role:
link = f':{func_role}:`{func}`'
else:
link = "`%s`_" % func
links.append(link)
link = ', '.join(links)
out += [link]
if desc:
out += self._str_indent([' '.join(desc)])
last_had_desc = True
else:
last_had_desc = False
out += self._str_indent([self.empty_description])
if last_had_desc:
out += ['']
out += ['']
return out
def _str_index(self):
idx = self['index']
out = []
output_index = False
default_index = idx.get('default', '')
if default_index:
output_index = True
out += ['.. index:: %s' % default_index]
for section, references in idx.items():
if section == 'default':
continue
output_index = True
out += [' :{}: {}'.format(section, ', '.join(references))]
if output_index:
return out
else:
return ''
def __str__(self, func_role=''):
out = []
out += self._str_signature()
out += self._str_summary()
out += self._str_extended_summary()
for param_list in ('Parameters', 'Returns', 'Yields', 'Receives',
'Other Parameters', 'Raises', 'Warns'):
out += self._str_param_list(param_list)
out += self._str_section('Warnings')
out += self._str_see_also(func_role)
for s in ('Notes', 'References', 'Examples'):
out += self._str_section(s)
for param_list in ('Attributes', 'Methods'):
out += self._str_param_list(param_list)
out += self._str_index()
return '\n'.join(out)
def indent(str, indent=4):
indent_str = ' '*indent
if str is None:
return indent_str
lines = str.split('\n')
return '\n'.join(indent_str + l for l in lines)
def dedent_lines(lines):
"""Deindent a list of lines maximally"""
return textwrap.dedent("\n".join(lines)).split("\n")
def header(text, style='-'):
return text + '\n' + style*len(text) + '\n'
class FunctionDoc(NumpyDocString):
def __init__(self, func, role='func', doc=None, config={}):
self._f = func
self._role = role # e.g. "func" or "meth"
if doc is None:
if func is None:
raise ValueError("No function or docstring given")
doc = inspect.getdoc(func) or ''
NumpyDocString.__init__(self, doc, config)
def get_func(self):
func_name = getattr(self._f, '__name__', self.__class__.__name__)
if inspect.isclass(self._f):
func = getattr(self._f, '__call__', self._f.__init__)
else:
func = self._f
return func, func_name
def __str__(self):
out = ''
func, func_name = self.get_func()
roles = {'func': 'function',
'meth': 'method'}
if self._role:
if self._role not in roles:
print("Warning: invalid role %s" % self._role)
out += '.. {}:: {}\n \n\n'.format(roles.get(self._role, ''),
func_name)
out += super().__str__(func_role=self._role)
return out
class ClassDoc(NumpyDocString):
extra_public_methods = ['__call__']
def __init__(self, cls, doc=None, modulename='', func_doc=FunctionDoc,
config={}):
if not inspect.isclass(cls) and cls is not None:
raise ValueError("Expected a class or None, but got %r" % cls)
self._cls = cls
if 'sphinx' in sys.modules:
from sphinx.ext.autodoc import ALL
else:
ALL = object()
self.show_inherited_members = config.get(
'show_inherited_class_members', True)
if modulename and not modulename.endswith('.'):
modulename += '.'
self._mod = modulename
if doc is None:
if cls is None:
raise ValueError("No class or documentation string given")
doc = pydoc.getdoc(cls)
NumpyDocString.__init__(self, doc)
_members = config.get('members', [])
if _members is ALL:
_members = None
_exclude = config.get('exclude-members', [])
if config.get('show_class_members', True) and _exclude is not ALL:
def splitlines_x(s):
if not s:
return []
else:
return s.splitlines()
for field, items in [('Methods', self.methods),
('Attributes', self.properties)]:
if not self[field]:
doc_list = []
for name in sorted(items):
if (name in _exclude or
(_members and name not in _members)):
continue
try:
doc_item = pydoc.getdoc(getattr(self._cls, name))
doc_list.append(
Parameter(name, '', splitlines_x(doc_item)))
except AttributeError:
pass # method doesn't exist
self[field] = doc_list
@property
def methods(self):
if self._cls is None:
return []
return [name for name, func in inspect.getmembers(self._cls)
if ((not name.startswith('_')
or name in self.extra_public_methods)
and isinstance(func, Callable)
and self._is_show_member(name))]
@property
def properties(self):
if self._cls is None:
return []
return [name for name, func in inspect.getmembers(self._cls)
if (not name.startswith('_') and
(func is None or isinstance(func, property) or
inspect.isdatadescriptor(func))
and self._is_show_member(name))]
def _is_show_member(self, name):
if self.show_inherited_members:
return True # show all class members
if name not in self._cls.__dict__:
return False # class member is inherited, we do not show it
return True

View File

@ -0,0 +1,348 @@
# `_elementwise_iterative_method.py` includes tools for writing functions that
# - are vectorized to work elementwise on arrays,
# - implement non-trivial, iterative algorithms with a callback interface, and
# - return rich objects with iteration count, termination status, etc.
#
# Examples include:
# `scipy.optimize._chandrupatla._chandrupatla for scalar rootfinding,
# `scipy.optimize._chandrupatla._chandrupatla_minimize for scalar minimization,
# `scipy.optimize._differentiate._differentiate for numerical differentiation,
# `scipy.optimize._bracket._bracket_root for finding rootfinding brackets,
# `scipy.optimize._bracket._bracket_minimize for finding minimization brackets,
# `scipy.integrate._tanhsinh._tanhsinh` for numerical quadrature.
import math
import numpy as np
from ._util import _RichResult, _call_callback_maybe_halt
from ._array_api import array_namespace, size as xp_size
_ESIGNERR = -1
_ECONVERR = -2
_EVALUEERR = -3
_ECALLBACK = -4
_EINPUTERR = -5
_ECONVERGED = 0
_EINPROGRESS = 1
def _initialize(func, xs, args, complex_ok=False, preserve_shape=None):
"""Initialize abscissa, function, and args arrays for elementwise function
Parameters
----------
func : callable
An elementwise function with signature
func(x: ndarray, *args) -> ndarray
where each element of ``x`` is a finite real and ``args`` is a tuple,
which may contain an arbitrary number of arrays that are broadcastable
with ``x``.
xs : tuple of arrays
Finite real abscissa arrays. Must be broadcastable.
args : tuple, optional
Additional positional arguments to be passed to `func`.
preserve_shape : bool, default:False
When ``preserve_shape=False`` (default), `func` may be passed
arguments of any shape; `_scalar_optimization_loop` is permitted
to reshape and compress arguments at will. When
``preserve_shape=False``, arguments passed to `func` must have shape
`shape` or ``shape + (n,)``, where ``n`` is any integer.
Returns
-------
xs, fs, args : tuple of arrays
Broadcasted, writeable, 1D abscissa and function value arrays (or
NumPy floats, if appropriate). The dtypes of the `xs` and `fs` are
`xfat`; the dtype of the `args` are unchanged.
shape : tuple of ints
Original shape of broadcasted arrays.
xfat : NumPy dtype
Result dtype of abscissae, function values, and args determined using
`np.result_type`, except integer types are promoted to `np.float64`.
Raises
------
ValueError
If the result dtype is not that of a real scalar
Notes
-----
Useful for initializing the input of SciPy functions that accept
an elementwise callable, abscissae, and arguments; e.g.
`scipy.optimize._chandrupatla`.
"""
nx = len(xs)
xp = array_namespace(*xs)
# Try to preserve `dtype`, but we need to ensure that the arguments are at
# least floats before passing them into the function; integers can overflow
# and cause failure.
# There might be benefit to combining the `xs` into a single array and
# calling `func` once on the combined array. For now, keep them separate.
xas = xp.broadcast_arrays(*xs, *args) # broadcast and rename
xat = xp.result_type(*[xa.dtype for xa in xas])
xat = xp.asarray(1.).dtype if xp.isdtype(xat, "integral") else xat
xs, args = xas[:nx], xas[nx:]
xs = [xp.asarray(x, dtype=xat) for x in xs] # use copy=False when implemented
fs = [xp.asarray(func(x, *args)) for x in xs]
shape = xs[0].shape
fshape = fs[0].shape
if preserve_shape:
# bind original shape/func now to avoid late-binding gotcha
def func(x, *args, shape=shape, func=func, **kwargs):
i = (0,)*(len(fshape) - len(shape))
return func(x[i], *args, **kwargs)
shape = np.broadcast_shapes(fshape, shape) # just shapes; use of NumPy OK
xs = [xp.broadcast_to(x, shape) for x in xs]
args = [xp.broadcast_to(arg, shape) for arg in args]
message = ("The shape of the array returned by `func` must be the same as "
"the broadcasted shape of `x` and all other `args`.")
if preserve_shape is not None: # only in tanhsinh for now
message = f"When `preserve_shape=False`, {message.lower()}"
shapes_equal = [f.shape == shape for f in fs]
if not all(shapes_equal): # use Python all to reduce overhead
raise ValueError(message)
# These algorithms tend to mix the dtypes of the abscissae and function
# values, so figure out what the result will be and convert them all to
# that type from the outset.
xfat = xp.result_type(*([f.dtype for f in fs] + [xat]))
if not complex_ok and not xp.isdtype(xfat, "real floating"):
raise ValueError("Abscissae and function output must be real numbers.")
xs = [xp.asarray(x, dtype=xfat, copy=True) for x in xs]
fs = [xp.asarray(f, dtype=xfat, copy=True) for f in fs]
# To ensure that we can do indexing, we'll work with at least 1d arrays,
# but remember the appropriate shape of the output.
xs = [xp.reshape(x, (-1,)) for x in xs]
fs = [xp.reshape(f, (-1,)) for f in fs]
args = [xp.reshape(xp.asarray(arg, copy=True), (-1,)) for arg in args]
return func, xs, fs, args, shape, xfat, xp
def _loop(work, callback, shape, maxiter, func, args, dtype, pre_func_eval,
post_func_eval, check_termination, post_termination_check,
customize_result, res_work_pairs, xp, preserve_shape=False):
"""Main loop of a vectorized scalar optimization algorithm
Parameters
----------
work : _RichResult
All variables that need to be retained between iterations. Must
contain attributes `nit`, `nfev`, and `success`
callback : callable
User-specified callback function
shape : tuple of ints
The shape of all output arrays
maxiter :
Maximum number of iterations of the algorithm
func : callable
The user-specified callable that is being optimized or solved
args : tuple
Additional positional arguments to be passed to `func`.
dtype : NumPy dtype
The common dtype of all abscissae and function values
pre_func_eval : callable
A function that accepts `work` and returns `x`, the active elements
of `x` at which `func` will be evaluated. May modify attributes
of `work` with any algorithmic steps that need to happen
at the beginning of an iteration, before `func` is evaluated,
post_func_eval : callable
A function that accepts `x`, `func(x)`, and `work`. May modify
attributes of `work` with any algorithmic steps that need to happen
in the middle of an iteration, after `func` is evaluated but before
the termination check.
check_termination : callable
A function that accepts `work` and returns `stop`, a boolean array
indicating which of the active elements have met a termination
condition.
post_termination_check : callable
A function that accepts `work`. May modify `work` with any algorithmic
steps that need to happen after the termination check and before the
end of the iteration.
customize_result : callable
A function that accepts `res` and `shape` and returns `shape`. May
modify `res` (in-place) according to preferences (e.g. rearrange
elements between attributes) and modify `shape` if needed.
res_work_pairs : list of (str, str)
Identifies correspondence between attributes of `res` and attributes
of `work`; i.e., attributes of active elements of `work` will be
copied to the appropriate indices of `res` when appropriate. The order
determines the order in which _RichResult attributes will be
pretty-printed.
Returns
-------
res : _RichResult
The final result object
Notes
-----
Besides providing structure, this framework provides several important
services for a vectorized optimization algorithm.
- It handles common tasks involving iteration count, function evaluation
count, a user-specified callback, and associated termination conditions.
- It compresses the attributes of `work` to eliminate unnecessary
computation on elements that have already converged.
"""
if xp is None:
raise NotImplementedError("Must provide xp.")
cb_terminate = False
# Initialize the result object and active element index array
n_elements = math.prod(shape)
active = xp.arange(n_elements) # in-progress element indices
res_dict = {i: xp.zeros(n_elements, dtype=dtype) for i, j in res_work_pairs}
res_dict['success'] = xp.zeros(n_elements, dtype=xp.bool)
res_dict['status'] = xp.full(n_elements, _EINPROGRESS, dtype=xp.int32)
res_dict['nit'] = xp.zeros(n_elements, dtype=xp.int32)
res_dict['nfev'] = xp.zeros(n_elements, dtype=xp.int32)
res = _RichResult(res_dict)
work.args = args
active = _check_termination(work, res, res_work_pairs, active,
check_termination, preserve_shape, xp)
if callback is not None:
temp = _prepare_result(work, res, res_work_pairs, active, shape,
customize_result, preserve_shape, xp)
if _call_callback_maybe_halt(callback, temp):
cb_terminate = True
while work.nit < maxiter and xp_size(active) and not cb_terminate and n_elements:
x = pre_func_eval(work)
if work.args and work.args[0].ndim != x.ndim:
# `x` always starts as 1D. If the SciPy function that uses
# _loop added dimensions to `x`, we need to
# add them to the elements of `args`.
args = []
for arg in work.args:
n_new_dims = x.ndim - arg.ndim
new_shape = arg.shape + (1,)*n_new_dims
args.append(xp.reshape(arg, new_shape))
work.args = args
x_shape = x.shape
if preserve_shape:
x = xp.reshape(x, (shape + (-1,)))
f = func(x, *work.args)
f = xp.asarray(f, dtype=dtype)
if preserve_shape:
x = xp.reshape(x, x_shape)
f = xp.reshape(f, x_shape)
work.nfev += 1 if x.ndim == 1 else x.shape[-1]
post_func_eval(x, f, work)
work.nit += 1
active = _check_termination(work, res, res_work_pairs, active,
check_termination, preserve_shape, xp)
if callback is not None:
temp = _prepare_result(work, res, res_work_pairs, active, shape,
customize_result, preserve_shape, xp)
if _call_callback_maybe_halt(callback, temp):
cb_terminate = True
break
if xp_size(active) == 0:
break
post_termination_check(work)
work.status[:] = _ECALLBACK if cb_terminate else _ECONVERR
return _prepare_result(work, res, res_work_pairs, active, shape,
customize_result, preserve_shape, xp)
def _check_termination(work, res, res_work_pairs, active, check_termination,
preserve_shape, xp):
# Checks termination conditions, updates elements of `res` with
# corresponding elements of `work`, and compresses `work`.
stop = check_termination(work)
if xp.any(stop):
# update the active elements of the result object with the active
# elements for which a termination condition has been met
_update_active(work, res, res_work_pairs, active, stop, preserve_shape, xp)
if preserve_shape:
stop = stop[active]
proceed = ~stop
active = active[proceed]
if not preserve_shape:
# compress the arrays to avoid unnecessary computation
for key, val in work.items():
# Need to find a better way than these try/excepts
# Somehow need to keep compressible numerical args separate
if key == 'args':
continue
try:
work[key] = val[proceed]
except (IndexError, TypeError, KeyError): # not a compressible array
work[key] = val
work.args = [arg[proceed] for arg in work.args]
return active
def _update_active(work, res, res_work_pairs, active, mask, preserve_shape, xp):
# Update `active` indices of the arrays in result object `res` with the
# contents of the scalars and arrays in `update_dict`. When provided,
# `mask` is a boolean array applied both to the arrays in `update_dict`
# that are to be used and to the arrays in `res` that are to be updated.
update_dict = {key1: work[key2] for key1, key2 in res_work_pairs}
update_dict['success'] = work.status == 0
if mask is not None:
if preserve_shape:
active_mask = xp.zeros_like(mask)
active_mask[active] = 1
active_mask = active_mask & mask
for key, val in update_dict.items():
try:
res[key][active_mask] = val[active_mask]
except (IndexError, TypeError, KeyError):
res[key][active_mask] = val
else:
active_mask = active[mask]
for key, val in update_dict.items():
try:
res[key][active_mask] = val[mask]
except (IndexError, TypeError, KeyError):
res[key][active_mask] = val
else:
for key, val in update_dict.items():
if preserve_shape:
try:
val = val[active]
except (IndexError, TypeError, KeyError):
pass
res[key][active] = val
def _prepare_result(work, res, res_work_pairs, active, shape, customize_result,
preserve_shape, xp):
# Prepare the result object `res` by creating a copy, copying the latest
# data from work, running the provided result customization function,
# and reshaping the data to the original shapes.
res = res.copy()
_update_active(work, res, res_work_pairs, active, None, preserve_shape, xp)
shape = customize_result(res, shape)
for key, val in res.items():
# this looks like it won't work for xp != np if val is not numeric
temp = xp.reshape(val, shape)
res[key] = temp[()] if temp.ndim == 0 else temp
res['_order_keys'] = ['success'] + [i for i, j in res_work_pairs]
return _RichResult(**res)

View File

@ -0,0 +1,145 @@
from numpy import arange, newaxis, hstack, prod, array
def _central_diff_weights(Np, ndiv=1):
"""
Return weights for an Np-point central derivative.
Assumes equally-spaced function points.
If weights are in the vector w, then
derivative is w[0] * f(x-ho*dx) + ... + w[-1] * f(x+h0*dx)
Parameters
----------
Np : int
Number of points for the central derivative.
ndiv : int, optional
Number of divisions. Default is 1.
Returns
-------
w : ndarray
Weights for an Np-point central derivative. Its size is `Np`.
Notes
-----
Can be inaccurate for a large number of points.
Examples
--------
We can calculate a derivative value of a function.
>>> def f(x):
... return 2 * x**2 + 3
>>> x = 3.0 # derivative point
>>> h = 0.1 # differential step
>>> Np = 3 # point number for central derivative
>>> weights = _central_diff_weights(Np) # weights for first derivative
>>> vals = [f(x + (i - Np/2) * h) for i in range(Np)]
>>> sum(w * v for (w, v) in zip(weights, vals))/h
11.79999999999998
This value is close to the analytical solution:
f'(x) = 4x, so f'(3) = 12
References
----------
.. [1] https://en.wikipedia.org/wiki/Finite_difference
"""
if Np < ndiv + 1:
raise ValueError(
"Number of points must be at least the derivative order + 1."
)
if Np % 2 == 0:
raise ValueError("The number of points must be odd.")
from scipy import linalg
ho = Np >> 1
x = arange(-ho, ho + 1.0)
x = x[:, newaxis]
X = x**0.0
for k in range(1, Np):
X = hstack([X, x**k])
w = prod(arange(1, ndiv + 1), axis=0) * linalg.inv(X)[ndiv]
return w
def _derivative(func, x0, dx=1.0, n=1, args=(), order=3):
"""
Find the nth derivative of a function at a point.
Given a function, use a central difference formula with spacing `dx` to
compute the nth derivative at `x0`.
Parameters
----------
func : function
Input function.
x0 : float
The point at which the nth derivative is found.
dx : float, optional
Spacing.
n : int, optional
Order of the derivative. Default is 1.
args : tuple, optional
Arguments
order : int, optional
Number of points to use, must be odd.
Notes
-----
Decreasing the step size too small can result in round-off error.
Examples
--------
>>> def f(x):
... return x**3 + x**2
>>> _derivative(f, 1.0, dx=1e-6)
4.9999999999217337
"""
if order < n + 1:
raise ValueError(
"'order' (the number of points used to compute the derivative), "
"must be at least the derivative order 'n' + 1."
)
if order % 2 == 0:
raise ValueError(
"'order' (the number of points used to compute the derivative) "
"must be odd."
)
# pre-computed for n=1 and 2 and low-order for speed.
if n == 1:
if order == 3:
weights = array([-1, 0, 1]) / 2.0
elif order == 5:
weights = array([1, -8, 0, 8, -1]) / 12.0
elif order == 7:
weights = array([-1, 9, -45, 0, 45, -9, 1]) / 60.0
elif order == 9:
weights = array([3, -32, 168, -672, 0, 672, -168, 32, -3]) / 840.0
else:
weights = _central_diff_weights(order, 1)
elif n == 2:
if order == 3:
weights = array([1, -2.0, 1])
elif order == 5:
weights = array([-1, 16, -30, 16, -1]) / 12.0
elif order == 7:
weights = array([2, -27, 270, -490, 270, -27, 2]) / 180.0
elif order == 9:
weights = (
array([-9, 128, -1008, 8064, -14350, 8064, -1008, 128, -9])
/ 5040.0
)
else:
weights = _central_diff_weights(order, 2)
else:
weights = _central_diff_weights(order, n)
val = 0.0
ho = order >> 1
for k in range(order):
val += weights[k] * func(x0 + (k - ho) * dx, *args)
return val / prod((dx,) * n, axis=0)

View File

@ -0,0 +1,105 @@
"""
Module for testing automatic garbage collection of objects
.. autosummary::
:toctree: generated/
set_gc_state - enable or disable garbage collection
gc_state - context manager for given state of garbage collector
assert_deallocated - context manager to check for circular references on object
"""
import weakref
import gc
from contextlib import contextmanager
from platform import python_implementation
__all__ = ['set_gc_state', 'gc_state', 'assert_deallocated']
IS_PYPY = python_implementation() == 'PyPy'
class ReferenceError(AssertionError):
pass
def set_gc_state(state):
""" Set status of garbage collector """
if gc.isenabled() == state:
return
if state:
gc.enable()
else:
gc.disable()
@contextmanager
def gc_state(state):
""" Context manager to set state of garbage collector to `state`
Parameters
----------
state : bool
True for gc enabled, False for disabled
Examples
--------
>>> with gc_state(False):
... assert not gc.isenabled()
>>> with gc_state(True):
... assert gc.isenabled()
"""
orig_state = gc.isenabled()
set_gc_state(state)
yield
set_gc_state(orig_state)
@contextmanager
def assert_deallocated(func, *args, **kwargs):
"""Context manager to check that object is deallocated
This is useful for checking that an object can be freed directly by
reference counting, without requiring gc to break reference cycles.
GC is disabled inside the context manager.
This check is not available on PyPy.
Parameters
----------
func : callable
Callable to create object to check
\\*args : sequence
positional arguments to `func` in order to create object to check
\\*\\*kwargs : dict
keyword arguments to `func` in order to create object to check
Examples
--------
>>> class C: pass
>>> with assert_deallocated(C) as c:
... # do something
... del c
>>> class C:
... def __init__(self):
... self._circular = self # Make circular reference
>>> with assert_deallocated(C) as c: #doctest: +IGNORE_EXCEPTION_DETAIL
... # do something
... del c
Traceback (most recent call last):
...
ReferenceError: Remaining reference(s) to object
"""
if IS_PYPY:
raise RuntimeError("assert_deallocated is unavailable on PyPy")
with gc_state(False):
obj = func(*args, **kwargs)
ref = weakref.ref(obj)
yield obj
del obj
if ref() is not None:
raise ReferenceError("Remaining reference(s) to object")

View File

@ -0,0 +1,487 @@
"""Utility to compare pep440 compatible version strings.
The LooseVersion and StrictVersion classes that distutils provides don't
work; they don't recognize anything like alpha/beta/rc/dev versions.
"""
# Copyright (c) Donald Stufft and individual contributors.
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
import collections
import itertools
import re
__all__ = [
"parse", "Version", "LegacyVersion", "InvalidVersion", "VERSION_PATTERN",
]
# BEGIN packaging/_structures.py
class Infinity:
def __repr__(self):
return "Infinity"
def __hash__(self):
return hash(repr(self))
def __lt__(self, other):
return False
def __le__(self, other):
return False
def __eq__(self, other):
return isinstance(other, self.__class__)
def __ne__(self, other):
return not isinstance(other, self.__class__)
def __gt__(self, other):
return True
def __ge__(self, other):
return True
def __neg__(self):
return NegativeInfinity
Infinity = Infinity()
class NegativeInfinity:
def __repr__(self):
return "-Infinity"
def __hash__(self):
return hash(repr(self))
def __lt__(self, other):
return True
def __le__(self, other):
return True
def __eq__(self, other):
return isinstance(other, self.__class__)
def __ne__(self, other):
return not isinstance(other, self.__class__)
def __gt__(self, other):
return False
def __ge__(self, other):
return False
def __neg__(self):
return Infinity
# BEGIN packaging/version.py
NegativeInfinity = NegativeInfinity()
_Version = collections.namedtuple(
"_Version",
["epoch", "release", "dev", "pre", "post", "local"],
)
def parse(version):
"""
Parse the given version string and return either a :class:`Version` object
or a :class:`LegacyVersion` object depending on if the given version is
a valid PEP 440 version or a legacy version.
"""
try:
return Version(version)
except InvalidVersion:
return LegacyVersion(version)
class InvalidVersion(ValueError):
"""
An invalid version was found, users should refer to PEP 440.
"""
class _BaseVersion:
def __hash__(self):
return hash(self._key)
def __lt__(self, other):
return self._compare(other, lambda s, o: s < o)
def __le__(self, other):
return self._compare(other, lambda s, o: s <= o)
def __eq__(self, other):
return self._compare(other, lambda s, o: s == o)
def __ge__(self, other):
return self._compare(other, lambda s, o: s >= o)
def __gt__(self, other):
return self._compare(other, lambda s, o: s > o)
def __ne__(self, other):
return self._compare(other, lambda s, o: s != o)
def _compare(self, other, method):
if not isinstance(other, _BaseVersion):
return NotImplemented
return method(self._key, other._key)
class LegacyVersion(_BaseVersion):
def __init__(self, version):
self._version = str(version)
self._key = _legacy_cmpkey(self._version)
def __str__(self):
return self._version
def __repr__(self):
return f"<LegacyVersion({repr(str(self))})>"
@property
def public(self):
return self._version
@property
def base_version(self):
return self._version
@property
def local(self):
return None
@property
def is_prerelease(self):
return False
@property
def is_postrelease(self):
return False
_legacy_version_component_re = re.compile(
r"(\d+ | [a-z]+ | \.| -)", re.VERBOSE,
)
_legacy_version_replacement_map = {
"pre": "c", "preview": "c", "-": "final-", "rc": "c", "dev": "@",
}
def _parse_version_parts(s):
for part in _legacy_version_component_re.split(s):
part = _legacy_version_replacement_map.get(part, part)
if not part or part == ".":
continue
if part[:1] in "0123456789":
# pad for numeric comparison
yield part.zfill(8)
else:
yield "*" + part
# ensure that alpha/beta/candidate are before final
yield "*final"
def _legacy_cmpkey(version):
# We hardcode an epoch of -1 here. A PEP 440 version can only have an epoch
# greater than or equal to 0. This will effectively put the LegacyVersion,
# which uses the defacto standard originally implemented by setuptools,
# as before all PEP 440 versions.
epoch = -1
# This scheme is taken from pkg_resources.parse_version setuptools prior to
# its adoption of the packaging library.
parts = []
for part in _parse_version_parts(version.lower()):
if part.startswith("*"):
# remove "-" before a prerelease tag
if part < "*final":
while parts and parts[-1] == "*final-":
parts.pop()
# remove trailing zeros from each series of numeric parts
while parts and parts[-1] == "00000000":
parts.pop()
parts.append(part)
parts = tuple(parts)
return epoch, parts
# Deliberately not anchored to the start and end of the string, to make it
# easier for 3rd party code to reuse
VERSION_PATTERN = r"""
v?
(?:
(?:(?P<epoch>[0-9]+)!)? # epoch
(?P<release>[0-9]+(?:\.[0-9]+)*) # release segment
(?P<pre> # pre-release
[-_\.]?
(?P<pre_l>(a|b|c|rc|alpha|beta|pre|preview))
[-_\.]?
(?P<pre_n>[0-9]+)?
)?
(?P<post> # post release
(?:-(?P<post_n1>[0-9]+))
|
(?:
[-_\.]?
(?P<post_l>post|rev|r)
[-_\.]?
(?P<post_n2>[0-9]+)?
)
)?
(?P<dev> # dev release
[-_\.]?
(?P<dev_l>dev)
[-_\.]?
(?P<dev_n>[0-9]+)?
)?
)
(?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
"""
class Version(_BaseVersion):
_regex = re.compile(
r"^\s*" + VERSION_PATTERN + r"\s*$",
re.VERBOSE | re.IGNORECASE,
)
def __init__(self, version):
# Validate the version and parse it into pieces
match = self._regex.search(version)
if not match:
raise InvalidVersion(f"Invalid version: '{version}'")
# Store the parsed out pieces of the version
self._version = _Version(
epoch=int(match.group("epoch")) if match.group("epoch") else 0,
release=tuple(int(i) for i in match.group("release").split(".")),
pre=_parse_letter_version(
match.group("pre_l"),
match.group("pre_n"),
),
post=_parse_letter_version(
match.group("post_l"),
match.group("post_n1") or match.group("post_n2"),
),
dev=_parse_letter_version(
match.group("dev_l"),
match.group("dev_n"),
),
local=_parse_local_version(match.group("local")),
)
# Generate a key which will be used for sorting
self._key = _cmpkey(
self._version.epoch,
self._version.release,
self._version.pre,
self._version.post,
self._version.dev,
self._version.local,
)
def __repr__(self):
return f"<Version({repr(str(self))})>"
def __str__(self):
parts = []
# Epoch
if self._version.epoch != 0:
parts.append(f"{self._version.epoch}!")
# Release segment
parts.append(".".join(str(x) for x in self._version.release))
# Pre-release
if self._version.pre is not None:
parts.append("".join(str(x) for x in self._version.pre))
# Post-release
if self._version.post is not None:
parts.append(f".post{self._version.post[1]}")
# Development release
if self._version.dev is not None:
parts.append(f".dev{self._version.dev[1]}")
# Local version segment
if self._version.local is not None:
parts.append(
"+{}".format(".".join(str(x) for x in self._version.local))
)
return "".join(parts)
@property
def public(self):
return str(self).split("+", 1)[0]
@property
def base_version(self):
parts = []
# Epoch
if self._version.epoch != 0:
parts.append(f"{self._version.epoch}!")
# Release segment
parts.append(".".join(str(x) for x in self._version.release))
return "".join(parts)
@property
def local(self):
version_string = str(self)
if "+" in version_string:
return version_string.split("+", 1)[1]
@property
def is_prerelease(self):
return bool(self._version.dev or self._version.pre)
@property
def is_postrelease(self):
return bool(self._version.post)
def _parse_letter_version(letter, number):
if letter:
# We assume there is an implicit 0 in a pre-release if there is
# no numeral associated with it.
if number is None:
number = 0
# We normalize any letters to their lower-case form
letter = letter.lower()
# We consider some words to be alternate spellings of other words and
# in those cases we want to normalize the spellings to our preferred
# spelling.
if letter == "alpha":
letter = "a"
elif letter == "beta":
letter = "b"
elif letter in ["c", "pre", "preview"]:
letter = "rc"
elif letter in ["rev", "r"]:
letter = "post"
return letter, int(number)
if not letter and number:
# We assume that if we are given a number but not given a letter,
# then this is using the implicit post release syntax (e.g., 1.0-1)
letter = "post"
return letter, int(number)
_local_version_seperators = re.compile(r"[\._-]")
def _parse_local_version(local):
"""
Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
"""
if local is not None:
return tuple(
part.lower() if not part.isdigit() else int(part)
for part in _local_version_seperators.split(local)
)
def _cmpkey(epoch, release, pre, post, dev, local):
# When we compare a release version, we want to compare it with all of the
# trailing zeros removed. So we'll use a reverse the list, drop all the now
# leading zeros until we come to something non-zero, then take the rest,
# re-reverse it back into the correct order, and make it a tuple and use
# that for our sorting key.
release = tuple(
reversed(list(
itertools.dropwhile(
lambda x: x == 0,
reversed(release),
)
))
)
# We need to "trick" the sorting algorithm to put 1.0.dev0 before 1.0a0.
# We'll do this by abusing the pre-segment, but we _only_ want to do this
# if there is no pre- or a post-segment. If we have one of those, then
# the normal sorting rules will handle this case correctly.
if pre is None and post is None and dev is not None:
pre = -Infinity
# Versions without a pre-release (except as noted above) should sort after
# those with one.
elif pre is None:
pre = Infinity
# Versions without a post-segment should sort before those with one.
if post is None:
post = -Infinity
# Versions without a development segment should sort after those with one.
if dev is None:
dev = Infinity
if local is None:
# Versions without a local segment should sort before those with one.
local = -Infinity
else:
# Versions with a local segment need that segment parsed to implement
# the sorting rules in PEP440.
# - Alphanumeric segments sort before numeric segments
# - Alphanumeric segments sort lexicographically
# - Numeric segments sort numerically
# - Shorter versions sort before longer versions when the prefixes
# match exactly
local = tuple(
(i, "") if isinstance(i, int) else (-Infinity, i)
for i in local
)
return epoch, release, pre, post, dev, local

View File

@ -0,0 +1,337 @@
"""
Generic test utilities.
"""
import inspect
import os
import re
import shutil
import subprocess
import sys
import sysconfig
from importlib.util import module_from_spec, spec_from_file_location
import numpy as np
import scipy
try:
# Need type: ignore[import-untyped] for mypy >= 1.6
import cython # type: ignore[import-untyped]
from Cython.Compiler.Version import ( # type: ignore[import-untyped]
version as cython_version,
)
except ImportError:
cython = None
else:
from scipy._lib import _pep440
required_version = '3.0.8'
if _pep440.parse(cython_version) < _pep440.Version(required_version):
# too old or wrong cython, skip Cython API tests
cython = None
__all__ = ['PytestTester', 'check_free_memory', '_TestPythranFunc', 'IS_MUSL']
IS_MUSL = False
# alternate way is
# from packaging.tags import sys_tags
# _tags = list(sys_tags())
# if 'musllinux' in _tags[0].platform:
_v = sysconfig.get_config_var('HOST_GNU_TYPE') or ''
if 'musl' in _v:
IS_MUSL = True
IS_EDITABLE = 'editable' in scipy.__path__[0]
class FPUModeChangeWarning(RuntimeWarning):
"""Warning about FPU mode change"""
pass
class PytestTester:
"""
Run tests for this namespace
``scipy.test()`` runs tests for all of SciPy, with the default settings.
When used from a submodule (e.g., ``scipy.cluster.test()``, only the tests
for that namespace are run.
Parameters
----------
label : {'fast', 'full'}, optional
Whether to run only the fast tests, or also those marked as slow.
Default is 'fast'.
verbose : int, optional
Test output verbosity. Default is 1.
extra_argv : list, optional
Arguments to pass through to Pytest.
doctests : bool, optional
Whether to run doctests or not. Default is False.
coverage : bool, optional
Whether to run tests with code coverage measurements enabled.
Default is False.
tests : list of str, optional
List of module names to run tests for. By default, uses the module
from which the ``test`` function is called.
parallel : int, optional
Run tests in parallel with pytest-xdist, if number given is larger than
1. Default is 1.
"""
def __init__(self, module_name):
self.module_name = module_name
def __call__(self, label="fast", verbose=1, extra_argv=None, doctests=False,
coverage=False, tests=None, parallel=None):
import pytest
module = sys.modules[self.module_name]
module_path = os.path.abspath(module.__path__[0])
pytest_args = ['--showlocals', '--tb=short']
if doctests:
pytest_args += [
"--doctest-modules",
"--ignore=scipy/interpolate/_interpnd_info.py",
"--ignore=scipy/_lib/array_api_compat",
"--ignore=scipy/_lib/highs",
"--ignore=scipy/_lib/unuran",
"--ignore=scipy/_lib/_gcutils.py",
"--ignore=scipy/_lib/doccer.py",
"--ignore=scipy/_lib/_uarray",
]
if extra_argv:
pytest_args += list(extra_argv)
if verbose and int(verbose) > 1:
pytest_args += ["-" + "v"*(int(verbose)-1)]
if coverage:
pytest_args += ["--cov=" + module_path]
if label == "fast":
pytest_args += ["-m", "not slow"]
elif label != "full":
pytest_args += ["-m", label]
if tests is None:
tests = [self.module_name]
if parallel is not None and parallel > 1:
if _pytest_has_xdist():
pytest_args += ['-n', str(parallel)]
else:
import warnings
warnings.warn('Could not run tests in parallel because '
'pytest-xdist plugin is not available.',
stacklevel=2)
pytest_args += ['--pyargs'] + list(tests)
try:
code = pytest.main(pytest_args)
except SystemExit as exc:
code = exc.code
return (code == 0)
class _TestPythranFunc:
'''
These are situations that can be tested in our pythran tests:
- A function with multiple array arguments and then
other positional and keyword arguments.
- A function with array-like keywords (e.g. `def somefunc(x0, x1=None)`.
Note: list/tuple input is not yet tested!
`self.arguments`: A dictionary which key is the index of the argument,
value is tuple(array value, all supported dtypes)
`self.partialfunc`: A function used to freeze some non-array argument
that of no interests in the original function
'''
ALL_INTEGER = [np.int8, np.int16, np.int32, np.int64, np.intc, np.intp]
ALL_FLOAT = [np.float32, np.float64]
ALL_COMPLEX = [np.complex64, np.complex128]
def setup_method(self):
self.arguments = {}
self.partialfunc = None
self.expected = None
def get_optional_args(self, func):
# get optional arguments with its default value,
# used for testing keywords
signature = inspect.signature(func)
optional_args = {}
for k, v in signature.parameters.items():
if v.default is not inspect.Parameter.empty:
optional_args[k] = v.default
return optional_args
def get_max_dtype_list_length(self):
# get the max supported dtypes list length in all arguments
max_len = 0
for arg_idx in self.arguments:
cur_len = len(self.arguments[arg_idx][1])
if cur_len > max_len:
max_len = cur_len
return max_len
def get_dtype(self, dtype_list, dtype_idx):
# get the dtype from dtype_list via index
# if the index is out of range, then return the last dtype
if dtype_idx > len(dtype_list)-1:
return dtype_list[-1]
else:
return dtype_list[dtype_idx]
def test_all_dtypes(self):
for type_idx in range(self.get_max_dtype_list_length()):
args_array = []
for arg_idx in self.arguments:
new_dtype = self.get_dtype(self.arguments[arg_idx][1],
type_idx)
args_array.append(self.arguments[arg_idx][0].astype(new_dtype))
self.pythranfunc(*args_array)
def test_views(self):
args_array = []
for arg_idx in self.arguments:
args_array.append(self.arguments[arg_idx][0][::-1][::-1])
self.pythranfunc(*args_array)
def test_strided(self):
args_array = []
for arg_idx in self.arguments:
args_array.append(np.repeat(self.arguments[arg_idx][0],
2, axis=0)[::2])
self.pythranfunc(*args_array)
def _pytest_has_xdist():
"""
Check if the pytest-xdist plugin is installed, providing parallel tests
"""
# Check xdist exists without importing, otherwise pytests emits warnings
from importlib.util import find_spec
return find_spec('xdist') is not None
def check_free_memory(free_mb):
"""
Check *free_mb* of memory is available, otherwise do pytest.skip
"""
import pytest
try:
mem_free = _parse_size(os.environ['SCIPY_AVAILABLE_MEM'])
msg = '{} MB memory required, but environment SCIPY_AVAILABLE_MEM={}'.format(
free_mb, os.environ['SCIPY_AVAILABLE_MEM'])
except KeyError:
mem_free = _get_mem_available()
if mem_free is None:
pytest.skip("Could not determine available memory; set SCIPY_AVAILABLE_MEM "
"variable to free memory in MB to run the test.")
msg = f'{free_mb} MB memory required, but {mem_free/1e6} MB available'
if mem_free < free_mb * 1e6:
pytest.skip(msg)
def _parse_size(size_str):
suffixes = {'': 1e6,
'b': 1.0,
'k': 1e3, 'M': 1e6, 'G': 1e9, 'T': 1e12,
'kb': 1e3, 'Mb': 1e6, 'Gb': 1e9, 'Tb': 1e12,
'kib': 1024.0, 'Mib': 1024.0**2, 'Gib': 1024.0**3, 'Tib': 1024.0**4}
m = re.match(r'^\s*(\d+)\s*({})\s*$'.format('|'.join(suffixes.keys())),
size_str,
re.I)
if not m or m.group(2) not in suffixes:
raise ValueError("Invalid size string")
return float(m.group(1)) * suffixes[m.group(2)]
def _get_mem_available():
"""
Get information about memory available, not counting swap.
"""
try:
import psutil
return psutil.virtual_memory().available
except (ImportError, AttributeError):
pass
if sys.platform.startswith('linux'):
info = {}
with open('/proc/meminfo') as f:
for line in f:
p = line.split()
info[p[0].strip(':').lower()] = float(p[1]) * 1e3
if 'memavailable' in info:
# Linux >= 3.14
return info['memavailable']
else:
return info['memfree'] + info['cached']
return None
def _test_cython_extension(tmp_path, srcdir):
"""
Helper function to test building and importing Cython modules that
make use of the Cython APIs for BLAS, LAPACK, optimize, and special.
"""
import pytest
try:
subprocess.check_call(["meson", "--version"])
except FileNotFoundError:
pytest.skip("No usable 'meson' found")
# build the examples in a temporary directory
mod_name = os.path.split(srcdir)[1]
shutil.copytree(srcdir, tmp_path / mod_name)
build_dir = tmp_path / mod_name / 'tests' / '_cython_examples'
target_dir = build_dir / 'build'
os.makedirs(target_dir, exist_ok=True)
# Ensure we use the correct Python interpreter even when `meson` is
# installed in a different Python environment (see numpy#24956)
native_file = str(build_dir / 'interpreter-native-file.ini')
with open(native_file, 'w') as f:
f.write("[binaries]\n")
f.write(f"python = '{sys.executable}'")
if sys.platform == "win32":
subprocess.check_call(["meson", "setup",
"--buildtype=release",
"--native-file", native_file,
"--vsenv", str(build_dir)],
cwd=target_dir,
)
else:
subprocess.check_call(["meson", "setup",
"--native-file", native_file, str(build_dir)],
cwd=target_dir
)
subprocess.check_call(["meson", "compile", "-vv"], cwd=target_dir)
# import without adding the directory to sys.path
suffix = sysconfig.get_config_var('EXT_SUFFIX')
def load(modname):
so = (target_dir / modname).with_suffix(suffix)
spec = spec_from_file_location(modname, so)
mod = module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
# test that the module can be imported
return load("extending"), load("extending_cpp")

View File

@ -0,0 +1,58 @@
import threading
import scipy._lib.decorator
__all__ = ['ReentrancyError', 'ReentrancyLock', 'non_reentrant']
class ReentrancyError(RuntimeError):
pass
class ReentrancyLock:
"""
Threading lock that raises an exception for reentrant calls.
Calls from different threads are serialized, and nested calls from the
same thread result to an error.
The object can be used as a context manager or to decorate functions
via the decorate() method.
"""
def __init__(self, err_msg):
self._rlock = threading.RLock()
self._entered = False
self._err_msg = err_msg
def __enter__(self):
self._rlock.acquire()
if self._entered:
self._rlock.release()
raise ReentrancyError(self._err_msg)
self._entered = True
def __exit__(self, type, value, traceback):
self._entered = False
self._rlock.release()
def decorate(self, func):
def caller(func, *a, **kw):
with self:
return func(*a, **kw)
return scipy._lib.decorator.decorate(func, caller)
def non_reentrant(err_msg=None):
"""
Decorate a function with a threading lock and prevent reentrant calls.
"""
def decorator(func):
msg = err_msg
if msg is None:
msg = "%s is not re-entrant" % func.__name__
lock = ReentrancyLock(msg)
return lock.decorate(func)
return decorator

View File

@ -0,0 +1,86 @@
''' Contexts for *with* statement providing temporary directories
'''
import os
from contextlib import contextmanager
from shutil import rmtree
from tempfile import mkdtemp
@contextmanager
def tempdir():
"""Create and return a temporary directory. This has the same
behavior as mkdtemp but can be used as a context manager.
Upon exiting the context, the directory and everything contained
in it are removed.
Examples
--------
>>> import os
>>> with tempdir() as tmpdir:
... fname = os.path.join(tmpdir, 'example_file.txt')
... with open(fname, 'wt') as fobj:
... _ = fobj.write('a string\\n')
>>> os.path.exists(tmpdir)
False
"""
d = mkdtemp()
yield d
rmtree(d)
@contextmanager
def in_tempdir():
''' Create, return, and change directory to a temporary directory
Examples
--------
>>> import os
>>> my_cwd = os.getcwd()
>>> with in_tempdir() as tmpdir:
... _ = open('test.txt', 'wt').write('some text')
... assert os.path.isfile('test.txt')
... assert os.path.isfile(os.path.join(tmpdir, 'test.txt'))
>>> os.path.exists(tmpdir)
False
>>> os.getcwd() == my_cwd
True
'''
pwd = os.getcwd()
d = mkdtemp()
os.chdir(d)
yield d
os.chdir(pwd)
rmtree(d)
@contextmanager
def in_dir(dir=None):
""" Change directory to given directory for duration of ``with`` block
Useful when you want to use `in_tempdir` for the final test, but
you are still debugging. For example, you may want to do this in the end:
>>> with in_tempdir() as tmpdir:
... # do something complicated which might break
... pass
But, indeed, the complicated thing does break, and meanwhile, the
``in_tempdir`` context manager wiped out the directory with the
temporary files that you wanted for debugging. So, while debugging, you
replace with something like:
>>> with in_dir() as tmpdir: # Use working directory by default
... # do something complicated which might break
... pass
You can then look at the temporary file outputs to debug what is happening,
fix, and finally replace ``in_dir`` with ``in_tempdir`` again.
"""
cwd = os.getcwd()
if dir is None:
yield cwd
return
os.chdir(dir)
yield dir
os.chdir(cwd)

View File

@ -0,0 +1,29 @@
BSD 3-Clause License
Copyright (c) 2018, Quansight-Labs
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -0,0 +1,116 @@
"""
.. note:
If you are looking for overrides for NumPy-specific methods, see the
documentation for :obj:`unumpy`. This page explains how to write
back-ends and multimethods.
``uarray`` is built around a back-end protocol, and overridable multimethods.
It is necessary to define multimethods for back-ends to be able to override them.
See the documentation of :obj:`generate_multimethod` on how to write multimethods.
Let's start with the simplest:
``__ua_domain__`` defines the back-end *domain*. The domain consists of period-
separated string consisting of the modules you extend plus the submodule. For
example, if a submodule ``module2.submodule`` extends ``module1``
(i.e., it exposes dispatchables marked as types available in ``module1``),
then the domain string should be ``"module1.module2.submodule"``.
For the purpose of this demonstration, we'll be creating an object and setting
its attributes directly. However, note that you can use a module or your own type
as a backend as well.
>>> class Backend: pass
>>> be = Backend()
>>> be.__ua_domain__ = "ua_examples"
It might be useful at this point to sidetrack to the documentation of
:obj:`generate_multimethod` to find out how to generate a multimethod
overridable by :obj:`uarray`. Needless to say, writing a backend and
creating multimethods are mostly orthogonal activities, and knowing
one doesn't necessarily require knowledge of the other, although it
is certainly helpful. We expect core API designers/specifiers to write the
multimethods, and implementors to override them. But, as is often the case,
similar people write both.
Without further ado, here's an example multimethod:
>>> import uarray as ua
>>> from uarray import Dispatchable
>>> def override_me(a, b):
... return Dispatchable(a, int),
>>> def override_replacer(args, kwargs, dispatchables):
... return (dispatchables[0], args[1]), {}
>>> overridden_me = ua.generate_multimethod(
... override_me, override_replacer, "ua_examples"
... )
Next comes the part about overriding the multimethod. This requires
the ``__ua_function__`` protocol, and the ``__ua_convert__``
protocol. The ``__ua_function__`` protocol has the signature
``(method, args, kwargs)`` where ``method`` is the passed
multimethod, ``args``/``kwargs`` specify the arguments and ``dispatchables``
is the list of converted dispatchables passed in.
>>> def __ua_function__(method, args, kwargs):
... return method.__name__, args, kwargs
>>> be.__ua_function__ = __ua_function__
The other protocol of interest is the ``__ua_convert__`` protocol. It has the
signature ``(dispatchables, coerce)``. When ``coerce`` is ``False``, conversion
between the formats should ideally be an ``O(1)`` operation, but it means that
no memory copying should be involved, only views of the existing data.
>>> def __ua_convert__(dispatchables, coerce):
... for d in dispatchables:
... if d.type is int:
... if coerce and d.coercible:
... yield str(d.value)
... else:
... yield d.value
>>> be.__ua_convert__ = __ua_convert__
Now that we have defined the backend, the next thing to do is to call the multimethod.
>>> with ua.set_backend(be):
... overridden_me(1, "2")
('override_me', (1, '2'), {})
Note that the marked type has no effect on the actual type of the passed object.
We can also coerce the type of the input.
>>> with ua.set_backend(be, coerce=True):
... overridden_me(1, "2")
... overridden_me(1.0, "2")
('override_me', ('1', '2'), {})
('override_me', ('1.0', '2'), {})
Another feature is that if you remove ``__ua_convert__``, the arguments are not
converted at all and it's up to the backend to handle that.
>>> del be.__ua_convert__
>>> with ua.set_backend(be):
... overridden_me(1, "2")
('override_me', (1, '2'), {})
You also have the option to return ``NotImplemented``, in which case processing moves on
to the next back-end, which in this case, doesn't exist. The same applies to
``__ua_convert__``.
>>> be.__ua_function__ = lambda *a, **kw: NotImplemented
>>> with ua.set_backend(be):
... overridden_me(1, "2")
Traceback (most recent call last):
...
uarray.BackendNotImplementedError: ...
The last possibility is if we don't have ``__ua_convert__``, in which case the job is
left up to ``__ua_function__``, but putting things back into arrays after conversion
will not be possible.
"""
from ._backend import *
__version__ = '0.8.8.dev0+aa94c5a4.scipy'

View File

@ -0,0 +1,704 @@
import typing
import types
import inspect
import functools
from . import _uarray
import copyreg
import pickle
import contextlib
from ._uarray import ( # type: ignore
BackendNotImplementedError,
_Function,
_SkipBackendContext,
_SetBackendContext,
_BackendState,
)
__all__ = [
"set_backend",
"set_global_backend",
"skip_backend",
"register_backend",
"determine_backend",
"determine_backend_multi",
"clear_backends",
"create_multimethod",
"generate_multimethod",
"_Function",
"BackendNotImplementedError",
"Dispatchable",
"wrap_single_convertor",
"wrap_single_convertor_instance",
"all_of_type",
"mark_as",
"set_state",
"get_state",
"reset_state",
"_BackendState",
"_SkipBackendContext",
"_SetBackendContext",
]
ArgumentExtractorType = typing.Callable[..., tuple["Dispatchable", ...]]
ArgumentReplacerType = typing.Callable[
[tuple, dict, tuple], tuple[tuple, dict]
]
def unpickle_function(mod_name, qname, self_):
import importlib
try:
module = importlib.import_module(mod_name)
qname = qname.split(".")
func = module
for q in qname:
func = getattr(func, q)
if self_ is not None:
func = types.MethodType(func, self_)
return func
except (ImportError, AttributeError) as e:
from pickle import UnpicklingError
raise UnpicklingError from e
def pickle_function(func):
mod_name = getattr(func, "__module__", None)
qname = getattr(func, "__qualname__", None)
self_ = getattr(func, "__self__", None)
try:
test = unpickle_function(mod_name, qname, self_)
except pickle.UnpicklingError:
test = None
if test is not func:
raise pickle.PicklingError(
f"Can't pickle {func}: it's not the same object as {test}"
)
return unpickle_function, (mod_name, qname, self_)
def pickle_state(state):
return _uarray._BackendState._unpickle, state._pickle()
def pickle_set_backend_context(ctx):
return _SetBackendContext, ctx._pickle()
def pickle_skip_backend_context(ctx):
return _SkipBackendContext, ctx._pickle()
copyreg.pickle(_Function, pickle_function)
copyreg.pickle(_uarray._BackendState, pickle_state)
copyreg.pickle(_SetBackendContext, pickle_set_backend_context)
copyreg.pickle(_SkipBackendContext, pickle_skip_backend_context)
def get_state():
"""
Returns an opaque object containing the current state of all the backends.
Can be used for synchronization between threads/processes.
See Also
--------
set_state
Sets the state returned by this function.
"""
return _uarray.get_state()
@contextlib.contextmanager
def reset_state():
"""
Returns a context manager that resets all state once exited.
See Also
--------
set_state
Context manager that sets the backend state.
get_state
Gets a state to be set by this context manager.
"""
with set_state(get_state()):
yield
@contextlib.contextmanager
def set_state(state):
"""
A context manager that sets the state of the backends to one returned by :obj:`get_state`.
See Also
--------
get_state
Gets a state to be set by this context manager.
""" # noqa: E501
old_state = get_state()
_uarray.set_state(state)
try:
yield
finally:
_uarray.set_state(old_state, True)
def create_multimethod(*args, **kwargs):
"""
Creates a decorator for generating multimethods.
This function creates a decorator that can be used with an argument
extractor in order to generate a multimethod. Other than for the
argument extractor, all arguments are passed on to
:obj:`generate_multimethod`.
See Also
--------
generate_multimethod
Generates a multimethod.
"""
def wrapper(a):
return generate_multimethod(a, *args, **kwargs)
return wrapper
def generate_multimethod(
argument_extractor: ArgumentExtractorType,
argument_replacer: ArgumentReplacerType,
domain: str,
default: typing.Optional[typing.Callable] = None,
):
"""
Generates a multimethod.
Parameters
----------
argument_extractor : ArgumentExtractorType
A callable which extracts the dispatchable arguments. Extracted arguments
should be marked by the :obj:`Dispatchable` class. It has the same signature
as the desired multimethod.
argument_replacer : ArgumentReplacerType
A callable with the signature (args, kwargs, dispatchables), which should also
return an (args, kwargs) pair with the dispatchables replaced inside the
args/kwargs.
domain : str
A string value indicating the domain of this multimethod.
default: Optional[Callable], optional
The default implementation of this multimethod, where ``None`` (the default)
specifies there is no default implementation.
Examples
--------
In this example, ``a`` is to be dispatched over, so we return it, while marking it
as an ``int``.
The trailing comma is needed because the args have to be returned as an iterable.
>>> def override_me(a, b):
... return Dispatchable(a, int),
Next, we define the argument replacer that replaces the dispatchables inside
args/kwargs with the supplied ones.
>>> def override_replacer(args, kwargs, dispatchables):
... return (dispatchables[0], args[1]), {}
Next, we define the multimethod.
>>> overridden_me = generate_multimethod(
... override_me, override_replacer, "ua_examples"
... )
Notice that there's no default implementation, unless you supply one.
>>> overridden_me(1, "a")
Traceback (most recent call last):
...
uarray.BackendNotImplementedError: ...
>>> overridden_me2 = generate_multimethod(
... override_me, override_replacer, "ua_examples", default=lambda x, y: (x, y)
... )
>>> overridden_me2(1, "a")
(1, 'a')
See Also
--------
uarray
See the module documentation for how to override the method by creating
backends.
"""
kw_defaults, arg_defaults, opts = get_defaults(argument_extractor)
ua_func = _Function(
argument_extractor,
argument_replacer,
domain,
arg_defaults,
kw_defaults,
default,
)
return functools.update_wrapper(ua_func, argument_extractor)
def set_backend(backend, coerce=False, only=False):
"""
A context manager that sets the preferred backend.
Parameters
----------
backend
The backend to set.
coerce
Whether or not to coerce to a specific backend's types. Implies ``only``.
only
Whether or not this should be the last backend to try.
See Also
--------
skip_backend: A context manager that allows skipping of backends.
set_global_backend: Set a single, global backend for a domain.
"""
try:
return backend.__ua_cache__["set", coerce, only]
except AttributeError:
backend.__ua_cache__ = {}
except KeyError:
pass
ctx = _SetBackendContext(backend, coerce, only)
backend.__ua_cache__["set", coerce, only] = ctx
return ctx
def skip_backend(backend):
"""
A context manager that allows one to skip a given backend from processing
entirely. This allows one to use another backend's code in a library that
is also a consumer of the same backend.
Parameters
----------
backend
The backend to skip.
See Also
--------
set_backend: A context manager that allows setting of backends.
set_global_backend: Set a single, global backend for a domain.
"""
try:
return backend.__ua_cache__["skip"]
except AttributeError:
backend.__ua_cache__ = {}
except KeyError:
pass
ctx = _SkipBackendContext(backend)
backend.__ua_cache__["skip"] = ctx
return ctx
def get_defaults(f):
sig = inspect.signature(f)
kw_defaults = {}
arg_defaults = []
opts = set()
for k, v in sig.parameters.items():
if v.default is not inspect.Parameter.empty:
kw_defaults[k] = v.default
if v.kind in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
):
arg_defaults.append(v.default)
opts.add(k)
return kw_defaults, tuple(arg_defaults), opts
def set_global_backend(backend, coerce=False, only=False, *, try_last=False):
"""
This utility method replaces the default backend for permanent use. It
will be tried in the list of backends automatically, unless the
``only`` flag is set on a backend. This will be the first tried
backend outside the :obj:`set_backend` context manager.
Note that this method is not thread-safe.
.. warning::
We caution library authors against using this function in
their code. We do *not* support this use-case. This function
is meant to be used only by users themselves, or by a reference
implementation, if one exists.
Parameters
----------
backend
The backend to register.
coerce : bool
Whether to coerce input types when trying this backend.
only : bool
If ``True``, no more backends will be tried if this fails.
Implied by ``coerce=True``.
try_last : bool
If ``True``, the global backend is tried after registered backends.
See Also
--------
set_backend: A context manager that allows setting of backends.
skip_backend: A context manager that allows skipping of backends.
"""
_uarray.set_global_backend(backend, coerce, only, try_last)
def register_backend(backend):
"""
This utility method sets registers backend for permanent use. It
will be tried in the list of backends automatically, unless the
``only`` flag is set on a backend.
Note that this method is not thread-safe.
Parameters
----------
backend
The backend to register.
"""
_uarray.register_backend(backend)
def clear_backends(domain, registered=True, globals=False):
"""
This utility method clears registered backends.
.. warning::
We caution library authors against using this function in
their code. We do *not* support this use-case. This function
is meant to be used only by users themselves.
.. warning::
Do NOT use this method inside a multimethod call, or the
program is likely to crash.
Parameters
----------
domain : Optional[str]
The domain for which to de-register backends. ``None`` means
de-register for all domains.
registered : bool
Whether or not to clear registered backends. See :obj:`register_backend`.
globals : bool
Whether or not to clear global backends. See :obj:`set_global_backend`.
See Also
--------
register_backend : Register a backend globally.
set_global_backend : Set a global backend.
"""
_uarray.clear_backends(domain, registered, globals)
class Dispatchable:
"""
A utility class which marks an argument with a specific dispatch type.
Attributes
----------
value
The value of the Dispatchable.
type
The type of the Dispatchable.
Examples
--------
>>> x = Dispatchable(1, str)
>>> x
<Dispatchable: type=<class 'str'>, value=1>
See Also
--------
all_of_type
Marks all unmarked parameters of a function.
mark_as
Allows one to create a utility function to mark as a given type.
"""
def __init__(self, value, dispatch_type, coercible=True):
self.value = value
self.type = dispatch_type
self.coercible = coercible
def __getitem__(self, index):
return (self.type, self.value)[index]
def __str__(self):
return f"<{type(self).__name__}: type={self.type!r}, value={self.value!r}>"
__repr__ = __str__
def mark_as(dispatch_type):
"""
Creates a utility function to mark something as a specific type.
Examples
--------
>>> mark_int = mark_as(int)
>>> mark_int(1)
<Dispatchable: type=<class 'int'>, value=1>
"""
return functools.partial(Dispatchable, dispatch_type=dispatch_type)
def all_of_type(arg_type):
"""
Marks all unmarked arguments as a given type.
Examples
--------
>>> @all_of_type(str)
... def f(a, b):
... return a, Dispatchable(b, int)
>>> f('a', 1)
(<Dispatchable: type=<class 'str'>, value='a'>,
<Dispatchable: type=<class 'int'>, value=1>)
"""
def outer(func):
@functools.wraps(func)
def inner(*args, **kwargs):
extracted_args = func(*args, **kwargs)
return tuple(
Dispatchable(arg, arg_type)
if not isinstance(arg, Dispatchable)
else arg
for arg in extracted_args
)
return inner
return outer
def wrap_single_convertor(convert_single):
"""
Wraps a ``__ua_convert__`` defined for a single element to all elements.
If any of them return ``NotImplemented``, the operation is assumed to be
undefined.
Accepts a signature of (value, type, coerce).
"""
@functools.wraps(convert_single)
def __ua_convert__(dispatchables, coerce):
converted = []
for d in dispatchables:
c = convert_single(d.value, d.type, coerce and d.coercible)
if c is NotImplemented:
return NotImplemented
converted.append(c)
return converted
return __ua_convert__
def wrap_single_convertor_instance(convert_single):
"""
Wraps a ``__ua_convert__`` defined for a single element to all elements.
If any of them return ``NotImplemented``, the operation is assumed to be
undefined.
Accepts a signature of (value, type, coerce).
"""
@functools.wraps(convert_single)
def __ua_convert__(self, dispatchables, coerce):
converted = []
for d in dispatchables:
c = convert_single(self, d.value, d.type, coerce and d.coercible)
if c is NotImplemented:
return NotImplemented
converted.append(c)
return converted
return __ua_convert__
def determine_backend(value, dispatch_type, *, domain, only=True, coerce=False):
"""Set the backend to the first active backend that supports ``value``
This is useful for functions that call multimethods without any dispatchable
arguments. You can use :func:`determine_backend` to ensure the same backend
is used everywhere in a block of multimethod calls.
Parameters
----------
value
The value being tested
dispatch_type
The dispatch type associated with ``value``, aka
":ref:`marking <MarkingGlossary>`".
domain: string
The domain to query for backends and set.
coerce: bool
Whether or not to allow coercion to the backend's types. Implies ``only``.
only: bool
Whether or not this should be the last backend to try.
See Also
--------
set_backend: For when you know which backend to set
Notes
-----
Support is determined by the ``__ua_convert__`` protocol. Backends not
supporting the type must return ``NotImplemented`` from their
``__ua_convert__`` if they don't support input of that type.
Examples
--------
Suppose we have two backends ``BackendA`` and ``BackendB`` each supporting
different types, ``TypeA`` and ``TypeB``. Neither supporting the other type:
>>> with ua.set_backend(ex.BackendA):
... ex.call_multimethod(ex.TypeB(), ex.TypeB())
Traceback (most recent call last):
...
uarray.BackendNotImplementedError: ...
Now consider a multimethod that creates a new object of ``TypeA``, or
``TypeB`` depending on the active backend.
>>> with ua.set_backend(ex.BackendA), ua.set_backend(ex.BackendB):
... res = ex.creation_multimethod()
... ex.call_multimethod(res, ex.TypeA())
Traceback (most recent call last):
...
uarray.BackendNotImplementedError: ...
``res`` is an object of ``TypeB`` because ``BackendB`` is set in the
innermost with statement. So, ``call_multimethod`` fails since the types
don't match.
Instead, we need to first find a backend suitable for all of our objects.
>>> with ua.set_backend(ex.BackendA), ua.set_backend(ex.BackendB):
... x = ex.TypeA()
... with ua.determine_backend(x, "mark", domain="ua_examples"):
... res = ex.creation_multimethod()
... ex.call_multimethod(res, x)
TypeA
"""
dispatchables = (Dispatchable(value, dispatch_type, coerce),)
backend = _uarray.determine_backend(domain, dispatchables, coerce)
return set_backend(backend, coerce=coerce, only=only)
def determine_backend_multi(
dispatchables, *, domain, only=True, coerce=False, **kwargs
):
"""Set a backend supporting all ``dispatchables``
This is useful for functions that call multimethods without any dispatchable
arguments. You can use :func:`determine_backend_multi` to ensure the same
backend is used everywhere in a block of multimethod calls involving
multiple arrays.
Parameters
----------
dispatchables: Sequence[Union[uarray.Dispatchable, Any]]
The dispatchables that must be supported
domain: string
The domain to query for backends and set.
coerce: bool
Whether or not to allow coercion to the backend's types. Implies ``only``.
only: bool
Whether or not this should be the last backend to try.
dispatch_type: Optional[Any]
The default dispatch type associated with ``dispatchables``, aka
":ref:`marking <MarkingGlossary>`".
See Also
--------
determine_backend: For a single dispatch value
set_backend: For when you know which backend to set
Notes
-----
Support is determined by the ``__ua_convert__`` protocol. Backends not
supporting the type must return ``NotImplemented`` from their
``__ua_convert__`` if they don't support input of that type.
Examples
--------
:func:`determine_backend` allows the backend to be set from a single
object. :func:`determine_backend_multi` allows multiple objects to be
checked simultaneously for support in the backend. Suppose we have a
``BackendAB`` which supports ``TypeA`` and ``TypeB`` in the same call,
and a ``BackendBC`` that doesn't support ``TypeA``.
>>> with ua.set_backend(ex.BackendAB), ua.set_backend(ex.BackendBC):
... a, b = ex.TypeA(), ex.TypeB()
... with ua.determine_backend_multi(
... [ua.Dispatchable(a, "mark"), ua.Dispatchable(b, "mark")],
... domain="ua_examples"
... ):
... res = ex.creation_multimethod()
... ex.call_multimethod(res, a, b)
TypeA
This won't call ``BackendBC`` because it doesn't support ``TypeA``.
We can also use leave out the ``ua.Dispatchable`` if we specify the
default ``dispatch_type`` for the ``dispatchables`` argument.
>>> with ua.set_backend(ex.BackendAB), ua.set_backend(ex.BackendBC):
... a, b = ex.TypeA(), ex.TypeB()
... with ua.determine_backend_multi(
... [a, b], dispatch_type="mark", domain="ua_examples"
... ):
... res = ex.creation_multimethod()
... ex.call_multimethod(res, a, b)
TypeA
"""
if "dispatch_type" in kwargs:
disp_type = kwargs.pop("dispatch_type")
dispatchables = tuple(
d if isinstance(d, Dispatchable) else Dispatchable(d, disp_type)
for d in dispatchables
)
else:
dispatchables = tuple(dispatchables)
if not all(isinstance(d, Dispatchable) for d in dispatchables):
raise TypeError("dispatchables must be instances of uarray.Dispatchable")
if len(kwargs) != 0:
raise TypeError(f"Received unexpected keyword arguments: {kwargs}")
backend = _uarray.determine_backend(domain, dispatchables, coerce)
return set_backend(backend, coerce=coerce, only=only)

View File

@ -0,0 +1,954 @@
import re
from contextlib import contextmanager
import functools
import operator
import warnings
import numbers
from collections import namedtuple
import inspect
import math
from typing import (
Optional,
Union,
TYPE_CHECKING,
TypeVar,
)
import numpy as np
from scipy._lib._array_api import array_namespace, is_numpy, size as xp_size
AxisError: type[Exception]
ComplexWarning: type[Warning]
VisibleDeprecationWarning: type[Warning]
if np.lib.NumpyVersion(np.__version__) >= '1.25.0':
from numpy.exceptions import (
AxisError, ComplexWarning, VisibleDeprecationWarning,
DTypePromotionError
)
else:
from numpy import ( # type: ignore[attr-defined, no-redef]
AxisError, ComplexWarning, VisibleDeprecationWarning # noqa: F401
)
DTypePromotionError = TypeError # type: ignore
np_long: type
np_ulong: type
if np.lib.NumpyVersion(np.__version__) >= "2.0.0.dev0":
try:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
r".*In the future `np\.long` will be defined as.*",
FutureWarning,
)
np_long = np.long # type: ignore[attr-defined]
np_ulong = np.ulong # type: ignore[attr-defined]
except AttributeError:
np_long = np.int_
np_ulong = np.uint
else:
np_long = np.int_
np_ulong = np.uint
IntNumber = Union[int, np.integer]
DecimalNumber = Union[float, np.floating, np.integer]
copy_if_needed: Optional[bool]
if np.lib.NumpyVersion(np.__version__) >= "2.0.0":
copy_if_needed = None
elif np.lib.NumpyVersion(np.__version__) < "1.28.0":
copy_if_needed = False
else:
# 2.0.0 dev versions, handle cases where copy may or may not exist
try:
np.array([1]).__array__(copy=None) # type: ignore[call-overload]
copy_if_needed = None
except TypeError:
copy_if_needed = False
# Since Generator was introduced in numpy 1.17, the following condition is needed for
# backward compatibility
if TYPE_CHECKING:
SeedType = Optional[Union[IntNumber, np.random.Generator,
np.random.RandomState]]
GeneratorType = TypeVar("GeneratorType", bound=Union[np.random.Generator,
np.random.RandomState])
try:
from numpy.random import Generator as Generator
except ImportError:
class Generator: # type: ignore[no-redef]
pass
def _lazywhere(cond, arrays, f, fillvalue=None, f2=None):
"""Return elements chosen from two possibilities depending on a condition
Equivalent to ``f(*arrays) if cond else fillvalue`` performed elementwise.
Parameters
----------
cond : array
The condition (expressed as a boolean array).
arrays : tuple of array
Arguments to `f` (and `f2`). Must be broadcastable with `cond`.
f : callable
Where `cond` is True, output will be ``f(arr1[cond], arr2[cond], ...)``
fillvalue : object
If provided, value with which to fill output array where `cond` is
not True.
f2 : callable
If provided, output will be ``f2(arr1[cond], arr2[cond], ...)`` where
`cond` is not True.
Returns
-------
out : array
An array with elements from the output of `f` where `cond` is True
and `fillvalue` (or elements from the output of `f2`) elsewhere. The
returned array has data type determined by Type Promotion Rules
with the output of `f` and `fillvalue` (or the output of `f2`).
Notes
-----
``xp.where(cond, x, fillvalue)`` requires explicitly forming `x` even where
`cond` is False. This function evaluates ``f(arr1[cond], arr2[cond], ...)``
onle where `cond` ``is True.
Examples
--------
>>> import numpy as np
>>> a, b = np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8])
>>> def f(a, b):
... return a*b
>>> _lazywhere(a > 2, (a, b), f, np.nan)
array([ nan, nan, 21., 32.])
"""
xp = array_namespace(cond, *arrays)
if (f2 is fillvalue is None) or (f2 is not None and fillvalue is not None):
raise ValueError("Exactly one of `fillvalue` or `f2` must be given.")
args = xp.broadcast_arrays(cond, *arrays)
bool_dtype = xp.asarray([True]).dtype # numpy 1.xx doesn't have `bool`
cond, arrays = xp.astype(args[0], bool_dtype, copy=False), args[1:]
temp1 = xp.asarray(f(*(arr[cond] for arr in arrays)))
if f2 is None:
fillvalue = xp.asarray(fillvalue)
dtype = xp.result_type(temp1.dtype, fillvalue.dtype)
out = xp.full(cond.shape, fill_value=fillvalue, dtype=dtype)
else:
ncond = ~cond
temp2 = xp.asarray(f2(*(arr[ncond] for arr in arrays)))
dtype = xp.result_type(temp1, temp2)
out = xp.empty(cond.shape, dtype=dtype)
out[ncond] = temp2
out[cond] = temp1
return out
def _lazyselect(condlist, choicelist, arrays, default=0):
"""
Mimic `np.select(condlist, choicelist)`.
Notice, it assumes that all `arrays` are of the same shape or can be
broadcasted together.
All functions in `choicelist` must accept array arguments in the order
given in `arrays` and must return an array of the same shape as broadcasted
`arrays`.
Examples
--------
>>> import numpy as np
>>> x = np.arange(6)
>>> np.select([x <3, x > 3], [x**2, x**3], default=0)
array([ 0, 1, 4, 0, 64, 125])
>>> _lazyselect([x < 3, x > 3], [lambda x: x**2, lambda x: x**3], (x,))
array([ 0., 1., 4., 0., 64., 125.])
>>> a = -np.ones_like(x)
>>> _lazyselect([x < 3, x > 3],
... [lambda x, a: x**2, lambda x, a: a * x**3],
... (x, a), default=np.nan)
array([ 0., 1., 4., nan, -64., -125.])
"""
arrays = np.broadcast_arrays(*arrays)
tcode = np.mintypecode([a.dtype.char for a in arrays])
out = np.full(np.shape(arrays[0]), fill_value=default, dtype=tcode)
for func, cond in zip(choicelist, condlist):
if np.all(cond is False):
continue
cond, _ = np.broadcast_arrays(cond, arrays[0])
temp = tuple(np.extract(cond, arr) for arr in arrays)
np.place(out, cond, func(*temp))
return out
def _aligned_zeros(shape, dtype=float, order="C", align=None):
"""Allocate a new ndarray with aligned memory.
Primary use case for this currently is working around a f2py issue
in NumPy 1.9.1, where dtype.alignment is such that np.zeros() does
not necessarily create arrays aligned up to it.
"""
dtype = np.dtype(dtype)
if align is None:
align = dtype.alignment
if not hasattr(shape, '__len__'):
shape = (shape,)
size = functools.reduce(operator.mul, shape) * dtype.itemsize
buf = np.empty(size + align + 1, np.uint8)
offset = buf.__array_interface__['data'][0] % align
if offset != 0:
offset = align - offset
# Note: slices producing 0-size arrays do not necessarily change
# data pointer --- so we use and allocate size+1
buf = buf[offset:offset+size+1][:-1]
data = np.ndarray(shape, dtype, buf, order=order)
data.fill(0)
return data
def _prune_array(array):
"""Return an array equivalent to the input array. If the input
array is a view of a much larger array, copy its contents to a
newly allocated array. Otherwise, return the input unchanged.
"""
if array.base is not None and array.size < array.base.size // 2:
return array.copy()
return array
def float_factorial(n: int) -> float:
"""Compute the factorial and return as a float
Returns infinity when result is too large for a double
"""
return float(math.factorial(n)) if n < 171 else np.inf
# copy-pasted from scikit-learn utils/validation.py
# change this to scipy.stats._qmc.check_random_state once numpy 1.16 is dropped
def check_random_state(seed):
"""Turn `seed` into a `np.random.RandomState` instance.
Parameters
----------
seed : {None, int, `numpy.random.Generator`, `numpy.random.RandomState`}, optional
If `seed` is None (or `np.random`), the `numpy.random.RandomState`
singleton is used.
If `seed` is an int, a new ``RandomState`` instance is used,
seeded with `seed`.
If `seed` is already a ``Generator`` or ``RandomState`` instance then
that instance is used.
Returns
-------
seed : {`numpy.random.Generator`, `numpy.random.RandomState`}
Random number generator.
"""
if seed is None or seed is np.random:
return np.random.mtrand._rand
if isinstance(seed, (numbers.Integral, np.integer)):
return np.random.RandomState(seed)
if isinstance(seed, (np.random.RandomState, np.random.Generator)):
return seed
raise ValueError(f"'{seed}' cannot be used to seed a numpy.random.RandomState"
" instance")
def _asarray_validated(a, check_finite=True,
sparse_ok=False, objects_ok=False, mask_ok=False,
as_inexact=False):
"""
Helper function for SciPy argument validation.
Many SciPy linear algebra functions do support arbitrary array-like
input arguments. Examples of commonly unsupported inputs include
matrices containing inf/nan, sparse matrix representations, and
matrices with complicated elements.
Parameters
----------
a : array_like
The array-like input.
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Default: True
sparse_ok : bool, optional
True if scipy sparse matrices are allowed.
objects_ok : bool, optional
True if arrays with dype('O') are allowed.
mask_ok : bool, optional
True if masked arrays are allowed.
as_inexact : bool, optional
True to convert the input array to a np.inexact dtype.
Returns
-------
ret : ndarray
The converted validated array.
"""
if not sparse_ok:
import scipy.sparse
if scipy.sparse.issparse(a):
msg = ('Sparse matrices are not supported by this function. '
'Perhaps one of the scipy.sparse.linalg functions '
'would work instead.')
raise ValueError(msg)
if not mask_ok:
if np.ma.isMaskedArray(a):
raise ValueError('masked arrays are not supported')
toarray = np.asarray_chkfinite if check_finite else np.asarray
a = toarray(a)
if not objects_ok:
if a.dtype is np.dtype('O'):
raise ValueError('object arrays are not supported')
if as_inexact:
if not np.issubdtype(a.dtype, np.inexact):
a = toarray(a, dtype=np.float64)
return a
def _validate_int(k, name, minimum=None):
"""
Validate a scalar integer.
This function can be used to validate an argument to a function
that expects the value to be an integer. It uses `operator.index`
to validate the value (so, for example, k=2.0 results in a
TypeError).
Parameters
----------
k : int
The value to be validated.
name : str
The name of the parameter.
minimum : int, optional
An optional lower bound.
"""
try:
k = operator.index(k)
except TypeError:
raise TypeError(f'{name} must be an integer.') from None
if minimum is not None and k < minimum:
raise ValueError(f'{name} must be an integer not less '
f'than {minimum}') from None
return k
# Add a replacement for inspect.getfullargspec()/
# The version below is borrowed from Django,
# https://github.com/django/django/pull/4846.
# Note an inconsistency between inspect.getfullargspec(func) and
# inspect.signature(func). If `func` is a bound method, the latter does *not*
# list `self` as a first argument, while the former *does*.
# Hence, cook up a common ground replacement: `getfullargspec_no_self` which
# mimics `inspect.getfullargspec` but does not list `self`.
#
# This way, the caller code does not need to know whether it uses a legacy
# .getfullargspec or a bright and shiny .signature.
FullArgSpec = namedtuple('FullArgSpec',
['args', 'varargs', 'varkw', 'defaults',
'kwonlyargs', 'kwonlydefaults', 'annotations'])
def getfullargspec_no_self(func):
"""inspect.getfullargspec replacement using inspect.signature.
If func is a bound method, do not list the 'self' parameter.
Parameters
----------
func : callable
A callable to inspect
Returns
-------
fullargspec : FullArgSpec(args, varargs, varkw, defaults, kwonlyargs,
kwonlydefaults, annotations)
NOTE: if the first argument of `func` is self, it is *not*, I repeat
*not*, included in fullargspec.args.
This is done for consistency between inspect.getargspec() under
Python 2.x, and inspect.signature() under Python 3.x.
"""
sig = inspect.signature(func)
args = [
p.name for p in sig.parameters.values()
if p.kind in [inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.POSITIONAL_ONLY]
]
varargs = [
p.name for p in sig.parameters.values()
if p.kind == inspect.Parameter.VAR_POSITIONAL
]
varargs = varargs[0] if varargs else None
varkw = [
p.name for p in sig.parameters.values()
if p.kind == inspect.Parameter.VAR_KEYWORD
]
varkw = varkw[0] if varkw else None
defaults = tuple(
p.default for p in sig.parameters.values()
if (p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD and
p.default is not p.empty)
) or None
kwonlyargs = [
p.name for p in sig.parameters.values()
if p.kind == inspect.Parameter.KEYWORD_ONLY
]
kwdefaults = {p.name: p.default for p in sig.parameters.values()
if p.kind == inspect.Parameter.KEYWORD_ONLY and
p.default is not p.empty}
annotations = {p.name: p.annotation for p in sig.parameters.values()
if p.annotation is not p.empty}
return FullArgSpec(args, varargs, varkw, defaults, kwonlyargs,
kwdefaults or None, annotations)
class _FunctionWrapper:
"""
Object to wrap user's function, allowing picklability
"""
def __init__(self, f, args):
self.f = f
self.args = [] if args is None else args
def __call__(self, x):
return self.f(x, *self.args)
class MapWrapper:
"""
Parallelisation wrapper for working with map-like callables, such as
`multiprocessing.Pool.map`.
Parameters
----------
pool : int or map-like callable
If `pool` is an integer, then it specifies the number of threads to
use for parallelization. If ``int(pool) == 1``, then no parallel
processing is used and the map builtin is used.
If ``pool == -1``, then the pool will utilize all available CPUs.
If `pool` is a map-like callable that follows the same
calling sequence as the built-in map function, then this callable is
used for parallelization.
"""
def __init__(self, pool=1):
self.pool = None
self._mapfunc = map
self._own_pool = False
if callable(pool):
self.pool = pool
self._mapfunc = self.pool
else:
from multiprocessing import Pool
# user supplies a number
if int(pool) == -1:
# use as many processors as possible
self.pool = Pool()
self._mapfunc = self.pool.map
self._own_pool = True
elif int(pool) == 1:
pass
elif int(pool) > 1:
# use the number of processors requested
self.pool = Pool(processes=int(pool))
self._mapfunc = self.pool.map
self._own_pool = True
else:
raise RuntimeError("Number of workers specified must be -1,"
" an int >= 1, or an object with a 'map' "
"method")
def __enter__(self):
return self
def terminate(self):
if self._own_pool:
self.pool.terminate()
def join(self):
if self._own_pool:
self.pool.join()
def close(self):
if self._own_pool:
self.pool.close()
def __exit__(self, exc_type, exc_value, traceback):
if self._own_pool:
self.pool.close()
self.pool.terminate()
def __call__(self, func, iterable):
# only accept one iterable because that's all Pool.map accepts
try:
return self._mapfunc(func, iterable)
except TypeError as e:
# wrong number of arguments
raise TypeError("The map-like callable must be of the"
" form f(func, iterable)") from e
def rng_integers(gen, low, high=None, size=None, dtype='int64',
endpoint=False):
"""
Return random integers from low (inclusive) to high (exclusive), or if
endpoint=True, low (inclusive) to high (inclusive). Replaces
`RandomState.randint` (with endpoint=False) and
`RandomState.random_integers` (with endpoint=True).
Return random integers from the "discrete uniform" distribution of the
specified dtype. If high is None (the default), then results are from
0 to low.
Parameters
----------
gen : {None, np.random.RandomState, np.random.Generator}
Random number generator. If None, then the np.random.RandomState
singleton is used.
low : int or array-like of ints
Lowest (signed) integers to be drawn from the distribution (unless
high=None, in which case this parameter is 0 and this value is used
for high).
high : int or array-like of ints
If provided, one above the largest (signed) integer to be drawn from
the distribution (see above for behavior if high=None). If array-like,
must contain integer values.
size : array-like of ints, optional
Output shape. If the given shape is, e.g., (m, n, k), then m * n * k
samples are drawn. Default is None, in which case a single value is
returned.
dtype : {str, dtype}, optional
Desired dtype of the result. All dtypes are determined by their name,
i.e., 'int64', 'int', etc, so byteorder is not available and a specific
precision may have different C types depending on the platform.
The default value is 'int64'.
endpoint : bool, optional
If True, sample from the interval [low, high] instead of the default
[low, high) Defaults to False.
Returns
-------
out: int or ndarray of ints
size-shaped array of random integers from the appropriate distribution,
or a single such random int if size not provided.
"""
if isinstance(gen, Generator):
return gen.integers(low, high=high, size=size, dtype=dtype,
endpoint=endpoint)
else:
if gen is None:
# default is RandomState singleton used by np.random.
gen = np.random.mtrand._rand
if endpoint:
# inclusive of endpoint
# remember that low and high can be arrays, so don't modify in
# place
if high is None:
return gen.randint(low + 1, size=size, dtype=dtype)
if high is not None:
return gen.randint(low, high=high + 1, size=size, dtype=dtype)
# exclusive
return gen.randint(low, high=high, size=size, dtype=dtype)
@contextmanager
def _fixed_default_rng(seed=1638083107694713882823079058616272161):
"""Context with a fixed np.random.default_rng seed."""
orig_fun = np.random.default_rng
np.random.default_rng = lambda seed=seed: orig_fun(seed)
try:
yield
finally:
np.random.default_rng = orig_fun
def _rng_html_rewrite(func):
"""Rewrite the HTML rendering of ``np.random.default_rng``.
This is intended to decorate
``numpydoc.docscrape_sphinx.SphinxDocString._str_examples``.
Examples are only run by Sphinx when there are plot involved. Even so,
it does not change the result values getting printed.
"""
# hexadecimal or number seed, case-insensitive
pattern = re.compile(r'np.random.default_rng\((0x[0-9A-F]+|\d+)\)', re.I)
def _wrapped(*args, **kwargs):
res = func(*args, **kwargs)
lines = [
re.sub(pattern, 'np.random.default_rng()', line)
for line in res
]
return lines
return _wrapped
def _argmin(a, keepdims=False, axis=None):
"""
argmin with a `keepdims` parameter.
See https://github.com/numpy/numpy/issues/8710
If axis is not None, a.shape[axis] must be greater than 0.
"""
res = np.argmin(a, axis=axis)
if keepdims and axis is not None:
res = np.expand_dims(res, axis=axis)
return res
def _first_nonnan(a, axis):
"""
Return the first non-nan value along the given axis.
If a slice is all nan, nan is returned for that slice.
The shape of the return value corresponds to ``keepdims=True``.
Examples
--------
>>> import numpy as np
>>> nan = np.nan
>>> a = np.array([[ 3., 3., nan, 3.],
[ 1., nan, 2., 4.],
[nan, nan, 9., -1.],
[nan, 5., 4., 3.],
[ 2., 2., 2., 2.],
[nan, nan, nan, nan]])
>>> _first_nonnan(a, axis=0)
array([[3., 3., 2., 3.]])
>>> _first_nonnan(a, axis=1)
array([[ 3.],
[ 1.],
[ 9.],
[ 5.],
[ 2.],
[nan]])
"""
k = _argmin(np.isnan(a), axis=axis, keepdims=True)
return np.take_along_axis(a, k, axis=axis)
def _nan_allsame(a, axis, keepdims=False):
"""
Determine if the values along an axis are all the same.
nan values are ignored.
`a` must be a numpy array.
`axis` is assumed to be normalized; that is, 0 <= axis < a.ndim.
For an axis of length 0, the result is True. That is, we adopt the
convention that ``allsame([])`` is True. (There are no values in the
input that are different.)
`True` is returned for slices that are all nan--not because all the
values are the same, but because this is equivalent to ``allsame([])``.
Examples
--------
>>> from numpy import nan, array
>>> a = array([[ 3., 3., nan, 3.],
... [ 1., nan, 2., 4.],
... [nan, nan, 9., -1.],
... [nan, 5., 4., 3.],
... [ 2., 2., 2., 2.],
... [nan, nan, nan, nan]])
>>> _nan_allsame(a, axis=1, keepdims=True)
array([[ True],
[False],
[False],
[False],
[ True],
[ True]])
"""
if axis is None:
if a.size == 0:
return True
a = a.ravel()
axis = 0
else:
shp = a.shape
if shp[axis] == 0:
shp = shp[:axis] + (1,)*keepdims + shp[axis + 1:]
return np.full(shp, fill_value=True, dtype=bool)
a0 = _first_nonnan(a, axis=axis)
return ((a0 == a) | np.isnan(a)).all(axis=axis, keepdims=keepdims)
def _contains_nan(a, nan_policy='propagate', policies=None, *, xp=None):
if xp is None:
xp = array_namespace(a)
not_numpy = not is_numpy(xp)
if policies is None:
policies = {'propagate', 'raise', 'omit'}
if nan_policy not in policies:
raise ValueError(f"nan_policy must be one of {set(policies)}.")
inexact = (xp.isdtype(a.dtype, "real floating")
or xp.isdtype(a.dtype, "complex floating"))
if xp_size(a) == 0:
contains_nan = False
elif inexact:
# Faster and less memory-intensive than xp.any(xp.isnan(a))
contains_nan = xp.isnan(xp.max(a))
elif is_numpy(xp) and np.issubdtype(a.dtype, object):
contains_nan = False
for el in a.ravel():
# isnan doesn't work on non-numeric elements
if np.issubdtype(type(el), np.number) and np.isnan(el):
contains_nan = True
break
else:
# Only `object` and `inexact` arrays can have NaNs
contains_nan = False
if contains_nan and nan_policy == 'raise':
raise ValueError("The input contains nan values")
if not_numpy and contains_nan and nan_policy=='omit':
message = "`nan_policy='omit' is incompatible with non-NumPy arrays."
raise ValueError(message)
return contains_nan, nan_policy
def _rename_parameter(old_name, new_name, dep_version=None):
"""
Generate decorator for backward-compatible keyword renaming.
Apply the decorator generated by `_rename_parameter` to functions with a
recently renamed parameter to maintain backward-compatibility.
After decoration, the function behaves as follows:
If only the new parameter is passed into the function, behave as usual.
If only the old parameter is passed into the function (as a keyword), raise
a DeprecationWarning if `dep_version` is provided, and behave as usual
otherwise.
If both old and new parameters are passed into the function, raise a
DeprecationWarning if `dep_version` is provided, and raise the appropriate
TypeError (function got multiple values for argument).
Parameters
----------
old_name : str
Old name of parameter
new_name : str
New name of parameter
dep_version : str, optional
Version of SciPy in which old parameter was deprecated in the format
'X.Y.Z'. If supplied, the deprecation message will indicate that
support for the old parameter will be removed in version 'X.Y+2.Z'
Notes
-----
Untested with functions that accept *args. Probably won't work as written.
"""
def decorator(fun):
@functools.wraps(fun)
def wrapper(*args, **kwargs):
if old_name in kwargs:
if dep_version:
end_version = dep_version.split('.')
end_version[1] = str(int(end_version[1]) + 2)
end_version = '.'.join(end_version)
message = (f"Use of keyword argument `{old_name}` is "
f"deprecated and replaced by `{new_name}`. "
f"Support for `{old_name}` will be removed "
f"in SciPy {end_version}.")
warnings.warn(message, DeprecationWarning, stacklevel=2)
if new_name in kwargs:
message = (f"{fun.__name__}() got multiple values for "
f"argument now known as `{new_name}`")
raise TypeError(message)
kwargs[new_name] = kwargs.pop(old_name)
return fun(*args, **kwargs)
return wrapper
return decorator
def _rng_spawn(rng, n_children):
# spawns independent RNGs from a parent RNG
bg = rng._bit_generator
ss = bg._seed_seq
child_rngs = [np.random.Generator(type(bg)(child_ss))
for child_ss in ss.spawn(n_children)]
return child_rngs
def _get_nan(*data, xp=None):
xp = array_namespace(*data) if xp is None else xp
# Get NaN of appropriate dtype for data
data = [xp.asarray(item) for item in data]
try:
min_float = getattr(xp, 'float16', xp.float32)
dtype = xp.result_type(*data, min_float) # must be at least a float
except DTypePromotionError:
# fallback to float64
dtype = xp.float64
return xp.asarray(xp.nan, dtype=dtype)[()]
def normalize_axis_index(axis, ndim):
# Check if `axis` is in the correct range and normalize it
if axis < -ndim or axis >= ndim:
msg = f"axis {axis} is out of bounds for array of dimension {ndim}"
raise AxisError(msg)
if axis < 0:
axis = axis + ndim
return axis
def _call_callback_maybe_halt(callback, res):
"""Call wrapped callback; return True if algorithm should stop.
Parameters
----------
callback : callable or None
A user-provided callback wrapped with `_wrap_callback`
res : OptimizeResult
Information about the current iterate
Returns
-------
halt : bool
True if minimization should stop
"""
if callback is None:
return False
try:
callback(res)
return False
except StopIteration:
callback.stop_iteration = True
return True
class _RichResult(dict):
""" Container for multiple outputs with pretty-printing """
def __getattr__(self, name):
try:
return self[name]
except KeyError as e:
raise AttributeError(name) from e
__setattr__ = dict.__setitem__ # type: ignore[assignment]
__delattr__ = dict.__delitem__ # type: ignore[assignment]
def __repr__(self):
order_keys = ['message', 'success', 'status', 'fun', 'funl', 'x', 'xl',
'col_ind', 'nit', 'lower', 'upper', 'eqlin', 'ineqlin',
'converged', 'flag', 'function_calls', 'iterations',
'root']
order_keys = getattr(self, '_order_keys', order_keys)
# 'slack', 'con' are redundant with residuals
# 'crossover_nit' is probably not interesting to most users
omit_keys = {'slack', 'con', 'crossover_nit', '_order_keys'}
def key(item):
try:
return order_keys.index(item[0].lower())
except ValueError: # item not in list
return np.inf
def omit_redundant(items):
for item in items:
if item[0] in omit_keys:
continue
yield item
def item_sorter(d):
return sorted(omit_redundant(d.items()), key=key)
if self.keys():
return _dict_formatter(self, sorter=item_sorter)
else:
return self.__class__.__name__ + "()"
def __dir__(self):
return list(self.keys())
def _indenter(s, n=0):
"""
Ensures that lines after the first are indented by the specified amount
"""
split = s.split("\n")
indent = " "*n
return ("\n" + indent).join(split)
def _float_formatter_10(x):
"""
Returns a string representation of a float with exactly ten characters
"""
if np.isposinf(x):
return " inf"
elif np.isneginf(x):
return " -inf"
elif np.isnan(x):
return " nan"
return np.format_float_scientific(x, precision=3, pad_left=2, unique=False)
def _dict_formatter(d, n=0, mplus=1, sorter=None):
"""
Pretty printer for dictionaries
`n` keeps track of the starting indentation;
lines are indented by this much after a line break.
`mplus` is additional left padding applied to keys
"""
if isinstance(d, dict):
m = max(map(len, list(d.keys()))) + mplus # width to print keys
s = '\n'.join([k.rjust(m) + ': ' + # right justified, width m
_indenter(_dict_formatter(v, m+n+2, 0, sorter), m+2)
for k, v in sorter(d)]) # +2 for ': '
else:
# By default, NumPy arrays print with linewidth=76. `n` is
# the indent at which a line begins printing, so it is subtracted
# from the default to avoid exceeding 76 characters total.
# `edgeitems` is the number of elements to include before and after
# ellipses when arrays are not shown in full.
# `threshold` is the maximum number of elements for which an
# array is shown in full.
# These values tend to work well for use with OptimizeResult.
with np.printoptions(linewidth=76-n, edgeitems=2, threshold=12,
formatter={'float_kind': _float_formatter_10}):
s = str(d)
return s

View File

@ -0,0 +1,22 @@
"""
NumPy Array API compatibility library
This is a small wrapper around NumPy and CuPy that is compatible with the
Array API standard https://data-apis.org/array-api/latest/. See also NEP 47
https://numpy.org/neps/nep-0047-array-api-standard.html.
Unlike array_api_strict, this is not a strict minimal implementation of the
Array API, but rather just an extension of the main NumPy namespace with
changes needed to be compliant with the Array API. See
https://numpy.org/doc/stable/reference/array_api.html for a full list of
changes. In particular, unlike array_api_strict, this package does not use a
separate Array object, but rather just uses numpy.ndarray directly.
Library authors using the Array API may wish to test against array_api_strict
to ensure they are not using functionality outside of the standard, but prefer
this implementation for the default when working with NumPy arrays.
"""
__version__ = '1.5.1'
from .common import * # noqa: F401, F403

View File

@ -0,0 +1,46 @@
"""
Internal helpers
"""
from functools import wraps
from inspect import signature
def get_xp(xp):
"""
Decorator to automatically replace xp with the corresponding array module.
Use like
import numpy as np
@get_xp(np)
def func(x, /, xp, kwarg=None):
return xp.func(x, kwarg=kwarg)
Note that xp must be a keyword argument and come after all non-keyword
arguments.
"""
def inner(f):
@wraps(f)
def wrapped_f(*args, **kwargs):
return f(*args, xp=xp, **kwargs)
sig = signature(f)
new_sig = sig.replace(
parameters=[sig.parameters[i] for i in sig.parameters if i != "xp"]
)
if wrapped_f.__doc__ is None:
wrapped_f.__doc__ = f"""\
Array API compatibility wrapper for {f.__name__}.
See the corresponding documentation in NumPy/CuPy and/or the array API
specification for more details.
"""
wrapped_f.__signature__ = new_sig
return wrapped_f
return inner

View File

@ -0,0 +1 @@
from ._helpers import * # noqa: F403

View File

@ -0,0 +1,554 @@
"""
These are functions that are just aliases of existing functions in NumPy.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import numpy as np
from typing import Optional, Sequence, Tuple, Union
from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
from typing import NamedTuple
from types import ModuleType
import inspect
from ._helpers import _check_device, is_numpy_array, array_namespace
# These functions are modified from the NumPy versions.
def arange(
start: Union[int, float],
/,
stop: Optional[Union[int, float]] = None,
step: Union[int, float] = 1,
*,
xp,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs
) -> ndarray:
_check_device(xp, device)
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
def empty(
shape: Union[int, Tuple[int, ...]],
xp,
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs
) -> ndarray:
_check_device(xp, device)
return xp.empty(shape, dtype=dtype, **kwargs)
def empty_like(
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
**kwargs
) -> ndarray:
_check_device(xp, device)
return xp.empty_like(x, dtype=dtype, **kwargs)
def eye(
n_rows: int,
n_cols: Optional[int] = None,
/,
*,
xp,
k: int = 0,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs,
) -> ndarray:
_check_device(xp, device)
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
def full(
shape: Union[int, Tuple[int, ...]],
fill_value: Union[int, float],
xp,
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs,
) -> ndarray:
_check_device(xp, device)
return xp.full(shape, fill_value, dtype=dtype, **kwargs)
def full_like(
x: ndarray,
/,
fill_value: Union[int, float],
*,
xp,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs,
) -> ndarray:
_check_device(xp, device)
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
def linspace(
start: Union[int, float],
stop: Union[int, float],
/,
num: int,
*,
xp,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
endpoint: bool = True,
**kwargs,
) -> ndarray:
_check_device(xp, device)
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
def ones(
shape: Union[int, Tuple[int, ...]],
xp,
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs,
) -> ndarray:
_check_device(xp, device)
return xp.ones(shape, dtype=dtype, **kwargs)
def ones_like(
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
**kwargs,
) -> ndarray:
_check_device(xp, device)
return xp.ones_like(x, dtype=dtype, **kwargs)
def zeros(
shape: Union[int, Tuple[int, ...]],
xp,
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs,
) -> ndarray:
_check_device(xp, device)
return xp.zeros(shape, dtype=dtype, **kwargs)
def zeros_like(
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
**kwargs,
) -> ndarray:
_check_device(xp, device)
return xp.zeros_like(x, dtype=dtype, **kwargs)
# np.unique() is split into four functions in the array API:
# unique_all, unique_counts, unique_inverse, and unique_values (this is done
# to remove polymorphic return types).
# The functions here return namedtuples (np.unique() returns a normal
# tuple).
# Note that these named tuples aren't actually part of the standard namespace,
# but I don't see any issue with exporting the names here regardless.
class UniqueAllResult(NamedTuple):
values: ndarray
indices: ndarray
inverse_indices: ndarray
counts: ndarray
class UniqueCountsResult(NamedTuple):
values: ndarray
counts: ndarray
class UniqueInverseResult(NamedTuple):
values: ndarray
inverse_indices: ndarray
def _unique_kwargs(xp):
# Older versions of NumPy and CuPy do not have equal_nan. Rather than
# trying to parse version numbers, just check if equal_nan is in the
# signature.
s = inspect.signature(xp.unique)
if 'equal_nan' in s.parameters:
return {'equal_nan': False}
return {}
def unique_all(x: ndarray, /, xp) -> UniqueAllResult:
kwargs = _unique_kwargs(xp)
values, indices, inverse_indices, counts = xp.unique(
x,
return_counts=True,
return_index=True,
return_inverse=True,
**kwargs,
)
# np.unique() flattens inverse indices, but they need to share x's shape
# See https://github.com/numpy/numpy/issues/20638
inverse_indices = inverse_indices.reshape(x.shape)
return UniqueAllResult(
values,
indices,
inverse_indices,
counts,
)
def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult:
kwargs = _unique_kwargs(xp)
res = xp.unique(
x,
return_counts=True,
return_index=False,
return_inverse=False,
**kwargs
)
return UniqueCountsResult(*res)
def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult:
kwargs = _unique_kwargs(xp)
values, inverse_indices = xp.unique(
x,
return_counts=False,
return_index=False,
return_inverse=True,
**kwargs,
)
# xp.unique() flattens inverse indices, but they need to share x's shape
# See https://github.com/numpy/numpy/issues/20638
inverse_indices = inverse_indices.reshape(x.shape)
return UniqueInverseResult(values, inverse_indices)
def unique_values(x: ndarray, /, xp) -> ndarray:
kwargs = _unique_kwargs(xp)
return xp.unique(
x,
return_counts=False,
return_index=False,
return_inverse=False,
**kwargs,
)
def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray:
if not copy and dtype == x.dtype:
return x
return x.astype(dtype=dtype, copy=copy)
# These functions have different keyword argument names
def std(
x: ndarray,
/,
xp,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
correction: Union[int, float] = 0.0, # correction instead of ddof
keepdims: bool = False,
**kwargs,
) -> ndarray:
return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
def var(
x: ndarray,
/,
xp,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
correction: Union[int, float] = 0.0, # correction instead of ddof
keepdims: bool = False,
**kwargs,
) -> ndarray:
return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
# Unlike transpose(), the axes argument to permute_dims() is required.
def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray:
return xp.transpose(x, axes)
# Creation functions add the device keyword (which does nothing for NumPy)
# asarray also adds the copy keyword
def _asarray(
obj: Union[
ndarray,
bool,
int,
float,
NestedSequence[bool | int | float],
SupportsBufferProtocol,
],
/,
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
copy: "Optional[Union[bool, np._CopyMode]]" = None,
namespace = None,
**kwargs,
) -> ndarray:
"""
Array API compatibility wrapper for asarray().
See the corresponding documentation in NumPy/CuPy and/or the array API
specification for more details.
"""
if namespace is None:
try:
xp = array_namespace(obj, _use_compat=False)
except ValueError:
# TODO: What about lists of arrays?
raise ValueError("A namespace must be specified for asarray() with non-array input")
elif isinstance(namespace, ModuleType):
xp = namespace
elif namespace == 'numpy':
import numpy as xp
elif namespace == 'cupy':
import cupy as xp
elif namespace == 'dask.array':
import dask.array as xp
else:
raise ValueError("Unrecognized namespace argument to asarray()")
_check_device(xp, device)
if is_numpy_array(obj):
import numpy as np
if hasattr(np, '_CopyMode'):
# Not present in older NumPys
COPY_FALSE = (False, np._CopyMode.IF_NEEDED)
COPY_TRUE = (True, np._CopyMode.ALWAYS)
else:
COPY_FALSE = (False,)
COPY_TRUE = (True,)
else:
COPY_FALSE = (False,)
COPY_TRUE = (True,)
if copy in COPY_FALSE and namespace != "dask.array":
# copy=False is not yet implemented in xp.asarray
raise NotImplementedError("copy=False is not yet implemented")
if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)):
if dtype is not None and obj.dtype != dtype:
copy = True
if copy in COPY_TRUE:
return xp.array(obj, copy=True, dtype=dtype)
return obj
elif namespace == "dask.array":
if copy in COPY_TRUE:
if dtype is None:
return obj.copy()
# Go through numpy, since dask copy is no-op by default
import numpy as np
obj = np.array(obj, dtype=dtype, copy=True)
return xp.array(obj, dtype=dtype)
else:
import dask.array as da
import numpy as np
if not isinstance(obj, da.Array):
obj = np.asarray(obj, dtype=dtype)
return da.from_array(obj)
return obj
return xp.asarray(obj, dtype=dtype, **kwargs)
# np.reshape calls the keyword argument 'newshape' instead of 'shape'
def reshape(x: ndarray,
/,
shape: Tuple[int, ...],
xp, copy: Optional[bool] = None,
**kwargs) -> ndarray:
if copy is True:
x = x.copy()
elif copy is False:
y = x.view()
y.shape = shape
return y
return xp.reshape(x, shape, **kwargs)
# The descending keyword is new in sort and argsort, and 'kind' replaced with
# 'stable'
def argsort(
x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True,
**kwargs,
) -> ndarray:
# Note: this keyword argument is different, and the default is different.
# We set it in kwargs like this because numpy.sort uses kind='quicksort'
# as the default whereas cupy.sort uses kind=None.
if stable:
kwargs['kind'] = "stable"
if not descending:
res = xp.argsort(x, axis=axis, **kwargs)
else:
# As NumPy has no native descending sort, we imitate it here. Note that
# simply flipping the results of xp.argsort(x, ...) would not
# respect the relative order like it would in native descending sorts.
res = xp.flip(
xp.argsort(xp.flip(x, axis=axis), axis=axis, **kwargs),
axis=axis,
)
# Rely on flip()/argsort() to validate axis
normalised_axis = axis if axis >= 0 else x.ndim + axis
max_i = x.shape[normalised_axis] - 1
res = max_i - res
return res
def sort(
x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True,
**kwargs,
) -> ndarray:
# Note: this keyword argument is different, and the default is different.
# We set it in kwargs like this because numpy.sort uses kind='quicksort'
# as the default whereas cupy.sort uses kind=None.
if stable:
kwargs['kind'] = "stable"
res = xp.sort(x, axis=axis, **kwargs)
if descending:
res = xp.flip(res, axis=axis)
return res
# nonzero should error for zero-dimensional arrays
def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]:
if x.ndim == 0:
raise ValueError("nonzero() does not support zero-dimensional arrays")
return xp.nonzero(x, **kwargs)
# sum() and prod() should always upcast when dtype=None
def sum(
x: ndarray,
/,
xp,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype: Optional[Dtype] = None,
keepdims: bool = False,
**kwargs,
) -> ndarray:
# `xp.sum` already upcasts integers, but not floats or complexes
if dtype is None:
if x.dtype == xp.float32:
dtype = xp.float64
elif x.dtype == xp.complex64:
dtype = xp.complex128
return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs)
def prod(
x: ndarray,
/,
xp,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype: Optional[Dtype] = None,
keepdims: bool = False,
**kwargs,
) -> ndarray:
if dtype is None:
if x.dtype == xp.float32:
dtype = xp.float64
elif x.dtype == xp.complex64:
dtype = xp.complex128
return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs)
# ceil, floor, and trunc return integers for integer inputs
def ceil(x: ndarray, /, xp, **kwargs) -> ndarray:
if xp.issubdtype(x.dtype, xp.integer):
return x
return xp.ceil(x, **kwargs)
def floor(x: ndarray, /, xp, **kwargs) -> ndarray:
if xp.issubdtype(x.dtype, xp.integer):
return x
return xp.floor(x, **kwargs)
def trunc(x: ndarray, /, xp, **kwargs) -> ndarray:
if xp.issubdtype(x.dtype, xp.integer):
return x
return xp.trunc(x, **kwargs)
# linear algebra functions
def matmul(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray:
return xp.matmul(x1, x2, **kwargs)
# Unlike transpose, matrix_transpose only transposes the last two axes.
def matrix_transpose(x: ndarray, /, xp) -> ndarray:
if x.ndim < 2:
raise ValueError("x must be at least 2-dimensional for matrix_transpose")
return xp.swapaxes(x, -1, -2)
def tensordot(x1: ndarray,
x2: ndarray,
/,
xp,
*,
axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2,
**kwargs,
) -> ndarray:
return xp.tensordot(x1, x2, axes=axes, **kwargs)
def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray:
if x1.shape[axis] != x2.shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")
if hasattr(xp, 'broadcast_tensors'):
_broadcast = xp.broadcast_tensors
else:
_broadcast = xp.broadcast_arrays
x1_ = xp.moveaxis(x1, axis, -1)
x2_ = xp.moveaxis(x2, axis, -1)
x1_, x2_ = _broadcast(x1_, x2_)
res = x1_[..., None, :] @ x2_[..., None]
return res[..., 0, 0]
# isdtype is a new function in the 2022.12 array API specification.
def isdtype(
dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], xp,
*, _tuple=True, # Disallow nested tuples
) -> bool:
"""
Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
Note that outside of this function, this compat library does not yet fully
support complex numbers.
See
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
for more details
"""
if isinstance(kind, tuple) and _tuple:
return any(isdtype(dtype, k, xp, _tuple=False) for k in kind)
elif isinstance(kind, str):
if kind == 'bool':
return dtype == xp.bool_
elif kind == 'signed integer':
return xp.issubdtype(dtype, xp.signedinteger)
elif kind == 'unsigned integer':
return xp.issubdtype(dtype, xp.unsignedinteger)
elif kind == 'integral':
return xp.issubdtype(dtype, xp.integer)
elif kind == 'real floating':
return xp.issubdtype(dtype, xp.floating)
elif kind == 'complex floating':
return xp.issubdtype(dtype, xp.complexfloating)
elif kind == 'numeric':
return xp.issubdtype(dtype, xp.number)
else:
raise ValueError(f"Unrecognized data type kind: {kind!r}")
else:
# This will allow things that aren't required by the spec, like
# isdtype(np.float64, float) or isdtype(np.int64, 'l'). Should we be
# more strict here to match the type annotation? Note that the
# array_api_strict implementation will be very strict.
return dtype == kind
__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']

View File

@ -0,0 +1,183 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Union, Optional, Literal
if TYPE_CHECKING:
from ._typing import Device, ndarray
from collections.abc import Sequence
# Note: NumPy fft functions improperly upcast float32 and complex64 to
# complex128, which is why we require wrapping them all here.
def fft(
x: ndarray,
/,
xp,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.fft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res
def ifft(
x: ndarray,
/,
xp,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.ifft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res
def fftn(
x: ndarray,
/,
xp,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.fftn(x, s=s, axes=axes, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res
def ifftn(
x: ndarray,
/,
xp,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res
def rfft(
x: ndarray,
/,
xp,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.rfft(x, n=n, axis=axis, norm=norm)
if x.dtype == xp.float32:
return res.astype(xp.complex64)
return res
def irfft(
x: ndarray,
/,
xp,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.irfft(x, n=n, axis=axis, norm=norm)
if x.dtype == xp.complex64:
return res.astype(xp.float32)
return res
def rfftn(
x: ndarray,
/,
xp,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm)
if x.dtype == xp.float32:
return res.astype(xp.complex64)
return res
def irfftn(
x: ndarray,
/,
xp,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm)
if x.dtype == xp.complex64:
return res.astype(xp.float32)
return res
def hfft(
x: ndarray,
/,
xp,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.hfft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.float32)
return res
def ihfft(
x: ndarray,
/,
xp,
*,
n: Optional[int] = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
) -> ndarray:
res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
return res.astype(xp.complex64)
return res
def fftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray:
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
return xp.fft.fftfreq(n, d=d)
def rfftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray:
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
return xp.fft.rfftfreq(n, d=d)
def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
return xp.fft.fftshift(x, axes=axes)
def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
return xp.fft.ifftshift(x, axes=axes)
__all__ = [
"fft",
"ifft",
"fftn",
"ifftn",
"rfft",
"irfft",
"rfftn",
"irfftn",
"hfft",
"ihfft",
"fftfreq",
"rfftfreq",
"fftshift",
"ifftshift",
]

View File

@ -0,0 +1,515 @@
"""
Various helper functions which are not part of the spec.
Functions which start with an underscore are for internal use only but helpers
that are in __all__ are intended as additional helper functions for use by end
users of the compat library.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Union, Any
from ._typing import Array, Device
import sys
import math
import inspect
import warnings
def is_numpy_array(x):
"""
Return True if `x` is a NumPy array.
This function does not import NumPy if it has not already been imported
and is therefore cheap to use.
This also returns True for `ndarray` subclasses and NumPy scalar objects.
See Also
--------
array_namespace
is_array_api_obj
is_cupy_array
is_torch_array
is_dask_array
is_jax_array
"""
# Avoid importing NumPy if it isn't already
if 'numpy' not in sys.modules:
return False
import numpy as np
# TODO: Should we reject ndarray subclasses?
return isinstance(x, (np.ndarray, np.generic))
def is_cupy_array(x):
"""
Return True if `x` is a CuPy array.
This function does not import CuPy if it has not already been imported
and is therefore cheap to use.
This also returns True for `cupy.ndarray` subclasses and CuPy scalar objects.
See Also
--------
array_namespace
is_array_api_obj
is_numpy_array
is_torch_array
is_dask_array
is_jax_array
"""
# Avoid importing NumPy if it isn't already
if 'cupy' not in sys.modules:
return False
import cupy as cp
# TODO: Should we reject ndarray subclasses?
return isinstance(x, (cp.ndarray, cp.generic))
def is_torch_array(x):
"""
Return True if `x` is a PyTorch tensor.
This function does not import PyTorch if it has not already been imported
and is therefore cheap to use.
See Also
--------
array_namespace
is_array_api_obj
is_numpy_array
is_cupy_array
is_dask_array
is_jax_array
"""
# Avoid importing torch if it isn't already
if 'torch' not in sys.modules:
return False
import torch
# TODO: Should we reject ndarray subclasses?
return isinstance(x, torch.Tensor)
def is_dask_array(x):
"""
Return True if `x` is a dask.array Array.
This function does not import dask if it has not already been imported
and is therefore cheap to use.
See Also
--------
array_namespace
is_array_api_obj
is_numpy_array
is_cupy_array
is_torch_array
is_jax_array
"""
# Avoid importing dask if it isn't already
if 'dask.array' not in sys.modules:
return False
import dask.array
return isinstance(x, dask.array.Array)
def is_jax_array(x):
"""
Return True if `x` is a JAX array.
This function does not import JAX if it has not already been imported
and is therefore cheap to use.
See Also
--------
array_namespace
is_array_api_obj
is_numpy_array
is_cupy_array
is_torch_array
is_dask_array
"""
# Avoid importing jax if it isn't already
if 'jax' not in sys.modules:
return False
import jax
return isinstance(x, jax.Array)
def is_array_api_obj(x):
"""
Return True if `x` is an array API compatible array object.
See Also
--------
array_namespace
is_numpy_array
is_cupy_array
is_torch_array
is_dask_array
is_jax_array
"""
return is_numpy_array(x) \
or is_cupy_array(x) \
or is_torch_array(x) \
or is_dask_array(x) \
or is_jax_array(x) \
or hasattr(x, '__array_namespace__')
def _check_api_version(api_version):
if api_version == '2021.12':
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
elif api_version is not None and api_version != '2022.12':
raise ValueError("Only the 2022.12 version of the array API specification is currently supported")
def array_namespace(*xs, api_version=None, _use_compat=True):
"""
Get the array API compatible namespace for the arrays `xs`.
Parameters
----------
xs: arrays
one or more arrays.
api_version: str
The newest version of the spec that you need support for (currently
the compat library wrapped APIs support v2022.12).
Returns
-------
out: namespace
The array API compatible namespace corresponding to the arrays in `xs`.
Raises
------
TypeError
If `xs` contains arrays from different array libraries or contains a
non-array.
Typical usage is to pass the arguments of a function to
`array_namespace()` at the top of a function to get the corresponding
array API namespace:
.. code:: python
def your_function(x, y):
xp = array_api_compat.array_namespace(x, y)
# Now use xp as the array library namespace
return xp.mean(x, axis=0) + 2*xp.std(y, axis=0)
Wrapped array namespaces can also be imported directly. For example,
`array_namespace(np.array(...))` will return `array_api_compat.numpy`.
This function will also work for any array library not wrapped by
array-api-compat if it explicitly defines `__array_namespace__
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__array_namespace__.html>`__
(the wrapped namespace is always preferred if it exists).
See Also
--------
is_array_api_obj
is_numpy_array
is_cupy_array
is_torch_array
is_dask_array
is_jax_array
"""
namespaces = set()
for x in xs:
if is_numpy_array(x):
_check_api_version(api_version)
if _use_compat:
from .. import numpy as numpy_namespace
namespaces.add(numpy_namespace)
else:
import numpy as np
namespaces.add(np)
elif is_cupy_array(x):
_check_api_version(api_version)
if _use_compat:
from .. import cupy as cupy_namespace
namespaces.add(cupy_namespace)
else:
import cupy as cp
namespaces.add(cp)
elif is_torch_array(x):
_check_api_version(api_version)
if _use_compat:
from .. import torch as torch_namespace
namespaces.add(torch_namespace)
else:
import torch
namespaces.add(torch)
elif is_dask_array(x):
_check_api_version(api_version)
if _use_compat:
from ..dask import array as dask_namespace
namespaces.add(dask_namespace)
else:
raise TypeError("_use_compat cannot be False if input array is a dask array!")
elif is_jax_array(x):
_check_api_version(api_version)
# jax.experimental.array_api is already an array namespace. We do
# not have a wrapper submodule for it.
import jax.experimental.array_api as jnp
namespaces.add(jnp)
elif hasattr(x, '__array_namespace__'):
namespaces.add(x.__array_namespace__(api_version=api_version))
else:
# TODO: Support Python scalars?
raise TypeError(f"{type(x).__name__} is not a supported array type")
if not namespaces:
raise TypeError("Unrecognized array input")
if len(namespaces) != 1:
raise TypeError(f"Multiple namespaces for array inputs: {namespaces}")
xp, = namespaces
return xp
# backwards compatibility alias
get_namespace = array_namespace
def _check_device(xp, device):
if xp == sys.modules.get('numpy'):
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device for NumPy: {device!r}")
# Placeholder object to represent the dask device
# when the array backend is not the CPU.
# (since it is not easy to tell which device a dask array is on)
class _dask_device:
def __repr__(self):
return "DASK_DEVICE"
_DASK_DEVICE = _dask_device()
# device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray
# or cupy.ndarray. They are not included in array objects of this library
# because this library just reuses the respective ndarray classes without
# wrapping or subclassing them. These helper functions can be used instead of
# the wrapper functions for libraries that need to support both NumPy/CuPy and
# other libraries that use devices.
def device(x: Array, /) -> Device:
"""
Hardware device the array data resides on.
This is equivalent to `x.device` according to the `standard
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.device.html>`__.
This helper is included because some array libraries either do not have
the `device` attribute or include it with an incompatible API.
Parameters
----------
x: array
array instance from an array API compatible library.
Returns
-------
out: device
a ``device`` object (see the `Device Support <https://data-apis.org/array-api/latest/design_topics/device_support.html>`__
section of the array API specification).
Notes
-----
For NumPy the device is always `"cpu"`. For Dask, the device is always a
special `DASK_DEVICE` object.
See Also
--------
to_device : Move array data to a different device.
"""
if is_numpy_array(x):
return "cpu"
elif is_dask_array(x):
# Peek at the metadata of the jax array to determine type
try:
import numpy as np
if isinstance(x._meta, np.ndarray):
# Must be on CPU since backed by numpy
return "cpu"
except ImportError:
pass
return _DASK_DEVICE
elif is_jax_array(x):
# JAX has .device() as a method, but it is being deprecated so that it
# can become a property, in accordance with the standard. In order for
# this function to not break when JAX makes the flip, we check for
# both here.
if inspect.ismethod(x.device):
return x.device()
else:
return x.device
return x.device
# Based on cupy.array_api.Array.to_device
def _cupy_to_device(x, device, /, stream=None):
import cupy as cp
from cupy.cuda import Device as _Device
from cupy.cuda import stream as stream_module
from cupy_backends.cuda.api import runtime
if device == x.device:
return x
elif device == "cpu":
# allowing us to use `to_device(x, "cpu")`
# is useful for portable test swapping between
# host and device backends
return x.get()
elif not isinstance(device, _Device):
raise ValueError(f"Unsupported device {device!r}")
else:
# see cupy/cupy#5985 for the reason how we handle device/stream here
prev_device = runtime.getDevice()
prev_stream: stream_module.Stream = None
if stream is not None:
prev_stream = stream_module.get_current_stream()
# stream can be an int as specified in __dlpack__, or a CuPy stream
if isinstance(stream, int):
stream = cp.cuda.ExternalStream(stream)
elif isinstance(stream, cp.cuda.Stream):
pass
else:
raise ValueError('the input stream is not recognized')
stream.use()
try:
runtime.setDevice(device.id)
arr = x.copy()
finally:
runtime.setDevice(prev_device)
if stream is not None:
prev_stream.use()
return arr
def _torch_to_device(x, device, /, stream=None):
if stream is not None:
raise NotImplementedError
return x.to(device)
def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array:
"""
Copy the array from the device on which it currently resides to the specified ``device``.
This is equivalent to `x.to_device(device, stream=stream)` according to
the `standard
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.to_device.html>`__.
This helper is included because some array libraries do not have the
`to_device` method.
Parameters
----------
x: array
array instance from an array API compatible library.
device: device
a ``device`` object (see the `Device Support <https://data-apis.org/array-api/latest/design_topics/device_support.html>`__
section of the array API specification).
stream: Optional[Union[int, Any]]
stream object to use during copy. In addition to the types supported
in ``array.__dlpack__``, implementations may choose to support any
library-specific stream object with the caveat that any code using
such an object would not be portable.
Returns
-------
out: array
an array with the same data and data type as ``x`` and located on the
specified ``device``.
Notes
-----
For NumPy, this function effectively does nothing since the only supported
device is the CPU. For CuPy, this method supports CuPy CUDA
:external+cupy:class:`Device <cupy.cuda.Device>` and
:external+cupy:class:`Stream <cupy.cuda.Stream>` objects. For PyTorch,
this is the same as :external+torch:meth:`x.to(device) <torch.Tensor.to>`
(the ``stream`` argument is not supported in PyTorch).
See Also
--------
device : Hardware device the array data resides on.
"""
if is_numpy_array(x):
if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
if device == 'cpu':
return x
raise ValueError(f"Unsupported device {device!r}")
elif is_cupy_array(x):
# cupy does not yet have to_device
return _cupy_to_device(x, device, stream=stream)
elif is_torch_array(x):
return _torch_to_device(x, device, stream=stream)
elif is_dask_array(x):
if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
# TODO: What if our array is on the GPU already?
if device == 'cpu':
return x
raise ValueError(f"Unsupported device {device!r}")
elif is_jax_array(x):
# This import adds to_device to x
import jax.experimental.array_api # noqa: F401
return x.to_device(device, stream=stream)
return x.to_device(device, stream=stream)
def size(x):
"""
Return the total number of elements of x.
This is equivalent to `x.size` according to the `standard
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.size.html>`__.
This helper is included because PyTorch defines `size` in an
:external+torch:meth:`incompatible way <torch.Tensor.size>`.
"""
if None in x.shape:
return None
return math.prod(x.shape)
__all__ = [
"array_namespace",
"device",
"get_namespace",
"is_array_api_obj",
"is_cupy_array",
"is_dask_array",
"is_jax_array",
"is_numpy_array",
"is_torch_array",
"size",
"to_device",
]
_all_ignore = ['sys', 'math', 'inspect', 'warnings']

View File

@ -0,0 +1,161 @@
from __future__ import annotations
from typing import TYPE_CHECKING, NamedTuple
if TYPE_CHECKING:
from typing import Literal, Optional, Tuple, Union
from ._typing import ndarray
import math
import numpy as np
if np.__version__[0] == "2":
from numpy.lib.array_utils import normalize_axis_tuple
else:
from numpy.core.numeric import normalize_axis_tuple
from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype
from .._internal import get_xp
# These are in the main NumPy namespace but not in numpy.linalg
def cross(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1, **kwargs) -> ndarray:
return xp.cross(x1, x2, axis=axis, **kwargs)
def outer(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray:
return xp.outer(x1, x2, **kwargs)
class EighResult(NamedTuple):
eigenvalues: ndarray
eigenvectors: ndarray
class QRResult(NamedTuple):
Q: ndarray
R: ndarray
class SlogdetResult(NamedTuple):
sign: ndarray
logabsdet: ndarray
class SVDResult(NamedTuple):
U: ndarray
S: ndarray
Vh: ndarray
# These functions are the same as their NumPy counterparts except they return
# a namedtuple.
def eigh(x: ndarray, /, xp, **kwargs) -> EighResult:
return EighResult(*xp.linalg.eigh(x, **kwargs))
def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced',
**kwargs) -> QRResult:
return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs))
def slogdet(x: ndarray, /, xp, **kwargs) -> SlogdetResult:
return SlogdetResult(*xp.linalg.slogdet(x, **kwargs))
def svd(x: ndarray, /, xp, *, full_matrices: bool = True, **kwargs) -> SVDResult:
return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs))
# These functions have additional keyword arguments
# The upper keyword argument is new from NumPy
def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray:
L = xp.linalg.cholesky(x, **kwargs)
if upper:
U = get_xp(xp)(matrix_transpose)(L)
if get_xp(xp)(isdtype)(U.dtype, 'complex floating'):
U = xp.conj(U)
return U
return L
# The rtol keyword argument of matrix_rank() and pinv() is new from NumPy.
# Note that it has a different semantic meaning from tol and rcond.
def matrix_rank(x: ndarray,
/,
xp,
*,
rtol: Optional[Union[float, ndarray]] = None,
**kwargs) -> ndarray:
# this is different from xp.linalg.matrix_rank, which supports 1
# dimensional arrays.
if x.ndim < 2:
raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
S = get_xp(xp)(svdvals)(x, **kwargs)
if rtol is None:
tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps
else:
# this is different from xp.linalg.matrix_rank, which does not
# multiply the tolerance by the largest singular value.
tol = S.max(axis=-1, keepdims=True)*xp.asarray(rtol)[..., xp.newaxis]
return xp.count_nonzero(S > tol, axis=-1)
def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **kwargs) -> ndarray:
# this is different from xp.linalg.pinv, which does not multiply the
# default tolerance by max(M, N).
if rtol is None:
rtol = max(x.shape[-2:]) * xp.finfo(x.dtype).eps
return xp.linalg.pinv(x, rcond=rtol, **kwargs)
# These functions are new in the array API spec
def matrix_norm(x: ndarray, /, xp, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray:
return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)
# svdvals is not in NumPy (but it is in SciPy). It is equivalent to
# xp.linalg.svd(compute_uv=False).
def svdvals(x: ndarray, /, xp) -> Union[ndarray, Tuple[ndarray, ...]]:
return xp.linalg.svd(x, compute_uv=False)
def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> ndarray:
# xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
# when axis=None and the input is 2-D, so to force a vector norm, we make
# it so the input is 1-D (for axis=None), or reshape so that norm is done
# on a single dimension.
if axis is None:
# Note: xp.linalg.norm() doesn't handle 0-D arrays
_x = x.ravel()
_axis = 0
elif isinstance(axis, tuple):
# Note: The axis argument supports any number of axes, whereas
# xp.linalg.norm() only supports a single axis for vector norm.
normalized_axis = normalize_axis_tuple(axis, x.ndim)
rest = tuple(i for i in range(x.ndim) if i not in normalized_axis)
newshape = axis + rest
_x = xp.transpose(x, newshape).reshape(
(math.prod([x.shape[i] for i in axis]), *[x.shape[i] for i in rest]))
_axis = 0
else:
_x = x
_axis = axis
res = xp.linalg.norm(_x, axis=_axis, ord=ord)
if keepdims:
# We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks
# above to avoid matrix norm logic.
shape = list(x.shape)
_axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim)
for i in _axis:
shape[i] = 1
res = xp.reshape(res, tuple(shape))
return res
# xp.diagonal and xp.trace operate on the first two axes whereas these
# operates on the last two
def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray:
return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)
def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray:
if dtype is None:
if x.dtype == xp.float32:
dtype = xp.float64
elif x.dtype == xp.complex64:
dtype = xp.complex128
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))
__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet',
'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm',
'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal',
'trace']

View File

@ -0,0 +1,23 @@
from __future__ import annotations
__all__ = [
"NestedSequence",
"SupportsBufferProtocol",
]
from typing import (
Any,
TypeVar,
Protocol,
)
_T_co = TypeVar("_T_co", covariant=True)
class NestedSequence(Protocol[_T_co]):
def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
def __len__(self, /) -> int: ...
SupportsBufferProtocol = Any
Array = Any
Device = Any

View File

@ -0,0 +1,16 @@
from cupy import * # noqa: F403
# from cupy import * doesn't overwrite these builtin names
from cupy import abs, max, min, round # noqa: F401
# These imports may overwrite names from the import * above.
from ._aliases import * # noqa: F403
# See the comment in the numpy __init__.py
__import__(__package__ + '.linalg')
__import__(__package__ + '.fft')
from ..common._helpers import * # noqa: F401,F403
__array_api_version__ = '2022.12'

View File

@ -0,0 +1,81 @@
from __future__ import annotations
from functools import partial
import cupy as cp
from ..common import _aliases
from .._internal import get_xp
asarray = asarray_cupy = partial(_aliases._asarray, namespace='cupy')
asarray.__doc__ = _aliases._asarray.__doc__
del partial
bool = cp.bool_
# Basic renames
acos = cp.arccos
acosh = cp.arccosh
asin = cp.arcsin
asinh = cp.arcsinh
atan = cp.arctan
atan2 = cp.arctan2
atanh = cp.arctanh
bitwise_left_shift = cp.left_shift
bitwise_invert = cp.invert
bitwise_right_shift = cp.right_shift
concat = cp.concatenate
pow = cp.power
arange = get_xp(cp)(_aliases.arange)
empty = get_xp(cp)(_aliases.empty)
empty_like = get_xp(cp)(_aliases.empty_like)
eye = get_xp(cp)(_aliases.eye)
full = get_xp(cp)(_aliases.full)
full_like = get_xp(cp)(_aliases.full_like)
linspace = get_xp(cp)(_aliases.linspace)
ones = get_xp(cp)(_aliases.ones)
ones_like = get_xp(cp)(_aliases.ones_like)
zeros = get_xp(cp)(_aliases.zeros)
zeros_like = get_xp(cp)(_aliases.zeros_like)
UniqueAllResult = get_xp(cp)(_aliases.UniqueAllResult)
UniqueCountsResult = get_xp(cp)(_aliases.UniqueCountsResult)
UniqueInverseResult = get_xp(cp)(_aliases.UniqueInverseResult)
unique_all = get_xp(cp)(_aliases.unique_all)
unique_counts = get_xp(cp)(_aliases.unique_counts)
unique_inverse = get_xp(cp)(_aliases.unique_inverse)
unique_values = get_xp(cp)(_aliases.unique_values)
astype = _aliases.astype
std = get_xp(cp)(_aliases.std)
var = get_xp(cp)(_aliases.var)
permute_dims = get_xp(cp)(_aliases.permute_dims)
reshape = get_xp(cp)(_aliases.reshape)
argsort = get_xp(cp)(_aliases.argsort)
sort = get_xp(cp)(_aliases.sort)
nonzero = get_xp(cp)(_aliases.nonzero)
sum = get_xp(cp)(_aliases.sum)
prod = get_xp(cp)(_aliases.prod)
ceil = get_xp(cp)(_aliases.ceil)
floor = get_xp(cp)(_aliases.floor)
trunc = get_xp(cp)(_aliases.trunc)
matmul = get_xp(cp)(_aliases.matmul)
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
tensordot = get_xp(cp)(_aliases.tensordot)
# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(cp, 'vecdot'):
vecdot = cp.vecdot
else:
vecdot = get_xp(cp)(_aliases.vecdot)
if hasattr(cp, 'isdtype'):
isdtype = cp.isdtype
else:
isdtype = get_xp(cp)(_aliases.isdtype)
__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow']
_all_ignore = ['cp', 'get_xp']

View File

@ -0,0 +1,46 @@
from __future__ import annotations
__all__ = [
"ndarray",
"Device",
"Dtype",
]
import sys
from typing import (
Union,
TYPE_CHECKING,
)
from cupy import (
ndarray,
dtype,
int8,
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
float32,
float64,
)
from cupy.cuda.device import Device
if TYPE_CHECKING or sys.version_info >= (3, 9):
Dtype = dtype[Union[
int8,
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
float32,
float64,
]]
else:
Dtype = dtype

View File

@ -0,0 +1,36 @@
from cupy.fft import * # noqa: F403
# cupy.fft doesn't have __all__. If it is added, replace this with
#
# from cupy.fft import __all__ as linalg_all
_n = {}
exec('from cupy.fft import *', _n)
del _n['__builtins__']
fft_all = list(_n)
del _n
from ..common import _fft
from .._internal import get_xp
import cupy as cp
fft = get_xp(cp)(_fft.fft)
ifft = get_xp(cp)(_fft.ifft)
fftn = get_xp(cp)(_fft.fftn)
ifftn = get_xp(cp)(_fft.ifftn)
rfft = get_xp(cp)(_fft.rfft)
irfft = get_xp(cp)(_fft.irfft)
rfftn = get_xp(cp)(_fft.rfftn)
irfftn = get_xp(cp)(_fft.irfftn)
hfft = get_xp(cp)(_fft.hfft)
ihfft = get_xp(cp)(_fft.ihfft)
fftfreq = get_xp(cp)(_fft.fftfreq)
rfftfreq = get_xp(cp)(_fft.rfftfreq)
fftshift = get_xp(cp)(_fft.fftshift)
ifftshift = get_xp(cp)(_fft.ifftshift)
__all__ = fft_all + _fft.__all__
del get_xp
del cp
del fft_all
del _fft

View File

@ -0,0 +1,49 @@
from cupy.linalg import * # noqa: F403
# cupy.linalg doesn't have __all__. If it is added, replace this with
#
# from cupy.linalg import __all__ as linalg_all
_n = {}
exec('from cupy.linalg import *', _n)
del _n['__builtins__']
linalg_all = list(_n)
del _n
from ..common import _linalg
from .._internal import get_xp
import cupy as cp
# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401
cross = get_xp(cp)(_linalg.cross)
outer = get_xp(cp)(_linalg.outer)
EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
eigh = get_xp(cp)(_linalg.eigh)
qr = get_xp(cp)(_linalg.qr)
slogdet = get_xp(cp)(_linalg.slogdet)
svd = get_xp(cp)(_linalg.svd)
cholesky = get_xp(cp)(_linalg.cholesky)
matrix_rank = get_xp(cp)(_linalg.matrix_rank)
pinv = get_xp(cp)(_linalg.pinv)
matrix_norm = get_xp(cp)(_linalg.matrix_norm)
svdvals = get_xp(cp)(_linalg.svdvals)
diagonal = get_xp(cp)(_linalg.diagonal)
trace = get_xp(cp)(_linalg.trace)
# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(cp.linalg, 'vector_norm'):
vector_norm = cp.linalg.vector_norm
else:
vector_norm = get_xp(cp)(_linalg.vector_norm)
__all__ = linalg_all + _linalg.__all__
del get_xp
del cp
del linalg_all
del _linalg

View File

@ -0,0 +1,8 @@
from dask.array import * # noqa: F403
# These imports may overwrite names from the import * above.
from ._aliases import * # noqa: F403
__array_api_version__ = '2022.12'
__import__(__package__ + '.linalg')

View File

@ -0,0 +1,146 @@
from __future__ import annotations
from ...common import _aliases
from ...common._helpers import _check_device
from ..._internal import get_xp
import numpy as np
from numpy import (
# Constants
e,
inf,
nan,
pi,
newaxis,
# Dtypes
bool_ as bool,
float32,
float64,
int8,
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
complex64,
complex128,
iinfo,
finfo,
can_cast,
result_type,
)
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Union
from ...common._typing import Device, Dtype, Array
import dask.array as da
isdtype = get_xp(np)(_aliases.isdtype)
astype = _aliases.astype
# Common aliases
# This arange func is modified from the common one to
# not pass stop/step as keyword arguments, which will cause
# an error with dask
# TODO: delete the xp stuff, it shouldn't be necessary
def _dask_arange(
start: Union[int, float],
/,
stop: Optional[Union[int, float]] = None,
step: Union[int, float] = 1,
*,
xp,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs,
) -> Array:
_check_device(xp, device)
args = [start]
if stop is not None:
args.append(stop)
else:
# stop is None, so start is actually stop
# prepend the default value for start which is 0
args.insert(0, 0)
args.append(step)
return xp.arange(*args, dtype=dtype, **kwargs)
arange = get_xp(da)(_dask_arange)
eye = get_xp(da)(_aliases.eye)
from functools import partial
asarray = partial(_aliases._asarray, namespace='dask.array')
asarray.__doc__ = _aliases._asarray.__doc__
linspace = get_xp(da)(_aliases.linspace)
eye = get_xp(da)(_aliases.eye)
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult)
UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult)
UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult)
unique_all = get_xp(da)(_aliases.unique_all)
unique_counts = get_xp(da)(_aliases.unique_counts)
unique_inverse = get_xp(da)(_aliases.unique_inverse)
unique_values = get_xp(da)(_aliases.unique_values)
permute_dims = get_xp(da)(_aliases.permute_dims)
std = get_xp(da)(_aliases.std)
var = get_xp(da)(_aliases.var)
empty = get_xp(da)(_aliases.empty)
empty_like = get_xp(da)(_aliases.empty_like)
full = get_xp(da)(_aliases.full)
full_like = get_xp(da)(_aliases.full_like)
ones = get_xp(da)(_aliases.ones)
ones_like = get_xp(da)(_aliases.ones_like)
zeros = get_xp(da)(_aliases.zeros)
zeros_like = get_xp(da)(_aliases.zeros_like)
reshape = get_xp(da)(_aliases.reshape)
matrix_transpose = get_xp(da)(_aliases.matrix_transpose)
vecdot = get_xp(da)(_aliases.vecdot)
nonzero = get_xp(da)(_aliases.nonzero)
sum = get_xp(np)(_aliases.sum)
prod = get_xp(np)(_aliases.prod)
ceil = get_xp(np)(_aliases.ceil)
floor = get_xp(np)(_aliases.floor)
trunc = get_xp(np)(_aliases.trunc)
matmul = get_xp(np)(_aliases.matmul)
tensordot = get_xp(np)(_aliases.tensordot)
from dask.array import (
# Element wise aliases
arccos as acos,
arccosh as acosh,
arcsin as asin,
arcsinh as asinh,
arctan as atan,
arctan2 as atan2,
arctanh as atanh,
left_shift as bitwise_left_shift,
right_shift as bitwise_right_shift,
invert as bitwise_invert,
power as pow,
# Other
concatenate as concat,
)
# exclude these from all since
_da_unsupported = ['sort', 'argsort']
common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]
__all__ = common_aliases + ['asarray', 'bool', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow',
'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8',
'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64',
'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type']
_all_ignore = ['get_xp', 'da', 'partial', 'common_aliases', 'np']

View File

@ -0,0 +1,72 @@
from __future__ import annotations
from ...common import _linalg
from ..._internal import get_xp
# Exports
from dask.array.linalg import * # noqa: F403
from dask.array import trace, outer
# These functions are in both the main and linalg namespaces
from dask.array import matmul, tensordot
from ._aliases import matrix_transpose, vecdot
import dask.array as da
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ...common._typing import Array
from typing import Literal
# dask.array.linalg doesn't have __all__. If it is added, replace this with
#
# from dask.array.linalg import __all__ as linalg_all
_n = {}
exec('from dask.array.linalg import *', _n)
del _n['__builtins__']
if 'annotations' in _n:
del _n['annotations']
linalg_all = list(_n)
del _n
EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
# TODO: use the QR wrapper once dask
# supports the mode keyword on QR
# https://github.com/dask/dask/issues/10388
#qr = get_xp(da)(_linalg.qr)
def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced',
**kwargs) -> QRResult:
if mode != "reduced":
raise ValueError("dask arrays only support using mode='reduced'")
return QRResult(*da.linalg.qr(x, **kwargs))
cholesky = get_xp(da)(_linalg.cholesky)
matrix_rank = get_xp(da)(_linalg.matrix_rank)
matrix_norm = get_xp(da)(_linalg.matrix_norm)
# Wrap the svd functions to not pass full_matrices to dask
# when full_matrices=False (as that is the default behavior for dask),
# and dask doesn't have the full_matrices keyword
def svd(x: Array, full_matrices: bool = True, **kwargs) -> SVDResult:
if full_matrices:
raise ValueError("full_matrics=True is not supported by dask.")
return da.linalg.svd(x, coerce_signs=False, **kwargs)
def svdvals(x: Array) -> Array:
# TODO: can't avoid computing U or V for dask
_, s, _ = svd(x)
return s
vector_norm = get_xp(da)(_linalg.vector_norm)
diagonal = get_xp(da)(_linalg.diagonal)
__all__ = linalg_all + ["trace", "outer", "matmul", "tensordot",
"matrix_transpose", "vecdot", "EighResult",
"QRResult", "SlogdetResult", "SVDResult", "qr",
"cholesky", "matrix_rank", "matrix_norm", "svdvals",
"vector_norm", "diagonal"]
_all_ignore = ['get_xp', 'da', 'linalg_all']

View File

@ -0,0 +1,24 @@
from numpy import * # noqa: F403
# from numpy import * doesn't overwrite these builtin names
from numpy import abs, max, min, round # noqa: F401
# These imports may overwrite names from the import * above.
from ._aliases import * # noqa: F403
# Don't know why, but we have to do an absolute import to import linalg. If we
# instead do
#
# from . import linalg
#
# It doesn't overwrite np.linalg from above. The import is generated
# dynamically so that the library can be vendored.
__import__(__package__ + '.linalg')
__import__(__package__ + '.fft')
from .linalg import matrix_transpose, vecdot # noqa: F401
from ..common._helpers import * # noqa: F403
__array_api_version__ = '2022.12'

View File

@ -0,0 +1,81 @@
from __future__ import annotations
from functools import partial
from ..common import _aliases
from .._internal import get_xp
asarray = asarray_numpy = partial(_aliases._asarray, namespace='numpy')
asarray.__doc__ = _aliases._asarray.__doc__
del partial
import numpy as np
bool = np.bool_
# Basic renames
acos = np.arccos
acosh = np.arccosh
asin = np.arcsin
asinh = np.arcsinh
atan = np.arctan
atan2 = np.arctan2
atanh = np.arctanh
bitwise_left_shift = np.left_shift
bitwise_invert = np.invert
bitwise_right_shift = np.right_shift
concat = np.concatenate
pow = np.power
arange = get_xp(np)(_aliases.arange)
empty = get_xp(np)(_aliases.empty)
empty_like = get_xp(np)(_aliases.empty_like)
eye = get_xp(np)(_aliases.eye)
full = get_xp(np)(_aliases.full)
full_like = get_xp(np)(_aliases.full_like)
linspace = get_xp(np)(_aliases.linspace)
ones = get_xp(np)(_aliases.ones)
ones_like = get_xp(np)(_aliases.ones_like)
zeros = get_xp(np)(_aliases.zeros)
zeros_like = get_xp(np)(_aliases.zeros_like)
UniqueAllResult = get_xp(np)(_aliases.UniqueAllResult)
UniqueCountsResult = get_xp(np)(_aliases.UniqueCountsResult)
UniqueInverseResult = get_xp(np)(_aliases.UniqueInverseResult)
unique_all = get_xp(np)(_aliases.unique_all)
unique_counts = get_xp(np)(_aliases.unique_counts)
unique_inverse = get_xp(np)(_aliases.unique_inverse)
unique_values = get_xp(np)(_aliases.unique_values)
astype = _aliases.astype
std = get_xp(np)(_aliases.std)
var = get_xp(np)(_aliases.var)
permute_dims = get_xp(np)(_aliases.permute_dims)
reshape = get_xp(np)(_aliases.reshape)
argsort = get_xp(np)(_aliases.argsort)
sort = get_xp(np)(_aliases.sort)
nonzero = get_xp(np)(_aliases.nonzero)
sum = get_xp(np)(_aliases.sum)
prod = get_xp(np)(_aliases.prod)
ceil = get_xp(np)(_aliases.ceil)
floor = get_xp(np)(_aliases.floor)
trunc = get_xp(np)(_aliases.trunc)
matmul = get_xp(np)(_aliases.matmul)
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
tensordot = get_xp(np)(_aliases.tensordot)
# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(np, 'vecdot'):
vecdot = np.vecdot
else:
vecdot = get_xp(np)(_aliases.vecdot)
if hasattr(np, 'isdtype'):
isdtype = np.isdtype
else:
isdtype = get_xp(np)(_aliases.isdtype)
__all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow']
_all_ignore = ['np', 'get_xp']

View File

@ -0,0 +1,46 @@
from __future__ import annotations
__all__ = [
"ndarray",
"Device",
"Dtype",
]
import sys
from typing import (
Literal,
Union,
TYPE_CHECKING,
)
from numpy import (
ndarray,
dtype,
int8,
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
float32,
float64,
)
Device = Literal["cpu"]
if TYPE_CHECKING or sys.version_info >= (3, 9):
Dtype = dtype[Union[
int8,
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
float32,
float64,
]]
else:
Dtype = dtype

View File

@ -0,0 +1,29 @@
from numpy.fft import * # noqa: F403
from numpy.fft import __all__ as fft_all
from ..common import _fft
from .._internal import get_xp
import numpy as np
fft = get_xp(np)(_fft.fft)
ifft = get_xp(np)(_fft.ifft)
fftn = get_xp(np)(_fft.fftn)
ifftn = get_xp(np)(_fft.ifftn)
rfft = get_xp(np)(_fft.rfft)
irfft = get_xp(np)(_fft.irfft)
rfftn = get_xp(np)(_fft.rfftn)
irfftn = get_xp(np)(_fft.irfftn)
hfft = get_xp(np)(_fft.hfft)
ihfft = get_xp(np)(_fft.ihfft)
fftfreq = get_xp(np)(_fft.fftfreq)
rfftfreq = get_xp(np)(_fft.rfftfreq)
fftshift = get_xp(np)(_fft.fftshift)
ifftshift = get_xp(np)(_fft.ifftshift)
__all__ = fft_all + _fft.__all__
del get_xp
del np
del fft_all
del _fft

View File

@ -0,0 +1,90 @@
from numpy.linalg import * # noqa: F403
from numpy.linalg import __all__ as linalg_all
import numpy as _np
from ..common import _linalg
from .._internal import get_xp
# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401
import numpy as np
cross = get_xp(np)(_linalg.cross)
outer = get_xp(np)(_linalg.outer)
EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
eigh = get_xp(np)(_linalg.eigh)
qr = get_xp(np)(_linalg.qr)
slogdet = get_xp(np)(_linalg.slogdet)
svd = get_xp(np)(_linalg.svd)
cholesky = get_xp(np)(_linalg.cholesky)
matrix_rank = get_xp(np)(_linalg.matrix_rank)
pinv = get_xp(np)(_linalg.pinv)
matrix_norm = get_xp(np)(_linalg.matrix_norm)
svdvals = get_xp(np)(_linalg.svdvals)
diagonal = get_xp(np)(_linalg.diagonal)
trace = get_xp(np)(_linalg.trace)
# Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a
# vector when it is exactly 1-dimensional. All other cases treat x2 as a stack
# of matrices. The np.linalg.solve behavior of allowing stacks of both
# matrices and vectors is ambiguous c.f.
# https://github.com/numpy/numpy/issues/15349 and
# https://github.com/data-apis/array-api/issues/285.
# To workaround this, the below is the code from np.linalg.solve except
# only calling solve1 in the exactly 1D case.
# This code is here instead of in common because it is numpy specific. Also
# note that CuPy's solve() does not currently support broadcasting (see
# https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43).
def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray:
try:
from numpy.linalg._linalg import (
_makearray, _assert_stacked_2d, _assert_stacked_square,
_commonType, isComplexType, _raise_linalgerror_singular
)
except ImportError:
from numpy.linalg.linalg import (
_makearray, _assert_stacked_2d, _assert_stacked_square,
_commonType, isComplexType, _raise_linalgerror_singular
)
from numpy.linalg import _umath_linalg
x1, _ = _makearray(x1)
_assert_stacked_2d(x1)
_assert_stacked_square(x1)
x2, wrap = _makearray(x2)
t, result_t = _commonType(x1, x2)
# This part is different from np.linalg.solve
if x2.ndim == 1:
gufunc = _umath_linalg.solve1
else:
gufunc = _umath_linalg.solve
# This does nothing currently but is left in because it will be relevant
# when complex dtype support is added to the spec in 2022.
signature = 'DD->D' if isComplexType(t) else 'dd->d'
with _np.errstate(call=_raise_linalgerror_singular, invalid='call',
over='ignore', divide='ignore', under='ignore'):
r = gufunc(x1, x2, signature=signature)
return wrap(r.astype(result_t, copy=False))
# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(np.linalg, 'vector_norm'):
vector_norm = np.linalg.vector_norm
else:
vector_norm = get_xp(np)(_linalg.vector_norm)
__all__ = linalg_all + _linalg.__all__ + ['solve']
del get_xp
del np
del linalg_all
del _linalg

View File

@ -0,0 +1,24 @@
from torch import * # noqa: F403
# Several names are not included in the above import *
import torch
for n in dir(torch):
if (n.startswith('_')
or n.endswith('_')
or 'cuda' in n
or 'cpu' in n
or 'backward' in n):
continue
exec(n + ' = torch.' + n)
# These imports may overwrite names from the import * above.
from ._aliases import * # noqa: F403
# See the comment in the numpy __init__.py
__import__(__package__ + '.linalg')
__import__(__package__ + '.fft')
from ..common._helpers import * # noqa: F403
__array_api_version__ = '2022.12'

View File

@ -0,0 +1,718 @@
from __future__ import annotations
from functools import wraps as _wraps
from builtins import all as _builtin_all, any as _builtin_any
from ..common._aliases import (matrix_transpose as _aliases_matrix_transpose,
vecdot as _aliases_vecdot)
from .._internal import get_xp
import torch
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import List, Optional, Sequence, Tuple, Union
from ..common._typing import Device
from torch import dtype as Dtype
array = torch.Tensor
_int_dtypes = {
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
}
_array_api_dtypes = {
torch.bool,
*_int_dtypes,
torch.float32,
torch.float64,
torch.complex64,
torch.complex128,
}
_promotion_table = {
# bool
(torch.bool, torch.bool): torch.bool,
# ints
(torch.int8, torch.int8): torch.int8,
(torch.int8, torch.int16): torch.int16,
(torch.int8, torch.int32): torch.int32,
(torch.int8, torch.int64): torch.int64,
(torch.int16, torch.int8): torch.int16,
(torch.int16, torch.int16): torch.int16,
(torch.int16, torch.int32): torch.int32,
(torch.int16, torch.int64): torch.int64,
(torch.int32, torch.int8): torch.int32,
(torch.int32, torch.int16): torch.int32,
(torch.int32, torch.int32): torch.int32,
(torch.int32, torch.int64): torch.int64,
(torch.int64, torch.int8): torch.int64,
(torch.int64, torch.int16): torch.int64,
(torch.int64, torch.int32): torch.int64,
(torch.int64, torch.int64): torch.int64,
# uints
(torch.uint8, torch.uint8): torch.uint8,
# ints and uints (mixed sign)
(torch.int8, torch.uint8): torch.int16,
(torch.int16, torch.uint8): torch.int16,
(torch.int32, torch.uint8): torch.int32,
(torch.int64, torch.uint8): torch.int64,
(torch.uint8, torch.int8): torch.int16,
(torch.uint8, torch.int16): torch.int16,
(torch.uint8, torch.int32): torch.int32,
(torch.uint8, torch.int64): torch.int64,
# floats
(torch.float32, torch.float32): torch.float32,
(torch.float32, torch.float64): torch.float64,
(torch.float64, torch.float32): torch.float64,
(torch.float64, torch.float64): torch.float64,
# complexes
(torch.complex64, torch.complex64): torch.complex64,
(torch.complex64, torch.complex128): torch.complex128,
(torch.complex128, torch.complex64): torch.complex128,
(torch.complex128, torch.complex128): torch.complex128,
# Mixed float and complex
(torch.float32, torch.complex64): torch.complex64,
(torch.float32, torch.complex128): torch.complex128,
(torch.float64, torch.complex64): torch.complex128,
(torch.float64, torch.complex128): torch.complex128,
}
def _two_arg(f):
@_wraps(f)
def _f(x1, x2, /, **kwargs):
x1, x2 = _fix_promotion(x1, x2)
return f(x1, x2, **kwargs)
if _f.__doc__ is None:
_f.__doc__ = f"""\
Array API compatibility wrapper for torch.{f.__name__}.
See the corresponding PyTorch documentation and/or the array API specification
for more details.
"""
return _f
def _fix_promotion(x1, x2, only_scalar=True):
if not isinstance(x1, torch.Tensor) or not isinstance(x2, torch.Tensor):
return x1, x2
if x1.dtype not in _array_api_dtypes or x2.dtype not in _array_api_dtypes:
return x1, x2
# If an argument is 0-D pytorch downcasts the other argument
if not only_scalar or x1.shape == ():
dtype = result_type(x1, x2)
x2 = x2.to(dtype)
if not only_scalar or x2.shape == ():
dtype = result_type(x1, x2)
x1 = x1.to(dtype)
return x1, x2
def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
if len(arrays_and_dtypes) == 0:
raise TypeError("At least one array or dtype must be provided")
if len(arrays_and_dtypes) == 1:
x = arrays_and_dtypes[0]
if isinstance(x, torch.dtype):
return x
return x.dtype
if len(arrays_and_dtypes) > 2:
return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:]))
x, y = arrays_and_dtypes
xdt = x.dtype if not isinstance(x, torch.dtype) else x
ydt = y.dtype if not isinstance(y, torch.dtype) else y
if (xdt, ydt) in _promotion_table:
return _promotion_table[xdt, ydt]
# This doesn't result_type(dtype, dtype) for non-array API dtypes
# because torch.result_type only accepts tensors. This does however, allow
# cross-kind promotion.
x = torch.tensor([], dtype=x) if isinstance(x, torch.dtype) else x
y = torch.tensor([], dtype=y) if isinstance(y, torch.dtype) else y
return torch.result_type(x, y)
def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
if not isinstance(from_, torch.dtype):
from_ = from_.dtype
return torch.can_cast(from_, to)
# Basic renames
bitwise_invert = torch.bitwise_not
newaxis = None
# Two-arg elementwise functions
# These require a wrapper to do the correct type promotion on 0-D tensors
add = _two_arg(torch.add)
atan2 = _two_arg(torch.atan2)
bitwise_and = _two_arg(torch.bitwise_and)
bitwise_left_shift = _two_arg(torch.bitwise_left_shift)
bitwise_or = _two_arg(torch.bitwise_or)
bitwise_right_shift = _two_arg(torch.bitwise_right_shift)
bitwise_xor = _two_arg(torch.bitwise_xor)
divide = _two_arg(torch.divide)
# Also a rename. torch.equal does not broadcast
equal = _two_arg(torch.eq)
floor_divide = _two_arg(torch.floor_divide)
greater = _two_arg(torch.greater)
greater_equal = _two_arg(torch.greater_equal)
less = _two_arg(torch.less)
less_equal = _two_arg(torch.less_equal)
logaddexp = _two_arg(torch.logaddexp)
# logical functions are not included here because they only accept bool in the
# spec, so type promotion is irrelevant.
multiply = _two_arg(torch.multiply)
not_equal = _two_arg(torch.not_equal)
pow = _two_arg(torch.pow)
remainder = _two_arg(torch.remainder)
subtract = _two_arg(torch.subtract)
# These wrappers are mostly based on the fact that pytorch uses 'dim' instead
# of 'axis'.
# torch.min and torch.max return a tuple and don't support multiple axes https://github.com/pytorch/pytorch/issues/58745
def max(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
# https://github.com/pytorch/pytorch/issues/29137
if axis == ():
return torch.clone(x)
return torch.amax(x, axis, keepdims=keepdims)
def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
# https://github.com/pytorch/pytorch/issues/29137
if axis == ():
return torch.clone(x)
return torch.amin(x, axis, keepdims=keepdims)
# torch.sort also returns a tuple
# https://github.com/pytorch/pytorch/issues/70921
def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> array:
return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values
def _normalize_axes(axis, ndim):
axes = []
if ndim == 0 and axis:
# Better error message in this case
raise IndexError(f"Dimension out of range: {axis[0]}")
lower, upper = -ndim, ndim - 1
for a in axis:
if a < lower or a > upper:
# Match torch error message (e.g., from sum())
raise IndexError(f"Dimension out of range (expected to be in range of [{lower}, {upper}], but got {a}")
if a < 0:
a = a + ndim
if a in axes:
# Use IndexError instead of RuntimeError, and "axis" instead of "dim"
raise IndexError(f"Axis {a} appears multiple times in the list of axes")
axes.append(a)
return sorted(axes)
def _axis_none_keepdims(x, ndim, keepdims):
# Apply keepdims when axis=None
# (https://github.com/pytorch/pytorch/issues/71209)
# Note that this is only valid for the axis=None case.
if keepdims:
for i in range(ndim):
x = torch.unsqueeze(x, 0)
return x
def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
# Some reductions don't support multiple axes
# (https://github.com/pytorch/pytorch/issues/56586).
axes = _normalize_axes(axis, x.ndim)
for a in reversed(axes):
x = torch.movedim(x, a, -1)
x = torch.flatten(x, -len(axes))
out = f(x, -1, **kwargs)
if keepdims:
for a in axes:
out = torch.unsqueeze(out, a)
return out
def prod(x: array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype: Optional[Dtype] = None,
keepdims: bool = False,
**kwargs) -> array:
x = torch.asarray(x)
ndim = x.ndim
# https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
# below because it still needs to upcast.
if axis == ():
if dtype is None:
# We can't upcast uint8 according to the spec because there is no
# torch.uint64, so at least upcast to int64 which is what sum does
# when axis=None.
if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]:
return x.to(torch.int64)
return x.clone()
return x.to(dtype)
# torch.prod doesn't support multiple axes
# (https://github.com/pytorch/pytorch/issues/56586).
if isinstance(axis, tuple):
return _reduce_multiple_axes(torch.prod, x, axis, keepdims=keepdims, dtype=dtype, **kwargs)
if axis is None:
# torch doesn't support keepdims with axis=None
# (https://github.com/pytorch/pytorch/issues/71209)
res = torch.prod(x, dtype=dtype, **kwargs)
res = _axis_none_keepdims(res, ndim, keepdims)
return res
return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs)
def sum(x: array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype: Optional[Dtype] = None,
keepdims: bool = False,
**kwargs) -> array:
x = torch.asarray(x)
ndim = x.ndim
# https://github.com/pytorch/pytorch/issues/29137.
# Make sure it upcasts.
if axis == ():
if dtype is None:
# We can't upcast uint8 according to the spec because there is no
# torch.uint64, so at least upcast to int64 which is what sum does
# when axis=None.
if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]:
return x.to(torch.int64)
return x.clone()
return x.to(dtype)
if axis is None:
# torch doesn't support keepdims with axis=None
# (https://github.com/pytorch/pytorch/issues/71209)
res = torch.sum(x, dtype=dtype, **kwargs)
res = _axis_none_keepdims(res, ndim, keepdims)
return res
return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs)
def any(x: array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
**kwargs) -> array:
x = torch.asarray(x)
ndim = x.ndim
if axis == ():
return x.to(torch.bool)
# torch.any doesn't support multiple axes
# (https://github.com/pytorch/pytorch/issues/56586).
if isinstance(axis, tuple):
res = _reduce_multiple_axes(torch.any, x, axis, keepdims=keepdims, **kwargs)
return res.to(torch.bool)
if axis is None:
# torch doesn't support keepdims with axis=None
# (https://github.com/pytorch/pytorch/issues/71209)
res = torch.any(x, **kwargs)
res = _axis_none_keepdims(res, ndim, keepdims)
return res.to(torch.bool)
# torch.any doesn't return bool for uint8
return torch.any(x, axis, keepdims=keepdims).to(torch.bool)
def all(x: array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
**kwargs) -> array:
x = torch.asarray(x)
ndim = x.ndim
if axis == ():
return x.to(torch.bool)
# torch.all doesn't support multiple axes
# (https://github.com/pytorch/pytorch/issues/56586).
if isinstance(axis, tuple):
res = _reduce_multiple_axes(torch.all, x, axis, keepdims=keepdims, **kwargs)
return res.to(torch.bool)
if axis is None:
# torch doesn't support keepdims with axis=None
# (https://github.com/pytorch/pytorch/issues/71209)
res = torch.all(x, **kwargs)
res = _axis_none_keepdims(res, ndim, keepdims)
return res.to(torch.bool)
# torch.all doesn't return bool for uint8
return torch.all(x, axis, keepdims=keepdims).to(torch.bool)
def mean(x: array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
**kwargs) -> array:
# https://github.com/pytorch/pytorch/issues/29137
if axis == ():
return torch.clone(x)
if axis is None:
# torch doesn't support keepdims with axis=None
# (https://github.com/pytorch/pytorch/issues/71209)
res = torch.mean(x, **kwargs)
res = _axis_none_keepdims(res, x.ndim, keepdims)
return res
return torch.mean(x, axis, keepdims=keepdims, **kwargs)
def std(x: array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
correction: Union[int, float] = 0.0,
keepdims: bool = False,
**kwargs) -> array:
# Note, float correction is not supported
# https://github.com/pytorch/pytorch/issues/61492. We don't try to
# implement it here for now.
if isinstance(correction, float):
_correction = int(correction)
if correction != _correction:
raise NotImplementedError("float correction in torch std() is not yet supported")
else:
_correction = correction
# https://github.com/pytorch/pytorch/issues/29137
if axis == ():
return torch.zeros_like(x)
if isinstance(axis, int):
axis = (axis,)
if axis is None:
# torch doesn't support keepdims with axis=None
# (https://github.com/pytorch/pytorch/issues/71209)
res = torch.std(x, tuple(range(x.ndim)), correction=_correction, **kwargs)
res = _axis_none_keepdims(res, x.ndim, keepdims)
return res
return torch.std(x, axis, correction=_correction, keepdims=keepdims, **kwargs)
def var(x: array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
correction: Union[int, float] = 0.0,
keepdims: bool = False,
**kwargs) -> array:
# Note, float correction is not supported
# https://github.com/pytorch/pytorch/issues/61492. We don't try to
# implement it here for now.
# if isinstance(correction, float):
# correction = int(correction)
# https://github.com/pytorch/pytorch/issues/29137
if axis == ():
return torch.zeros_like(x)
if isinstance(axis, int):
axis = (axis,)
if axis is None:
# torch doesn't support keepdims with axis=None
# (https://github.com/pytorch/pytorch/issues/71209)
res = torch.var(x, tuple(range(x.ndim)), correction=correction, **kwargs)
res = _axis_none_keepdims(res, x.ndim, keepdims)
return res
return torch.var(x, axis, correction=correction, keepdims=keepdims, **kwargs)
# torch.concat doesn't support dim=None
# https://github.com/pytorch/pytorch/issues/70925
def concat(arrays: Union[Tuple[array, ...], List[array]],
/,
*,
axis: Optional[int] = 0,
**kwargs) -> array:
if axis is None:
arrays = tuple(ar.flatten() for ar in arrays)
axis = 0
return torch.concat(arrays, axis, **kwargs)
# torch.squeeze only accepts int dim and doesn't require it
# https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was
# added at https://github.com/pytorch/pytorch/pull/89017.
def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array:
if isinstance(axis, int):
axis = (axis,)
for a in axis:
if x.shape[a] != 1:
raise ValueError("squeezed dimensions must be equal to 1")
axes = _normalize_axes(axis, x.ndim)
# Remove this once pytorch 1.14 is released with the above PR #89017.
sequence = [a - i for i, a in enumerate(axes)]
for a in sequence:
x = torch.squeeze(x, a)
return x
# torch.broadcast_to uses size instead of shape
def broadcast_to(x: array, /, shape: Tuple[int, ...], **kwargs) -> array:
return torch.broadcast_to(x, shape, **kwargs)
# torch.permute uses dims instead of axes
def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array:
return torch.permute(x, axes)
# The axis parameter doesn't work for flip() and roll()
# https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't
# accept axis=None
def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array:
if axis is None:
axis = tuple(range(x.ndim))
# torch.flip doesn't accept dim as an int but the method does
# https://github.com/pytorch/pytorch/issues/18095
return x.flip(axis, **kwargs)
def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array:
return torch.roll(x, shift, axis, **kwargs)
def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
if x.ndim == 0:
raise ValueError("nonzero() does not support zero-dimensional arrays")
return torch.nonzero(x, as_tuple=True, **kwargs)
def where(condition: array, x1: array, x2: array, /) -> array:
x1, x2 = _fix_promotion(x1, x2)
return torch.where(condition, x1, x2)
# torch.reshape doesn't have the copy keyword
def reshape(x: array,
/,
shape: Tuple[int, ...],
copy: Optional[bool] = None,
**kwargs) -> array:
if copy is not None:
raise NotImplementedError("torch.reshape doesn't yet support the copy keyword")
return torch.reshape(x, shape, **kwargs)
# torch.arange doesn't support returning empty arrays
# (https://github.com/pytorch/pytorch/issues/70915), and doesn't support some
# keyword argument combinations
# (https://github.com/pytorch/pytorch/issues/70914)
def arange(start: Union[int, float],
/,
stop: Optional[Union[int, float]] = None,
step: Union[int, float] = 1,
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs) -> array:
if stop is None:
start, stop = 0, start
if step > 0 and stop <= start or step < 0 and stop >= start:
if dtype is None:
if _builtin_all(isinstance(i, int) for i in [start, stop, step]):
dtype = torch.int64
else:
dtype = torch.float32
return torch.empty(0, dtype=dtype, device=device, **kwargs)
return torch.arange(start, stop, step, dtype=dtype, device=device, **kwargs)
# torch.eye does not accept None as a default for the second argument and
# doesn't support off-diagonals (https://github.com/pytorch/pytorch/issues/70910)
def eye(n_rows: int,
n_cols: Optional[int] = None,
/,
*,
k: int = 0,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs) -> array:
if n_cols is None:
n_cols = n_rows
z = torch.zeros(n_rows, n_cols, dtype=dtype, device=device, **kwargs)
if abs(k) <= n_rows + n_cols:
z.diagonal(k).fill_(1)
return z
# torch.linspace doesn't have the endpoint parameter
def linspace(start: Union[int, float],
stop: Union[int, float],
/,
num: int,
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
endpoint: bool = True,
**kwargs) -> array:
if not endpoint:
return torch.linspace(start, stop, num+1, dtype=dtype, device=device, **kwargs)[:-1]
return torch.linspace(start, stop, num, dtype=dtype, device=device, **kwargs)
# torch.full does not accept an int size
# https://github.com/pytorch/pytorch/issues/70906
def full(shape: Union[int, Tuple[int, ...]],
fill_value: Union[bool, int, float, complex],
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs) -> array:
if isinstance(shape, int):
shape = (shape,)
return torch.full(shape, fill_value, dtype=dtype, device=device, **kwargs)
# ones, zeros, and empty do not accept shape as a keyword argument
def ones(shape: Union[int, Tuple[int, ...]],
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs) -> array:
return torch.ones(shape, dtype=dtype, device=device, **kwargs)
def zeros(shape: Union[int, Tuple[int, ...]],
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs) -> array:
return torch.zeros(shape, dtype=dtype, device=device, **kwargs)
def empty(shape: Union[int, Tuple[int, ...]],
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
**kwargs) -> array:
return torch.empty(shape, dtype=dtype, device=device, **kwargs)
# tril and triu do not call the keyword argument k
def tril(x: array, /, *, k: int = 0) -> array:
return torch.tril(x, k)
def triu(x: array, /, *, k: int = 0) -> array:
return torch.triu(x, k)
# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742
def expand_dims(x: array, /, *, axis: int = 0) -> array:
return torch.unsqueeze(x, axis)
def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array:
return x.to(dtype, copy=copy)
def broadcast_arrays(*arrays: array) -> List[array]:
shape = torch.broadcast_shapes(*[a.shape for a in arrays])
return [torch.broadcast_to(a, shape) for a in arrays]
# Note that these named tuples aren't actually part of the standard namespace,
# but I don't see any issue with exporting the names here regardless.
from ..common._aliases import (UniqueAllResult, UniqueCountsResult,
UniqueInverseResult)
# https://github.com/pytorch/pytorch/issues/70920
def unique_all(x: array) -> UniqueAllResult:
# torch.unique doesn't support returning indices.
# https://github.com/pytorch/pytorch/issues/36748. The workaround
# suggested in that issue doesn't actually function correctly (it relies
# on non-deterministic behavior of scatter()).
raise NotImplementedError("unique_all() not yet implemented for pytorch (see https://github.com/pytorch/pytorch/issues/36748)")
# values, inverse_indices, counts = torch.unique(x, return_counts=True, return_inverse=True)
# # torch.unique incorrectly gives a 0 count for nan values.
# # https://github.com/pytorch/pytorch/issues/94106
# counts[torch.isnan(values)] = 1
# return UniqueAllResult(values, indices, inverse_indices, counts)
def unique_counts(x: array) -> UniqueCountsResult:
values, counts = torch.unique(x, return_counts=True)
# torch.unique incorrectly gives a 0 count for nan values.
# https://github.com/pytorch/pytorch/issues/94106
counts[torch.isnan(values)] = 1
return UniqueCountsResult(values, counts)
def unique_inverse(x: array) -> UniqueInverseResult:
values, inverse = torch.unique(x, return_inverse=True)
return UniqueInverseResult(values, inverse)
def unique_values(x: array) -> array:
return torch.unique(x)
def matmul(x1: array, x2: array, /, **kwargs) -> array:
# torch.matmul doesn't type promote (but differently from _fix_promotion)
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return torch.matmul(x1, x2, **kwargs)
matrix_transpose = get_xp(torch)(_aliases_matrix_transpose)
_vecdot = get_xp(torch)(_aliases_vecdot)
def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return _vecdot(x1, x2, axis=axis)
# torch.tensordot uses dims instead of axes
def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, **kwargs) -> array:
# Note: torch.tensordot fails with integer dtypes when there is only 1
# element in the axis (https://github.com/pytorch/pytorch/issues/84530).
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return torch.tensordot(x1, x2, dims=axes, **kwargs)
def isdtype(
dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]],
*, _tuple=True, # Disallow nested tuples
) -> bool:
"""
Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
Note that outside of this function, this compat library does not yet fully
support complex numbers.
See
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
for more details
"""
if isinstance(kind, tuple) and _tuple:
return _builtin_any(isdtype(dtype, k, _tuple=False) for k in kind)
elif isinstance(kind, str):
if kind == 'bool':
return dtype == torch.bool
elif kind == 'signed integer':
return dtype in _int_dtypes and dtype.is_signed
elif kind == 'unsigned integer':
return dtype in _int_dtypes and not dtype.is_signed
elif kind == 'integral':
return dtype in _int_dtypes
elif kind == 'real floating':
return dtype.is_floating_point
elif kind == 'complex floating':
return dtype.is_complex
elif kind == 'numeric':
return isdtype(dtype, ('integral', 'real floating', 'complex floating'))
else:
raise ValueError(f"Unrecognized data type kind: {kind!r}")
else:
return dtype == kind
def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -> array:
if axis is None:
if x.ndim != 1:
raise ValueError("axis must be specified when ndim > 1")
axis = 0
return torch.index_select(x, axis, indices, **kwargs)
__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert',
'newaxis', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift',
'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'divide',
'equal', 'floor_divide', 'greater', 'greater_equal', 'less',
'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow',
'remainder', 'subtract', 'max', 'min', 'sort', 'prod', 'sum',
'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze',
'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape',
'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty',
'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays',
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype',
'take']
_all_ignore = ['torch', 'get_xp']

View File

@ -0,0 +1,86 @@
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import torch
array = torch.Tensor
from typing import Union, Sequence, Literal
from torch.fft import * # noqa: F403
import torch.fft
# Several torch fft functions do not map axes to dim
def fftn(
x: array,
/,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs)
def ifftn(
x: array,
/,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs)
def rfftn(
x: array,
/,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs)
def irfftn(
x: array,
/,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
**kwargs,
) -> array:
return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs)
def fftshift(
x: array,
/,
*,
axes: Union[int, Sequence[int]] = None,
**kwargs,
) -> array:
return torch.fft.fftshift(x, dim=axes, **kwargs)
def ifftshift(
x: array,
/,
*,
axes: Union[int, Sequence[int]] = None,
**kwargs,
) -> array:
return torch.fft.ifftshift(x, dim=axes, **kwargs)
__all__ = torch.fft.__all__ + [
"fftn",
"ifftn",
"rfftn",
"irfftn",
"fftshift",
"ifftshift",
]
_all_ignore = ['torch']

View File

@ -0,0 +1,89 @@
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import torch
array = torch.Tensor
from torch import dtype as Dtype
from typing import Optional, Union, Tuple, Literal
inf = float('inf')
from ._aliases import _fix_promotion, sum
from torch.linalg import * # noqa: F403
# torch.linalg doesn't define __all__
# from torch.linalg import __all__ as linalg_all
from torch import linalg as torch_linalg
linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')]
# outer is implemented in torch but aren't in the linalg namespace
from torch import outer
# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
# torch.cross also does not support broadcasting when it would add new
# dimensions https://github.com/pytorch/pytorch/issues/39656
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)):
raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}")
if not (x1.shape[axis] == x2.shape[axis] == 3):
raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}")
x1, x2 = torch.broadcast_tensors(x1, x2)
return torch_linalg.cross(x1, x2, dim=axis)
def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
from ._aliases import isdtype
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
# torch.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
if x1.shape[axis] != x2.shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")
# torch.linalg.vecdot doesn't support integer dtypes
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
if kwargs:
raise RuntimeError("vecdot kwargs not supported for integral dtypes")
x1_ = torch.moveaxis(x1, axis, -1)
x2_ = torch.moveaxis(x2, axis, -1)
x1_, x2_ = torch.broadcast_tensors(x1_, x2_)
res = x1_[..., None, :] @ x2_[..., None]
return res[..., 0, 0]
return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs)
def solve(x1: array, x2: array, /, **kwargs) -> array:
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return torch.linalg.solve(x1, x2, **kwargs)
# torch.trace doesn't support the offset argument and doesn't support stacking
def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
# Use our wrapped sum to make sure it does upcasting correctly
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
def vector_norm(
x: array,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
ord: Union[int, float, Literal[inf, -inf]] = 2,
**kwargs,
) -> array:
# torch.vector_norm incorrectly treats axis=() the same as axis=None
if axis == ():
keepdims = True
return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs)
__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot',
'cross', 'vecdot', 'solve', 'trace', 'vector_norm']
_all_ignore = ['torch_linalg', 'sum']
del linalg_all

View File

@ -0,0 +1,20 @@
from .main import minimize
from .utils import show_versions
# PEP0440 compatible formatted version, see:
# https://www.python.org/dev/peps/pep-0440/
#
# Final release markers:
# X.Y.0 # For first release after an increment in Y
# X.Y.Z # For bugfix releases
#
# Admissible pre-release markers:
# X.YaN # Alpha release
# X.YbN # Beta release
# X.YrcN # Release Candidate
#
# Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer.
# 'X.Y.dev0' is the canonical version of 'X.Y.dev'.
__version__ = "1.1.1"
__all__ = ["minimize", "show_versions"]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,132 @@
import sys
from enum import Enum
import numpy as np
# Exit status.
class ExitStatus(Enum):
"""
Exit statuses.
"""
RADIUS_SUCCESS = 0
TARGET_SUCCESS = 1
FIXED_SUCCESS = 2
CALLBACK_SUCCESS = 3
FEASIBLE_SUCCESS = 4
MAX_EVAL_WARNING = 5
MAX_ITER_WARNING = 6
INFEASIBLE_ERROR = -1
LINALG_ERROR = -2
class Options(str, Enum):
"""
Options.
"""
DEBUG = "debug"
FEASIBILITY_TOL = "feasibility_tol"
FILTER_SIZE = "filter_size"
HISTORY_SIZE = "history_size"
MAX_EVAL = "maxfev"
MAX_ITER = "maxiter"
NPT = "nb_points"
RHOBEG = "radius_init"
RHOEND = "radius_final"
SCALE = "scale"
STORE_HISTORY = "store_history"
TARGET = "target"
VERBOSE = "disp"
class Constants(str, Enum):
"""
Constants.
"""
DECREASE_RADIUS_FACTOR = "decrease_radius_factor"
INCREASE_RADIUS_FACTOR = "increase_radius_factor"
INCREASE_RADIUS_THRESHOLD = "increase_radius_threshold"
DECREASE_RADIUS_THRESHOLD = "decrease_radius_threshold"
DECREASE_RESOLUTION_FACTOR = "decrease_resolution_factor"
LARGE_RESOLUTION_THRESHOLD = "large_resolution_threshold"
MODERATE_RESOLUTION_THRESHOLD = "moderate_resolution_threshold"
LOW_RATIO = "low_ratio"
HIGH_RATIO = "high_ratio"
VERY_LOW_RATIO = "very_low_ratio"
PENALTY_INCREASE_THRESHOLD = "penalty_increase_threshold"
PENALTY_INCREASE_FACTOR = "penalty_increase_factor"
SHORT_STEP_THRESHOLD = "short_step_threshold"
LOW_RADIUS_FACTOR = "low_radius_factor"
BYRD_OMOJOKUN_FACTOR = "byrd_omojokun_factor"
THRESHOLD_RATIO_CONSTRAINTS = "threshold_ratio_constraints"
LARGE_SHIFT_FACTOR = "large_shift_factor"
LARGE_GRADIENT_FACTOR = "large_gradient_factor"
RESOLUTION_FACTOR = "resolution_factor"
IMPROVE_TCG = "improve_tcg"
# Default options.
DEFAULT_OPTIONS = {
Options.DEBUG.value: False,
Options.FEASIBILITY_TOL.value: np.sqrt(np.finfo(float).eps),
Options.FILTER_SIZE.value: sys.maxsize,
Options.HISTORY_SIZE.value: sys.maxsize,
Options.MAX_EVAL.value: lambda n: 500 * n,
Options.MAX_ITER.value: lambda n: 1000 * n,
Options.NPT.value: lambda n: 2 * n + 1,
Options.RHOBEG.value: 1.0,
Options.RHOEND.value: 1e-6,
Options.SCALE.value: False,
Options.STORE_HISTORY.value: False,
Options.TARGET.value: -np.inf,
Options.VERBOSE.value: False,
}
# Default constants.
DEFAULT_CONSTANTS = {
Constants.DECREASE_RADIUS_FACTOR.value: 0.5,
Constants.INCREASE_RADIUS_FACTOR.value: np.sqrt(2.0),
Constants.INCREASE_RADIUS_THRESHOLD.value: 2.0,
Constants.DECREASE_RADIUS_THRESHOLD.value: 1.4,
Constants.DECREASE_RESOLUTION_FACTOR.value: 0.1,
Constants.LARGE_RESOLUTION_THRESHOLD.value: 250.0,
Constants.MODERATE_RESOLUTION_THRESHOLD.value: 16.0,
Constants.LOW_RATIO.value: 0.1,
Constants.HIGH_RATIO.value: 0.7,
Constants.VERY_LOW_RATIO.value: 0.01,
Constants.PENALTY_INCREASE_THRESHOLD.value: 1.5,
Constants.PENALTY_INCREASE_FACTOR.value: 2.0,
Constants.SHORT_STEP_THRESHOLD.value: 0.5,
Constants.LOW_RADIUS_FACTOR.value: 0.1,
Constants.BYRD_OMOJOKUN_FACTOR.value: 0.8,
Constants.THRESHOLD_RATIO_CONSTRAINTS.value: 2.0,
Constants.LARGE_SHIFT_FACTOR.value: 10.0,
Constants.LARGE_GRADIENT_FACTOR.value: 10.0,
Constants.RESOLUTION_FACTOR.value: 2.0,
Constants.IMPROVE_TCG.value: True,
}
# Printing options.
PRINT_OPTIONS = {
"threshold": 6,
"edgeitems": 2,
"linewidth": sys.maxsize,
"formatter": {
"float_kind": lambda x: np.format_float_scientific(
x,
precision=3,
unique=False,
pad_left=2,
)
},
}
# Constants.
BARRIER = 2.0 ** min(
100,
np.finfo(float).maxexp // 2,
-np.finfo(float).minexp // 2,
)

View File

@ -0,0 +1,14 @@
from .geometry import cauchy_geometry, spider_geometry
from .optim import (
tangential_byrd_omojokun,
constrained_tangential_byrd_omojokun,
normal_byrd_omojokun,
)
__all__ = [
"cauchy_geometry",
"spider_geometry",
"tangential_byrd_omojokun",
"constrained_tangential_byrd_omojokun",
"normal_byrd_omojokun",
]

View File

@ -0,0 +1,387 @@
import inspect
import numpy as np
from ..utils import get_arrays_tol
TINY = np.finfo(float).tiny
def cauchy_geometry(const, grad, curv, xl, xu, delta, debug):
r"""
Maximize approximately the absolute value of a quadratic function subject
to bound constraints in a trust region.
This function solves approximately
.. math::
\max_{s \in \mathbb{R}^n} \quad \bigg\lvert c + g^{\mathsf{T}} s +
\frac{1}{2} s^{\mathsf{T}} H s \bigg\rvert \quad \text{s.t.} \quad
\left\{ \begin{array}{l}
l \le s \le u,\\
\lVert s \rVert \le \Delta,
\end{array} \right.
by maximizing the objective function along the constrained Cauchy
direction.
Parameters
----------
const : float
Constant :math:`c` as shown above.
grad : `numpy.ndarray`, shape (n,)
Gradient :math:`g` as shown above.
curv : callable
Curvature of :math:`H` along any vector.
``curv(s) -> float``
returns :math:`s^{\mathsf{T}} H s`.
xl : `numpy.ndarray`, shape (n,)
Lower bounds :math:`l` as shown above.
xu : `numpy.ndarray`, shape (n,)
Upper bounds :math:`u` as shown above.
delta : float
Trust-region radius :math:`\Delta` as shown above.
debug : bool
Whether to make debugging tests during the execution.
Returns
-------
`numpy.ndarray`, shape (n,)
Approximate solution :math:`s`.
Notes
-----
This function is described as the first alternative in Section 6.5 of [1]_.
It is assumed that the origin is feasible with respect to the bound
constraints and that `delta` is finite and positive.
References
----------
.. [1] T. M. Ragonneau. *Model-Based Derivative-Free Optimization Methods
and Software*. PhD thesis, Department of Applied Mathematics, The Hong
Kong Polytechnic University, Hong Kong, China, 2022. URL:
https://theses.lib.polyu.edu.hk/handle/200/12294.
"""
if debug:
assert isinstance(const, float)
assert isinstance(grad, np.ndarray) and grad.ndim == 1
assert inspect.signature(curv).bind(grad)
assert isinstance(xl, np.ndarray) and xl.shape == grad.shape
assert isinstance(xu, np.ndarray) and xu.shape == grad.shape
assert isinstance(delta, float)
assert isinstance(debug, bool)
tol = get_arrays_tol(xl, xu)
assert np.all(xl <= tol)
assert np.all(xu >= -tol)
assert np.isfinite(delta) and delta > 0.0
xl = np.minimum(xl, 0.0)
xu = np.maximum(xu, 0.0)
# To maximize the absolute value of a quadratic function, we maximize the
# function itself or its negative, and we choose the solution that provides
# the largest function value.
step1, q_val1 = _cauchy_geom(const, grad, curv, xl, xu, delta, debug)
step2, q_val2 = _cauchy_geom(
-const,
-grad,
lambda x: -curv(x),
xl,
xu,
delta,
debug,
)
step = step1 if abs(q_val1) >= abs(q_val2) else step2
if debug:
assert np.all(xl <= step)
assert np.all(step <= xu)
assert np.linalg.norm(step) < 1.1 * delta
return step
def spider_geometry(const, grad, curv, xpt, xl, xu, delta, debug):
r"""
Maximize approximately the absolute value of a quadratic function subject
to bound constraints in a trust region.
This function solves approximately
.. math::
\max_{s \in \mathbb{R}^n} \quad \bigg\lvert c + g^{\mathsf{T}} s +
\frac{1}{2} s^{\mathsf{T}} H s \bigg\rvert \quad \text{s.t.} \quad
\left\{ \begin{array}{l}
l \le s \le u,\\
\lVert s \rVert \le \Delta,
\end{array} \right.
by maximizing the objective function along given straight lines.
Parameters
----------
const : float
Constant :math:`c` as shown above.
grad : `numpy.ndarray`, shape (n,)
Gradient :math:`g` as shown above.
curv : callable
Curvature of :math:`H` along any vector.
``curv(s) -> float``
returns :math:`s^{\mathsf{T}} H s`.
xpt : `numpy.ndarray`, shape (n, npt)
Points defining the straight lines. The straight lines considered are
the ones passing through the origin and the points in `xpt`.
xl : `numpy.ndarray`, shape (n,)
Lower bounds :math:`l` as shown above.
xu : `numpy.ndarray`, shape (n,)
Upper bounds :math:`u` as shown above.
delta : float
Trust-region radius :math:`\Delta` as shown above.
debug : bool
Whether to make debugging tests during the execution.
Returns
-------
`numpy.ndarray`, shape (n,)
Approximate solution :math:`s`.
Notes
-----
This function is described as the second alternative in Section 6.5 of
[1]_. It is assumed that the origin is feasible with respect to the bound
constraints and that `delta` is finite and positive.
References
----------
.. [1] T. M. Ragonneau. *Model-Based Derivative-Free Optimization Methods
and Software*. PhD thesis, Department of Applied Mathematics, The Hong
Kong Polytechnic University, Hong Kong, China, 2022. URL:
https://theses.lib.polyu.edu.hk/handle/200/12294.
"""
if debug:
assert isinstance(const, float)
assert isinstance(grad, np.ndarray) and grad.ndim == 1
assert inspect.signature(curv).bind(grad)
assert (
isinstance(xpt, np.ndarray)
and xpt.ndim == 2
and xpt.shape[0] == grad.size
)
assert isinstance(xl, np.ndarray) and xl.shape == grad.shape
assert isinstance(xu, np.ndarray) and xu.shape == grad.shape
assert isinstance(delta, float)
assert isinstance(debug, bool)
tol = get_arrays_tol(xl, xu)
assert np.all(xl <= tol)
assert np.all(xu >= -tol)
assert np.isfinite(delta) and delta > 0.0
xl = np.minimum(xl, 0.0)
xu = np.maximum(xu, 0.0)
# Iterate through the straight lines.
step = np.zeros_like(grad)
q_val = const
s_norm = np.linalg.norm(xpt, axis=0)
# Set alpha_xl to the step size for the lower-bound constraint and
# alpha_xu to the step size for the upper-bound constraint.
# xl.shape = (N,)
# xpt.shape = (N, M)
# i_xl_pos.shape = (M, N)
i_xl_pos = (xl > -np.inf) & (xpt.T > -TINY * xl)
i_xl_neg = (xl > -np.inf) & (xpt.T < TINY * xl)
i_xu_pos = (xu < np.inf) & (xpt.T > TINY * xu)
i_xu_neg = (xu < np.inf) & (xpt.T < -TINY * xu)
# (M, N)
alpha_xl_pos = np.atleast_2d(
np.broadcast_to(xl, i_xl_pos.shape)[i_xl_pos] / xpt.T[i_xl_pos]
)
# (M,)
alpha_xl_pos = np.max(alpha_xl_pos, axis=1, initial=-np.inf)
# make sure it's (M,)
alpha_xl_pos = np.broadcast_to(np.atleast_1d(alpha_xl_pos), xpt.shape[1])
alpha_xl_neg = np.atleast_2d(
np.broadcast_to(xl, i_xl_neg.shape)[i_xl_neg] / xpt.T[i_xl_neg]
)
alpha_xl_neg = np.max(alpha_xl_neg, axis=1, initial=np.inf)
alpha_xl_neg = np.broadcast_to(np.atleast_1d(alpha_xl_neg), xpt.shape[1])
alpha_xu_neg = np.atleast_2d(
np.broadcast_to(xu, i_xu_neg.shape)[i_xu_neg] / xpt.T[i_xu_neg]
)
alpha_xu_neg = np.max(alpha_xu_neg, axis=1, initial=-np.inf)
alpha_xu_neg = np.broadcast_to(np.atleast_1d(alpha_xu_neg), xpt.shape[1])
alpha_xu_pos = np.atleast_2d(
np.broadcast_to(xu, i_xu_pos.shape)[i_xu_pos] / xpt.T[i_xu_pos]
)
alpha_xu_pos = np.max(alpha_xu_pos, axis=1, initial=np.inf)
alpha_xu_pos = np.broadcast_to(np.atleast_1d(alpha_xu_pos), xpt.shape[1])
for k in range(xpt.shape[1]):
# Set alpha_tr to the step size for the trust-region constraint.
if s_norm[k] > TINY * delta:
alpha_tr = max(delta / s_norm[k], 0.0)
else:
# The current straight line is basically zero.
continue
alpha_bd_pos = max(min(alpha_xu_pos[k], alpha_xl_neg[k]), 0.0)
alpha_bd_neg = min(max(alpha_xl_pos[k], alpha_xu_neg[k]), 0.0)
# Set alpha_quad_pos and alpha_quad_neg to the step size to the extrema
# of the quadratic function along the positive and negative directions.
grad_step = grad @ xpt[:, k]
curv_step = curv(xpt[:, k])
if (
grad_step >= 0.0
and curv_step < -TINY * grad_step
or grad_step <= 0.0
and curv_step > -TINY * grad_step
):
alpha_quad_pos = max(-grad_step / curv_step, 0.0)
else:
alpha_quad_pos = np.inf
if (
grad_step >= 0.0
and curv_step > TINY * grad_step
or grad_step <= 0.0
and curv_step < TINY * grad_step
):
alpha_quad_neg = min(-grad_step / curv_step, 0.0)
else:
alpha_quad_neg = -np.inf
# Select the step that provides the largest value of the objective
# function if it improves the current best. The best positive step is
# either the one that reaches the constraints or the one that reaches
# the extremum of the objective function along the current direction
# (only possible if the resulting step is feasible). We test both, and
# we perform similar calculations along the negative step.
# N.B.: we select the largest possible step among all the ones that
# maximize the objective function. This is to avoid returning the zero
# step in some extreme cases.
alpha_pos = min(alpha_tr, alpha_bd_pos)
alpha_neg = max(-alpha_tr, alpha_bd_neg)
q_val_pos = (
const + alpha_pos * grad_step + 0.5 * alpha_pos**2.0 * curv_step
)
q_val_neg = (
const + alpha_neg * grad_step + 0.5 * alpha_neg**2.0 * curv_step
)
if alpha_quad_pos < alpha_pos:
q_val_quad_pos = (
const
+ alpha_quad_pos * grad_step
+ 0.5 * alpha_quad_pos**2.0 * curv_step
)
if abs(q_val_quad_pos) > abs(q_val_pos):
alpha_pos = alpha_quad_pos
q_val_pos = q_val_quad_pos
if alpha_quad_neg > alpha_neg:
q_val_quad_neg = (
const
+ alpha_quad_neg * grad_step
+ 0.5 * alpha_quad_neg**2.0 * curv_step
)
if abs(q_val_quad_neg) > abs(q_val_neg):
alpha_neg = alpha_quad_neg
q_val_neg = q_val_quad_neg
if abs(q_val_pos) >= abs(q_val_neg) and abs(q_val_pos) > abs(q_val):
step = np.clip(alpha_pos * xpt[:, k], xl, xu)
q_val = q_val_pos
elif abs(q_val_neg) > abs(q_val_pos) and abs(q_val_neg) > abs(q_val):
step = np.clip(alpha_neg * xpt[:, k], xl, xu)
q_val = q_val_neg
if debug:
assert np.all(xl <= step)
assert np.all(step <= xu)
assert np.linalg.norm(step) < 1.1 * delta
return step
def _cauchy_geom(const, grad, curv, xl, xu, delta, debug):
"""
Same as `bound_constrained_cauchy_step` without the absolute value.
"""
# Calculate the initial active set.
fixed_xl = (xl < 0.0) & (grad > 0.0)
fixed_xu = (xu > 0.0) & (grad < 0.0)
# Calculate the Cauchy step.
cauchy_step = np.zeros_like(grad)
cauchy_step[fixed_xl] = xl[fixed_xl]
cauchy_step[fixed_xu] = xu[fixed_xu]
if np.linalg.norm(cauchy_step) > delta:
working = fixed_xl | fixed_xu
while True:
# Calculate the Cauchy step for the directions in the working set.
g_norm = np.linalg.norm(grad[working])
delta_reduced = np.sqrt(
delta**2.0 - cauchy_step[~working] @ cauchy_step[~working]
)
if g_norm > TINY * abs(delta_reduced):
mu = max(delta_reduced / g_norm, 0.0)
else:
break
cauchy_step[working] = mu * grad[working]
# Update the working set.
fixed_xl = working & (cauchy_step < xl)
fixed_xu = working & (cauchy_step > xu)
if not np.any(fixed_xl) and not np.any(fixed_xu):
# Stop the calculations as the Cauchy step is now feasible.
break
cauchy_step[fixed_xl] = xl[fixed_xl]
cauchy_step[fixed_xu] = xu[fixed_xu]
working = working & ~(fixed_xl | fixed_xu)
# Calculate the step that maximizes the quadratic along the Cauchy step.
grad_step = grad @ cauchy_step
if grad_step >= 0.0:
# Set alpha_tr to the step size for the trust-region constraint.
s_norm = np.linalg.norm(cauchy_step)
if s_norm > TINY * delta:
alpha_tr = max(delta / s_norm, 0.0)
else:
# The Cauchy step is basically zero.
alpha_tr = 0.0
# Set alpha_quad to the step size for the maximization problem.
curv_step = curv(cauchy_step)
if curv_step < -TINY * grad_step:
alpha_quad = max(-grad_step / curv_step, 0.0)
else:
alpha_quad = np.inf
# Set alpha_bd to the step size for the bound constraints.
i_xl = (xl > -np.inf) & (cauchy_step < TINY * xl)
i_xu = (xu < np.inf) & (cauchy_step > TINY * xu)
alpha_xl = np.min(xl[i_xl] / cauchy_step[i_xl], initial=np.inf)
alpha_xu = np.min(xu[i_xu] / cauchy_step[i_xu], initial=np.inf)
alpha_bd = min(alpha_xl, alpha_xu)
# Calculate the solution and the corresponding function value.
alpha = min(alpha_tr, alpha_quad, alpha_bd)
step = np.clip(alpha * cauchy_step, xl, xu)
q_val = const + alpha * grad_step + 0.5 * alpha**2.0 * curv_step
else:
# This case is never reached in exact arithmetic. It prevents this
# function to return a step that decreases the objective function.
step = np.zeros_like(grad)
q_val = const
if debug:
assert np.all(xl <= step)
assert np.all(step <= xu)
assert np.linalg.norm(step) < 1.1 * delta
return step, q_val

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,18 @@
from .exceptions import (
MaxEvalError,
TargetSuccess,
CallbackSuccess,
FeasibleSuccess,
)
from .math import get_arrays_tol, exact_1d_array
from .versions import show_versions
__all__ = [
"MaxEvalError",
"TargetSuccess",
"CallbackSuccess",
"FeasibleSuccess",
"get_arrays_tol",
"exact_1d_array",
"show_versions",
]

View File

@ -0,0 +1,22 @@
class MaxEvalError(Exception):
"""
Exception raised when the maximum number of evaluations is reached.
"""
class TargetSuccess(Exception):
"""
Exception raised when the target value is reached.
"""
class CallbackSuccess(StopIteration):
"""
Exception raised when the callback function raises a ``StopIteration``.
"""
class FeasibleSuccess(Exception):
"""
Exception raised when a feasible point of a feasible problem is found.
"""

View File

@ -0,0 +1,77 @@
import numpy as np
EPS = np.finfo(float).eps
def get_arrays_tol(*arrays):
"""
Get a relative tolerance for a set of arrays.
Parameters
----------
*arrays: tuple
Set of `numpy.ndarray` to get the tolerance for.
Returns
-------
float
Relative tolerance for the set of arrays.
Raises
------
ValueError
If no array is provided.
"""
if len(arrays) == 0:
raise ValueError("At least one array must be provided.")
size = max(array.size for array in arrays)
weight = max(
np.max(np.abs(array[np.isfinite(array)]), initial=1.0)
for array in arrays
)
return 10.0 * EPS * max(size, 1.0) * weight
def exact_1d_array(x, message):
"""
Preprocess a 1-dimensional array.
Parameters
----------
x : array_like
Array to be preprocessed.
message : str
Error message if `x` cannot be interpreter as a 1-dimensional array.
Returns
-------
`numpy.ndarray`
Preprocessed array.
"""
x = np.atleast_1d(np.squeeze(x)).astype(float)
if x.ndim != 1:
raise ValueError(message)
return x
def exact_2d_array(x, message):
"""
Preprocess a 2-dimensional array.
Parameters
----------
x : array_like
Array to be preprocessed.
message : str
Error message if `x` cannot be interpreter as a 2-dimensional array.
Returns
-------
`numpy.ndarray`
Preprocessed array.
"""
x = np.atleast_2d(x).astype(float)
if x.ndim != 2:
raise ValueError(message)
return x

View File

@ -0,0 +1,67 @@
import os
import platform
import sys
from importlib.metadata import PackageNotFoundError, version
def _get_sys_info():
"""
Get useful system information.
Returns
-------
dict
Useful system information.
"""
return {
"python": sys.version.replace(os.linesep, " "),
"executable": sys.executable,
"machine": platform.platform(),
}
def _get_deps_info():
"""
Get the versions of the dependencies.
Returns
-------
dict
Versions of the dependencies.
"""
deps = ["cobyqa", "numpy", "scipy", "setuptools", "pip"]
deps_info = {}
for module in deps:
try:
deps_info[module] = version(module)
except PackageNotFoundError:
deps_info[module] = None
return deps_info
def show_versions():
"""
Display useful system and dependencies information.
When reporting issues, please include this information.
"""
print("System settings")
print("---------------")
sys_info = _get_sys_info()
print(
"\n".join(
f"{k:>{max(map(len, sys_info.keys())) + 1}}: {v}"
for k, v in sys_info.items()
)
)
print()
print("Python dependencies")
print("-------------------")
deps_info = _get_deps_info()
print(
"\n".join(
f"{k:>{max(map(len, deps_info.keys())) + 1}}: {v}"
for k, v in deps_info.items()
)
)

View File

@ -0,0 +1,399 @@
# ######################### LICENSE ############################ #
# Copyright (c) 2005-2015, Michele Simionato
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
# Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# Redistributions in bytecode form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in
# the documentation and/or other materials provided with the
# distribution.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
# DAMAGE.
"""
Decorator module, see https://pypi.python.org/pypi/decorator
for the documentation.
"""
import re
import sys
import inspect
import operator
import itertools
import collections
from inspect import getfullargspec
__version__ = '4.0.5'
def get_init(cls):
return cls.__init__
# getargspec has been deprecated in Python 3.5
ArgSpec = collections.namedtuple(
'ArgSpec', 'args varargs varkw defaults')
def getargspec(f):
"""A replacement for inspect.getargspec"""
spec = getfullargspec(f)
return ArgSpec(spec.args, spec.varargs, spec.varkw, spec.defaults)
DEF = re.compile(r'\s*def\s*([_\w][_\w\d]*)\s*\(')
# basic functionality
class FunctionMaker:
"""
An object with the ability to create functions with a given signature.
It has attributes name, doc, module, signature, defaults, dict, and
methods update and make.
"""
# Atomic get-and-increment provided by the GIL
_compile_count = itertools.count()
def __init__(self, func=None, name=None, signature=None,
defaults=None, doc=None, module=None, funcdict=None):
self.shortsignature = signature
if func:
# func can be a class or a callable, but not an instance method
self.name = func.__name__
if self.name == '<lambda>': # small hack for lambda functions
self.name = '_lambda_'
self.doc = func.__doc__
self.module = func.__module__
if inspect.isfunction(func):
argspec = getfullargspec(func)
self.annotations = getattr(func, '__annotations__', {})
for a in ('args', 'varargs', 'varkw', 'defaults', 'kwonlyargs',
'kwonlydefaults'):
setattr(self, a, getattr(argspec, a))
for i, arg in enumerate(self.args):
setattr(self, 'arg%d' % i, arg)
allargs = list(self.args)
allshortargs = list(self.args)
if self.varargs:
allargs.append('*' + self.varargs)
allshortargs.append('*' + self.varargs)
elif self.kwonlyargs:
allargs.append('*') # single star syntax
for a in self.kwonlyargs:
allargs.append('%s=None' % a)
allshortargs.append(f'{a}={a}')
if self.varkw:
allargs.append('**' + self.varkw)
allshortargs.append('**' + self.varkw)
self.signature = ', '.join(allargs)
self.shortsignature = ', '.join(allshortargs)
self.dict = func.__dict__.copy()
# func=None happens when decorating a caller
if name:
self.name = name
if signature is not None:
self.signature = signature
if defaults:
self.defaults = defaults
if doc:
self.doc = doc
if module:
self.module = module
if funcdict:
self.dict = funcdict
# check existence required attributes
assert hasattr(self, 'name')
if not hasattr(self, 'signature'):
raise TypeError('You are decorating a non-function: %s' % func)
def update(self, func, **kw):
"Update the signature of func with the data in self"
func.__name__ = self.name
func.__doc__ = getattr(self, 'doc', None)
func.__dict__ = getattr(self, 'dict', {})
func.__defaults__ = getattr(self, 'defaults', ())
func.__kwdefaults__ = getattr(self, 'kwonlydefaults', None)
func.__annotations__ = getattr(self, 'annotations', None)
try:
frame = sys._getframe(3)
except AttributeError: # for IronPython and similar implementations
callermodule = '?'
else:
callermodule = frame.f_globals.get('__name__', '?')
func.__module__ = getattr(self, 'module', callermodule)
func.__dict__.update(kw)
def make(self, src_templ, evaldict=None, addsource=False, **attrs):
"Make a new function from a given template and update the signature"
src = src_templ % vars(self) # expand name and signature
evaldict = evaldict or {}
mo = DEF.match(src)
if mo is None:
raise SyntaxError('not a valid function template\n%s' % src)
name = mo.group(1) # extract the function name
names = set([name] + [arg.strip(' *') for arg in
self.shortsignature.split(',')])
for n in names:
if n in ('_func_', '_call_'):
raise NameError(f'{n} is overridden in\n{src}')
if not src.endswith('\n'): # add a newline just for safety
src += '\n' # this is needed in old versions of Python
# Ensure each generated function has a unique filename for profilers
# (such as cProfile) that depend on the tuple of (<filename>,
# <definition line>, <function name>) being unique.
filename = '<decorator-gen-%d>' % (next(self._compile_count),)
try:
code = compile(src, filename, 'single')
exec(code, evaldict)
except: # noqa: E722
print('Error in generated code:', file=sys.stderr)
print(src, file=sys.stderr)
raise
func = evaldict[name]
if addsource:
attrs['__source__'] = src
self.update(func, **attrs)
return func
@classmethod
def create(cls, obj, body, evaldict, defaults=None,
doc=None, module=None, addsource=True, **attrs):
"""
Create a function from the strings name, signature, and body.
evaldict is the evaluation dictionary. If addsource is true, an
attribute __source__ is added to the result. The attributes attrs
are added, if any.
"""
if isinstance(obj, str): # "name(signature)"
name, rest = obj.strip().split('(', 1)
signature = rest[:-1] # strip a right parens
func = None
else: # a function
name = None
signature = None
func = obj
self = cls(func, name, signature, defaults, doc, module)
ibody = '\n'.join(' ' + line for line in body.splitlines())
return self.make('def %(name)s(%(signature)s):\n' + ibody,
evaldict, addsource, **attrs)
def decorate(func, caller):
"""
decorate(func, caller) decorates a function using a caller.
"""
evaldict = func.__globals__.copy()
evaldict['_call_'] = caller
evaldict['_func_'] = func
fun = FunctionMaker.create(
func, "return _call_(_func_, %(shortsignature)s)",
evaldict, __wrapped__=func)
if hasattr(func, '__qualname__'):
fun.__qualname__ = func.__qualname__
return fun
def decorator(caller, _func=None):
"""decorator(caller) converts a caller function into a decorator"""
if _func is not None: # return a decorated function
# this is obsolete behavior; you should use decorate instead
return decorate(_func, caller)
# else return a decorator function
if inspect.isclass(caller):
name = caller.__name__.lower()
callerfunc = get_init(caller)
doc = (f'decorator({caller.__name__}) converts functions/generators into '
f'factories of {caller.__name__} objects')
elif inspect.isfunction(caller):
if caller.__name__ == '<lambda>':
name = '_lambda_'
else:
name = caller.__name__
callerfunc = caller
doc = caller.__doc__
else: # assume caller is an object with a __call__ method
name = caller.__class__.__name__.lower()
callerfunc = caller.__call__.__func__
doc = caller.__call__.__doc__
evaldict = callerfunc.__globals__.copy()
evaldict['_call_'] = caller
evaldict['_decorate_'] = decorate
return FunctionMaker.create(
'%s(func)' % name, 'return _decorate_(func, _call_)',
evaldict, doc=doc, module=caller.__module__,
__wrapped__=caller)
# ####################### contextmanager ####################### #
try: # Python >= 3.2
from contextlib import _GeneratorContextManager
except ImportError: # Python >= 2.5
from contextlib import GeneratorContextManager as _GeneratorContextManager
class ContextManager(_GeneratorContextManager):
def __call__(self, func):
"""Context manager decorator"""
return FunctionMaker.create(
func, "with _self_: return _func_(%(shortsignature)s)",
dict(_self_=self, _func_=func), __wrapped__=func)
init = getfullargspec(_GeneratorContextManager.__init__)
n_args = len(init.args)
if n_args == 2 and not init.varargs: # (self, genobj) Python 2.7
def __init__(self, g, *a, **k):
return _GeneratorContextManager.__init__(self, g(*a, **k))
ContextManager.__init__ = __init__
elif n_args == 2 and init.varargs: # (self, gen, *a, **k) Python 3.4
pass
elif n_args == 4: # (self, gen, args, kwds) Python 3.5
def __init__(self, g, *a, **k):
return _GeneratorContextManager.__init__(self, g, a, k)
ContextManager.__init__ = __init__
contextmanager = decorator(ContextManager)
# ############################ dispatch_on ############################ #
def append(a, vancestors):
"""
Append ``a`` to the list of the virtual ancestors, unless it is already
included.
"""
add = True
for j, va in enumerate(vancestors):
if issubclass(va, a):
add = False
break
if issubclass(a, va):
vancestors[j] = a
add = False
if add:
vancestors.append(a)
# inspired from simplegeneric by P.J. Eby and functools.singledispatch
def dispatch_on(*dispatch_args):
"""
Factory of decorators turning a function into a generic function
dispatching on the given arguments.
"""
assert dispatch_args, 'No dispatch args passed'
dispatch_str = '(%s,)' % ', '.join(dispatch_args)
def check(arguments, wrong=operator.ne, msg=''):
"""Make sure one passes the expected number of arguments"""
if wrong(len(arguments), len(dispatch_args)):
raise TypeError('Expected %d arguments, got %d%s' %
(len(dispatch_args), len(arguments), msg))
def gen_func_dec(func):
"""Decorator turning a function into a generic function"""
# first check the dispatch arguments
argset = set(getfullargspec(func).args)
if not set(dispatch_args) <= argset:
raise NameError('Unknown dispatch arguments %s' % dispatch_str)
typemap = {}
def vancestors(*types):
"""
Get a list of sets of virtual ancestors for the given types
"""
check(types)
ras = [[] for _ in range(len(dispatch_args))]
for types_ in typemap:
for t, type_, ra in zip(types, types_, ras):
if issubclass(t, type_) and type_ not in t.__mro__:
append(type_, ra)
return [set(ra) for ra in ras]
def ancestors(*types):
"""
Get a list of virtual MROs, one for each type
"""
check(types)
lists = []
for t, vas in zip(types, vancestors(*types)):
n_vas = len(vas)
if n_vas > 1:
raise RuntimeError(
f'Ambiguous dispatch for {t}: {vas}')
elif n_vas == 1:
va, = vas
mro = type('t', (t, va), {}).__mro__[1:]
else:
mro = t.__mro__
lists.append(mro[:-1]) # discard t and object
return lists
def register(*types):
"""
Decorator to register an implementation for the given types
"""
check(types)
def dec(f):
check(getfullargspec(f).args, operator.lt, ' in ' + f.__name__)
typemap[types] = f
return f
return dec
def dispatch_info(*types):
"""
An utility to introspect the dispatch algorithm
"""
check(types)
lst = [tuple(a.__name__ for a in anc)
for anc in itertools.product(*ancestors(*types))]
return lst
def _dispatch(dispatch_args, *args, **kw):
types = tuple(type(arg) for arg in dispatch_args)
try: # fast path
f = typemap[types]
except KeyError:
pass
else:
return f(*args, **kw)
combinations = itertools.product(*ancestors(*types))
next(combinations) # the first one has been already tried
for types_ in combinations:
f = typemap.get(types_)
if f is not None:
return f(*args, **kw)
# else call the default implementation
return func(*args, **kw)
return FunctionMaker.create(
func, 'return _f_(%s, %%(shortsignature)s)' % dispatch_str,
dict(_f_=_dispatch), register=register, default=func,
typemap=typemap, vancestors=vancestors, ancestors=ancestors,
dispatch_info=dispatch_info, __wrapped__=func)
gen_func_dec.__name__ = 'dispatch_on' + dispatch_str
return gen_func_dec

View File

@ -0,0 +1,239 @@
from inspect import Parameter, signature
import functools
import warnings
from importlib import import_module
__all__ = ["_deprecated"]
# Object to use as default value for arguments to be deprecated. This should
# be used over 'None' as the user could parse 'None' as a positional argument
_NoValue = object()
def _sub_module_deprecation(*, sub_package, module, private_modules, all,
attribute, correct_module=None):
"""Helper function for deprecating modules that are public but were
intended to be private.
Parameters
----------
sub_package : str
Subpackage the module belongs to eg. stats
module : str
Public but intended private module to deprecate
private_modules : list
Private replacement(s) for `module`; should contain the
content of ``all``, possibly spread over several modules.
all : list
``__all__`` belonging to `module`
attribute : str
The attribute in `module` being accessed
correct_module : str, optional
Module in `sub_package` that `attribute` should be imported from.
Default is that `attribute` should be imported from ``scipy.sub_package``.
"""
if correct_module is not None:
correct_import = f"scipy.{sub_package}.{correct_module}"
else:
correct_import = f"scipy.{sub_package}"
if attribute not in all:
raise AttributeError(
f"`scipy.{sub_package}.{module}` has no attribute `{attribute}`; "
f"furthermore, `scipy.{sub_package}.{module}` is deprecated "
f"and will be removed in SciPy 2.0.0."
)
attr = getattr(import_module(correct_import), attribute, None)
if attr is not None:
message = (
f"Please import `{attribute}` from the `{correct_import}` namespace; "
f"the `scipy.{sub_package}.{module}` namespace is deprecated "
f"and will be removed in SciPy 2.0.0."
)
else:
message = (
f"`scipy.{sub_package}.{module}.{attribute}` is deprecated along with "
f"the `scipy.{sub_package}.{module}` namespace. "
f"`scipy.{sub_package}.{module}.{attribute}` will be removed "
f"in SciPy 1.14.0, and the `scipy.{sub_package}.{module}` namespace "
f"will be removed in SciPy 2.0.0."
)
warnings.warn(message, category=DeprecationWarning, stacklevel=3)
for module in private_modules:
try:
return getattr(import_module(f"scipy.{sub_package}.{module}"), attribute)
except AttributeError as e:
# still raise an error if the attribute isn't in any of the expected
# private modules
if module == private_modules[-1]:
raise e
continue
def _deprecated(msg, stacklevel=2):
"""Deprecate a function by emitting a warning on use."""
def wrap(fun):
if isinstance(fun, type):
warnings.warn(
f"Trying to deprecate class {fun!r}",
category=RuntimeWarning, stacklevel=2)
return fun
@functools.wraps(fun)
def call(*args, **kwargs):
warnings.warn(msg, category=DeprecationWarning,
stacklevel=stacklevel)
return fun(*args, **kwargs)
call.__doc__ = fun.__doc__
return call
return wrap
class _DeprecationHelperStr:
"""
Helper class used by deprecate_cython_api
"""
def __init__(self, content, message):
self._content = content
self._message = message
def __hash__(self):
return hash(self._content)
def __eq__(self, other):
res = (self._content == other)
if res:
warnings.warn(self._message, category=DeprecationWarning,
stacklevel=2)
return res
def deprecate_cython_api(module, routine_name, new_name=None, message=None):
"""
Deprecate an exported cdef function in a public Cython API module.
Only functions can be deprecated; typedefs etc. cannot.
Parameters
----------
module : module
Public Cython API module (e.g. scipy.linalg.cython_blas).
routine_name : str
Name of the routine to deprecate. May also be a fused-type
routine (in which case its all specializations are deprecated).
new_name : str
New name to include in the deprecation warning message
message : str
Additional text in the deprecation warning message
Examples
--------
Usually, this function would be used in the top-level of the
module ``.pyx`` file:
>>> from scipy._lib.deprecation import deprecate_cython_api
>>> import scipy.linalg.cython_blas as mod
>>> deprecate_cython_api(mod, "dgemm", "dgemm_new",
... message="Deprecated in Scipy 1.5.0")
>>> del deprecate_cython_api, mod
After this, Cython modules that use the deprecated function emit a
deprecation warning when they are imported.
"""
old_name = f"{module.__name__}.{routine_name}"
if new_name is None:
depdoc = "`%s` is deprecated!" % old_name
else:
depdoc = f"`{old_name}` is deprecated, use `{new_name}` instead!"
if message is not None:
depdoc += "\n" + message
d = module.__pyx_capi__
# Check if the function is a fused-type function with a mangled name
j = 0
has_fused = False
while True:
fused_name = f"__pyx_fuse_{j}{routine_name}"
if fused_name in d:
has_fused = True
d[_DeprecationHelperStr(fused_name, depdoc)] = d.pop(fused_name)
j += 1
else:
break
# If not, apply deprecation to the named routine
if not has_fused:
d[_DeprecationHelperStr(routine_name, depdoc)] = d.pop(routine_name)
# taken from scikit-learn, see
# https://github.com/scikit-learn/scikit-learn/blob/1.3.0/sklearn/utils/validation.py#L38
def _deprecate_positional_args(func=None, *, version=None):
"""Decorator for methods that issues warnings for positional arguments.
Using the keyword-only argument syntax in pep 3102, arguments after the
* will issue a warning when passed as a positional argument.
Parameters
----------
func : callable, default=None
Function to check arguments on.
version : callable, default=None
The version when positional arguments will result in error.
"""
if version is None:
msg = "Need to specify a version where signature will be changed"
raise ValueError(msg)
def _inner_deprecate_positional_args(f):
sig = signature(f)
kwonly_args = []
all_args = []
for name, param in sig.parameters.items():
if param.kind == Parameter.POSITIONAL_OR_KEYWORD:
all_args.append(name)
elif param.kind == Parameter.KEYWORD_ONLY:
kwonly_args.append(name)
@functools.wraps(f)
def inner_f(*args, **kwargs):
extra_args = len(args) - len(all_args)
if extra_args <= 0:
return f(*args, **kwargs)
# extra_args > 0
args_msg = [
f"{name}={arg}"
for name, arg in zip(kwonly_args[:extra_args], args[-extra_args:])
]
args_msg = ", ".join(args_msg)
warnings.warn(
(
f"You are passing {args_msg} as a positional argument. "
"Please change your invocation to use keyword arguments. "
f"From SciPy {version}, passing these as positional "
"arguments will result in an error."
),
DeprecationWarning,
stacklevel=2,
)
kwargs.update(zip(sig.parameters, args))
return f(**kwargs)
return inner_f
if func is not None:
return _inner_deprecate_positional_args(func)
return _inner_deprecate_positional_args

View File

@ -0,0 +1,275 @@
''' Utilities to allow inserting docstring fragments for common
parameters into function and method docstrings'''
import sys
__all__ = [
'docformat', 'inherit_docstring_from', 'indentcount_lines',
'filldoc', 'unindent_dict', 'unindent_string', 'extend_notes_in_docstring',
'replace_notes_in_docstring', 'doc_replace'
]
def docformat(docstring, docdict=None):
''' Fill a function docstring from variables in dictionary
Adapt the indent of the inserted docs
Parameters
----------
docstring : string
docstring from function, possibly with dict formatting strings
docdict : dict, optional
dictionary with keys that match the dict formatting strings
and values that are docstring fragments to be inserted. The
indentation of the inserted docstrings is set to match the
minimum indentation of the ``docstring`` by adding this
indentation to all lines of the inserted string, except the
first.
Returns
-------
outstring : string
string with requested ``docdict`` strings inserted
Examples
--------
>>> docformat(' Test string with %(value)s', {'value':'inserted value'})
' Test string with inserted value'
>>> docstring = 'First line\\n Second line\\n %(value)s'
>>> inserted_string = "indented\\nstring"
>>> docdict = {'value': inserted_string}
>>> docformat(docstring, docdict)
'First line\\n Second line\\n indented\\n string'
'''
if not docstring:
return docstring
if docdict is None:
docdict = {}
if not docdict:
return docstring
lines = docstring.expandtabs().splitlines()
# Find the minimum indent of the main docstring, after first line
if len(lines) < 2:
icount = 0
else:
icount = indentcount_lines(lines[1:])
indent = ' ' * icount
# Insert this indent to dictionary docstrings
indented = {}
for name, dstr in docdict.items():
lines = dstr.expandtabs().splitlines()
try:
newlines = [lines[0]]
for line in lines[1:]:
newlines.append(indent+line)
indented[name] = '\n'.join(newlines)
except IndexError:
indented[name] = dstr
return docstring % indented
def inherit_docstring_from(cls):
"""
This decorator modifies the decorated function's docstring by
replacing occurrences of '%(super)s' with the docstring of the
method of the same name from the class `cls`.
If the decorated method has no docstring, it is simply given the
docstring of `cls`s method.
Parameters
----------
cls : Python class or instance
A class with a method with the same name as the decorated method.
The docstring of the method in this class replaces '%(super)s' in the
docstring of the decorated method.
Returns
-------
f : function
The decorator function that modifies the __doc__ attribute
of its argument.
Examples
--------
In the following, the docstring for Bar.func created using the
docstring of `Foo.func`.
>>> class Foo:
... def func(self):
... '''Do something useful.'''
... return
...
>>> class Bar(Foo):
... @inherit_docstring_from(Foo)
... def func(self):
... '''%(super)s
... Do it fast.
... '''
... return
...
>>> b = Bar()
>>> b.func.__doc__
'Do something useful.\n Do it fast.\n '
"""
def _doc(func):
cls_docstring = getattr(cls, func.__name__).__doc__
func_docstring = func.__doc__
if func_docstring is None:
func.__doc__ = cls_docstring
else:
new_docstring = func_docstring % dict(super=cls_docstring)
func.__doc__ = new_docstring
return func
return _doc
def extend_notes_in_docstring(cls, notes):
"""
This decorator replaces the decorated function's docstring
with the docstring from corresponding method in `cls`.
It extends the 'Notes' section of that docstring to include
the given `notes`.
"""
def _doc(func):
cls_docstring = getattr(cls, func.__name__).__doc__
# If python is called with -OO option,
# there is no docstring
if cls_docstring is None:
return func
end_of_notes = cls_docstring.find(' References\n')
if end_of_notes == -1:
end_of_notes = cls_docstring.find(' Examples\n')
if end_of_notes == -1:
end_of_notes = len(cls_docstring)
func.__doc__ = (cls_docstring[:end_of_notes] + notes +
cls_docstring[end_of_notes:])
return func
return _doc
def replace_notes_in_docstring(cls, notes):
"""
This decorator replaces the decorated function's docstring
with the docstring from corresponding method in `cls`.
It replaces the 'Notes' section of that docstring with
the given `notes`.
"""
def _doc(func):
cls_docstring = getattr(cls, func.__name__).__doc__
notes_header = ' Notes\n -----\n'
# If python is called with -OO option,
# there is no docstring
if cls_docstring is None:
return func
start_of_notes = cls_docstring.find(notes_header)
end_of_notes = cls_docstring.find(' References\n')
if end_of_notes == -1:
end_of_notes = cls_docstring.find(' Examples\n')
if end_of_notes == -1:
end_of_notes = len(cls_docstring)
func.__doc__ = (cls_docstring[:start_of_notes + len(notes_header)] +
notes +
cls_docstring[end_of_notes:])
return func
return _doc
def indentcount_lines(lines):
''' Minimum indent for all lines in line list
>>> lines = [' one', ' two', ' three']
>>> indentcount_lines(lines)
1
>>> lines = []
>>> indentcount_lines(lines)
0
>>> lines = [' one']
>>> indentcount_lines(lines)
1
>>> indentcount_lines([' '])
0
'''
indentno = sys.maxsize
for line in lines:
stripped = line.lstrip()
if stripped:
indentno = min(indentno, len(line) - len(stripped))
if indentno == sys.maxsize:
return 0
return indentno
def filldoc(docdict, unindent_params=True):
''' Return docstring decorator using docdict variable dictionary
Parameters
----------
docdict : dictionary
dictionary containing name, docstring fragment pairs
unindent_params : {False, True}, boolean, optional
If True, strip common indentation from all parameters in
docdict
Returns
-------
decfunc : function
decorator that applies dictionary to input function docstring
'''
if unindent_params:
docdict = unindent_dict(docdict)
def decorate(f):
f.__doc__ = docformat(f.__doc__, docdict)
return f
return decorate
def unindent_dict(docdict):
''' Unindent all strings in a docdict '''
can_dict = {}
for name, dstr in docdict.items():
can_dict[name] = unindent_string(dstr)
return can_dict
def unindent_string(docstring):
''' Set docstring to minimum indent for all lines, including first
>>> unindent_string(' two')
'two'
>>> unindent_string(' two\\n three')
'two\\n three'
'''
lines = docstring.expandtabs().splitlines()
icount = indentcount_lines(lines)
if icount == 0:
return docstring
return '\n'.join([line[icount:] for line in lines])
def doc_replace(obj, oldval, newval):
"""Decorator to take the docstring from obj, with oldval replaced by newval
Equivalent to ``func.__doc__ = obj.__doc__.replace(oldval, newval)``
Parameters
----------
obj : object
The object to take the docstring from.
oldval : string
The string to replace from the original docstring.
newval : string
The string to replace ``oldval`` with.
"""
# __doc__ may be None for optimized Python (-OO)
doc = (obj.__doc__ or '').replace(oldval, newval)
def inner(func):
func.__doc__ = doc
return func
return inner

View File

@ -0,0 +1,101 @@
""" Test for assert_deallocated context manager and gc utilities
"""
import gc
from scipy._lib._gcutils import (set_gc_state, gc_state, assert_deallocated,
ReferenceError, IS_PYPY)
from numpy.testing import assert_equal
import pytest
def test_set_gc_state():
gc_status = gc.isenabled()
try:
for state in (True, False):
gc.enable()
set_gc_state(state)
assert_equal(gc.isenabled(), state)
gc.disable()
set_gc_state(state)
assert_equal(gc.isenabled(), state)
finally:
if gc_status:
gc.enable()
def test_gc_state():
# Test gc_state context manager
gc_status = gc.isenabled()
try:
for pre_state in (True, False):
set_gc_state(pre_state)
for with_state in (True, False):
# Check the gc state is with_state in with block
with gc_state(with_state):
assert_equal(gc.isenabled(), with_state)
# And returns to previous state outside block
assert_equal(gc.isenabled(), pre_state)
# Even if the gc state is set explicitly within the block
with gc_state(with_state):
assert_equal(gc.isenabled(), with_state)
set_gc_state(not with_state)
assert_equal(gc.isenabled(), pre_state)
finally:
if gc_status:
gc.enable()
@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
def test_assert_deallocated():
# Ordinary use
class C:
def __init__(self, arg0, arg1, name='myname'):
self.name = name
for gc_current in (True, False):
with gc_state(gc_current):
# We are deleting from with-block context, so that's OK
with assert_deallocated(C, 0, 2, 'another name') as c:
assert_equal(c.name, 'another name')
del c
# Or not using the thing in with-block context, also OK
with assert_deallocated(C, 0, 2, name='third name'):
pass
assert_equal(gc.isenabled(), gc_current)
@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
def test_assert_deallocated_nodel():
class C:
pass
with pytest.raises(ReferenceError):
# Need to delete after using if in with-block context
# Note: assert_deallocated(C) needs to be assigned for the test
# to function correctly. It is assigned to _, but _ itself is
# not referenced in the body of the with, it is only there for
# the refcount.
with assert_deallocated(C) as _:
pass
@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
def test_assert_deallocated_circular():
class C:
def __init__(self):
self._circular = self
with pytest.raises(ReferenceError):
# Circular reference, no automatic garbage collection
with assert_deallocated(C) as c:
del c
@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
def test_assert_deallocated_circular2():
class C:
def __init__(self):
self._circular = self
with pytest.raises(ReferenceError):
# Still circular reference, no automatic garbage collection
with assert_deallocated(C):
pass

View File

@ -0,0 +1,67 @@
from pytest import raises as assert_raises
from scipy._lib._pep440 import Version, parse
def test_main_versions():
assert Version('1.8.0') == Version('1.8.0')
for ver in ['1.9.0', '2.0.0', '1.8.1']:
assert Version('1.8.0') < Version(ver)
for ver in ['1.7.0', '1.7.1', '0.9.9']:
assert Version('1.8.0') > Version(ver)
def test_version_1_point_10():
# regression test for gh-2998.
assert Version('1.9.0') < Version('1.10.0')
assert Version('1.11.0') < Version('1.11.1')
assert Version('1.11.0') == Version('1.11.0')
assert Version('1.99.11') < Version('1.99.12')
def test_alpha_beta_rc():
assert Version('1.8.0rc1') == Version('1.8.0rc1')
for ver in ['1.8.0', '1.8.0rc2']:
assert Version('1.8.0rc1') < Version(ver)
for ver in ['1.8.0a2', '1.8.0b3', '1.7.2rc4']:
assert Version('1.8.0rc1') > Version(ver)
assert Version('1.8.0b1') > Version('1.8.0a2')
def test_dev_version():
assert Version('1.9.0.dev+Unknown') < Version('1.9.0')
for ver in ['1.9.0', '1.9.0a1', '1.9.0b2', '1.9.0b2.dev+ffffffff', '1.9.0.dev1']:
assert Version('1.9.0.dev+f16acvda') < Version(ver)
assert Version('1.9.0.dev+f16acvda') == Version('1.9.0.dev+f16acvda')
def test_dev_a_b_rc_mixed():
assert Version('1.9.0a2.dev+f16acvda') == Version('1.9.0a2.dev+f16acvda')
assert Version('1.9.0a2.dev+6acvda54') < Version('1.9.0a2')
def test_dev0_version():
assert Version('1.9.0.dev0+Unknown') < Version('1.9.0')
for ver in ['1.9.0', '1.9.0a1', '1.9.0b2', '1.9.0b2.dev0+ffffffff']:
assert Version('1.9.0.dev0+f16acvda') < Version(ver)
assert Version('1.9.0.dev0+f16acvda') == Version('1.9.0.dev0+f16acvda')
def test_dev0_a_b_rc_mixed():
assert Version('1.9.0a2.dev0+f16acvda') == Version('1.9.0a2.dev0+f16acvda')
assert Version('1.9.0a2.dev0+6acvda54') < Version('1.9.0a2')
def test_raises():
for ver in ['1,9.0', '1.7.x']:
assert_raises(ValueError, Version, ver)
def test_legacy_version():
# Non-PEP-440 version identifiers always compare less. For NumPy this only
# occurs on dev builds prior to 1.10.0 which are unsupported anyway.
assert parse('invalid') < Version('0.0.0')
assert parse('1.9.0-f16acvda') < Version('1.0.0')

View File

@ -0,0 +1,32 @@
import sys
from scipy._lib._testutils import _parse_size, _get_mem_available
import pytest
def test__parse_size():
expected = {
'12': 12e6,
'12 b': 12,
'12k': 12e3,
' 12 M ': 12e6,
' 12 G ': 12e9,
' 12Tb ': 12e12,
'12 Mib ': 12 * 1024.0**2,
'12Tib': 12 * 1024.0**4,
}
for inp, outp in sorted(expected.items()):
if outp is None:
with pytest.raises(ValueError):
_parse_size(inp)
else:
assert _parse_size(inp) == outp
def test__mem_available():
# May return None on non-Linux platforms
available = _get_mem_available()
if sys.platform.startswith('linux'):
assert available >= 0
else:
assert available is None or available >= 0

View File

@ -0,0 +1,51 @@
import threading
import time
import traceback
from numpy.testing import assert_
from pytest import raises as assert_raises
from scipy._lib._threadsafety import ReentrancyLock, non_reentrant, ReentrancyError
def test_parallel_threads():
# Check that ReentrancyLock serializes work in parallel threads.
#
# The test is not fully deterministic, and may succeed falsely if
# the timings go wrong.
lock = ReentrancyLock("failure")
failflag = [False]
exceptions_raised = []
def worker(k):
try:
with lock:
assert_(not failflag[0])
failflag[0] = True
time.sleep(0.1 * k)
assert_(failflag[0])
failflag[0] = False
except Exception:
exceptions_raised.append(traceback.format_exc(2))
threads = [threading.Thread(target=lambda k=k: worker(k))
for k in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
exceptions_raised = "\n".join(exceptions_raised)
assert_(not exceptions_raised, exceptions_raised)
def test_reentering():
# Check that ReentrancyLock prevents re-entering from the same thread.
@non_reentrant()
def func(x):
return func(x)
assert_raises(ReentrancyError, func, 0)

View File

@ -0,0 +1,447 @@
from multiprocessing import Pool
from multiprocessing.pool import Pool as PWL
import re
import math
from fractions import Fraction
import numpy as np
from numpy.testing import assert_equal, assert_
import pytest
from pytest import raises as assert_raises
import hypothesis.extra.numpy as npst
from hypothesis import given, strategies, reproduce_failure # noqa: F401
from scipy.conftest import array_api_compatible, skip_xp_invalid_arg
from scipy._lib._array_api import (xp_assert_equal, xp_assert_close, is_numpy,
copy as xp_copy)
from scipy._lib._util import (_aligned_zeros, check_random_state, MapWrapper,
getfullargspec_no_self, FullArgSpec,
rng_integers, _validate_int, _rename_parameter,
_contains_nan, _rng_html_rewrite, _lazywhere)
skip_xp_backends = pytest.mark.skip_xp_backends
@pytest.mark.slow
def test__aligned_zeros():
niter = 10
def check(shape, dtype, order, align):
err_msg = repr((shape, dtype, order, align))
x = _aligned_zeros(shape, dtype, order, align=align)
if align is None:
align = np.dtype(dtype).alignment
assert_equal(x.__array_interface__['data'][0] % align, 0)
if hasattr(shape, '__len__'):
assert_equal(x.shape, shape, err_msg)
else:
assert_equal(x.shape, (shape,), err_msg)
assert_equal(x.dtype, dtype)
if order == "C":
assert_(x.flags.c_contiguous, err_msg)
elif order == "F":
if x.size > 0:
# Size-0 arrays get invalid flags on NumPy 1.5
assert_(x.flags.f_contiguous, err_msg)
elif order is None:
assert_(x.flags.c_contiguous, err_msg)
else:
raise ValueError()
# try various alignments
for align in [1, 2, 3, 4, 8, 16, 32, 64, None]:
for n in [0, 1, 3, 11]:
for order in ["C", "F", None]:
for dtype in [np.uint8, np.float64]:
for shape in [n, (1, 2, 3, n)]:
for j in range(niter):
check(shape, dtype, order, align)
def test_check_random_state():
# If seed is None, return the RandomState singleton used by np.random.
# If seed is an int, return a new RandomState instance seeded with seed.
# If seed is already a RandomState instance, return it.
# Otherwise raise ValueError.
rsi = check_random_state(1)
assert_equal(type(rsi), np.random.RandomState)
rsi = check_random_state(rsi)
assert_equal(type(rsi), np.random.RandomState)
rsi = check_random_state(None)
assert_equal(type(rsi), np.random.RandomState)
assert_raises(ValueError, check_random_state, 'a')
rg = np.random.Generator(np.random.PCG64())
rsi = check_random_state(rg)
assert_equal(type(rsi), np.random.Generator)
def test_getfullargspec_no_self():
p = MapWrapper(1)
argspec = getfullargspec_no_self(p.__init__)
assert_equal(argspec, FullArgSpec(['pool'], None, None, (1,), [],
None, {}))
argspec = getfullargspec_no_self(p.__call__)
assert_equal(argspec, FullArgSpec(['func', 'iterable'], None, None, None,
[], None, {}))
class _rv_generic:
def _rvs(self, a, b=2, c=3, *args, size=None, **kwargs):
return None
rv_obj = _rv_generic()
argspec = getfullargspec_no_self(rv_obj._rvs)
assert_equal(argspec, FullArgSpec(['a', 'b', 'c'], 'args', 'kwargs',
(2, 3), ['size'], {'size': None}, {}))
def test_mapwrapper_serial():
in_arg = np.arange(10.)
out_arg = np.sin(in_arg)
p = MapWrapper(1)
assert_(p._mapfunc is map)
assert_(p.pool is None)
assert_(p._own_pool is False)
out = list(p(np.sin, in_arg))
assert_equal(out, out_arg)
with assert_raises(RuntimeError):
p = MapWrapper(0)
def test_pool():
with Pool(2) as p:
p.map(math.sin, [1, 2, 3, 4])
def test_mapwrapper_parallel():
in_arg = np.arange(10.)
out_arg = np.sin(in_arg)
with MapWrapper(2) as p:
out = p(np.sin, in_arg)
assert_equal(list(out), out_arg)
assert_(p._own_pool is True)
assert_(isinstance(p.pool, PWL))
assert_(p._mapfunc is not None)
# the context manager should've closed the internal pool
# check that it has by asking it to calculate again.
with assert_raises(Exception) as excinfo:
p(np.sin, in_arg)
assert_(excinfo.type is ValueError)
# can also set a PoolWrapper up with a map-like callable instance
with Pool(2) as p:
q = MapWrapper(p.map)
assert_(q._own_pool is False)
q.close()
# closing the PoolWrapper shouldn't close the internal pool
# because it didn't create it
out = p.map(np.sin, in_arg)
assert_equal(list(out), out_arg)
def test_rng_integers():
rng = np.random.RandomState()
# test that numbers are inclusive of high point
arr = rng_integers(rng, low=2, high=5, size=100, endpoint=True)
assert np.max(arr) == 5
assert np.min(arr) == 2
assert arr.shape == (100, )
# test that numbers are inclusive of high point
arr = rng_integers(rng, low=5, size=100, endpoint=True)
assert np.max(arr) == 5
assert np.min(arr) == 0
assert arr.shape == (100, )
# test that numbers are exclusive of high point
arr = rng_integers(rng, low=2, high=5, size=100, endpoint=False)
assert np.max(arr) == 4
assert np.min(arr) == 2
assert arr.shape == (100, )
# test that numbers are exclusive of high point
arr = rng_integers(rng, low=5, size=100, endpoint=False)
assert np.max(arr) == 4
assert np.min(arr) == 0
assert arr.shape == (100, )
# now try with np.random.Generator
try:
rng = np.random.default_rng()
except AttributeError:
return
# test that numbers are inclusive of high point
arr = rng_integers(rng, low=2, high=5, size=100, endpoint=True)
assert np.max(arr) == 5
assert np.min(arr) == 2
assert arr.shape == (100, )
# test that numbers are inclusive of high point
arr = rng_integers(rng, low=5, size=100, endpoint=True)
assert np.max(arr) == 5
assert np.min(arr) == 0
assert arr.shape == (100, )
# test that numbers are exclusive of high point
arr = rng_integers(rng, low=2, high=5, size=100, endpoint=False)
assert np.max(arr) == 4
assert np.min(arr) == 2
assert arr.shape == (100, )
# test that numbers are exclusive of high point
arr = rng_integers(rng, low=5, size=100, endpoint=False)
assert np.max(arr) == 4
assert np.min(arr) == 0
assert arr.shape == (100, )
class TestValidateInt:
@pytest.mark.parametrize('n', [4, np.uint8(4), np.int16(4), np.array(4)])
def test_validate_int(self, n):
n = _validate_int(n, 'n')
assert n == 4
@pytest.mark.parametrize('n', [4.0, np.array([4]), Fraction(4, 1)])
def test_validate_int_bad(self, n):
with pytest.raises(TypeError, match='n must be an integer'):
_validate_int(n, 'n')
def test_validate_int_below_min(self):
with pytest.raises(ValueError, match='n must be an integer not '
'less than 0'):
_validate_int(-1, 'n', 0)
class TestRenameParameter:
# check that wrapper `_rename_parameter` for backward-compatible
# keyword renaming works correctly
# Example method/function that still accepts keyword `old`
@_rename_parameter("old", "new")
def old_keyword_still_accepted(self, new):
return new
# Example method/function for which keyword `old` is deprecated
@_rename_parameter("old", "new", dep_version="1.9.0")
def old_keyword_deprecated(self, new):
return new
def test_old_keyword_still_accepted(self):
# positional argument and both keyword work identically
res1 = self.old_keyword_still_accepted(10)
res2 = self.old_keyword_still_accepted(new=10)
res3 = self.old_keyword_still_accepted(old=10)
assert res1 == res2 == res3 == 10
# unexpected keyword raises an error
message = re.escape("old_keyword_still_accepted() got an unexpected")
with pytest.raises(TypeError, match=message):
self.old_keyword_still_accepted(unexpected=10)
# multiple values for the same parameter raises an error
message = re.escape("old_keyword_still_accepted() got multiple")
with pytest.raises(TypeError, match=message):
self.old_keyword_still_accepted(10, new=10)
with pytest.raises(TypeError, match=message):
self.old_keyword_still_accepted(10, old=10)
with pytest.raises(TypeError, match=message):
self.old_keyword_still_accepted(new=10, old=10)
def test_old_keyword_deprecated(self):
# positional argument and both keyword work identically,
# but use of old keyword results in DeprecationWarning
dep_msg = "Use of keyword argument `old` is deprecated"
res1 = self.old_keyword_deprecated(10)
res2 = self.old_keyword_deprecated(new=10)
with pytest.warns(DeprecationWarning, match=dep_msg):
res3 = self.old_keyword_deprecated(old=10)
assert res1 == res2 == res3 == 10
# unexpected keyword raises an error
message = re.escape("old_keyword_deprecated() got an unexpected")
with pytest.raises(TypeError, match=message):
self.old_keyword_deprecated(unexpected=10)
# multiple values for the same parameter raises an error and,
# if old keyword is used, results in DeprecationWarning
message = re.escape("old_keyword_deprecated() got multiple")
with pytest.raises(TypeError, match=message):
self.old_keyword_deprecated(10, new=10)
with pytest.raises(TypeError, match=message), \
pytest.warns(DeprecationWarning, match=dep_msg):
self.old_keyword_deprecated(10, old=10)
with pytest.raises(TypeError, match=message), \
pytest.warns(DeprecationWarning, match=dep_msg):
self.old_keyword_deprecated(new=10, old=10)
class TestContainsNaNTest:
def test_policy(self):
data = np.array([1, 2, 3, np.nan])
contains_nan, nan_policy = _contains_nan(data, nan_policy="propagate")
assert contains_nan
assert nan_policy == "propagate"
contains_nan, nan_policy = _contains_nan(data, nan_policy="omit")
assert contains_nan
assert nan_policy == "omit"
msg = "The input contains nan values"
with pytest.raises(ValueError, match=msg):
_contains_nan(data, nan_policy="raise")
msg = "nan_policy must be one of"
with pytest.raises(ValueError, match=msg):
_contains_nan(data, nan_policy="nan")
def test_contains_nan(self):
data1 = np.array([1, 2, 3])
assert not _contains_nan(data1)[0]
data2 = np.array([1, 2, 3, np.nan])
assert _contains_nan(data2)[0]
data3 = np.array([np.nan, 2, 3, np.nan])
assert _contains_nan(data3)[0]
data4 = np.array([[1, 2], [3, 4]])
assert not _contains_nan(data4)[0]
data5 = np.array([[1, 2], [3, np.nan]])
assert _contains_nan(data5)[0]
@skip_xp_invalid_arg
def test_contains_nan_with_strings(self):
data1 = np.array([1, 2, "3", np.nan]) # converted to string "nan"
assert not _contains_nan(data1)[0]
data2 = np.array([1, 2, "3", np.nan], dtype='object')
assert _contains_nan(data2)[0]
data3 = np.array([["1", 2], [3, np.nan]]) # converted to string "nan"
assert not _contains_nan(data3)[0]
data4 = np.array([["1", 2], [3, np.nan]], dtype='object')
assert _contains_nan(data4)[0]
@skip_xp_backends('jax.numpy',
reasons=["JAX arrays do not support item assignment"])
@pytest.mark.usefixtures("skip_xp_backends")
@array_api_compatible
@pytest.mark.parametrize("nan_policy", ['propagate', 'omit', 'raise'])
def test_array_api(self, xp, nan_policy):
rng = np.random.default_rng(932347235892482)
x0 = rng.random(size=(2, 3, 4))
x = xp.asarray(x0)
x_nan = xp_copy(x, xp=xp)
x_nan[1, 2, 1] = np.nan
contains_nan, nan_policy_out = _contains_nan(x, nan_policy=nan_policy)
assert not contains_nan
assert nan_policy_out == nan_policy
if nan_policy == 'raise':
message = 'The input contains...'
with pytest.raises(ValueError, match=message):
_contains_nan(x_nan, nan_policy=nan_policy)
elif nan_policy == 'omit' and not is_numpy(xp):
message = "`nan_policy='omit' is incompatible..."
with pytest.raises(ValueError, match=message):
_contains_nan(x_nan, nan_policy=nan_policy)
elif nan_policy == 'propagate':
contains_nan, nan_policy_out = _contains_nan(
x_nan, nan_policy=nan_policy)
assert contains_nan
assert nan_policy_out == nan_policy
def test__rng_html_rewrite():
def mock_str():
lines = [
'np.random.default_rng(8989843)',
'np.random.default_rng(seed)',
'np.random.default_rng(0x9a71b21474694f919882289dc1559ca)',
' bob ',
]
return lines
res = _rng_html_rewrite(mock_str)()
ref = [
'np.random.default_rng()',
'np.random.default_rng(seed)',
'np.random.default_rng()',
' bob ',
]
assert res == ref
class TestLazywhere:
n_arrays = strategies.integers(min_value=1, max_value=3)
rng_seed = strategies.integers(min_value=1000000000, max_value=9999999999)
dtype = strategies.sampled_from((np.float32, np.float64))
p = strategies.floats(min_value=0, max_value=1)
data = strategies.data()
@pytest.mark.fail_slow(5)
@pytest.mark.filterwarnings('ignore::RuntimeWarning') # overflows, etc.
@skip_xp_backends('jax.numpy',
reasons=["JAX arrays do not support item assignment"])
@pytest.mark.usefixtures("skip_xp_backends")
@array_api_compatible
@given(n_arrays=n_arrays, rng_seed=rng_seed, dtype=dtype, p=p, data=data)
def test_basic(self, n_arrays, rng_seed, dtype, p, data, xp):
mbs = npst.mutually_broadcastable_shapes(num_shapes=n_arrays+1,
min_side=0)
input_shapes, result_shape = data.draw(mbs)
cond_shape, *shapes = input_shapes
fillvalue = xp.asarray(data.draw(npst.arrays(dtype=dtype, shape=tuple())))
arrays = [xp.asarray(data.draw(npst.arrays(dtype=dtype, shape=shape)))
for shape in shapes]
def f(*args):
return sum(arg for arg in args)
def f2(*args):
return sum(arg for arg in args) / 2
rng = np.random.default_rng(rng_seed)
cond = xp.asarray(rng.random(size=cond_shape) > p)
res1 = _lazywhere(cond, arrays, f, fillvalue)
res2 = _lazywhere(cond, arrays, f, f2=f2)
# Ensure arrays are at least 1d to follow sane type promotion rules.
if xp == np:
cond, fillvalue, *arrays = np.atleast_1d(cond, fillvalue, *arrays)
ref1 = xp.where(cond, f(*arrays), fillvalue)
ref2 = xp.where(cond, f(*arrays), f2(*arrays))
if xp == np:
ref1 = ref1.reshape(result_shape)
ref2 = ref2.reshape(result_shape)
res1 = xp.asarray(res1)[()]
res2 = xp.asarray(res2)[()]
isinstance(res1, type(xp.asarray([])))
xp_assert_close(res1, ref1, rtol=2e-16)
assert_equal(res1.shape, ref1.shape)
assert_equal(res1.dtype, ref1.dtype)
isinstance(res2, type(xp.asarray([])))
xp_assert_equal(res2, ref2)
assert_equal(res2.shape, ref2.shape)
assert_equal(res2.dtype, ref2.dtype)

View File

@ -0,0 +1,114 @@
import numpy as np
import pytest
from scipy.conftest import array_api_compatible
from scipy._lib._array_api import (
_GLOBAL_CONFIG, array_namespace, _asarray, copy, xp_assert_equal, is_numpy
)
import scipy._lib.array_api_compat.numpy as np_compat
skip_xp_backends = pytest.mark.skip_xp_backends
@pytest.mark.skipif(not _GLOBAL_CONFIG["SCIPY_ARRAY_API"],
reason="Array API test; set environment variable SCIPY_ARRAY_API=1 to run it")
class TestArrayAPI:
def test_array_namespace(self):
x, y = np.array([0, 1, 2]), np.array([0, 1, 2])
xp = array_namespace(x, y)
assert 'array_api_compat.numpy' in xp.__name__
_GLOBAL_CONFIG["SCIPY_ARRAY_API"] = False
xp = array_namespace(x, y)
assert 'array_api_compat.numpy' in xp.__name__
_GLOBAL_CONFIG["SCIPY_ARRAY_API"] = True
@array_api_compatible
def test_asarray(self, xp):
x, y = _asarray([0, 1, 2], xp=xp), _asarray(np.arange(3), xp=xp)
ref = xp.asarray([0, 1, 2])
xp_assert_equal(x, ref)
xp_assert_equal(y, ref)
@pytest.mark.filterwarnings("ignore: the matrix subclass")
def test_raises(self):
msg = "of type `numpy.ma.MaskedArray` are not supported"
with pytest.raises(TypeError, match=msg):
array_namespace(np.ma.array(1), np.array(1))
msg = "of type `numpy.matrix` are not supported"
with pytest.raises(TypeError, match=msg):
array_namespace(np.array(1), np.matrix(1))
msg = "only boolean and numerical dtypes are supported"
with pytest.raises(TypeError, match=msg):
array_namespace([object()])
with pytest.raises(TypeError, match=msg):
array_namespace('abc')
def test_array_likes(self):
# should be no exceptions
array_namespace([0, 1, 2])
array_namespace(1, 2, 3)
array_namespace(1)
@skip_xp_backends('jax.numpy',
reasons=["JAX arrays do not support item assignment"])
@pytest.mark.usefixtures("skip_xp_backends")
@array_api_compatible
def test_copy(self, xp):
for _xp in [xp, None]:
x = xp.asarray([1, 2, 3])
y = copy(x, xp=_xp)
# with numpy we'd want to use np.shared_memory, but that's not specified
# in the array-api
x[0] = 10
x[1] = 11
x[2] = 12
assert x[0] != y[0]
assert x[1] != y[1]
assert x[2] != y[2]
assert id(x) != id(y)
@array_api_compatible
@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float32', 'float64'])
@pytest.mark.parametrize('shape', [(), (3,)])
def test_strict_checks(self, xp, dtype, shape):
# Check that `_strict_check` behaves as expected
dtype = getattr(xp, dtype)
x = xp.broadcast_to(xp.asarray(1, dtype=dtype), shape)
x = x if shape else x[()]
y = np_compat.asarray(1)[()]
options = dict(check_namespace=True, check_dtype=False, check_shape=False)
if xp == np:
xp_assert_equal(x, y, **options)
else:
with pytest.raises(AssertionError, match="Namespaces do not match."):
xp_assert_equal(x, y, **options)
options = dict(check_namespace=False, check_dtype=True, check_shape=False)
if y.dtype.name in str(x.dtype):
xp_assert_equal(x, y, **options)
else:
with pytest.raises(AssertionError, match="dtypes do not match."):
xp_assert_equal(x, y, **options)
options = dict(check_namespace=False, check_dtype=False, check_shape=True)
if x.shape == y.shape:
xp_assert_equal(x, y, **options)
else:
with pytest.raises(AssertionError, match="Shapes do not match."):
xp_assert_equal(x, y, **options)
@array_api_compatible
def test_check_scalar(self, xp):
if not is_numpy(xp):
pytest.skip("Scalars only exist in NumPy")
if is_numpy(xp):
with pytest.raises(AssertionError, match="Types do not match."):
xp_assert_equal(xp.asarray(0.), xp.float64(0))
xp_assert_equal(xp.float64(0), xp.asarray(0.))

View File

@ -0,0 +1,162 @@
import pytest
import pickle
from numpy.testing import assert_equal
from scipy._lib._bunch import _make_tuple_bunch
# `Result` is defined at the top level of the module so it can be
# used to test pickling.
Result = _make_tuple_bunch('Result', ['x', 'y', 'z'], ['w', 'beta'])
class TestMakeTupleBunch:
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# Tests with Result
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def setup_method(self):
# Set up an instance of Result.
self.result = Result(x=1, y=2, z=3, w=99, beta=0.5)
def test_attribute_access(self):
assert_equal(self.result.x, 1)
assert_equal(self.result.y, 2)
assert_equal(self.result.z, 3)
assert_equal(self.result.w, 99)
assert_equal(self.result.beta, 0.5)
def test_indexing(self):
assert_equal(self.result[0], 1)
assert_equal(self.result[1], 2)
assert_equal(self.result[2], 3)
assert_equal(self.result[-1], 3)
with pytest.raises(IndexError, match='index out of range'):
self.result[3]
def test_unpacking(self):
x0, y0, z0 = self.result
assert_equal((x0, y0, z0), (1, 2, 3))
assert_equal(self.result, (1, 2, 3))
def test_slice(self):
assert_equal(self.result[1:], (2, 3))
assert_equal(self.result[::2], (1, 3))
assert_equal(self.result[::-1], (3, 2, 1))
def test_len(self):
assert_equal(len(self.result), 3)
def test_repr(self):
s = repr(self.result)
assert_equal(s, 'Result(x=1, y=2, z=3, w=99, beta=0.5)')
def test_hash(self):
assert_equal(hash(self.result), hash((1, 2, 3)))
def test_pickle(self):
s = pickle.dumps(self.result)
obj = pickle.loads(s)
assert isinstance(obj, Result)
assert_equal(obj.x, self.result.x)
assert_equal(obj.y, self.result.y)
assert_equal(obj.z, self.result.z)
assert_equal(obj.w, self.result.w)
assert_equal(obj.beta, self.result.beta)
def test_read_only_existing(self):
with pytest.raises(AttributeError, match="can't set attribute"):
self.result.x = -1
def test_read_only_new(self):
self.result.plate_of_shrimp = "lattice of coincidence"
assert self.result.plate_of_shrimp == "lattice of coincidence"
def test_constructor_missing_parameter(self):
with pytest.raises(TypeError, match='missing'):
# `w` is missing.
Result(x=1, y=2, z=3, beta=0.75)
def test_constructor_incorrect_parameter(self):
with pytest.raises(TypeError, match='unexpected'):
# `foo` is not an existing field.
Result(x=1, y=2, z=3, w=123, beta=0.75, foo=999)
def test_module(self):
m = 'scipy._lib.tests.test_bunch'
assert_equal(Result.__module__, m)
assert_equal(self.result.__module__, m)
def test_extra_fields_per_instance(self):
# This test exists to ensure that instances of the same class
# store their own values for the extra fields. That is, the values
# are stored per instance and not in the class.
result1 = Result(x=1, y=2, z=3, w=-1, beta=0.0)
result2 = Result(x=4, y=5, z=6, w=99, beta=1.0)
assert_equal(result1.w, -1)
assert_equal(result1.beta, 0.0)
# The rest of these checks aren't essential, but let's check
# them anyway.
assert_equal(result1[:], (1, 2, 3))
assert_equal(result2.w, 99)
assert_equal(result2.beta, 1.0)
assert_equal(result2[:], (4, 5, 6))
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# Other tests
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def test_extra_field_names_is_optional(self):
Square = _make_tuple_bunch('Square', ['width', 'height'])
sq = Square(width=1, height=2)
assert_equal(sq.width, 1)
assert_equal(sq.height, 2)
s = repr(sq)
assert_equal(s, 'Square(width=1, height=2)')
def test_tuple_like(self):
Tup = _make_tuple_bunch('Tup', ['a', 'b'])
tu = Tup(a=1, b=2)
assert isinstance(tu, tuple)
assert isinstance(tu + (1,), tuple)
def test_explicit_module(self):
m = 'some.module.name'
Foo = _make_tuple_bunch('Foo', ['x'], ['a', 'b'], module=m)
foo = Foo(x=1, a=355, b=113)
assert_equal(Foo.__module__, m)
assert_equal(foo.__module__, m)
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# Argument validation
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
@pytest.mark.parametrize('args', [('123', ['a'], ['b']),
('Foo', ['-3'], ['x']),
('Foo', ['a'], ['+-*/'])])
def test_identifiers_not_allowed(self, args):
with pytest.raises(ValueError, match='identifiers'):
_make_tuple_bunch(*args)
@pytest.mark.parametrize('args', [('Foo', ['a', 'b', 'a'], ['x']),
('Foo', ['a', 'b'], ['b', 'x'])])
def test_repeated_field_names(self, args):
with pytest.raises(ValueError, match='Duplicate'):
_make_tuple_bunch(*args)
@pytest.mark.parametrize('args', [('Foo', ['_a'], ['x']),
('Foo', ['a'], ['_x'])])
def test_leading_underscore_not_allowed(self, args):
with pytest.raises(ValueError, match='underscore'):
_make_tuple_bunch(*args)
@pytest.mark.parametrize('args', [('Foo', ['def'], ['x']),
('Foo', ['a'], ['or']),
('and', ['a'], ['x'])])
def test_keyword_not_allowed_in_fields(self, args):
with pytest.raises(ValueError, match='keyword'):
_make_tuple_bunch(*args)
def test_at_least_one_field_name_required(self):
with pytest.raises(ValueError, match='at least one name'):
_make_tuple_bunch('Qwerty', [], ['a', 'b'])

View File

@ -0,0 +1,204 @@
from numpy.testing import assert_equal, assert_
from pytest import raises as assert_raises
import time
import pytest
import ctypes
import threading
from scipy._lib import _ccallback_c as _test_ccallback_cython
from scipy._lib import _test_ccallback
from scipy._lib._ccallback import LowLevelCallable
try:
import cffi
HAVE_CFFI = True
except ImportError:
HAVE_CFFI = False
ERROR_VALUE = 2.0
def callback_python(a, user_data=None):
if a == ERROR_VALUE:
raise ValueError("bad value")
if user_data is None:
return a + 1
else:
return a + user_data
def _get_cffi_func(base, signature):
if not HAVE_CFFI:
pytest.skip("cffi not installed")
# Get function address
voidp = ctypes.cast(base, ctypes.c_void_p)
address = voidp.value
# Create corresponding cffi handle
ffi = cffi.FFI()
func = ffi.cast(signature, address)
return func
def _get_ctypes_data():
value = ctypes.c_double(2.0)
return ctypes.cast(ctypes.pointer(value), ctypes.c_voidp)
def _get_cffi_data():
if not HAVE_CFFI:
pytest.skip("cffi not installed")
ffi = cffi.FFI()
return ffi.new('double *', 2.0)
CALLERS = {
'simple': _test_ccallback.test_call_simple,
'nodata': _test_ccallback.test_call_nodata,
'nonlocal': _test_ccallback.test_call_nonlocal,
'cython': _test_ccallback_cython.test_call_cython,
}
# These functions have signatures known to the callers
FUNCS = {
'python': lambda: callback_python,
'capsule': lambda: _test_ccallback.test_get_plus1_capsule(),
'cython': lambda: LowLevelCallable.from_cython(_test_ccallback_cython,
"plus1_cython"),
'ctypes': lambda: _test_ccallback_cython.plus1_ctypes,
'cffi': lambda: _get_cffi_func(_test_ccallback_cython.plus1_ctypes,
'double (*)(double, int *, void *)'),
'capsule_b': lambda: _test_ccallback.test_get_plus1b_capsule(),
'cython_b': lambda: LowLevelCallable.from_cython(_test_ccallback_cython,
"plus1b_cython"),
'ctypes_b': lambda: _test_ccallback_cython.plus1b_ctypes,
'cffi_b': lambda: _get_cffi_func(_test_ccallback_cython.plus1b_ctypes,
'double (*)(double, double, int *, void *)'),
}
# These functions have signatures the callers don't know
BAD_FUNCS = {
'capsule_bc': lambda: _test_ccallback.test_get_plus1bc_capsule(),
'cython_bc': lambda: LowLevelCallable.from_cython(_test_ccallback_cython,
"plus1bc_cython"),
'ctypes_bc': lambda: _test_ccallback_cython.plus1bc_ctypes,
'cffi_bc': lambda: _get_cffi_func(
_test_ccallback_cython.plus1bc_ctypes,
'double (*)(double, double, double, int *, void *)'
),
}
USER_DATAS = {
'ctypes': _get_ctypes_data,
'cffi': _get_cffi_data,
'capsule': _test_ccallback.test_get_data_capsule,
}
def test_callbacks():
def check(caller, func, user_data):
caller = CALLERS[caller]
func = FUNCS[func]()
user_data = USER_DATAS[user_data]()
if func is callback_python:
def func2(x):
return func(x, 2.0)
else:
func2 = LowLevelCallable(func, user_data)
func = LowLevelCallable(func)
# Test basic call
assert_equal(caller(func, 1.0), 2.0)
# Test 'bad' value resulting to an error
assert_raises(ValueError, caller, func, ERROR_VALUE)
# Test passing in user_data
assert_equal(caller(func2, 1.0), 3.0)
for caller in sorted(CALLERS.keys()):
for func in sorted(FUNCS.keys()):
for user_data in sorted(USER_DATAS.keys()):
check(caller, func, user_data)
def test_bad_callbacks():
def check(caller, func, user_data):
caller = CALLERS[caller]
user_data = USER_DATAS[user_data]()
func = BAD_FUNCS[func]()
if func is callback_python:
def func2(x):
return func(x, 2.0)
else:
func2 = LowLevelCallable(func, user_data)
func = LowLevelCallable(func)
# Test that basic call fails
assert_raises(ValueError, caller, LowLevelCallable(func), 1.0)
# Test that passing in user_data also fails
assert_raises(ValueError, caller, func2, 1.0)
# Test error message
llfunc = LowLevelCallable(func)
try:
caller(llfunc, 1.0)
except ValueError as err:
msg = str(err)
assert_(llfunc.signature in msg, msg)
assert_('double (double, double, int *, void *)' in msg, msg)
for caller in sorted(CALLERS.keys()):
for func in sorted(BAD_FUNCS.keys()):
for user_data in sorted(USER_DATAS.keys()):
check(caller, func, user_data)
def test_signature_override():
caller = _test_ccallback.test_call_simple
func = _test_ccallback.test_get_plus1_capsule()
llcallable = LowLevelCallable(func, signature="bad signature")
assert_equal(llcallable.signature, "bad signature")
assert_raises(ValueError, caller, llcallable, 3)
llcallable = LowLevelCallable(func, signature="double (double, int *, void *)")
assert_equal(llcallable.signature, "double (double, int *, void *)")
assert_equal(caller(llcallable, 3), 4)
def test_threadsafety():
def callback(a, caller):
if a <= 0:
return 1
else:
res = caller(lambda x: callback(x, caller), a - 1)
return 2*res
def check(caller):
caller = CALLERS[caller]
results = []
count = 10
def run():
time.sleep(0.01)
r = caller(lambda x: callback(x, caller), count)
results.append(r)
threads = [threading.Thread(target=run) for j in range(20)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert_equal(results, [2.0**count]*len(threads))
for caller in CALLERS.keys():
check(caller)

View File

@ -0,0 +1,10 @@
import pytest
def test_cython_api_deprecation():
match = ("`scipy._lib._test_deprecation_def.foo_deprecated` "
"is deprecated, use `foo` instead!\n"
"Deprecated in Scipy 42.0.0")
with pytest.warns(DeprecationWarning, match=match):
from .. import _test_deprecation_call
assert _test_deprecation_call.call() == (1, 1)

View File

@ -0,0 +1,17 @@
import pytest
import sys
import subprocess
from .test_public_api import PUBLIC_MODULES
# Regression tests for gh-6793.
# Check that all modules are importable in a new Python process.
# This is not necessarily true if there are import cycles present.
@pytest.mark.fail_slow(20)
@pytest.mark.slow
def test_public_modules_importable():
pids = [subprocess.Popen([sys.executable, '-c', f'import {module}'])
for module in PUBLIC_MODULES]
for i, pid in enumerate(pids):
assert pid.wait() == 0, f'Failed to import {PUBLIC_MODULES[i]}'

View File

@ -0,0 +1,496 @@
"""
This test script is adopted from:
https://github.com/numpy/numpy/blob/main/numpy/tests/test_public_api.py
"""
import pkgutil
import types
import importlib
import warnings
from importlib import import_module
import pytest
import scipy
from scipy.conftest import xp_available_backends
def test_dir_testing():
"""Assert that output of dir has only one "testing/tester"
attribute without duplicate"""
assert len(dir(scipy)) == len(set(dir(scipy)))
# Historically SciPy has not used leading underscores for private submodules
# much. This has resulted in lots of things that look like public modules
# (i.e. things that can be imported as `import scipy.somesubmodule.somefile`),
# but were never intended to be public. The PUBLIC_MODULES list contains
# modules that are either public because they were meant to be, or because they
# contain public functions/objects that aren't present in any other namespace
# for whatever reason and therefore should be treated as public.
PUBLIC_MODULES = ["scipy." + s for s in [
"cluster",
"cluster.vq",
"cluster.hierarchy",
"constants",
"datasets",
"fft",
"fftpack",
"integrate",
"interpolate",
"io",
"io.arff",
"io.matlab",
"io.wavfile",
"linalg",
"linalg.blas",
"linalg.cython_blas",
"linalg.lapack",
"linalg.cython_lapack",
"linalg.interpolative",
"misc",
"ndimage",
"odr",
"optimize",
"signal",
"signal.windows",
"sparse",
"sparse.linalg",
"sparse.csgraph",
"spatial",
"spatial.distance",
"spatial.transform",
"special",
"stats",
"stats.contingency",
"stats.distributions",
"stats.mstats",
"stats.qmc",
"stats.sampling"
]]
# The PRIVATE_BUT_PRESENT_MODULES list contains modules that lacked underscores
# in their name and hence looked public, but weren't meant to be. All these
# namespace were deprecated in the 1.8.0 release - see "clear split between
# public and private API" in the 1.8.0 release notes.
# These private modules support will be removed in SciPy v2.0.0, as the
# deprecation messages emitted by each of these modules say.
PRIVATE_BUT_PRESENT_MODULES = [
'scipy.constants.codata',
'scipy.constants.constants',
'scipy.fftpack.basic',
'scipy.fftpack.convolve',
'scipy.fftpack.helper',
'scipy.fftpack.pseudo_diffs',
'scipy.fftpack.realtransforms',
'scipy.integrate.dop',
'scipy.integrate.lsoda',
'scipy.integrate.odepack',
'scipy.integrate.quadpack',
'scipy.integrate.vode',
'scipy.interpolate.dfitpack',
'scipy.interpolate.fitpack',
'scipy.interpolate.fitpack2',
'scipy.interpolate.interpnd',
'scipy.interpolate.interpolate',
'scipy.interpolate.ndgriddata',
'scipy.interpolate.polyint',
'scipy.interpolate.rbf',
'scipy.io.arff.arffread',
'scipy.io.harwell_boeing',
'scipy.io.idl',
'scipy.io.matlab.byteordercodes',
'scipy.io.matlab.mio',
'scipy.io.matlab.mio4',
'scipy.io.matlab.mio5',
'scipy.io.matlab.mio5_params',
'scipy.io.matlab.mio5_utils',
'scipy.io.matlab.mio_utils',
'scipy.io.matlab.miobase',
'scipy.io.matlab.streams',
'scipy.io.mmio',
'scipy.io.netcdf',
'scipy.linalg.basic',
'scipy.linalg.decomp',
'scipy.linalg.decomp_cholesky',
'scipy.linalg.decomp_lu',
'scipy.linalg.decomp_qr',
'scipy.linalg.decomp_schur',
'scipy.linalg.decomp_svd',
'scipy.linalg.matfuncs',
'scipy.linalg.misc',
'scipy.linalg.special_matrices',
'scipy.misc.common',
'scipy.misc.doccer',
'scipy.ndimage.filters',
'scipy.ndimage.fourier',
'scipy.ndimage.interpolation',
'scipy.ndimage.measurements',
'scipy.ndimage.morphology',
'scipy.odr.models',
'scipy.odr.odrpack',
'scipy.optimize.cobyla',
'scipy.optimize.cython_optimize',
'scipy.optimize.lbfgsb',
'scipy.optimize.linesearch',
'scipy.optimize.minpack',
'scipy.optimize.minpack2',
'scipy.optimize.moduleTNC',
'scipy.optimize.nonlin',
'scipy.optimize.optimize',
'scipy.optimize.slsqp',
'scipy.optimize.tnc',
'scipy.optimize.zeros',
'scipy.signal.bsplines',
'scipy.signal.filter_design',
'scipy.signal.fir_filter_design',
'scipy.signal.lti_conversion',
'scipy.signal.ltisys',
'scipy.signal.signaltools',
'scipy.signal.spectral',
'scipy.signal.spline',
'scipy.signal.waveforms',
'scipy.signal.wavelets',
'scipy.signal.windows.windows',
'scipy.sparse.base',
'scipy.sparse.bsr',
'scipy.sparse.compressed',
'scipy.sparse.construct',
'scipy.sparse.coo',
'scipy.sparse.csc',
'scipy.sparse.csr',
'scipy.sparse.data',
'scipy.sparse.dia',
'scipy.sparse.dok',
'scipy.sparse.extract',
'scipy.sparse.lil',
'scipy.sparse.linalg.dsolve',
'scipy.sparse.linalg.eigen',
'scipy.sparse.linalg.interface',
'scipy.sparse.linalg.isolve',
'scipy.sparse.linalg.matfuncs',
'scipy.sparse.sparsetools',
'scipy.sparse.spfuncs',
'scipy.sparse.sputils',
'scipy.spatial.ckdtree',
'scipy.spatial.kdtree',
'scipy.spatial.qhull',
'scipy.spatial.transform.rotation',
'scipy.special.add_newdocs',
'scipy.special.basic',
'scipy.special.cython_special',
'scipy.special.orthogonal',
'scipy.special.sf_error',
'scipy.special.specfun',
'scipy.special.spfun_stats',
'scipy.stats.biasedurn',
'scipy.stats.kde',
'scipy.stats.morestats',
'scipy.stats.mstats_basic',
'scipy.stats.mstats_extras',
'scipy.stats.mvn',
'scipy.stats.stats',
]
def is_unexpected(name):
"""Check if this needs to be considered."""
if '._' in name or '.tests' in name or '.setup' in name:
return False
if name in PUBLIC_MODULES:
return False
if name in PRIVATE_BUT_PRESENT_MODULES:
return False
return True
SKIP_LIST = [
'scipy.conftest',
'scipy.version',
'scipy.special.libsf_error_state'
]
# XXX: this test does more than it says on the tin - in using `pkgutil.walk_packages`,
# it will raise if it encounters any exceptions which are not handled by `ignore_errors`
# while attempting to import each discovered package.
# For now, `ignore_errors` only ignores what is necessary, but this could be expanded -
# for example, to all errors from private modules or git subpackages - if desired.
def test_all_modules_are_expected():
"""
Test that we don't add anything that looks like a new public module by
accident. Check is based on filenames.
"""
def ignore_errors(name):
# if versions of other array libraries are installed which are incompatible
# with the installed NumPy version, there can be errors on importing
# `array_api_compat`. This should only raise if SciPy is configured with
# that library as an available backend.
backends = {'cupy': 'cupy',
'pytorch': 'torch',
'dask.array': 'dask.array'}
for backend, dir_name in backends.items():
path = f'array_api_compat.{dir_name}'
if path in name and backend not in xp_available_backends:
return
raise
modnames = []
for _, modname, _ in pkgutil.walk_packages(path=scipy.__path__,
prefix=scipy.__name__ + '.',
onerror=ignore_errors):
if is_unexpected(modname) and modname not in SKIP_LIST:
# We have a name that is new. If that's on purpose, add it to
# PUBLIC_MODULES. We don't expect to have to add anything to
# PRIVATE_BUT_PRESENT_MODULES. Use an underscore in the name!
modnames.append(modname)
if modnames:
raise AssertionError(f'Found unexpected modules: {modnames}')
# Stuff that clearly shouldn't be in the API and is detected by the next test
# below
SKIP_LIST_2 = [
'scipy.char',
'scipy.rec',
'scipy.emath',
'scipy.math',
'scipy.random',
'scipy.ctypeslib',
'scipy.ma'
]
def test_all_modules_are_expected_2():
"""
Method checking all objects. The pkgutil-based method in
`test_all_modules_are_expected` does not catch imports into a namespace,
only filenames.
"""
def find_unexpected_members(mod_name):
members = []
module = importlib.import_module(mod_name)
if hasattr(module, '__all__'):
objnames = module.__all__
else:
objnames = dir(module)
for objname in objnames:
if not objname.startswith('_'):
fullobjname = mod_name + '.' + objname
if isinstance(getattr(module, objname), types.ModuleType):
if is_unexpected(fullobjname) and fullobjname not in SKIP_LIST_2:
members.append(fullobjname)
return members
unexpected_members = find_unexpected_members("scipy")
for modname in PUBLIC_MODULES:
unexpected_members.extend(find_unexpected_members(modname))
if unexpected_members:
raise AssertionError("Found unexpected object(s) that look like "
f"modules: {unexpected_members}")
def test_api_importable():
"""
Check that all submodules listed higher up in this file can be imported
Note that if a PRIVATE_BUT_PRESENT_MODULES entry goes missing, it may
simply need to be removed from the list (deprecation may or may not be
needed - apply common sense).
"""
def check_importable(module_name):
try:
importlib.import_module(module_name)
except (ImportError, AttributeError):
return False
return True
module_names = []
for module_name in PUBLIC_MODULES:
if not check_importable(module_name):
module_names.append(module_name)
if module_names:
raise AssertionError("Modules in the public API that cannot be "
f"imported: {module_names}")
with warnings.catch_warnings(record=True):
warnings.filterwarnings('always', category=DeprecationWarning)
warnings.filterwarnings('always', category=ImportWarning)
for module_name in PRIVATE_BUT_PRESENT_MODULES:
if not check_importable(module_name):
module_names.append(module_name)
if module_names:
raise AssertionError("Modules that are not really public but looked "
"public and can not be imported: "
f"{module_names}")
@pytest.mark.parametrize(("module_name", "correct_module"),
[('scipy.constants.codata', None),
('scipy.constants.constants', None),
('scipy.fftpack.basic', None),
('scipy.fftpack.helper', None),
('scipy.fftpack.pseudo_diffs', None),
('scipy.fftpack.realtransforms', None),
('scipy.integrate.dop', None),
('scipy.integrate.lsoda', None),
('scipy.integrate.odepack', None),
('scipy.integrate.quadpack', None),
('scipy.integrate.vode', None),
('scipy.interpolate.fitpack', None),
('scipy.interpolate.fitpack2', None),
('scipy.interpolate.interpolate', None),
('scipy.interpolate.ndgriddata', None),
('scipy.interpolate.polyint', None),
('scipy.interpolate.rbf', None),
('scipy.io.harwell_boeing', None),
('scipy.io.idl', None),
('scipy.io.mmio', None),
('scipy.io.netcdf', None),
('scipy.io.arff.arffread', 'arff'),
('scipy.io.matlab.byteordercodes', 'matlab'),
('scipy.io.matlab.mio_utils', 'matlab'),
('scipy.io.matlab.mio', 'matlab'),
('scipy.io.matlab.mio4', 'matlab'),
('scipy.io.matlab.mio5_params', 'matlab'),
('scipy.io.matlab.mio5_utils', 'matlab'),
('scipy.io.matlab.mio5', 'matlab'),
('scipy.io.matlab.miobase', 'matlab'),
('scipy.io.matlab.streams', 'matlab'),
('scipy.linalg.basic', None),
('scipy.linalg.decomp', None),
('scipy.linalg.decomp_cholesky', None),
('scipy.linalg.decomp_lu', None),
('scipy.linalg.decomp_qr', None),
('scipy.linalg.decomp_schur', None),
('scipy.linalg.decomp_svd', None),
('scipy.linalg.matfuncs', None),
('scipy.linalg.misc', None),
('scipy.linalg.special_matrices', None),
('scipy.misc.common', None),
('scipy.ndimage.filters', None),
('scipy.ndimage.fourier', None),
('scipy.ndimage.interpolation', None),
('scipy.ndimage.measurements', None),
('scipy.ndimage.morphology', None),
('scipy.odr.models', None),
('scipy.odr.odrpack', None),
('scipy.optimize.cobyla', None),
('scipy.optimize.lbfgsb', None),
('scipy.optimize.linesearch', None),
('scipy.optimize.minpack', None),
('scipy.optimize.minpack2', None),
('scipy.optimize.moduleTNC', None),
('scipy.optimize.nonlin', None),
('scipy.optimize.optimize', None),
('scipy.optimize.slsqp', None),
('scipy.optimize.tnc', None),
('scipy.optimize.zeros', None),
('scipy.signal.bsplines', None),
('scipy.signal.filter_design', None),
('scipy.signal.fir_filter_design', None),
('scipy.signal.lti_conversion', None),
('scipy.signal.ltisys', None),
('scipy.signal.signaltools', None),
('scipy.signal.spectral', None),
('scipy.signal.waveforms', None),
('scipy.signal.wavelets', None),
('scipy.signal.windows.windows', 'windows'),
('scipy.sparse.lil', None),
('scipy.sparse.linalg.dsolve', 'linalg'),
('scipy.sparse.linalg.eigen', 'linalg'),
('scipy.sparse.linalg.interface', 'linalg'),
('scipy.sparse.linalg.isolve', 'linalg'),
('scipy.sparse.linalg.matfuncs', 'linalg'),
('scipy.sparse.sparsetools', None),
('scipy.sparse.spfuncs', None),
('scipy.sparse.sputils', None),
('scipy.spatial.ckdtree', None),
('scipy.spatial.kdtree', None),
('scipy.spatial.qhull', None),
('scipy.spatial.transform.rotation', 'transform'),
('scipy.special.add_newdocs', None),
('scipy.special.basic', None),
('scipy.special.orthogonal', None),
('scipy.special.sf_error', None),
('scipy.special.specfun', None),
('scipy.special.spfun_stats', None),
('scipy.stats.biasedurn', None),
('scipy.stats.kde', None),
('scipy.stats.morestats', None),
('scipy.stats.mstats_basic', 'mstats'),
('scipy.stats.mstats_extras', 'mstats'),
('scipy.stats.mvn', None),
('scipy.stats.stats', None)])
def test_private_but_present_deprecation(module_name, correct_module):
# gh-18279, gh-17572, gh-17771 noted that deprecation warnings
# for imports from private modules
# were misleading. Check that this is resolved.
module = import_module(module_name)
if correct_module is None:
import_name = f'scipy.{module_name.split(".")[1]}'
else:
import_name = f'scipy.{module_name.split(".")[1]}.{correct_module}'
correct_import = import_module(import_name)
# Attributes that were formerly in `module_name` can still be imported from
# `module_name`, albeit with a deprecation warning.
for attr_name in module.__all__:
if attr_name == "varmats_from_mat":
# defer handling this case, see
# https://github.com/scipy/scipy/issues/19223
continue
# ensure attribute is present where the warning is pointing
assert getattr(correct_import, attr_name, None) is not None
message = f"Please import `{attr_name}` from the `{import_name}`..."
with pytest.deprecated_call(match=message):
getattr(module, attr_name)
# Attributes that were not in `module_name` get an error notifying the user
# that the attribute is not in `module_name` and that `module_name` is deprecated.
message = f"`{module_name}` is deprecated..."
with pytest.raises(AttributeError, match=message):
getattr(module, "ekki")
def test_misc_doccer_deprecation():
# gh-18279, gh-17572, gh-17771 noted that deprecation warnings
# for imports from private modules were misleading.
# Check that this is resolved.
# `test_private_but_present_deprecation` cannot be used since `correct_import`
# is a different subpackage (`_lib` instead of `misc`).
module = import_module('scipy.misc.doccer')
correct_import = import_module('scipy._lib.doccer')
# Attributes that were formerly in `scipy.misc.doccer` can still be imported from
# `scipy.misc.doccer`, albeit with a deprecation warning. The specific message
# depends on whether the attribute is in `scipy._lib.doccer` or not.
for attr_name in module.__all__:
attr = getattr(correct_import, attr_name, None)
if attr is None:
message = f"`scipy.misc.{attr_name}` is deprecated..."
else:
message = f"Please import `{attr_name}` from the `scipy._lib.doccer`..."
with pytest.deprecated_call(match=message):
getattr(module, attr_name)
# Attributes that were not in `scipy.misc.doccer` get an error
# notifying the user that the attribute is not in `scipy.misc.doccer`
# and that `scipy.misc.doccer` is deprecated.
message = "`scipy.misc.doccer` is deprecated..."
with pytest.raises(AttributeError, match=message):
getattr(module, "ekki")

View File

@ -0,0 +1,18 @@
import re
import scipy
from numpy.testing import assert_
def test_valid_scipy_version():
# Verify that the SciPy version is a valid one (no .post suffix or other
# nonsense). See NumPy issue gh-6431 for an issue caused by an invalid
# version.
version_pattern = r"^[0-9]+\.[0-9]+\.[0-9]+(|a[0-9]|b[0-9]|rc[0-9])"
dev_suffix = r"(\.dev0\+.+([0-9a-f]{7}|Unknown))"
if scipy.version.release:
res = re.match(version_pattern, scipy.__version__)
else:
res = re.match(version_pattern + dev_suffix, scipy.__version__)
assert_(res is not None, scipy.__version__)

View File

@ -0,0 +1,42 @@
""" Test tmpdirs module """
from os import getcwd
from os.path import realpath, abspath, dirname, isfile, join as pjoin, exists
from scipy._lib._tmpdirs import tempdir, in_tempdir, in_dir
from numpy.testing import assert_, assert_equal
MY_PATH = abspath(__file__)
MY_DIR = dirname(MY_PATH)
def test_tempdir():
with tempdir() as tmpdir:
fname = pjoin(tmpdir, 'example_file.txt')
with open(fname, "w") as fobj:
fobj.write('a string\\n')
assert_(not exists(tmpdir))
def test_in_tempdir():
my_cwd = getcwd()
with in_tempdir() as tmpdir:
with open('test.txt', "w") as f:
f.write('some text')
assert_(isfile('test.txt'))
assert_(isfile(pjoin(tmpdir, 'test.txt')))
assert_(not exists(tmpdir))
assert_equal(getcwd(), my_cwd)
def test_given_directory():
# Test InGivenDirectory
cwd = getcwd()
with in_dir() as tmpdir:
assert_equal(tmpdir, abspath(cwd))
assert_equal(tmpdir, abspath(getcwd()))
with in_dir(MY_DIR) as tmpdir:
assert_equal(tmpdir, MY_DIR)
assert_equal(realpath(MY_DIR), realpath(abspath(getcwd())))
# We were deleting the given directory! Check not so now.
assert_(isfile(MY_PATH))

View File

@ -0,0 +1,135 @@
"""
Tests which scan for certain occurrences in the code, they may not find
all of these occurrences but should catch almost all. This file was adapted
from NumPy.
"""
import os
from pathlib import Path
import ast
import tokenize
import scipy
import pytest
class ParseCall(ast.NodeVisitor):
def __init__(self):
self.ls = []
def visit_Attribute(self, node):
ast.NodeVisitor.generic_visit(self, node)
self.ls.append(node.attr)
def visit_Name(self, node):
self.ls.append(node.id)
class FindFuncs(ast.NodeVisitor):
def __init__(self, filename):
super().__init__()
self.__filename = filename
self.bad_filters = []
self.bad_stacklevels = []
def visit_Call(self, node):
p = ParseCall()
p.visit(node.func)
ast.NodeVisitor.generic_visit(self, node)
if p.ls[-1] == 'simplefilter' or p.ls[-1] == 'filterwarnings':
# get first argument of the `args` node of the filter call
match node.args[0]:
case ast.Constant() as c:
argtext = c.value
case ast.JoinedStr() as js:
# if we get an f-string, discard the templated pieces, which
# are likely the type or specific message; we're interested
# in the action, which is less likely to use a template
argtext = "".join(
x.value for x in js.values if isinstance(x, ast.Constant)
)
case _:
raise ValueError("unknown ast node type")
# check if filter is set to ignore
if argtext == "ignore":
self.bad_filters.append(
f"{self.__filename}:{node.lineno}")
if p.ls[-1] == 'warn' and (
len(p.ls) == 1 or p.ls[-2] == 'warnings'):
if self.__filename == "_lib/tests/test_warnings.py":
# This file
return
# See if stacklevel exists:
if len(node.args) == 3:
return
args = {kw.arg for kw in node.keywords}
if "stacklevel" not in args:
self.bad_stacklevels.append(
f"{self.__filename}:{node.lineno}")
@pytest.fixture(scope="session")
def warning_calls():
# combined "ignore" and stacklevel error
base = Path(scipy.__file__).parent
bad_filters = []
bad_stacklevels = []
for path in base.rglob("*.py"):
# use tokenize to auto-detect encoding on systems where no
# default encoding is defined (e.g., LANG='C')
with tokenize.open(str(path)) as file:
tree = ast.parse(file.read(), filename=str(path))
finder = FindFuncs(path.relative_to(base))
finder.visit(tree)
bad_filters.extend(finder.bad_filters)
bad_stacklevels.extend(finder.bad_stacklevels)
return bad_filters, bad_stacklevels
@pytest.mark.fail_slow(20)
@pytest.mark.slow
def test_warning_calls_filters(warning_calls):
bad_filters, bad_stacklevels = warning_calls
# We try not to add filters in the code base, because those filters aren't
# thread-safe. We aim to only filter in tests with
# np.testing.suppress_warnings. However, in some cases it may prove
# necessary to filter out warnings, because we can't (easily) fix the root
# cause for them and we don't want users to see some warnings when they use
# SciPy correctly. So we list exceptions here. Add new entries only if
# there's a good reason.
allowed_filters = (
os.path.join('datasets', '_fetchers.py'),
os.path.join('datasets', '__init__.py'),
os.path.join('optimize', '_optimize.py'),
os.path.join('optimize', '_constraints.py'),
os.path.join('optimize', '_nnls.py'),
os.path.join('signal', '_ltisys.py'),
os.path.join('sparse', '__init__.py'), # np.matrix pending-deprecation
os.path.join('stats', '_discrete_distns.py'), # gh-14901
os.path.join('stats', '_continuous_distns.py'),
os.path.join('stats', '_binned_statistic.py'), # gh-19345
os.path.join('stats', 'tests', 'test_axis_nan_policy.py'), # gh-20694
os.path.join('_lib', '_util.py'), # gh-19341
os.path.join('sparse', 'linalg', '_dsolve', 'linsolve.py'), # gh-17924
"conftest.py",
)
bad_filters = [item for item in bad_filters if item.split(':')[0] not in
allowed_filters]
if bad_filters:
raise AssertionError(
"warning ignore filter should not be used, instead, use\n"
"numpy.testing.suppress_warnings (in tests only);\n"
"found in:\n {}".format(
"\n ".join(bad_filters)))

View File

@ -0,0 +1,31 @@
"""`uarray` provides functions for generating multimethods that dispatch to
multiple different backends
This should be imported, rather than `_uarray` so that an installed version could
be used instead, if available. This means that users can call
`uarray.set_backend` directly instead of going through SciPy.
"""
# Prefer an installed version of uarray, if available
try:
import uarray as _uarray
except ImportError:
_has_uarray = False
else:
from scipy._lib._pep440 import Version as _Version
_has_uarray = _Version(_uarray.__version__) >= _Version("0.8")
del _uarray
del _Version
if _has_uarray:
from uarray import * # noqa: F403
from uarray import _Function
else:
from ._uarray import * # noqa: F403
from ._uarray import _Function # noqa: F401
del _has_uarray

View File

@ -0,0 +1,31 @@
"""
=========================================
Clustering package (:mod:`scipy.cluster`)
=========================================
.. currentmodule:: scipy.cluster
.. toctree::
:hidden:
cluster.vq
cluster.hierarchy
Clustering algorithms are useful in information theory, target detection,
communications, compression, and other areas. The `vq` module only
supports vector quantization and the k-means algorithms.
The `hierarchy` module provides functions for hierarchical and
agglomerative clustering. Its features include generating hierarchical
clusters from distance matrices,
calculating statistics on clusters, cutting linkages
to generate flat clusters, and visualizing clusters with dendrograms.
"""
__all__ = ['vq', 'hierarchy']
from . import vq, hierarchy
from scipy._lib._testutils import PytestTester
test = PytestTester(__name__)
del PytestTester

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,145 @@
from numpy import array
Q_X = array([[5.26563660e-01, 3.14160190e-01, 8.00656370e-02],
[7.50205180e-01, 4.60299830e-01, 8.98696460e-01],
[6.65461230e-01, 6.94011420e-01, 9.10465700e-01],
[9.64047590e-01, 1.43082200e-03, 7.39874220e-01],
[1.08159060e-01, 5.53028790e-01, 6.63804780e-02],
[9.31359130e-01, 8.25424910e-01, 9.52315440e-01],
[6.78086960e-01, 3.41903970e-01, 5.61481950e-01],
[9.82730940e-01, 7.04605210e-01, 8.70978630e-02],
[6.14691610e-01, 4.69989230e-02, 6.02406450e-01],
[5.80161260e-01, 9.17354970e-01, 5.88163850e-01],
[1.38246310e+00, 1.96358160e+00, 1.94437880e+00],
[2.10675860e+00, 1.67148730e+00, 1.34854480e+00],
[1.39880070e+00, 1.66142050e+00, 1.32224550e+00],
[1.71410460e+00, 1.49176380e+00, 1.45432170e+00],
[1.54102340e+00, 1.84374950e+00, 1.64658950e+00],
[2.08512480e+00, 1.84524350e+00, 2.17340850e+00],
[1.30748740e+00, 1.53801650e+00, 2.16007740e+00],
[1.41447700e+00, 1.99329070e+00, 1.99107420e+00],
[1.61943490e+00, 1.47703280e+00, 1.89788160e+00],
[1.59880600e+00, 1.54988980e+00, 1.57563350e+00],
[3.37247380e+00, 2.69635310e+00, 3.39981700e+00],
[3.13705120e+00, 3.36528090e+00, 3.06089070e+00],
[3.29413250e+00, 3.19619500e+00, 2.90700170e+00],
[2.65510510e+00, 3.06785900e+00, 2.97198540e+00],
[3.30941040e+00, 2.59283970e+00, 2.57714110e+00],
[2.59557220e+00, 3.33477370e+00, 3.08793190e+00],
[2.58206180e+00, 3.41615670e+00, 3.26441990e+00],
[2.71127000e+00, 2.77032450e+00, 2.63466500e+00],
[2.79617850e+00, 3.25473720e+00, 3.41801560e+00],
[2.64741750e+00, 2.54538040e+00, 3.25354110e+00]])
ytdist = array([662., 877., 255., 412., 996., 295., 468., 268., 400., 754.,
564., 138., 219., 869., 669.])
linkage_ytdist_single = array([[2., 5., 138., 2.],
[3., 4., 219., 2.],
[0., 7., 255., 3.],
[1., 8., 268., 4.],
[6., 9., 295., 6.]])
linkage_ytdist_complete = array([[2., 5., 138., 2.],
[3., 4., 219., 2.],
[1., 6., 400., 3.],
[0., 7., 412., 3.],
[8., 9., 996., 6.]])
linkage_ytdist_average = array([[2., 5., 138., 2.],
[3., 4., 219., 2.],
[0., 7., 333.5, 3.],
[1., 6., 347.5, 3.],
[8., 9., 680.77777778, 6.]])
linkage_ytdist_weighted = array([[2., 5., 138., 2.],
[3., 4., 219., 2.],
[0., 7., 333.5, 3.],
[1., 6., 347.5, 3.],
[8., 9., 670.125, 6.]])
# the optimal leaf ordering of linkage_ytdist_single
linkage_ytdist_single_olo = array([[5., 2., 138., 2.],
[4., 3., 219., 2.],
[7., 0., 255., 3.],
[1., 8., 268., 4.],
[6., 9., 295., 6.]])
X = array([[1.43054825, -7.5693489],
[6.95887839, 6.82293382],
[2.87137846, -9.68248579],
[7.87974764, -6.05485803],
[8.24018364, -6.09495602],
[7.39020262, 8.54004355]])
linkage_X_centroid = array([[3., 4., 0.36265956, 2.],
[1., 5., 1.77045373, 2.],
[0., 2., 2.55760419, 2.],
[6., 8., 6.43614494, 4.],
[7., 9., 15.17363237, 6.]])
linkage_X_median = array([[3., 4., 0.36265956, 2.],
[1., 5., 1.77045373, 2.],
[0., 2., 2.55760419, 2.],
[6., 8., 6.43614494, 4.],
[7., 9., 15.17363237, 6.]])
linkage_X_ward = array([[3., 4., 0.36265956, 2.],
[1., 5., 1.77045373, 2.],
[0., 2., 2.55760419, 2.],
[6., 8., 9.10208346, 4.],
[7., 9., 24.7784379, 6.]])
# the optimal leaf ordering of linkage_X_ward
linkage_X_ward_olo = array([[4., 3., 0.36265956, 2.],
[5., 1., 1.77045373, 2.],
[2., 0., 2.55760419, 2.],
[6., 8., 9.10208346, 4.],
[7., 9., 24.7784379, 6.]])
inconsistent_ytdist = {
1: array([[138., 0., 1., 0.],
[219., 0., 1., 0.],
[255., 0., 1., 0.],
[268., 0., 1., 0.],
[295., 0., 1., 0.]]),
2: array([[138., 0., 1., 0.],
[219., 0., 1., 0.],
[237., 25.45584412, 2., 0.70710678],
[261.5, 9.19238816, 2., 0.70710678],
[233.66666667, 83.9424406, 3., 0.7306594]]),
3: array([[138., 0., 1., 0.],
[219., 0., 1., 0.],
[237., 25.45584412, 2., 0.70710678],
[247.33333333, 25.38372182, 3., 0.81417007],
[239., 69.36377537, 4., 0.80733783]]),
4: array([[138., 0., 1., 0.],
[219., 0., 1., 0.],
[237., 25.45584412, 2., 0.70710678],
[247.33333333, 25.38372182, 3., 0.81417007],
[235., 60.73302232, 5., 0.98793042]])}
fcluster_inconsistent = {
0.8: array([6, 2, 2, 4, 6, 2, 3, 7, 3, 5, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1]),
1.0: array([6, 2, 2, 4, 6, 2, 3, 7, 3, 5, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1]),
2.0: array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1])}
fcluster_distance = {
0.6: array([4, 4, 4, 4, 4, 4, 4, 5, 4, 4, 6, 6, 6, 6, 6, 7, 6, 6, 6, 6, 3,
1, 1, 1, 2, 1, 1, 1, 1, 1]),
1.0: array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1]),
2.0: array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1])}
fcluster_maxclust = {
8.0: array([5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 7, 7, 7, 7, 7, 8, 7, 7, 7, 7, 4,
1, 1, 1, 3, 1, 1, 1, 1, 2]),
4.0: array([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2,
1, 1, 1, 1, 1, 1, 1, 1, 1]),
1.0: array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1])}

View File

@ -0,0 +1,202 @@
import pytest
from pytest import raises as assert_raises
import numpy as np
from scipy.cluster.hierarchy import DisjointSet
import string
def generate_random_token():
k = len(string.ascii_letters)
tokens = list(np.arange(k, dtype=int))
tokens += list(np.arange(k, dtype=float))
tokens += list(string.ascii_letters)
tokens += [None for i in range(k)]
tokens = np.array(tokens, dtype=object)
rng = np.random.RandomState(seed=0)
while 1:
size = rng.randint(1, 3)
element = rng.choice(tokens, size)
if size == 1:
yield element[0]
else:
yield tuple(element)
def get_elements(n):
# dict is deterministic without difficulty of comparing numpy ints
elements = {}
for element in generate_random_token():
if element not in elements:
elements[element] = len(elements)
if len(elements) >= n:
break
return list(elements.keys())
def test_init():
n = 10
elements = get_elements(n)
dis = DisjointSet(elements)
assert dis.n_subsets == n
assert list(dis) == elements
def test_len():
n = 10
elements = get_elements(n)
dis = DisjointSet(elements)
assert len(dis) == n
dis.add("dummy")
assert len(dis) == n + 1
@pytest.mark.parametrize("n", [10, 100])
def test_contains(n):
elements = get_elements(n)
dis = DisjointSet(elements)
for x in elements:
assert x in dis
assert "dummy" not in dis
@pytest.mark.parametrize("n", [10, 100])
def test_add(n):
elements = get_elements(n)
dis1 = DisjointSet(elements)
dis2 = DisjointSet()
for i, x in enumerate(elements):
dis2.add(x)
assert len(dis2) == i + 1
# test idempotency by adding element again
dis2.add(x)
assert len(dis2) == i + 1
assert list(dis1) == list(dis2)
def test_element_not_present():
elements = get_elements(n=10)
dis = DisjointSet(elements)
with assert_raises(KeyError):
dis["dummy"]
with assert_raises(KeyError):
dis.merge(elements[0], "dummy")
with assert_raises(KeyError):
dis.connected(elements[0], "dummy")
@pytest.mark.parametrize("direction", ["forwards", "backwards"])
@pytest.mark.parametrize("n", [10, 100])
def test_linear_union_sequence(n, direction):
elements = get_elements(n)
dis = DisjointSet(elements)
assert elements == list(dis)
indices = list(range(n - 1))
if direction == "backwards":
indices = indices[::-1]
for it, i in enumerate(indices):
assert not dis.connected(elements[i], elements[i + 1])
assert dis.merge(elements[i], elements[i + 1])
assert dis.connected(elements[i], elements[i + 1])
assert dis.n_subsets == n - 1 - it
roots = [dis[i] for i in elements]
if direction == "forwards":
assert all(elements[0] == r for r in roots)
else:
assert all(elements[-2] == r for r in roots)
assert not dis.merge(elements[0], elements[-1])
@pytest.mark.parametrize("n", [10, 100])
def test_self_unions(n):
elements = get_elements(n)
dis = DisjointSet(elements)
for x in elements:
assert dis.connected(x, x)
assert not dis.merge(x, x)
assert dis.connected(x, x)
assert dis.n_subsets == len(elements)
assert elements == list(dis)
roots = [dis[x] for x in elements]
assert elements == roots
@pytest.mark.parametrize("order", ["ab", "ba"])
@pytest.mark.parametrize("n", [10, 100])
def test_equal_size_ordering(n, order):
elements = get_elements(n)
dis = DisjointSet(elements)
rng = np.random.RandomState(seed=0)
indices = np.arange(n)
rng.shuffle(indices)
for i in range(0, len(indices), 2):
a, b = elements[indices[i]], elements[indices[i + 1]]
if order == "ab":
assert dis.merge(a, b)
else:
assert dis.merge(b, a)
expected = elements[min(indices[i], indices[i + 1])]
assert dis[a] == expected
assert dis[b] == expected
@pytest.mark.parametrize("kmax", [5, 10])
def test_binary_tree(kmax):
n = 2**kmax
elements = get_elements(n)
dis = DisjointSet(elements)
rng = np.random.RandomState(seed=0)
for k in 2**np.arange(kmax):
for i in range(0, n, 2 * k):
r1, r2 = rng.randint(0, k, size=2)
a, b = elements[i + r1], elements[i + k + r2]
assert not dis.connected(a, b)
assert dis.merge(a, b)
assert dis.connected(a, b)
assert elements == list(dis)
roots = [dis[i] for i in elements]
expected_indices = np.arange(n) - np.arange(n) % (2 * k)
expected = [elements[i] for i in expected_indices]
assert roots == expected
@pytest.mark.parametrize("n", [10, 100])
def test_subsets(n):
elements = get_elements(n)
dis = DisjointSet(elements)
rng = np.random.RandomState(seed=0)
for i, j in rng.randint(0, n, (n, 2)):
x = elements[i]
y = elements[j]
expected = {element for element in dis if {dis[element]} == {dis[x]}}
assert dis.subset_size(x) == len(dis.subset(x))
assert expected == dis.subset(x)
expected = {dis[element]: set() for element in dis}
for element in dis:
expected[dis[element]].add(element)
expected = list(expected.values())
assert expected == dis.subsets()
dis.merge(x, y)
assert dis.subset(x) == dis.subset(y)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,435 @@
import warnings
import sys
from copy import deepcopy
import numpy as np
from numpy.testing import (
assert_array_equal, assert_equal, assert_, suppress_warnings
)
import pytest
from pytest import raises as assert_raises
from scipy.cluster.vq import (kmeans, kmeans2, py_vq, vq, whiten,
ClusterError, _krandinit)
from scipy.cluster import _vq
from scipy.conftest import array_api_compatible
from scipy.sparse._sputils import matrix
from scipy._lib._array_api import (
SCIPY_ARRAY_API, copy, cov, xp_assert_close, xp_assert_equal
)
pytestmark = [array_api_compatible, pytest.mark.usefixtures("skip_xp_backends")]
skip_xp_backends = pytest.mark.skip_xp_backends
TESTDATA_2D = np.array([
-2.2, 1.17, -1.63, 1.69, -2.04, 4.38, -3.09, 0.95, -1.7, 4.79, -1.68, 0.68,
-2.26, 3.34, -2.29, 2.55, -1.72, -0.72, -1.99, 2.34, -2.75, 3.43, -2.45,
2.41, -4.26, 3.65, -1.57, 1.87, -1.96, 4.03, -3.01, 3.86, -2.53, 1.28,
-4.0, 3.95, -1.62, 1.25, -3.42, 3.17, -1.17, 0.12, -3.03, -0.27, -2.07,
-0.55, -1.17, 1.34, -2.82, 3.08, -2.44, 0.24, -1.71, 2.48, -5.23, 4.29,
-2.08, 3.69, -1.89, 3.62, -2.09, 0.26, -0.92, 1.07, -2.25, 0.88, -2.25,
2.02, -4.31, 3.86, -2.03, 3.42, -2.76, 0.3, -2.48, -0.29, -3.42, 3.21,
-2.3, 1.73, -2.84, 0.69, -1.81, 2.48, -5.24, 4.52, -2.8, 1.31, -1.67,
-2.34, -1.18, 2.17, -2.17, 2.82, -1.85, 2.25, -2.45, 1.86, -6.79, 3.94,
-2.33, 1.89, -1.55, 2.08, -1.36, 0.93, -2.51, 2.74, -2.39, 3.92, -3.33,
2.99, -2.06, -0.9, -2.83, 3.35, -2.59, 3.05, -2.36, 1.85, -1.69, 1.8,
-1.39, 0.66, -2.06, 0.38, -1.47, 0.44, -4.68, 3.77, -5.58, 3.44, -2.29,
2.24, -1.04, -0.38, -1.85, 4.23, -2.88, 0.73, -2.59, 1.39, -1.34, 1.75,
-1.95, 1.3, -2.45, 3.09, -1.99, 3.41, -5.55, 5.21, -1.73, 2.52, -2.17,
0.85, -2.06, 0.49, -2.54, 2.07, -2.03, 1.3, -3.23, 3.09, -1.55, 1.44,
-0.81, 1.1, -2.99, 2.92, -1.59, 2.18, -2.45, -0.73, -3.12, -1.3, -2.83,
0.2, -2.77, 3.24, -1.98, 1.6, -4.59, 3.39, -4.85, 3.75, -2.25, 1.71, -3.28,
3.38, -1.74, 0.88, -2.41, 1.92, -2.24, 1.19, -2.48, 1.06, -1.68, -0.62,
-1.3, 0.39, -1.78, 2.35, -3.54, 2.44, -1.32, 0.66, -2.38, 2.76, -2.35,
3.95, -1.86, 4.32, -2.01, -1.23, -1.79, 2.76, -2.13, -0.13, -5.25, 3.84,
-2.24, 1.59, -4.85, 2.96, -2.41, 0.01, -0.43, 0.13, -3.92, 2.91, -1.75,
-0.53, -1.69, 1.69, -1.09, 0.15, -2.11, 2.17, -1.53, 1.22, -2.1, -0.86,
-2.56, 2.28, -3.02, 3.33, -1.12, 3.86, -2.18, -1.19, -3.03, 0.79, -0.83,
0.97, -3.19, 1.45, -1.34, 1.28, -2.52, 4.22, -4.53, 3.22, -1.97, 1.75,
-2.36, 3.19, -0.83, 1.53, -1.59, 1.86, -2.17, 2.3, -1.63, 2.71, -2.03,
3.75, -2.57, -0.6, -1.47, 1.33, -1.95, 0.7, -1.65, 1.27, -1.42, 1.09, -3.0,
3.87, -2.51, 3.06, -2.6, 0.74, -1.08, -0.03, -2.44, 1.31, -2.65, 2.99,
-1.84, 1.65, -4.76, 3.75, -2.07, 3.98, -2.4, 2.67, -2.21, 1.49, -1.21,
1.22, -5.29, 2.38, -2.85, 2.28, -5.6, 3.78, -2.7, 0.8, -1.81, 3.5, -3.75,
4.17, -1.29, 2.99, -5.92, 3.43, -1.83, 1.23, -1.24, -1.04, -2.56, 2.37,
-3.26, 0.39, -4.63, 2.51, -4.52, 3.04, -1.7, 0.36, -1.41, 0.04, -2.1, 1.0,
-1.87, 3.78, -4.32, 3.59, -2.24, 1.38, -1.99, -0.22, -1.87, 1.95, -0.84,
2.17, -5.38, 3.56, -1.27, 2.9, -1.79, 3.31, -5.47, 3.85, -1.44, 3.69,
-2.02, 0.37, -1.29, 0.33, -2.34, 2.56, -1.74, -1.27, -1.97, 1.22, -2.51,
-0.16, -1.64, -0.96, -2.99, 1.4, -1.53, 3.31, -2.24, 0.45, -2.46, 1.71,
-2.88, 1.56, -1.63, 1.46, -1.41, 0.68, -1.96, 2.76, -1.61,
2.11]).reshape((200, 2))
# Global data
X = np.array([[3.0, 3], [4, 3], [4, 2],
[9, 2], [5, 1], [6, 2], [9, 4],
[5, 2], [5, 4], [7, 4], [6, 5]])
CODET1 = np.array([[3.0000, 3.0000],
[6.2000, 4.0000],
[5.8000, 1.8000]])
CODET2 = np.array([[11.0/3, 8.0/3],
[6.7500, 4.2500],
[6.2500, 1.7500]])
LABEL1 = np.array([0, 1, 2, 2, 2, 2, 1, 2, 1, 1, 1])
class TestWhiten:
def test_whiten(self, xp):
desired = xp.asarray([[5.08738849, 2.97091878],
[3.19909255, 0.69660580],
[4.51041982, 0.02640918],
[4.38567074, 0.95120889],
[2.32191480, 1.63195503]])
obs = xp.asarray([[0.98744510, 0.82766775],
[0.62093317, 0.19406729],
[0.87545741, 0.00735733],
[0.85124403, 0.26499712],
[0.45067590, 0.45464607]])
xp_assert_close(whiten(obs), desired, rtol=1e-5)
@skip_xp_backends('jax.numpy',
reasons=['jax arrays do not support item assignment'])
def test_whiten_zero_std(self, xp):
desired = xp.asarray([[0., 1.0, 2.86666544],
[0., 1.0, 1.32460034],
[0., 1.0, 3.74382172]])
obs = xp.asarray([[0., 1., 0.74109533],
[0., 1., 0.34243798],
[0., 1., 0.96785929]])
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
xp_assert_close(whiten(obs), desired, rtol=1e-5)
assert_equal(len(w), 1)
assert_(issubclass(w[-1].category, RuntimeWarning))
def test_whiten_not_finite(self, xp):
for bad_value in xp.nan, xp.inf, -xp.inf:
obs = xp.asarray([[0.98744510, bad_value],
[0.62093317, 0.19406729],
[0.87545741, 0.00735733],
[0.85124403, 0.26499712],
[0.45067590, 0.45464607]])
assert_raises(ValueError, whiten, obs)
@pytest.mark.skipif(SCIPY_ARRAY_API,
reason='`np.matrix` unsupported in array API mode')
def test_whiten_not_finite_matrix(self, xp):
for bad_value in np.nan, np.inf, -np.inf:
obs = matrix([[0.98744510, bad_value],
[0.62093317, 0.19406729],
[0.87545741, 0.00735733],
[0.85124403, 0.26499712],
[0.45067590, 0.45464607]])
assert_raises(ValueError, whiten, obs)
class TestVq:
@skip_xp_backends(cpu_only=True)
def test_py_vq(self, xp):
initc = np.concatenate([[X[0]], [X[1]], [X[2]]])
# label1.dtype varies between int32 and int64 over platforms
label1 = py_vq(xp.asarray(X), xp.asarray(initc))[0]
xp_assert_equal(label1, xp.asarray(LABEL1, dtype=xp.int64),
check_dtype=False)
@pytest.mark.skipif(SCIPY_ARRAY_API,
reason='`np.matrix` unsupported in array API mode')
def test_py_vq_matrix(self, xp):
initc = np.concatenate([[X[0]], [X[1]], [X[2]]])
# label1.dtype varies between int32 and int64 over platforms
label1 = py_vq(matrix(X), matrix(initc))[0]
assert_array_equal(label1, LABEL1)
@skip_xp_backends(np_only=True, reasons=['`_vq` only supports NumPy backend'])
def test_vq(self, xp):
initc = np.concatenate([[X[0]], [X[1]], [X[2]]])
label1, _ = _vq.vq(xp.asarray(X), xp.asarray(initc))
assert_array_equal(label1, LABEL1)
_, _ = vq(xp.asarray(X), xp.asarray(initc))
@pytest.mark.skipif(SCIPY_ARRAY_API,
reason='`np.matrix` unsupported in array API mode')
def test_vq_matrix(self, xp):
initc = np.concatenate([[X[0]], [X[1]], [X[2]]])
label1, _ = _vq.vq(matrix(X), matrix(initc))
assert_array_equal(label1, LABEL1)
_, _ = vq(matrix(X), matrix(initc))
@skip_xp_backends(cpu_only=True)
def test_vq_1d(self, xp):
# Test special rank 1 vq algo, python implementation.
data = X[:, 0]
initc = data[:3]
a, b = _vq.vq(data, initc)
data = xp.asarray(data)
initc = xp.asarray(initc)
ta, tb = py_vq(data[:, np.newaxis], initc[:, np.newaxis])
# ta.dtype varies between int32 and int64 over platforms
xp_assert_equal(ta, xp.asarray(a, dtype=xp.int64), check_dtype=False)
xp_assert_equal(tb, xp.asarray(b))
@skip_xp_backends(np_only=True, reasons=['`_vq` only supports NumPy backend'])
def test__vq_sametype(self, xp):
a = xp.asarray([1.0, 2.0], dtype=xp.float64)
b = a.astype(xp.float32)
assert_raises(TypeError, _vq.vq, a, b)
@skip_xp_backends(np_only=True, reasons=['`_vq` only supports NumPy backend'])
def test__vq_invalid_type(self, xp):
a = xp.asarray([1, 2], dtype=int)
assert_raises(TypeError, _vq.vq, a, a)
@skip_xp_backends(cpu_only=True)
def test_vq_large_nfeat(self, xp):
X = np.random.rand(20, 20)
code_book = np.random.rand(3, 20)
codes0, dis0 = _vq.vq(X, code_book)
codes1, dis1 = py_vq(
xp.asarray(X), xp.asarray(code_book)
)
xp_assert_close(dis1, xp.asarray(dis0), rtol=1e-5)
# codes1.dtype varies between int32 and int64 over platforms
xp_assert_equal(codes1, xp.asarray(codes0, dtype=xp.int64), check_dtype=False)
X = X.astype(np.float32)
code_book = code_book.astype(np.float32)
codes0, dis0 = _vq.vq(X, code_book)
codes1, dis1 = py_vq(
xp.asarray(X), xp.asarray(code_book)
)
xp_assert_close(dis1, xp.asarray(dis0, dtype=xp.float64), rtol=1e-5)
# codes1.dtype varies between int32 and int64 over platforms
xp_assert_equal(codes1, xp.asarray(codes0, dtype=xp.int64), check_dtype=False)
@skip_xp_backends(cpu_only=True)
def test_vq_large_features(self, xp):
X = np.random.rand(10, 5) * 1000000
code_book = np.random.rand(2, 5) * 1000000
codes0, dis0 = _vq.vq(X, code_book)
codes1, dis1 = py_vq(
xp.asarray(X), xp.asarray(code_book)
)
xp_assert_close(dis1, xp.asarray(dis0), rtol=1e-5)
# codes1.dtype varies between int32 and int64 over platforms
xp_assert_equal(codes1, xp.asarray(codes0, dtype=xp.int64), check_dtype=False)
# Whole class skipped on GPU for now;
# once pdist/cdist are hooked up for CuPy, more tests will work
@skip_xp_backends(cpu_only=True)
class TestKMean:
def test_large_features(self, xp):
# Generate a data set with large values, and run kmeans on it to
# (regression for 1077).
d = 300
n = 100
m1 = np.random.randn(d)
m2 = np.random.randn(d)
x = 10000 * np.random.randn(n, d) - 20000 * m1
y = 10000 * np.random.randn(n, d) + 20000 * m2
data = np.empty((x.shape[0] + y.shape[0], d), np.float64)
data[:x.shape[0]] = x
data[x.shape[0]:] = y
kmeans(xp.asarray(data), 2)
def test_kmeans_simple(self, xp):
np.random.seed(54321)
initc = np.concatenate([[X[0]], [X[1]], [X[2]]])
code1 = kmeans(xp.asarray(X), xp.asarray(initc), iter=1)[0]
xp_assert_close(code1, xp.asarray(CODET2))
@pytest.mark.skipif(SCIPY_ARRAY_API,
reason='`np.matrix` unsupported in array API mode')
def test_kmeans_simple_matrix(self, xp):
np.random.seed(54321)
initc = np.concatenate([[X[0]], [X[1]], [X[2]]])
code1 = kmeans(matrix(X), matrix(initc), iter=1)[0]
xp_assert_close(code1, CODET2)
def test_kmeans_lost_cluster(self, xp):
# This will cause kmeans to have a cluster with no points.
data = xp.asarray(TESTDATA_2D)
initk = xp.asarray([[-1.8127404, -0.67128041],
[2.04621601, 0.07401111],
[-2.31149087, -0.05160469]])
kmeans(data, initk)
with suppress_warnings() as sup:
sup.filter(UserWarning,
"One of the clusters is empty. Re-run kmeans with a "
"different initialization")
kmeans2(data, initk, missing='warn')
assert_raises(ClusterError, kmeans2, data, initk, missing='raise')
def test_kmeans2_simple(self, xp):
np.random.seed(12345678)
initc = xp.asarray(np.concatenate([[X[0]], [X[1]], [X[2]]]))
arrays = [xp.asarray] if SCIPY_ARRAY_API else [np.asarray, matrix]
for tp in arrays:
code1 = kmeans2(tp(X), tp(initc), iter=1)[0]
code2 = kmeans2(tp(X), tp(initc), iter=2)[0]
xp_assert_close(code1, xp.asarray(CODET1))
xp_assert_close(code2, xp.asarray(CODET2))
@pytest.mark.skipif(SCIPY_ARRAY_API,
reason='`np.matrix` unsupported in array API mode')
def test_kmeans2_simple_matrix(self, xp):
np.random.seed(12345678)
initc = xp.asarray(np.concatenate([[X[0]], [X[1]], [X[2]]]))
code1 = kmeans2(matrix(X), matrix(initc), iter=1)[0]
code2 = kmeans2(matrix(X), matrix(initc), iter=2)[0]
xp_assert_close(code1, CODET1)
xp_assert_close(code2, CODET2)
def test_kmeans2_rank1(self, xp):
data = xp.asarray(TESTDATA_2D)
data1 = data[:, 0]
initc = data1[:3]
code = copy(initc, xp=xp)
kmeans2(data1, code, iter=1)[0]
kmeans2(data1, code, iter=2)[0]
def test_kmeans2_rank1_2(self, xp):
data = xp.asarray(TESTDATA_2D)
data1 = data[:, 0]
kmeans2(data1, 2, iter=1)
def test_kmeans2_high_dim(self, xp):
# test kmeans2 when the number of dimensions exceeds the number
# of input points
data = xp.asarray(TESTDATA_2D)
data = xp.reshape(data, (20, 20))[:10, :]
kmeans2(data, 2)
@skip_xp_backends('jax.numpy',
reasons=['jax arrays do not support item assignment'],
cpu_only=True)
def test_kmeans2_init(self, xp):
np.random.seed(12345)
data = xp.asarray(TESTDATA_2D)
k = 3
kmeans2(data, k, minit='points')
kmeans2(data[:, 1], k, minit='points') # special case (1-D)
kmeans2(data, k, minit='++')
kmeans2(data[:, 1], k, minit='++') # special case (1-D)
# minit='random' can give warnings, filter those
with suppress_warnings() as sup:
sup.filter(message="One of the clusters is empty. Re-run.")
kmeans2(data, k, minit='random')
kmeans2(data[:, 1], k, minit='random') # special case (1-D)
@pytest.mark.skipif(sys.platform == 'win32',
reason='Fails with MemoryError in Wine.')
def test_krandinit(self, xp):
data = xp.asarray(TESTDATA_2D)
datas = [xp.reshape(data, (200, 2)),
xp.reshape(data, (20, 20))[:10, :]]
k = int(1e6)
for data in datas:
rng = np.random.default_rng(1234)
init = _krandinit(data, k, rng, xp)
orig_cov = cov(data.T)
init_cov = cov(init.T)
xp_assert_close(orig_cov, init_cov, atol=1.1e-2)
def test_kmeans2_empty(self, xp):
# Regression test for gh-1032.
assert_raises(ValueError, kmeans2, xp.asarray([]), 2)
def test_kmeans_0k(self, xp):
# Regression test for gh-1073: fail when k arg is 0.
assert_raises(ValueError, kmeans, xp.asarray(X), 0)
assert_raises(ValueError, kmeans2, xp.asarray(X), 0)
assert_raises(ValueError, kmeans2, xp.asarray(X), xp.asarray([]))
def test_kmeans_large_thres(self, xp):
# Regression test for gh-1774
x = xp.asarray([1, 2, 3, 4, 10], dtype=xp.float64)
res = kmeans(x, 1, thresh=1e16)
xp_assert_close(res[0], xp.asarray([4.], dtype=xp.float64))
xp_assert_close(res[1], xp.asarray(2.3999999999999999, dtype=xp.float64)[()])
@skip_xp_backends('jax.numpy',
reasons=['jax arrays do not support item assignment'],
cpu_only=True)
def test_kmeans2_kpp_low_dim(self, xp):
# Regression test for gh-11462
prev_res = xp.asarray([[-1.95266667, 0.898],
[-3.153375, 3.3945]], dtype=xp.float64)
np.random.seed(42)
res, _ = kmeans2(xp.asarray(TESTDATA_2D), 2, minit='++')
xp_assert_close(res, prev_res)
@skip_xp_backends('jax.numpy',
reasons=['jax arrays do not support item assignment'],
cpu_only=True)
def test_kmeans2_kpp_high_dim(self, xp):
# Regression test for gh-11462
n_dim = 100
size = 10
centers = np.vstack([5 * np.ones(n_dim),
-5 * np.ones(n_dim)])
np.random.seed(42)
data = np.vstack([
np.random.multivariate_normal(centers[0], np.eye(n_dim), size=size),
np.random.multivariate_normal(centers[1], np.eye(n_dim), size=size)
])
data = xp.asarray(data)
res, _ = kmeans2(data, 2, minit='++')
xp_assert_equal(xp.sign(res), xp.sign(xp.asarray(centers)))
def test_kmeans_diff_convergence(self, xp):
# Regression test for gh-8727
obs = xp.asarray([-3, -1, 0, 1, 1, 8], dtype=xp.float64)
res = kmeans(obs, xp.asarray([-3., 0.99]))
xp_assert_close(res[0], xp.asarray([-0.4, 8.], dtype=xp.float64))
xp_assert_close(res[1], xp.asarray(1.0666666666666667, dtype=xp.float64)[()])
@skip_xp_backends('jax.numpy',
reasons=['jax arrays do not support item assignment'],
cpu_only=True)
def test_kmeans_and_kmeans2_random_seed(self, xp):
seed_list = [
1234, np.random.RandomState(1234), np.random.default_rng(1234)
]
for seed in seed_list:
seed1 = deepcopy(seed)
seed2 = deepcopy(seed)
data = xp.asarray(TESTDATA_2D)
# test for kmeans
res1, _ = kmeans(data, 2, seed=seed1)
res2, _ = kmeans(data, 2, seed=seed2)
xp_assert_close(res1, res2) # should be same results
# test for kmeans2
for minit in ["random", "points", "++"]:
res1, _ = kmeans2(data, 2, minit=minit, seed=seed1)
res2, _ = kmeans2(data, 2, minit=minit, seed=seed2)
xp_assert_close(res1, res2) # should be same results

View File

@ -0,0 +1,835 @@
"""
K-means clustering and vector quantization (:mod:`scipy.cluster.vq`)
====================================================================
Provides routines for k-means clustering, generating code books
from k-means models and quantizing vectors by comparing them with
centroids in a code book.
.. autosummary::
:toctree: generated/
whiten -- Normalize a group of observations so each feature has unit variance
vq -- Calculate code book membership of a set of observation vectors
kmeans -- Perform k-means on a set of observation vectors forming k clusters
kmeans2 -- A different implementation of k-means with more methods
-- for initializing centroids
Background information
----------------------
The k-means algorithm takes as input the number of clusters to
generate, k, and a set of observation vectors to cluster. It
returns a set of centroids, one for each of the k clusters. An
observation vector is classified with the cluster number or
centroid index of the centroid closest to it.
A vector v belongs to cluster i if it is closer to centroid i than
any other centroid. If v belongs to i, we say centroid i is the
dominating centroid of v. The k-means algorithm tries to
minimize distortion, which is defined as the sum of the squared distances
between each observation vector and its dominating centroid.
The minimization is achieved by iteratively reclassifying
the observations into clusters and recalculating the centroids until
a configuration is reached in which the centroids are stable. One can
also define a maximum number of iterations.
Since vector quantization is a natural application for k-means,
information theory terminology is often used. The centroid index
or cluster index is also referred to as a "code" and the table
mapping codes to centroids and, vice versa, is often referred to as a
"code book". The result of k-means, a set of centroids, can be
used to quantize vectors. Quantization aims to find an encoding of
vectors that reduces the expected distortion.
All routines expect obs to be an M by N array, where the rows are
the observation vectors. The codebook is a k by N array, where the
ith row is the centroid of code word i. The observation vectors
and centroids have the same feature dimension.
As an example, suppose we wish to compress a 24-bit color image
(each pixel is represented by one byte for red, one for blue, and
one for green) before sending it over the web. By using a smaller
8-bit encoding, we can reduce the amount of data by two
thirds. Ideally, the colors for each of the 256 possible 8-bit
encoding values should be chosen to minimize distortion of the
color. Running k-means with k=256 generates a code book of 256
codes, which fills up all possible 8-bit sequences. Instead of
sending a 3-byte value for each pixel, the 8-bit centroid index
(or code word) of the dominating centroid is transmitted. The code
book is also sent over the wire so each 8-bit code can be
translated back to a 24-bit pixel value representation. If the
image of interest was of an ocean, we would expect many 24-bit
blues to be represented by 8-bit codes. If it was an image of a
human face, more flesh-tone colors would be represented in the
code book.
"""
import warnings
import numpy as np
from collections import deque
from scipy._lib._array_api import (
_asarray, array_namespace, size, atleast_nd, copy, cov
)
from scipy._lib._util import check_random_state, rng_integers
from scipy.spatial.distance import cdist
from . import _vq
__docformat__ = 'restructuredtext'
__all__ = ['whiten', 'vq', 'kmeans', 'kmeans2']
class ClusterError(Exception):
pass
def whiten(obs, check_finite=True):
"""
Normalize a group of observations on a per feature basis.
Before running k-means, it is beneficial to rescale each feature
dimension of the observation set by its standard deviation (i.e. "whiten"
it - as in "white noise" where each frequency has equal power).
Each feature is divided by its standard deviation across all observations
to give it unit variance.
Parameters
----------
obs : ndarray
Each row of the array is an observation. The
columns are the features seen during each observation.
>>> # f0 f1 f2
>>> obs = [[ 1., 1., 1.], #o0
... [ 2., 2., 2.], #o1
... [ 3., 3., 3.], #o2
... [ 4., 4., 4.]] #o3
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Default: True
Returns
-------
result : ndarray
Contains the values in `obs` scaled by the standard deviation
of each column.
Examples
--------
>>> import numpy as np
>>> from scipy.cluster.vq import whiten
>>> features = np.array([[1.9, 2.3, 1.7],
... [1.5, 2.5, 2.2],
... [0.8, 0.6, 1.7,]])
>>> whiten(features)
array([[ 4.17944278, 2.69811351, 7.21248917],
[ 3.29956009, 2.93273208, 9.33380951],
[ 1.75976538, 0.7038557 , 7.21248917]])
"""
xp = array_namespace(obs)
obs = _asarray(obs, check_finite=check_finite, xp=xp)
std_dev = xp.std(obs, axis=0)
zero_std_mask = std_dev == 0
if xp.any(zero_std_mask):
std_dev[zero_std_mask] = 1.0
warnings.warn("Some columns have standard deviation zero. "
"The values of these columns will not change.",
RuntimeWarning, stacklevel=2)
return obs / std_dev
def vq(obs, code_book, check_finite=True):
"""
Assign codes from a code book to observations.
Assigns a code from a code book to each observation. Each
observation vector in the 'M' by 'N' `obs` array is compared with the
centroids in the code book and assigned the code of the closest
centroid.
The features in `obs` should have unit variance, which can be
achieved by passing them through the whiten function. The code
book can be created with the k-means algorithm or a different
encoding algorithm.
Parameters
----------
obs : ndarray
Each row of the 'M' x 'N' array is an observation. The columns are
the "features" seen during each observation. The features must be
whitened first using the whiten function or something equivalent.
code_book : ndarray
The code book is usually generated using the k-means algorithm.
Each row of the array holds a different code, and the columns are
the features of the code.
>>> # f0 f1 f2 f3
>>> code_book = [
... [ 1., 2., 3., 4.], #c0
... [ 1., 2., 3., 4.], #c1
... [ 1., 2., 3., 4.]] #c2
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Default: True
Returns
-------
code : ndarray
A length M array holding the code book index for each observation.
dist : ndarray
The distortion (distance) between the observation and its nearest
code.
Examples
--------
>>> import numpy as np
>>> from scipy.cluster.vq import vq
>>> code_book = np.array([[1., 1., 1.],
... [2., 2., 2.]])
>>> features = np.array([[1.9, 2.3, 1.7],
... [1.5, 2.5, 2.2],
... [0.8, 0.6, 1.7]])
>>> vq(features, code_book)
(array([1, 1, 0], dtype=int32), array([0.43588989, 0.73484692, 0.83066239]))
"""
xp = array_namespace(obs, code_book)
obs = _asarray(obs, xp=xp, check_finite=check_finite)
code_book = _asarray(code_book, xp=xp, check_finite=check_finite)
ct = xp.result_type(obs, code_book)
c_obs = xp.astype(obs, ct, copy=False)
c_code_book = xp.astype(code_book, ct, copy=False)
if xp.isdtype(ct, kind='real floating'):
c_obs = np.asarray(c_obs)
c_code_book = np.asarray(c_code_book)
result = _vq.vq(c_obs, c_code_book)
return xp.asarray(result[0]), xp.asarray(result[1])
return py_vq(obs, code_book, check_finite=False)
def py_vq(obs, code_book, check_finite=True):
""" Python version of vq algorithm.
The algorithm computes the Euclidean distance between each
observation and every frame in the code_book.
Parameters
----------
obs : ndarray
Expects a rank 2 array. Each row is one observation.
code_book : ndarray
Code book to use. Same format than obs. Should have same number of
features (e.g., columns) than obs.
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Default: True
Returns
-------
code : ndarray
code[i] gives the label of the ith obversation; its code is
code_book[code[i]].
mind_dist : ndarray
min_dist[i] gives the distance between the ith observation and its
corresponding code.
Notes
-----
This function is slower than the C version but works for
all input types. If the inputs have the wrong types for the
C versions of the function, this one is called as a last resort.
It is about 20 times slower than the C version.
"""
xp = array_namespace(obs, code_book)
obs = _asarray(obs, xp=xp, check_finite=check_finite)
code_book = _asarray(code_book, xp=xp, check_finite=check_finite)
if obs.ndim != code_book.ndim:
raise ValueError("Observation and code_book should have the same rank")
if obs.ndim == 1:
obs = obs[:, xp.newaxis]
code_book = code_book[:, xp.newaxis]
# Once `cdist` has array API support, this `xp.asarray` call can be removed
dist = xp.asarray(cdist(obs, code_book))
code = xp.argmin(dist, axis=1)
min_dist = xp.min(dist, axis=1)
return code, min_dist
def _kmeans(obs, guess, thresh=1e-5, xp=None):
""" "raw" version of k-means.
Returns
-------
code_book
The lowest distortion codebook found.
avg_dist
The average distance a observation is from a code in the book.
Lower means the code_book matches the data better.
See Also
--------
kmeans : wrapper around k-means
Examples
--------
Note: not whitened in this example.
>>> import numpy as np
>>> from scipy.cluster.vq import _kmeans
>>> features = np.array([[ 1.9,2.3],
... [ 1.5,2.5],
... [ 0.8,0.6],
... [ 0.4,1.8],
... [ 1.0,1.0]])
>>> book = np.array((features[0],features[2]))
>>> _kmeans(features,book)
(array([[ 1.7 , 2.4 ],
[ 0.73333333, 1.13333333]]), 0.40563916697728591)
"""
xp = np if xp is None else xp
code_book = guess
diff = xp.inf
prev_avg_dists = deque([diff], maxlen=2)
while diff > thresh:
# compute membership and distances between obs and code_book
obs_code, distort = vq(obs, code_book, check_finite=False)
prev_avg_dists.append(xp.mean(distort, axis=-1))
# recalc code_book as centroids of associated obs
obs = np.asarray(obs)
obs_code = np.asarray(obs_code)
code_book, has_members = _vq.update_cluster_means(obs, obs_code,
code_book.shape[0])
obs = xp.asarray(obs)
obs_code = xp.asarray(obs_code)
code_book = xp.asarray(code_book)
has_members = xp.asarray(has_members)
code_book = code_book[has_members]
diff = xp.abs(prev_avg_dists[0] - prev_avg_dists[1])
return code_book, prev_avg_dists[1]
def kmeans(obs, k_or_guess, iter=20, thresh=1e-5, check_finite=True,
*, seed=None):
"""
Performs k-means on a set of observation vectors forming k clusters.
The k-means algorithm adjusts the classification of the observations
into clusters and updates the cluster centroids until the position of
the centroids is stable over successive iterations. In this
implementation of the algorithm, the stability of the centroids is
determined by comparing the absolute value of the change in the average
Euclidean distance between the observations and their corresponding
centroids against a threshold. This yields
a code book mapping centroids to codes and vice versa.
Parameters
----------
obs : ndarray
Each row of the M by N array is an observation vector. The
columns are the features seen during each observation.
The features must be whitened first with the `whiten` function.
k_or_guess : int or ndarray
The number of centroids to generate. A code is assigned to
each centroid, which is also the row index of the centroid
in the code_book matrix generated.
The initial k centroids are chosen by randomly selecting
observations from the observation matrix. Alternatively,
passing a k by N array specifies the initial k centroids.
iter : int, optional
The number of times to run k-means, returning the codebook
with the lowest distortion. This argument is ignored if
initial centroids are specified with an array for the
``k_or_guess`` parameter. This parameter does not represent the
number of iterations of the k-means algorithm.
thresh : float, optional
Terminates the k-means algorithm if the change in
distortion since the last k-means iteration is less than
or equal to threshold.
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Default: True
seed : {None, int, `numpy.random.Generator`, `numpy.random.RandomState`}, optional
Seed for initializing the pseudo-random number generator.
If `seed` is None (or `numpy.random`), the `numpy.random.RandomState`
singleton is used.
If `seed` is an int, a new ``RandomState`` instance is used,
seeded with `seed`.
If `seed` is already a ``Generator`` or ``RandomState`` instance then
that instance is used.
The default is None.
Returns
-------
codebook : ndarray
A k by N array of k centroids. The ith centroid
codebook[i] is represented with the code i. The centroids
and codes generated represent the lowest distortion seen,
not necessarily the globally minimal distortion.
Note that the number of centroids is not necessarily the same as the
``k_or_guess`` parameter, because centroids assigned to no observations
are removed during iterations.
distortion : float
The mean (non-squared) Euclidean distance between the observations
passed and the centroids generated. Note the difference to the standard
definition of distortion in the context of the k-means algorithm, which
is the sum of the squared distances.
See Also
--------
kmeans2 : a different implementation of k-means clustering
with more methods for generating initial centroids but without
using a distortion change threshold as a stopping criterion.
whiten : must be called prior to passing an observation matrix
to kmeans.
Notes
-----
For more functionalities or optimal performance, you can use
`sklearn.cluster.KMeans <https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html>`_.
`This <https://hdbscan.readthedocs.io/en/latest/performance_and_scalability.html#comparison-of-high-performance-implementations>`_
is a benchmark result of several implementations.
Examples
--------
>>> import numpy as np
>>> from scipy.cluster.vq import vq, kmeans, whiten
>>> import matplotlib.pyplot as plt
>>> features = np.array([[ 1.9,2.3],
... [ 1.5,2.5],
... [ 0.8,0.6],
... [ 0.4,1.8],
... [ 0.1,0.1],
... [ 0.2,1.8],
... [ 2.0,0.5],
... [ 0.3,1.5],
... [ 1.0,1.0]])
>>> whitened = whiten(features)
>>> book = np.array((whitened[0],whitened[2]))
>>> kmeans(whitened,book)
(array([[ 2.3110306 , 2.86287398], # random
[ 0.93218041, 1.24398691]]), 0.85684700941625547)
>>> codes = 3
>>> kmeans(whitened,codes)
(array([[ 2.3110306 , 2.86287398], # random
[ 1.32544402, 0.65607529],
[ 0.40782893, 2.02786907]]), 0.5196582527686241)
>>> # Create 50 datapoints in two clusters a and b
>>> pts = 50
>>> rng = np.random.default_rng()
>>> a = rng.multivariate_normal([0, 0], [[4, 1], [1, 4]], size=pts)
>>> b = rng.multivariate_normal([30, 10],
... [[10, 2], [2, 1]],
... size=pts)
>>> features = np.concatenate((a, b))
>>> # Whiten data
>>> whitened = whiten(features)
>>> # Find 2 clusters in the data
>>> codebook, distortion = kmeans(whitened, 2)
>>> # Plot whitened data and cluster centers in red
>>> plt.scatter(whitened[:, 0], whitened[:, 1])
>>> plt.scatter(codebook[:, 0], codebook[:, 1], c='r')
>>> plt.show()
"""
if isinstance(k_or_guess, int):
xp = array_namespace(obs)
else:
xp = array_namespace(obs, k_or_guess)
obs = _asarray(obs, xp=xp, check_finite=check_finite)
guess = _asarray(k_or_guess, xp=xp, check_finite=check_finite)
if iter < 1:
raise ValueError("iter must be at least 1, got %s" % iter)
# Determine whether a count (scalar) or an initial guess (array) was passed.
if size(guess) != 1:
if size(guess) < 1:
raise ValueError("Asked for 0 clusters. Initial book was %s" %
guess)
return _kmeans(obs, guess, thresh=thresh, xp=xp)
# k_or_guess is a scalar, now verify that it's an integer
k = int(guess)
if k != guess:
raise ValueError("If k_or_guess is a scalar, it must be an integer.")
if k < 1:
raise ValueError("Asked for %d clusters." % k)
rng = check_random_state(seed)
# initialize best distance value to a large value
best_dist = xp.inf
for i in range(iter):
# the initial code book is randomly selected from observations
guess = _kpoints(obs, k, rng, xp)
book, dist = _kmeans(obs, guess, thresh=thresh, xp=xp)
if dist < best_dist:
best_book = book
best_dist = dist
return best_book, best_dist
def _kpoints(data, k, rng, xp):
"""Pick k points at random in data (one row = one observation).
Parameters
----------
data : ndarray
Expect a rank 1 or 2 array. Rank 1 are assumed to describe one
dimensional data, rank 2 multidimensional data, in which case one
row is one observation.
k : int
Number of samples to generate.
rng : `numpy.random.Generator` or `numpy.random.RandomState`
Random number generator.
Returns
-------
x : ndarray
A 'k' by 'N' containing the initial centroids
"""
idx = rng.choice(data.shape[0], size=int(k), replace=False)
# convert to array with default integer dtype (avoids numpy#25607)
idx = xp.asarray(idx, dtype=xp.asarray([1]).dtype)
return xp.take(data, idx, axis=0)
def _krandinit(data, k, rng, xp):
"""Returns k samples of a random variable whose parameters depend on data.
More precisely, it returns k observations sampled from a Gaussian random
variable whose mean and covariances are the ones estimated from the data.
Parameters
----------
data : ndarray
Expect a rank 1 or 2 array. Rank 1 is assumed to describe 1-D
data, rank 2 multidimensional data, in which case one
row is one observation.
k : int
Number of samples to generate.
rng : `numpy.random.Generator` or `numpy.random.RandomState`
Random number generator.
Returns
-------
x : ndarray
A 'k' by 'N' containing the initial centroids
"""
mu = xp.mean(data, axis=0)
k = np.asarray(k)
if data.ndim == 1:
_cov = cov(data)
x = rng.standard_normal(size=k)
x = xp.asarray(x)
x *= xp.sqrt(_cov)
elif data.shape[1] > data.shape[0]:
# initialize when the covariance matrix is rank deficient
_, s, vh = xp.linalg.svd(data - mu, full_matrices=False)
x = rng.standard_normal(size=(k, size(s)))
x = xp.asarray(x)
sVh = s[:, None] * vh / xp.sqrt(data.shape[0] - xp.asarray(1.))
x = x @ sVh
else:
_cov = atleast_nd(cov(data.T), ndim=2)
# k rows, d cols (one row = one obs)
# Generate k sample of a random variable ~ Gaussian(mu, cov)
x = rng.standard_normal(size=(k, size(mu)))
x = xp.asarray(x)
x = x @ xp.linalg.cholesky(_cov).T
x += mu
return x
def _kpp(data, k, rng, xp):
""" Picks k points in the data based on the kmeans++ method.
Parameters
----------
data : ndarray
Expect a rank 1 or 2 array. Rank 1 is assumed to describe 1-D
data, rank 2 multidimensional data, in which case one
row is one observation.
k : int
Number of samples to generate.
rng : `numpy.random.Generator` or `numpy.random.RandomState`
Random number generator.
Returns
-------
init : ndarray
A 'k' by 'N' containing the initial centroids.
References
----------
.. [1] D. Arthur and S. Vassilvitskii, "k-means++: the advantages of
careful seeding", Proceedings of the Eighteenth Annual ACM-SIAM Symposium
on Discrete Algorithms, 2007.
"""
ndim = len(data.shape)
if ndim == 1:
data = data[:, None]
dims = data.shape[1]
init = xp.empty((int(k), dims))
for i in range(k):
if i == 0:
init[i, :] = data[rng_integers(rng, data.shape[0]), :]
else:
D2 = cdist(init[:i,:], data, metric='sqeuclidean').min(axis=0)
probs = D2/D2.sum()
cumprobs = probs.cumsum()
r = rng.uniform()
cumprobs = np.asarray(cumprobs)
init[i, :] = data[np.searchsorted(cumprobs, r), :]
if ndim == 1:
init = init[:, 0]
return init
_valid_init_meth = {'random': _krandinit, 'points': _kpoints, '++': _kpp}
def _missing_warn():
"""Print a warning when called."""
warnings.warn("One of the clusters is empty. "
"Re-run kmeans with a different initialization.",
stacklevel=3)
def _missing_raise():
"""Raise a ClusterError when called."""
raise ClusterError("One of the clusters is empty. "
"Re-run kmeans with a different initialization.")
_valid_miss_meth = {'warn': _missing_warn, 'raise': _missing_raise}
def kmeans2(data, k, iter=10, thresh=1e-5, minit='random',
missing='warn', check_finite=True, *, seed=None):
"""
Classify a set of observations into k clusters using the k-means algorithm.
The algorithm attempts to minimize the Euclidean distance between
observations and centroids. Several initialization methods are
included.
Parameters
----------
data : ndarray
A 'M' by 'N' array of 'M' observations in 'N' dimensions or a length
'M' array of 'M' 1-D observations.
k : int or ndarray
The number of clusters to form as well as the number of
centroids to generate. If `minit` initialization string is
'matrix', or if a ndarray is given instead, it is
interpreted as initial cluster to use instead.
iter : int, optional
Number of iterations of the k-means algorithm to run. Note
that this differs in meaning from the iters parameter to
the kmeans function.
thresh : float, optional
(not used yet)
minit : str, optional
Method for initialization. Available methods are 'random',
'points', '++' and 'matrix':
'random': generate k centroids from a Gaussian with mean and
variance estimated from the data.
'points': choose k observations (rows) at random from data for
the initial centroids.
'++': choose k observations accordingly to the kmeans++ method
(careful seeding)
'matrix': interpret the k parameter as a k by M (or length k
array for 1-D data) array of initial centroids.
missing : str, optional
Method to deal with empty clusters. Available methods are
'warn' and 'raise':
'warn': give a warning and continue.
'raise': raise an ClusterError and terminate the algorithm.
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
Default: True
seed : {None, int, `numpy.random.Generator`, `numpy.random.RandomState`}, optional
Seed for initializing the pseudo-random number generator.
If `seed` is None (or `numpy.random`), the `numpy.random.RandomState`
singleton is used.
If `seed` is an int, a new ``RandomState`` instance is used,
seeded with `seed`.
If `seed` is already a ``Generator`` or ``RandomState`` instance then
that instance is used.
The default is None.
Returns
-------
centroid : ndarray
A 'k' by 'N' array of centroids found at the last iteration of
k-means.
label : ndarray
label[i] is the code or index of the centroid the
ith observation is closest to.
See Also
--------
kmeans
References
----------
.. [1] D. Arthur and S. Vassilvitskii, "k-means++: the advantages of
careful seeding", Proceedings of the Eighteenth Annual ACM-SIAM Symposium
on Discrete Algorithms, 2007.
Examples
--------
>>> from scipy.cluster.vq import kmeans2
>>> import matplotlib.pyplot as plt
>>> import numpy as np
Create z, an array with shape (100, 2) containing a mixture of samples
from three multivariate normal distributions.
>>> rng = np.random.default_rng()
>>> a = rng.multivariate_normal([0, 6], [[2, 1], [1, 1.5]], size=45)
>>> b = rng.multivariate_normal([2, 0], [[1, -1], [-1, 3]], size=30)
>>> c = rng.multivariate_normal([6, 4], [[5, 0], [0, 1.2]], size=25)
>>> z = np.concatenate((a, b, c))
>>> rng.shuffle(z)
Compute three clusters.
>>> centroid, label = kmeans2(z, 3, minit='points')
>>> centroid
array([[ 2.22274463, -0.61666946], # may vary
[ 0.54069047, 5.86541444],
[ 6.73846769, 4.01991898]])
How many points are in each cluster?
>>> counts = np.bincount(label)
>>> counts
array([29, 51, 20]) # may vary
Plot the clusters.
>>> w0 = z[label == 0]
>>> w1 = z[label == 1]
>>> w2 = z[label == 2]
>>> plt.plot(w0[:, 0], w0[:, 1], 'o', alpha=0.5, label='cluster 0')
>>> plt.plot(w1[:, 0], w1[:, 1], 'd', alpha=0.5, label='cluster 1')
>>> plt.plot(w2[:, 0], w2[:, 1], 's', alpha=0.5, label='cluster 2')
>>> plt.plot(centroid[:, 0], centroid[:, 1], 'k*', label='centroids')
>>> plt.axis('equal')
>>> plt.legend(shadow=True)
>>> plt.show()
"""
if int(iter) < 1:
raise ValueError("Invalid iter (%s), "
"must be a positive integer." % iter)
try:
miss_meth = _valid_miss_meth[missing]
except KeyError as e:
raise ValueError(f"Unknown missing method {missing!r}") from e
if isinstance(k, int):
xp = array_namespace(data)
else:
xp = array_namespace(data, k)
data = _asarray(data, xp=xp, check_finite=check_finite)
code_book = copy(k, xp=xp)
if data.ndim == 1:
d = 1
elif data.ndim == 2:
d = data.shape[1]
else:
raise ValueError("Input of rank > 2 is not supported.")
if size(data) < 1 or size(code_book) < 1:
raise ValueError("Empty input is not supported.")
# If k is not a single value, it should be compatible with data's shape
if minit == 'matrix' or size(code_book) > 1:
if data.ndim != code_book.ndim:
raise ValueError("k array doesn't match data rank")
nc = code_book.shape[0]
if data.ndim > 1 and code_book.shape[1] != d:
raise ValueError("k array doesn't match data dimension")
else:
nc = int(code_book)
if nc < 1:
raise ValueError("Cannot ask kmeans2 for %d clusters"
" (k was %s)" % (nc, code_book))
elif nc != code_book:
warnings.warn("k was not an integer, was converted.", stacklevel=2)
try:
init_meth = _valid_init_meth[minit]
except KeyError as e:
raise ValueError(f"Unknown init method {minit!r}") from e
else:
rng = check_random_state(seed)
code_book = init_meth(data, code_book, rng, xp)
data = np.asarray(data)
code_book = np.asarray(code_book)
for i in range(iter):
# Compute the nearest neighbor for each obs using the current code book
label = vq(data, code_book, check_finite=check_finite)[0]
# Update the code book by computing centroids
new_code_book, has_members = _vq.update_cluster_means(data, label, nc)
if not has_members.all():
miss_meth()
# Set the empty clusters to their previous positions
new_code_book[~has_members] = code_book[~has_members]
code_book = new_code_book
return xp.asarray(code_book), xp.asarray(label)

View File

@ -0,0 +1,413 @@
# Pytest customization
import json
import os
import warnings
import tempfile
from contextlib import contextmanager
import numpy as np
import numpy.testing as npt
import pytest
import hypothesis
from scipy._lib._fpumode import get_fpu_mode
from scipy._lib._testutils import FPUModeChangeWarning
from scipy._lib._array_api import SCIPY_ARRAY_API, SCIPY_DEVICE
from scipy._lib import _pep440
try:
from scipy_doctest.conftest import dt_config
HAVE_SCPDT = True
except ModuleNotFoundError:
HAVE_SCPDT = False
def pytest_configure(config):
config.addinivalue_line("markers",
"slow: Tests that are very slow.")
config.addinivalue_line("markers",
"xslow: mark test as extremely slow (not run unless explicitly requested)")
config.addinivalue_line("markers",
"xfail_on_32bit: mark test as failing on 32-bit platforms")
try:
import pytest_timeout # noqa:F401
except Exception:
config.addinivalue_line(
"markers", 'timeout: mark a test for a non-default timeout')
try:
# This is a more reliable test of whether pytest_fail_slow is installed
# When I uninstalled it, `import pytest_fail_slow` didn't fail!
from pytest_fail_slow import parse_duration # type: ignore[import-not-found] # noqa:F401,E501
except Exception:
config.addinivalue_line(
"markers", 'fail_slow: mark a test for a non-default timeout failure')
config.addinivalue_line("markers",
"skip_xp_backends(*backends, reasons=None, np_only=False, cpu_only=False): "
"mark the desired skip configuration for the `skip_xp_backends` fixture.")
def pytest_runtest_setup(item):
mark = item.get_closest_marker("xslow")
if mark is not None:
try:
v = int(os.environ.get('SCIPY_XSLOW', '0'))
except ValueError:
v = False
if not v:
pytest.skip("very slow test; "
"set environment variable SCIPY_XSLOW=1 to run it")
mark = item.get_closest_marker("xfail_on_32bit")
if mark is not None and np.intp(0).itemsize < 8:
pytest.xfail(f'Fails on our 32-bit test platform(s): {mark.args[0]}')
# Older versions of threadpoolctl have an issue that may lead to this
# warning being emitted, see gh-14441
with npt.suppress_warnings() as sup:
sup.filter(pytest.PytestUnraisableExceptionWarning)
try:
from threadpoolctl import threadpool_limits
HAS_THREADPOOLCTL = True
except Exception: # observed in gh-14441: (ImportError, AttributeError)
# Optional dependency only. All exceptions are caught, for robustness
HAS_THREADPOOLCTL = False
if HAS_THREADPOOLCTL:
# Set the number of openmp threads based on the number of workers
# xdist is using to prevent oversubscription. Simplified version of what
# sklearn does (it can rely on threadpoolctl and its builtin OpenMP helper
# functions)
try:
xdist_worker_count = int(os.environ['PYTEST_XDIST_WORKER_COUNT'])
except KeyError:
# raises when pytest-xdist is not installed
return
if not os.getenv('OMP_NUM_THREADS'):
max_openmp_threads = os.cpu_count() // 2 # use nr of physical cores
threads_per_worker = max(max_openmp_threads // xdist_worker_count, 1)
try:
threadpool_limits(threads_per_worker, user_api='blas')
except Exception:
# May raise AttributeError for older versions of OpenBLAS.
# Catch any error for robustness.
return
@pytest.fixture(scope="function", autouse=True)
def check_fpu_mode(request):
"""
Check FPU mode was not changed during the test.
"""
old_mode = get_fpu_mode()
yield
new_mode = get_fpu_mode()
if old_mode != new_mode:
warnings.warn(f"FPU mode changed from {old_mode:#x} to {new_mode:#x} during "
"the test",
category=FPUModeChangeWarning, stacklevel=0)
# Array API backend handling
xp_available_backends = {'numpy': np}
if SCIPY_ARRAY_API and isinstance(SCIPY_ARRAY_API, str):
# fill the dict of backends with available libraries
try:
import array_api_strict
xp_available_backends.update({'array_api_strict': array_api_strict})
if _pep440.parse(array_api_strict.__version__) < _pep440.Version('2.0'):
raise ImportError("array-api-strict must be >= version 2.0")
array_api_strict.set_array_api_strict_flags(
api_version='2023.12'
)
except ImportError:
pass
try:
import torch # type: ignore[import-not-found]
xp_available_backends.update({'pytorch': torch})
# can use `mps` or `cpu`
torch.set_default_device(SCIPY_DEVICE)
except ImportError:
pass
try:
import cupy # type: ignore[import-not-found]
xp_available_backends.update({'cupy': cupy})
except ImportError:
pass
try:
import jax.numpy # type: ignore[import-not-found]
xp_available_backends.update({'jax.numpy': jax.numpy})
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_default_device", jax.devices(SCIPY_DEVICE)[0])
except ImportError:
pass
# by default, use all available backends
if SCIPY_ARRAY_API.lower() not in ("1", "true"):
SCIPY_ARRAY_API_ = json.loads(SCIPY_ARRAY_API)
if 'all' in SCIPY_ARRAY_API_:
pass # same as True
else:
# only select a subset of backend by filtering out the dict
try:
xp_available_backends = {
backend: xp_available_backends[backend]
for backend in SCIPY_ARRAY_API_
}
except KeyError:
msg = f"'--array-api-backend' must be in {xp_available_backends.keys()}"
raise ValueError(msg)
if 'cupy' in xp_available_backends:
SCIPY_DEVICE = 'cuda'
array_api_compatible = pytest.mark.parametrize("xp", xp_available_backends.values())
skip_xp_invalid_arg = pytest.mark.skipif(SCIPY_ARRAY_API,
reason = ('Test involves masked arrays, object arrays, or other types '
'that are not valid input when `SCIPY_ARRAY_API` is used.'))
@pytest.fixture
def skip_xp_backends(xp, request):
"""
Skip based on the ``skip_xp_backends`` marker.
Parameters
----------
*backends : tuple
Backends to skip, e.g. ``("array_api_strict", "torch")``.
These are overriden when ``np_only`` is ``True``, and are not
necessary to provide for non-CPU backends when ``cpu_only`` is ``True``.
reasons : list, optional
A list of reasons for each skip. When ``np_only`` is ``True``,
this should be a singleton list. Otherwise, this should be a list
of reasons, one for each corresponding backend in ``backends``.
If unprovided, default reasons are used. Note that it is not possible
to specify a custom reason with ``cpu_only``. Default: ``None``.
np_only : bool, optional
When ``True``, the test is skipped for all backends other
than the default NumPy backend. There is no need to provide
any ``backends`` in this case. To specify a reason, pass a
singleton list to ``reasons``. Default: ``False``.
cpu_only : bool, optional
When ``True``, the test is skipped on non-CPU devices.
There is no need to provide any ``backends`` in this case,
but any ``backends`` will also be skipped on the CPU.
Default: ``False``.
"""
if "skip_xp_backends" not in request.keywords:
return
backends = request.keywords["skip_xp_backends"].args
kwargs = request.keywords["skip_xp_backends"].kwargs
np_only = kwargs.get("np_only", False)
cpu_only = kwargs.get("cpu_only", False)
if np_only:
reasons = kwargs.get("reasons", ["do not run with non-NumPy backends."])
reason = reasons[0]
if xp.__name__ != 'numpy':
pytest.skip(reason=reason)
return
if cpu_only:
reason = "do not run with `SCIPY_ARRAY_API` set and not on CPU"
if SCIPY_ARRAY_API and SCIPY_DEVICE != 'cpu':
if xp.__name__ == 'cupy':
pytest.skip(reason=reason)
elif xp.__name__ == 'torch':
if 'cpu' not in xp.empty(0).device.type:
pytest.skip(reason=reason)
elif xp.__name__ == 'jax.numpy':
for d in xp.empty(0).devices():
if 'cpu' not in d.device_kind:
pytest.skip(reason=reason)
if backends is not None:
reasons = kwargs.get("reasons", False)
for i, backend in enumerate(backends):
if xp.__name__ == backend:
if not reasons:
reason = f"do not run with array API backend: {backend}"
else:
reason = reasons[i]
pytest.skip(reason=reason)
# Following the approach of NumPy's conftest.py...
# Use a known and persistent tmpdir for hypothesis' caches, which
# can be automatically cleared by the OS or user.
hypothesis.configuration.set_hypothesis_home_dir(
os.path.join(tempfile.gettempdir(), ".hypothesis")
)
# We register two custom profiles for SciPy - for details see
# https://hypothesis.readthedocs.io/en/latest/settings.html
# The first is designed for our own CI runs; the latter also
# forces determinism and is designed for use via scipy.test()
hypothesis.settings.register_profile(
name="nondeterministic", deadline=None, print_blob=True,
)
hypothesis.settings.register_profile(
name="deterministic",
deadline=None, print_blob=True, database=None, derandomize=True,
suppress_health_check=list(hypothesis.HealthCheck),
)
# Profile is currently set by environment variable `SCIPY_HYPOTHESIS_PROFILE`
# In the future, it would be good to work the choice into dev.py.
SCIPY_HYPOTHESIS_PROFILE = os.environ.get("SCIPY_HYPOTHESIS_PROFILE",
"deterministic")
hypothesis.settings.load_profile(SCIPY_HYPOTHESIS_PROFILE)
############################################################################
# doctesting stuff
if HAVE_SCPDT:
# FIXME: populate the dict once
@contextmanager
def warnings_errors_and_rng(test=None):
"""Temporarily turn (almost) all warnings to errors.
Filter out known warnings which we allow.
"""
known_warnings = dict()
# these functions are known to emit "divide by zero" RuntimeWarnings
divide_by_zero = [
'scipy.linalg.norm', 'scipy.ndimage.center_of_mass',
]
for name in divide_by_zero:
known_warnings[name] = dict(category=RuntimeWarning,
message='divide by zero')
# Deprecated stuff in scipy.signal and elsewhere
deprecated = [
'scipy.signal.cwt', 'scipy.signal.morlet', 'scipy.signal.morlet2',
'scipy.signal.ricker',
'scipy.integrate.simpson',
'scipy.interpolate.interp2d',
]
for name in deprecated:
known_warnings[name] = dict(category=DeprecationWarning)
from scipy import integrate
# the funcions are known to emit IntergrationWarnings
integration_w = ['scipy.special.ellip_normal',
'scipy.special.ellip_harm_2',
]
for name in integration_w:
known_warnings[name] = dict(category=integrate.IntegrationWarning,
message='The occurrence of roundoff')
# scipy.stats deliberately emits UserWarnings sometimes
user_w = ['scipy.stats.anderson_ksamp', 'scipy.stats.kurtosistest',
'scipy.stats.normaltest', 'scipy.sparse.linalg.norm']
for name in user_w:
known_warnings[name] = dict(category=UserWarning)
# additional one-off warnings to filter
dct = {
'scipy.sparse.linalg.norm':
dict(category=UserWarning, message="Exited at iteration"),
# tutorials
'linalg.rst':
dict(message='the matrix subclass is not',
category=PendingDeprecationWarning),
'stats.rst':
dict(message='The maximum number of subdivisions',
category=integrate.IntegrationWarning),
}
known_warnings.update(dct)
# these legitimately emit warnings in examples
legit = set('scipy.signal.normalize')
# Now, the meat of the matter: filter warnings,
# also control the random seed for each doctest.
# XXX: this matches the refguide-check behavior, but is a tad strange:
# makes sure that the seed the old-fashioned np.random* methods is
# *NOT* reproducible but the new-style `default_rng()` *IS* repoducible.
# Should these two be either both repro or both not repro?
from scipy._lib._util import _fixed_default_rng
import numpy as np
with _fixed_default_rng():
np.random.seed(None)
with warnings.catch_warnings():
if test and test.name in known_warnings:
warnings.filterwarnings('ignore',
**known_warnings[test.name])
yield
elif test and test.name in legit:
yield
else:
warnings.simplefilter('error', Warning)
yield
dt_config.user_context_mgr = warnings_errors_and_rng
dt_config.skiplist = set([
'scipy.linalg.LinAlgError', # comes from numpy
'scipy.fftpack.fftshift', # fftpack stuff is also from numpy
'scipy.fftpack.ifftshift',
'scipy.fftpack.fftfreq',
'scipy.special.sinc', # sinc is from numpy
'scipy.optimize.show_options', # does not have much to doctest
'scipy.signal.normalize', # manipulates warnings (XXX temp skip)
'scipy.sparse.linalg.norm', # XXX temp skip
])
# these are affected by NumPy 2.0 scalar repr: rely on string comparison
if np.__version__ < "2":
dt_config.skiplist.update(set([
'scipy.io.hb_read',
'scipy.io.hb_write',
'scipy.sparse.csgraph.connected_components',
'scipy.sparse.csgraph.depth_first_order',
'scipy.sparse.csgraph.shortest_path',
'scipy.sparse.csgraph.floyd_warshall',
'scipy.sparse.csgraph.dijkstra',
'scipy.sparse.csgraph.bellman_ford',
'scipy.sparse.csgraph.johnson',
'scipy.sparse.csgraph.yen',
'scipy.sparse.csgraph.breadth_first_order',
'scipy.sparse.csgraph.reverse_cuthill_mckee',
'scipy.sparse.csgraph.structural_rank',
'scipy.sparse.csgraph.construct_dist_matrix',
'scipy.sparse.csgraph.reconstruct_path',
'scipy.ndimage.value_indices',
'scipy.stats.mstats.describe',
]))
# help pytest collection a bit: these names are either private
# (distributions), or just do not need doctesting.
dt_config.pytest_extra_ignore = [
"scipy.stats.distributions",
"scipy.optimize.cython_optimize",
"scipy.test",
"scipy.show_config",
]
dt_config.pytest_extra_xfail = {
# name: reason
"io.rst": "",
"ND_regular_grid.rst": "ReST parser limitation",
"extrapolation_examples.rst": "ReST parser limitation",
"sampling_pinv.rst": "__cinit__ unexpected argument",
"sampling_srou.rst": "nan in scalar_power",
"probability_distributions.rst": "integration warning",
}
# tutorials
dt_config.pseudocode = set(['integrate.nquad(func,'])
dt_config.local_resources = {'io.rst': ["octave_a.mat"]}
############################################################################

View File

@ -0,0 +1,347 @@
r"""
==================================
Constants (:mod:`scipy.constants`)
==================================
.. currentmodule:: scipy.constants
Physical and mathematical constants and units.
Mathematical constants
======================
================ =================================================================
``pi`` Pi
``golden`` Golden ratio
``golden_ratio`` Golden ratio
================ =================================================================
Physical constants
==================
=========================== =================================================================
``c`` speed of light in vacuum
``speed_of_light`` speed of light in vacuum
``mu_0`` the magnetic constant :math:`\mu_0`
``epsilon_0`` the electric constant (vacuum permittivity), :math:`\epsilon_0`
``h`` the Planck constant :math:`h`
``Planck`` the Planck constant :math:`h`
``hbar`` :math:`\hbar = h/(2\pi)`
``G`` Newtonian constant of gravitation
``gravitational_constant`` Newtonian constant of gravitation
``g`` standard acceleration of gravity
``e`` elementary charge
``elementary_charge`` elementary charge
``R`` molar gas constant
``gas_constant`` molar gas constant
``alpha`` fine-structure constant
``fine_structure`` fine-structure constant
``N_A`` Avogadro constant
``Avogadro`` Avogadro constant
``k`` Boltzmann constant
``Boltzmann`` Boltzmann constant
``sigma`` Stefan-Boltzmann constant :math:`\sigma`
``Stefan_Boltzmann`` Stefan-Boltzmann constant :math:`\sigma`
``Wien`` Wien displacement law constant
``Rydberg`` Rydberg constant
``m_e`` electron mass
``electron_mass`` electron mass
``m_p`` proton mass
``proton_mass`` proton mass
``m_n`` neutron mass
``neutron_mass`` neutron mass
=========================== =================================================================
Constants database
------------------
In addition to the above variables, :mod:`scipy.constants` also contains the
2018 CODATA recommended values [CODATA2018]_ database containing more physical
constants.
.. autosummary::
:toctree: generated/
value -- Value in physical_constants indexed by key
unit -- Unit in physical_constants indexed by key
precision -- Relative precision in physical_constants indexed by key
find -- Return list of physical_constant keys with a given string
ConstantWarning -- Constant sought not in newest CODATA data set
.. data:: physical_constants
Dictionary of physical constants, of the format
``physical_constants[name] = (value, unit, uncertainty)``.
Available constants:
====================================================================== ====
%(constant_names)s
====================================================================== ====
Units
=====
SI prefixes
-----------
============ =================================================================
``quetta`` :math:`10^{30}`
``ronna`` :math:`10^{27}`
``yotta`` :math:`10^{24}`
``zetta`` :math:`10^{21}`
``exa`` :math:`10^{18}`
``peta`` :math:`10^{15}`
``tera`` :math:`10^{12}`
``giga`` :math:`10^{9}`
``mega`` :math:`10^{6}`
``kilo`` :math:`10^{3}`
``hecto`` :math:`10^{2}`
``deka`` :math:`10^{1}`
``deci`` :math:`10^{-1}`
``centi`` :math:`10^{-2}`
``milli`` :math:`10^{-3}`
``micro`` :math:`10^{-6}`
``nano`` :math:`10^{-9}`
``pico`` :math:`10^{-12}`
``femto`` :math:`10^{-15}`
``atto`` :math:`10^{-18}`
``zepto`` :math:`10^{-21}`
``yocto`` :math:`10^{-24}`
``ronto`` :math:`10^{-27}`
``quecto`` :math:`10^{-30}`
============ =================================================================
Binary prefixes
---------------
============ =================================================================
``kibi`` :math:`2^{10}`
``mebi`` :math:`2^{20}`
``gibi`` :math:`2^{30}`
``tebi`` :math:`2^{40}`
``pebi`` :math:`2^{50}`
``exbi`` :math:`2^{60}`
``zebi`` :math:`2^{70}`
``yobi`` :math:`2^{80}`
============ =================================================================
Mass
----
================= ============================================================
``gram`` :math:`10^{-3}` kg
``metric_ton`` :math:`10^{3}` kg
``grain`` one grain in kg
``lb`` one pound (avoirdupous) in kg
``pound`` one pound (avoirdupous) in kg
``blob`` one inch version of a slug in kg (added in 1.0.0)
``slinch`` one inch version of a slug in kg (added in 1.0.0)
``slug`` one slug in kg (added in 1.0.0)
``oz`` one ounce in kg
``ounce`` one ounce in kg
``stone`` one stone in kg
``grain`` one grain in kg
``long_ton`` one long ton in kg
``short_ton`` one short ton in kg
``troy_ounce`` one Troy ounce in kg
``troy_pound`` one Troy pound in kg
``carat`` one carat in kg
``m_u`` atomic mass constant (in kg)
``u`` atomic mass constant (in kg)
``atomic_mass`` atomic mass constant (in kg)
================= ============================================================
Angle
-----
================= ============================================================
``degree`` degree in radians
``arcmin`` arc minute in radians
``arcminute`` arc minute in radians
``arcsec`` arc second in radians
``arcsecond`` arc second in radians
================= ============================================================
Time
----
================= ============================================================
``minute`` one minute in seconds
``hour`` one hour in seconds
``day`` one day in seconds
``week`` one week in seconds
``year`` one year (365 days) in seconds
``Julian_year`` one Julian year (365.25 days) in seconds
================= ============================================================
Length
------
===================== ============================================================
``inch`` one inch in meters
``foot`` one foot in meters
``yard`` one yard in meters
``mile`` one mile in meters
``mil`` one mil in meters
``pt`` one point in meters
``point`` one point in meters
``survey_foot`` one survey foot in meters
``survey_mile`` one survey mile in meters
``nautical_mile`` one nautical mile in meters
``fermi`` one Fermi in meters
``angstrom`` one Angstrom in meters
``micron`` one micron in meters
``au`` one astronomical unit in meters
``astronomical_unit`` one astronomical unit in meters
``light_year`` one light year in meters
``parsec`` one parsec in meters
===================== ============================================================
Pressure
--------
================= ============================================================
``atm`` standard atmosphere in pascals
``atmosphere`` standard atmosphere in pascals
``bar`` one bar in pascals
``torr`` one torr (mmHg) in pascals
``mmHg`` one torr (mmHg) in pascals
``psi`` one psi in pascals
================= ============================================================
Area
----
================= ============================================================
``hectare`` one hectare in square meters
``acre`` one acre in square meters
================= ============================================================
Volume
------
=================== ========================================================
``liter`` one liter in cubic meters
``litre`` one liter in cubic meters
``gallon`` one gallon (US) in cubic meters
``gallon_US`` one gallon (US) in cubic meters
``gallon_imp`` one gallon (UK) in cubic meters
``fluid_ounce`` one fluid ounce (US) in cubic meters
``fluid_ounce_US`` one fluid ounce (US) in cubic meters
``fluid_ounce_imp`` one fluid ounce (UK) in cubic meters
``bbl`` one barrel in cubic meters
``barrel`` one barrel in cubic meters
=================== ========================================================
Speed
-----
================== ==========================================================
``kmh`` kilometers per hour in meters per second
``mph`` miles per hour in meters per second
``mach`` one Mach (approx., at 15 C, 1 atm) in meters per second
``speed_of_sound`` one Mach (approx., at 15 C, 1 atm) in meters per second
``knot`` one knot in meters per second
================== ==========================================================
Temperature
-----------
===================== =======================================================
``zero_Celsius`` zero of Celsius scale in Kelvin
``degree_Fahrenheit`` one Fahrenheit (only differences) in Kelvins
===================== =======================================================
.. autosummary::
:toctree: generated/
convert_temperature
Energy
------
==================== =======================================================
``eV`` one electron volt in Joules
``electron_volt`` one electron volt in Joules
``calorie`` one calorie (thermochemical) in Joules
``calorie_th`` one calorie (thermochemical) in Joules
``calorie_IT`` one calorie (International Steam Table calorie, 1956) in Joules
``erg`` one erg in Joules
``Btu`` one British thermal unit (International Steam Table) in Joules
``Btu_IT`` one British thermal unit (International Steam Table) in Joules
``Btu_th`` one British thermal unit (thermochemical) in Joules
``ton_TNT`` one ton of TNT in Joules
==================== =======================================================
Power
-----
==================== =======================================================
``hp`` one horsepower in watts
``horsepower`` one horsepower in watts
==================== =======================================================
Force
-----
==================== =======================================================
``dyn`` one dyne in newtons
``dyne`` one dyne in newtons
``lbf`` one pound force in newtons
``pound_force`` one pound force in newtons
``kgf`` one kilogram force in newtons
``kilogram_force`` one kilogram force in newtons
==================== =======================================================
Optics
------
.. autosummary::
:toctree: generated/
lambda2nu
nu2lambda
References
==========
.. [CODATA2018] CODATA Recommended Values of the Fundamental
Physical Constants 2018.
https://physics.nist.gov/cuu/Constants/
""" # noqa: E501
# Modules contributed by BasSw (wegwerp@gmail.com)
from ._codata import *
from ._constants import *
from ._codata import _obsolete_constants, physical_constants
# Deprecated namespaces, to be removed in v2.0.0
from . import codata, constants
_constant_names_list = [(_k.lower(), _k, _v)
for _k, _v in physical_constants.items()
if _k not in _obsolete_constants]
_constant_names = "\n".join(["``{}``{} {} {}".format(_x[1], " "*(66-len(_x[1])),
_x[2][0], _x[2][1])
for _x in sorted(_constant_names_list)])
if __doc__:
__doc__ = __doc__ % dict(constant_names=_constant_names)
del _constant_names
del _constant_names_list
__all__ = [s for s in dir() if not s.startswith('_')]
from scipy._lib._testutils import PytestTester
test = PytestTester(__name__)
del PytestTester

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,368 @@
"""
Collection of physical constants and conversion factors.
Most constants are in SI units, so you can do
print '10 mile per minute is', 10*mile/minute, 'm/s or', 10*mile/(minute*knot), 'knots'
The list is not meant to be comprehensive, but just convenient for everyday use.
"""
from __future__ import annotations
import math as _math
from typing import TYPE_CHECKING, Any
from ._codata import value as _cd
if TYPE_CHECKING:
import numpy.typing as npt
from scipy._lib._array_api import array_namespace, _asarray
"""
BasSw 2006
physical constants: imported from CODATA
unit conversion: see e.g., NIST special publication 811
Use at own risk: double-check values before calculating your Mars orbit-insertion burn.
Some constants exist in a few variants, which are marked with suffixes.
The ones without any suffix should be the most common ones.
"""
__all__ = [
'Avogadro', 'Boltzmann', 'Btu', 'Btu_IT', 'Btu_th', 'G',
'Julian_year', 'N_A', 'Planck', 'R', 'Rydberg',
'Stefan_Boltzmann', 'Wien', 'acre', 'alpha',
'angstrom', 'arcmin', 'arcminute', 'arcsec',
'arcsecond', 'astronomical_unit', 'atm',
'atmosphere', 'atomic_mass', 'atto', 'au', 'bar',
'barrel', 'bbl', 'blob', 'c', 'calorie',
'calorie_IT', 'calorie_th', 'carat', 'centi',
'convert_temperature', 'day', 'deci', 'degree',
'degree_Fahrenheit', 'deka', 'dyn', 'dyne', 'e',
'eV', 'electron_mass', 'electron_volt',
'elementary_charge', 'epsilon_0', 'erg',
'exa', 'exbi', 'femto', 'fermi', 'fine_structure',
'fluid_ounce', 'fluid_ounce_US', 'fluid_ounce_imp',
'foot', 'g', 'gallon', 'gallon_US', 'gallon_imp',
'gas_constant', 'gibi', 'giga', 'golden', 'golden_ratio',
'grain', 'gram', 'gravitational_constant', 'h', 'hbar',
'hectare', 'hecto', 'horsepower', 'hour', 'hp',
'inch', 'k', 'kgf', 'kibi', 'kilo', 'kilogram_force',
'kmh', 'knot', 'lambda2nu', 'lb', 'lbf',
'light_year', 'liter', 'litre', 'long_ton', 'm_e',
'm_n', 'm_p', 'm_u', 'mach', 'mebi', 'mega',
'metric_ton', 'micro', 'micron', 'mil', 'mile',
'milli', 'minute', 'mmHg', 'mph', 'mu_0', 'nano',
'nautical_mile', 'neutron_mass', 'nu2lambda',
'ounce', 'oz', 'parsec', 'pebi', 'peta',
'pi', 'pico', 'point', 'pound', 'pound_force',
'proton_mass', 'psi', 'pt', 'quecto', 'quetta', 'ronna', 'ronto',
'short_ton', 'sigma', 'slinch', 'slug', 'speed_of_light',
'speed_of_sound', 'stone', 'survey_foot',
'survey_mile', 'tebi', 'tera', 'ton_TNT',
'torr', 'troy_ounce', 'troy_pound', 'u',
'week', 'yard', 'year', 'yobi', 'yocto',
'yotta', 'zebi', 'zepto', 'zero_Celsius', 'zetta'
]
# mathematical constants
pi = _math.pi
golden = golden_ratio = (1 + _math.sqrt(5)) / 2
# SI prefixes
quetta = 1e30
ronna = 1e27
yotta = 1e24
zetta = 1e21
exa = 1e18
peta = 1e15
tera = 1e12
giga = 1e9
mega = 1e6
kilo = 1e3
hecto = 1e2
deka = 1e1
deci = 1e-1
centi = 1e-2
milli = 1e-3
micro = 1e-6
nano = 1e-9
pico = 1e-12
femto = 1e-15
atto = 1e-18
zepto = 1e-21
yocto = 1e-24
ronto = 1e-27
quecto = 1e-30
# binary prefixes
kibi = 2**10
mebi = 2**20
gibi = 2**30
tebi = 2**40
pebi = 2**50
exbi = 2**60
zebi = 2**70
yobi = 2**80
# physical constants
c = speed_of_light = _cd('speed of light in vacuum')
mu_0 = _cd('vacuum mag. permeability')
epsilon_0 = _cd('vacuum electric permittivity')
h = Planck = _cd('Planck constant')
hbar = h / (2 * pi)
G = gravitational_constant = _cd('Newtonian constant of gravitation')
g = _cd('standard acceleration of gravity')
e = elementary_charge = _cd('elementary charge')
R = gas_constant = _cd('molar gas constant')
alpha = fine_structure = _cd('fine-structure constant')
N_A = Avogadro = _cd('Avogadro constant')
k = Boltzmann = _cd('Boltzmann constant')
sigma = Stefan_Boltzmann = _cd('Stefan-Boltzmann constant')
Wien = _cd('Wien wavelength displacement law constant')
Rydberg = _cd('Rydberg constant')
# mass in kg
gram = 1e-3
metric_ton = 1e3
grain = 64.79891e-6
lb = pound = 7000 * grain # avoirdupois
blob = slinch = pound * g / 0.0254 # lbf*s**2/in (added in 1.0.0)
slug = blob / 12 # lbf*s**2/foot (added in 1.0.0)
oz = ounce = pound / 16
stone = 14 * pound
long_ton = 2240 * pound
short_ton = 2000 * pound
troy_ounce = 480 * grain # only for metals / gems
troy_pound = 12 * troy_ounce
carat = 200e-6
m_e = electron_mass = _cd('electron mass')
m_p = proton_mass = _cd('proton mass')
m_n = neutron_mass = _cd('neutron mass')
m_u = u = atomic_mass = _cd('atomic mass constant')
# angle in rad
degree = pi / 180
arcmin = arcminute = degree / 60
arcsec = arcsecond = arcmin / 60
# time in second
minute = 60.0
hour = 60 * minute
day = 24 * hour
week = 7 * day
year = 365 * day
Julian_year = 365.25 * day
# length in meter
inch = 0.0254
foot = 12 * inch
yard = 3 * foot
mile = 1760 * yard
mil = inch / 1000
pt = point = inch / 72 # typography
survey_foot = 1200.0 / 3937
survey_mile = 5280 * survey_foot
nautical_mile = 1852.0
fermi = 1e-15
angstrom = 1e-10
micron = 1e-6
au = astronomical_unit = 149597870700.0
light_year = Julian_year * c
parsec = au / arcsec
# pressure in pascal
atm = atmosphere = _cd('standard atmosphere')
bar = 1e5
torr = mmHg = atm / 760
psi = pound * g / (inch * inch)
# area in meter**2
hectare = 1e4
acre = 43560 * foot**2
# volume in meter**3
litre = liter = 1e-3
gallon = gallon_US = 231 * inch**3 # US
# pint = gallon_US / 8
fluid_ounce = fluid_ounce_US = gallon_US / 128
bbl = barrel = 42 * gallon_US # for oil
gallon_imp = 4.54609e-3 # UK
fluid_ounce_imp = gallon_imp / 160
# speed in meter per second
kmh = 1e3 / hour
mph = mile / hour
# approx value of mach at 15 degrees in 1 atm. Is this a common value?
mach = speed_of_sound = 340.5
knot = nautical_mile / hour
# temperature in kelvin
zero_Celsius = 273.15
degree_Fahrenheit = 1/1.8 # only for differences
# energy in joule
eV = electron_volt = elementary_charge # * 1 Volt
calorie = calorie_th = 4.184
calorie_IT = 4.1868
erg = 1e-7
Btu_th = pound * degree_Fahrenheit * calorie_th / gram
Btu = Btu_IT = pound * degree_Fahrenheit * calorie_IT / gram
ton_TNT = 1e9 * calorie_th
# Wh = watt_hour
# power in watt
hp = horsepower = 550 * foot * pound * g
# force in newton
dyn = dyne = 1e-5
lbf = pound_force = pound * g
kgf = kilogram_force = g # * 1 kg
# functions for conversions that are not linear
def convert_temperature(
val: npt.ArrayLike,
old_scale: str,
new_scale: str,
) -> Any:
"""
Convert from a temperature scale to another one among Celsius, Kelvin,
Fahrenheit, and Rankine scales.
Parameters
----------
val : array_like
Value(s) of the temperature(s) to be converted expressed in the
original scale.
old_scale : str
Specifies as a string the original scale from which the temperature
value(s) will be converted. Supported scales are Celsius ('Celsius',
'celsius', 'C' or 'c'), Kelvin ('Kelvin', 'kelvin', 'K', 'k'),
Fahrenheit ('Fahrenheit', 'fahrenheit', 'F' or 'f'), and Rankine
('Rankine', 'rankine', 'R', 'r').
new_scale : str
Specifies as a string the new scale to which the temperature
value(s) will be converted. Supported scales are Celsius ('Celsius',
'celsius', 'C' or 'c'), Kelvin ('Kelvin', 'kelvin', 'K', 'k'),
Fahrenheit ('Fahrenheit', 'fahrenheit', 'F' or 'f'), and Rankine
('Rankine', 'rankine', 'R', 'r').
Returns
-------
res : float or array of floats
Value(s) of the converted temperature(s) expressed in the new scale.
Notes
-----
.. versionadded:: 0.18.0
Examples
--------
>>> from scipy.constants import convert_temperature
>>> import numpy as np
>>> convert_temperature(np.array([-40, 40]), 'Celsius', 'Kelvin')
array([ 233.15, 313.15])
"""
xp = array_namespace(val)
_val = _asarray(val, xp=xp, subok=True)
# Convert from `old_scale` to Kelvin
if old_scale.lower() in ['celsius', 'c']:
tempo = _val + zero_Celsius
elif old_scale.lower() in ['kelvin', 'k']:
tempo = _val
elif old_scale.lower() in ['fahrenheit', 'f']:
tempo = (_val - 32) * 5 / 9 + zero_Celsius
elif old_scale.lower() in ['rankine', 'r']:
tempo = _val * 5 / 9
else:
raise NotImplementedError(f"{old_scale=} is unsupported: supported scales "
"are Celsius, Kelvin, Fahrenheit, and "
"Rankine")
# and from Kelvin to `new_scale`.
if new_scale.lower() in ['celsius', 'c']:
res = tempo - zero_Celsius
elif new_scale.lower() in ['kelvin', 'k']:
res = tempo
elif new_scale.lower() in ['fahrenheit', 'f']:
res = (tempo - zero_Celsius) * 9 / 5 + 32
elif new_scale.lower() in ['rankine', 'r']:
res = tempo * 9 / 5
else:
raise NotImplementedError(f"{new_scale=} is unsupported: supported "
"scales are 'Celsius', 'Kelvin', "
"'Fahrenheit', and 'Rankine'")
return res
# optics
def lambda2nu(lambda_: npt.ArrayLike) -> Any:
"""
Convert wavelength to optical frequency
Parameters
----------
lambda_ : array_like
Wavelength(s) to be converted.
Returns
-------
nu : float or array of floats
Equivalent optical frequency.
Notes
-----
Computes ``nu = c / lambda`` where c = 299792458.0, i.e., the
(vacuum) speed of light in meters/second.
Examples
--------
>>> from scipy.constants import lambda2nu, speed_of_light
>>> import numpy as np
>>> lambda2nu(np.array((1, speed_of_light)))
array([ 2.99792458e+08, 1.00000000e+00])
"""
xp = array_namespace(lambda_)
return c / _asarray(lambda_, xp=xp, subok=True)
def nu2lambda(nu: npt.ArrayLike) -> Any:
"""
Convert optical frequency to wavelength.
Parameters
----------
nu : array_like
Optical frequency to be converted.
Returns
-------
lambda : float or array of floats
Equivalent wavelength(s).
Notes
-----
Computes ``lambda = c / nu`` where c = 299792458.0, i.e., the
(vacuum) speed of light in meters/second.
Examples
--------
>>> from scipy.constants import nu2lambda, speed_of_light
>>> import numpy as np
>>> nu2lambda(np.array((1, speed_of_light)))
array([ 2.99792458e+08, 1.00000000e+00])
"""
xp = array_namespace(nu)
return c / _asarray(nu, xp=xp, subok=True)

Some files were not shown because too many files have changed in this diff Show More