asd
This commit is contained in:
161
venv/lib/python3.12/site-packages/scipy/__config__.py
Normal file
161
venv/lib/python3.12/site-packages/scipy/__config__.py
Normal file
@ -0,0 +1,161 @@
|
||||
# This file is generated by SciPy's build process
|
||||
# It contains system_info results at the time of building this package.
|
||||
from enum import Enum
|
||||
|
||||
__all__ = ["show"]
|
||||
_built_with_meson = True
|
||||
|
||||
|
||||
class DisplayModes(Enum):
|
||||
stdout = "stdout"
|
||||
dicts = "dicts"
|
||||
|
||||
|
||||
def _cleanup(d):
|
||||
"""
|
||||
Removes empty values in a `dict` recursively
|
||||
This ensures we remove values that Meson could not provide to CONFIG
|
||||
"""
|
||||
if isinstance(d, dict):
|
||||
return { k: _cleanup(v) for k, v in d.items() if v != '' and _cleanup(v) != '' }
|
||||
else:
|
||||
return d
|
||||
|
||||
|
||||
CONFIG = _cleanup(
|
||||
{
|
||||
"Compilers": {
|
||||
"c": {
|
||||
"name": "gcc",
|
||||
"linker": r"ld.bfd",
|
||||
"version": "10.2.1",
|
||||
"commands": r"cc",
|
||||
"args": r"",
|
||||
"linker args": r"",
|
||||
},
|
||||
"cython": {
|
||||
"name": r"cython",
|
||||
"linker": r"cython",
|
||||
"version": r"3.0.11",
|
||||
"commands": r"cython",
|
||||
"args": r"",
|
||||
"linker args": r"",
|
||||
},
|
||||
"c++": {
|
||||
"name": "gcc",
|
||||
"linker": r"ld.bfd",
|
||||
"version": "10.2.1",
|
||||
"commands": r"c++",
|
||||
"args": r"",
|
||||
"linker args": r"",
|
||||
},
|
||||
"fortran": {
|
||||
"name": "gcc",
|
||||
"linker": r"ld.bfd",
|
||||
"version": "10.2.1",
|
||||
"commands": r"gfortran",
|
||||
"args": r"",
|
||||
"linker args": r"",
|
||||
},
|
||||
"pythran": {
|
||||
"version": r"0.16.1",
|
||||
"include directory": r"../../tmp/pip-build-env-znxdftlb/overlay/lib/python3.12/site-packages/pythran"
|
||||
},
|
||||
},
|
||||
"Machine Information": {
|
||||
"host": {
|
||||
"cpu": r"x86_64",
|
||||
"family": r"x86_64",
|
||||
"endian": r"little",
|
||||
"system": r"linux",
|
||||
},
|
||||
"build": {
|
||||
"cpu": r"x86_64",
|
||||
"family": r"x86_64",
|
||||
"endian": r"little",
|
||||
"system": r"linux",
|
||||
},
|
||||
"cross-compiled": bool("False".lower().replace('false', '')),
|
||||
},
|
||||
"Build Dependencies": {
|
||||
"blas": {
|
||||
"name": "scipy-openblas",
|
||||
"found": bool("True".lower().replace('false', '')),
|
||||
"version": "0.3.27.dev",
|
||||
"detection method": "pkgconfig",
|
||||
"include directory": r"/opt/_internal/cpython-3.12.4/lib/python3.12/site-packages/scipy_openblas32/include",
|
||||
"lib directory": r"/opt/_internal/cpython-3.12.4/lib/python3.12/site-packages/scipy_openblas32/lib",
|
||||
"openblas configuration": r"OpenBLAS 0.3.27.dev DYNAMIC_ARCH NO_AFFINITY Zen MAX_THREADS=64",
|
||||
"pc file directory": r"/project",
|
||||
},
|
||||
"lapack": {
|
||||
"name": "scipy-openblas",
|
||||
"found": bool("True".lower().replace('false', '')),
|
||||
"version": "0.3.27.dev",
|
||||
"detection method": "pkgconfig",
|
||||
"include directory": r"/opt/_internal/cpython-3.12.4/lib/python3.12/site-packages/scipy_openblas32/include",
|
||||
"lib directory": r"/opt/_internal/cpython-3.12.4/lib/python3.12/site-packages/scipy_openblas32/lib",
|
||||
"openblas configuration": r"OpenBLAS 0.3.27.dev DYNAMIC_ARCH NO_AFFINITY Zen MAX_THREADS=64",
|
||||
"pc file directory": r"/project",
|
||||
},
|
||||
"pybind11": {
|
||||
"name": "pybind11",
|
||||
"version": "2.12.0",
|
||||
"detection method": "config-tool",
|
||||
"include directory": r"unknown",
|
||||
},
|
||||
},
|
||||
"Python Information": {
|
||||
"path": r"/opt/python/cp312-cp312/bin/python",
|
||||
"version": "3.12",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _check_pyyaml():
|
||||
import yaml
|
||||
|
||||
return yaml
|
||||
|
||||
|
||||
def show(mode=DisplayModes.stdout.value):
|
||||
"""
|
||||
Show libraries and system information on which SciPy was built
|
||||
and is being used
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mode : {`'stdout'`, `'dicts'`}, optional.
|
||||
Indicates how to display the config information.
|
||||
`'stdout'` prints to console, `'dicts'` returns a dictionary
|
||||
of the configuration.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : {`dict`, `None`}
|
||||
If mode is `'dicts'`, a dict is returned, else None
|
||||
|
||||
Notes
|
||||
-----
|
||||
1. The `'stdout'` mode will give more readable
|
||||
output if ``pyyaml`` is installed
|
||||
|
||||
"""
|
||||
if mode == DisplayModes.stdout.value:
|
||||
try: # Non-standard library, check import
|
||||
yaml = _check_pyyaml()
|
||||
|
||||
print(yaml.dump(CONFIG))
|
||||
except ModuleNotFoundError:
|
||||
import warnings
|
||||
import json
|
||||
|
||||
warnings.warn("Install `pyyaml` for better output", stacklevel=1)
|
||||
print(json.dumps(CONFIG, indent=2))
|
||||
elif mode == DisplayModes.dicts.value:
|
||||
return CONFIG
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"Invalid `mode`, use one of: {', '.join([e.value for e in DisplayModes])}"
|
||||
)
|
||||
141
venv/lib/python3.12/site-packages/scipy/__init__.py
Normal file
141
venv/lib/python3.12/site-packages/scipy/__init__.py
Normal file
@ -0,0 +1,141 @@
|
||||
"""
|
||||
SciPy: A scientific computing package for Python
|
||||
================================================
|
||||
|
||||
Documentation is available in the docstrings and
|
||||
online at https://docs.scipy.org.
|
||||
|
||||
Subpackages
|
||||
-----------
|
||||
Using any of these subpackages requires an explicit import. For example,
|
||||
``import scipy.cluster``.
|
||||
|
||||
::
|
||||
|
||||
cluster --- Vector Quantization / Kmeans
|
||||
constants --- Physical and mathematical constants and units
|
||||
datasets --- Dataset methods
|
||||
fft --- Discrete Fourier transforms
|
||||
fftpack --- Legacy discrete Fourier transforms
|
||||
integrate --- Integration routines
|
||||
interpolate --- Interpolation Tools
|
||||
io --- Data input and output
|
||||
linalg --- Linear algebra routines
|
||||
misc --- Utilities that don't have another home.
|
||||
ndimage --- N-D image package
|
||||
odr --- Orthogonal Distance Regression
|
||||
optimize --- Optimization Tools
|
||||
signal --- Signal Processing Tools
|
||||
sparse --- Sparse Matrices
|
||||
spatial --- Spatial data structures and algorithms
|
||||
special --- Special functions
|
||||
stats --- Statistical Functions
|
||||
|
||||
Public API in the main SciPy namespace
|
||||
--------------------------------------
|
||||
::
|
||||
|
||||
__version__ --- SciPy version string
|
||||
LowLevelCallable --- Low-level callback function
|
||||
show_config --- Show scipy build configuration
|
||||
test --- Run scipy unittests
|
||||
|
||||
"""
|
||||
|
||||
import importlib as _importlib
|
||||
|
||||
from numpy import __version__ as __numpy_version__
|
||||
|
||||
|
||||
try:
|
||||
from scipy.__config__ import show as show_config
|
||||
except ImportError as e:
|
||||
msg = """Error importing SciPy: you cannot import SciPy while
|
||||
being in scipy source directory; please exit the SciPy source
|
||||
tree first and relaunch your Python interpreter."""
|
||||
raise ImportError(msg) from e
|
||||
|
||||
|
||||
from scipy.version import version as __version__
|
||||
|
||||
|
||||
# Allow distributors to run custom init code
|
||||
from . import _distributor_init
|
||||
del _distributor_init
|
||||
|
||||
|
||||
from scipy._lib import _pep440
|
||||
# In maintenance branch, change to np_maxversion N+3 if numpy is at N
|
||||
np_minversion = '1.23.5'
|
||||
np_maxversion = '2.3.0'
|
||||
if (_pep440.parse(__numpy_version__) < _pep440.Version(np_minversion) or
|
||||
_pep440.parse(__numpy_version__) >= _pep440.Version(np_maxversion)):
|
||||
import warnings
|
||||
warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
|
||||
f" is required for this version of SciPy (detected "
|
||||
f"version {__numpy_version__})",
|
||||
UserWarning, stacklevel=2)
|
||||
del _pep440
|
||||
|
||||
|
||||
# This is the first import of an extension module within SciPy. If there's
|
||||
# a general issue with the install, such that extension modules are missing
|
||||
# or cannot be imported, this is where we'll get a failure - so give an
|
||||
# informative error message.
|
||||
try:
|
||||
from scipy._lib._ccallback import LowLevelCallable
|
||||
except ImportError as e:
|
||||
msg = "The `scipy` install you are using seems to be broken, " + \
|
||||
"(extension modules cannot be imported), " + \
|
||||
"please try reinstalling."
|
||||
raise ImportError(msg) from e
|
||||
|
||||
|
||||
from scipy._lib._testutils import PytestTester
|
||||
test = PytestTester(__name__)
|
||||
del PytestTester
|
||||
|
||||
|
||||
submodules = [
|
||||
'cluster',
|
||||
'constants',
|
||||
'datasets',
|
||||
'fft',
|
||||
'fftpack',
|
||||
'integrate',
|
||||
'interpolate',
|
||||
'io',
|
||||
'linalg',
|
||||
'misc',
|
||||
'ndimage',
|
||||
'odr',
|
||||
'optimize',
|
||||
'signal',
|
||||
'sparse',
|
||||
'spatial',
|
||||
'special',
|
||||
'stats'
|
||||
]
|
||||
|
||||
__all__ = submodules + [
|
||||
'LowLevelCallable',
|
||||
'test',
|
||||
'show_config',
|
||||
'__version__',
|
||||
]
|
||||
|
||||
|
||||
def __dir__():
|
||||
return __all__
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
if name in submodules:
|
||||
return _importlib.import_module(f'scipy.{name}')
|
||||
else:
|
||||
try:
|
||||
return globals()[name]
|
||||
except KeyError:
|
||||
raise AttributeError(
|
||||
f"Module 'scipy' has no attribute '{name}'"
|
||||
)
|
||||
18
venv/lib/python3.12/site-packages/scipy/_distributor_init.py
Normal file
18
venv/lib/python3.12/site-packages/scipy/_distributor_init.py
Normal file
@ -0,0 +1,18 @@
|
||||
""" Distributor init file
|
||||
|
||||
Distributors: you can replace the contents of this file with your own custom
|
||||
code to support particular distributions of SciPy.
|
||||
|
||||
For example, this is a good place to put any checks for hardware requirements
|
||||
or BLAS/LAPACK library initialization.
|
||||
|
||||
The SciPy standard source distribution will not put code in this file beyond
|
||||
the try-except import of `_distributor_init_local` (which is not part of a
|
||||
standard source distribution), so you can safely replace this file with your
|
||||
own version.
|
||||
"""
|
||||
|
||||
try:
|
||||
from . import _distributor_init_local # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
14
venv/lib/python3.12/site-packages/scipy/_lib/__init__.py
Normal file
14
venv/lib/python3.12/site-packages/scipy/_lib/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
"""
|
||||
Module containing private utility functions
|
||||
===========================================
|
||||
|
||||
The ``scipy._lib`` namespace is empty (for now). Tests for all
|
||||
utilities in submodules of ``_lib`` can be run with::
|
||||
|
||||
from scipy import _lib
|
||||
_lib.test()
|
||||
|
||||
"""
|
||||
from scipy._lib._testutils import PytestTester
|
||||
test = PytestTester(__name__)
|
||||
del PytestTester
|
||||
524
venv/lib/python3.12/site-packages/scipy/_lib/_array_api.py
Normal file
524
venv/lib/python3.12/site-packages/scipy/_lib/_array_api.py
Normal file
@ -0,0 +1,524 @@
|
||||
"""Utility functions to use Python Array API compatible libraries.
|
||||
|
||||
For the context about the Array API see:
|
||||
https://data-apis.org/array-api/latest/purpose_and_scope.html
|
||||
|
||||
The SciPy use case of the Array API is described on the following page:
|
||||
https://data-apis.org/array-api/latest/use_cases.html#use-case-scipy
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from types import ModuleType
|
||||
from typing import Any, Literal, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from scipy._lib import array_api_compat
|
||||
from scipy._lib.array_api_compat import (
|
||||
is_array_api_obj,
|
||||
size,
|
||||
numpy as np_compat,
|
||||
device
|
||||
)
|
||||
|
||||
__all__ = ['array_namespace', '_asarray', 'size', 'device']
|
||||
|
||||
|
||||
# To enable array API and strict array-like input validation
|
||||
SCIPY_ARRAY_API: str | bool = os.environ.get("SCIPY_ARRAY_API", False)
|
||||
# To control the default device - for use in the test suite only
|
||||
SCIPY_DEVICE = os.environ.get("SCIPY_DEVICE", "cpu")
|
||||
|
||||
_GLOBAL_CONFIG = {
|
||||
"SCIPY_ARRAY_API": SCIPY_ARRAY_API,
|
||||
"SCIPY_DEVICE": SCIPY_DEVICE,
|
||||
}
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
Array = Any # To be changed to a Protocol later (see array-api#589)
|
||||
ArrayLike = Array | npt.ArrayLike
|
||||
|
||||
|
||||
def compliance_scipy(arrays: list[ArrayLike]) -> list[Array]:
|
||||
"""Raise exceptions on known-bad subclasses.
|
||||
|
||||
The following subclasses are not supported and raise and error:
|
||||
- `numpy.ma.MaskedArray`
|
||||
- `numpy.matrix`
|
||||
- NumPy arrays which do not have a boolean or numerical dtype
|
||||
- Any array-like which is neither array API compatible nor coercible by NumPy
|
||||
- Any array-like which is coerced by NumPy to an unsupported dtype
|
||||
"""
|
||||
for i in range(len(arrays)):
|
||||
array = arrays[i]
|
||||
if isinstance(array, np.ma.MaskedArray):
|
||||
raise TypeError("Inputs of type `numpy.ma.MaskedArray` are not supported.")
|
||||
elif isinstance(array, np.matrix):
|
||||
raise TypeError("Inputs of type `numpy.matrix` are not supported.")
|
||||
if isinstance(array, (np.ndarray, np.generic)):
|
||||
dtype = array.dtype
|
||||
if not (np.issubdtype(dtype, np.number) or np.issubdtype(dtype, np.bool_)):
|
||||
raise TypeError(f"An argument has dtype `{dtype!r}`; "
|
||||
f"only boolean and numerical dtypes are supported.")
|
||||
elif not is_array_api_obj(array):
|
||||
try:
|
||||
array = np.asanyarray(array)
|
||||
except TypeError:
|
||||
raise TypeError("An argument is neither array API compatible nor "
|
||||
"coercible by NumPy.")
|
||||
dtype = array.dtype
|
||||
if not (np.issubdtype(dtype, np.number) or np.issubdtype(dtype, np.bool_)):
|
||||
message = (
|
||||
f"An argument was coerced to an unsupported dtype `{dtype!r}`; "
|
||||
f"only boolean and numerical dtypes are supported."
|
||||
)
|
||||
raise TypeError(message)
|
||||
arrays[i] = array
|
||||
return arrays
|
||||
|
||||
|
||||
def _check_finite(array: Array, xp: ModuleType) -> None:
|
||||
"""Check for NaNs or Infs."""
|
||||
msg = "array must not contain infs or NaNs"
|
||||
try:
|
||||
if not xp.all(xp.isfinite(array)):
|
||||
raise ValueError(msg)
|
||||
except TypeError:
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def array_namespace(*arrays: Array) -> ModuleType:
|
||||
"""Get the array API compatible namespace for the arrays xs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
*arrays : sequence of array_like
|
||||
Arrays used to infer the common namespace.
|
||||
|
||||
Returns
|
||||
-------
|
||||
namespace : module
|
||||
Common namespace.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Thin wrapper around `array_api_compat.array_namespace`.
|
||||
|
||||
1. Check for the global switch: SCIPY_ARRAY_API. This can also be accessed
|
||||
dynamically through ``_GLOBAL_CONFIG['SCIPY_ARRAY_API']``.
|
||||
2. `compliance_scipy` raise exceptions on known-bad subclasses. See
|
||||
its definition for more details.
|
||||
|
||||
When the global switch is False, it defaults to the `numpy` namespace.
|
||||
In that case, there is no compliance check. This is a convenience to
|
||||
ease the adoption. Otherwise, arrays must comply with the new rules.
|
||||
"""
|
||||
if not _GLOBAL_CONFIG["SCIPY_ARRAY_API"]:
|
||||
# here we could wrap the namespace if needed
|
||||
return np_compat
|
||||
|
||||
_arrays = [array for array in arrays if array is not None]
|
||||
|
||||
_arrays = compliance_scipy(_arrays)
|
||||
|
||||
return array_api_compat.array_namespace(*_arrays)
|
||||
|
||||
|
||||
def _asarray(
|
||||
array: ArrayLike,
|
||||
dtype: Any = None,
|
||||
order: Literal['K', 'A', 'C', 'F'] | None = None,
|
||||
copy: bool | None = None,
|
||||
*,
|
||||
xp: ModuleType | None = None,
|
||||
check_finite: bool = False,
|
||||
subok: bool = False,
|
||||
) -> Array:
|
||||
"""SciPy-specific replacement for `np.asarray` with `order`, `check_finite`, and
|
||||
`subok`.
|
||||
|
||||
Memory layout parameter `order` is not exposed in the Array API standard.
|
||||
`order` is only enforced if the input array implementation
|
||||
is NumPy based, otherwise `order` is just silently ignored.
|
||||
|
||||
`check_finite` is also not a keyword in the array API standard; included
|
||||
here for convenience rather than that having to be a separate function
|
||||
call inside SciPy functions.
|
||||
|
||||
`subok` is included to allow this function to preserve the behaviour of
|
||||
`np.asanyarray` for NumPy based inputs.
|
||||
"""
|
||||
if xp is None:
|
||||
xp = array_namespace(array)
|
||||
if xp.__name__ in {"numpy", "scipy._lib.array_api_compat.numpy"}:
|
||||
# Use NumPy API to support order
|
||||
if copy is True:
|
||||
array = np.array(array, order=order, dtype=dtype, subok=subok)
|
||||
elif subok:
|
||||
array = np.asanyarray(array, order=order, dtype=dtype)
|
||||
else:
|
||||
array = np.asarray(array, order=order, dtype=dtype)
|
||||
|
||||
# At this point array is a NumPy ndarray. We convert it to an array
|
||||
# container that is consistent with the input's namespace.
|
||||
array = xp.asarray(array)
|
||||
else:
|
||||
try:
|
||||
array = xp.asarray(array, dtype=dtype, copy=copy)
|
||||
except TypeError:
|
||||
coerced_xp = array_namespace(xp.asarray(3))
|
||||
array = coerced_xp.asarray(array, dtype=dtype, copy=copy)
|
||||
|
||||
if check_finite:
|
||||
_check_finite(array, xp)
|
||||
|
||||
return array
|
||||
|
||||
|
||||
def atleast_nd(x: Array, *, ndim: int, xp: ModuleType | None = None) -> Array:
|
||||
"""Recursively expand the dimension to have at least `ndim`."""
|
||||
if xp is None:
|
||||
xp = array_namespace(x)
|
||||
x = xp.asarray(x)
|
||||
if x.ndim < ndim:
|
||||
x = xp.expand_dims(x, axis=0)
|
||||
x = atleast_nd(x, ndim=ndim, xp=xp)
|
||||
return x
|
||||
|
||||
|
||||
def copy(x: Array, *, xp: ModuleType | None = None) -> Array:
|
||||
"""
|
||||
Copies an array.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : array
|
||||
|
||||
xp : array_namespace
|
||||
|
||||
Returns
|
||||
-------
|
||||
copy : array
|
||||
Copied array
|
||||
|
||||
Notes
|
||||
-----
|
||||
This copy function does not offer all the semantics of `np.copy`, i.e. the
|
||||
`subok` and `order` keywords are not used.
|
||||
"""
|
||||
# Note: xp.asarray fails if xp is numpy.
|
||||
if xp is None:
|
||||
xp = array_namespace(x)
|
||||
|
||||
return _asarray(x, copy=True, xp=xp)
|
||||
|
||||
|
||||
def is_numpy(xp: ModuleType) -> bool:
|
||||
return xp.__name__ in ('numpy', 'scipy._lib.array_api_compat.numpy')
|
||||
|
||||
|
||||
def is_cupy(xp: ModuleType) -> bool:
|
||||
return xp.__name__ in ('cupy', 'scipy._lib.array_api_compat.cupy')
|
||||
|
||||
|
||||
def is_torch(xp: ModuleType) -> bool:
|
||||
return xp.__name__ in ('torch', 'scipy._lib.array_api_compat.torch')
|
||||
|
||||
def is_jax(xp):
|
||||
return xp.__name__ in ('jax.numpy', 'jax.experimental.array_api')
|
||||
|
||||
|
||||
def _strict_check(actual, desired, xp,
|
||||
check_namespace=True, check_dtype=True, check_shape=True):
|
||||
__tracebackhide__ = True # Hide traceback for py.test
|
||||
if check_namespace:
|
||||
_assert_matching_namespace(actual, desired)
|
||||
|
||||
desired = xp.asarray(desired)
|
||||
|
||||
if check_dtype:
|
||||
_msg = f"dtypes do not match.\nActual: {actual.dtype}\nDesired: {desired.dtype}"
|
||||
assert actual.dtype == desired.dtype, _msg
|
||||
|
||||
if check_shape:
|
||||
_msg = f"Shapes do not match.\nActual: {actual.shape}\nDesired: {desired.shape}"
|
||||
assert actual.shape == desired.shape, _msg
|
||||
_check_scalar(actual, desired, xp)
|
||||
|
||||
desired = xp.broadcast_to(desired, actual.shape)
|
||||
return desired
|
||||
|
||||
|
||||
def _assert_matching_namespace(actual, desired):
|
||||
__tracebackhide__ = True # Hide traceback for py.test
|
||||
actual = actual if isinstance(actual, tuple) else (actual,)
|
||||
desired_space = array_namespace(desired)
|
||||
for arr in actual:
|
||||
arr_space = array_namespace(arr)
|
||||
_msg = (f"Namespaces do not match.\n"
|
||||
f"Actual: {arr_space.__name__}\n"
|
||||
f"Desired: {desired_space.__name__}")
|
||||
assert arr_space == desired_space, _msg
|
||||
|
||||
|
||||
def _check_scalar(actual, desired, xp):
|
||||
__tracebackhide__ = True # Hide traceback for py.test
|
||||
# Shape check alone is sufficient unless desired.shape == (). Also,
|
||||
# only NumPy distinguishes between scalars and arrays.
|
||||
if desired.shape != () or not is_numpy(xp):
|
||||
return
|
||||
# We want to follow the conventions of the `xp` library. Libraries like
|
||||
# NumPy, for which `np.asarray(0)[()]` returns a scalar, tend to return
|
||||
# a scalar even when a 0D array might be more appropriate:
|
||||
# import numpy as np
|
||||
# np.mean([1, 2, 3]) # scalar, not 0d array
|
||||
# np.asarray(0)*2 # scalar, not 0d array
|
||||
# np.sin(np.asarray(0)) # scalar, not 0d array
|
||||
# Libraries like CuPy, for which `cp.asarray(0)[()]` returns a 0D array,
|
||||
# tend to return a 0D array in scenarios like those above.
|
||||
# Therefore, regardless of whether the developer provides a scalar or 0D
|
||||
# array for `desired`, we would typically want the type of `actual` to be
|
||||
# the type of `desired[()]`. If the developer wants to override this
|
||||
# behavior, they can set `check_shape=False`.
|
||||
desired = desired[()]
|
||||
_msg = f"Types do not match:\n Actual: {type(actual)}\n Desired: {type(desired)}"
|
||||
assert (xp.isscalar(actual) and xp.isscalar(desired)
|
||||
or (not xp.isscalar(actual) and not xp.isscalar(desired))), _msg
|
||||
|
||||
|
||||
def xp_assert_equal(actual, desired, check_namespace=True, check_dtype=True,
|
||||
check_shape=True, err_msg='', xp=None):
|
||||
__tracebackhide__ = True # Hide traceback for py.test
|
||||
if xp is None:
|
||||
xp = array_namespace(actual)
|
||||
desired = _strict_check(actual, desired, xp, check_namespace=check_namespace,
|
||||
check_dtype=check_dtype, check_shape=check_shape)
|
||||
if is_cupy(xp):
|
||||
return xp.testing.assert_array_equal(actual, desired, err_msg=err_msg)
|
||||
elif is_torch(xp):
|
||||
# PyTorch recommends using `rtol=0, atol=0` like this
|
||||
# to test for exact equality
|
||||
err_msg = None if err_msg == '' else err_msg
|
||||
return xp.testing.assert_close(actual, desired, rtol=0, atol=0, equal_nan=True,
|
||||
check_dtype=False, msg=err_msg)
|
||||
# JAX uses `np.testing`
|
||||
return np.testing.assert_array_equal(actual, desired, err_msg=err_msg)
|
||||
|
||||
|
||||
def xp_assert_close(actual, desired, rtol=None, atol=0, check_namespace=True,
|
||||
check_dtype=True, check_shape=True, err_msg='', xp=None):
|
||||
__tracebackhide__ = True # Hide traceback for py.test
|
||||
if xp is None:
|
||||
xp = array_namespace(actual)
|
||||
desired = _strict_check(actual, desired, xp, check_namespace=check_namespace,
|
||||
check_dtype=check_dtype, check_shape=check_shape)
|
||||
|
||||
floating = xp.isdtype(actual.dtype, ('real floating', 'complex floating'))
|
||||
if rtol is None and floating:
|
||||
# multiplier of 4 is used as for `np.float64` this puts the default `rtol`
|
||||
# roughly half way between sqrt(eps) and the default for
|
||||
# `numpy.testing.assert_allclose`, 1e-7
|
||||
rtol = xp.finfo(actual.dtype).eps**0.5 * 4
|
||||
elif rtol is None:
|
||||
rtol = 1e-7
|
||||
|
||||
if is_cupy(xp):
|
||||
return xp.testing.assert_allclose(actual, desired, rtol=rtol,
|
||||
atol=atol, err_msg=err_msg)
|
||||
elif is_torch(xp):
|
||||
err_msg = None if err_msg == '' else err_msg
|
||||
return xp.testing.assert_close(actual, desired, rtol=rtol, atol=atol,
|
||||
equal_nan=True, check_dtype=False, msg=err_msg)
|
||||
# JAX uses `np.testing`
|
||||
return np.testing.assert_allclose(actual, desired, rtol=rtol,
|
||||
atol=atol, err_msg=err_msg)
|
||||
|
||||
|
||||
def xp_assert_less(actual, desired, check_namespace=True, check_dtype=True,
|
||||
check_shape=True, err_msg='', verbose=True, xp=None):
|
||||
__tracebackhide__ = True # Hide traceback for py.test
|
||||
if xp is None:
|
||||
xp = array_namespace(actual)
|
||||
desired = _strict_check(actual, desired, xp, check_namespace=check_namespace,
|
||||
check_dtype=check_dtype, check_shape=check_shape)
|
||||
if is_cupy(xp):
|
||||
return xp.testing.assert_array_less(actual, desired,
|
||||
err_msg=err_msg, verbose=verbose)
|
||||
elif is_torch(xp):
|
||||
if actual.device.type != 'cpu':
|
||||
actual = actual.cpu()
|
||||
if desired.device.type != 'cpu':
|
||||
desired = desired.cpu()
|
||||
# JAX uses `np.testing`
|
||||
return np.testing.assert_array_less(actual, desired,
|
||||
err_msg=err_msg, verbose=verbose)
|
||||
|
||||
|
||||
def cov(x: Array, *, xp: ModuleType | None = None) -> Array:
|
||||
if xp is None:
|
||||
xp = array_namespace(x)
|
||||
|
||||
X = copy(x, xp=xp)
|
||||
dtype = xp.result_type(X, xp.float64)
|
||||
|
||||
X = atleast_nd(X, ndim=2, xp=xp)
|
||||
X = xp.asarray(X, dtype=dtype)
|
||||
|
||||
avg = xp.mean(X, axis=1)
|
||||
fact = X.shape[1] - 1
|
||||
|
||||
if fact <= 0:
|
||||
warnings.warn("Degrees of freedom <= 0 for slice",
|
||||
RuntimeWarning, stacklevel=2)
|
||||
fact = 0.0
|
||||
|
||||
X -= avg[:, None]
|
||||
X_T = X.T
|
||||
if xp.isdtype(X_T.dtype, 'complex floating'):
|
||||
X_T = xp.conj(X_T)
|
||||
c = X @ X_T
|
||||
c /= fact
|
||||
axes = tuple(axis for axis, length in enumerate(c.shape) if length == 1)
|
||||
return xp.squeeze(c, axis=axes)
|
||||
|
||||
|
||||
def xp_unsupported_param_msg(param: Any) -> str:
|
||||
return f'Providing {param!r} is only supported for numpy arrays.'
|
||||
|
||||
|
||||
def is_complex(x: Array, xp: ModuleType) -> bool:
|
||||
return xp.isdtype(x.dtype, 'complex floating')
|
||||
|
||||
|
||||
def get_xp_devices(xp: ModuleType) -> list[str] | list[None]:
|
||||
"""Returns a list of available devices for the given namespace."""
|
||||
devices: list[str] = []
|
||||
if is_torch(xp):
|
||||
devices += ['cpu']
|
||||
import torch # type: ignore[import]
|
||||
num_cuda = torch.cuda.device_count()
|
||||
for i in range(0, num_cuda):
|
||||
devices += [f'cuda:{i}']
|
||||
if torch.backends.mps.is_available():
|
||||
devices += ['mps']
|
||||
return devices
|
||||
elif is_cupy(xp):
|
||||
import cupy # type: ignore[import]
|
||||
num_cuda = cupy.cuda.runtime.getDeviceCount()
|
||||
for i in range(0, num_cuda):
|
||||
devices += [f'cuda:{i}']
|
||||
return devices
|
||||
elif is_jax(xp):
|
||||
import jax # type: ignore[import]
|
||||
num_cpu = jax.device_count(backend='cpu')
|
||||
for i in range(0, num_cpu):
|
||||
devices += [f'cpu:{i}']
|
||||
num_gpu = jax.device_count(backend='gpu')
|
||||
for i in range(0, num_gpu):
|
||||
devices += [f'gpu:{i}']
|
||||
num_tpu = jax.device_count(backend='tpu')
|
||||
for i in range(0, num_tpu):
|
||||
devices += [f'tpu:{i}']
|
||||
return devices
|
||||
|
||||
# given namespace is not known to have a list of available devices;
|
||||
# return `[None]` so that one can use this in tests for `device=None`.
|
||||
return [None]
|
||||
|
||||
|
||||
def scipy_namespace_for(xp: ModuleType) -> ModuleType:
|
||||
"""
|
||||
Return the `scipy` namespace for alternative backends, where it exists,
|
||||
such as `cupyx.scipy` and `jax.scipy`. Useful for ad hoc dispatching.
|
||||
|
||||
Default: return `scipy` (this package).
|
||||
"""
|
||||
|
||||
|
||||
if is_cupy(xp):
|
||||
import cupyx # type: ignore[import-not-found,import-untyped]
|
||||
return cupyx.scipy
|
||||
|
||||
if is_jax(xp):
|
||||
import jax # type: ignore[import-not-found]
|
||||
return jax.scipy
|
||||
|
||||
import scipy
|
||||
return scipy
|
||||
|
||||
|
||||
# temporary substitute for xp.minimum, which is not yet in all backends
|
||||
# or covered by array_api_compat.
|
||||
def xp_minimum(x1: Array, x2: Array, /) -> Array:
|
||||
# xp won't be passed in because it doesn't need to be passed in to xp.minimum
|
||||
xp = array_namespace(x1, x2)
|
||||
if hasattr(xp, 'minimum'):
|
||||
return xp.minimum(x1, x2)
|
||||
x1, x2 = xp.broadcast_arrays(x1, x2)
|
||||
i = (x2 < x1) | xp.isnan(x2)
|
||||
res = xp.where(i, x2, x1)
|
||||
return res[()] if res.ndim == 0 else res
|
||||
|
||||
|
||||
# temporary substitute for xp.clip, which is not yet in all backends
|
||||
# or covered by array_api_compat.
|
||||
def xp_clip(
|
||||
x: Array,
|
||||
/,
|
||||
min: int | float | Array | None = None,
|
||||
max: int | float | Array | None = None,
|
||||
*,
|
||||
xp: ModuleType | None = None) -> Array:
|
||||
xp = array_namespace(x) if xp is None else xp
|
||||
a, b = xp.asarray(min, dtype=x.dtype), xp.asarray(max, dtype=x.dtype)
|
||||
if hasattr(xp, 'clip'):
|
||||
return xp.clip(x, a, b)
|
||||
x, a, b = xp.broadcast_arrays(x, a, b)
|
||||
y = xp.asarray(x, copy=True)
|
||||
ia = y < a
|
||||
y[ia] = a[ia]
|
||||
ib = y > b
|
||||
y[ib] = b[ib]
|
||||
return y[()] if y.ndim == 0 else y
|
||||
|
||||
|
||||
# temporary substitute for xp.moveaxis, which is not yet in all backends
|
||||
# or covered by array_api_compat.
|
||||
def xp_moveaxis_to_end(
|
||||
x: Array,
|
||||
source: int,
|
||||
/, *,
|
||||
xp: ModuleType | None = None) -> Array:
|
||||
xp = array_namespace(xp) if xp is None else xp
|
||||
axes = list(range(x.ndim))
|
||||
temp = axes.pop(source)
|
||||
axes = axes + [temp]
|
||||
return xp.permute_dims(x, axes)
|
||||
|
||||
|
||||
# temporary substitute for xp.copysign, which is not yet in all backends
|
||||
# or covered by array_api_compat.
|
||||
def xp_copysign(x1: Array, x2: Array, /, *, xp: ModuleType | None = None) -> Array:
|
||||
# no attempt to account for special cases
|
||||
xp = array_namespace(x1, x2) if xp is None else xp
|
||||
abs_x1 = xp.abs(x1)
|
||||
return xp.where(x2 >= 0, abs_x1, -abs_x1)
|
||||
|
||||
|
||||
# partial substitute for xp.sign, which does not cover the NaN special case
|
||||
# that I need. (https://github.com/data-apis/array-api-compat/issues/136)
|
||||
def xp_sign(x: Array, /, *, xp: ModuleType | None = None) -> Array:
|
||||
xp = array_namespace(x) if xp is None else xp
|
||||
if is_numpy(xp): # only NumPy implements the special cases correctly
|
||||
return xp.sign(x)
|
||||
sign = xp.full_like(x, xp.nan)
|
||||
one = xp.asarray(1, dtype=x.dtype)
|
||||
sign = xp.where(x > 0, one, sign)
|
||||
sign = xp.where(x < 0, -one, sign)
|
||||
sign = xp.where(x == 0, 0*one, sign)
|
||||
return sign
|
||||
225
venv/lib/python3.12/site-packages/scipy/_lib/_bunch.py
Normal file
225
venv/lib/python3.12/site-packages/scipy/_lib/_bunch.py
Normal file
@ -0,0 +1,225 @@
|
||||
import sys as _sys
|
||||
from keyword import iskeyword as _iskeyword
|
||||
|
||||
|
||||
def _validate_names(typename, field_names, extra_field_names):
|
||||
"""
|
||||
Ensure that all the given names are valid Python identifiers that
|
||||
do not start with '_'. Also check that there are no duplicates
|
||||
among field_names + extra_field_names.
|
||||
"""
|
||||
for name in [typename] + field_names + extra_field_names:
|
||||
if not isinstance(name, str):
|
||||
raise TypeError('typename and all field names must be strings')
|
||||
if not name.isidentifier():
|
||||
raise ValueError('typename and all field names must be valid '
|
||||
f'identifiers: {name!r}')
|
||||
if _iskeyword(name):
|
||||
raise ValueError('typename and all field names cannot be a '
|
||||
f'keyword: {name!r}')
|
||||
|
||||
seen = set()
|
||||
for name in field_names + extra_field_names:
|
||||
if name.startswith('_'):
|
||||
raise ValueError('Field names cannot start with an underscore: '
|
||||
f'{name!r}')
|
||||
if name in seen:
|
||||
raise ValueError(f'Duplicate field name: {name!r}')
|
||||
seen.add(name)
|
||||
|
||||
|
||||
# Note: This code is adapted from CPython:Lib/collections/__init__.py
|
||||
def _make_tuple_bunch(typename, field_names, extra_field_names=None,
|
||||
module=None):
|
||||
"""
|
||||
Create a namedtuple-like class with additional attributes.
|
||||
|
||||
This function creates a subclass of tuple that acts like a namedtuple
|
||||
and that has additional attributes.
|
||||
|
||||
The additional attributes are listed in `extra_field_names`. The
|
||||
values assigned to these attributes are not part of the tuple.
|
||||
|
||||
The reason this function exists is to allow functions in SciPy
|
||||
that currently return a tuple or a namedtuple to returned objects
|
||||
that have additional attributes, while maintaining backwards
|
||||
compatibility.
|
||||
|
||||
This should only be used to enhance *existing* functions in SciPy.
|
||||
New functions are free to create objects as return values without
|
||||
having to maintain backwards compatibility with an old tuple or
|
||||
namedtuple return value.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
typename : str
|
||||
The name of the type.
|
||||
field_names : list of str
|
||||
List of names of the values to be stored in the tuple. These names
|
||||
will also be attributes of instances, so the values in the tuple
|
||||
can be accessed by indexing or as attributes. At least one name
|
||||
is required. See the Notes for additional restrictions.
|
||||
extra_field_names : list of str, optional
|
||||
List of names of values that will be stored as attributes of the
|
||||
object. See the notes for additional restrictions.
|
||||
|
||||
Returns
|
||||
-------
|
||||
cls : type
|
||||
The new class.
|
||||
|
||||
Notes
|
||||
-----
|
||||
There are restrictions on the names that may be used in `field_names`
|
||||
and `extra_field_names`:
|
||||
|
||||
* The names must be unique--no duplicates allowed.
|
||||
* The names must be valid Python identifiers, and must not begin with
|
||||
an underscore.
|
||||
* The names must not be Python keywords (e.g. 'def', 'and', etc., are
|
||||
not allowed).
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from scipy._lib._bunch import _make_tuple_bunch
|
||||
|
||||
Create a class that acts like a namedtuple with length 2 (with field
|
||||
names `x` and `y`) that will also have the attributes `w` and `beta`:
|
||||
|
||||
>>> Result = _make_tuple_bunch('Result', ['x', 'y'], ['w', 'beta'])
|
||||
|
||||
`Result` is the new class. We call it with keyword arguments to create
|
||||
a new instance with given values.
|
||||
|
||||
>>> result1 = Result(x=1, y=2, w=99, beta=0.5)
|
||||
>>> result1
|
||||
Result(x=1, y=2, w=99, beta=0.5)
|
||||
|
||||
`result1` acts like a tuple of length 2:
|
||||
|
||||
>>> len(result1)
|
||||
2
|
||||
>>> result1[:]
|
||||
(1, 2)
|
||||
|
||||
The values assigned when the instance was created are available as
|
||||
attributes:
|
||||
|
||||
>>> result1.y
|
||||
2
|
||||
>>> result1.beta
|
||||
0.5
|
||||
"""
|
||||
if len(field_names) == 0:
|
||||
raise ValueError('field_names must contain at least one name')
|
||||
|
||||
if extra_field_names is None:
|
||||
extra_field_names = []
|
||||
_validate_names(typename, field_names, extra_field_names)
|
||||
|
||||
typename = _sys.intern(str(typename))
|
||||
field_names = tuple(map(_sys.intern, field_names))
|
||||
extra_field_names = tuple(map(_sys.intern, extra_field_names))
|
||||
|
||||
all_names = field_names + extra_field_names
|
||||
arg_list = ', '.join(field_names)
|
||||
full_list = ', '.join(all_names)
|
||||
repr_fmt = ''.join(('(',
|
||||
', '.join(f'{name}=%({name})r' for name in all_names),
|
||||
')'))
|
||||
tuple_new = tuple.__new__
|
||||
_dict, _tuple, _zip = dict, tuple, zip
|
||||
|
||||
# Create all the named tuple methods to be added to the class namespace
|
||||
|
||||
s = f"""\
|
||||
def __new__(_cls, {arg_list}, **extra_fields):
|
||||
return _tuple_new(_cls, ({arg_list},))
|
||||
|
||||
def __init__(self, {arg_list}, **extra_fields):
|
||||
for key in self._extra_fields:
|
||||
if key not in extra_fields:
|
||||
raise TypeError("missing keyword argument '%s'" % (key,))
|
||||
for key, val in extra_fields.items():
|
||||
if key not in self._extra_fields:
|
||||
raise TypeError("unexpected keyword argument '%s'" % (key,))
|
||||
self.__dict__[key] = val
|
||||
|
||||
def __setattr__(self, key, val):
|
||||
if key in {repr(field_names)}:
|
||||
raise AttributeError("can't set attribute %r of class %r"
|
||||
% (key, self.__class__.__name__))
|
||||
else:
|
||||
self.__dict__[key] = val
|
||||
"""
|
||||
del arg_list
|
||||
namespace = {'_tuple_new': tuple_new,
|
||||
'__builtins__': dict(TypeError=TypeError,
|
||||
AttributeError=AttributeError),
|
||||
'__name__': f'namedtuple_{typename}'}
|
||||
exec(s, namespace)
|
||||
__new__ = namespace['__new__']
|
||||
__new__.__doc__ = f'Create new instance of {typename}({full_list})'
|
||||
__init__ = namespace['__init__']
|
||||
__init__.__doc__ = f'Instantiate instance of {typename}({full_list})'
|
||||
__setattr__ = namespace['__setattr__']
|
||||
|
||||
def __repr__(self):
|
||||
'Return a nicely formatted representation string'
|
||||
return self.__class__.__name__ + repr_fmt % self._asdict()
|
||||
|
||||
def _asdict(self):
|
||||
'Return a new dict which maps field names to their values.'
|
||||
out = _dict(_zip(self._fields, self))
|
||||
out.update(self.__dict__)
|
||||
return out
|
||||
|
||||
def __getnewargs_ex__(self):
|
||||
'Return self as a plain tuple. Used by copy and pickle.'
|
||||
return _tuple(self), self.__dict__
|
||||
|
||||
# Modify function metadata to help with introspection and debugging
|
||||
for method in (__new__, __repr__, _asdict, __getnewargs_ex__):
|
||||
method.__qualname__ = f'{typename}.{method.__name__}'
|
||||
|
||||
# Build-up the class namespace dictionary
|
||||
# and use type() to build the result class
|
||||
class_namespace = {
|
||||
'__doc__': f'{typename}({full_list})',
|
||||
'_fields': field_names,
|
||||
'__new__': __new__,
|
||||
'__init__': __init__,
|
||||
'__repr__': __repr__,
|
||||
'__setattr__': __setattr__,
|
||||
'_asdict': _asdict,
|
||||
'_extra_fields': extra_field_names,
|
||||
'__getnewargs_ex__': __getnewargs_ex__,
|
||||
}
|
||||
for index, name in enumerate(field_names):
|
||||
|
||||
def _get(self, index=index):
|
||||
return self[index]
|
||||
class_namespace[name] = property(_get)
|
||||
for name in extra_field_names:
|
||||
|
||||
def _get(self, name=name):
|
||||
return self.__dict__[name]
|
||||
class_namespace[name] = property(_get)
|
||||
|
||||
result = type(typename, (tuple,), class_namespace)
|
||||
|
||||
# For pickling to work, the __module__ variable needs to be set to the
|
||||
# frame where the named tuple is created. Bypass this step in environments
|
||||
# where sys._getframe is not defined (Jython for example) or sys._getframe
|
||||
# is not defined for arguments greater than 0 (IronPython), or where the
|
||||
# user has specified a particular module.
|
||||
if module is None:
|
||||
try:
|
||||
module = _sys._getframe(1).f_globals.get('__name__', '__main__')
|
||||
except (AttributeError, ValueError):
|
||||
pass
|
||||
if module is not None:
|
||||
result.__module__ = module
|
||||
__new__.__module__ = module
|
||||
|
||||
return result
|
||||
251
venv/lib/python3.12/site-packages/scipy/_lib/_ccallback.py
Normal file
251
venv/lib/python3.12/site-packages/scipy/_lib/_ccallback.py
Normal file
@ -0,0 +1,251 @@
|
||||
from . import _ccallback_c
|
||||
|
||||
import ctypes
|
||||
|
||||
PyCFuncPtr = ctypes.CFUNCTYPE(ctypes.c_void_p).__bases__[0]
|
||||
|
||||
ffi = None
|
||||
|
||||
class CData:
|
||||
pass
|
||||
|
||||
def _import_cffi():
|
||||
global ffi, CData
|
||||
|
||||
if ffi is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
import cffi
|
||||
ffi = cffi.FFI()
|
||||
CData = ffi.CData
|
||||
except ImportError:
|
||||
ffi = False
|
||||
|
||||
|
||||
class LowLevelCallable(tuple):
|
||||
"""
|
||||
Low-level callback function.
|
||||
|
||||
Some functions in SciPy take as arguments callback functions, which
|
||||
can either be python callables or low-level compiled functions. Using
|
||||
compiled callback functions can improve performance somewhat by
|
||||
avoiding wrapping data in Python objects.
|
||||
|
||||
Such low-level functions in SciPy are wrapped in `LowLevelCallable`
|
||||
objects, which can be constructed from function pointers obtained from
|
||||
ctypes, cffi, Cython, or contained in Python `PyCapsule` objects.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Functions accepting low-level callables:
|
||||
|
||||
`scipy.integrate.quad`, `scipy.ndimage.generic_filter`,
|
||||
`scipy.ndimage.generic_filter1d`, `scipy.ndimage.geometric_transform`
|
||||
|
||||
Usage examples:
|
||||
|
||||
:ref:`ndimage-ccallbacks`, :ref:`quad-callbacks`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
function : {PyCapsule, ctypes function pointer, cffi function pointer}
|
||||
Low-level callback function.
|
||||
user_data : {PyCapsule, ctypes void pointer, cffi void pointer}
|
||||
User data to pass on to the callback function.
|
||||
signature : str, optional
|
||||
Signature of the function. If omitted, determined from *function*,
|
||||
if possible.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
function
|
||||
Callback function given.
|
||||
user_data
|
||||
User data given.
|
||||
signature
|
||||
Signature of the function.
|
||||
|
||||
Methods
|
||||
-------
|
||||
from_cython
|
||||
Class method for constructing callables from Cython C-exported
|
||||
functions.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The argument ``function`` can be one of:
|
||||
|
||||
- PyCapsule, whose name contains the C function signature
|
||||
- ctypes function pointer
|
||||
- cffi function pointer
|
||||
|
||||
The signature of the low-level callback must match one of those expected
|
||||
by the routine it is passed to.
|
||||
|
||||
If constructing low-level functions from a PyCapsule, the name of the
|
||||
capsule must be the corresponding signature, in the format::
|
||||
|
||||
return_type (arg1_type, arg2_type, ...)
|
||||
|
||||
For example::
|
||||
|
||||
"void (double)"
|
||||
"double (double, int *, void *)"
|
||||
|
||||
The context of a PyCapsule passed in as ``function`` is used as ``user_data``,
|
||||
if an explicit value for ``user_data`` was not given.
|
||||
|
||||
"""
|
||||
|
||||
# Make the class immutable
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, function, user_data=None, signature=None):
|
||||
# We need to hold a reference to the function & user data,
|
||||
# to prevent them going out of scope
|
||||
item = cls._parse_callback(function, user_data, signature)
|
||||
return tuple.__new__(cls, (item, function, user_data))
|
||||
|
||||
def __repr__(self):
|
||||
return f"LowLevelCallable({self.function!r}, {self.user_data!r})"
|
||||
|
||||
@property
|
||||
def function(self):
|
||||
return tuple.__getitem__(self, 1)
|
||||
|
||||
@property
|
||||
def user_data(self):
|
||||
return tuple.__getitem__(self, 2)
|
||||
|
||||
@property
|
||||
def signature(self):
|
||||
return _ccallback_c.get_capsule_signature(tuple.__getitem__(self, 0))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
raise ValueError()
|
||||
|
||||
@classmethod
|
||||
def from_cython(cls, module, name, user_data=None, signature=None):
|
||||
"""
|
||||
Create a low-level callback function from an exported Cython function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
module : module
|
||||
Cython module where the exported function resides
|
||||
name : str
|
||||
Name of the exported function
|
||||
user_data : {PyCapsule, ctypes void pointer, cffi void pointer}, optional
|
||||
User data to pass on to the callback function.
|
||||
signature : str, optional
|
||||
Signature of the function. If omitted, determined from *function*.
|
||||
|
||||
"""
|
||||
try:
|
||||
function = module.__pyx_capi__[name]
|
||||
except AttributeError as e:
|
||||
message = "Given module is not a Cython module with __pyx_capi__ attribute"
|
||||
raise ValueError(message) from e
|
||||
except KeyError as e:
|
||||
message = f"No function {name!r} found in __pyx_capi__ of the module"
|
||||
raise ValueError(message) from e
|
||||
return cls(function, user_data, signature)
|
||||
|
||||
@classmethod
|
||||
def _parse_callback(cls, obj, user_data=None, signature=None):
|
||||
_import_cffi()
|
||||
|
||||
if isinstance(obj, LowLevelCallable):
|
||||
func = tuple.__getitem__(obj, 0)
|
||||
elif isinstance(obj, PyCFuncPtr):
|
||||
func, signature = _get_ctypes_func(obj, signature)
|
||||
elif isinstance(obj, CData):
|
||||
func, signature = _get_cffi_func(obj, signature)
|
||||
elif _ccallback_c.check_capsule(obj):
|
||||
func = obj
|
||||
else:
|
||||
raise ValueError("Given input is not a callable or a "
|
||||
"low-level callable (pycapsule/ctypes/cffi)")
|
||||
|
||||
if isinstance(user_data, ctypes.c_void_p):
|
||||
context = _get_ctypes_data(user_data)
|
||||
elif isinstance(user_data, CData):
|
||||
context = _get_cffi_data(user_data)
|
||||
elif user_data is None:
|
||||
context = 0
|
||||
elif _ccallback_c.check_capsule(user_data):
|
||||
context = user_data
|
||||
else:
|
||||
raise ValueError("Given user data is not a valid "
|
||||
"low-level void* pointer (pycapsule/ctypes/cffi)")
|
||||
|
||||
return _ccallback_c.get_raw_capsule(func, signature, context)
|
||||
|
||||
|
||||
#
|
||||
# ctypes helpers
|
||||
#
|
||||
|
||||
def _get_ctypes_func(func, signature=None):
|
||||
# Get function pointer
|
||||
func_ptr = ctypes.cast(func, ctypes.c_void_p).value
|
||||
|
||||
# Construct function signature
|
||||
if signature is None:
|
||||
signature = _typename_from_ctypes(func.restype) + " ("
|
||||
for j, arg in enumerate(func.argtypes):
|
||||
if j == 0:
|
||||
signature += _typename_from_ctypes(arg)
|
||||
else:
|
||||
signature += ", " + _typename_from_ctypes(arg)
|
||||
signature += ")"
|
||||
|
||||
return func_ptr, signature
|
||||
|
||||
|
||||
def _typename_from_ctypes(item):
|
||||
if item is None:
|
||||
return "void"
|
||||
elif item is ctypes.c_void_p:
|
||||
return "void *"
|
||||
|
||||
name = item.__name__
|
||||
|
||||
pointer_level = 0
|
||||
while name.startswith("LP_"):
|
||||
pointer_level += 1
|
||||
name = name[3:]
|
||||
|
||||
if name.startswith('c_'):
|
||||
name = name[2:]
|
||||
|
||||
if pointer_level > 0:
|
||||
name += " " + "*"*pointer_level
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def _get_ctypes_data(data):
|
||||
# Get voidp pointer
|
||||
return ctypes.cast(data, ctypes.c_void_p).value
|
||||
|
||||
|
||||
#
|
||||
# CFFI helpers
|
||||
#
|
||||
|
||||
def _get_cffi_func(func, signature=None):
|
||||
# Get function pointer
|
||||
func_ptr = ffi.cast('uintptr_t', func)
|
||||
|
||||
# Get signature
|
||||
if signature is None:
|
||||
signature = ffi.getctype(ffi.typeof(func)).replace('(*)', ' ')
|
||||
|
||||
return func_ptr, signature
|
||||
|
||||
|
||||
def _get_cffi_data(data):
|
||||
# Get pointer
|
||||
return ffi.cast('uintptr_t', data)
|
||||
Binary file not shown.
254
venv/lib/python3.12/site-packages/scipy/_lib/_disjoint_set.py
Normal file
254
venv/lib/python3.12/site-packages/scipy/_lib/_disjoint_set.py
Normal file
@ -0,0 +1,254 @@
|
||||
"""
|
||||
Disjoint set data structure
|
||||
"""
|
||||
|
||||
|
||||
class DisjointSet:
|
||||
""" Disjoint set data structure for incremental connectivity queries.
|
||||
|
||||
.. versionadded:: 1.6.0
|
||||
|
||||
Attributes
|
||||
----------
|
||||
n_subsets : int
|
||||
The number of subsets.
|
||||
|
||||
Methods
|
||||
-------
|
||||
add
|
||||
merge
|
||||
connected
|
||||
subset
|
||||
subset_size
|
||||
subsets
|
||||
__getitem__
|
||||
|
||||
Notes
|
||||
-----
|
||||
This class implements the disjoint set [1]_, also known as the *union-find*
|
||||
or *merge-find* data structure. The *find* operation (implemented in
|
||||
`__getitem__`) implements the *path halving* variant. The *merge* method
|
||||
implements the *merge by size* variant.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] https://en.wikipedia.org/wiki/Disjoint-set_data_structure
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from scipy.cluster.hierarchy import DisjointSet
|
||||
|
||||
Initialize a disjoint set:
|
||||
|
||||
>>> disjoint_set = DisjointSet([1, 2, 3, 'a', 'b'])
|
||||
|
||||
Merge some subsets:
|
||||
|
||||
>>> disjoint_set.merge(1, 2)
|
||||
True
|
||||
>>> disjoint_set.merge(3, 'a')
|
||||
True
|
||||
>>> disjoint_set.merge('a', 'b')
|
||||
True
|
||||
>>> disjoint_set.merge('b', 'b')
|
||||
False
|
||||
|
||||
Find root elements:
|
||||
|
||||
>>> disjoint_set[2]
|
||||
1
|
||||
>>> disjoint_set['b']
|
||||
3
|
||||
|
||||
Test connectivity:
|
||||
|
||||
>>> disjoint_set.connected(1, 2)
|
||||
True
|
||||
>>> disjoint_set.connected(1, 'b')
|
||||
False
|
||||
|
||||
List elements in disjoint set:
|
||||
|
||||
>>> list(disjoint_set)
|
||||
[1, 2, 3, 'a', 'b']
|
||||
|
||||
Get the subset containing 'a':
|
||||
|
||||
>>> disjoint_set.subset('a')
|
||||
{'a', 3, 'b'}
|
||||
|
||||
Get the size of the subset containing 'a' (without actually instantiating
|
||||
the subset):
|
||||
|
||||
>>> disjoint_set.subset_size('a')
|
||||
3
|
||||
|
||||
Get all subsets in the disjoint set:
|
||||
|
||||
>>> disjoint_set.subsets()
|
||||
[{1, 2}, {'a', 3, 'b'}]
|
||||
"""
|
||||
def __init__(self, elements=None):
|
||||
self.n_subsets = 0
|
||||
self._sizes = {}
|
||||
self._parents = {}
|
||||
# _nbrs is a circular linked list which links connected elements.
|
||||
self._nbrs = {}
|
||||
# _indices tracks the element insertion order in `__iter__`.
|
||||
self._indices = {}
|
||||
if elements is not None:
|
||||
for x in elements:
|
||||
self.add(x)
|
||||
|
||||
def __iter__(self):
|
||||
"""Returns an iterator of the elements in the disjoint set.
|
||||
|
||||
Elements are ordered by insertion order.
|
||||
"""
|
||||
return iter(self._indices)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._indices)
|
||||
|
||||
def __contains__(self, x):
|
||||
return x in self._indices
|
||||
|
||||
def __getitem__(self, x):
|
||||
"""Find the root element of `x`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : hashable object
|
||||
Input element.
|
||||
|
||||
Returns
|
||||
-------
|
||||
root : hashable object
|
||||
Root element of `x`.
|
||||
"""
|
||||
if x not in self._indices:
|
||||
raise KeyError(x)
|
||||
|
||||
# find by "path halving"
|
||||
parents = self._parents
|
||||
while self._indices[x] != self._indices[parents[x]]:
|
||||
parents[x] = parents[parents[x]]
|
||||
x = parents[x]
|
||||
return x
|
||||
|
||||
def add(self, x):
|
||||
"""Add element `x` to disjoint set
|
||||
"""
|
||||
if x in self._indices:
|
||||
return
|
||||
|
||||
self._sizes[x] = 1
|
||||
self._parents[x] = x
|
||||
self._nbrs[x] = x
|
||||
self._indices[x] = len(self._indices)
|
||||
self.n_subsets += 1
|
||||
|
||||
def merge(self, x, y):
|
||||
"""Merge the subsets of `x` and `y`.
|
||||
|
||||
The smaller subset (the child) is merged into the larger subset (the
|
||||
parent). If the subsets are of equal size, the root element which was
|
||||
first inserted into the disjoint set is selected as the parent.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x, y : hashable object
|
||||
Elements to merge.
|
||||
|
||||
Returns
|
||||
-------
|
||||
merged : bool
|
||||
True if `x` and `y` were in disjoint sets, False otherwise.
|
||||
"""
|
||||
xr = self[x]
|
||||
yr = self[y]
|
||||
if self._indices[xr] == self._indices[yr]:
|
||||
return False
|
||||
|
||||
sizes = self._sizes
|
||||
if (sizes[xr], self._indices[yr]) < (sizes[yr], self._indices[xr]):
|
||||
xr, yr = yr, xr
|
||||
self._parents[yr] = xr
|
||||
self._sizes[xr] += self._sizes[yr]
|
||||
self._nbrs[xr], self._nbrs[yr] = self._nbrs[yr], self._nbrs[xr]
|
||||
self.n_subsets -= 1
|
||||
return True
|
||||
|
||||
def connected(self, x, y):
|
||||
"""Test whether `x` and `y` are in the same subset.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x, y : hashable object
|
||||
Elements to test.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : bool
|
||||
True if `x` and `y` are in the same set, False otherwise.
|
||||
"""
|
||||
return self._indices[self[x]] == self._indices[self[y]]
|
||||
|
||||
def subset(self, x):
|
||||
"""Get the subset containing `x`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : hashable object
|
||||
Input element.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : set
|
||||
Subset containing `x`.
|
||||
"""
|
||||
if x not in self._indices:
|
||||
raise KeyError(x)
|
||||
|
||||
result = [x]
|
||||
nxt = self._nbrs[x]
|
||||
while self._indices[nxt] != self._indices[x]:
|
||||
result.append(nxt)
|
||||
nxt = self._nbrs[nxt]
|
||||
return set(result)
|
||||
|
||||
def subset_size(self, x):
|
||||
"""Get the size of the subset containing `x`.
|
||||
|
||||
Note that this method is faster than ``len(self.subset(x))`` because
|
||||
the size is directly read off an internal field, without the need to
|
||||
instantiate the full subset.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : hashable object
|
||||
Input element.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : int
|
||||
Size of the subset containing `x`.
|
||||
"""
|
||||
return self._sizes[self[x]]
|
||||
|
||||
def subsets(self):
|
||||
"""Get all the subsets in the disjoint set.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : list
|
||||
Subsets in the disjoint set.
|
||||
"""
|
||||
result = []
|
||||
visited = set()
|
||||
for x in self:
|
||||
if x not in visited:
|
||||
xset = self.subset(x)
|
||||
visited.update(xset)
|
||||
result.append(xset)
|
||||
return result
|
||||
679
venv/lib/python3.12/site-packages/scipy/_lib/_docscrape.py
Normal file
679
venv/lib/python3.12/site-packages/scipy/_lib/_docscrape.py
Normal file
@ -0,0 +1,679 @@
|
||||
"""Extract reference documentation from the NumPy source tree.
|
||||
|
||||
"""
|
||||
# copied from numpydoc/docscrape.py
|
||||
import inspect
|
||||
import textwrap
|
||||
import re
|
||||
import pydoc
|
||||
from warnings import warn
|
||||
from collections import namedtuple
|
||||
from collections.abc import Callable, Mapping
|
||||
import copy
|
||||
import sys
|
||||
|
||||
|
||||
def strip_blank_lines(l):
|
||||
"Remove leading and trailing blank lines from a list of lines"
|
||||
while l and not l[0].strip():
|
||||
del l[0]
|
||||
while l and not l[-1].strip():
|
||||
del l[-1]
|
||||
return l
|
||||
|
||||
|
||||
class Reader:
|
||||
"""A line-based string reader.
|
||||
|
||||
"""
|
||||
def __init__(self, data):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
data : str
|
||||
String with lines separated by '\\n'.
|
||||
|
||||
"""
|
||||
if isinstance(data, list):
|
||||
self._str = data
|
||||
else:
|
||||
self._str = data.split('\n') # store string as list of lines
|
||||
|
||||
self.reset()
|
||||
|
||||
def __getitem__(self, n):
|
||||
return self._str[n]
|
||||
|
||||
def reset(self):
|
||||
self._l = 0 # current line nr
|
||||
|
||||
def read(self):
|
||||
if not self.eof():
|
||||
out = self[self._l]
|
||||
self._l += 1
|
||||
return out
|
||||
else:
|
||||
return ''
|
||||
|
||||
def seek_next_non_empty_line(self):
|
||||
for l in self[self._l:]:
|
||||
if l.strip():
|
||||
break
|
||||
else:
|
||||
self._l += 1
|
||||
|
||||
def eof(self):
|
||||
return self._l >= len(self._str)
|
||||
|
||||
def read_to_condition(self, condition_func):
|
||||
start = self._l
|
||||
for line in self[start:]:
|
||||
if condition_func(line):
|
||||
return self[start:self._l]
|
||||
self._l += 1
|
||||
if self.eof():
|
||||
return self[start:self._l+1]
|
||||
return []
|
||||
|
||||
def read_to_next_empty_line(self):
|
||||
self.seek_next_non_empty_line()
|
||||
|
||||
def is_empty(line):
|
||||
return not line.strip()
|
||||
|
||||
return self.read_to_condition(is_empty)
|
||||
|
||||
def read_to_next_unindented_line(self):
|
||||
def is_unindented(line):
|
||||
return (line.strip() and (len(line.lstrip()) == len(line)))
|
||||
return self.read_to_condition(is_unindented)
|
||||
|
||||
def peek(self, n=0):
|
||||
if self._l + n < len(self._str):
|
||||
return self[self._l + n]
|
||||
else:
|
||||
return ''
|
||||
|
||||
def is_empty(self):
|
||||
return not ''.join(self._str).strip()
|
||||
|
||||
|
||||
class ParseError(Exception):
|
||||
def __str__(self):
|
||||
message = self.args[0]
|
||||
if hasattr(self, 'docstring'):
|
||||
message = f"{message} in {self.docstring!r}"
|
||||
return message
|
||||
|
||||
|
||||
Parameter = namedtuple('Parameter', ['name', 'type', 'desc'])
|
||||
|
||||
|
||||
class NumpyDocString(Mapping):
|
||||
"""Parses a numpydoc string to an abstract representation
|
||||
|
||||
Instances define a mapping from section title to structured data.
|
||||
|
||||
"""
|
||||
|
||||
sections = {
|
||||
'Signature': '',
|
||||
'Summary': [''],
|
||||
'Extended Summary': [],
|
||||
'Parameters': [],
|
||||
'Returns': [],
|
||||
'Yields': [],
|
||||
'Receives': [],
|
||||
'Raises': [],
|
||||
'Warns': [],
|
||||
'Other Parameters': [],
|
||||
'Attributes': [],
|
||||
'Methods': [],
|
||||
'See Also': [],
|
||||
'Notes': [],
|
||||
'Warnings': [],
|
||||
'References': '',
|
||||
'Examples': '',
|
||||
'index': {}
|
||||
}
|
||||
|
||||
def __init__(self, docstring, config={}):
|
||||
orig_docstring = docstring
|
||||
docstring = textwrap.dedent(docstring).split('\n')
|
||||
|
||||
self._doc = Reader(docstring)
|
||||
self._parsed_data = copy.deepcopy(self.sections)
|
||||
|
||||
try:
|
||||
self._parse()
|
||||
except ParseError as e:
|
||||
e.docstring = orig_docstring
|
||||
raise
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._parsed_data[key]
|
||||
|
||||
def __setitem__(self, key, val):
|
||||
if key not in self._parsed_data:
|
||||
self._error_location("Unknown section %s" % key, error=False)
|
||||
else:
|
||||
self._parsed_data[key] = val
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._parsed_data)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._parsed_data)
|
||||
|
||||
def _is_at_section(self):
|
||||
self._doc.seek_next_non_empty_line()
|
||||
|
||||
if self._doc.eof():
|
||||
return False
|
||||
|
||||
l1 = self._doc.peek().strip() # e.g. Parameters
|
||||
|
||||
if l1.startswith('.. index::'):
|
||||
return True
|
||||
|
||||
l2 = self._doc.peek(1).strip() # ---------- or ==========
|
||||
return l2.startswith('-'*len(l1)) or l2.startswith('='*len(l1))
|
||||
|
||||
def _strip(self, doc):
|
||||
i = 0
|
||||
j = 0
|
||||
for i, line in enumerate(doc):
|
||||
if line.strip():
|
||||
break
|
||||
|
||||
for j, line in enumerate(doc[::-1]):
|
||||
if line.strip():
|
||||
break
|
||||
|
||||
return doc[i:len(doc)-j]
|
||||
|
||||
def _read_to_next_section(self):
|
||||
section = self._doc.read_to_next_empty_line()
|
||||
|
||||
while not self._is_at_section() and not self._doc.eof():
|
||||
if not self._doc.peek(-1).strip(): # previous line was empty
|
||||
section += ['']
|
||||
|
||||
section += self._doc.read_to_next_empty_line()
|
||||
|
||||
return section
|
||||
|
||||
def _read_sections(self):
|
||||
while not self._doc.eof():
|
||||
data = self._read_to_next_section()
|
||||
name = data[0].strip()
|
||||
|
||||
if name.startswith('..'): # index section
|
||||
yield name, data[1:]
|
||||
elif len(data) < 2:
|
||||
yield StopIteration
|
||||
else:
|
||||
yield name, self._strip(data[2:])
|
||||
|
||||
def _parse_param_list(self, content, single_element_is_type=False):
|
||||
r = Reader(content)
|
||||
params = []
|
||||
while not r.eof():
|
||||
header = r.read().strip()
|
||||
if ' : ' in header:
|
||||
arg_name, arg_type = header.split(' : ')[:2]
|
||||
else:
|
||||
if single_element_is_type:
|
||||
arg_name, arg_type = '', header
|
||||
else:
|
||||
arg_name, arg_type = header, ''
|
||||
|
||||
desc = r.read_to_next_unindented_line()
|
||||
desc = dedent_lines(desc)
|
||||
desc = strip_blank_lines(desc)
|
||||
|
||||
params.append(Parameter(arg_name, arg_type, desc))
|
||||
|
||||
return params
|
||||
|
||||
# See also supports the following formats.
|
||||
#
|
||||
# <FUNCNAME>
|
||||
# <FUNCNAME> SPACE* COLON SPACE+ <DESC> SPACE*
|
||||
# <FUNCNAME> ( COMMA SPACE+ <FUNCNAME>)+ (COMMA | PERIOD)? SPACE*
|
||||
# <FUNCNAME> ( COMMA SPACE+ <FUNCNAME>)* SPACE* COLON SPACE+ <DESC> SPACE*
|
||||
|
||||
# <FUNCNAME> is one of
|
||||
# <PLAIN_FUNCNAME>
|
||||
# COLON <ROLE> COLON BACKTICK <PLAIN_FUNCNAME> BACKTICK
|
||||
# where
|
||||
# <PLAIN_FUNCNAME> is a legal function name, and
|
||||
# <ROLE> is any nonempty sequence of word characters.
|
||||
# Examples: func_f1 :meth:`func_h1` :obj:`~baz.obj_r` :class:`class_j`
|
||||
# <DESC> is a string describing the function.
|
||||
|
||||
_role = r":(?P<role>\w+):"
|
||||
_funcbacktick = r"`(?P<name>(?:~\w+\.)?[a-zA-Z0-9_\.-]+)`"
|
||||
_funcplain = r"(?P<name2>[a-zA-Z0-9_\.-]+)"
|
||||
_funcname = r"(" + _role + _funcbacktick + r"|" + _funcplain + r")"
|
||||
_funcnamenext = _funcname.replace('role', 'rolenext')
|
||||
_funcnamenext = _funcnamenext.replace('name', 'namenext')
|
||||
_description = r"(?P<description>\s*:(\s+(?P<desc>\S+.*))?)?\s*$"
|
||||
_func_rgx = re.compile(r"^\s*" + _funcname + r"\s*")
|
||||
_line_rgx = re.compile(
|
||||
r"^\s*" +
|
||||
r"(?P<allfuncs>" + # group for all function names
|
||||
_funcname +
|
||||
r"(?P<morefuncs>([,]\s+" + _funcnamenext + r")*)" +
|
||||
r")" + # end of "allfuncs"
|
||||
# Some function lists have a trailing comma (or period) '\s*'
|
||||
r"(?P<trailing>[,\.])?" +
|
||||
_description)
|
||||
|
||||
# Empty <DESC> elements are replaced with '..'
|
||||
empty_description = '..'
|
||||
|
||||
def _parse_see_also(self, content):
|
||||
"""
|
||||
func_name : Descriptive text
|
||||
continued text
|
||||
another_func_name : Descriptive text
|
||||
func_name1, func_name2, :meth:`func_name`, func_name3
|
||||
|
||||
"""
|
||||
|
||||
items = []
|
||||
|
||||
def parse_item_name(text):
|
||||
"""Match ':role:`name`' or 'name'."""
|
||||
m = self._func_rgx.match(text)
|
||||
if not m:
|
||||
raise ParseError("%s is not a item name" % text)
|
||||
role = m.group('role')
|
||||
name = m.group('name') if role else m.group('name2')
|
||||
return name, role, m.end()
|
||||
|
||||
rest = []
|
||||
for line in content:
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
line_match = self._line_rgx.match(line)
|
||||
description = None
|
||||
if line_match:
|
||||
description = line_match.group('desc')
|
||||
if line_match.group('trailing') and description:
|
||||
self._error_location(
|
||||
'Unexpected comma or period after function list at '
|
||||
'index %d of line "%s"' % (line_match.end('trailing'),
|
||||
line),
|
||||
error=False)
|
||||
if not description and line.startswith(' '):
|
||||
rest.append(line.strip())
|
||||
elif line_match:
|
||||
funcs = []
|
||||
text = line_match.group('allfuncs')
|
||||
while True:
|
||||
if not text.strip():
|
||||
break
|
||||
name, role, match_end = parse_item_name(text)
|
||||
funcs.append((name, role))
|
||||
text = text[match_end:].strip()
|
||||
if text and text[0] == ',':
|
||||
text = text[1:].strip()
|
||||
rest = list(filter(None, [description]))
|
||||
items.append((funcs, rest))
|
||||
else:
|
||||
raise ParseError("%s is not a item name" % line)
|
||||
return items
|
||||
|
||||
def _parse_index(self, section, content):
|
||||
"""
|
||||
.. index:: default
|
||||
:refguide: something, else, and more
|
||||
|
||||
"""
|
||||
def strip_each_in(lst):
|
||||
return [s.strip() for s in lst]
|
||||
|
||||
out = {}
|
||||
section = section.split('::')
|
||||
if len(section) > 1:
|
||||
out['default'] = strip_each_in(section[1].split(','))[0]
|
||||
for line in content:
|
||||
line = line.split(':')
|
||||
if len(line) > 2:
|
||||
out[line[1]] = strip_each_in(line[2].split(','))
|
||||
return out
|
||||
|
||||
def _parse_summary(self):
|
||||
"""Grab signature (if given) and summary"""
|
||||
if self._is_at_section():
|
||||
return
|
||||
|
||||
# If several signatures present, take the last one
|
||||
while True:
|
||||
summary = self._doc.read_to_next_empty_line()
|
||||
summary_str = " ".join([s.strip() for s in summary]).strip()
|
||||
compiled = re.compile(r'^([\w., ]+=)?\s*[\w\.]+\(.*\)$')
|
||||
if compiled.match(summary_str):
|
||||
self['Signature'] = summary_str
|
||||
if not self._is_at_section():
|
||||
continue
|
||||
break
|
||||
|
||||
if summary is not None:
|
||||
self['Summary'] = summary
|
||||
|
||||
if not self._is_at_section():
|
||||
self['Extended Summary'] = self._read_to_next_section()
|
||||
|
||||
def _parse(self):
|
||||
self._doc.reset()
|
||||
self._parse_summary()
|
||||
|
||||
sections = list(self._read_sections())
|
||||
section_names = {section for section, content in sections}
|
||||
|
||||
has_returns = 'Returns' in section_names
|
||||
has_yields = 'Yields' in section_names
|
||||
# We could do more tests, but we are not. Arbitrarily.
|
||||
if has_returns and has_yields:
|
||||
msg = 'Docstring contains both a Returns and Yields section.'
|
||||
raise ValueError(msg)
|
||||
if not has_yields and 'Receives' in section_names:
|
||||
msg = 'Docstring contains a Receives section but not Yields.'
|
||||
raise ValueError(msg)
|
||||
|
||||
for (section, content) in sections:
|
||||
if not section.startswith('..'):
|
||||
section = (s.capitalize() for s in section.split(' '))
|
||||
section = ' '.join(section)
|
||||
if self.get(section):
|
||||
self._error_location("The section %s appears twice"
|
||||
% section)
|
||||
|
||||
if section in ('Parameters', 'Other Parameters', 'Attributes',
|
||||
'Methods'):
|
||||
self[section] = self._parse_param_list(content)
|
||||
elif section in ('Returns', 'Yields', 'Raises', 'Warns',
|
||||
'Receives'):
|
||||
self[section] = self._parse_param_list(
|
||||
content, single_element_is_type=True)
|
||||
elif section.startswith('.. index::'):
|
||||
self['index'] = self._parse_index(section, content)
|
||||
elif section == 'See Also':
|
||||
self['See Also'] = self._parse_see_also(content)
|
||||
else:
|
||||
self[section] = content
|
||||
|
||||
def _error_location(self, msg, error=True):
|
||||
if hasattr(self, '_obj'):
|
||||
# we know where the docs came from:
|
||||
try:
|
||||
filename = inspect.getsourcefile(self._obj)
|
||||
except TypeError:
|
||||
filename = None
|
||||
msg = msg + (f" in the docstring of {self._obj} in {filename}.")
|
||||
if error:
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
warn(msg, stacklevel=3)
|
||||
|
||||
# string conversion routines
|
||||
|
||||
def _str_header(self, name, symbol='-'):
|
||||
return [name, len(name)*symbol]
|
||||
|
||||
def _str_indent(self, doc, indent=4):
|
||||
out = []
|
||||
for line in doc:
|
||||
out += [' '*indent + line]
|
||||
return out
|
||||
|
||||
def _str_signature(self):
|
||||
if self['Signature']:
|
||||
return [self['Signature'].replace('*', r'\*')] + ['']
|
||||
else:
|
||||
return ['']
|
||||
|
||||
def _str_summary(self):
|
||||
if self['Summary']:
|
||||
return self['Summary'] + ['']
|
||||
else:
|
||||
return []
|
||||
|
||||
def _str_extended_summary(self):
|
||||
if self['Extended Summary']:
|
||||
return self['Extended Summary'] + ['']
|
||||
else:
|
||||
return []
|
||||
|
||||
def _str_param_list(self, name):
|
||||
out = []
|
||||
if self[name]:
|
||||
out += self._str_header(name)
|
||||
for param in self[name]:
|
||||
parts = []
|
||||
if param.name:
|
||||
parts.append(param.name)
|
||||
if param.type:
|
||||
parts.append(param.type)
|
||||
out += [' : '.join(parts)]
|
||||
if param.desc and ''.join(param.desc).strip():
|
||||
out += self._str_indent(param.desc)
|
||||
out += ['']
|
||||
return out
|
||||
|
||||
def _str_section(self, name):
|
||||
out = []
|
||||
if self[name]:
|
||||
out += self._str_header(name)
|
||||
out += self[name]
|
||||
out += ['']
|
||||
return out
|
||||
|
||||
def _str_see_also(self, func_role):
|
||||
if not self['See Also']:
|
||||
return []
|
||||
out = []
|
||||
out += self._str_header("See Also")
|
||||
out += ['']
|
||||
last_had_desc = True
|
||||
for funcs, desc in self['See Also']:
|
||||
assert isinstance(funcs, list)
|
||||
links = []
|
||||
for func, role in funcs:
|
||||
if role:
|
||||
link = f':{role}:`{func}`'
|
||||
elif func_role:
|
||||
link = f':{func_role}:`{func}`'
|
||||
else:
|
||||
link = "`%s`_" % func
|
||||
links.append(link)
|
||||
link = ', '.join(links)
|
||||
out += [link]
|
||||
if desc:
|
||||
out += self._str_indent([' '.join(desc)])
|
||||
last_had_desc = True
|
||||
else:
|
||||
last_had_desc = False
|
||||
out += self._str_indent([self.empty_description])
|
||||
|
||||
if last_had_desc:
|
||||
out += ['']
|
||||
out += ['']
|
||||
return out
|
||||
|
||||
def _str_index(self):
|
||||
idx = self['index']
|
||||
out = []
|
||||
output_index = False
|
||||
default_index = idx.get('default', '')
|
||||
if default_index:
|
||||
output_index = True
|
||||
out += ['.. index:: %s' % default_index]
|
||||
for section, references in idx.items():
|
||||
if section == 'default':
|
||||
continue
|
||||
output_index = True
|
||||
out += [' :{}: {}'.format(section, ', '.join(references))]
|
||||
if output_index:
|
||||
return out
|
||||
else:
|
||||
return ''
|
||||
|
||||
def __str__(self, func_role=''):
|
||||
out = []
|
||||
out += self._str_signature()
|
||||
out += self._str_summary()
|
||||
out += self._str_extended_summary()
|
||||
for param_list in ('Parameters', 'Returns', 'Yields', 'Receives',
|
||||
'Other Parameters', 'Raises', 'Warns'):
|
||||
out += self._str_param_list(param_list)
|
||||
out += self._str_section('Warnings')
|
||||
out += self._str_see_also(func_role)
|
||||
for s in ('Notes', 'References', 'Examples'):
|
||||
out += self._str_section(s)
|
||||
for param_list in ('Attributes', 'Methods'):
|
||||
out += self._str_param_list(param_list)
|
||||
out += self._str_index()
|
||||
return '\n'.join(out)
|
||||
|
||||
|
||||
def indent(str, indent=4):
|
||||
indent_str = ' '*indent
|
||||
if str is None:
|
||||
return indent_str
|
||||
lines = str.split('\n')
|
||||
return '\n'.join(indent_str + l for l in lines)
|
||||
|
||||
|
||||
def dedent_lines(lines):
|
||||
"""Deindent a list of lines maximally"""
|
||||
return textwrap.dedent("\n".join(lines)).split("\n")
|
||||
|
||||
|
||||
def header(text, style='-'):
|
||||
return text + '\n' + style*len(text) + '\n'
|
||||
|
||||
|
||||
class FunctionDoc(NumpyDocString):
|
||||
def __init__(self, func, role='func', doc=None, config={}):
|
||||
self._f = func
|
||||
self._role = role # e.g. "func" or "meth"
|
||||
|
||||
if doc is None:
|
||||
if func is None:
|
||||
raise ValueError("No function or docstring given")
|
||||
doc = inspect.getdoc(func) or ''
|
||||
NumpyDocString.__init__(self, doc, config)
|
||||
|
||||
def get_func(self):
|
||||
func_name = getattr(self._f, '__name__', self.__class__.__name__)
|
||||
if inspect.isclass(self._f):
|
||||
func = getattr(self._f, '__call__', self._f.__init__)
|
||||
else:
|
||||
func = self._f
|
||||
return func, func_name
|
||||
|
||||
def __str__(self):
|
||||
out = ''
|
||||
|
||||
func, func_name = self.get_func()
|
||||
|
||||
roles = {'func': 'function',
|
||||
'meth': 'method'}
|
||||
|
||||
if self._role:
|
||||
if self._role not in roles:
|
||||
print("Warning: invalid role %s" % self._role)
|
||||
out += '.. {}:: {}\n \n\n'.format(roles.get(self._role, ''),
|
||||
func_name)
|
||||
|
||||
out += super().__str__(func_role=self._role)
|
||||
return out
|
||||
|
||||
|
||||
class ClassDoc(NumpyDocString):
|
||||
|
||||
extra_public_methods = ['__call__']
|
||||
|
||||
def __init__(self, cls, doc=None, modulename='', func_doc=FunctionDoc,
|
||||
config={}):
|
||||
if not inspect.isclass(cls) and cls is not None:
|
||||
raise ValueError("Expected a class or None, but got %r" % cls)
|
||||
self._cls = cls
|
||||
|
||||
if 'sphinx' in sys.modules:
|
||||
from sphinx.ext.autodoc import ALL
|
||||
else:
|
||||
ALL = object()
|
||||
|
||||
self.show_inherited_members = config.get(
|
||||
'show_inherited_class_members', True)
|
||||
|
||||
if modulename and not modulename.endswith('.'):
|
||||
modulename += '.'
|
||||
self._mod = modulename
|
||||
|
||||
if doc is None:
|
||||
if cls is None:
|
||||
raise ValueError("No class or documentation string given")
|
||||
doc = pydoc.getdoc(cls)
|
||||
|
||||
NumpyDocString.__init__(self, doc)
|
||||
|
||||
_members = config.get('members', [])
|
||||
if _members is ALL:
|
||||
_members = None
|
||||
_exclude = config.get('exclude-members', [])
|
||||
|
||||
if config.get('show_class_members', True) and _exclude is not ALL:
|
||||
def splitlines_x(s):
|
||||
if not s:
|
||||
return []
|
||||
else:
|
||||
return s.splitlines()
|
||||
for field, items in [('Methods', self.methods),
|
||||
('Attributes', self.properties)]:
|
||||
if not self[field]:
|
||||
doc_list = []
|
||||
for name in sorted(items):
|
||||
if (name in _exclude or
|
||||
(_members and name not in _members)):
|
||||
continue
|
||||
try:
|
||||
doc_item = pydoc.getdoc(getattr(self._cls, name))
|
||||
doc_list.append(
|
||||
Parameter(name, '', splitlines_x(doc_item)))
|
||||
except AttributeError:
|
||||
pass # method doesn't exist
|
||||
self[field] = doc_list
|
||||
|
||||
@property
|
||||
def methods(self):
|
||||
if self._cls is None:
|
||||
return []
|
||||
return [name for name, func in inspect.getmembers(self._cls)
|
||||
if ((not name.startswith('_')
|
||||
or name in self.extra_public_methods)
|
||||
and isinstance(func, Callable)
|
||||
and self._is_show_member(name))]
|
||||
|
||||
@property
|
||||
def properties(self):
|
||||
if self._cls is None:
|
||||
return []
|
||||
return [name for name, func in inspect.getmembers(self._cls)
|
||||
if (not name.startswith('_') and
|
||||
(func is None or isinstance(func, property) or
|
||||
inspect.isdatadescriptor(func))
|
||||
and self._is_show_member(name))]
|
||||
|
||||
def _is_show_member(self, name):
|
||||
if self.show_inherited_members:
|
||||
return True # show all class members
|
||||
if name not in self._cls.__dict__:
|
||||
return False # class member is inherited, we do not show it
|
||||
return True
|
||||
@ -0,0 +1,348 @@
|
||||
# `_elementwise_iterative_method.py` includes tools for writing functions that
|
||||
# - are vectorized to work elementwise on arrays,
|
||||
# - implement non-trivial, iterative algorithms with a callback interface, and
|
||||
# - return rich objects with iteration count, termination status, etc.
|
||||
#
|
||||
# Examples include:
|
||||
# `scipy.optimize._chandrupatla._chandrupatla for scalar rootfinding,
|
||||
# `scipy.optimize._chandrupatla._chandrupatla_minimize for scalar minimization,
|
||||
# `scipy.optimize._differentiate._differentiate for numerical differentiation,
|
||||
# `scipy.optimize._bracket._bracket_root for finding rootfinding brackets,
|
||||
# `scipy.optimize._bracket._bracket_minimize for finding minimization brackets,
|
||||
# `scipy.integrate._tanhsinh._tanhsinh` for numerical quadrature.
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
from ._util import _RichResult, _call_callback_maybe_halt
|
||||
from ._array_api import array_namespace, size as xp_size
|
||||
|
||||
_ESIGNERR = -1
|
||||
_ECONVERR = -2
|
||||
_EVALUEERR = -3
|
||||
_ECALLBACK = -4
|
||||
_EINPUTERR = -5
|
||||
_ECONVERGED = 0
|
||||
_EINPROGRESS = 1
|
||||
|
||||
def _initialize(func, xs, args, complex_ok=False, preserve_shape=None):
|
||||
"""Initialize abscissa, function, and args arrays for elementwise function
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : callable
|
||||
An elementwise function with signature
|
||||
|
||||
func(x: ndarray, *args) -> ndarray
|
||||
|
||||
where each element of ``x`` is a finite real and ``args`` is a tuple,
|
||||
which may contain an arbitrary number of arrays that are broadcastable
|
||||
with ``x``.
|
||||
xs : tuple of arrays
|
||||
Finite real abscissa arrays. Must be broadcastable.
|
||||
args : tuple, optional
|
||||
Additional positional arguments to be passed to `func`.
|
||||
preserve_shape : bool, default:False
|
||||
When ``preserve_shape=False`` (default), `func` may be passed
|
||||
arguments of any shape; `_scalar_optimization_loop` is permitted
|
||||
to reshape and compress arguments at will. When
|
||||
``preserve_shape=False``, arguments passed to `func` must have shape
|
||||
`shape` or ``shape + (n,)``, where ``n`` is any integer.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xs, fs, args : tuple of arrays
|
||||
Broadcasted, writeable, 1D abscissa and function value arrays (or
|
||||
NumPy floats, if appropriate). The dtypes of the `xs` and `fs` are
|
||||
`xfat`; the dtype of the `args` are unchanged.
|
||||
shape : tuple of ints
|
||||
Original shape of broadcasted arrays.
|
||||
xfat : NumPy dtype
|
||||
Result dtype of abscissae, function values, and args determined using
|
||||
`np.result_type`, except integer types are promoted to `np.float64`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the result dtype is not that of a real scalar
|
||||
|
||||
Notes
|
||||
-----
|
||||
Useful for initializing the input of SciPy functions that accept
|
||||
an elementwise callable, abscissae, and arguments; e.g.
|
||||
`scipy.optimize._chandrupatla`.
|
||||
"""
|
||||
nx = len(xs)
|
||||
xp = array_namespace(*xs)
|
||||
|
||||
# Try to preserve `dtype`, but we need to ensure that the arguments are at
|
||||
# least floats before passing them into the function; integers can overflow
|
||||
# and cause failure.
|
||||
# There might be benefit to combining the `xs` into a single array and
|
||||
# calling `func` once on the combined array. For now, keep them separate.
|
||||
xas = xp.broadcast_arrays(*xs, *args) # broadcast and rename
|
||||
xat = xp.result_type(*[xa.dtype for xa in xas])
|
||||
xat = xp.asarray(1.).dtype if xp.isdtype(xat, "integral") else xat
|
||||
xs, args = xas[:nx], xas[nx:]
|
||||
xs = [xp.asarray(x, dtype=xat) for x in xs] # use copy=False when implemented
|
||||
fs = [xp.asarray(func(x, *args)) for x in xs]
|
||||
shape = xs[0].shape
|
||||
fshape = fs[0].shape
|
||||
|
||||
if preserve_shape:
|
||||
# bind original shape/func now to avoid late-binding gotcha
|
||||
def func(x, *args, shape=shape, func=func, **kwargs):
|
||||
i = (0,)*(len(fshape) - len(shape))
|
||||
return func(x[i], *args, **kwargs)
|
||||
shape = np.broadcast_shapes(fshape, shape) # just shapes; use of NumPy OK
|
||||
xs = [xp.broadcast_to(x, shape) for x in xs]
|
||||
args = [xp.broadcast_to(arg, shape) for arg in args]
|
||||
|
||||
message = ("The shape of the array returned by `func` must be the same as "
|
||||
"the broadcasted shape of `x` and all other `args`.")
|
||||
if preserve_shape is not None: # only in tanhsinh for now
|
||||
message = f"When `preserve_shape=False`, {message.lower()}"
|
||||
shapes_equal = [f.shape == shape for f in fs]
|
||||
if not all(shapes_equal): # use Python all to reduce overhead
|
||||
raise ValueError(message)
|
||||
|
||||
# These algorithms tend to mix the dtypes of the abscissae and function
|
||||
# values, so figure out what the result will be and convert them all to
|
||||
# that type from the outset.
|
||||
xfat = xp.result_type(*([f.dtype for f in fs] + [xat]))
|
||||
if not complex_ok and not xp.isdtype(xfat, "real floating"):
|
||||
raise ValueError("Abscissae and function output must be real numbers.")
|
||||
xs = [xp.asarray(x, dtype=xfat, copy=True) for x in xs]
|
||||
fs = [xp.asarray(f, dtype=xfat, copy=True) for f in fs]
|
||||
|
||||
# To ensure that we can do indexing, we'll work with at least 1d arrays,
|
||||
# but remember the appropriate shape of the output.
|
||||
xs = [xp.reshape(x, (-1,)) for x in xs]
|
||||
fs = [xp.reshape(f, (-1,)) for f in fs]
|
||||
args = [xp.reshape(xp.asarray(arg, copy=True), (-1,)) for arg in args]
|
||||
return func, xs, fs, args, shape, xfat, xp
|
||||
|
||||
|
||||
def _loop(work, callback, shape, maxiter, func, args, dtype, pre_func_eval,
|
||||
post_func_eval, check_termination, post_termination_check,
|
||||
customize_result, res_work_pairs, xp, preserve_shape=False):
|
||||
"""Main loop of a vectorized scalar optimization algorithm
|
||||
|
||||
Parameters
|
||||
----------
|
||||
work : _RichResult
|
||||
All variables that need to be retained between iterations. Must
|
||||
contain attributes `nit`, `nfev`, and `success`
|
||||
callback : callable
|
||||
User-specified callback function
|
||||
shape : tuple of ints
|
||||
The shape of all output arrays
|
||||
maxiter :
|
||||
Maximum number of iterations of the algorithm
|
||||
func : callable
|
||||
The user-specified callable that is being optimized or solved
|
||||
args : tuple
|
||||
Additional positional arguments to be passed to `func`.
|
||||
dtype : NumPy dtype
|
||||
The common dtype of all abscissae and function values
|
||||
pre_func_eval : callable
|
||||
A function that accepts `work` and returns `x`, the active elements
|
||||
of `x` at which `func` will be evaluated. May modify attributes
|
||||
of `work` with any algorithmic steps that need to happen
|
||||
at the beginning of an iteration, before `func` is evaluated,
|
||||
post_func_eval : callable
|
||||
A function that accepts `x`, `func(x)`, and `work`. May modify
|
||||
attributes of `work` with any algorithmic steps that need to happen
|
||||
in the middle of an iteration, after `func` is evaluated but before
|
||||
the termination check.
|
||||
check_termination : callable
|
||||
A function that accepts `work` and returns `stop`, a boolean array
|
||||
indicating which of the active elements have met a termination
|
||||
condition.
|
||||
post_termination_check : callable
|
||||
A function that accepts `work`. May modify `work` with any algorithmic
|
||||
steps that need to happen after the termination check and before the
|
||||
end of the iteration.
|
||||
customize_result : callable
|
||||
A function that accepts `res` and `shape` and returns `shape`. May
|
||||
modify `res` (in-place) according to preferences (e.g. rearrange
|
||||
elements between attributes) and modify `shape` if needed.
|
||||
res_work_pairs : list of (str, str)
|
||||
Identifies correspondence between attributes of `res` and attributes
|
||||
of `work`; i.e., attributes of active elements of `work` will be
|
||||
copied to the appropriate indices of `res` when appropriate. The order
|
||||
determines the order in which _RichResult attributes will be
|
||||
pretty-printed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
res : _RichResult
|
||||
The final result object
|
||||
|
||||
Notes
|
||||
-----
|
||||
Besides providing structure, this framework provides several important
|
||||
services for a vectorized optimization algorithm.
|
||||
|
||||
- It handles common tasks involving iteration count, function evaluation
|
||||
count, a user-specified callback, and associated termination conditions.
|
||||
- It compresses the attributes of `work` to eliminate unnecessary
|
||||
computation on elements that have already converged.
|
||||
|
||||
"""
|
||||
if xp is None:
|
||||
raise NotImplementedError("Must provide xp.")
|
||||
|
||||
cb_terminate = False
|
||||
|
||||
# Initialize the result object and active element index array
|
||||
n_elements = math.prod(shape)
|
||||
active = xp.arange(n_elements) # in-progress element indices
|
||||
res_dict = {i: xp.zeros(n_elements, dtype=dtype) for i, j in res_work_pairs}
|
||||
res_dict['success'] = xp.zeros(n_elements, dtype=xp.bool)
|
||||
res_dict['status'] = xp.full(n_elements, _EINPROGRESS, dtype=xp.int32)
|
||||
res_dict['nit'] = xp.zeros(n_elements, dtype=xp.int32)
|
||||
res_dict['nfev'] = xp.zeros(n_elements, dtype=xp.int32)
|
||||
res = _RichResult(res_dict)
|
||||
work.args = args
|
||||
|
||||
active = _check_termination(work, res, res_work_pairs, active,
|
||||
check_termination, preserve_shape, xp)
|
||||
|
||||
if callback is not None:
|
||||
temp = _prepare_result(work, res, res_work_pairs, active, shape,
|
||||
customize_result, preserve_shape, xp)
|
||||
if _call_callback_maybe_halt(callback, temp):
|
||||
cb_terminate = True
|
||||
|
||||
while work.nit < maxiter and xp_size(active) and not cb_terminate and n_elements:
|
||||
x = pre_func_eval(work)
|
||||
|
||||
if work.args and work.args[0].ndim != x.ndim:
|
||||
# `x` always starts as 1D. If the SciPy function that uses
|
||||
# _loop added dimensions to `x`, we need to
|
||||
# add them to the elements of `args`.
|
||||
args = []
|
||||
for arg in work.args:
|
||||
n_new_dims = x.ndim - arg.ndim
|
||||
new_shape = arg.shape + (1,)*n_new_dims
|
||||
args.append(xp.reshape(arg, new_shape))
|
||||
work.args = args
|
||||
|
||||
x_shape = x.shape
|
||||
if preserve_shape:
|
||||
x = xp.reshape(x, (shape + (-1,)))
|
||||
f = func(x, *work.args)
|
||||
f = xp.asarray(f, dtype=dtype)
|
||||
if preserve_shape:
|
||||
x = xp.reshape(x, x_shape)
|
||||
f = xp.reshape(f, x_shape)
|
||||
work.nfev += 1 if x.ndim == 1 else x.shape[-1]
|
||||
|
||||
post_func_eval(x, f, work)
|
||||
|
||||
work.nit += 1
|
||||
active = _check_termination(work, res, res_work_pairs, active,
|
||||
check_termination, preserve_shape, xp)
|
||||
|
||||
if callback is not None:
|
||||
temp = _prepare_result(work, res, res_work_pairs, active, shape,
|
||||
customize_result, preserve_shape, xp)
|
||||
if _call_callback_maybe_halt(callback, temp):
|
||||
cb_terminate = True
|
||||
break
|
||||
if xp_size(active) == 0:
|
||||
break
|
||||
|
||||
post_termination_check(work)
|
||||
|
||||
work.status[:] = _ECALLBACK if cb_terminate else _ECONVERR
|
||||
return _prepare_result(work, res, res_work_pairs, active, shape,
|
||||
customize_result, preserve_shape, xp)
|
||||
|
||||
|
||||
def _check_termination(work, res, res_work_pairs, active, check_termination,
|
||||
preserve_shape, xp):
|
||||
# Checks termination conditions, updates elements of `res` with
|
||||
# corresponding elements of `work`, and compresses `work`.
|
||||
|
||||
stop = check_termination(work)
|
||||
|
||||
if xp.any(stop):
|
||||
# update the active elements of the result object with the active
|
||||
# elements for which a termination condition has been met
|
||||
_update_active(work, res, res_work_pairs, active, stop, preserve_shape, xp)
|
||||
|
||||
if preserve_shape:
|
||||
stop = stop[active]
|
||||
|
||||
proceed = ~stop
|
||||
active = active[proceed]
|
||||
|
||||
if not preserve_shape:
|
||||
# compress the arrays to avoid unnecessary computation
|
||||
for key, val in work.items():
|
||||
# Need to find a better way than these try/excepts
|
||||
# Somehow need to keep compressible numerical args separate
|
||||
if key == 'args':
|
||||
continue
|
||||
try:
|
||||
work[key] = val[proceed]
|
||||
except (IndexError, TypeError, KeyError): # not a compressible array
|
||||
work[key] = val
|
||||
work.args = [arg[proceed] for arg in work.args]
|
||||
|
||||
return active
|
||||
|
||||
|
||||
def _update_active(work, res, res_work_pairs, active, mask, preserve_shape, xp):
|
||||
# Update `active` indices of the arrays in result object `res` with the
|
||||
# contents of the scalars and arrays in `update_dict`. When provided,
|
||||
# `mask` is a boolean array applied both to the arrays in `update_dict`
|
||||
# that are to be used and to the arrays in `res` that are to be updated.
|
||||
update_dict = {key1: work[key2] for key1, key2 in res_work_pairs}
|
||||
update_dict['success'] = work.status == 0
|
||||
|
||||
if mask is not None:
|
||||
if preserve_shape:
|
||||
active_mask = xp.zeros_like(mask)
|
||||
active_mask[active] = 1
|
||||
active_mask = active_mask & mask
|
||||
for key, val in update_dict.items():
|
||||
try:
|
||||
res[key][active_mask] = val[active_mask]
|
||||
except (IndexError, TypeError, KeyError):
|
||||
res[key][active_mask] = val
|
||||
else:
|
||||
active_mask = active[mask]
|
||||
for key, val in update_dict.items():
|
||||
try:
|
||||
res[key][active_mask] = val[mask]
|
||||
except (IndexError, TypeError, KeyError):
|
||||
res[key][active_mask] = val
|
||||
else:
|
||||
for key, val in update_dict.items():
|
||||
if preserve_shape:
|
||||
try:
|
||||
val = val[active]
|
||||
except (IndexError, TypeError, KeyError):
|
||||
pass
|
||||
res[key][active] = val
|
||||
|
||||
|
||||
def _prepare_result(work, res, res_work_pairs, active, shape, customize_result,
|
||||
preserve_shape, xp):
|
||||
# Prepare the result object `res` by creating a copy, copying the latest
|
||||
# data from work, running the provided result customization function,
|
||||
# and reshaping the data to the original shapes.
|
||||
res = res.copy()
|
||||
_update_active(work, res, res_work_pairs, active, None, preserve_shape, xp)
|
||||
|
||||
shape = customize_result(res, shape)
|
||||
|
||||
for key, val in res.items():
|
||||
# this looks like it won't work for xp != np if val is not numeric
|
||||
temp = xp.reshape(val, shape)
|
||||
res[key] = temp[()] if temp.ndim == 0 else temp
|
||||
|
||||
res['_order_keys'] = ['success'] + [i for i, j in res_work_pairs]
|
||||
return _RichResult(**res)
|
||||
@ -0,0 +1,145 @@
|
||||
from numpy import arange, newaxis, hstack, prod, array
|
||||
|
||||
|
||||
def _central_diff_weights(Np, ndiv=1):
|
||||
"""
|
||||
Return weights for an Np-point central derivative.
|
||||
|
||||
Assumes equally-spaced function points.
|
||||
|
||||
If weights are in the vector w, then
|
||||
derivative is w[0] * f(x-ho*dx) + ... + w[-1] * f(x+h0*dx)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
Np : int
|
||||
Number of points for the central derivative.
|
||||
ndiv : int, optional
|
||||
Number of divisions. Default is 1.
|
||||
|
||||
Returns
|
||||
-------
|
||||
w : ndarray
|
||||
Weights for an Np-point central derivative. Its size is `Np`.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Can be inaccurate for a large number of points.
|
||||
|
||||
Examples
|
||||
--------
|
||||
We can calculate a derivative value of a function.
|
||||
|
||||
>>> def f(x):
|
||||
... return 2 * x**2 + 3
|
||||
>>> x = 3.0 # derivative point
|
||||
>>> h = 0.1 # differential step
|
||||
>>> Np = 3 # point number for central derivative
|
||||
>>> weights = _central_diff_weights(Np) # weights for first derivative
|
||||
>>> vals = [f(x + (i - Np/2) * h) for i in range(Np)]
|
||||
>>> sum(w * v for (w, v) in zip(weights, vals))/h
|
||||
11.79999999999998
|
||||
|
||||
This value is close to the analytical solution:
|
||||
f'(x) = 4x, so f'(3) = 12
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] https://en.wikipedia.org/wiki/Finite_difference
|
||||
|
||||
"""
|
||||
if Np < ndiv + 1:
|
||||
raise ValueError(
|
||||
"Number of points must be at least the derivative order + 1."
|
||||
)
|
||||
if Np % 2 == 0:
|
||||
raise ValueError("The number of points must be odd.")
|
||||
from scipy import linalg
|
||||
|
||||
ho = Np >> 1
|
||||
x = arange(-ho, ho + 1.0)
|
||||
x = x[:, newaxis]
|
||||
X = x**0.0
|
||||
for k in range(1, Np):
|
||||
X = hstack([X, x**k])
|
||||
w = prod(arange(1, ndiv + 1), axis=0) * linalg.inv(X)[ndiv]
|
||||
return w
|
||||
|
||||
|
||||
def _derivative(func, x0, dx=1.0, n=1, args=(), order=3):
|
||||
"""
|
||||
Find the nth derivative of a function at a point.
|
||||
|
||||
Given a function, use a central difference formula with spacing `dx` to
|
||||
compute the nth derivative at `x0`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : function
|
||||
Input function.
|
||||
x0 : float
|
||||
The point at which the nth derivative is found.
|
||||
dx : float, optional
|
||||
Spacing.
|
||||
n : int, optional
|
||||
Order of the derivative. Default is 1.
|
||||
args : tuple, optional
|
||||
Arguments
|
||||
order : int, optional
|
||||
Number of points to use, must be odd.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Decreasing the step size too small can result in round-off error.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> def f(x):
|
||||
... return x**3 + x**2
|
||||
>>> _derivative(f, 1.0, dx=1e-6)
|
||||
4.9999999999217337
|
||||
|
||||
"""
|
||||
if order < n + 1:
|
||||
raise ValueError(
|
||||
"'order' (the number of points used to compute the derivative), "
|
||||
"must be at least the derivative order 'n' + 1."
|
||||
)
|
||||
if order % 2 == 0:
|
||||
raise ValueError(
|
||||
"'order' (the number of points used to compute the derivative) "
|
||||
"must be odd."
|
||||
)
|
||||
# pre-computed for n=1 and 2 and low-order for speed.
|
||||
if n == 1:
|
||||
if order == 3:
|
||||
weights = array([-1, 0, 1]) / 2.0
|
||||
elif order == 5:
|
||||
weights = array([1, -8, 0, 8, -1]) / 12.0
|
||||
elif order == 7:
|
||||
weights = array([-1, 9, -45, 0, 45, -9, 1]) / 60.0
|
||||
elif order == 9:
|
||||
weights = array([3, -32, 168, -672, 0, 672, -168, 32, -3]) / 840.0
|
||||
else:
|
||||
weights = _central_diff_weights(order, 1)
|
||||
elif n == 2:
|
||||
if order == 3:
|
||||
weights = array([1, -2.0, 1])
|
||||
elif order == 5:
|
||||
weights = array([-1, 16, -30, 16, -1]) / 12.0
|
||||
elif order == 7:
|
||||
weights = array([2, -27, 270, -490, 270, -27, 2]) / 180.0
|
||||
elif order == 9:
|
||||
weights = (
|
||||
array([-9, 128, -1008, 8064, -14350, 8064, -1008, 128, -9])
|
||||
/ 5040.0
|
||||
)
|
||||
else:
|
||||
weights = _central_diff_weights(order, 2)
|
||||
else:
|
||||
weights = _central_diff_weights(order, n)
|
||||
val = 0.0
|
||||
ho = order >> 1
|
||||
for k in range(order):
|
||||
val += weights[k] * func(x0 + (k - ho) * dx, *args)
|
||||
return val / prod((dx,) * n, axis=0)
|
||||
Binary file not shown.
105
venv/lib/python3.12/site-packages/scipy/_lib/_gcutils.py
Normal file
105
venv/lib/python3.12/site-packages/scipy/_lib/_gcutils.py
Normal file
@ -0,0 +1,105 @@
|
||||
"""
|
||||
Module for testing automatic garbage collection of objects
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated/
|
||||
|
||||
set_gc_state - enable or disable garbage collection
|
||||
gc_state - context manager for given state of garbage collector
|
||||
assert_deallocated - context manager to check for circular references on object
|
||||
|
||||
"""
|
||||
import weakref
|
||||
import gc
|
||||
|
||||
from contextlib import contextmanager
|
||||
from platform import python_implementation
|
||||
|
||||
__all__ = ['set_gc_state', 'gc_state', 'assert_deallocated']
|
||||
|
||||
|
||||
IS_PYPY = python_implementation() == 'PyPy'
|
||||
|
||||
|
||||
class ReferenceError(AssertionError):
|
||||
pass
|
||||
|
||||
|
||||
def set_gc_state(state):
|
||||
""" Set status of garbage collector """
|
||||
if gc.isenabled() == state:
|
||||
return
|
||||
if state:
|
||||
gc.enable()
|
||||
else:
|
||||
gc.disable()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def gc_state(state):
|
||||
""" Context manager to set state of garbage collector to `state`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state : bool
|
||||
True for gc enabled, False for disabled
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> with gc_state(False):
|
||||
... assert not gc.isenabled()
|
||||
>>> with gc_state(True):
|
||||
... assert gc.isenabled()
|
||||
"""
|
||||
orig_state = gc.isenabled()
|
||||
set_gc_state(state)
|
||||
yield
|
||||
set_gc_state(orig_state)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def assert_deallocated(func, *args, **kwargs):
|
||||
"""Context manager to check that object is deallocated
|
||||
|
||||
This is useful for checking that an object can be freed directly by
|
||||
reference counting, without requiring gc to break reference cycles.
|
||||
GC is disabled inside the context manager.
|
||||
|
||||
This check is not available on PyPy.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : callable
|
||||
Callable to create object to check
|
||||
\\*args : sequence
|
||||
positional arguments to `func` in order to create object to check
|
||||
\\*\\*kwargs : dict
|
||||
keyword arguments to `func` in order to create object to check
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> class C: pass
|
||||
>>> with assert_deallocated(C) as c:
|
||||
... # do something
|
||||
... del c
|
||||
|
||||
>>> class C:
|
||||
... def __init__(self):
|
||||
... self._circular = self # Make circular reference
|
||||
>>> with assert_deallocated(C) as c: #doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
... # do something
|
||||
... del c
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ReferenceError: Remaining reference(s) to object
|
||||
"""
|
||||
if IS_PYPY:
|
||||
raise RuntimeError("assert_deallocated is unavailable on PyPy")
|
||||
|
||||
with gc_state(False):
|
||||
obj = func(*args, **kwargs)
|
||||
ref = weakref.ref(obj)
|
||||
yield obj
|
||||
del obj
|
||||
if ref() is not None:
|
||||
raise ReferenceError("Remaining reference(s) to object")
|
||||
487
venv/lib/python3.12/site-packages/scipy/_lib/_pep440.py
Normal file
487
venv/lib/python3.12/site-packages/scipy/_lib/_pep440.py
Normal file
@ -0,0 +1,487 @@
|
||||
"""Utility to compare pep440 compatible version strings.
|
||||
|
||||
The LooseVersion and StrictVersion classes that distutils provides don't
|
||||
work; they don't recognize anything like alpha/beta/rc/dev versions.
|
||||
"""
|
||||
|
||||
# Copyright (c) Donald Stufft and individual contributors.
|
||||
# All rights reserved.
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in the
|
||||
# documentation and/or other materials provided with the distribution.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
# POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import collections
|
||||
import itertools
|
||||
import re
|
||||
|
||||
|
||||
__all__ = [
|
||||
"parse", "Version", "LegacyVersion", "InvalidVersion", "VERSION_PATTERN",
|
||||
]
|
||||
|
||||
|
||||
# BEGIN packaging/_structures.py
|
||||
|
||||
|
||||
class Infinity:
|
||||
def __repr__(self):
|
||||
return "Infinity"
|
||||
|
||||
def __hash__(self):
|
||||
return hash(repr(self))
|
||||
|
||||
def __lt__(self, other):
|
||||
return False
|
||||
|
||||
def __le__(self, other):
|
||||
return False
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, self.__class__)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not isinstance(other, self.__class__)
|
||||
|
||||
def __gt__(self, other):
|
||||
return True
|
||||
|
||||
def __ge__(self, other):
|
||||
return True
|
||||
|
||||
def __neg__(self):
|
||||
return NegativeInfinity
|
||||
|
||||
|
||||
Infinity = Infinity()
|
||||
|
||||
|
||||
class NegativeInfinity:
|
||||
def __repr__(self):
|
||||
return "-Infinity"
|
||||
|
||||
def __hash__(self):
|
||||
return hash(repr(self))
|
||||
|
||||
def __lt__(self, other):
|
||||
return True
|
||||
|
||||
def __le__(self, other):
|
||||
return True
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, self.__class__)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not isinstance(other, self.__class__)
|
||||
|
||||
def __gt__(self, other):
|
||||
return False
|
||||
|
||||
def __ge__(self, other):
|
||||
return False
|
||||
|
||||
def __neg__(self):
|
||||
return Infinity
|
||||
|
||||
|
||||
# BEGIN packaging/version.py
|
||||
|
||||
|
||||
NegativeInfinity = NegativeInfinity()
|
||||
|
||||
_Version = collections.namedtuple(
|
||||
"_Version",
|
||||
["epoch", "release", "dev", "pre", "post", "local"],
|
||||
)
|
||||
|
||||
|
||||
def parse(version):
|
||||
"""
|
||||
Parse the given version string and return either a :class:`Version` object
|
||||
or a :class:`LegacyVersion` object depending on if the given version is
|
||||
a valid PEP 440 version or a legacy version.
|
||||
"""
|
||||
try:
|
||||
return Version(version)
|
||||
except InvalidVersion:
|
||||
return LegacyVersion(version)
|
||||
|
||||
|
||||
class InvalidVersion(ValueError):
|
||||
"""
|
||||
An invalid version was found, users should refer to PEP 440.
|
||||
"""
|
||||
|
||||
|
||||
class _BaseVersion:
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self._key)
|
||||
|
||||
def __lt__(self, other):
|
||||
return self._compare(other, lambda s, o: s < o)
|
||||
|
||||
def __le__(self, other):
|
||||
return self._compare(other, lambda s, o: s <= o)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self._compare(other, lambda s, o: s == o)
|
||||
|
||||
def __ge__(self, other):
|
||||
return self._compare(other, lambda s, o: s >= o)
|
||||
|
||||
def __gt__(self, other):
|
||||
return self._compare(other, lambda s, o: s > o)
|
||||
|
||||
def __ne__(self, other):
|
||||
return self._compare(other, lambda s, o: s != o)
|
||||
|
||||
def _compare(self, other, method):
|
||||
if not isinstance(other, _BaseVersion):
|
||||
return NotImplemented
|
||||
|
||||
return method(self._key, other._key)
|
||||
|
||||
|
||||
class LegacyVersion(_BaseVersion):
|
||||
|
||||
def __init__(self, version):
|
||||
self._version = str(version)
|
||||
self._key = _legacy_cmpkey(self._version)
|
||||
|
||||
def __str__(self):
|
||||
return self._version
|
||||
|
||||
def __repr__(self):
|
||||
return f"<LegacyVersion({repr(str(self))})>"
|
||||
|
||||
@property
|
||||
def public(self):
|
||||
return self._version
|
||||
|
||||
@property
|
||||
def base_version(self):
|
||||
return self._version
|
||||
|
||||
@property
|
||||
def local(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_prerelease(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_postrelease(self):
|
||||
return False
|
||||
|
||||
|
||||
_legacy_version_component_re = re.compile(
|
||||
r"(\d+ | [a-z]+ | \.| -)", re.VERBOSE,
|
||||
)
|
||||
|
||||
_legacy_version_replacement_map = {
|
||||
"pre": "c", "preview": "c", "-": "final-", "rc": "c", "dev": "@",
|
||||
}
|
||||
|
||||
|
||||
def _parse_version_parts(s):
|
||||
for part in _legacy_version_component_re.split(s):
|
||||
part = _legacy_version_replacement_map.get(part, part)
|
||||
|
||||
if not part or part == ".":
|
||||
continue
|
||||
|
||||
if part[:1] in "0123456789":
|
||||
# pad for numeric comparison
|
||||
yield part.zfill(8)
|
||||
else:
|
||||
yield "*" + part
|
||||
|
||||
# ensure that alpha/beta/candidate are before final
|
||||
yield "*final"
|
||||
|
||||
|
||||
def _legacy_cmpkey(version):
|
||||
# We hardcode an epoch of -1 here. A PEP 440 version can only have an epoch
|
||||
# greater than or equal to 0. This will effectively put the LegacyVersion,
|
||||
# which uses the defacto standard originally implemented by setuptools,
|
||||
# as before all PEP 440 versions.
|
||||
epoch = -1
|
||||
|
||||
# This scheme is taken from pkg_resources.parse_version setuptools prior to
|
||||
# its adoption of the packaging library.
|
||||
parts = []
|
||||
for part in _parse_version_parts(version.lower()):
|
||||
if part.startswith("*"):
|
||||
# remove "-" before a prerelease tag
|
||||
if part < "*final":
|
||||
while parts and parts[-1] == "*final-":
|
||||
parts.pop()
|
||||
|
||||
# remove trailing zeros from each series of numeric parts
|
||||
while parts and parts[-1] == "00000000":
|
||||
parts.pop()
|
||||
|
||||
parts.append(part)
|
||||
parts = tuple(parts)
|
||||
|
||||
return epoch, parts
|
||||
|
||||
|
||||
# Deliberately not anchored to the start and end of the string, to make it
|
||||
# easier for 3rd party code to reuse
|
||||
VERSION_PATTERN = r"""
|
||||
v?
|
||||
(?:
|
||||
(?:(?P<epoch>[0-9]+)!)? # epoch
|
||||
(?P<release>[0-9]+(?:\.[0-9]+)*) # release segment
|
||||
(?P<pre> # pre-release
|
||||
[-_\.]?
|
||||
(?P<pre_l>(a|b|c|rc|alpha|beta|pre|preview))
|
||||
[-_\.]?
|
||||
(?P<pre_n>[0-9]+)?
|
||||
)?
|
||||
(?P<post> # post release
|
||||
(?:-(?P<post_n1>[0-9]+))
|
||||
|
|
||||
(?:
|
||||
[-_\.]?
|
||||
(?P<post_l>post|rev|r)
|
||||
[-_\.]?
|
||||
(?P<post_n2>[0-9]+)?
|
||||
)
|
||||
)?
|
||||
(?P<dev> # dev release
|
||||
[-_\.]?
|
||||
(?P<dev_l>dev)
|
||||
[-_\.]?
|
||||
(?P<dev_n>[0-9]+)?
|
||||
)?
|
||||
)
|
||||
(?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
|
||||
"""
|
||||
|
||||
|
||||
class Version(_BaseVersion):
|
||||
|
||||
_regex = re.compile(
|
||||
r"^\s*" + VERSION_PATTERN + r"\s*$",
|
||||
re.VERBOSE | re.IGNORECASE,
|
||||
)
|
||||
|
||||
def __init__(self, version):
|
||||
# Validate the version and parse it into pieces
|
||||
match = self._regex.search(version)
|
||||
if not match:
|
||||
raise InvalidVersion(f"Invalid version: '{version}'")
|
||||
|
||||
# Store the parsed out pieces of the version
|
||||
self._version = _Version(
|
||||
epoch=int(match.group("epoch")) if match.group("epoch") else 0,
|
||||
release=tuple(int(i) for i in match.group("release").split(".")),
|
||||
pre=_parse_letter_version(
|
||||
match.group("pre_l"),
|
||||
match.group("pre_n"),
|
||||
),
|
||||
post=_parse_letter_version(
|
||||
match.group("post_l"),
|
||||
match.group("post_n1") or match.group("post_n2"),
|
||||
),
|
||||
dev=_parse_letter_version(
|
||||
match.group("dev_l"),
|
||||
match.group("dev_n"),
|
||||
),
|
||||
local=_parse_local_version(match.group("local")),
|
||||
)
|
||||
|
||||
# Generate a key which will be used for sorting
|
||||
self._key = _cmpkey(
|
||||
self._version.epoch,
|
||||
self._version.release,
|
||||
self._version.pre,
|
||||
self._version.post,
|
||||
self._version.dev,
|
||||
self._version.local,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Version({repr(str(self))})>"
|
||||
|
||||
def __str__(self):
|
||||
parts = []
|
||||
|
||||
# Epoch
|
||||
if self._version.epoch != 0:
|
||||
parts.append(f"{self._version.epoch}!")
|
||||
|
||||
# Release segment
|
||||
parts.append(".".join(str(x) for x in self._version.release))
|
||||
|
||||
# Pre-release
|
||||
if self._version.pre is not None:
|
||||
parts.append("".join(str(x) for x in self._version.pre))
|
||||
|
||||
# Post-release
|
||||
if self._version.post is not None:
|
||||
parts.append(f".post{self._version.post[1]}")
|
||||
|
||||
# Development release
|
||||
if self._version.dev is not None:
|
||||
parts.append(f".dev{self._version.dev[1]}")
|
||||
|
||||
# Local version segment
|
||||
if self._version.local is not None:
|
||||
parts.append(
|
||||
"+{}".format(".".join(str(x) for x in self._version.local))
|
||||
)
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
@property
|
||||
def public(self):
|
||||
return str(self).split("+", 1)[0]
|
||||
|
||||
@property
|
||||
def base_version(self):
|
||||
parts = []
|
||||
|
||||
# Epoch
|
||||
if self._version.epoch != 0:
|
||||
parts.append(f"{self._version.epoch}!")
|
||||
|
||||
# Release segment
|
||||
parts.append(".".join(str(x) for x in self._version.release))
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
@property
|
||||
def local(self):
|
||||
version_string = str(self)
|
||||
if "+" in version_string:
|
||||
return version_string.split("+", 1)[1]
|
||||
|
||||
@property
|
||||
def is_prerelease(self):
|
||||
return bool(self._version.dev or self._version.pre)
|
||||
|
||||
@property
|
||||
def is_postrelease(self):
|
||||
return bool(self._version.post)
|
||||
|
||||
|
||||
def _parse_letter_version(letter, number):
|
||||
if letter:
|
||||
# We assume there is an implicit 0 in a pre-release if there is
|
||||
# no numeral associated with it.
|
||||
if number is None:
|
||||
number = 0
|
||||
|
||||
# We normalize any letters to their lower-case form
|
||||
letter = letter.lower()
|
||||
|
||||
# We consider some words to be alternate spellings of other words and
|
||||
# in those cases we want to normalize the spellings to our preferred
|
||||
# spelling.
|
||||
if letter == "alpha":
|
||||
letter = "a"
|
||||
elif letter == "beta":
|
||||
letter = "b"
|
||||
elif letter in ["c", "pre", "preview"]:
|
||||
letter = "rc"
|
||||
elif letter in ["rev", "r"]:
|
||||
letter = "post"
|
||||
|
||||
return letter, int(number)
|
||||
if not letter and number:
|
||||
# We assume that if we are given a number but not given a letter,
|
||||
# then this is using the implicit post release syntax (e.g., 1.0-1)
|
||||
letter = "post"
|
||||
|
||||
return letter, int(number)
|
||||
|
||||
|
||||
_local_version_seperators = re.compile(r"[\._-]")
|
||||
|
||||
|
||||
def _parse_local_version(local):
|
||||
"""
|
||||
Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
|
||||
"""
|
||||
if local is not None:
|
||||
return tuple(
|
||||
part.lower() if not part.isdigit() else int(part)
|
||||
for part in _local_version_seperators.split(local)
|
||||
)
|
||||
|
||||
|
||||
def _cmpkey(epoch, release, pre, post, dev, local):
|
||||
# When we compare a release version, we want to compare it with all of the
|
||||
# trailing zeros removed. So we'll use a reverse the list, drop all the now
|
||||
# leading zeros until we come to something non-zero, then take the rest,
|
||||
# re-reverse it back into the correct order, and make it a tuple and use
|
||||
# that for our sorting key.
|
||||
release = tuple(
|
||||
reversed(list(
|
||||
itertools.dropwhile(
|
||||
lambda x: x == 0,
|
||||
reversed(release),
|
||||
)
|
||||
))
|
||||
)
|
||||
|
||||
# We need to "trick" the sorting algorithm to put 1.0.dev0 before 1.0a0.
|
||||
# We'll do this by abusing the pre-segment, but we _only_ want to do this
|
||||
# if there is no pre- or a post-segment. If we have one of those, then
|
||||
# the normal sorting rules will handle this case correctly.
|
||||
if pre is None and post is None and dev is not None:
|
||||
pre = -Infinity
|
||||
# Versions without a pre-release (except as noted above) should sort after
|
||||
# those with one.
|
||||
elif pre is None:
|
||||
pre = Infinity
|
||||
|
||||
# Versions without a post-segment should sort before those with one.
|
||||
if post is None:
|
||||
post = -Infinity
|
||||
|
||||
# Versions without a development segment should sort after those with one.
|
||||
if dev is None:
|
||||
dev = Infinity
|
||||
|
||||
if local is None:
|
||||
# Versions without a local segment should sort before those with one.
|
||||
local = -Infinity
|
||||
else:
|
||||
# Versions with a local segment need that segment parsed to implement
|
||||
# the sorting rules in PEP440.
|
||||
# - Alphanumeric segments sort before numeric segments
|
||||
# - Alphanumeric segments sort lexicographically
|
||||
# - Numeric segments sort numerically
|
||||
# - Shorter versions sort before longer versions when the prefixes
|
||||
# match exactly
|
||||
local = tuple(
|
||||
(i, "") if isinstance(i, int) else (-Infinity, i)
|
||||
for i in local
|
||||
)
|
||||
|
||||
return epoch, release, pre, post, dev, local
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
337
venv/lib/python3.12/site-packages/scipy/_lib/_testutils.py
Normal file
337
venv/lib/python3.12/site-packages/scipy/_lib/_testutils.py
Normal file
@ -0,0 +1,337 @@
|
||||
"""
|
||||
Generic test utilities.
|
||||
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import sysconfig
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
|
||||
try:
|
||||
# Need type: ignore[import-untyped] for mypy >= 1.6
|
||||
import cython # type: ignore[import-untyped]
|
||||
from Cython.Compiler.Version import ( # type: ignore[import-untyped]
|
||||
version as cython_version,
|
||||
)
|
||||
except ImportError:
|
||||
cython = None
|
||||
else:
|
||||
from scipy._lib import _pep440
|
||||
required_version = '3.0.8'
|
||||
if _pep440.parse(cython_version) < _pep440.Version(required_version):
|
||||
# too old or wrong cython, skip Cython API tests
|
||||
cython = None
|
||||
|
||||
|
||||
__all__ = ['PytestTester', 'check_free_memory', '_TestPythranFunc', 'IS_MUSL']
|
||||
|
||||
|
||||
IS_MUSL = False
|
||||
# alternate way is
|
||||
# from packaging.tags import sys_tags
|
||||
# _tags = list(sys_tags())
|
||||
# if 'musllinux' in _tags[0].platform:
|
||||
_v = sysconfig.get_config_var('HOST_GNU_TYPE') or ''
|
||||
if 'musl' in _v:
|
||||
IS_MUSL = True
|
||||
|
||||
|
||||
IS_EDITABLE = 'editable' in scipy.__path__[0]
|
||||
|
||||
|
||||
class FPUModeChangeWarning(RuntimeWarning):
|
||||
"""Warning about FPU mode change"""
|
||||
pass
|
||||
|
||||
|
||||
class PytestTester:
|
||||
"""
|
||||
Run tests for this namespace
|
||||
|
||||
``scipy.test()`` runs tests for all of SciPy, with the default settings.
|
||||
When used from a submodule (e.g., ``scipy.cluster.test()``, only the tests
|
||||
for that namespace are run.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
label : {'fast', 'full'}, optional
|
||||
Whether to run only the fast tests, or also those marked as slow.
|
||||
Default is 'fast'.
|
||||
verbose : int, optional
|
||||
Test output verbosity. Default is 1.
|
||||
extra_argv : list, optional
|
||||
Arguments to pass through to Pytest.
|
||||
doctests : bool, optional
|
||||
Whether to run doctests or not. Default is False.
|
||||
coverage : bool, optional
|
||||
Whether to run tests with code coverage measurements enabled.
|
||||
Default is False.
|
||||
tests : list of str, optional
|
||||
List of module names to run tests for. By default, uses the module
|
||||
from which the ``test`` function is called.
|
||||
parallel : int, optional
|
||||
Run tests in parallel with pytest-xdist, if number given is larger than
|
||||
1. Default is 1.
|
||||
|
||||
"""
|
||||
def __init__(self, module_name):
|
||||
self.module_name = module_name
|
||||
|
||||
def __call__(self, label="fast", verbose=1, extra_argv=None, doctests=False,
|
||||
coverage=False, tests=None, parallel=None):
|
||||
import pytest
|
||||
|
||||
module = sys.modules[self.module_name]
|
||||
module_path = os.path.abspath(module.__path__[0])
|
||||
|
||||
pytest_args = ['--showlocals', '--tb=short']
|
||||
|
||||
if doctests:
|
||||
pytest_args += [
|
||||
"--doctest-modules",
|
||||
"--ignore=scipy/interpolate/_interpnd_info.py",
|
||||
"--ignore=scipy/_lib/array_api_compat",
|
||||
"--ignore=scipy/_lib/highs",
|
||||
"--ignore=scipy/_lib/unuran",
|
||||
"--ignore=scipy/_lib/_gcutils.py",
|
||||
"--ignore=scipy/_lib/doccer.py",
|
||||
"--ignore=scipy/_lib/_uarray",
|
||||
]
|
||||
|
||||
if extra_argv:
|
||||
pytest_args += list(extra_argv)
|
||||
|
||||
if verbose and int(verbose) > 1:
|
||||
pytest_args += ["-" + "v"*(int(verbose)-1)]
|
||||
|
||||
if coverage:
|
||||
pytest_args += ["--cov=" + module_path]
|
||||
|
||||
if label == "fast":
|
||||
pytest_args += ["-m", "not slow"]
|
||||
elif label != "full":
|
||||
pytest_args += ["-m", label]
|
||||
|
||||
if tests is None:
|
||||
tests = [self.module_name]
|
||||
|
||||
if parallel is not None and parallel > 1:
|
||||
if _pytest_has_xdist():
|
||||
pytest_args += ['-n', str(parallel)]
|
||||
else:
|
||||
import warnings
|
||||
warnings.warn('Could not run tests in parallel because '
|
||||
'pytest-xdist plugin is not available.',
|
||||
stacklevel=2)
|
||||
|
||||
pytest_args += ['--pyargs'] + list(tests)
|
||||
|
||||
try:
|
||||
code = pytest.main(pytest_args)
|
||||
except SystemExit as exc:
|
||||
code = exc.code
|
||||
|
||||
return (code == 0)
|
||||
|
||||
|
||||
class _TestPythranFunc:
|
||||
'''
|
||||
These are situations that can be tested in our pythran tests:
|
||||
- A function with multiple array arguments and then
|
||||
other positional and keyword arguments.
|
||||
- A function with array-like keywords (e.g. `def somefunc(x0, x1=None)`.
|
||||
Note: list/tuple input is not yet tested!
|
||||
|
||||
`self.arguments`: A dictionary which key is the index of the argument,
|
||||
value is tuple(array value, all supported dtypes)
|
||||
`self.partialfunc`: A function used to freeze some non-array argument
|
||||
that of no interests in the original function
|
||||
'''
|
||||
ALL_INTEGER = [np.int8, np.int16, np.int32, np.int64, np.intc, np.intp]
|
||||
ALL_FLOAT = [np.float32, np.float64]
|
||||
ALL_COMPLEX = [np.complex64, np.complex128]
|
||||
|
||||
def setup_method(self):
|
||||
self.arguments = {}
|
||||
self.partialfunc = None
|
||||
self.expected = None
|
||||
|
||||
def get_optional_args(self, func):
|
||||
# get optional arguments with its default value,
|
||||
# used for testing keywords
|
||||
signature = inspect.signature(func)
|
||||
optional_args = {}
|
||||
for k, v in signature.parameters.items():
|
||||
if v.default is not inspect.Parameter.empty:
|
||||
optional_args[k] = v.default
|
||||
return optional_args
|
||||
|
||||
def get_max_dtype_list_length(self):
|
||||
# get the max supported dtypes list length in all arguments
|
||||
max_len = 0
|
||||
for arg_idx in self.arguments:
|
||||
cur_len = len(self.arguments[arg_idx][1])
|
||||
if cur_len > max_len:
|
||||
max_len = cur_len
|
||||
return max_len
|
||||
|
||||
def get_dtype(self, dtype_list, dtype_idx):
|
||||
# get the dtype from dtype_list via index
|
||||
# if the index is out of range, then return the last dtype
|
||||
if dtype_idx > len(dtype_list)-1:
|
||||
return dtype_list[-1]
|
||||
else:
|
||||
return dtype_list[dtype_idx]
|
||||
|
||||
def test_all_dtypes(self):
|
||||
for type_idx in range(self.get_max_dtype_list_length()):
|
||||
args_array = []
|
||||
for arg_idx in self.arguments:
|
||||
new_dtype = self.get_dtype(self.arguments[arg_idx][1],
|
||||
type_idx)
|
||||
args_array.append(self.arguments[arg_idx][0].astype(new_dtype))
|
||||
self.pythranfunc(*args_array)
|
||||
|
||||
def test_views(self):
|
||||
args_array = []
|
||||
for arg_idx in self.arguments:
|
||||
args_array.append(self.arguments[arg_idx][0][::-1][::-1])
|
||||
self.pythranfunc(*args_array)
|
||||
|
||||
def test_strided(self):
|
||||
args_array = []
|
||||
for arg_idx in self.arguments:
|
||||
args_array.append(np.repeat(self.arguments[arg_idx][0],
|
||||
2, axis=0)[::2])
|
||||
self.pythranfunc(*args_array)
|
||||
|
||||
|
||||
def _pytest_has_xdist():
|
||||
"""
|
||||
Check if the pytest-xdist plugin is installed, providing parallel tests
|
||||
"""
|
||||
# Check xdist exists without importing, otherwise pytests emits warnings
|
||||
from importlib.util import find_spec
|
||||
return find_spec('xdist') is not None
|
||||
|
||||
|
||||
def check_free_memory(free_mb):
|
||||
"""
|
||||
Check *free_mb* of memory is available, otherwise do pytest.skip
|
||||
"""
|
||||
import pytest
|
||||
|
||||
try:
|
||||
mem_free = _parse_size(os.environ['SCIPY_AVAILABLE_MEM'])
|
||||
msg = '{} MB memory required, but environment SCIPY_AVAILABLE_MEM={}'.format(
|
||||
free_mb, os.environ['SCIPY_AVAILABLE_MEM'])
|
||||
except KeyError:
|
||||
mem_free = _get_mem_available()
|
||||
if mem_free is None:
|
||||
pytest.skip("Could not determine available memory; set SCIPY_AVAILABLE_MEM "
|
||||
"variable to free memory in MB to run the test.")
|
||||
msg = f'{free_mb} MB memory required, but {mem_free/1e6} MB available'
|
||||
|
||||
if mem_free < free_mb * 1e6:
|
||||
pytest.skip(msg)
|
||||
|
||||
|
||||
def _parse_size(size_str):
|
||||
suffixes = {'': 1e6,
|
||||
'b': 1.0,
|
||||
'k': 1e3, 'M': 1e6, 'G': 1e9, 'T': 1e12,
|
||||
'kb': 1e3, 'Mb': 1e6, 'Gb': 1e9, 'Tb': 1e12,
|
||||
'kib': 1024.0, 'Mib': 1024.0**2, 'Gib': 1024.0**3, 'Tib': 1024.0**4}
|
||||
m = re.match(r'^\s*(\d+)\s*({})\s*$'.format('|'.join(suffixes.keys())),
|
||||
size_str,
|
||||
re.I)
|
||||
if not m or m.group(2) not in suffixes:
|
||||
raise ValueError("Invalid size string")
|
||||
|
||||
return float(m.group(1)) * suffixes[m.group(2)]
|
||||
|
||||
|
||||
def _get_mem_available():
|
||||
"""
|
||||
Get information about memory available, not counting swap.
|
||||
"""
|
||||
try:
|
||||
import psutil
|
||||
return psutil.virtual_memory().available
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
if sys.platform.startswith('linux'):
|
||||
info = {}
|
||||
with open('/proc/meminfo') as f:
|
||||
for line in f:
|
||||
p = line.split()
|
||||
info[p[0].strip(':').lower()] = float(p[1]) * 1e3
|
||||
|
||||
if 'memavailable' in info:
|
||||
# Linux >= 3.14
|
||||
return info['memavailable']
|
||||
else:
|
||||
return info['memfree'] + info['cached']
|
||||
|
||||
return None
|
||||
|
||||
def _test_cython_extension(tmp_path, srcdir):
|
||||
"""
|
||||
Helper function to test building and importing Cython modules that
|
||||
make use of the Cython APIs for BLAS, LAPACK, optimize, and special.
|
||||
"""
|
||||
import pytest
|
||||
try:
|
||||
subprocess.check_call(["meson", "--version"])
|
||||
except FileNotFoundError:
|
||||
pytest.skip("No usable 'meson' found")
|
||||
|
||||
# build the examples in a temporary directory
|
||||
mod_name = os.path.split(srcdir)[1]
|
||||
shutil.copytree(srcdir, tmp_path / mod_name)
|
||||
build_dir = tmp_path / mod_name / 'tests' / '_cython_examples'
|
||||
target_dir = build_dir / 'build'
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
# Ensure we use the correct Python interpreter even when `meson` is
|
||||
# installed in a different Python environment (see numpy#24956)
|
||||
native_file = str(build_dir / 'interpreter-native-file.ini')
|
||||
with open(native_file, 'w') as f:
|
||||
f.write("[binaries]\n")
|
||||
f.write(f"python = '{sys.executable}'")
|
||||
|
||||
if sys.platform == "win32":
|
||||
subprocess.check_call(["meson", "setup",
|
||||
"--buildtype=release",
|
||||
"--native-file", native_file,
|
||||
"--vsenv", str(build_dir)],
|
||||
cwd=target_dir,
|
||||
)
|
||||
else:
|
||||
subprocess.check_call(["meson", "setup",
|
||||
"--native-file", native_file, str(build_dir)],
|
||||
cwd=target_dir
|
||||
)
|
||||
subprocess.check_call(["meson", "compile", "-vv"], cwd=target_dir)
|
||||
|
||||
# import without adding the directory to sys.path
|
||||
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
||||
|
||||
def load(modname):
|
||||
so = (target_dir / modname).with_suffix(suffix)
|
||||
spec = spec_from_file_location(modname, so)
|
||||
mod = module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
# test that the module can be imported
|
||||
return load("extending"), load("extending_cpp")
|
||||
@ -0,0 +1,58 @@
|
||||
import threading
|
||||
|
||||
import scipy._lib.decorator
|
||||
|
||||
|
||||
__all__ = ['ReentrancyError', 'ReentrancyLock', 'non_reentrant']
|
||||
|
||||
|
||||
class ReentrancyError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class ReentrancyLock:
|
||||
"""
|
||||
Threading lock that raises an exception for reentrant calls.
|
||||
|
||||
Calls from different threads are serialized, and nested calls from the
|
||||
same thread result to an error.
|
||||
|
||||
The object can be used as a context manager or to decorate functions
|
||||
via the decorate() method.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, err_msg):
|
||||
self._rlock = threading.RLock()
|
||||
self._entered = False
|
||||
self._err_msg = err_msg
|
||||
|
||||
def __enter__(self):
|
||||
self._rlock.acquire()
|
||||
if self._entered:
|
||||
self._rlock.release()
|
||||
raise ReentrancyError(self._err_msg)
|
||||
self._entered = True
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self._entered = False
|
||||
self._rlock.release()
|
||||
|
||||
def decorate(self, func):
|
||||
def caller(func, *a, **kw):
|
||||
with self:
|
||||
return func(*a, **kw)
|
||||
return scipy._lib.decorator.decorate(func, caller)
|
||||
|
||||
|
||||
def non_reentrant(err_msg=None):
|
||||
"""
|
||||
Decorate a function with a threading lock and prevent reentrant calls.
|
||||
"""
|
||||
def decorator(func):
|
||||
msg = err_msg
|
||||
if msg is None:
|
||||
msg = "%s is not re-entrant" % func.__name__
|
||||
lock = ReentrancyLock(msg)
|
||||
return lock.decorate(func)
|
||||
return decorator
|
||||
86
venv/lib/python3.12/site-packages/scipy/_lib/_tmpdirs.py
Normal file
86
venv/lib/python3.12/site-packages/scipy/_lib/_tmpdirs.py
Normal file
@ -0,0 +1,86 @@
|
||||
''' Contexts for *with* statement providing temporary directories
|
||||
'''
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from shutil import rmtree
|
||||
from tempfile import mkdtemp
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tempdir():
|
||||
"""Create and return a temporary directory. This has the same
|
||||
behavior as mkdtemp but can be used as a context manager.
|
||||
|
||||
Upon exiting the context, the directory and everything contained
|
||||
in it are removed.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import os
|
||||
>>> with tempdir() as tmpdir:
|
||||
... fname = os.path.join(tmpdir, 'example_file.txt')
|
||||
... with open(fname, 'wt') as fobj:
|
||||
... _ = fobj.write('a string\\n')
|
||||
>>> os.path.exists(tmpdir)
|
||||
False
|
||||
"""
|
||||
d = mkdtemp()
|
||||
yield d
|
||||
rmtree(d)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def in_tempdir():
|
||||
''' Create, return, and change directory to a temporary directory
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import os
|
||||
>>> my_cwd = os.getcwd()
|
||||
>>> with in_tempdir() as tmpdir:
|
||||
... _ = open('test.txt', 'wt').write('some text')
|
||||
... assert os.path.isfile('test.txt')
|
||||
... assert os.path.isfile(os.path.join(tmpdir, 'test.txt'))
|
||||
>>> os.path.exists(tmpdir)
|
||||
False
|
||||
>>> os.getcwd() == my_cwd
|
||||
True
|
||||
'''
|
||||
pwd = os.getcwd()
|
||||
d = mkdtemp()
|
||||
os.chdir(d)
|
||||
yield d
|
||||
os.chdir(pwd)
|
||||
rmtree(d)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def in_dir(dir=None):
|
||||
""" Change directory to given directory for duration of ``with`` block
|
||||
|
||||
Useful when you want to use `in_tempdir` for the final test, but
|
||||
you are still debugging. For example, you may want to do this in the end:
|
||||
|
||||
>>> with in_tempdir() as tmpdir:
|
||||
... # do something complicated which might break
|
||||
... pass
|
||||
|
||||
But, indeed, the complicated thing does break, and meanwhile, the
|
||||
``in_tempdir`` context manager wiped out the directory with the
|
||||
temporary files that you wanted for debugging. So, while debugging, you
|
||||
replace with something like:
|
||||
|
||||
>>> with in_dir() as tmpdir: # Use working directory by default
|
||||
... # do something complicated which might break
|
||||
... pass
|
||||
|
||||
You can then look at the temporary file outputs to debug what is happening,
|
||||
fix, and finally replace ``in_dir`` with ``in_tempdir`` again.
|
||||
"""
|
||||
cwd = os.getcwd()
|
||||
if dir is None:
|
||||
yield cwd
|
||||
return
|
||||
os.chdir(dir)
|
||||
yield dir
|
||||
os.chdir(cwd)
|
||||
29
venv/lib/python3.12/site-packages/scipy/_lib/_uarray/LICENSE
Normal file
29
venv/lib/python3.12/site-packages/scipy/_lib/_uarray/LICENSE
Normal file
@ -0,0 +1,29 @@
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2018, Quansight-Labs
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
116
venv/lib/python3.12/site-packages/scipy/_lib/_uarray/__init__.py
Normal file
116
venv/lib/python3.12/site-packages/scipy/_lib/_uarray/__init__.py
Normal file
@ -0,0 +1,116 @@
|
||||
"""
|
||||
.. note:
|
||||
If you are looking for overrides for NumPy-specific methods, see the
|
||||
documentation for :obj:`unumpy`. This page explains how to write
|
||||
back-ends and multimethods.
|
||||
|
||||
``uarray`` is built around a back-end protocol, and overridable multimethods.
|
||||
It is necessary to define multimethods for back-ends to be able to override them.
|
||||
See the documentation of :obj:`generate_multimethod` on how to write multimethods.
|
||||
|
||||
|
||||
|
||||
Let's start with the simplest:
|
||||
|
||||
``__ua_domain__`` defines the back-end *domain*. The domain consists of period-
|
||||
separated string consisting of the modules you extend plus the submodule. For
|
||||
example, if a submodule ``module2.submodule`` extends ``module1``
|
||||
(i.e., it exposes dispatchables marked as types available in ``module1``),
|
||||
then the domain string should be ``"module1.module2.submodule"``.
|
||||
|
||||
|
||||
For the purpose of this demonstration, we'll be creating an object and setting
|
||||
its attributes directly. However, note that you can use a module or your own type
|
||||
as a backend as well.
|
||||
|
||||
>>> class Backend: pass
|
||||
>>> be = Backend()
|
||||
>>> be.__ua_domain__ = "ua_examples"
|
||||
|
||||
It might be useful at this point to sidetrack to the documentation of
|
||||
:obj:`generate_multimethod` to find out how to generate a multimethod
|
||||
overridable by :obj:`uarray`. Needless to say, writing a backend and
|
||||
creating multimethods are mostly orthogonal activities, and knowing
|
||||
one doesn't necessarily require knowledge of the other, although it
|
||||
is certainly helpful. We expect core API designers/specifiers to write the
|
||||
multimethods, and implementors to override them. But, as is often the case,
|
||||
similar people write both.
|
||||
|
||||
Without further ado, here's an example multimethod:
|
||||
|
||||
>>> import uarray as ua
|
||||
>>> from uarray import Dispatchable
|
||||
>>> def override_me(a, b):
|
||||
... return Dispatchable(a, int),
|
||||
>>> def override_replacer(args, kwargs, dispatchables):
|
||||
... return (dispatchables[0], args[1]), {}
|
||||
>>> overridden_me = ua.generate_multimethod(
|
||||
... override_me, override_replacer, "ua_examples"
|
||||
... )
|
||||
|
||||
Next comes the part about overriding the multimethod. This requires
|
||||
the ``__ua_function__`` protocol, and the ``__ua_convert__``
|
||||
protocol. The ``__ua_function__`` protocol has the signature
|
||||
``(method, args, kwargs)`` where ``method`` is the passed
|
||||
multimethod, ``args``/``kwargs`` specify the arguments and ``dispatchables``
|
||||
is the list of converted dispatchables passed in.
|
||||
|
||||
>>> def __ua_function__(method, args, kwargs):
|
||||
... return method.__name__, args, kwargs
|
||||
>>> be.__ua_function__ = __ua_function__
|
||||
|
||||
The other protocol of interest is the ``__ua_convert__`` protocol. It has the
|
||||
signature ``(dispatchables, coerce)``. When ``coerce`` is ``False``, conversion
|
||||
between the formats should ideally be an ``O(1)`` operation, but it means that
|
||||
no memory copying should be involved, only views of the existing data.
|
||||
|
||||
>>> def __ua_convert__(dispatchables, coerce):
|
||||
... for d in dispatchables:
|
||||
... if d.type is int:
|
||||
... if coerce and d.coercible:
|
||||
... yield str(d.value)
|
||||
... else:
|
||||
... yield d.value
|
||||
>>> be.__ua_convert__ = __ua_convert__
|
||||
|
||||
Now that we have defined the backend, the next thing to do is to call the multimethod.
|
||||
|
||||
>>> with ua.set_backend(be):
|
||||
... overridden_me(1, "2")
|
||||
('override_me', (1, '2'), {})
|
||||
|
||||
Note that the marked type has no effect on the actual type of the passed object.
|
||||
We can also coerce the type of the input.
|
||||
|
||||
>>> with ua.set_backend(be, coerce=True):
|
||||
... overridden_me(1, "2")
|
||||
... overridden_me(1.0, "2")
|
||||
('override_me', ('1', '2'), {})
|
||||
('override_me', ('1.0', '2'), {})
|
||||
|
||||
Another feature is that if you remove ``__ua_convert__``, the arguments are not
|
||||
converted at all and it's up to the backend to handle that.
|
||||
|
||||
>>> del be.__ua_convert__
|
||||
>>> with ua.set_backend(be):
|
||||
... overridden_me(1, "2")
|
||||
('override_me', (1, '2'), {})
|
||||
|
||||
You also have the option to return ``NotImplemented``, in which case processing moves on
|
||||
to the next back-end, which in this case, doesn't exist. The same applies to
|
||||
``__ua_convert__``.
|
||||
|
||||
>>> be.__ua_function__ = lambda *a, **kw: NotImplemented
|
||||
>>> with ua.set_backend(be):
|
||||
... overridden_me(1, "2")
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
uarray.BackendNotImplementedError: ...
|
||||
|
||||
The last possibility is if we don't have ``__ua_convert__``, in which case the job is
|
||||
left up to ``__ua_function__``, but putting things back into arrays after conversion
|
||||
will not be possible.
|
||||
"""
|
||||
|
||||
from ._backend import *
|
||||
__version__ = '0.8.8.dev0+aa94c5a4.scipy'
|
||||
704
venv/lib/python3.12/site-packages/scipy/_lib/_uarray/_backend.py
Normal file
704
venv/lib/python3.12/site-packages/scipy/_lib/_uarray/_backend.py
Normal file
@ -0,0 +1,704 @@
|
||||
import typing
|
||||
import types
|
||||
import inspect
|
||||
import functools
|
||||
from . import _uarray
|
||||
import copyreg
|
||||
import pickle
|
||||
import contextlib
|
||||
|
||||
from ._uarray import ( # type: ignore
|
||||
BackendNotImplementedError,
|
||||
_Function,
|
||||
_SkipBackendContext,
|
||||
_SetBackendContext,
|
||||
_BackendState,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"set_backend",
|
||||
"set_global_backend",
|
||||
"skip_backend",
|
||||
"register_backend",
|
||||
"determine_backend",
|
||||
"determine_backend_multi",
|
||||
"clear_backends",
|
||||
"create_multimethod",
|
||||
"generate_multimethod",
|
||||
"_Function",
|
||||
"BackendNotImplementedError",
|
||||
"Dispatchable",
|
||||
"wrap_single_convertor",
|
||||
"wrap_single_convertor_instance",
|
||||
"all_of_type",
|
||||
"mark_as",
|
||||
"set_state",
|
||||
"get_state",
|
||||
"reset_state",
|
||||
"_BackendState",
|
||||
"_SkipBackendContext",
|
||||
"_SetBackendContext",
|
||||
]
|
||||
|
||||
ArgumentExtractorType = typing.Callable[..., tuple["Dispatchable", ...]]
|
||||
ArgumentReplacerType = typing.Callable[
|
||||
[tuple, dict, tuple], tuple[tuple, dict]
|
||||
]
|
||||
|
||||
def unpickle_function(mod_name, qname, self_):
|
||||
import importlib
|
||||
|
||||
try:
|
||||
module = importlib.import_module(mod_name)
|
||||
qname = qname.split(".")
|
||||
func = module
|
||||
for q in qname:
|
||||
func = getattr(func, q)
|
||||
|
||||
if self_ is not None:
|
||||
func = types.MethodType(func, self_)
|
||||
|
||||
return func
|
||||
except (ImportError, AttributeError) as e:
|
||||
from pickle import UnpicklingError
|
||||
|
||||
raise UnpicklingError from e
|
||||
|
||||
|
||||
def pickle_function(func):
|
||||
mod_name = getattr(func, "__module__", None)
|
||||
qname = getattr(func, "__qualname__", None)
|
||||
self_ = getattr(func, "__self__", None)
|
||||
|
||||
try:
|
||||
test = unpickle_function(mod_name, qname, self_)
|
||||
except pickle.UnpicklingError:
|
||||
test = None
|
||||
|
||||
if test is not func:
|
||||
raise pickle.PicklingError(
|
||||
f"Can't pickle {func}: it's not the same object as {test}"
|
||||
)
|
||||
|
||||
return unpickle_function, (mod_name, qname, self_)
|
||||
|
||||
|
||||
def pickle_state(state):
|
||||
return _uarray._BackendState._unpickle, state._pickle()
|
||||
|
||||
|
||||
def pickle_set_backend_context(ctx):
|
||||
return _SetBackendContext, ctx._pickle()
|
||||
|
||||
|
||||
def pickle_skip_backend_context(ctx):
|
||||
return _SkipBackendContext, ctx._pickle()
|
||||
|
||||
|
||||
copyreg.pickle(_Function, pickle_function)
|
||||
copyreg.pickle(_uarray._BackendState, pickle_state)
|
||||
copyreg.pickle(_SetBackendContext, pickle_set_backend_context)
|
||||
copyreg.pickle(_SkipBackendContext, pickle_skip_backend_context)
|
||||
|
||||
|
||||
def get_state():
|
||||
"""
|
||||
Returns an opaque object containing the current state of all the backends.
|
||||
|
||||
Can be used for synchronization between threads/processes.
|
||||
|
||||
See Also
|
||||
--------
|
||||
set_state
|
||||
Sets the state returned by this function.
|
||||
"""
|
||||
return _uarray.get_state()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def reset_state():
|
||||
"""
|
||||
Returns a context manager that resets all state once exited.
|
||||
|
||||
See Also
|
||||
--------
|
||||
set_state
|
||||
Context manager that sets the backend state.
|
||||
get_state
|
||||
Gets a state to be set by this context manager.
|
||||
"""
|
||||
with set_state(get_state()):
|
||||
yield
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_state(state):
|
||||
"""
|
||||
A context manager that sets the state of the backends to one returned by :obj:`get_state`.
|
||||
|
||||
See Also
|
||||
--------
|
||||
get_state
|
||||
Gets a state to be set by this context manager.
|
||||
""" # noqa: E501
|
||||
old_state = get_state()
|
||||
_uarray.set_state(state)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_uarray.set_state(old_state, True)
|
||||
|
||||
|
||||
def create_multimethod(*args, **kwargs):
|
||||
"""
|
||||
Creates a decorator for generating multimethods.
|
||||
|
||||
This function creates a decorator that can be used with an argument
|
||||
extractor in order to generate a multimethod. Other than for the
|
||||
argument extractor, all arguments are passed on to
|
||||
:obj:`generate_multimethod`.
|
||||
|
||||
See Also
|
||||
--------
|
||||
generate_multimethod
|
||||
Generates a multimethod.
|
||||
"""
|
||||
|
||||
def wrapper(a):
|
||||
return generate_multimethod(a, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def generate_multimethod(
|
||||
argument_extractor: ArgumentExtractorType,
|
||||
argument_replacer: ArgumentReplacerType,
|
||||
domain: str,
|
||||
default: typing.Optional[typing.Callable] = None,
|
||||
):
|
||||
"""
|
||||
Generates a multimethod.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
argument_extractor : ArgumentExtractorType
|
||||
A callable which extracts the dispatchable arguments. Extracted arguments
|
||||
should be marked by the :obj:`Dispatchable` class. It has the same signature
|
||||
as the desired multimethod.
|
||||
argument_replacer : ArgumentReplacerType
|
||||
A callable with the signature (args, kwargs, dispatchables), which should also
|
||||
return an (args, kwargs) pair with the dispatchables replaced inside the
|
||||
args/kwargs.
|
||||
domain : str
|
||||
A string value indicating the domain of this multimethod.
|
||||
default: Optional[Callable], optional
|
||||
The default implementation of this multimethod, where ``None`` (the default)
|
||||
specifies there is no default implementation.
|
||||
|
||||
Examples
|
||||
--------
|
||||
In this example, ``a`` is to be dispatched over, so we return it, while marking it
|
||||
as an ``int``.
|
||||
The trailing comma is needed because the args have to be returned as an iterable.
|
||||
|
||||
>>> def override_me(a, b):
|
||||
... return Dispatchable(a, int),
|
||||
|
||||
Next, we define the argument replacer that replaces the dispatchables inside
|
||||
args/kwargs with the supplied ones.
|
||||
|
||||
>>> def override_replacer(args, kwargs, dispatchables):
|
||||
... return (dispatchables[0], args[1]), {}
|
||||
|
||||
Next, we define the multimethod.
|
||||
|
||||
>>> overridden_me = generate_multimethod(
|
||||
... override_me, override_replacer, "ua_examples"
|
||||
... )
|
||||
|
||||
Notice that there's no default implementation, unless you supply one.
|
||||
|
||||
>>> overridden_me(1, "a")
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
uarray.BackendNotImplementedError: ...
|
||||
|
||||
>>> overridden_me2 = generate_multimethod(
|
||||
... override_me, override_replacer, "ua_examples", default=lambda x, y: (x, y)
|
||||
... )
|
||||
>>> overridden_me2(1, "a")
|
||||
(1, 'a')
|
||||
|
||||
See Also
|
||||
--------
|
||||
uarray
|
||||
See the module documentation for how to override the method by creating
|
||||
backends.
|
||||
"""
|
||||
kw_defaults, arg_defaults, opts = get_defaults(argument_extractor)
|
||||
ua_func = _Function(
|
||||
argument_extractor,
|
||||
argument_replacer,
|
||||
domain,
|
||||
arg_defaults,
|
||||
kw_defaults,
|
||||
default,
|
||||
)
|
||||
|
||||
return functools.update_wrapper(ua_func, argument_extractor)
|
||||
|
||||
|
||||
def set_backend(backend, coerce=False, only=False):
|
||||
"""
|
||||
A context manager that sets the preferred backend.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backend
|
||||
The backend to set.
|
||||
coerce
|
||||
Whether or not to coerce to a specific backend's types. Implies ``only``.
|
||||
only
|
||||
Whether or not this should be the last backend to try.
|
||||
|
||||
See Also
|
||||
--------
|
||||
skip_backend: A context manager that allows skipping of backends.
|
||||
set_global_backend: Set a single, global backend for a domain.
|
||||
"""
|
||||
try:
|
||||
return backend.__ua_cache__["set", coerce, only]
|
||||
except AttributeError:
|
||||
backend.__ua_cache__ = {}
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
ctx = _SetBackendContext(backend, coerce, only)
|
||||
backend.__ua_cache__["set", coerce, only] = ctx
|
||||
return ctx
|
||||
|
||||
|
||||
def skip_backend(backend):
|
||||
"""
|
||||
A context manager that allows one to skip a given backend from processing
|
||||
entirely. This allows one to use another backend's code in a library that
|
||||
is also a consumer of the same backend.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backend
|
||||
The backend to skip.
|
||||
|
||||
See Also
|
||||
--------
|
||||
set_backend: A context manager that allows setting of backends.
|
||||
set_global_backend: Set a single, global backend for a domain.
|
||||
"""
|
||||
try:
|
||||
return backend.__ua_cache__["skip"]
|
||||
except AttributeError:
|
||||
backend.__ua_cache__ = {}
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
ctx = _SkipBackendContext(backend)
|
||||
backend.__ua_cache__["skip"] = ctx
|
||||
return ctx
|
||||
|
||||
|
||||
def get_defaults(f):
|
||||
sig = inspect.signature(f)
|
||||
kw_defaults = {}
|
||||
arg_defaults = []
|
||||
opts = set()
|
||||
for k, v in sig.parameters.items():
|
||||
if v.default is not inspect.Parameter.empty:
|
||||
kw_defaults[k] = v.default
|
||||
if v.kind in (
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
):
|
||||
arg_defaults.append(v.default)
|
||||
opts.add(k)
|
||||
|
||||
return kw_defaults, tuple(arg_defaults), opts
|
||||
|
||||
|
||||
def set_global_backend(backend, coerce=False, only=False, *, try_last=False):
|
||||
"""
|
||||
This utility method replaces the default backend for permanent use. It
|
||||
will be tried in the list of backends automatically, unless the
|
||||
``only`` flag is set on a backend. This will be the first tried
|
||||
backend outside the :obj:`set_backend` context manager.
|
||||
|
||||
Note that this method is not thread-safe.
|
||||
|
||||
.. warning::
|
||||
We caution library authors against using this function in
|
||||
their code. We do *not* support this use-case. This function
|
||||
is meant to be used only by users themselves, or by a reference
|
||||
implementation, if one exists.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backend
|
||||
The backend to register.
|
||||
coerce : bool
|
||||
Whether to coerce input types when trying this backend.
|
||||
only : bool
|
||||
If ``True``, no more backends will be tried if this fails.
|
||||
Implied by ``coerce=True``.
|
||||
try_last : bool
|
||||
If ``True``, the global backend is tried after registered backends.
|
||||
|
||||
See Also
|
||||
--------
|
||||
set_backend: A context manager that allows setting of backends.
|
||||
skip_backend: A context manager that allows skipping of backends.
|
||||
"""
|
||||
_uarray.set_global_backend(backend, coerce, only, try_last)
|
||||
|
||||
|
||||
def register_backend(backend):
|
||||
"""
|
||||
This utility method sets registers backend for permanent use. It
|
||||
will be tried in the list of backends automatically, unless the
|
||||
``only`` flag is set on a backend.
|
||||
|
||||
Note that this method is not thread-safe.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backend
|
||||
The backend to register.
|
||||
"""
|
||||
_uarray.register_backend(backend)
|
||||
|
||||
|
||||
def clear_backends(domain, registered=True, globals=False):
|
||||
"""
|
||||
This utility method clears registered backends.
|
||||
|
||||
.. warning::
|
||||
We caution library authors against using this function in
|
||||
their code. We do *not* support this use-case. This function
|
||||
is meant to be used only by users themselves.
|
||||
|
||||
.. warning::
|
||||
Do NOT use this method inside a multimethod call, or the
|
||||
program is likely to crash.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
domain : Optional[str]
|
||||
The domain for which to de-register backends. ``None`` means
|
||||
de-register for all domains.
|
||||
registered : bool
|
||||
Whether or not to clear registered backends. See :obj:`register_backend`.
|
||||
globals : bool
|
||||
Whether or not to clear global backends. See :obj:`set_global_backend`.
|
||||
|
||||
See Also
|
||||
--------
|
||||
register_backend : Register a backend globally.
|
||||
set_global_backend : Set a global backend.
|
||||
"""
|
||||
_uarray.clear_backends(domain, registered, globals)
|
||||
|
||||
|
||||
class Dispatchable:
|
||||
"""
|
||||
A utility class which marks an argument with a specific dispatch type.
|
||||
|
||||
|
||||
Attributes
|
||||
----------
|
||||
value
|
||||
The value of the Dispatchable.
|
||||
|
||||
type
|
||||
The type of the Dispatchable.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> x = Dispatchable(1, str)
|
||||
>>> x
|
||||
<Dispatchable: type=<class 'str'>, value=1>
|
||||
|
||||
See Also
|
||||
--------
|
||||
all_of_type
|
||||
Marks all unmarked parameters of a function.
|
||||
|
||||
mark_as
|
||||
Allows one to create a utility function to mark as a given type.
|
||||
"""
|
||||
|
||||
def __init__(self, value, dispatch_type, coercible=True):
|
||||
self.value = value
|
||||
self.type = dispatch_type
|
||||
self.coercible = coercible
|
||||
|
||||
def __getitem__(self, index):
|
||||
return (self.type, self.value)[index]
|
||||
|
||||
def __str__(self):
|
||||
return f"<{type(self).__name__}: type={self.type!r}, value={self.value!r}>"
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
def mark_as(dispatch_type):
|
||||
"""
|
||||
Creates a utility function to mark something as a specific type.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> mark_int = mark_as(int)
|
||||
>>> mark_int(1)
|
||||
<Dispatchable: type=<class 'int'>, value=1>
|
||||
"""
|
||||
return functools.partial(Dispatchable, dispatch_type=dispatch_type)
|
||||
|
||||
|
||||
def all_of_type(arg_type):
|
||||
"""
|
||||
Marks all unmarked arguments as a given type.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> @all_of_type(str)
|
||||
... def f(a, b):
|
||||
... return a, Dispatchable(b, int)
|
||||
>>> f('a', 1)
|
||||
(<Dispatchable: type=<class 'str'>, value='a'>,
|
||||
<Dispatchable: type=<class 'int'>, value=1>)
|
||||
"""
|
||||
|
||||
def outer(func):
|
||||
@functools.wraps(func)
|
||||
def inner(*args, **kwargs):
|
||||
extracted_args = func(*args, **kwargs)
|
||||
return tuple(
|
||||
Dispatchable(arg, arg_type)
|
||||
if not isinstance(arg, Dispatchable)
|
||||
else arg
|
||||
for arg in extracted_args
|
||||
)
|
||||
|
||||
return inner
|
||||
|
||||
return outer
|
||||
|
||||
|
||||
def wrap_single_convertor(convert_single):
|
||||
"""
|
||||
Wraps a ``__ua_convert__`` defined for a single element to all elements.
|
||||
If any of them return ``NotImplemented``, the operation is assumed to be
|
||||
undefined.
|
||||
|
||||
Accepts a signature of (value, type, coerce).
|
||||
"""
|
||||
|
||||
@functools.wraps(convert_single)
|
||||
def __ua_convert__(dispatchables, coerce):
|
||||
converted = []
|
||||
for d in dispatchables:
|
||||
c = convert_single(d.value, d.type, coerce and d.coercible)
|
||||
|
||||
if c is NotImplemented:
|
||||
return NotImplemented
|
||||
|
||||
converted.append(c)
|
||||
|
||||
return converted
|
||||
|
||||
return __ua_convert__
|
||||
|
||||
|
||||
def wrap_single_convertor_instance(convert_single):
|
||||
"""
|
||||
Wraps a ``__ua_convert__`` defined for a single element to all elements.
|
||||
If any of them return ``NotImplemented``, the operation is assumed to be
|
||||
undefined.
|
||||
|
||||
Accepts a signature of (value, type, coerce).
|
||||
"""
|
||||
|
||||
@functools.wraps(convert_single)
|
||||
def __ua_convert__(self, dispatchables, coerce):
|
||||
converted = []
|
||||
for d in dispatchables:
|
||||
c = convert_single(self, d.value, d.type, coerce and d.coercible)
|
||||
|
||||
if c is NotImplemented:
|
||||
return NotImplemented
|
||||
|
||||
converted.append(c)
|
||||
|
||||
return converted
|
||||
|
||||
return __ua_convert__
|
||||
|
||||
|
||||
def determine_backend(value, dispatch_type, *, domain, only=True, coerce=False):
|
||||
"""Set the backend to the first active backend that supports ``value``
|
||||
|
||||
This is useful for functions that call multimethods without any dispatchable
|
||||
arguments. You can use :func:`determine_backend` to ensure the same backend
|
||||
is used everywhere in a block of multimethod calls.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
value
|
||||
The value being tested
|
||||
dispatch_type
|
||||
The dispatch type associated with ``value``, aka
|
||||
":ref:`marking <MarkingGlossary>`".
|
||||
domain: string
|
||||
The domain to query for backends and set.
|
||||
coerce: bool
|
||||
Whether or not to allow coercion to the backend's types. Implies ``only``.
|
||||
only: bool
|
||||
Whether or not this should be the last backend to try.
|
||||
|
||||
See Also
|
||||
--------
|
||||
set_backend: For when you know which backend to set
|
||||
|
||||
Notes
|
||||
-----
|
||||
|
||||
Support is determined by the ``__ua_convert__`` protocol. Backends not
|
||||
supporting the type must return ``NotImplemented`` from their
|
||||
``__ua_convert__`` if they don't support input of that type.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
Suppose we have two backends ``BackendA`` and ``BackendB`` each supporting
|
||||
different types, ``TypeA`` and ``TypeB``. Neither supporting the other type:
|
||||
|
||||
>>> with ua.set_backend(ex.BackendA):
|
||||
... ex.call_multimethod(ex.TypeB(), ex.TypeB())
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
uarray.BackendNotImplementedError: ...
|
||||
|
||||
Now consider a multimethod that creates a new object of ``TypeA``, or
|
||||
``TypeB`` depending on the active backend.
|
||||
|
||||
>>> with ua.set_backend(ex.BackendA), ua.set_backend(ex.BackendB):
|
||||
... res = ex.creation_multimethod()
|
||||
... ex.call_multimethod(res, ex.TypeA())
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
uarray.BackendNotImplementedError: ...
|
||||
|
||||
``res`` is an object of ``TypeB`` because ``BackendB`` is set in the
|
||||
innermost with statement. So, ``call_multimethod`` fails since the types
|
||||
don't match.
|
||||
|
||||
Instead, we need to first find a backend suitable for all of our objects.
|
||||
|
||||
>>> with ua.set_backend(ex.BackendA), ua.set_backend(ex.BackendB):
|
||||
... x = ex.TypeA()
|
||||
... with ua.determine_backend(x, "mark", domain="ua_examples"):
|
||||
... res = ex.creation_multimethod()
|
||||
... ex.call_multimethod(res, x)
|
||||
TypeA
|
||||
|
||||
"""
|
||||
dispatchables = (Dispatchable(value, dispatch_type, coerce),)
|
||||
backend = _uarray.determine_backend(domain, dispatchables, coerce)
|
||||
|
||||
return set_backend(backend, coerce=coerce, only=only)
|
||||
|
||||
|
||||
def determine_backend_multi(
|
||||
dispatchables, *, domain, only=True, coerce=False, **kwargs
|
||||
):
|
||||
"""Set a backend supporting all ``dispatchables``
|
||||
|
||||
This is useful for functions that call multimethods without any dispatchable
|
||||
arguments. You can use :func:`determine_backend_multi` to ensure the same
|
||||
backend is used everywhere in a block of multimethod calls involving
|
||||
multiple arrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dispatchables: Sequence[Union[uarray.Dispatchable, Any]]
|
||||
The dispatchables that must be supported
|
||||
domain: string
|
||||
The domain to query for backends and set.
|
||||
coerce: bool
|
||||
Whether or not to allow coercion to the backend's types. Implies ``only``.
|
||||
only: bool
|
||||
Whether or not this should be the last backend to try.
|
||||
dispatch_type: Optional[Any]
|
||||
The default dispatch type associated with ``dispatchables``, aka
|
||||
":ref:`marking <MarkingGlossary>`".
|
||||
|
||||
See Also
|
||||
--------
|
||||
determine_backend: For a single dispatch value
|
||||
set_backend: For when you know which backend to set
|
||||
|
||||
Notes
|
||||
-----
|
||||
|
||||
Support is determined by the ``__ua_convert__`` protocol. Backends not
|
||||
supporting the type must return ``NotImplemented`` from their
|
||||
``__ua_convert__`` if they don't support input of that type.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
:func:`determine_backend` allows the backend to be set from a single
|
||||
object. :func:`determine_backend_multi` allows multiple objects to be
|
||||
checked simultaneously for support in the backend. Suppose we have a
|
||||
``BackendAB`` which supports ``TypeA`` and ``TypeB`` in the same call,
|
||||
and a ``BackendBC`` that doesn't support ``TypeA``.
|
||||
|
||||
>>> with ua.set_backend(ex.BackendAB), ua.set_backend(ex.BackendBC):
|
||||
... a, b = ex.TypeA(), ex.TypeB()
|
||||
... with ua.determine_backend_multi(
|
||||
... [ua.Dispatchable(a, "mark"), ua.Dispatchable(b, "mark")],
|
||||
... domain="ua_examples"
|
||||
... ):
|
||||
... res = ex.creation_multimethod()
|
||||
... ex.call_multimethod(res, a, b)
|
||||
TypeA
|
||||
|
||||
This won't call ``BackendBC`` because it doesn't support ``TypeA``.
|
||||
|
||||
We can also use leave out the ``ua.Dispatchable`` if we specify the
|
||||
default ``dispatch_type`` for the ``dispatchables`` argument.
|
||||
|
||||
>>> with ua.set_backend(ex.BackendAB), ua.set_backend(ex.BackendBC):
|
||||
... a, b = ex.TypeA(), ex.TypeB()
|
||||
... with ua.determine_backend_multi(
|
||||
... [a, b], dispatch_type="mark", domain="ua_examples"
|
||||
... ):
|
||||
... res = ex.creation_multimethod()
|
||||
... ex.call_multimethod(res, a, b)
|
||||
TypeA
|
||||
|
||||
"""
|
||||
if "dispatch_type" in kwargs:
|
||||
disp_type = kwargs.pop("dispatch_type")
|
||||
dispatchables = tuple(
|
||||
d if isinstance(d, Dispatchable) else Dispatchable(d, disp_type)
|
||||
for d in dispatchables
|
||||
)
|
||||
else:
|
||||
dispatchables = tuple(dispatchables)
|
||||
if not all(isinstance(d, Dispatchable) for d in dispatchables):
|
||||
raise TypeError("dispatchables must be instances of uarray.Dispatchable")
|
||||
|
||||
if len(kwargs) != 0:
|
||||
raise TypeError(f"Received unexpected keyword arguments: {kwargs}")
|
||||
|
||||
backend = _uarray.determine_backend(domain, dispatchables, coerce)
|
||||
|
||||
return set_backend(backend, coerce=coerce, only=only)
|
||||
Binary file not shown.
954
venv/lib/python3.12/site-packages/scipy/_lib/_util.py
Normal file
954
venv/lib/python3.12/site-packages/scipy/_lib/_util.py
Normal file
@ -0,0 +1,954 @@
|
||||
import re
|
||||
from contextlib import contextmanager
|
||||
import functools
|
||||
import operator
|
||||
import warnings
|
||||
import numbers
|
||||
from collections import namedtuple
|
||||
import inspect
|
||||
import math
|
||||
from typing import (
|
||||
Optional,
|
||||
Union,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from scipy._lib._array_api import array_namespace, is_numpy, size as xp_size
|
||||
|
||||
|
||||
AxisError: type[Exception]
|
||||
ComplexWarning: type[Warning]
|
||||
VisibleDeprecationWarning: type[Warning]
|
||||
|
||||
if np.lib.NumpyVersion(np.__version__) >= '1.25.0':
|
||||
from numpy.exceptions import (
|
||||
AxisError, ComplexWarning, VisibleDeprecationWarning,
|
||||
DTypePromotionError
|
||||
)
|
||||
else:
|
||||
from numpy import ( # type: ignore[attr-defined, no-redef]
|
||||
AxisError, ComplexWarning, VisibleDeprecationWarning # noqa: F401
|
||||
)
|
||||
DTypePromotionError = TypeError # type: ignore
|
||||
|
||||
np_long: type
|
||||
np_ulong: type
|
||||
|
||||
if np.lib.NumpyVersion(np.__version__) >= "2.0.0.dev0":
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
r".*In the future `np\.long` will be defined as.*",
|
||||
FutureWarning,
|
||||
)
|
||||
np_long = np.long # type: ignore[attr-defined]
|
||||
np_ulong = np.ulong # type: ignore[attr-defined]
|
||||
except AttributeError:
|
||||
np_long = np.int_
|
||||
np_ulong = np.uint
|
||||
else:
|
||||
np_long = np.int_
|
||||
np_ulong = np.uint
|
||||
|
||||
IntNumber = Union[int, np.integer]
|
||||
DecimalNumber = Union[float, np.floating, np.integer]
|
||||
|
||||
copy_if_needed: Optional[bool]
|
||||
|
||||
if np.lib.NumpyVersion(np.__version__) >= "2.0.0":
|
||||
copy_if_needed = None
|
||||
elif np.lib.NumpyVersion(np.__version__) < "1.28.0":
|
||||
copy_if_needed = False
|
||||
else:
|
||||
# 2.0.0 dev versions, handle cases where copy may or may not exist
|
||||
try:
|
||||
np.array([1]).__array__(copy=None) # type: ignore[call-overload]
|
||||
copy_if_needed = None
|
||||
except TypeError:
|
||||
copy_if_needed = False
|
||||
|
||||
# Since Generator was introduced in numpy 1.17, the following condition is needed for
|
||||
# backward compatibility
|
||||
if TYPE_CHECKING:
|
||||
SeedType = Optional[Union[IntNumber, np.random.Generator,
|
||||
np.random.RandomState]]
|
||||
GeneratorType = TypeVar("GeneratorType", bound=Union[np.random.Generator,
|
||||
np.random.RandomState])
|
||||
|
||||
try:
|
||||
from numpy.random import Generator as Generator
|
||||
except ImportError:
|
||||
class Generator: # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
|
||||
def _lazywhere(cond, arrays, f, fillvalue=None, f2=None):
|
||||
"""Return elements chosen from two possibilities depending on a condition
|
||||
|
||||
Equivalent to ``f(*arrays) if cond else fillvalue`` performed elementwise.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cond : array
|
||||
The condition (expressed as a boolean array).
|
||||
arrays : tuple of array
|
||||
Arguments to `f` (and `f2`). Must be broadcastable with `cond`.
|
||||
f : callable
|
||||
Where `cond` is True, output will be ``f(arr1[cond], arr2[cond], ...)``
|
||||
fillvalue : object
|
||||
If provided, value with which to fill output array where `cond` is
|
||||
not True.
|
||||
f2 : callable
|
||||
If provided, output will be ``f2(arr1[cond], arr2[cond], ...)`` where
|
||||
`cond` is not True.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : array
|
||||
An array with elements from the output of `f` where `cond` is True
|
||||
and `fillvalue` (or elements from the output of `f2`) elsewhere. The
|
||||
returned array has data type determined by Type Promotion Rules
|
||||
with the output of `f` and `fillvalue` (or the output of `f2`).
|
||||
|
||||
Notes
|
||||
-----
|
||||
``xp.where(cond, x, fillvalue)`` requires explicitly forming `x` even where
|
||||
`cond` is False. This function evaluates ``f(arr1[cond], arr2[cond], ...)``
|
||||
onle where `cond` ``is True.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import numpy as np
|
||||
>>> a, b = np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8])
|
||||
>>> def f(a, b):
|
||||
... return a*b
|
||||
>>> _lazywhere(a > 2, (a, b), f, np.nan)
|
||||
array([ nan, nan, 21., 32.])
|
||||
|
||||
"""
|
||||
xp = array_namespace(cond, *arrays)
|
||||
|
||||
if (f2 is fillvalue is None) or (f2 is not None and fillvalue is not None):
|
||||
raise ValueError("Exactly one of `fillvalue` or `f2` must be given.")
|
||||
|
||||
args = xp.broadcast_arrays(cond, *arrays)
|
||||
bool_dtype = xp.asarray([True]).dtype # numpy 1.xx doesn't have `bool`
|
||||
cond, arrays = xp.astype(args[0], bool_dtype, copy=False), args[1:]
|
||||
|
||||
temp1 = xp.asarray(f(*(arr[cond] for arr in arrays)))
|
||||
|
||||
if f2 is None:
|
||||
fillvalue = xp.asarray(fillvalue)
|
||||
dtype = xp.result_type(temp1.dtype, fillvalue.dtype)
|
||||
out = xp.full(cond.shape, fill_value=fillvalue, dtype=dtype)
|
||||
else:
|
||||
ncond = ~cond
|
||||
temp2 = xp.asarray(f2(*(arr[ncond] for arr in arrays)))
|
||||
dtype = xp.result_type(temp1, temp2)
|
||||
out = xp.empty(cond.shape, dtype=dtype)
|
||||
out[ncond] = temp2
|
||||
|
||||
out[cond] = temp1
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _lazyselect(condlist, choicelist, arrays, default=0):
|
||||
"""
|
||||
Mimic `np.select(condlist, choicelist)`.
|
||||
|
||||
Notice, it assumes that all `arrays` are of the same shape or can be
|
||||
broadcasted together.
|
||||
|
||||
All functions in `choicelist` must accept array arguments in the order
|
||||
given in `arrays` and must return an array of the same shape as broadcasted
|
||||
`arrays`.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import numpy as np
|
||||
>>> x = np.arange(6)
|
||||
>>> np.select([x <3, x > 3], [x**2, x**3], default=0)
|
||||
array([ 0, 1, 4, 0, 64, 125])
|
||||
|
||||
>>> _lazyselect([x < 3, x > 3], [lambda x: x**2, lambda x: x**3], (x,))
|
||||
array([ 0., 1., 4., 0., 64., 125.])
|
||||
|
||||
>>> a = -np.ones_like(x)
|
||||
>>> _lazyselect([x < 3, x > 3],
|
||||
... [lambda x, a: x**2, lambda x, a: a * x**3],
|
||||
... (x, a), default=np.nan)
|
||||
array([ 0., 1., 4., nan, -64., -125.])
|
||||
|
||||
"""
|
||||
arrays = np.broadcast_arrays(*arrays)
|
||||
tcode = np.mintypecode([a.dtype.char for a in arrays])
|
||||
out = np.full(np.shape(arrays[0]), fill_value=default, dtype=tcode)
|
||||
for func, cond in zip(choicelist, condlist):
|
||||
if np.all(cond is False):
|
||||
continue
|
||||
cond, _ = np.broadcast_arrays(cond, arrays[0])
|
||||
temp = tuple(np.extract(cond, arr) for arr in arrays)
|
||||
np.place(out, cond, func(*temp))
|
||||
return out
|
||||
|
||||
|
||||
def _aligned_zeros(shape, dtype=float, order="C", align=None):
|
||||
"""Allocate a new ndarray with aligned memory.
|
||||
|
||||
Primary use case for this currently is working around a f2py issue
|
||||
in NumPy 1.9.1, where dtype.alignment is such that np.zeros() does
|
||||
not necessarily create arrays aligned up to it.
|
||||
|
||||
"""
|
||||
dtype = np.dtype(dtype)
|
||||
if align is None:
|
||||
align = dtype.alignment
|
||||
if not hasattr(shape, '__len__'):
|
||||
shape = (shape,)
|
||||
size = functools.reduce(operator.mul, shape) * dtype.itemsize
|
||||
buf = np.empty(size + align + 1, np.uint8)
|
||||
offset = buf.__array_interface__['data'][0] % align
|
||||
if offset != 0:
|
||||
offset = align - offset
|
||||
# Note: slices producing 0-size arrays do not necessarily change
|
||||
# data pointer --- so we use and allocate size+1
|
||||
buf = buf[offset:offset+size+1][:-1]
|
||||
data = np.ndarray(shape, dtype, buf, order=order)
|
||||
data.fill(0)
|
||||
return data
|
||||
|
||||
|
||||
def _prune_array(array):
|
||||
"""Return an array equivalent to the input array. If the input
|
||||
array is a view of a much larger array, copy its contents to a
|
||||
newly allocated array. Otherwise, return the input unchanged.
|
||||
"""
|
||||
if array.base is not None and array.size < array.base.size // 2:
|
||||
return array.copy()
|
||||
return array
|
||||
|
||||
|
||||
def float_factorial(n: int) -> float:
|
||||
"""Compute the factorial and return as a float
|
||||
|
||||
Returns infinity when result is too large for a double
|
||||
"""
|
||||
return float(math.factorial(n)) if n < 171 else np.inf
|
||||
|
||||
|
||||
# copy-pasted from scikit-learn utils/validation.py
|
||||
# change this to scipy.stats._qmc.check_random_state once numpy 1.16 is dropped
|
||||
def check_random_state(seed):
|
||||
"""Turn `seed` into a `np.random.RandomState` instance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
seed : {None, int, `numpy.random.Generator`, `numpy.random.RandomState`}, optional
|
||||
If `seed` is None (or `np.random`), the `numpy.random.RandomState`
|
||||
singleton is used.
|
||||
If `seed` is an int, a new ``RandomState`` instance is used,
|
||||
seeded with `seed`.
|
||||
If `seed` is already a ``Generator`` or ``RandomState`` instance then
|
||||
that instance is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
seed : {`numpy.random.Generator`, `numpy.random.RandomState`}
|
||||
Random number generator.
|
||||
|
||||
"""
|
||||
if seed is None or seed is np.random:
|
||||
return np.random.mtrand._rand
|
||||
if isinstance(seed, (numbers.Integral, np.integer)):
|
||||
return np.random.RandomState(seed)
|
||||
if isinstance(seed, (np.random.RandomState, np.random.Generator)):
|
||||
return seed
|
||||
|
||||
raise ValueError(f"'{seed}' cannot be used to seed a numpy.random.RandomState"
|
||||
" instance")
|
||||
|
||||
|
||||
def _asarray_validated(a, check_finite=True,
|
||||
sparse_ok=False, objects_ok=False, mask_ok=False,
|
||||
as_inexact=False):
|
||||
"""
|
||||
Helper function for SciPy argument validation.
|
||||
|
||||
Many SciPy linear algebra functions do support arbitrary array-like
|
||||
input arguments. Examples of commonly unsupported inputs include
|
||||
matrices containing inf/nan, sparse matrix representations, and
|
||||
matrices with complicated elements.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : array_like
|
||||
The array-like input.
|
||||
check_finite : bool, optional
|
||||
Whether to check that the input matrices contain only finite numbers.
|
||||
Disabling may give a performance gain, but may result in problems
|
||||
(crashes, non-termination) if the inputs do contain infinities or NaNs.
|
||||
Default: True
|
||||
sparse_ok : bool, optional
|
||||
True if scipy sparse matrices are allowed.
|
||||
objects_ok : bool, optional
|
||||
True if arrays with dype('O') are allowed.
|
||||
mask_ok : bool, optional
|
||||
True if masked arrays are allowed.
|
||||
as_inexact : bool, optional
|
||||
True to convert the input array to a np.inexact dtype.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ret : ndarray
|
||||
The converted validated array.
|
||||
|
||||
"""
|
||||
if not sparse_ok:
|
||||
import scipy.sparse
|
||||
if scipy.sparse.issparse(a):
|
||||
msg = ('Sparse matrices are not supported by this function. '
|
||||
'Perhaps one of the scipy.sparse.linalg functions '
|
||||
'would work instead.')
|
||||
raise ValueError(msg)
|
||||
if not mask_ok:
|
||||
if np.ma.isMaskedArray(a):
|
||||
raise ValueError('masked arrays are not supported')
|
||||
toarray = np.asarray_chkfinite if check_finite else np.asarray
|
||||
a = toarray(a)
|
||||
if not objects_ok:
|
||||
if a.dtype is np.dtype('O'):
|
||||
raise ValueError('object arrays are not supported')
|
||||
if as_inexact:
|
||||
if not np.issubdtype(a.dtype, np.inexact):
|
||||
a = toarray(a, dtype=np.float64)
|
||||
return a
|
||||
|
||||
|
||||
def _validate_int(k, name, minimum=None):
|
||||
"""
|
||||
Validate a scalar integer.
|
||||
|
||||
This function can be used to validate an argument to a function
|
||||
that expects the value to be an integer. It uses `operator.index`
|
||||
to validate the value (so, for example, k=2.0 results in a
|
||||
TypeError).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
k : int
|
||||
The value to be validated.
|
||||
name : str
|
||||
The name of the parameter.
|
||||
minimum : int, optional
|
||||
An optional lower bound.
|
||||
"""
|
||||
try:
|
||||
k = operator.index(k)
|
||||
except TypeError:
|
||||
raise TypeError(f'{name} must be an integer.') from None
|
||||
if minimum is not None and k < minimum:
|
||||
raise ValueError(f'{name} must be an integer not less '
|
||||
f'than {minimum}') from None
|
||||
return k
|
||||
|
||||
|
||||
# Add a replacement for inspect.getfullargspec()/
|
||||
# The version below is borrowed from Django,
|
||||
# https://github.com/django/django/pull/4846.
|
||||
|
||||
# Note an inconsistency between inspect.getfullargspec(func) and
|
||||
# inspect.signature(func). If `func` is a bound method, the latter does *not*
|
||||
# list `self` as a first argument, while the former *does*.
|
||||
# Hence, cook up a common ground replacement: `getfullargspec_no_self` which
|
||||
# mimics `inspect.getfullargspec` but does not list `self`.
|
||||
#
|
||||
# This way, the caller code does not need to know whether it uses a legacy
|
||||
# .getfullargspec or a bright and shiny .signature.
|
||||
|
||||
FullArgSpec = namedtuple('FullArgSpec',
|
||||
['args', 'varargs', 'varkw', 'defaults',
|
||||
'kwonlyargs', 'kwonlydefaults', 'annotations'])
|
||||
|
||||
|
||||
def getfullargspec_no_self(func):
|
||||
"""inspect.getfullargspec replacement using inspect.signature.
|
||||
|
||||
If func is a bound method, do not list the 'self' parameter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : callable
|
||||
A callable to inspect
|
||||
|
||||
Returns
|
||||
-------
|
||||
fullargspec : FullArgSpec(args, varargs, varkw, defaults, kwonlyargs,
|
||||
kwonlydefaults, annotations)
|
||||
|
||||
NOTE: if the first argument of `func` is self, it is *not*, I repeat
|
||||
*not*, included in fullargspec.args.
|
||||
This is done for consistency between inspect.getargspec() under
|
||||
Python 2.x, and inspect.signature() under Python 3.x.
|
||||
|
||||
"""
|
||||
sig = inspect.signature(func)
|
||||
args = [
|
||||
p.name for p in sig.parameters.values()
|
||||
if p.kind in [inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.POSITIONAL_ONLY]
|
||||
]
|
||||
varargs = [
|
||||
p.name for p in sig.parameters.values()
|
||||
if p.kind == inspect.Parameter.VAR_POSITIONAL
|
||||
]
|
||||
varargs = varargs[0] if varargs else None
|
||||
varkw = [
|
||||
p.name for p in sig.parameters.values()
|
||||
if p.kind == inspect.Parameter.VAR_KEYWORD
|
||||
]
|
||||
varkw = varkw[0] if varkw else None
|
||||
defaults = tuple(
|
||||
p.default for p in sig.parameters.values()
|
||||
if (p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD and
|
||||
p.default is not p.empty)
|
||||
) or None
|
||||
kwonlyargs = [
|
||||
p.name for p in sig.parameters.values()
|
||||
if p.kind == inspect.Parameter.KEYWORD_ONLY
|
||||
]
|
||||
kwdefaults = {p.name: p.default for p in sig.parameters.values()
|
||||
if p.kind == inspect.Parameter.KEYWORD_ONLY and
|
||||
p.default is not p.empty}
|
||||
annotations = {p.name: p.annotation for p in sig.parameters.values()
|
||||
if p.annotation is not p.empty}
|
||||
return FullArgSpec(args, varargs, varkw, defaults, kwonlyargs,
|
||||
kwdefaults or None, annotations)
|
||||
|
||||
|
||||
class _FunctionWrapper:
|
||||
"""
|
||||
Object to wrap user's function, allowing picklability
|
||||
"""
|
||||
def __init__(self, f, args):
|
||||
self.f = f
|
||||
self.args = [] if args is None else args
|
||||
|
||||
def __call__(self, x):
|
||||
return self.f(x, *self.args)
|
||||
|
||||
|
||||
class MapWrapper:
|
||||
"""
|
||||
Parallelisation wrapper for working with map-like callables, such as
|
||||
`multiprocessing.Pool.map`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pool : int or map-like callable
|
||||
If `pool` is an integer, then it specifies the number of threads to
|
||||
use for parallelization. If ``int(pool) == 1``, then no parallel
|
||||
processing is used and the map builtin is used.
|
||||
If ``pool == -1``, then the pool will utilize all available CPUs.
|
||||
If `pool` is a map-like callable that follows the same
|
||||
calling sequence as the built-in map function, then this callable is
|
||||
used for parallelization.
|
||||
"""
|
||||
def __init__(self, pool=1):
|
||||
self.pool = None
|
||||
self._mapfunc = map
|
||||
self._own_pool = False
|
||||
|
||||
if callable(pool):
|
||||
self.pool = pool
|
||||
self._mapfunc = self.pool
|
||||
else:
|
||||
from multiprocessing import Pool
|
||||
# user supplies a number
|
||||
if int(pool) == -1:
|
||||
# use as many processors as possible
|
||||
self.pool = Pool()
|
||||
self._mapfunc = self.pool.map
|
||||
self._own_pool = True
|
||||
elif int(pool) == 1:
|
||||
pass
|
||||
elif int(pool) > 1:
|
||||
# use the number of processors requested
|
||||
self.pool = Pool(processes=int(pool))
|
||||
self._mapfunc = self.pool.map
|
||||
self._own_pool = True
|
||||
else:
|
||||
raise RuntimeError("Number of workers specified must be -1,"
|
||||
" an int >= 1, or an object with a 'map' "
|
||||
"method")
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def terminate(self):
|
||||
if self._own_pool:
|
||||
self.pool.terminate()
|
||||
|
||||
def join(self):
|
||||
if self._own_pool:
|
||||
self.pool.join()
|
||||
|
||||
def close(self):
|
||||
if self._own_pool:
|
||||
self.pool.close()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
if self._own_pool:
|
||||
self.pool.close()
|
||||
self.pool.terminate()
|
||||
|
||||
def __call__(self, func, iterable):
|
||||
# only accept one iterable because that's all Pool.map accepts
|
||||
try:
|
||||
return self._mapfunc(func, iterable)
|
||||
except TypeError as e:
|
||||
# wrong number of arguments
|
||||
raise TypeError("The map-like callable must be of the"
|
||||
" form f(func, iterable)") from e
|
||||
|
||||
|
||||
def rng_integers(gen, low, high=None, size=None, dtype='int64',
|
||||
endpoint=False):
|
||||
"""
|
||||
Return random integers from low (inclusive) to high (exclusive), or if
|
||||
endpoint=True, low (inclusive) to high (inclusive). Replaces
|
||||
`RandomState.randint` (with endpoint=False) and
|
||||
`RandomState.random_integers` (with endpoint=True).
|
||||
|
||||
Return random integers from the "discrete uniform" distribution of the
|
||||
specified dtype. If high is None (the default), then results are from
|
||||
0 to low.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
gen : {None, np.random.RandomState, np.random.Generator}
|
||||
Random number generator. If None, then the np.random.RandomState
|
||||
singleton is used.
|
||||
low : int or array-like of ints
|
||||
Lowest (signed) integers to be drawn from the distribution (unless
|
||||
high=None, in which case this parameter is 0 and this value is used
|
||||
for high).
|
||||
high : int or array-like of ints
|
||||
If provided, one above the largest (signed) integer to be drawn from
|
||||
the distribution (see above for behavior if high=None). If array-like,
|
||||
must contain integer values.
|
||||
size : array-like of ints, optional
|
||||
Output shape. If the given shape is, e.g., (m, n, k), then m * n * k
|
||||
samples are drawn. Default is None, in which case a single value is
|
||||
returned.
|
||||
dtype : {str, dtype}, optional
|
||||
Desired dtype of the result. All dtypes are determined by their name,
|
||||
i.e., 'int64', 'int', etc, so byteorder is not available and a specific
|
||||
precision may have different C types depending on the platform.
|
||||
The default value is 'int64'.
|
||||
endpoint : bool, optional
|
||||
If True, sample from the interval [low, high] instead of the default
|
||||
[low, high) Defaults to False.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out: int or ndarray of ints
|
||||
size-shaped array of random integers from the appropriate distribution,
|
||||
or a single such random int if size not provided.
|
||||
"""
|
||||
if isinstance(gen, Generator):
|
||||
return gen.integers(low, high=high, size=size, dtype=dtype,
|
||||
endpoint=endpoint)
|
||||
else:
|
||||
if gen is None:
|
||||
# default is RandomState singleton used by np.random.
|
||||
gen = np.random.mtrand._rand
|
||||
if endpoint:
|
||||
# inclusive of endpoint
|
||||
# remember that low and high can be arrays, so don't modify in
|
||||
# place
|
||||
if high is None:
|
||||
return gen.randint(low + 1, size=size, dtype=dtype)
|
||||
if high is not None:
|
||||
return gen.randint(low, high=high + 1, size=size, dtype=dtype)
|
||||
|
||||
# exclusive
|
||||
return gen.randint(low, high=high, size=size, dtype=dtype)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _fixed_default_rng(seed=1638083107694713882823079058616272161):
|
||||
"""Context with a fixed np.random.default_rng seed."""
|
||||
orig_fun = np.random.default_rng
|
||||
np.random.default_rng = lambda seed=seed: orig_fun(seed)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
np.random.default_rng = orig_fun
|
||||
|
||||
|
||||
def _rng_html_rewrite(func):
|
||||
"""Rewrite the HTML rendering of ``np.random.default_rng``.
|
||||
|
||||
This is intended to decorate
|
||||
``numpydoc.docscrape_sphinx.SphinxDocString._str_examples``.
|
||||
|
||||
Examples are only run by Sphinx when there are plot involved. Even so,
|
||||
it does not change the result values getting printed.
|
||||
"""
|
||||
# hexadecimal or number seed, case-insensitive
|
||||
pattern = re.compile(r'np.random.default_rng\((0x[0-9A-F]+|\d+)\)', re.I)
|
||||
|
||||
def _wrapped(*args, **kwargs):
|
||||
res = func(*args, **kwargs)
|
||||
lines = [
|
||||
re.sub(pattern, 'np.random.default_rng()', line)
|
||||
for line in res
|
||||
]
|
||||
return lines
|
||||
|
||||
return _wrapped
|
||||
|
||||
|
||||
def _argmin(a, keepdims=False, axis=None):
|
||||
"""
|
||||
argmin with a `keepdims` parameter.
|
||||
|
||||
See https://github.com/numpy/numpy/issues/8710
|
||||
|
||||
If axis is not None, a.shape[axis] must be greater than 0.
|
||||
"""
|
||||
res = np.argmin(a, axis=axis)
|
||||
if keepdims and axis is not None:
|
||||
res = np.expand_dims(res, axis=axis)
|
||||
return res
|
||||
|
||||
|
||||
def _first_nonnan(a, axis):
|
||||
"""
|
||||
Return the first non-nan value along the given axis.
|
||||
|
||||
If a slice is all nan, nan is returned for that slice.
|
||||
|
||||
The shape of the return value corresponds to ``keepdims=True``.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import numpy as np
|
||||
>>> nan = np.nan
|
||||
>>> a = np.array([[ 3., 3., nan, 3.],
|
||||
[ 1., nan, 2., 4.],
|
||||
[nan, nan, 9., -1.],
|
||||
[nan, 5., 4., 3.],
|
||||
[ 2., 2., 2., 2.],
|
||||
[nan, nan, nan, nan]])
|
||||
>>> _first_nonnan(a, axis=0)
|
||||
array([[3., 3., 2., 3.]])
|
||||
>>> _first_nonnan(a, axis=1)
|
||||
array([[ 3.],
|
||||
[ 1.],
|
||||
[ 9.],
|
||||
[ 5.],
|
||||
[ 2.],
|
||||
[nan]])
|
||||
"""
|
||||
k = _argmin(np.isnan(a), axis=axis, keepdims=True)
|
||||
return np.take_along_axis(a, k, axis=axis)
|
||||
|
||||
|
||||
def _nan_allsame(a, axis, keepdims=False):
|
||||
"""
|
||||
Determine if the values along an axis are all the same.
|
||||
|
||||
nan values are ignored.
|
||||
|
||||
`a` must be a numpy array.
|
||||
|
||||
`axis` is assumed to be normalized; that is, 0 <= axis < a.ndim.
|
||||
|
||||
For an axis of length 0, the result is True. That is, we adopt the
|
||||
convention that ``allsame([])`` is True. (There are no values in the
|
||||
input that are different.)
|
||||
|
||||
`True` is returned for slices that are all nan--not because all the
|
||||
values are the same, but because this is equivalent to ``allsame([])``.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from numpy import nan, array
|
||||
>>> a = array([[ 3., 3., nan, 3.],
|
||||
... [ 1., nan, 2., 4.],
|
||||
... [nan, nan, 9., -1.],
|
||||
... [nan, 5., 4., 3.],
|
||||
... [ 2., 2., 2., 2.],
|
||||
... [nan, nan, nan, nan]])
|
||||
>>> _nan_allsame(a, axis=1, keepdims=True)
|
||||
array([[ True],
|
||||
[False],
|
||||
[False],
|
||||
[False],
|
||||
[ True],
|
||||
[ True]])
|
||||
"""
|
||||
if axis is None:
|
||||
if a.size == 0:
|
||||
return True
|
||||
a = a.ravel()
|
||||
axis = 0
|
||||
else:
|
||||
shp = a.shape
|
||||
if shp[axis] == 0:
|
||||
shp = shp[:axis] + (1,)*keepdims + shp[axis + 1:]
|
||||
return np.full(shp, fill_value=True, dtype=bool)
|
||||
a0 = _first_nonnan(a, axis=axis)
|
||||
return ((a0 == a) | np.isnan(a)).all(axis=axis, keepdims=keepdims)
|
||||
|
||||
|
||||
def _contains_nan(a, nan_policy='propagate', policies=None, *, xp=None):
|
||||
if xp is None:
|
||||
xp = array_namespace(a)
|
||||
not_numpy = not is_numpy(xp)
|
||||
|
||||
if policies is None:
|
||||
policies = {'propagate', 'raise', 'omit'}
|
||||
if nan_policy not in policies:
|
||||
raise ValueError(f"nan_policy must be one of {set(policies)}.")
|
||||
|
||||
inexact = (xp.isdtype(a.dtype, "real floating")
|
||||
or xp.isdtype(a.dtype, "complex floating"))
|
||||
if xp_size(a) == 0:
|
||||
contains_nan = False
|
||||
elif inexact:
|
||||
# Faster and less memory-intensive than xp.any(xp.isnan(a))
|
||||
contains_nan = xp.isnan(xp.max(a))
|
||||
elif is_numpy(xp) and np.issubdtype(a.dtype, object):
|
||||
contains_nan = False
|
||||
for el in a.ravel():
|
||||
# isnan doesn't work on non-numeric elements
|
||||
if np.issubdtype(type(el), np.number) and np.isnan(el):
|
||||
contains_nan = True
|
||||
break
|
||||
else:
|
||||
# Only `object` and `inexact` arrays can have NaNs
|
||||
contains_nan = False
|
||||
|
||||
if contains_nan and nan_policy == 'raise':
|
||||
raise ValueError("The input contains nan values")
|
||||
|
||||
if not_numpy and contains_nan and nan_policy=='omit':
|
||||
message = "`nan_policy='omit' is incompatible with non-NumPy arrays."
|
||||
raise ValueError(message)
|
||||
|
||||
return contains_nan, nan_policy
|
||||
|
||||
|
||||
def _rename_parameter(old_name, new_name, dep_version=None):
|
||||
"""
|
||||
Generate decorator for backward-compatible keyword renaming.
|
||||
|
||||
Apply the decorator generated by `_rename_parameter` to functions with a
|
||||
recently renamed parameter to maintain backward-compatibility.
|
||||
|
||||
After decoration, the function behaves as follows:
|
||||
If only the new parameter is passed into the function, behave as usual.
|
||||
If only the old parameter is passed into the function (as a keyword), raise
|
||||
a DeprecationWarning if `dep_version` is provided, and behave as usual
|
||||
otherwise.
|
||||
If both old and new parameters are passed into the function, raise a
|
||||
DeprecationWarning if `dep_version` is provided, and raise the appropriate
|
||||
TypeError (function got multiple values for argument).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
old_name : str
|
||||
Old name of parameter
|
||||
new_name : str
|
||||
New name of parameter
|
||||
dep_version : str, optional
|
||||
Version of SciPy in which old parameter was deprecated in the format
|
||||
'X.Y.Z'. If supplied, the deprecation message will indicate that
|
||||
support for the old parameter will be removed in version 'X.Y+2.Z'
|
||||
|
||||
Notes
|
||||
-----
|
||||
Untested with functions that accept *args. Probably won't work as written.
|
||||
|
||||
"""
|
||||
def decorator(fun):
|
||||
@functools.wraps(fun)
|
||||
def wrapper(*args, **kwargs):
|
||||
if old_name in kwargs:
|
||||
if dep_version:
|
||||
end_version = dep_version.split('.')
|
||||
end_version[1] = str(int(end_version[1]) + 2)
|
||||
end_version = '.'.join(end_version)
|
||||
message = (f"Use of keyword argument `{old_name}` is "
|
||||
f"deprecated and replaced by `{new_name}`. "
|
||||
f"Support for `{old_name}` will be removed "
|
||||
f"in SciPy {end_version}.")
|
||||
warnings.warn(message, DeprecationWarning, stacklevel=2)
|
||||
if new_name in kwargs:
|
||||
message = (f"{fun.__name__}() got multiple values for "
|
||||
f"argument now known as `{new_name}`")
|
||||
raise TypeError(message)
|
||||
kwargs[new_name] = kwargs.pop(old_name)
|
||||
return fun(*args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def _rng_spawn(rng, n_children):
|
||||
# spawns independent RNGs from a parent RNG
|
||||
bg = rng._bit_generator
|
||||
ss = bg._seed_seq
|
||||
child_rngs = [np.random.Generator(type(bg)(child_ss))
|
||||
for child_ss in ss.spawn(n_children)]
|
||||
return child_rngs
|
||||
|
||||
|
||||
def _get_nan(*data, xp=None):
|
||||
xp = array_namespace(*data) if xp is None else xp
|
||||
# Get NaN of appropriate dtype for data
|
||||
data = [xp.asarray(item) for item in data]
|
||||
try:
|
||||
min_float = getattr(xp, 'float16', xp.float32)
|
||||
dtype = xp.result_type(*data, min_float) # must be at least a float
|
||||
except DTypePromotionError:
|
||||
# fallback to float64
|
||||
dtype = xp.float64
|
||||
return xp.asarray(xp.nan, dtype=dtype)[()]
|
||||
|
||||
|
||||
def normalize_axis_index(axis, ndim):
|
||||
# Check if `axis` is in the correct range and normalize it
|
||||
if axis < -ndim or axis >= ndim:
|
||||
msg = f"axis {axis} is out of bounds for array of dimension {ndim}"
|
||||
raise AxisError(msg)
|
||||
|
||||
if axis < 0:
|
||||
axis = axis + ndim
|
||||
return axis
|
||||
|
||||
|
||||
def _call_callback_maybe_halt(callback, res):
|
||||
"""Call wrapped callback; return True if algorithm should stop.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
callback : callable or None
|
||||
A user-provided callback wrapped with `_wrap_callback`
|
||||
res : OptimizeResult
|
||||
Information about the current iterate
|
||||
|
||||
Returns
|
||||
-------
|
||||
halt : bool
|
||||
True if minimization should stop
|
||||
|
||||
"""
|
||||
if callback is None:
|
||||
return False
|
||||
try:
|
||||
callback(res)
|
||||
return False
|
||||
except StopIteration:
|
||||
callback.stop_iteration = True
|
||||
return True
|
||||
|
||||
|
||||
class _RichResult(dict):
|
||||
""" Container for multiple outputs with pretty-printing """
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return self[name]
|
||||
except KeyError as e:
|
||||
raise AttributeError(name) from e
|
||||
|
||||
__setattr__ = dict.__setitem__ # type: ignore[assignment]
|
||||
__delattr__ = dict.__delitem__ # type: ignore[assignment]
|
||||
|
||||
def __repr__(self):
|
||||
order_keys = ['message', 'success', 'status', 'fun', 'funl', 'x', 'xl',
|
||||
'col_ind', 'nit', 'lower', 'upper', 'eqlin', 'ineqlin',
|
||||
'converged', 'flag', 'function_calls', 'iterations',
|
||||
'root']
|
||||
order_keys = getattr(self, '_order_keys', order_keys)
|
||||
# 'slack', 'con' are redundant with residuals
|
||||
# 'crossover_nit' is probably not interesting to most users
|
||||
omit_keys = {'slack', 'con', 'crossover_nit', '_order_keys'}
|
||||
|
||||
def key(item):
|
||||
try:
|
||||
return order_keys.index(item[0].lower())
|
||||
except ValueError: # item not in list
|
||||
return np.inf
|
||||
|
||||
def omit_redundant(items):
|
||||
for item in items:
|
||||
if item[0] in omit_keys:
|
||||
continue
|
||||
yield item
|
||||
|
||||
def item_sorter(d):
|
||||
return sorted(omit_redundant(d.items()), key=key)
|
||||
|
||||
if self.keys():
|
||||
return _dict_formatter(self, sorter=item_sorter)
|
||||
else:
|
||||
return self.__class__.__name__ + "()"
|
||||
|
||||
def __dir__(self):
|
||||
return list(self.keys())
|
||||
|
||||
|
||||
def _indenter(s, n=0):
|
||||
"""
|
||||
Ensures that lines after the first are indented by the specified amount
|
||||
"""
|
||||
split = s.split("\n")
|
||||
indent = " "*n
|
||||
return ("\n" + indent).join(split)
|
||||
|
||||
|
||||
def _float_formatter_10(x):
|
||||
"""
|
||||
Returns a string representation of a float with exactly ten characters
|
||||
"""
|
||||
if np.isposinf(x):
|
||||
return " inf"
|
||||
elif np.isneginf(x):
|
||||
return " -inf"
|
||||
elif np.isnan(x):
|
||||
return " nan"
|
||||
return np.format_float_scientific(x, precision=3, pad_left=2, unique=False)
|
||||
|
||||
|
||||
def _dict_formatter(d, n=0, mplus=1, sorter=None):
|
||||
"""
|
||||
Pretty printer for dictionaries
|
||||
|
||||
`n` keeps track of the starting indentation;
|
||||
lines are indented by this much after a line break.
|
||||
`mplus` is additional left padding applied to keys
|
||||
"""
|
||||
if isinstance(d, dict):
|
||||
m = max(map(len, list(d.keys()))) + mplus # width to print keys
|
||||
s = '\n'.join([k.rjust(m) + ': ' + # right justified, width m
|
||||
_indenter(_dict_formatter(v, m+n+2, 0, sorter), m+2)
|
||||
for k, v in sorter(d)]) # +2 for ': '
|
||||
else:
|
||||
# By default, NumPy arrays print with linewidth=76. `n` is
|
||||
# the indent at which a line begins printing, so it is subtracted
|
||||
# from the default to avoid exceeding 76 characters total.
|
||||
# `edgeitems` is the number of elements to include before and after
|
||||
# ellipses when arrays are not shown in full.
|
||||
# `threshold` is the maximum number of elements for which an
|
||||
# array is shown in full.
|
||||
# These values tend to work well for use with OptimizeResult.
|
||||
with np.printoptions(linewidth=76-n, edgeitems=2, threshold=12,
|
||||
formatter={'float_kind': _float_formatter_10}):
|
||||
s = str(d)
|
||||
return s
|
||||
@ -0,0 +1,22 @@
|
||||
"""
|
||||
NumPy Array API compatibility library
|
||||
|
||||
This is a small wrapper around NumPy and CuPy that is compatible with the
|
||||
Array API standard https://data-apis.org/array-api/latest/. See also NEP 47
|
||||
https://numpy.org/neps/nep-0047-array-api-standard.html.
|
||||
|
||||
Unlike array_api_strict, this is not a strict minimal implementation of the
|
||||
Array API, but rather just an extension of the main NumPy namespace with
|
||||
changes needed to be compliant with the Array API. See
|
||||
https://numpy.org/doc/stable/reference/array_api.html for a full list of
|
||||
changes. In particular, unlike array_api_strict, this package does not use a
|
||||
separate Array object, but rather just uses numpy.ndarray directly.
|
||||
|
||||
Library authors using the Array API may wish to test against array_api_strict
|
||||
to ensure they are not using functionality outside of the standard, but prefer
|
||||
this implementation for the default when working with NumPy arrays.
|
||||
|
||||
"""
|
||||
__version__ = '1.5.1'
|
||||
|
||||
from .common import * # noqa: F401, F403
|
||||
@ -0,0 +1,46 @@
|
||||
"""
|
||||
Internal helpers
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
from inspect import signature
|
||||
|
||||
def get_xp(xp):
|
||||
"""
|
||||
Decorator to automatically replace xp with the corresponding array module.
|
||||
|
||||
Use like
|
||||
|
||||
import numpy as np
|
||||
|
||||
@get_xp(np)
|
||||
def func(x, /, xp, kwarg=None):
|
||||
return xp.func(x, kwarg=kwarg)
|
||||
|
||||
Note that xp must be a keyword argument and come after all non-keyword
|
||||
arguments.
|
||||
|
||||
"""
|
||||
|
||||
def inner(f):
|
||||
@wraps(f)
|
||||
def wrapped_f(*args, **kwargs):
|
||||
return f(*args, xp=xp, **kwargs)
|
||||
|
||||
sig = signature(f)
|
||||
new_sig = sig.replace(
|
||||
parameters=[sig.parameters[i] for i in sig.parameters if i != "xp"]
|
||||
)
|
||||
|
||||
if wrapped_f.__doc__ is None:
|
||||
wrapped_f.__doc__ = f"""\
|
||||
Array API compatibility wrapper for {f.__name__}.
|
||||
|
||||
See the corresponding documentation in NumPy/CuPy and/or the array API
|
||||
specification for more details.
|
||||
|
||||
"""
|
||||
wrapped_f.__signature__ = new_sig
|
||||
return wrapped_f
|
||||
|
||||
return inner
|
||||
@ -0,0 +1 @@
|
||||
from ._helpers import * # noqa: F403
|
||||
@ -0,0 +1,554 @@
|
||||
"""
|
||||
These are functions that are just aliases of existing functions in NumPy.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
from typing import Optional, Sequence, Tuple, Union
|
||||
from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
|
||||
|
||||
from typing import NamedTuple
|
||||
from types import ModuleType
|
||||
import inspect
|
||||
|
||||
from ._helpers import _check_device, is_numpy_array, array_namespace
|
||||
|
||||
# These functions are modified from the NumPy versions.
|
||||
|
||||
def arange(
|
||||
start: Union[int, float],
|
||||
/,
|
||||
stop: Optional[Union[int, float]] = None,
|
||||
step: Union[int, float] = 1,
|
||||
*,
|
||||
xp,
|
||||
dtype: Optional[Dtype] = None,
|
||||
device: Optional[Device] = None,
|
||||
**kwargs
|
||||
) -> ndarray:
|
||||
_check_device(xp, device)
|
||||
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
|
||||
|
||||
def empty(
|
||||
shape: Union[int, Tuple[int, ...]],
|
||||
xp,
|
||||
*,
|
||||
dtype: Optional[Dtype] = None,
|
||||
device: Optional[Device] = None,
|
||||
**kwargs
|
||||
) -> ndarray:
|
||||
_check_device(xp, device)
|
||||
return xp.empty(shape, dtype=dtype, **kwargs)
|
||||
|
||||
def empty_like(
|
||||
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
|
||||
**kwargs
|
||||
) -> ndarray:
|
||||
_check_device(xp, device)
|
||||
return xp.empty_like(x, dtype=dtype, **kwargs)
|
||||
|
||||
def eye(
|
||||
n_rows: int,
|
||||
n_cols: Optional[int] = None,
|
||||
/,
|
||||
*,
|
||||
xp,
|
||||
k: int = 0,
|
||||
dtype: Optional[Dtype] = None,
|
||||
device: Optional[Device] = None,
|
||||
**kwargs,
|
||||
) -> ndarray:
|
||||
_check_device(xp, device)
|
||||
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
|
||||
|
||||
def full(
|
||||
shape: Union[int, Tuple[int, ...]],
|
||||
fill_value: Union[int, float],
|
||||
xp,
|
||||
*,
|
||||
dtype: Optional[Dtype] = None,
|
||||
device: Optional[Device] = None,
|
||||
**kwargs,
|
||||
) -> ndarray:
|
||||
_check_device(xp, device)
|
||||
return xp.full(shape, fill_value, dtype=dtype, **kwargs)
|
||||
|
||||
def full_like(
|
||||
x: ndarray,
|
||||
/,
|
||||
fill_value: Union[int, float],
|
||||
*,
|
||||
xp,
|
||||
dtype: Optional[Dtype] = None,
|
||||
device: Optional[Device] = None,
|
||||
**kwargs,
|
||||
) -> ndarray:
|
||||
_check_device(xp, device)
|
||||
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
|
||||
|
||||
def linspace(
|
||||
start: Union[int, float],
|
||||
stop: Union[int, float],
|
||||
/,
|
||||
num: int,
|
||||
*,
|
||||
xp,
|
||||
dtype: Optional[Dtype] = None,
|
||||
device: Optional[Device] = None,
|
||||
endpoint: bool = True,
|
||||
**kwargs,
|
||||
) -> ndarray:
|
||||
_check_device(xp, device)
|
||||
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
|
||||
|
||||
def ones(
|
||||
shape: Union[int, Tuple[int, ...]],
|
||||
xp,
|
||||
*,
|
||||
dtype: Optional[Dtype] = None,
|
||||
device: Optional[Device] = None,
|
||||
**kwargs,
|
||||
) -> ndarray:
|
||||
_check_device(xp, device)
|
||||
return xp.ones(shape, dtype=dtype, **kwargs)
|
||||
|
||||
def ones_like(
|
||||
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
|
||||
**kwargs,
|
||||
) -> ndarray:
|
||||
_check_device(xp, device)
|
||||
return xp.ones_like(x, dtype=dtype, **kwargs)
|
||||
|
||||
def zeros(
|
||||
shape: Union[int, Tuple[int, ...]],
|
||||
xp,
|
||||
*,
|
||||
dtype: Optional[Dtype] = None,
|
||||
device: Optional[Device] = None,
|
||||
**kwargs,
|
||||
) -> ndarray:
|
||||
_check_device(xp, device)
|
||||
return xp.zeros(shape, dtype=dtype, **kwargs)
|
||||
|
||||
def zeros_like(
|
||||
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None,
|
||||
**kwargs,
|
||||
) -> ndarray:
|
||||
_check_device(xp, device)
|
||||
return xp.zeros_like(x, dtype=dtype, **kwargs)
|
||||
|
||||
# np.unique() is split into four functions in the array API:
|
||||
# unique_all, unique_counts, unique_inverse, and unique_values (this is done
|
||||
# to remove polymorphic return types).
|
||||
|
||||
# The functions here return namedtuples (np.unique() returns a normal
|
||||
# tuple).
|
||||
|
||||
# Note that these named tuples aren't actually part of the standard namespace,
|
||||
# but I don't see any issue with exporting the names here regardless.
|
||||
class UniqueAllResult(NamedTuple):
|
||||
values: ndarray
|
||||
indices: ndarray
|
||||
inverse_indices: ndarray
|
||||
counts: ndarray
|
||||
|
||||
|
||||
class UniqueCountsResult(NamedTuple):
|
||||
values: ndarray
|
||||
counts: ndarray
|
||||
|
||||
|
||||
class UniqueInverseResult(NamedTuple):
|
||||
values: ndarray
|
||||
inverse_indices: ndarray
|
||||
|
||||
|
||||
def _unique_kwargs(xp):
|
||||
# Older versions of NumPy and CuPy do not have equal_nan. Rather than
|
||||
# trying to parse version numbers, just check if equal_nan is in the
|
||||
# signature.
|
||||
s = inspect.signature(xp.unique)
|
||||
if 'equal_nan' in s.parameters:
|
||||
return {'equal_nan': False}
|
||||
return {}
|
||||
|
||||
def unique_all(x: ndarray, /, xp) -> UniqueAllResult:
|
||||
kwargs = _unique_kwargs(xp)
|
||||
values, indices, inverse_indices, counts = xp.unique(
|
||||
x,
|
||||
return_counts=True,
|
||||
return_index=True,
|
||||
return_inverse=True,
|
||||
**kwargs,
|
||||
)
|
||||
# np.unique() flattens inverse indices, but they need to share x's shape
|
||||
# See https://github.com/numpy/numpy/issues/20638
|
||||
inverse_indices = inverse_indices.reshape(x.shape)
|
||||
return UniqueAllResult(
|
||||
values,
|
||||
indices,
|
||||
inverse_indices,
|
||||
counts,
|
||||
)
|
||||
|
||||
|
||||
def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult:
|
||||
kwargs = _unique_kwargs(xp)
|
||||
res = xp.unique(
|
||||
x,
|
||||
return_counts=True,
|
||||
return_index=False,
|
||||
return_inverse=False,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return UniqueCountsResult(*res)
|
||||
|
||||
|
||||
def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult:
|
||||
kwargs = _unique_kwargs(xp)
|
||||
values, inverse_indices = xp.unique(
|
||||
x,
|
||||
return_counts=False,
|
||||
return_index=False,
|
||||
return_inverse=True,
|
||||
**kwargs,
|
||||
)
|
||||
# xp.unique() flattens inverse indices, but they need to share x's shape
|
||||
# See https://github.com/numpy/numpy/issues/20638
|
||||
inverse_indices = inverse_indices.reshape(x.shape)
|
||||
return UniqueInverseResult(values, inverse_indices)
|
||||
|
||||
|
||||
def unique_values(x: ndarray, /, xp) -> ndarray:
|
||||
kwargs = _unique_kwargs(xp)
|
||||
return xp.unique(
|
||||
x,
|
||||
return_counts=False,
|
||||
return_index=False,
|
||||
return_inverse=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray:
|
||||
if not copy and dtype == x.dtype:
|
||||
return x
|
||||
return x.astype(dtype=dtype, copy=copy)
|
||||
|
||||
# These functions have different keyword argument names
|
||||
|
||||
def std(
|
||||
x: ndarray,
|
||||
/,
|
||||
xp,
|
||||
*,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
correction: Union[int, float] = 0.0, # correction instead of ddof
|
||||
keepdims: bool = False,
|
||||
**kwargs,
|
||||
) -> ndarray:
|
||||
return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
|
||||
|
||||
def var(
|
||||
x: ndarray,
|
||||
/,
|
||||
xp,
|
||||
*,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
correction: Union[int, float] = 0.0, # correction instead of ddof
|
||||
keepdims: bool = False,
|
||||
**kwargs,
|
||||
) -> ndarray:
|
||||
return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
|
||||
|
||||
# Unlike transpose(), the axes argument to permute_dims() is required.
|
||||
def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray:
|
||||
return xp.transpose(x, axes)
|
||||
|
||||
# Creation functions add the device keyword (which does nothing for NumPy)
|
||||
|
||||
# asarray also adds the copy keyword
|
||||
def _asarray(
|
||||
obj: Union[
|
||||
ndarray,
|
||||
bool,
|
||||
int,
|
||||
float,
|
||||
NestedSequence[bool | int | float],
|
||||
SupportsBufferProtocol,
|
||||
],
|
||||
/,
|
||||
*,
|
||||
dtype: Optional[Dtype] = None,
|
||||
device: Optional[Device] = None,
|
||||
copy: "Optional[Union[bool, np._CopyMode]]" = None,
|
||||
namespace = None,
|
||||
**kwargs,
|
||||
) -> ndarray:
|
||||
"""
|
||||
Array API compatibility wrapper for asarray().
|
||||
|
||||
See the corresponding documentation in NumPy/CuPy and/or the array API
|
||||
specification for more details.
|
||||
|
||||
"""
|
||||
if namespace is None:
|
||||
try:
|
||||
xp = array_namespace(obj, _use_compat=False)
|
||||
except ValueError:
|
||||
# TODO: What about lists of arrays?
|
||||
raise ValueError("A namespace must be specified for asarray() with non-array input")
|
||||
elif isinstance(namespace, ModuleType):
|
||||
xp = namespace
|
||||
elif namespace == 'numpy':
|
||||
import numpy as xp
|
||||
elif namespace == 'cupy':
|
||||
import cupy as xp
|
||||
elif namespace == 'dask.array':
|
||||
import dask.array as xp
|
||||
else:
|
||||
raise ValueError("Unrecognized namespace argument to asarray()")
|
||||
|
||||
_check_device(xp, device)
|
||||
if is_numpy_array(obj):
|
||||
import numpy as np
|
||||
if hasattr(np, '_CopyMode'):
|
||||
# Not present in older NumPys
|
||||
COPY_FALSE = (False, np._CopyMode.IF_NEEDED)
|
||||
COPY_TRUE = (True, np._CopyMode.ALWAYS)
|
||||
else:
|
||||
COPY_FALSE = (False,)
|
||||
COPY_TRUE = (True,)
|
||||
else:
|
||||
COPY_FALSE = (False,)
|
||||
COPY_TRUE = (True,)
|
||||
if copy in COPY_FALSE and namespace != "dask.array":
|
||||
# copy=False is not yet implemented in xp.asarray
|
||||
raise NotImplementedError("copy=False is not yet implemented")
|
||||
if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)):
|
||||
if dtype is not None and obj.dtype != dtype:
|
||||
copy = True
|
||||
if copy in COPY_TRUE:
|
||||
return xp.array(obj, copy=True, dtype=dtype)
|
||||
return obj
|
||||
elif namespace == "dask.array":
|
||||
if copy in COPY_TRUE:
|
||||
if dtype is None:
|
||||
return obj.copy()
|
||||
# Go through numpy, since dask copy is no-op by default
|
||||
import numpy as np
|
||||
obj = np.array(obj, dtype=dtype, copy=True)
|
||||
return xp.array(obj, dtype=dtype)
|
||||
else:
|
||||
import dask.array as da
|
||||
import numpy as np
|
||||
if not isinstance(obj, da.Array):
|
||||
obj = np.asarray(obj, dtype=dtype)
|
||||
return da.from_array(obj)
|
||||
return obj
|
||||
|
||||
return xp.asarray(obj, dtype=dtype, **kwargs)
|
||||
|
||||
# np.reshape calls the keyword argument 'newshape' instead of 'shape'
|
||||
def reshape(x: ndarray,
|
||||
/,
|
||||
shape: Tuple[int, ...],
|
||||
xp, copy: Optional[bool] = None,
|
||||
**kwargs) -> ndarray:
|
||||
if copy is True:
|
||||
x = x.copy()
|
||||
elif copy is False:
|
||||
y = x.view()
|
||||
y.shape = shape
|
||||
return y
|
||||
return xp.reshape(x, shape, **kwargs)
|
||||
|
||||
# The descending keyword is new in sort and argsort, and 'kind' replaced with
|
||||
# 'stable'
|
||||
def argsort(
|
||||
x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True,
|
||||
**kwargs,
|
||||
) -> ndarray:
|
||||
# Note: this keyword argument is different, and the default is different.
|
||||
# We set it in kwargs like this because numpy.sort uses kind='quicksort'
|
||||
# as the default whereas cupy.sort uses kind=None.
|
||||
if stable:
|
||||
kwargs['kind'] = "stable"
|
||||
if not descending:
|
||||
res = xp.argsort(x, axis=axis, **kwargs)
|
||||
else:
|
||||
# As NumPy has no native descending sort, we imitate it here. Note that
|
||||
# simply flipping the results of xp.argsort(x, ...) would not
|
||||
# respect the relative order like it would in native descending sorts.
|
||||
res = xp.flip(
|
||||
xp.argsort(xp.flip(x, axis=axis), axis=axis, **kwargs),
|
||||
axis=axis,
|
||||
)
|
||||
# Rely on flip()/argsort() to validate axis
|
||||
normalised_axis = axis if axis >= 0 else x.ndim + axis
|
||||
max_i = x.shape[normalised_axis] - 1
|
||||
res = max_i - res
|
||||
return res
|
||||
|
||||
def sort(
|
||||
x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True,
|
||||
**kwargs,
|
||||
) -> ndarray:
|
||||
# Note: this keyword argument is different, and the default is different.
|
||||
# We set it in kwargs like this because numpy.sort uses kind='quicksort'
|
||||
# as the default whereas cupy.sort uses kind=None.
|
||||
if stable:
|
||||
kwargs['kind'] = "stable"
|
||||
res = xp.sort(x, axis=axis, **kwargs)
|
||||
if descending:
|
||||
res = xp.flip(res, axis=axis)
|
||||
return res
|
||||
|
||||
# nonzero should error for zero-dimensional arrays
|
||||
def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]:
|
||||
if x.ndim == 0:
|
||||
raise ValueError("nonzero() does not support zero-dimensional arrays")
|
||||
return xp.nonzero(x, **kwargs)
|
||||
|
||||
# sum() and prod() should always upcast when dtype=None
|
||||
def sum(
|
||||
x: ndarray,
|
||||
/,
|
||||
xp,
|
||||
*,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype: Optional[Dtype] = None,
|
||||
keepdims: bool = False,
|
||||
**kwargs,
|
||||
) -> ndarray:
|
||||
# `xp.sum` already upcasts integers, but not floats or complexes
|
||||
if dtype is None:
|
||||
if x.dtype == xp.float32:
|
||||
dtype = xp.float64
|
||||
elif x.dtype == xp.complex64:
|
||||
dtype = xp.complex128
|
||||
return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs)
|
||||
|
||||
def prod(
|
||||
x: ndarray,
|
||||
/,
|
||||
xp,
|
||||
*,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype: Optional[Dtype] = None,
|
||||
keepdims: bool = False,
|
||||
**kwargs,
|
||||
) -> ndarray:
|
||||
if dtype is None:
|
||||
if x.dtype == xp.float32:
|
||||
dtype = xp.float64
|
||||
elif x.dtype == xp.complex64:
|
||||
dtype = xp.complex128
|
||||
return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs)
|
||||
|
||||
# ceil, floor, and trunc return integers for integer inputs
|
||||
|
||||
def ceil(x: ndarray, /, xp, **kwargs) -> ndarray:
|
||||
if xp.issubdtype(x.dtype, xp.integer):
|
||||
return x
|
||||
return xp.ceil(x, **kwargs)
|
||||
|
||||
def floor(x: ndarray, /, xp, **kwargs) -> ndarray:
|
||||
if xp.issubdtype(x.dtype, xp.integer):
|
||||
return x
|
||||
return xp.floor(x, **kwargs)
|
||||
|
||||
def trunc(x: ndarray, /, xp, **kwargs) -> ndarray:
|
||||
if xp.issubdtype(x.dtype, xp.integer):
|
||||
return x
|
||||
return xp.trunc(x, **kwargs)
|
||||
|
||||
# linear algebra functions
|
||||
|
||||
def matmul(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray:
|
||||
return xp.matmul(x1, x2, **kwargs)
|
||||
|
||||
# Unlike transpose, matrix_transpose only transposes the last two axes.
|
||||
def matrix_transpose(x: ndarray, /, xp) -> ndarray:
|
||||
if x.ndim < 2:
|
||||
raise ValueError("x must be at least 2-dimensional for matrix_transpose")
|
||||
return xp.swapaxes(x, -1, -2)
|
||||
|
||||
def tensordot(x1: ndarray,
|
||||
x2: ndarray,
|
||||
/,
|
||||
xp,
|
||||
*,
|
||||
axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2,
|
||||
**kwargs,
|
||||
) -> ndarray:
|
||||
return xp.tensordot(x1, x2, axes=axes, **kwargs)
|
||||
|
||||
def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray:
|
||||
if x1.shape[axis] != x2.shape[axis]:
|
||||
raise ValueError("x1 and x2 must have the same size along the given axis")
|
||||
|
||||
if hasattr(xp, 'broadcast_tensors'):
|
||||
_broadcast = xp.broadcast_tensors
|
||||
else:
|
||||
_broadcast = xp.broadcast_arrays
|
||||
|
||||
x1_ = xp.moveaxis(x1, axis, -1)
|
||||
x2_ = xp.moveaxis(x2, axis, -1)
|
||||
x1_, x2_ = _broadcast(x1_, x2_)
|
||||
|
||||
res = x1_[..., None, :] @ x2_[..., None]
|
||||
return res[..., 0, 0]
|
||||
|
||||
# isdtype is a new function in the 2022.12 array API specification.
|
||||
|
||||
def isdtype(
|
||||
dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], xp,
|
||||
*, _tuple=True, # Disallow nested tuples
|
||||
) -> bool:
|
||||
"""
|
||||
Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
|
||||
|
||||
Note that outside of this function, this compat library does not yet fully
|
||||
support complex numbers.
|
||||
|
||||
See
|
||||
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
|
||||
for more details
|
||||
"""
|
||||
if isinstance(kind, tuple) and _tuple:
|
||||
return any(isdtype(dtype, k, xp, _tuple=False) for k in kind)
|
||||
elif isinstance(kind, str):
|
||||
if kind == 'bool':
|
||||
return dtype == xp.bool_
|
||||
elif kind == 'signed integer':
|
||||
return xp.issubdtype(dtype, xp.signedinteger)
|
||||
elif kind == 'unsigned integer':
|
||||
return xp.issubdtype(dtype, xp.unsignedinteger)
|
||||
elif kind == 'integral':
|
||||
return xp.issubdtype(dtype, xp.integer)
|
||||
elif kind == 'real floating':
|
||||
return xp.issubdtype(dtype, xp.floating)
|
||||
elif kind == 'complex floating':
|
||||
return xp.issubdtype(dtype, xp.complexfloating)
|
||||
elif kind == 'numeric':
|
||||
return xp.issubdtype(dtype, xp.number)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized data type kind: {kind!r}")
|
||||
else:
|
||||
# This will allow things that aren't required by the spec, like
|
||||
# isdtype(np.float64, float) or isdtype(np.int64, 'l'). Should we be
|
||||
# more strict here to match the type annotation? Note that the
|
||||
# array_api_strict implementation will be very strict.
|
||||
return dtype == kind
|
||||
|
||||
__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
|
||||
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
|
||||
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
|
||||
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
|
||||
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
|
||||
'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
|
||||
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']
|
||||
@ -0,0 +1,183 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Union, Optional, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._typing import Device, ndarray
|
||||
from collections.abc import Sequence
|
||||
|
||||
# Note: NumPy fft functions improperly upcast float32 and complex64 to
|
||||
# complex128, which is why we require wrapping them all here.
|
||||
|
||||
def fft(
|
||||
x: ndarray,
|
||||
/,
|
||||
xp,
|
||||
*,
|
||||
n: Optional[int] = None,
|
||||
axis: int = -1,
|
||||
norm: Literal["backward", "ortho", "forward"] = "backward",
|
||||
) -> ndarray:
|
||||
res = xp.fft.fft(x, n=n, axis=axis, norm=norm)
|
||||
if x.dtype in [xp.float32, xp.complex64]:
|
||||
return res.astype(xp.complex64)
|
||||
return res
|
||||
|
||||
def ifft(
|
||||
x: ndarray,
|
||||
/,
|
||||
xp,
|
||||
*,
|
||||
n: Optional[int] = None,
|
||||
axis: int = -1,
|
||||
norm: Literal["backward", "ortho", "forward"] = "backward",
|
||||
) -> ndarray:
|
||||
res = xp.fft.ifft(x, n=n, axis=axis, norm=norm)
|
||||
if x.dtype in [xp.float32, xp.complex64]:
|
||||
return res.astype(xp.complex64)
|
||||
return res
|
||||
|
||||
def fftn(
|
||||
x: ndarray,
|
||||
/,
|
||||
xp,
|
||||
*,
|
||||
s: Sequence[int] = None,
|
||||
axes: Sequence[int] = None,
|
||||
norm: Literal["backward", "ortho", "forward"] = "backward",
|
||||
) -> ndarray:
|
||||
res = xp.fft.fftn(x, s=s, axes=axes, norm=norm)
|
||||
if x.dtype in [xp.float32, xp.complex64]:
|
||||
return res.astype(xp.complex64)
|
||||
return res
|
||||
|
||||
def ifftn(
|
||||
x: ndarray,
|
||||
/,
|
||||
xp,
|
||||
*,
|
||||
s: Sequence[int] = None,
|
||||
axes: Sequence[int] = None,
|
||||
norm: Literal["backward", "ortho", "forward"] = "backward",
|
||||
) -> ndarray:
|
||||
res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm)
|
||||
if x.dtype in [xp.float32, xp.complex64]:
|
||||
return res.astype(xp.complex64)
|
||||
return res
|
||||
|
||||
def rfft(
|
||||
x: ndarray,
|
||||
/,
|
||||
xp,
|
||||
*,
|
||||
n: Optional[int] = None,
|
||||
axis: int = -1,
|
||||
norm: Literal["backward", "ortho", "forward"] = "backward",
|
||||
) -> ndarray:
|
||||
res = xp.fft.rfft(x, n=n, axis=axis, norm=norm)
|
||||
if x.dtype == xp.float32:
|
||||
return res.astype(xp.complex64)
|
||||
return res
|
||||
|
||||
def irfft(
|
||||
x: ndarray,
|
||||
/,
|
||||
xp,
|
||||
*,
|
||||
n: Optional[int] = None,
|
||||
axis: int = -1,
|
||||
norm: Literal["backward", "ortho", "forward"] = "backward",
|
||||
) -> ndarray:
|
||||
res = xp.fft.irfft(x, n=n, axis=axis, norm=norm)
|
||||
if x.dtype == xp.complex64:
|
||||
return res.astype(xp.float32)
|
||||
return res
|
||||
|
||||
def rfftn(
|
||||
x: ndarray,
|
||||
/,
|
||||
xp,
|
||||
*,
|
||||
s: Sequence[int] = None,
|
||||
axes: Sequence[int] = None,
|
||||
norm: Literal["backward", "ortho", "forward"] = "backward",
|
||||
) -> ndarray:
|
||||
res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm)
|
||||
if x.dtype == xp.float32:
|
||||
return res.astype(xp.complex64)
|
||||
return res
|
||||
|
||||
def irfftn(
|
||||
x: ndarray,
|
||||
/,
|
||||
xp,
|
||||
*,
|
||||
s: Sequence[int] = None,
|
||||
axes: Sequence[int] = None,
|
||||
norm: Literal["backward", "ortho", "forward"] = "backward",
|
||||
) -> ndarray:
|
||||
res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm)
|
||||
if x.dtype == xp.complex64:
|
||||
return res.astype(xp.float32)
|
||||
return res
|
||||
|
||||
def hfft(
|
||||
x: ndarray,
|
||||
/,
|
||||
xp,
|
||||
*,
|
||||
n: Optional[int] = None,
|
||||
axis: int = -1,
|
||||
norm: Literal["backward", "ortho", "forward"] = "backward",
|
||||
) -> ndarray:
|
||||
res = xp.fft.hfft(x, n=n, axis=axis, norm=norm)
|
||||
if x.dtype in [xp.float32, xp.complex64]:
|
||||
return res.astype(xp.float32)
|
||||
return res
|
||||
|
||||
def ihfft(
|
||||
x: ndarray,
|
||||
/,
|
||||
xp,
|
||||
*,
|
||||
n: Optional[int] = None,
|
||||
axis: int = -1,
|
||||
norm: Literal["backward", "ortho", "forward"] = "backward",
|
||||
) -> ndarray:
|
||||
res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm)
|
||||
if x.dtype in [xp.float32, xp.complex64]:
|
||||
return res.astype(xp.complex64)
|
||||
return res
|
||||
|
||||
def fftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray:
|
||||
if device not in ["cpu", None]:
|
||||
raise ValueError(f"Unsupported device {device!r}")
|
||||
return xp.fft.fftfreq(n, d=d)
|
||||
|
||||
def rfftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray:
|
||||
if device not in ["cpu", None]:
|
||||
raise ValueError(f"Unsupported device {device!r}")
|
||||
return xp.fft.rfftfreq(n, d=d)
|
||||
|
||||
def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
|
||||
return xp.fft.fftshift(x, axes=axes)
|
||||
|
||||
def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
|
||||
return xp.fft.ifftshift(x, axes=axes)
|
||||
|
||||
__all__ = [
|
||||
"fft",
|
||||
"ifft",
|
||||
"fftn",
|
||||
"ifftn",
|
||||
"rfft",
|
||||
"irfft",
|
||||
"rfftn",
|
||||
"irfftn",
|
||||
"hfft",
|
||||
"ihfft",
|
||||
"fftfreq",
|
||||
"rfftfreq",
|
||||
"fftshift",
|
||||
"ifftshift",
|
||||
]
|
||||
@ -0,0 +1,515 @@
|
||||
"""
|
||||
Various helper functions which are not part of the spec.
|
||||
|
||||
Functions which start with an underscore are for internal use only but helpers
|
||||
that are in __all__ are intended as additional helper functions for use by end
|
||||
users of the compat library.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Optional, Union, Any
|
||||
from ._typing import Array, Device
|
||||
|
||||
import sys
|
||||
import math
|
||||
import inspect
|
||||
import warnings
|
||||
|
||||
def is_numpy_array(x):
|
||||
"""
|
||||
Return True if `x` is a NumPy array.
|
||||
|
||||
This function does not import NumPy if it has not already been imported
|
||||
and is therefore cheap to use.
|
||||
|
||||
This also returns True for `ndarray` subclasses and NumPy scalar objects.
|
||||
|
||||
See Also
|
||||
--------
|
||||
|
||||
array_namespace
|
||||
is_array_api_obj
|
||||
is_cupy_array
|
||||
is_torch_array
|
||||
is_dask_array
|
||||
is_jax_array
|
||||
"""
|
||||
# Avoid importing NumPy if it isn't already
|
||||
if 'numpy' not in sys.modules:
|
||||
return False
|
||||
|
||||
import numpy as np
|
||||
|
||||
# TODO: Should we reject ndarray subclasses?
|
||||
return isinstance(x, (np.ndarray, np.generic))
|
||||
|
||||
def is_cupy_array(x):
|
||||
"""
|
||||
Return True if `x` is a CuPy array.
|
||||
|
||||
This function does not import CuPy if it has not already been imported
|
||||
and is therefore cheap to use.
|
||||
|
||||
This also returns True for `cupy.ndarray` subclasses and CuPy scalar objects.
|
||||
|
||||
See Also
|
||||
--------
|
||||
|
||||
array_namespace
|
||||
is_array_api_obj
|
||||
is_numpy_array
|
||||
is_torch_array
|
||||
is_dask_array
|
||||
is_jax_array
|
||||
"""
|
||||
# Avoid importing NumPy if it isn't already
|
||||
if 'cupy' not in sys.modules:
|
||||
return False
|
||||
|
||||
import cupy as cp
|
||||
|
||||
# TODO: Should we reject ndarray subclasses?
|
||||
return isinstance(x, (cp.ndarray, cp.generic))
|
||||
|
||||
def is_torch_array(x):
|
||||
"""
|
||||
Return True if `x` is a PyTorch tensor.
|
||||
|
||||
This function does not import PyTorch if it has not already been imported
|
||||
and is therefore cheap to use.
|
||||
|
||||
See Also
|
||||
--------
|
||||
|
||||
array_namespace
|
||||
is_array_api_obj
|
||||
is_numpy_array
|
||||
is_cupy_array
|
||||
is_dask_array
|
||||
is_jax_array
|
||||
"""
|
||||
# Avoid importing torch if it isn't already
|
||||
if 'torch' not in sys.modules:
|
||||
return False
|
||||
|
||||
import torch
|
||||
|
||||
# TODO: Should we reject ndarray subclasses?
|
||||
return isinstance(x, torch.Tensor)
|
||||
|
||||
def is_dask_array(x):
|
||||
"""
|
||||
Return True if `x` is a dask.array Array.
|
||||
|
||||
This function does not import dask if it has not already been imported
|
||||
and is therefore cheap to use.
|
||||
|
||||
See Also
|
||||
--------
|
||||
|
||||
array_namespace
|
||||
is_array_api_obj
|
||||
is_numpy_array
|
||||
is_cupy_array
|
||||
is_torch_array
|
||||
is_jax_array
|
||||
"""
|
||||
# Avoid importing dask if it isn't already
|
||||
if 'dask.array' not in sys.modules:
|
||||
return False
|
||||
|
||||
import dask.array
|
||||
|
||||
return isinstance(x, dask.array.Array)
|
||||
|
||||
def is_jax_array(x):
|
||||
"""
|
||||
Return True if `x` is a JAX array.
|
||||
|
||||
This function does not import JAX if it has not already been imported
|
||||
and is therefore cheap to use.
|
||||
|
||||
|
||||
See Also
|
||||
--------
|
||||
|
||||
array_namespace
|
||||
is_array_api_obj
|
||||
is_numpy_array
|
||||
is_cupy_array
|
||||
is_torch_array
|
||||
is_dask_array
|
||||
"""
|
||||
# Avoid importing jax if it isn't already
|
||||
if 'jax' not in sys.modules:
|
||||
return False
|
||||
|
||||
import jax
|
||||
|
||||
return isinstance(x, jax.Array)
|
||||
|
||||
def is_array_api_obj(x):
|
||||
"""
|
||||
Return True if `x` is an array API compatible array object.
|
||||
|
||||
See Also
|
||||
--------
|
||||
|
||||
array_namespace
|
||||
is_numpy_array
|
||||
is_cupy_array
|
||||
is_torch_array
|
||||
is_dask_array
|
||||
is_jax_array
|
||||
"""
|
||||
return is_numpy_array(x) \
|
||||
or is_cupy_array(x) \
|
||||
or is_torch_array(x) \
|
||||
or is_dask_array(x) \
|
||||
or is_jax_array(x) \
|
||||
or hasattr(x, '__array_namespace__')
|
||||
|
||||
def _check_api_version(api_version):
|
||||
if api_version == '2021.12':
|
||||
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
|
||||
elif api_version is not None and api_version != '2022.12':
|
||||
raise ValueError("Only the 2022.12 version of the array API specification is currently supported")
|
||||
|
||||
def array_namespace(*xs, api_version=None, _use_compat=True):
|
||||
"""
|
||||
Get the array API compatible namespace for the arrays `xs`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
xs: arrays
|
||||
one or more arrays.
|
||||
|
||||
api_version: str
|
||||
The newest version of the spec that you need support for (currently
|
||||
the compat library wrapped APIs support v2022.12).
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
out: namespace
|
||||
The array API compatible namespace corresponding to the arrays in `xs`.
|
||||
|
||||
Raises
|
||||
------
|
||||
TypeError
|
||||
If `xs` contains arrays from different array libraries or contains a
|
||||
non-array.
|
||||
|
||||
|
||||
Typical usage is to pass the arguments of a function to
|
||||
`array_namespace()` at the top of a function to get the corresponding
|
||||
array API namespace:
|
||||
|
||||
.. code:: python
|
||||
|
||||
def your_function(x, y):
|
||||
xp = array_api_compat.array_namespace(x, y)
|
||||
# Now use xp as the array library namespace
|
||||
return xp.mean(x, axis=0) + 2*xp.std(y, axis=0)
|
||||
|
||||
|
||||
Wrapped array namespaces can also be imported directly. For example,
|
||||
`array_namespace(np.array(...))` will return `array_api_compat.numpy`.
|
||||
This function will also work for any array library not wrapped by
|
||||
array-api-compat if it explicitly defines `__array_namespace__
|
||||
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__array_namespace__.html>`__
|
||||
(the wrapped namespace is always preferred if it exists).
|
||||
|
||||
See Also
|
||||
--------
|
||||
|
||||
is_array_api_obj
|
||||
is_numpy_array
|
||||
is_cupy_array
|
||||
is_torch_array
|
||||
is_dask_array
|
||||
is_jax_array
|
||||
|
||||
"""
|
||||
namespaces = set()
|
||||
for x in xs:
|
||||
if is_numpy_array(x):
|
||||
_check_api_version(api_version)
|
||||
if _use_compat:
|
||||
from .. import numpy as numpy_namespace
|
||||
namespaces.add(numpy_namespace)
|
||||
else:
|
||||
import numpy as np
|
||||
namespaces.add(np)
|
||||
elif is_cupy_array(x):
|
||||
_check_api_version(api_version)
|
||||
if _use_compat:
|
||||
from .. import cupy as cupy_namespace
|
||||
namespaces.add(cupy_namespace)
|
||||
else:
|
||||
import cupy as cp
|
||||
namespaces.add(cp)
|
||||
elif is_torch_array(x):
|
||||
_check_api_version(api_version)
|
||||
if _use_compat:
|
||||
from .. import torch as torch_namespace
|
||||
namespaces.add(torch_namespace)
|
||||
else:
|
||||
import torch
|
||||
namespaces.add(torch)
|
||||
elif is_dask_array(x):
|
||||
_check_api_version(api_version)
|
||||
if _use_compat:
|
||||
from ..dask import array as dask_namespace
|
||||
namespaces.add(dask_namespace)
|
||||
else:
|
||||
raise TypeError("_use_compat cannot be False if input array is a dask array!")
|
||||
elif is_jax_array(x):
|
||||
_check_api_version(api_version)
|
||||
# jax.experimental.array_api is already an array namespace. We do
|
||||
# not have a wrapper submodule for it.
|
||||
import jax.experimental.array_api as jnp
|
||||
namespaces.add(jnp)
|
||||
elif hasattr(x, '__array_namespace__'):
|
||||
namespaces.add(x.__array_namespace__(api_version=api_version))
|
||||
else:
|
||||
# TODO: Support Python scalars?
|
||||
raise TypeError(f"{type(x).__name__} is not a supported array type")
|
||||
|
||||
if not namespaces:
|
||||
raise TypeError("Unrecognized array input")
|
||||
|
||||
if len(namespaces) != 1:
|
||||
raise TypeError(f"Multiple namespaces for array inputs: {namespaces}")
|
||||
|
||||
xp, = namespaces
|
||||
|
||||
return xp
|
||||
|
||||
# backwards compatibility alias
|
||||
get_namespace = array_namespace
|
||||
|
||||
def _check_device(xp, device):
|
||||
if xp == sys.modules.get('numpy'):
|
||||
if device not in ["cpu", None]:
|
||||
raise ValueError(f"Unsupported device for NumPy: {device!r}")
|
||||
|
||||
# Placeholder object to represent the dask device
|
||||
# when the array backend is not the CPU.
|
||||
# (since it is not easy to tell which device a dask array is on)
|
||||
class _dask_device:
|
||||
def __repr__(self):
|
||||
return "DASK_DEVICE"
|
||||
|
||||
_DASK_DEVICE = _dask_device()
|
||||
|
||||
# device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray
|
||||
# or cupy.ndarray. They are not included in array objects of this library
|
||||
# because this library just reuses the respective ndarray classes without
|
||||
# wrapping or subclassing them. These helper functions can be used instead of
|
||||
# the wrapper functions for libraries that need to support both NumPy/CuPy and
|
||||
# other libraries that use devices.
|
||||
def device(x: Array, /) -> Device:
|
||||
"""
|
||||
Hardware device the array data resides on.
|
||||
|
||||
This is equivalent to `x.device` according to the `standard
|
||||
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.device.html>`__.
|
||||
This helper is included because some array libraries either do not have
|
||||
the `device` attribute or include it with an incompatible API.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x: array
|
||||
array instance from an array API compatible library.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out: device
|
||||
a ``device`` object (see the `Device Support <https://data-apis.org/array-api/latest/design_topics/device_support.html>`__
|
||||
section of the array API specification).
|
||||
|
||||
Notes
|
||||
-----
|
||||
|
||||
For NumPy the device is always `"cpu"`. For Dask, the device is always a
|
||||
special `DASK_DEVICE` object.
|
||||
|
||||
See Also
|
||||
--------
|
||||
|
||||
to_device : Move array data to a different device.
|
||||
|
||||
"""
|
||||
if is_numpy_array(x):
|
||||
return "cpu"
|
||||
elif is_dask_array(x):
|
||||
# Peek at the metadata of the jax array to determine type
|
||||
try:
|
||||
import numpy as np
|
||||
if isinstance(x._meta, np.ndarray):
|
||||
# Must be on CPU since backed by numpy
|
||||
return "cpu"
|
||||
except ImportError:
|
||||
pass
|
||||
return _DASK_DEVICE
|
||||
elif is_jax_array(x):
|
||||
# JAX has .device() as a method, but it is being deprecated so that it
|
||||
# can become a property, in accordance with the standard. In order for
|
||||
# this function to not break when JAX makes the flip, we check for
|
||||
# both here.
|
||||
if inspect.ismethod(x.device):
|
||||
return x.device()
|
||||
else:
|
||||
return x.device
|
||||
return x.device
|
||||
|
||||
# Based on cupy.array_api.Array.to_device
|
||||
def _cupy_to_device(x, device, /, stream=None):
|
||||
import cupy as cp
|
||||
from cupy.cuda import Device as _Device
|
||||
from cupy.cuda import stream as stream_module
|
||||
from cupy_backends.cuda.api import runtime
|
||||
|
||||
if device == x.device:
|
||||
return x
|
||||
elif device == "cpu":
|
||||
# allowing us to use `to_device(x, "cpu")`
|
||||
# is useful for portable test swapping between
|
||||
# host and device backends
|
||||
return x.get()
|
||||
elif not isinstance(device, _Device):
|
||||
raise ValueError(f"Unsupported device {device!r}")
|
||||
else:
|
||||
# see cupy/cupy#5985 for the reason how we handle device/stream here
|
||||
prev_device = runtime.getDevice()
|
||||
prev_stream: stream_module.Stream = None
|
||||
if stream is not None:
|
||||
prev_stream = stream_module.get_current_stream()
|
||||
# stream can be an int as specified in __dlpack__, or a CuPy stream
|
||||
if isinstance(stream, int):
|
||||
stream = cp.cuda.ExternalStream(stream)
|
||||
elif isinstance(stream, cp.cuda.Stream):
|
||||
pass
|
||||
else:
|
||||
raise ValueError('the input stream is not recognized')
|
||||
stream.use()
|
||||
try:
|
||||
runtime.setDevice(device.id)
|
||||
arr = x.copy()
|
||||
finally:
|
||||
runtime.setDevice(prev_device)
|
||||
if stream is not None:
|
||||
prev_stream.use()
|
||||
return arr
|
||||
|
||||
def _torch_to_device(x, device, /, stream=None):
|
||||
if stream is not None:
|
||||
raise NotImplementedError
|
||||
return x.to(device)
|
||||
|
||||
def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array:
|
||||
"""
|
||||
Copy the array from the device on which it currently resides to the specified ``device``.
|
||||
|
||||
This is equivalent to `x.to_device(device, stream=stream)` according to
|
||||
the `standard
|
||||
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.to_device.html>`__.
|
||||
This helper is included because some array libraries do not have the
|
||||
`to_device` method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
x: array
|
||||
array instance from an array API compatible library.
|
||||
|
||||
device: device
|
||||
a ``device`` object (see the `Device Support <https://data-apis.org/array-api/latest/design_topics/device_support.html>`__
|
||||
section of the array API specification).
|
||||
|
||||
stream: Optional[Union[int, Any]]
|
||||
stream object to use during copy. In addition to the types supported
|
||||
in ``array.__dlpack__``, implementations may choose to support any
|
||||
library-specific stream object with the caveat that any code using
|
||||
such an object would not be portable.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
out: array
|
||||
an array with the same data and data type as ``x`` and located on the
|
||||
specified ``device``.
|
||||
|
||||
Notes
|
||||
-----
|
||||
|
||||
For NumPy, this function effectively does nothing since the only supported
|
||||
device is the CPU. For CuPy, this method supports CuPy CUDA
|
||||
:external+cupy:class:`Device <cupy.cuda.Device>` and
|
||||
:external+cupy:class:`Stream <cupy.cuda.Stream>` objects. For PyTorch,
|
||||
this is the same as :external+torch:meth:`x.to(device) <torch.Tensor.to>`
|
||||
(the ``stream`` argument is not supported in PyTorch).
|
||||
|
||||
See Also
|
||||
--------
|
||||
|
||||
device : Hardware device the array data resides on.
|
||||
|
||||
"""
|
||||
if is_numpy_array(x):
|
||||
if stream is not None:
|
||||
raise ValueError("The stream argument to to_device() is not supported")
|
||||
if device == 'cpu':
|
||||
return x
|
||||
raise ValueError(f"Unsupported device {device!r}")
|
||||
elif is_cupy_array(x):
|
||||
# cupy does not yet have to_device
|
||||
return _cupy_to_device(x, device, stream=stream)
|
||||
elif is_torch_array(x):
|
||||
return _torch_to_device(x, device, stream=stream)
|
||||
elif is_dask_array(x):
|
||||
if stream is not None:
|
||||
raise ValueError("The stream argument to to_device() is not supported")
|
||||
# TODO: What if our array is on the GPU already?
|
||||
if device == 'cpu':
|
||||
return x
|
||||
raise ValueError(f"Unsupported device {device!r}")
|
||||
elif is_jax_array(x):
|
||||
# This import adds to_device to x
|
||||
import jax.experimental.array_api # noqa: F401
|
||||
return x.to_device(device, stream=stream)
|
||||
return x.to_device(device, stream=stream)
|
||||
|
||||
def size(x):
|
||||
"""
|
||||
Return the total number of elements of x.
|
||||
|
||||
This is equivalent to `x.size` according to the `standard
|
||||
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.size.html>`__.
|
||||
This helper is included because PyTorch defines `size` in an
|
||||
:external+torch:meth:`incompatible way <torch.Tensor.size>`.
|
||||
|
||||
"""
|
||||
if None in x.shape:
|
||||
return None
|
||||
return math.prod(x.shape)
|
||||
|
||||
__all__ = [
|
||||
"array_namespace",
|
||||
"device",
|
||||
"get_namespace",
|
||||
"is_array_api_obj",
|
||||
"is_cupy_array",
|
||||
"is_dask_array",
|
||||
"is_jax_array",
|
||||
"is_numpy_array",
|
||||
"is_torch_array",
|
||||
"size",
|
||||
"to_device",
|
||||
]
|
||||
|
||||
_all_ignore = ['sys', 'math', 'inspect', 'warnings']
|
||||
@ -0,0 +1,161 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
if TYPE_CHECKING:
|
||||
from typing import Literal, Optional, Tuple, Union
|
||||
from ._typing import ndarray
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
if np.__version__[0] == "2":
|
||||
from numpy.lib.array_utils import normalize_axis_tuple
|
||||
else:
|
||||
from numpy.core.numeric import normalize_axis_tuple
|
||||
|
||||
from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype
|
||||
from .._internal import get_xp
|
||||
|
||||
# These are in the main NumPy namespace but not in numpy.linalg
|
||||
def cross(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1, **kwargs) -> ndarray:
|
||||
return xp.cross(x1, x2, axis=axis, **kwargs)
|
||||
|
||||
def outer(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray:
|
||||
return xp.outer(x1, x2, **kwargs)
|
||||
|
||||
class EighResult(NamedTuple):
|
||||
eigenvalues: ndarray
|
||||
eigenvectors: ndarray
|
||||
|
||||
class QRResult(NamedTuple):
|
||||
Q: ndarray
|
||||
R: ndarray
|
||||
|
||||
class SlogdetResult(NamedTuple):
|
||||
sign: ndarray
|
||||
logabsdet: ndarray
|
||||
|
||||
class SVDResult(NamedTuple):
|
||||
U: ndarray
|
||||
S: ndarray
|
||||
Vh: ndarray
|
||||
|
||||
# These functions are the same as their NumPy counterparts except they return
|
||||
# a namedtuple.
|
||||
def eigh(x: ndarray, /, xp, **kwargs) -> EighResult:
|
||||
return EighResult(*xp.linalg.eigh(x, **kwargs))
|
||||
|
||||
def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced',
|
||||
**kwargs) -> QRResult:
|
||||
return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs))
|
||||
|
||||
def slogdet(x: ndarray, /, xp, **kwargs) -> SlogdetResult:
|
||||
return SlogdetResult(*xp.linalg.slogdet(x, **kwargs))
|
||||
|
||||
def svd(x: ndarray, /, xp, *, full_matrices: bool = True, **kwargs) -> SVDResult:
|
||||
return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs))
|
||||
|
||||
# These functions have additional keyword arguments
|
||||
|
||||
# The upper keyword argument is new from NumPy
|
||||
def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray:
|
||||
L = xp.linalg.cholesky(x, **kwargs)
|
||||
if upper:
|
||||
U = get_xp(xp)(matrix_transpose)(L)
|
||||
if get_xp(xp)(isdtype)(U.dtype, 'complex floating'):
|
||||
U = xp.conj(U)
|
||||
return U
|
||||
return L
|
||||
|
||||
# The rtol keyword argument of matrix_rank() and pinv() is new from NumPy.
|
||||
# Note that it has a different semantic meaning from tol and rcond.
|
||||
def matrix_rank(x: ndarray,
|
||||
/,
|
||||
xp,
|
||||
*,
|
||||
rtol: Optional[Union[float, ndarray]] = None,
|
||||
**kwargs) -> ndarray:
|
||||
# this is different from xp.linalg.matrix_rank, which supports 1
|
||||
# dimensional arrays.
|
||||
if x.ndim < 2:
|
||||
raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
|
||||
S = get_xp(xp)(svdvals)(x, **kwargs)
|
||||
if rtol is None:
|
||||
tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps
|
||||
else:
|
||||
# this is different from xp.linalg.matrix_rank, which does not
|
||||
# multiply the tolerance by the largest singular value.
|
||||
tol = S.max(axis=-1, keepdims=True)*xp.asarray(rtol)[..., xp.newaxis]
|
||||
return xp.count_nonzero(S > tol, axis=-1)
|
||||
|
||||
def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **kwargs) -> ndarray:
|
||||
# this is different from xp.linalg.pinv, which does not multiply the
|
||||
# default tolerance by max(M, N).
|
||||
if rtol is None:
|
||||
rtol = max(x.shape[-2:]) * xp.finfo(x.dtype).eps
|
||||
return xp.linalg.pinv(x, rcond=rtol, **kwargs)
|
||||
|
||||
# These functions are new in the array API spec
|
||||
|
||||
def matrix_norm(x: ndarray, /, xp, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray:
|
||||
return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)
|
||||
|
||||
# svdvals is not in NumPy (but it is in SciPy). It is equivalent to
|
||||
# xp.linalg.svd(compute_uv=False).
|
||||
def svdvals(x: ndarray, /, xp) -> Union[ndarray, Tuple[ndarray, ...]]:
|
||||
return xp.linalg.svd(x, compute_uv=False)
|
||||
|
||||
def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> ndarray:
|
||||
# xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
|
||||
# when axis=None and the input is 2-D, so to force a vector norm, we make
|
||||
# it so the input is 1-D (for axis=None), or reshape so that norm is done
|
||||
# on a single dimension.
|
||||
if axis is None:
|
||||
# Note: xp.linalg.norm() doesn't handle 0-D arrays
|
||||
_x = x.ravel()
|
||||
_axis = 0
|
||||
elif isinstance(axis, tuple):
|
||||
# Note: The axis argument supports any number of axes, whereas
|
||||
# xp.linalg.norm() only supports a single axis for vector norm.
|
||||
normalized_axis = normalize_axis_tuple(axis, x.ndim)
|
||||
rest = tuple(i for i in range(x.ndim) if i not in normalized_axis)
|
||||
newshape = axis + rest
|
||||
_x = xp.transpose(x, newshape).reshape(
|
||||
(math.prod([x.shape[i] for i in axis]), *[x.shape[i] for i in rest]))
|
||||
_axis = 0
|
||||
else:
|
||||
_x = x
|
||||
_axis = axis
|
||||
|
||||
res = xp.linalg.norm(_x, axis=_axis, ord=ord)
|
||||
|
||||
if keepdims:
|
||||
# We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks
|
||||
# above to avoid matrix norm logic.
|
||||
shape = list(x.shape)
|
||||
_axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim)
|
||||
for i in _axis:
|
||||
shape[i] = 1
|
||||
res = xp.reshape(res, tuple(shape))
|
||||
|
||||
return res
|
||||
|
||||
# xp.diagonal and xp.trace operate on the first two axes whereas these
|
||||
# operates on the last two
|
||||
|
||||
def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray:
|
||||
return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)
|
||||
|
||||
def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray:
|
||||
if dtype is None:
|
||||
if x.dtype == xp.float32:
|
||||
dtype = xp.float64
|
||||
elif x.dtype == xp.complex64:
|
||||
dtype = xp.complex128
|
||||
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))
|
||||
|
||||
__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
|
||||
'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet',
|
||||
'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm',
|
||||
'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal',
|
||||
'trace']
|
||||
@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = [
|
||||
"NestedSequence",
|
||||
"SupportsBufferProtocol",
|
||||
]
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
TypeVar,
|
||||
Protocol,
|
||||
)
|
||||
|
||||
_T_co = TypeVar("_T_co", covariant=True)
|
||||
|
||||
class NestedSequence(Protocol[_T_co]):
|
||||
def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
|
||||
def __len__(self, /) -> int: ...
|
||||
|
||||
SupportsBufferProtocol = Any
|
||||
|
||||
Array = Any
|
||||
Device = Any
|
||||
@ -0,0 +1,16 @@
|
||||
from cupy import * # noqa: F403
|
||||
|
||||
# from cupy import * doesn't overwrite these builtin names
|
||||
from cupy import abs, max, min, round # noqa: F401
|
||||
|
||||
# These imports may overwrite names from the import * above.
|
||||
from ._aliases import * # noqa: F403
|
||||
|
||||
# See the comment in the numpy __init__.py
|
||||
__import__(__package__ + '.linalg')
|
||||
|
||||
__import__(__package__ + '.fft')
|
||||
|
||||
from ..common._helpers import * # noqa: F401,F403
|
||||
|
||||
__array_api_version__ = '2022.12'
|
||||
@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
|
||||
import cupy as cp
|
||||
|
||||
from ..common import _aliases
|
||||
from .._internal import get_xp
|
||||
|
||||
asarray = asarray_cupy = partial(_aliases._asarray, namespace='cupy')
|
||||
asarray.__doc__ = _aliases._asarray.__doc__
|
||||
del partial
|
||||
|
||||
bool = cp.bool_
|
||||
|
||||
# Basic renames
|
||||
acos = cp.arccos
|
||||
acosh = cp.arccosh
|
||||
asin = cp.arcsin
|
||||
asinh = cp.arcsinh
|
||||
atan = cp.arctan
|
||||
atan2 = cp.arctan2
|
||||
atanh = cp.arctanh
|
||||
bitwise_left_shift = cp.left_shift
|
||||
bitwise_invert = cp.invert
|
||||
bitwise_right_shift = cp.right_shift
|
||||
concat = cp.concatenate
|
||||
pow = cp.power
|
||||
|
||||
arange = get_xp(cp)(_aliases.arange)
|
||||
empty = get_xp(cp)(_aliases.empty)
|
||||
empty_like = get_xp(cp)(_aliases.empty_like)
|
||||
eye = get_xp(cp)(_aliases.eye)
|
||||
full = get_xp(cp)(_aliases.full)
|
||||
full_like = get_xp(cp)(_aliases.full_like)
|
||||
linspace = get_xp(cp)(_aliases.linspace)
|
||||
ones = get_xp(cp)(_aliases.ones)
|
||||
ones_like = get_xp(cp)(_aliases.ones_like)
|
||||
zeros = get_xp(cp)(_aliases.zeros)
|
||||
zeros_like = get_xp(cp)(_aliases.zeros_like)
|
||||
UniqueAllResult = get_xp(cp)(_aliases.UniqueAllResult)
|
||||
UniqueCountsResult = get_xp(cp)(_aliases.UniqueCountsResult)
|
||||
UniqueInverseResult = get_xp(cp)(_aliases.UniqueInverseResult)
|
||||
unique_all = get_xp(cp)(_aliases.unique_all)
|
||||
unique_counts = get_xp(cp)(_aliases.unique_counts)
|
||||
unique_inverse = get_xp(cp)(_aliases.unique_inverse)
|
||||
unique_values = get_xp(cp)(_aliases.unique_values)
|
||||
astype = _aliases.astype
|
||||
std = get_xp(cp)(_aliases.std)
|
||||
var = get_xp(cp)(_aliases.var)
|
||||
permute_dims = get_xp(cp)(_aliases.permute_dims)
|
||||
reshape = get_xp(cp)(_aliases.reshape)
|
||||
argsort = get_xp(cp)(_aliases.argsort)
|
||||
sort = get_xp(cp)(_aliases.sort)
|
||||
nonzero = get_xp(cp)(_aliases.nonzero)
|
||||
sum = get_xp(cp)(_aliases.sum)
|
||||
prod = get_xp(cp)(_aliases.prod)
|
||||
ceil = get_xp(cp)(_aliases.ceil)
|
||||
floor = get_xp(cp)(_aliases.floor)
|
||||
trunc = get_xp(cp)(_aliases.trunc)
|
||||
matmul = get_xp(cp)(_aliases.matmul)
|
||||
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
|
||||
tensordot = get_xp(cp)(_aliases.tensordot)
|
||||
|
||||
# These functions are completely new here. If the library already has them
|
||||
# (i.e., numpy 2.0), use the library version instead of our wrapper.
|
||||
if hasattr(cp, 'vecdot'):
|
||||
vecdot = cp.vecdot
|
||||
else:
|
||||
vecdot = get_xp(cp)(_aliases.vecdot)
|
||||
if hasattr(cp, 'isdtype'):
|
||||
isdtype = cp.isdtype
|
||||
else:
|
||||
isdtype = get_xp(cp)(_aliases.isdtype)
|
||||
|
||||
__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos',
|
||||
'acosh', 'asin', 'asinh', 'atan', 'atan2',
|
||||
'atanh', 'bitwise_left_shift', 'bitwise_invert',
|
||||
'bitwise_right_shift', 'concat', 'pow']
|
||||
|
||||
_all_ignore = ['cp', 'get_xp']
|
||||
@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = [
|
||||
"ndarray",
|
||||
"Device",
|
||||
"Dtype",
|
||||
]
|
||||
|
||||
import sys
|
||||
from typing import (
|
||||
Union,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
from cupy import (
|
||||
ndarray,
|
||||
dtype,
|
||||
int8,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
uint8,
|
||||
uint16,
|
||||
uint32,
|
||||
uint64,
|
||||
float32,
|
||||
float64,
|
||||
)
|
||||
|
||||
from cupy.cuda.device import Device
|
||||
|
||||
if TYPE_CHECKING or sys.version_info >= (3, 9):
|
||||
Dtype = dtype[Union[
|
||||
int8,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
uint8,
|
||||
uint16,
|
||||
uint32,
|
||||
uint64,
|
||||
float32,
|
||||
float64,
|
||||
]]
|
||||
else:
|
||||
Dtype = dtype
|
||||
@ -0,0 +1,36 @@
|
||||
from cupy.fft import * # noqa: F403
|
||||
# cupy.fft doesn't have __all__. If it is added, replace this with
|
||||
#
|
||||
# from cupy.fft import __all__ as linalg_all
|
||||
_n = {}
|
||||
exec('from cupy.fft import *', _n)
|
||||
del _n['__builtins__']
|
||||
fft_all = list(_n)
|
||||
del _n
|
||||
|
||||
from ..common import _fft
|
||||
from .._internal import get_xp
|
||||
|
||||
import cupy as cp
|
||||
|
||||
fft = get_xp(cp)(_fft.fft)
|
||||
ifft = get_xp(cp)(_fft.ifft)
|
||||
fftn = get_xp(cp)(_fft.fftn)
|
||||
ifftn = get_xp(cp)(_fft.ifftn)
|
||||
rfft = get_xp(cp)(_fft.rfft)
|
||||
irfft = get_xp(cp)(_fft.irfft)
|
||||
rfftn = get_xp(cp)(_fft.rfftn)
|
||||
irfftn = get_xp(cp)(_fft.irfftn)
|
||||
hfft = get_xp(cp)(_fft.hfft)
|
||||
ihfft = get_xp(cp)(_fft.ihfft)
|
||||
fftfreq = get_xp(cp)(_fft.fftfreq)
|
||||
rfftfreq = get_xp(cp)(_fft.rfftfreq)
|
||||
fftshift = get_xp(cp)(_fft.fftshift)
|
||||
ifftshift = get_xp(cp)(_fft.ifftshift)
|
||||
|
||||
__all__ = fft_all + _fft.__all__
|
||||
|
||||
del get_xp
|
||||
del cp
|
||||
del fft_all
|
||||
del _fft
|
||||
@ -0,0 +1,49 @@
|
||||
from cupy.linalg import * # noqa: F403
|
||||
# cupy.linalg doesn't have __all__. If it is added, replace this with
|
||||
#
|
||||
# from cupy.linalg import __all__ as linalg_all
|
||||
_n = {}
|
||||
exec('from cupy.linalg import *', _n)
|
||||
del _n['__builtins__']
|
||||
linalg_all = list(_n)
|
||||
del _n
|
||||
|
||||
from ..common import _linalg
|
||||
from .._internal import get_xp
|
||||
|
||||
import cupy as cp
|
||||
|
||||
# These functions are in both the main and linalg namespaces
|
||||
from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401
|
||||
|
||||
cross = get_xp(cp)(_linalg.cross)
|
||||
outer = get_xp(cp)(_linalg.outer)
|
||||
EighResult = _linalg.EighResult
|
||||
QRResult = _linalg.QRResult
|
||||
SlogdetResult = _linalg.SlogdetResult
|
||||
SVDResult = _linalg.SVDResult
|
||||
eigh = get_xp(cp)(_linalg.eigh)
|
||||
qr = get_xp(cp)(_linalg.qr)
|
||||
slogdet = get_xp(cp)(_linalg.slogdet)
|
||||
svd = get_xp(cp)(_linalg.svd)
|
||||
cholesky = get_xp(cp)(_linalg.cholesky)
|
||||
matrix_rank = get_xp(cp)(_linalg.matrix_rank)
|
||||
pinv = get_xp(cp)(_linalg.pinv)
|
||||
matrix_norm = get_xp(cp)(_linalg.matrix_norm)
|
||||
svdvals = get_xp(cp)(_linalg.svdvals)
|
||||
diagonal = get_xp(cp)(_linalg.diagonal)
|
||||
trace = get_xp(cp)(_linalg.trace)
|
||||
|
||||
# These functions are completely new here. If the library already has them
|
||||
# (i.e., numpy 2.0), use the library version instead of our wrapper.
|
||||
if hasattr(cp.linalg, 'vector_norm'):
|
||||
vector_norm = cp.linalg.vector_norm
|
||||
else:
|
||||
vector_norm = get_xp(cp)(_linalg.vector_norm)
|
||||
|
||||
__all__ = linalg_all + _linalg.__all__
|
||||
|
||||
del get_xp
|
||||
del cp
|
||||
del linalg_all
|
||||
del _linalg
|
||||
@ -0,0 +1,8 @@
|
||||
from dask.array import * # noqa: F403
|
||||
|
||||
# These imports may overwrite names from the import * above.
|
||||
from ._aliases import * # noqa: F403
|
||||
|
||||
__array_api_version__ = '2022.12'
|
||||
|
||||
__import__(__package__ + '.linalg')
|
||||
@ -0,0 +1,146 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ...common import _aliases
|
||||
from ...common._helpers import _check_device
|
||||
|
||||
from ..._internal import get_xp
|
||||
|
||||
import numpy as np
|
||||
from numpy import (
|
||||
# Constants
|
||||
e,
|
||||
inf,
|
||||
nan,
|
||||
pi,
|
||||
newaxis,
|
||||
# Dtypes
|
||||
bool_ as bool,
|
||||
float32,
|
||||
float64,
|
||||
int8,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
uint8,
|
||||
uint16,
|
||||
uint32,
|
||||
uint64,
|
||||
complex64,
|
||||
complex128,
|
||||
iinfo,
|
||||
finfo,
|
||||
can_cast,
|
||||
result_type,
|
||||
)
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...common._typing import Device, Dtype, Array
|
||||
|
||||
import dask.array as da
|
||||
|
||||
isdtype = get_xp(np)(_aliases.isdtype)
|
||||
astype = _aliases.astype
|
||||
|
||||
# Common aliases
|
||||
|
||||
# This arange func is modified from the common one to
|
||||
# not pass stop/step as keyword arguments, which will cause
|
||||
# an error with dask
|
||||
|
||||
# TODO: delete the xp stuff, it shouldn't be necessary
|
||||
def _dask_arange(
|
||||
start: Union[int, float],
|
||||
/,
|
||||
stop: Optional[Union[int, float]] = None,
|
||||
step: Union[int, float] = 1,
|
||||
*,
|
||||
xp,
|
||||
dtype: Optional[Dtype] = None,
|
||||
device: Optional[Device] = None,
|
||||
**kwargs,
|
||||
) -> Array:
|
||||
_check_device(xp, device)
|
||||
args = [start]
|
||||
if stop is not None:
|
||||
args.append(stop)
|
||||
else:
|
||||
# stop is None, so start is actually stop
|
||||
# prepend the default value for start which is 0
|
||||
args.insert(0, 0)
|
||||
args.append(step)
|
||||
return xp.arange(*args, dtype=dtype, **kwargs)
|
||||
|
||||
arange = get_xp(da)(_dask_arange)
|
||||
eye = get_xp(da)(_aliases.eye)
|
||||
|
||||
from functools import partial
|
||||
asarray = partial(_aliases._asarray, namespace='dask.array')
|
||||
asarray.__doc__ = _aliases._asarray.__doc__
|
||||
|
||||
linspace = get_xp(da)(_aliases.linspace)
|
||||
eye = get_xp(da)(_aliases.eye)
|
||||
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult)
|
||||
UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult)
|
||||
UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult)
|
||||
unique_all = get_xp(da)(_aliases.unique_all)
|
||||
unique_counts = get_xp(da)(_aliases.unique_counts)
|
||||
unique_inverse = get_xp(da)(_aliases.unique_inverse)
|
||||
unique_values = get_xp(da)(_aliases.unique_values)
|
||||
permute_dims = get_xp(da)(_aliases.permute_dims)
|
||||
std = get_xp(da)(_aliases.std)
|
||||
var = get_xp(da)(_aliases.var)
|
||||
empty = get_xp(da)(_aliases.empty)
|
||||
empty_like = get_xp(da)(_aliases.empty_like)
|
||||
full = get_xp(da)(_aliases.full)
|
||||
full_like = get_xp(da)(_aliases.full_like)
|
||||
ones = get_xp(da)(_aliases.ones)
|
||||
ones_like = get_xp(da)(_aliases.ones_like)
|
||||
zeros = get_xp(da)(_aliases.zeros)
|
||||
zeros_like = get_xp(da)(_aliases.zeros_like)
|
||||
reshape = get_xp(da)(_aliases.reshape)
|
||||
matrix_transpose = get_xp(da)(_aliases.matrix_transpose)
|
||||
vecdot = get_xp(da)(_aliases.vecdot)
|
||||
|
||||
nonzero = get_xp(da)(_aliases.nonzero)
|
||||
sum = get_xp(np)(_aliases.sum)
|
||||
prod = get_xp(np)(_aliases.prod)
|
||||
ceil = get_xp(np)(_aliases.ceil)
|
||||
floor = get_xp(np)(_aliases.floor)
|
||||
trunc = get_xp(np)(_aliases.trunc)
|
||||
matmul = get_xp(np)(_aliases.matmul)
|
||||
tensordot = get_xp(np)(_aliases.tensordot)
|
||||
|
||||
from dask.array import (
|
||||
# Element wise aliases
|
||||
arccos as acos,
|
||||
arccosh as acosh,
|
||||
arcsin as asin,
|
||||
arcsinh as asinh,
|
||||
arctan as atan,
|
||||
arctan2 as atan2,
|
||||
arctanh as atanh,
|
||||
left_shift as bitwise_left_shift,
|
||||
right_shift as bitwise_right_shift,
|
||||
invert as bitwise_invert,
|
||||
power as pow,
|
||||
# Other
|
||||
concatenate as concat,
|
||||
)
|
||||
|
||||
# exclude these from all since
|
||||
_da_unsupported = ['sort', 'argsort']
|
||||
|
||||
common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]
|
||||
|
||||
__all__ = common_aliases + ['asarray', 'bool', 'acos',
|
||||
'acosh', 'asin', 'asinh', 'atan', 'atan2',
|
||||
'atanh', 'bitwise_left_shift', 'bitwise_invert',
|
||||
'bitwise_right_shift', 'concat', 'pow',
|
||||
'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8',
|
||||
'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64',
|
||||
'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type']
|
||||
|
||||
_all_ignore = ['get_xp', 'da', 'partial', 'common_aliases', 'np']
|
||||
@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ...common import _linalg
|
||||
from ..._internal import get_xp
|
||||
|
||||
# Exports
|
||||
from dask.array.linalg import * # noqa: F403
|
||||
from dask.array import trace, outer
|
||||
|
||||
# These functions are in both the main and linalg namespaces
|
||||
from dask.array import matmul, tensordot
|
||||
from ._aliases import matrix_transpose, vecdot
|
||||
|
||||
import dask.array as da
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from ...common._typing import Array
|
||||
from typing import Literal
|
||||
|
||||
# dask.array.linalg doesn't have __all__. If it is added, replace this with
|
||||
#
|
||||
# from dask.array.linalg import __all__ as linalg_all
|
||||
_n = {}
|
||||
exec('from dask.array.linalg import *', _n)
|
||||
del _n['__builtins__']
|
||||
if 'annotations' in _n:
|
||||
del _n['annotations']
|
||||
linalg_all = list(_n)
|
||||
del _n
|
||||
|
||||
EighResult = _linalg.EighResult
|
||||
QRResult = _linalg.QRResult
|
||||
SlogdetResult = _linalg.SlogdetResult
|
||||
SVDResult = _linalg.SVDResult
|
||||
# TODO: use the QR wrapper once dask
|
||||
# supports the mode keyword on QR
|
||||
# https://github.com/dask/dask/issues/10388
|
||||
#qr = get_xp(da)(_linalg.qr)
|
||||
def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced',
|
||||
**kwargs) -> QRResult:
|
||||
if mode != "reduced":
|
||||
raise ValueError("dask arrays only support using mode='reduced'")
|
||||
return QRResult(*da.linalg.qr(x, **kwargs))
|
||||
cholesky = get_xp(da)(_linalg.cholesky)
|
||||
matrix_rank = get_xp(da)(_linalg.matrix_rank)
|
||||
matrix_norm = get_xp(da)(_linalg.matrix_norm)
|
||||
|
||||
|
||||
# Wrap the svd functions to not pass full_matrices to dask
|
||||
# when full_matrices=False (as that is the default behavior for dask),
|
||||
# and dask doesn't have the full_matrices keyword
|
||||
def svd(x: Array, full_matrices: bool = True, **kwargs) -> SVDResult:
|
||||
if full_matrices:
|
||||
raise ValueError("full_matrics=True is not supported by dask.")
|
||||
return da.linalg.svd(x, coerce_signs=False, **kwargs)
|
||||
|
||||
def svdvals(x: Array) -> Array:
|
||||
# TODO: can't avoid computing U or V for dask
|
||||
_, s, _ = svd(x)
|
||||
return s
|
||||
|
||||
vector_norm = get_xp(da)(_linalg.vector_norm)
|
||||
diagonal = get_xp(da)(_linalg.diagonal)
|
||||
|
||||
__all__ = linalg_all + ["trace", "outer", "matmul", "tensordot",
|
||||
"matrix_transpose", "vecdot", "EighResult",
|
||||
"QRResult", "SlogdetResult", "SVDResult", "qr",
|
||||
"cholesky", "matrix_rank", "matrix_norm", "svdvals",
|
||||
"vector_norm", "diagonal"]
|
||||
|
||||
_all_ignore = ['get_xp', 'da', 'linalg_all']
|
||||
@ -0,0 +1,24 @@
|
||||
from numpy import * # noqa: F403
|
||||
|
||||
# from numpy import * doesn't overwrite these builtin names
|
||||
from numpy import abs, max, min, round # noqa: F401
|
||||
|
||||
# These imports may overwrite names from the import * above.
|
||||
from ._aliases import * # noqa: F403
|
||||
|
||||
# Don't know why, but we have to do an absolute import to import linalg. If we
|
||||
# instead do
|
||||
#
|
||||
# from . import linalg
|
||||
#
|
||||
# It doesn't overwrite np.linalg from above. The import is generated
|
||||
# dynamically so that the library can be vendored.
|
||||
__import__(__package__ + '.linalg')
|
||||
|
||||
__import__(__package__ + '.fft')
|
||||
|
||||
from .linalg import matrix_transpose, vecdot # noqa: F401
|
||||
|
||||
from ..common._helpers import * # noqa: F403
|
||||
|
||||
__array_api_version__ = '2022.12'
|
||||
@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
|
||||
from ..common import _aliases
|
||||
|
||||
from .._internal import get_xp
|
||||
|
||||
asarray = asarray_numpy = partial(_aliases._asarray, namespace='numpy')
|
||||
asarray.__doc__ = _aliases._asarray.__doc__
|
||||
del partial
|
||||
|
||||
import numpy as np
|
||||
bool = np.bool_
|
||||
|
||||
# Basic renames
|
||||
acos = np.arccos
|
||||
acosh = np.arccosh
|
||||
asin = np.arcsin
|
||||
asinh = np.arcsinh
|
||||
atan = np.arctan
|
||||
atan2 = np.arctan2
|
||||
atanh = np.arctanh
|
||||
bitwise_left_shift = np.left_shift
|
||||
bitwise_invert = np.invert
|
||||
bitwise_right_shift = np.right_shift
|
||||
concat = np.concatenate
|
||||
pow = np.power
|
||||
|
||||
arange = get_xp(np)(_aliases.arange)
|
||||
empty = get_xp(np)(_aliases.empty)
|
||||
empty_like = get_xp(np)(_aliases.empty_like)
|
||||
eye = get_xp(np)(_aliases.eye)
|
||||
full = get_xp(np)(_aliases.full)
|
||||
full_like = get_xp(np)(_aliases.full_like)
|
||||
linspace = get_xp(np)(_aliases.linspace)
|
||||
ones = get_xp(np)(_aliases.ones)
|
||||
ones_like = get_xp(np)(_aliases.ones_like)
|
||||
zeros = get_xp(np)(_aliases.zeros)
|
||||
zeros_like = get_xp(np)(_aliases.zeros_like)
|
||||
UniqueAllResult = get_xp(np)(_aliases.UniqueAllResult)
|
||||
UniqueCountsResult = get_xp(np)(_aliases.UniqueCountsResult)
|
||||
UniqueInverseResult = get_xp(np)(_aliases.UniqueInverseResult)
|
||||
unique_all = get_xp(np)(_aliases.unique_all)
|
||||
unique_counts = get_xp(np)(_aliases.unique_counts)
|
||||
unique_inverse = get_xp(np)(_aliases.unique_inverse)
|
||||
unique_values = get_xp(np)(_aliases.unique_values)
|
||||
astype = _aliases.astype
|
||||
std = get_xp(np)(_aliases.std)
|
||||
var = get_xp(np)(_aliases.var)
|
||||
permute_dims = get_xp(np)(_aliases.permute_dims)
|
||||
reshape = get_xp(np)(_aliases.reshape)
|
||||
argsort = get_xp(np)(_aliases.argsort)
|
||||
sort = get_xp(np)(_aliases.sort)
|
||||
nonzero = get_xp(np)(_aliases.nonzero)
|
||||
sum = get_xp(np)(_aliases.sum)
|
||||
prod = get_xp(np)(_aliases.prod)
|
||||
ceil = get_xp(np)(_aliases.ceil)
|
||||
floor = get_xp(np)(_aliases.floor)
|
||||
trunc = get_xp(np)(_aliases.trunc)
|
||||
matmul = get_xp(np)(_aliases.matmul)
|
||||
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
|
||||
tensordot = get_xp(np)(_aliases.tensordot)
|
||||
|
||||
# These functions are completely new here. If the library already has them
|
||||
# (i.e., numpy 2.0), use the library version instead of our wrapper.
|
||||
if hasattr(np, 'vecdot'):
|
||||
vecdot = np.vecdot
|
||||
else:
|
||||
vecdot = get_xp(np)(_aliases.vecdot)
|
||||
if hasattr(np, 'isdtype'):
|
||||
isdtype = np.isdtype
|
||||
else:
|
||||
isdtype = get_xp(np)(_aliases.isdtype)
|
||||
|
||||
__all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'acos',
|
||||
'acosh', 'asin', 'asinh', 'atan', 'atan2',
|
||||
'atanh', 'bitwise_left_shift', 'bitwise_invert',
|
||||
'bitwise_right_shift', 'concat', 'pow']
|
||||
|
||||
_all_ignore = ['np', 'get_xp']
|
||||
@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = [
|
||||
"ndarray",
|
||||
"Device",
|
||||
"Dtype",
|
||||
]
|
||||
|
||||
import sys
|
||||
from typing import (
|
||||
Literal,
|
||||
Union,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
from numpy import (
|
||||
ndarray,
|
||||
dtype,
|
||||
int8,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
uint8,
|
||||
uint16,
|
||||
uint32,
|
||||
uint64,
|
||||
float32,
|
||||
float64,
|
||||
)
|
||||
|
||||
Device = Literal["cpu"]
|
||||
if TYPE_CHECKING or sys.version_info >= (3, 9):
|
||||
Dtype = dtype[Union[
|
||||
int8,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
uint8,
|
||||
uint16,
|
||||
uint32,
|
||||
uint64,
|
||||
float32,
|
||||
float64,
|
||||
]]
|
||||
else:
|
||||
Dtype = dtype
|
||||
@ -0,0 +1,29 @@
|
||||
from numpy.fft import * # noqa: F403
|
||||
from numpy.fft import __all__ as fft_all
|
||||
|
||||
from ..common import _fft
|
||||
from .._internal import get_xp
|
||||
|
||||
import numpy as np
|
||||
|
||||
fft = get_xp(np)(_fft.fft)
|
||||
ifft = get_xp(np)(_fft.ifft)
|
||||
fftn = get_xp(np)(_fft.fftn)
|
||||
ifftn = get_xp(np)(_fft.ifftn)
|
||||
rfft = get_xp(np)(_fft.rfft)
|
||||
irfft = get_xp(np)(_fft.irfft)
|
||||
rfftn = get_xp(np)(_fft.rfftn)
|
||||
irfftn = get_xp(np)(_fft.irfftn)
|
||||
hfft = get_xp(np)(_fft.hfft)
|
||||
ihfft = get_xp(np)(_fft.ihfft)
|
||||
fftfreq = get_xp(np)(_fft.fftfreq)
|
||||
rfftfreq = get_xp(np)(_fft.rfftfreq)
|
||||
fftshift = get_xp(np)(_fft.fftshift)
|
||||
ifftshift = get_xp(np)(_fft.ifftshift)
|
||||
|
||||
__all__ = fft_all + _fft.__all__
|
||||
|
||||
del get_xp
|
||||
del np
|
||||
del fft_all
|
||||
del _fft
|
||||
@ -0,0 +1,90 @@
|
||||
from numpy.linalg import * # noqa: F403
|
||||
from numpy.linalg import __all__ as linalg_all
|
||||
import numpy as _np
|
||||
|
||||
from ..common import _linalg
|
||||
from .._internal import get_xp
|
||||
|
||||
# These functions are in both the main and linalg namespaces
|
||||
from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401
|
||||
|
||||
import numpy as np
|
||||
|
||||
cross = get_xp(np)(_linalg.cross)
|
||||
outer = get_xp(np)(_linalg.outer)
|
||||
EighResult = _linalg.EighResult
|
||||
QRResult = _linalg.QRResult
|
||||
SlogdetResult = _linalg.SlogdetResult
|
||||
SVDResult = _linalg.SVDResult
|
||||
eigh = get_xp(np)(_linalg.eigh)
|
||||
qr = get_xp(np)(_linalg.qr)
|
||||
slogdet = get_xp(np)(_linalg.slogdet)
|
||||
svd = get_xp(np)(_linalg.svd)
|
||||
cholesky = get_xp(np)(_linalg.cholesky)
|
||||
matrix_rank = get_xp(np)(_linalg.matrix_rank)
|
||||
pinv = get_xp(np)(_linalg.pinv)
|
||||
matrix_norm = get_xp(np)(_linalg.matrix_norm)
|
||||
svdvals = get_xp(np)(_linalg.svdvals)
|
||||
diagonal = get_xp(np)(_linalg.diagonal)
|
||||
trace = get_xp(np)(_linalg.trace)
|
||||
|
||||
# Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a
|
||||
# vector when it is exactly 1-dimensional. All other cases treat x2 as a stack
|
||||
# of matrices. The np.linalg.solve behavior of allowing stacks of both
|
||||
# matrices and vectors is ambiguous c.f.
|
||||
# https://github.com/numpy/numpy/issues/15349 and
|
||||
# https://github.com/data-apis/array-api/issues/285.
|
||||
|
||||
# To workaround this, the below is the code from np.linalg.solve except
|
||||
# only calling solve1 in the exactly 1D case.
|
||||
|
||||
# This code is here instead of in common because it is numpy specific. Also
|
||||
# note that CuPy's solve() does not currently support broadcasting (see
|
||||
# https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43).
|
||||
def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray:
|
||||
try:
|
||||
from numpy.linalg._linalg import (
|
||||
_makearray, _assert_stacked_2d, _assert_stacked_square,
|
||||
_commonType, isComplexType, _raise_linalgerror_singular
|
||||
)
|
||||
except ImportError:
|
||||
from numpy.linalg.linalg import (
|
||||
_makearray, _assert_stacked_2d, _assert_stacked_square,
|
||||
_commonType, isComplexType, _raise_linalgerror_singular
|
||||
)
|
||||
from numpy.linalg import _umath_linalg
|
||||
|
||||
x1, _ = _makearray(x1)
|
||||
_assert_stacked_2d(x1)
|
||||
_assert_stacked_square(x1)
|
||||
x2, wrap = _makearray(x2)
|
||||
t, result_t = _commonType(x1, x2)
|
||||
|
||||
# This part is different from np.linalg.solve
|
||||
if x2.ndim == 1:
|
||||
gufunc = _umath_linalg.solve1
|
||||
else:
|
||||
gufunc = _umath_linalg.solve
|
||||
|
||||
# This does nothing currently but is left in because it will be relevant
|
||||
# when complex dtype support is added to the spec in 2022.
|
||||
signature = 'DD->D' if isComplexType(t) else 'dd->d'
|
||||
with _np.errstate(call=_raise_linalgerror_singular, invalid='call',
|
||||
over='ignore', divide='ignore', under='ignore'):
|
||||
r = gufunc(x1, x2, signature=signature)
|
||||
|
||||
return wrap(r.astype(result_t, copy=False))
|
||||
|
||||
# These functions are completely new here. If the library already has them
|
||||
# (i.e., numpy 2.0), use the library version instead of our wrapper.
|
||||
if hasattr(np.linalg, 'vector_norm'):
|
||||
vector_norm = np.linalg.vector_norm
|
||||
else:
|
||||
vector_norm = get_xp(np)(_linalg.vector_norm)
|
||||
|
||||
__all__ = linalg_all + _linalg.__all__ + ['solve']
|
||||
|
||||
del get_xp
|
||||
del np
|
||||
del linalg_all
|
||||
del _linalg
|
||||
@ -0,0 +1,24 @@
|
||||
from torch import * # noqa: F403
|
||||
|
||||
# Several names are not included in the above import *
|
||||
import torch
|
||||
for n in dir(torch):
|
||||
if (n.startswith('_')
|
||||
or n.endswith('_')
|
||||
or 'cuda' in n
|
||||
or 'cpu' in n
|
||||
or 'backward' in n):
|
||||
continue
|
||||
exec(n + ' = torch.' + n)
|
||||
|
||||
# These imports may overwrite names from the import * above.
|
||||
from ._aliases import * # noqa: F403
|
||||
|
||||
# See the comment in the numpy __init__.py
|
||||
__import__(__package__ + '.linalg')
|
||||
|
||||
__import__(__package__ + '.fft')
|
||||
|
||||
from ..common._helpers import * # noqa: F403
|
||||
|
||||
__array_api_version__ = '2022.12'
|
||||
@ -0,0 +1,718 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import wraps as _wraps
|
||||
from builtins import all as _builtin_all, any as _builtin_any
|
||||
|
||||
from ..common._aliases import (matrix_transpose as _aliases_matrix_transpose,
|
||||
vecdot as _aliases_vecdot)
|
||||
from .._internal import get_xp
|
||||
|
||||
import torch
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
from ..common._typing import Device
|
||||
from torch import dtype as Dtype
|
||||
|
||||
array = torch.Tensor
|
||||
|
||||
_int_dtypes = {
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
}
|
||||
|
||||
_array_api_dtypes = {
|
||||
torch.bool,
|
||||
*_int_dtypes,
|
||||
torch.float32,
|
||||
torch.float64,
|
||||
torch.complex64,
|
||||
torch.complex128,
|
||||
}
|
||||
|
||||
_promotion_table = {
|
||||
# bool
|
||||
(torch.bool, torch.bool): torch.bool,
|
||||
# ints
|
||||
(torch.int8, torch.int8): torch.int8,
|
||||
(torch.int8, torch.int16): torch.int16,
|
||||
(torch.int8, torch.int32): torch.int32,
|
||||
(torch.int8, torch.int64): torch.int64,
|
||||
(torch.int16, torch.int8): torch.int16,
|
||||
(torch.int16, torch.int16): torch.int16,
|
||||
(torch.int16, torch.int32): torch.int32,
|
||||
(torch.int16, torch.int64): torch.int64,
|
||||
(torch.int32, torch.int8): torch.int32,
|
||||
(torch.int32, torch.int16): torch.int32,
|
||||
(torch.int32, torch.int32): torch.int32,
|
||||
(torch.int32, torch.int64): torch.int64,
|
||||
(torch.int64, torch.int8): torch.int64,
|
||||
(torch.int64, torch.int16): torch.int64,
|
||||
(torch.int64, torch.int32): torch.int64,
|
||||
(torch.int64, torch.int64): torch.int64,
|
||||
# uints
|
||||
(torch.uint8, torch.uint8): torch.uint8,
|
||||
# ints and uints (mixed sign)
|
||||
(torch.int8, torch.uint8): torch.int16,
|
||||
(torch.int16, torch.uint8): torch.int16,
|
||||
(torch.int32, torch.uint8): torch.int32,
|
||||
(torch.int64, torch.uint8): torch.int64,
|
||||
(torch.uint8, torch.int8): torch.int16,
|
||||
(torch.uint8, torch.int16): torch.int16,
|
||||
(torch.uint8, torch.int32): torch.int32,
|
||||
(torch.uint8, torch.int64): torch.int64,
|
||||
# floats
|
||||
(torch.float32, torch.float32): torch.float32,
|
||||
(torch.float32, torch.float64): torch.float64,
|
||||
(torch.float64, torch.float32): torch.float64,
|
||||
(torch.float64, torch.float64): torch.float64,
|
||||
# complexes
|
||||
(torch.complex64, torch.complex64): torch.complex64,
|
||||
(torch.complex64, torch.complex128): torch.complex128,
|
||||
(torch.complex128, torch.complex64): torch.complex128,
|
||||
(torch.complex128, torch.complex128): torch.complex128,
|
||||
# Mixed float and complex
|
||||
(torch.float32, torch.complex64): torch.complex64,
|
||||
(torch.float32, torch.complex128): torch.complex128,
|
||||
(torch.float64, torch.complex64): torch.complex128,
|
||||
(torch.float64, torch.complex128): torch.complex128,
|
||||
}
|
||||
|
||||
|
||||
def _two_arg(f):
|
||||
@_wraps(f)
|
||||
def _f(x1, x2, /, **kwargs):
|
||||
x1, x2 = _fix_promotion(x1, x2)
|
||||
return f(x1, x2, **kwargs)
|
||||
if _f.__doc__ is None:
|
||||
_f.__doc__ = f"""\
|
||||
Array API compatibility wrapper for torch.{f.__name__}.
|
||||
|
||||
See the corresponding PyTorch documentation and/or the array API specification
|
||||
for more details.
|
||||
|
||||
"""
|
||||
return _f
|
||||
|
||||
def _fix_promotion(x1, x2, only_scalar=True):
|
||||
if not isinstance(x1, torch.Tensor) or not isinstance(x2, torch.Tensor):
|
||||
return x1, x2
|
||||
if x1.dtype not in _array_api_dtypes or x2.dtype not in _array_api_dtypes:
|
||||
return x1, x2
|
||||
# If an argument is 0-D pytorch downcasts the other argument
|
||||
if not only_scalar or x1.shape == ():
|
||||
dtype = result_type(x1, x2)
|
||||
x2 = x2.to(dtype)
|
||||
if not only_scalar or x2.shape == ():
|
||||
dtype = result_type(x1, x2)
|
||||
x1 = x1.to(dtype)
|
||||
return x1, x2
|
||||
|
||||
def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
|
||||
if len(arrays_and_dtypes) == 0:
|
||||
raise TypeError("At least one array or dtype must be provided")
|
||||
if len(arrays_and_dtypes) == 1:
|
||||
x = arrays_and_dtypes[0]
|
||||
if isinstance(x, torch.dtype):
|
||||
return x
|
||||
return x.dtype
|
||||
if len(arrays_and_dtypes) > 2:
|
||||
return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:]))
|
||||
|
||||
x, y = arrays_and_dtypes
|
||||
xdt = x.dtype if not isinstance(x, torch.dtype) else x
|
||||
ydt = y.dtype if not isinstance(y, torch.dtype) else y
|
||||
|
||||
if (xdt, ydt) in _promotion_table:
|
||||
return _promotion_table[xdt, ydt]
|
||||
|
||||
# This doesn't result_type(dtype, dtype) for non-array API dtypes
|
||||
# because torch.result_type only accepts tensors. This does however, allow
|
||||
# cross-kind promotion.
|
||||
x = torch.tensor([], dtype=x) if isinstance(x, torch.dtype) else x
|
||||
y = torch.tensor([], dtype=y) if isinstance(y, torch.dtype) else y
|
||||
return torch.result_type(x, y)
|
||||
|
||||
def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
|
||||
if not isinstance(from_, torch.dtype):
|
||||
from_ = from_.dtype
|
||||
return torch.can_cast(from_, to)
|
||||
|
||||
# Basic renames
|
||||
bitwise_invert = torch.bitwise_not
|
||||
newaxis = None
|
||||
|
||||
# Two-arg elementwise functions
|
||||
# These require a wrapper to do the correct type promotion on 0-D tensors
|
||||
add = _two_arg(torch.add)
|
||||
atan2 = _two_arg(torch.atan2)
|
||||
bitwise_and = _two_arg(torch.bitwise_and)
|
||||
bitwise_left_shift = _two_arg(torch.bitwise_left_shift)
|
||||
bitwise_or = _two_arg(torch.bitwise_or)
|
||||
bitwise_right_shift = _two_arg(torch.bitwise_right_shift)
|
||||
bitwise_xor = _two_arg(torch.bitwise_xor)
|
||||
divide = _two_arg(torch.divide)
|
||||
# Also a rename. torch.equal does not broadcast
|
||||
equal = _two_arg(torch.eq)
|
||||
floor_divide = _two_arg(torch.floor_divide)
|
||||
greater = _two_arg(torch.greater)
|
||||
greater_equal = _two_arg(torch.greater_equal)
|
||||
less = _two_arg(torch.less)
|
||||
less_equal = _two_arg(torch.less_equal)
|
||||
logaddexp = _two_arg(torch.logaddexp)
|
||||
# logical functions are not included here because they only accept bool in the
|
||||
# spec, so type promotion is irrelevant.
|
||||
multiply = _two_arg(torch.multiply)
|
||||
not_equal = _two_arg(torch.not_equal)
|
||||
pow = _two_arg(torch.pow)
|
||||
remainder = _two_arg(torch.remainder)
|
||||
subtract = _two_arg(torch.subtract)
|
||||
|
||||
# These wrappers are mostly based on the fact that pytorch uses 'dim' instead
|
||||
# of 'axis'.
|
||||
|
||||
# torch.min and torch.max return a tuple and don't support multiple axes https://github.com/pytorch/pytorch/issues/58745
|
||||
def max(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
|
||||
# https://github.com/pytorch/pytorch/issues/29137
|
||||
if axis == ():
|
||||
return torch.clone(x)
|
||||
return torch.amax(x, axis, keepdims=keepdims)
|
||||
|
||||
def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
|
||||
# https://github.com/pytorch/pytorch/issues/29137
|
||||
if axis == ():
|
||||
return torch.clone(x)
|
||||
return torch.amin(x, axis, keepdims=keepdims)
|
||||
|
||||
# torch.sort also returns a tuple
|
||||
# https://github.com/pytorch/pytorch/issues/70921
|
||||
def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> array:
|
||||
return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values
|
||||
|
||||
def _normalize_axes(axis, ndim):
|
||||
axes = []
|
||||
if ndim == 0 and axis:
|
||||
# Better error message in this case
|
||||
raise IndexError(f"Dimension out of range: {axis[0]}")
|
||||
lower, upper = -ndim, ndim - 1
|
||||
for a in axis:
|
||||
if a < lower or a > upper:
|
||||
# Match torch error message (e.g., from sum())
|
||||
raise IndexError(f"Dimension out of range (expected to be in range of [{lower}, {upper}], but got {a}")
|
||||
if a < 0:
|
||||
a = a + ndim
|
||||
if a in axes:
|
||||
# Use IndexError instead of RuntimeError, and "axis" instead of "dim"
|
||||
raise IndexError(f"Axis {a} appears multiple times in the list of axes")
|
||||
axes.append(a)
|
||||
return sorted(axes)
|
||||
|
||||
def _axis_none_keepdims(x, ndim, keepdims):
|
||||
# Apply keepdims when axis=None
|
||||
# (https://github.com/pytorch/pytorch/issues/71209)
|
||||
# Note that this is only valid for the axis=None case.
|
||||
if keepdims:
|
||||
for i in range(ndim):
|
||||
x = torch.unsqueeze(x, 0)
|
||||
return x
|
||||
|
||||
def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
|
||||
# Some reductions don't support multiple axes
|
||||
# (https://github.com/pytorch/pytorch/issues/56586).
|
||||
axes = _normalize_axes(axis, x.ndim)
|
||||
for a in reversed(axes):
|
||||
x = torch.movedim(x, a, -1)
|
||||
x = torch.flatten(x, -len(axes))
|
||||
|
||||
out = f(x, -1, **kwargs)
|
||||
|
||||
if keepdims:
|
||||
for a in axes:
|
||||
out = torch.unsqueeze(out, a)
|
||||
return out
|
||||
|
||||
def prod(x: array,
|
||||
/,
|
||||
*,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype: Optional[Dtype] = None,
|
||||
keepdims: bool = False,
|
||||
**kwargs) -> array:
|
||||
x = torch.asarray(x)
|
||||
ndim = x.ndim
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
|
||||
# below because it still needs to upcast.
|
||||
if axis == ():
|
||||
if dtype is None:
|
||||
# We can't upcast uint8 according to the spec because there is no
|
||||
# torch.uint64, so at least upcast to int64 which is what sum does
|
||||
# when axis=None.
|
||||
if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]:
|
||||
return x.to(torch.int64)
|
||||
return x.clone()
|
||||
return x.to(dtype)
|
||||
|
||||
# torch.prod doesn't support multiple axes
|
||||
# (https://github.com/pytorch/pytorch/issues/56586).
|
||||
if isinstance(axis, tuple):
|
||||
return _reduce_multiple_axes(torch.prod, x, axis, keepdims=keepdims, dtype=dtype, **kwargs)
|
||||
if axis is None:
|
||||
# torch doesn't support keepdims with axis=None
|
||||
# (https://github.com/pytorch/pytorch/issues/71209)
|
||||
res = torch.prod(x, dtype=dtype, **kwargs)
|
||||
res = _axis_none_keepdims(res, ndim, keepdims)
|
||||
return res
|
||||
|
||||
return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs)
|
||||
|
||||
|
||||
def sum(x: array,
|
||||
/,
|
||||
*,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
dtype: Optional[Dtype] = None,
|
||||
keepdims: bool = False,
|
||||
**kwargs) -> array:
|
||||
x = torch.asarray(x)
|
||||
ndim = x.ndim
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/29137.
|
||||
# Make sure it upcasts.
|
||||
if axis == ():
|
||||
if dtype is None:
|
||||
# We can't upcast uint8 according to the spec because there is no
|
||||
# torch.uint64, so at least upcast to int64 which is what sum does
|
||||
# when axis=None.
|
||||
if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]:
|
||||
return x.to(torch.int64)
|
||||
return x.clone()
|
||||
return x.to(dtype)
|
||||
|
||||
if axis is None:
|
||||
# torch doesn't support keepdims with axis=None
|
||||
# (https://github.com/pytorch/pytorch/issues/71209)
|
||||
res = torch.sum(x, dtype=dtype, **kwargs)
|
||||
res = _axis_none_keepdims(res, ndim, keepdims)
|
||||
return res
|
||||
|
||||
return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs)
|
||||
|
||||
def any(x: array,
|
||||
/,
|
||||
*,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
keepdims: bool = False,
|
||||
**kwargs) -> array:
|
||||
x = torch.asarray(x)
|
||||
ndim = x.ndim
|
||||
if axis == ():
|
||||
return x.to(torch.bool)
|
||||
# torch.any doesn't support multiple axes
|
||||
# (https://github.com/pytorch/pytorch/issues/56586).
|
||||
if isinstance(axis, tuple):
|
||||
res = _reduce_multiple_axes(torch.any, x, axis, keepdims=keepdims, **kwargs)
|
||||
return res.to(torch.bool)
|
||||
if axis is None:
|
||||
# torch doesn't support keepdims with axis=None
|
||||
# (https://github.com/pytorch/pytorch/issues/71209)
|
||||
res = torch.any(x, **kwargs)
|
||||
res = _axis_none_keepdims(res, ndim, keepdims)
|
||||
return res.to(torch.bool)
|
||||
|
||||
# torch.any doesn't return bool for uint8
|
||||
return torch.any(x, axis, keepdims=keepdims).to(torch.bool)
|
||||
|
||||
def all(x: array,
|
||||
/,
|
||||
*,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
keepdims: bool = False,
|
||||
**kwargs) -> array:
|
||||
x = torch.asarray(x)
|
||||
ndim = x.ndim
|
||||
if axis == ():
|
||||
return x.to(torch.bool)
|
||||
# torch.all doesn't support multiple axes
|
||||
# (https://github.com/pytorch/pytorch/issues/56586).
|
||||
if isinstance(axis, tuple):
|
||||
res = _reduce_multiple_axes(torch.all, x, axis, keepdims=keepdims, **kwargs)
|
||||
return res.to(torch.bool)
|
||||
if axis is None:
|
||||
# torch doesn't support keepdims with axis=None
|
||||
# (https://github.com/pytorch/pytorch/issues/71209)
|
||||
res = torch.all(x, **kwargs)
|
||||
res = _axis_none_keepdims(res, ndim, keepdims)
|
||||
return res.to(torch.bool)
|
||||
|
||||
# torch.all doesn't return bool for uint8
|
||||
return torch.all(x, axis, keepdims=keepdims).to(torch.bool)
|
||||
|
||||
def mean(x: array,
|
||||
/,
|
||||
*,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
keepdims: bool = False,
|
||||
**kwargs) -> array:
|
||||
# https://github.com/pytorch/pytorch/issues/29137
|
||||
if axis == ():
|
||||
return torch.clone(x)
|
||||
if axis is None:
|
||||
# torch doesn't support keepdims with axis=None
|
||||
# (https://github.com/pytorch/pytorch/issues/71209)
|
||||
res = torch.mean(x, **kwargs)
|
||||
res = _axis_none_keepdims(res, x.ndim, keepdims)
|
||||
return res
|
||||
return torch.mean(x, axis, keepdims=keepdims, **kwargs)
|
||||
|
||||
def std(x: array,
|
||||
/,
|
||||
*,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
correction: Union[int, float] = 0.0,
|
||||
keepdims: bool = False,
|
||||
**kwargs) -> array:
|
||||
# Note, float correction is not supported
|
||||
# https://github.com/pytorch/pytorch/issues/61492. We don't try to
|
||||
# implement it here for now.
|
||||
|
||||
if isinstance(correction, float):
|
||||
_correction = int(correction)
|
||||
if correction != _correction:
|
||||
raise NotImplementedError("float correction in torch std() is not yet supported")
|
||||
else:
|
||||
_correction = correction
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/29137
|
||||
if axis == ():
|
||||
return torch.zeros_like(x)
|
||||
if isinstance(axis, int):
|
||||
axis = (axis,)
|
||||
if axis is None:
|
||||
# torch doesn't support keepdims with axis=None
|
||||
# (https://github.com/pytorch/pytorch/issues/71209)
|
||||
res = torch.std(x, tuple(range(x.ndim)), correction=_correction, **kwargs)
|
||||
res = _axis_none_keepdims(res, x.ndim, keepdims)
|
||||
return res
|
||||
return torch.std(x, axis, correction=_correction, keepdims=keepdims, **kwargs)
|
||||
|
||||
def var(x: array,
|
||||
/,
|
||||
*,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
correction: Union[int, float] = 0.0,
|
||||
keepdims: bool = False,
|
||||
**kwargs) -> array:
|
||||
# Note, float correction is not supported
|
||||
# https://github.com/pytorch/pytorch/issues/61492. We don't try to
|
||||
# implement it here for now.
|
||||
|
||||
# if isinstance(correction, float):
|
||||
# correction = int(correction)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/29137
|
||||
if axis == ():
|
||||
return torch.zeros_like(x)
|
||||
if isinstance(axis, int):
|
||||
axis = (axis,)
|
||||
if axis is None:
|
||||
# torch doesn't support keepdims with axis=None
|
||||
# (https://github.com/pytorch/pytorch/issues/71209)
|
||||
res = torch.var(x, tuple(range(x.ndim)), correction=correction, **kwargs)
|
||||
res = _axis_none_keepdims(res, x.ndim, keepdims)
|
||||
return res
|
||||
return torch.var(x, axis, correction=correction, keepdims=keepdims, **kwargs)
|
||||
|
||||
# torch.concat doesn't support dim=None
|
||||
# https://github.com/pytorch/pytorch/issues/70925
|
||||
def concat(arrays: Union[Tuple[array, ...], List[array]],
|
||||
/,
|
||||
*,
|
||||
axis: Optional[int] = 0,
|
||||
**kwargs) -> array:
|
||||
if axis is None:
|
||||
arrays = tuple(ar.flatten() for ar in arrays)
|
||||
axis = 0
|
||||
return torch.concat(arrays, axis, **kwargs)
|
||||
|
||||
# torch.squeeze only accepts int dim and doesn't require it
|
||||
# https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was
|
||||
# added at https://github.com/pytorch/pytorch/pull/89017.
|
||||
def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array:
|
||||
if isinstance(axis, int):
|
||||
axis = (axis,)
|
||||
for a in axis:
|
||||
if x.shape[a] != 1:
|
||||
raise ValueError("squeezed dimensions must be equal to 1")
|
||||
axes = _normalize_axes(axis, x.ndim)
|
||||
# Remove this once pytorch 1.14 is released with the above PR #89017.
|
||||
sequence = [a - i for i, a in enumerate(axes)]
|
||||
for a in sequence:
|
||||
x = torch.squeeze(x, a)
|
||||
return x
|
||||
|
||||
# torch.broadcast_to uses size instead of shape
|
||||
def broadcast_to(x: array, /, shape: Tuple[int, ...], **kwargs) -> array:
|
||||
return torch.broadcast_to(x, shape, **kwargs)
|
||||
|
||||
# torch.permute uses dims instead of axes
|
||||
def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array:
|
||||
return torch.permute(x, axes)
|
||||
|
||||
# The axis parameter doesn't work for flip() and roll()
|
||||
# https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't
|
||||
# accept axis=None
|
||||
def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array:
|
||||
if axis is None:
|
||||
axis = tuple(range(x.ndim))
|
||||
# torch.flip doesn't accept dim as an int but the method does
|
||||
# https://github.com/pytorch/pytorch/issues/18095
|
||||
return x.flip(axis, **kwargs)
|
||||
|
||||
def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array:
|
||||
return torch.roll(x, shift, axis, **kwargs)
|
||||
|
||||
def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
|
||||
if x.ndim == 0:
|
||||
raise ValueError("nonzero() does not support zero-dimensional arrays")
|
||||
return torch.nonzero(x, as_tuple=True, **kwargs)
|
||||
|
||||
def where(condition: array, x1: array, x2: array, /) -> array:
|
||||
x1, x2 = _fix_promotion(x1, x2)
|
||||
return torch.where(condition, x1, x2)
|
||||
|
||||
# torch.reshape doesn't have the copy keyword
|
||||
def reshape(x: array,
|
||||
/,
|
||||
shape: Tuple[int, ...],
|
||||
copy: Optional[bool] = None,
|
||||
**kwargs) -> array:
|
||||
if copy is not None:
|
||||
raise NotImplementedError("torch.reshape doesn't yet support the copy keyword")
|
||||
return torch.reshape(x, shape, **kwargs)
|
||||
|
||||
# torch.arange doesn't support returning empty arrays
|
||||
# (https://github.com/pytorch/pytorch/issues/70915), and doesn't support some
|
||||
# keyword argument combinations
|
||||
# (https://github.com/pytorch/pytorch/issues/70914)
|
||||
def arange(start: Union[int, float],
|
||||
/,
|
||||
stop: Optional[Union[int, float]] = None,
|
||||
step: Union[int, float] = 1,
|
||||
*,
|
||||
dtype: Optional[Dtype] = None,
|
||||
device: Optional[Device] = None,
|
||||
**kwargs) -> array:
|
||||
if stop is None:
|
||||
start, stop = 0, start
|
||||
if step > 0 and stop <= start or step < 0 and stop >= start:
|
||||
if dtype is None:
|
||||
if _builtin_all(isinstance(i, int) for i in [start, stop, step]):
|
||||
dtype = torch.int64
|
||||
else:
|
||||
dtype = torch.float32
|
||||
return torch.empty(0, dtype=dtype, device=device, **kwargs)
|
||||
return torch.arange(start, stop, step, dtype=dtype, device=device, **kwargs)
|
||||
|
||||
# torch.eye does not accept None as a default for the second argument and
|
||||
# doesn't support off-diagonals (https://github.com/pytorch/pytorch/issues/70910)
|
||||
def eye(n_rows: int,
|
||||
n_cols: Optional[int] = None,
|
||||
/,
|
||||
*,
|
||||
k: int = 0,
|
||||
dtype: Optional[Dtype] = None,
|
||||
device: Optional[Device] = None,
|
||||
**kwargs) -> array:
|
||||
if n_cols is None:
|
||||
n_cols = n_rows
|
||||
z = torch.zeros(n_rows, n_cols, dtype=dtype, device=device, **kwargs)
|
||||
if abs(k) <= n_rows + n_cols:
|
||||
z.diagonal(k).fill_(1)
|
||||
return z
|
||||
|
||||
# torch.linspace doesn't have the endpoint parameter
|
||||
def linspace(start: Union[int, float],
|
||||
stop: Union[int, float],
|
||||
/,
|
||||
num: int,
|
||||
*,
|
||||
dtype: Optional[Dtype] = None,
|
||||
device: Optional[Device] = None,
|
||||
endpoint: bool = True,
|
||||
**kwargs) -> array:
|
||||
if not endpoint:
|
||||
return torch.linspace(start, stop, num+1, dtype=dtype, device=device, **kwargs)[:-1]
|
||||
return torch.linspace(start, stop, num, dtype=dtype, device=device, **kwargs)
|
||||
|
||||
# torch.full does not accept an int size
|
||||
# https://github.com/pytorch/pytorch/issues/70906
|
||||
def full(shape: Union[int, Tuple[int, ...]],
|
||||
fill_value: Union[bool, int, float, complex],
|
||||
*,
|
||||
dtype: Optional[Dtype] = None,
|
||||
device: Optional[Device] = None,
|
||||
**kwargs) -> array:
|
||||
if isinstance(shape, int):
|
||||
shape = (shape,)
|
||||
|
||||
return torch.full(shape, fill_value, dtype=dtype, device=device, **kwargs)
|
||||
|
||||
# ones, zeros, and empty do not accept shape as a keyword argument
|
||||
def ones(shape: Union[int, Tuple[int, ...]],
|
||||
*,
|
||||
dtype: Optional[Dtype] = None,
|
||||
device: Optional[Device] = None,
|
||||
**kwargs) -> array:
|
||||
return torch.ones(shape, dtype=dtype, device=device, **kwargs)
|
||||
|
||||
def zeros(shape: Union[int, Tuple[int, ...]],
|
||||
*,
|
||||
dtype: Optional[Dtype] = None,
|
||||
device: Optional[Device] = None,
|
||||
**kwargs) -> array:
|
||||
return torch.zeros(shape, dtype=dtype, device=device, **kwargs)
|
||||
|
||||
def empty(shape: Union[int, Tuple[int, ...]],
|
||||
*,
|
||||
dtype: Optional[Dtype] = None,
|
||||
device: Optional[Device] = None,
|
||||
**kwargs) -> array:
|
||||
return torch.empty(shape, dtype=dtype, device=device, **kwargs)
|
||||
|
||||
# tril and triu do not call the keyword argument k
|
||||
|
||||
def tril(x: array, /, *, k: int = 0) -> array:
|
||||
return torch.tril(x, k)
|
||||
|
||||
def triu(x: array, /, *, k: int = 0) -> array:
|
||||
return torch.triu(x, k)
|
||||
|
||||
# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742
|
||||
def expand_dims(x: array, /, *, axis: int = 0) -> array:
|
||||
return torch.unsqueeze(x, axis)
|
||||
|
||||
def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array:
|
||||
return x.to(dtype, copy=copy)
|
||||
|
||||
def broadcast_arrays(*arrays: array) -> List[array]:
|
||||
shape = torch.broadcast_shapes(*[a.shape for a in arrays])
|
||||
return [torch.broadcast_to(a, shape) for a in arrays]
|
||||
|
||||
# Note that these named tuples aren't actually part of the standard namespace,
|
||||
# but I don't see any issue with exporting the names here regardless.
|
||||
from ..common._aliases import (UniqueAllResult, UniqueCountsResult,
|
||||
UniqueInverseResult)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/70920
|
||||
def unique_all(x: array) -> UniqueAllResult:
|
||||
# torch.unique doesn't support returning indices.
|
||||
# https://github.com/pytorch/pytorch/issues/36748. The workaround
|
||||
# suggested in that issue doesn't actually function correctly (it relies
|
||||
# on non-deterministic behavior of scatter()).
|
||||
raise NotImplementedError("unique_all() not yet implemented for pytorch (see https://github.com/pytorch/pytorch/issues/36748)")
|
||||
|
||||
# values, inverse_indices, counts = torch.unique(x, return_counts=True, return_inverse=True)
|
||||
# # torch.unique incorrectly gives a 0 count for nan values.
|
||||
# # https://github.com/pytorch/pytorch/issues/94106
|
||||
# counts[torch.isnan(values)] = 1
|
||||
# return UniqueAllResult(values, indices, inverse_indices, counts)
|
||||
|
||||
def unique_counts(x: array) -> UniqueCountsResult:
|
||||
values, counts = torch.unique(x, return_counts=True)
|
||||
|
||||
# torch.unique incorrectly gives a 0 count for nan values.
|
||||
# https://github.com/pytorch/pytorch/issues/94106
|
||||
counts[torch.isnan(values)] = 1
|
||||
return UniqueCountsResult(values, counts)
|
||||
|
||||
def unique_inverse(x: array) -> UniqueInverseResult:
|
||||
values, inverse = torch.unique(x, return_inverse=True)
|
||||
return UniqueInverseResult(values, inverse)
|
||||
|
||||
def unique_values(x: array) -> array:
|
||||
return torch.unique(x)
|
||||
|
||||
def matmul(x1: array, x2: array, /, **kwargs) -> array:
|
||||
# torch.matmul doesn't type promote (but differently from _fix_promotion)
|
||||
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
|
||||
return torch.matmul(x1, x2, **kwargs)
|
||||
|
||||
matrix_transpose = get_xp(torch)(_aliases_matrix_transpose)
|
||||
_vecdot = get_xp(torch)(_aliases_vecdot)
|
||||
|
||||
def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
|
||||
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
|
||||
return _vecdot(x1, x2, axis=axis)
|
||||
|
||||
# torch.tensordot uses dims instead of axes
|
||||
def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, **kwargs) -> array:
|
||||
# Note: torch.tensordot fails with integer dtypes when there is only 1
|
||||
# element in the axis (https://github.com/pytorch/pytorch/issues/84530).
|
||||
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
|
||||
return torch.tensordot(x1, x2, dims=axes, **kwargs)
|
||||
|
||||
|
||||
def isdtype(
|
||||
dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]],
|
||||
*, _tuple=True, # Disallow nested tuples
|
||||
) -> bool:
|
||||
"""
|
||||
Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
|
||||
|
||||
Note that outside of this function, this compat library does not yet fully
|
||||
support complex numbers.
|
||||
|
||||
See
|
||||
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
|
||||
for more details
|
||||
"""
|
||||
if isinstance(kind, tuple) and _tuple:
|
||||
return _builtin_any(isdtype(dtype, k, _tuple=False) for k in kind)
|
||||
elif isinstance(kind, str):
|
||||
if kind == 'bool':
|
||||
return dtype == torch.bool
|
||||
elif kind == 'signed integer':
|
||||
return dtype in _int_dtypes and dtype.is_signed
|
||||
elif kind == 'unsigned integer':
|
||||
return dtype in _int_dtypes and not dtype.is_signed
|
||||
elif kind == 'integral':
|
||||
return dtype in _int_dtypes
|
||||
elif kind == 'real floating':
|
||||
return dtype.is_floating_point
|
||||
elif kind == 'complex floating':
|
||||
return dtype.is_complex
|
||||
elif kind == 'numeric':
|
||||
return isdtype(dtype, ('integral', 'real floating', 'complex floating'))
|
||||
else:
|
||||
raise ValueError(f"Unrecognized data type kind: {kind!r}")
|
||||
else:
|
||||
return dtype == kind
|
||||
|
||||
def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -> array:
|
||||
if axis is None:
|
||||
if x.ndim != 1:
|
||||
raise ValueError("axis must be specified when ndim > 1")
|
||||
axis = 0
|
||||
return torch.index_select(x, axis, indices, **kwargs)
|
||||
|
||||
__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert',
|
||||
'newaxis', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift',
|
||||
'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'divide',
|
||||
'equal', 'floor_divide', 'greater', 'greater_equal', 'less',
|
||||
'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow',
|
||||
'remainder', 'subtract', 'max', 'min', 'sort', 'prod', 'sum',
|
||||
'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze',
|
||||
'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape',
|
||||
'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty',
|
||||
'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays',
|
||||
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
|
||||
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
|
||||
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype',
|
||||
'take']
|
||||
|
||||
_all_ignore = ['torch', 'get_xp']
|
||||
@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
array = torch.Tensor
|
||||
from typing import Union, Sequence, Literal
|
||||
|
||||
from torch.fft import * # noqa: F403
|
||||
import torch.fft
|
||||
|
||||
# Several torch fft functions do not map axes to dim
|
||||
|
||||
def fftn(
|
||||
x: array,
|
||||
/,
|
||||
*,
|
||||
s: Sequence[int] = None,
|
||||
axes: Sequence[int] = None,
|
||||
norm: Literal["backward", "ortho", "forward"] = "backward",
|
||||
**kwargs,
|
||||
) -> array:
|
||||
return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs)
|
||||
|
||||
def ifftn(
|
||||
x: array,
|
||||
/,
|
||||
*,
|
||||
s: Sequence[int] = None,
|
||||
axes: Sequence[int] = None,
|
||||
norm: Literal["backward", "ortho", "forward"] = "backward",
|
||||
**kwargs,
|
||||
) -> array:
|
||||
return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs)
|
||||
|
||||
def rfftn(
|
||||
x: array,
|
||||
/,
|
||||
*,
|
||||
s: Sequence[int] = None,
|
||||
axes: Sequence[int] = None,
|
||||
norm: Literal["backward", "ortho", "forward"] = "backward",
|
||||
**kwargs,
|
||||
) -> array:
|
||||
return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs)
|
||||
|
||||
def irfftn(
|
||||
x: array,
|
||||
/,
|
||||
*,
|
||||
s: Sequence[int] = None,
|
||||
axes: Sequence[int] = None,
|
||||
norm: Literal["backward", "ortho", "forward"] = "backward",
|
||||
**kwargs,
|
||||
) -> array:
|
||||
return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs)
|
||||
|
||||
def fftshift(
|
||||
x: array,
|
||||
/,
|
||||
*,
|
||||
axes: Union[int, Sequence[int]] = None,
|
||||
**kwargs,
|
||||
) -> array:
|
||||
return torch.fft.fftshift(x, dim=axes, **kwargs)
|
||||
|
||||
def ifftshift(
|
||||
x: array,
|
||||
/,
|
||||
*,
|
||||
axes: Union[int, Sequence[int]] = None,
|
||||
**kwargs,
|
||||
) -> array:
|
||||
return torch.fft.ifftshift(x, dim=axes, **kwargs)
|
||||
|
||||
|
||||
__all__ = torch.fft.__all__ + [
|
||||
"fftn",
|
||||
"ifftn",
|
||||
"rfftn",
|
||||
"irfftn",
|
||||
"fftshift",
|
||||
"ifftshift",
|
||||
]
|
||||
|
||||
_all_ignore = ['torch']
|
||||
@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
array = torch.Tensor
|
||||
from torch import dtype as Dtype
|
||||
from typing import Optional, Union, Tuple, Literal
|
||||
inf = float('inf')
|
||||
|
||||
from ._aliases import _fix_promotion, sum
|
||||
|
||||
from torch.linalg import * # noqa: F403
|
||||
|
||||
# torch.linalg doesn't define __all__
|
||||
# from torch.linalg import __all__ as linalg_all
|
||||
from torch import linalg as torch_linalg
|
||||
linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')]
|
||||
|
||||
# outer is implemented in torch but aren't in the linalg namespace
|
||||
from torch import outer
|
||||
# These functions are in both the main and linalg namespaces
|
||||
from ._aliases import matmul, matrix_transpose, tensordot
|
||||
|
||||
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
|
||||
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
|
||||
|
||||
# torch.cross also does not support broadcasting when it would add new
|
||||
# dimensions https://github.com/pytorch/pytorch/issues/39656
|
||||
def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
|
||||
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
|
||||
if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)):
|
||||
raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}")
|
||||
if not (x1.shape[axis] == x2.shape[axis] == 3):
|
||||
raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}")
|
||||
x1, x2 = torch.broadcast_tensors(x1, x2)
|
||||
return torch_linalg.cross(x1, x2, dim=axis)
|
||||
|
||||
def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
|
||||
from ._aliases import isdtype
|
||||
|
||||
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
|
||||
|
||||
# torch.linalg.vecdot incorrectly allows broadcasting along the contracted dimension
|
||||
if x1.shape[axis] != x2.shape[axis]:
|
||||
raise ValueError("x1 and x2 must have the same size along the given axis")
|
||||
|
||||
# torch.linalg.vecdot doesn't support integer dtypes
|
||||
if isdtype(x1.dtype, 'integral') or isdtype(x2.dtype, 'integral'):
|
||||
if kwargs:
|
||||
raise RuntimeError("vecdot kwargs not supported for integral dtypes")
|
||||
|
||||
x1_ = torch.moveaxis(x1, axis, -1)
|
||||
x2_ = torch.moveaxis(x2, axis, -1)
|
||||
x1_, x2_ = torch.broadcast_tensors(x1_, x2_)
|
||||
|
||||
res = x1_[..., None, :] @ x2_[..., None]
|
||||
return res[..., 0, 0]
|
||||
return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs)
|
||||
|
||||
def solve(x1: array, x2: array, /, **kwargs) -> array:
|
||||
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
|
||||
return torch.linalg.solve(x1, x2, **kwargs)
|
||||
|
||||
# torch.trace doesn't support the offset argument and doesn't support stacking
|
||||
def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
|
||||
# Use our wrapped sum to make sure it does upcasting correctly
|
||||
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
|
||||
|
||||
def vector_norm(
|
||||
x: array,
|
||||
/,
|
||||
*,
|
||||
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
keepdims: bool = False,
|
||||
ord: Union[int, float, Literal[inf, -inf]] = 2,
|
||||
**kwargs,
|
||||
) -> array:
|
||||
# torch.vector_norm incorrectly treats axis=() the same as axis=None
|
||||
if axis == ():
|
||||
keepdims = True
|
||||
return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs)
|
||||
|
||||
__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot',
|
||||
'cross', 'vecdot', 'solve', 'trace', 'vector_norm']
|
||||
|
||||
_all_ignore = ['torch_linalg', 'sum']
|
||||
|
||||
del linalg_all
|
||||
@ -0,0 +1,20 @@
|
||||
from .main import minimize
|
||||
from .utils import show_versions
|
||||
|
||||
# PEP0440 compatible formatted version, see:
|
||||
# https://www.python.org/dev/peps/pep-0440/
|
||||
#
|
||||
# Final release markers:
|
||||
# X.Y.0 # For first release after an increment in Y
|
||||
# X.Y.Z # For bugfix releases
|
||||
#
|
||||
# Admissible pre-release markers:
|
||||
# X.YaN # Alpha release
|
||||
# X.YbN # Beta release
|
||||
# X.YrcN # Release Candidate
|
||||
#
|
||||
# Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer.
|
||||
# 'X.Y.dev0' is the canonical version of 'X.Y.dev'.
|
||||
__version__ = "1.1.1"
|
||||
|
||||
__all__ = ["minimize", "show_versions"]
|
||||
1240
venv/lib/python3.12/site-packages/scipy/_lib/cobyqa/framework.py
Normal file
1240
venv/lib/python3.12/site-packages/scipy/_lib/cobyqa/framework.py
Normal file
File diff suppressed because it is too large
Load Diff
1488
venv/lib/python3.12/site-packages/scipy/_lib/cobyqa/main.py
Normal file
1488
venv/lib/python3.12/site-packages/scipy/_lib/cobyqa/main.py
Normal file
File diff suppressed because it is too large
Load Diff
1525
venv/lib/python3.12/site-packages/scipy/_lib/cobyqa/models.py
Normal file
1525
venv/lib/python3.12/site-packages/scipy/_lib/cobyqa/models.py
Normal file
File diff suppressed because it is too large
Load Diff
1287
venv/lib/python3.12/site-packages/scipy/_lib/cobyqa/problem.py
Normal file
1287
venv/lib/python3.12/site-packages/scipy/_lib/cobyqa/problem.py
Normal file
File diff suppressed because it is too large
Load Diff
132
venv/lib/python3.12/site-packages/scipy/_lib/cobyqa/settings.py
Normal file
132
venv/lib/python3.12/site-packages/scipy/_lib/cobyqa/settings.py
Normal file
@ -0,0 +1,132 @@
|
||||
import sys
|
||||
from enum import Enum
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
# Exit status.
|
||||
class ExitStatus(Enum):
|
||||
"""
|
||||
Exit statuses.
|
||||
"""
|
||||
|
||||
RADIUS_SUCCESS = 0
|
||||
TARGET_SUCCESS = 1
|
||||
FIXED_SUCCESS = 2
|
||||
CALLBACK_SUCCESS = 3
|
||||
FEASIBLE_SUCCESS = 4
|
||||
MAX_EVAL_WARNING = 5
|
||||
MAX_ITER_WARNING = 6
|
||||
INFEASIBLE_ERROR = -1
|
||||
LINALG_ERROR = -2
|
||||
|
||||
|
||||
class Options(str, Enum):
|
||||
"""
|
||||
Options.
|
||||
"""
|
||||
|
||||
DEBUG = "debug"
|
||||
FEASIBILITY_TOL = "feasibility_tol"
|
||||
FILTER_SIZE = "filter_size"
|
||||
HISTORY_SIZE = "history_size"
|
||||
MAX_EVAL = "maxfev"
|
||||
MAX_ITER = "maxiter"
|
||||
NPT = "nb_points"
|
||||
RHOBEG = "radius_init"
|
||||
RHOEND = "radius_final"
|
||||
SCALE = "scale"
|
||||
STORE_HISTORY = "store_history"
|
||||
TARGET = "target"
|
||||
VERBOSE = "disp"
|
||||
|
||||
|
||||
class Constants(str, Enum):
|
||||
"""
|
||||
Constants.
|
||||
"""
|
||||
|
||||
DECREASE_RADIUS_FACTOR = "decrease_radius_factor"
|
||||
INCREASE_RADIUS_FACTOR = "increase_radius_factor"
|
||||
INCREASE_RADIUS_THRESHOLD = "increase_radius_threshold"
|
||||
DECREASE_RADIUS_THRESHOLD = "decrease_radius_threshold"
|
||||
DECREASE_RESOLUTION_FACTOR = "decrease_resolution_factor"
|
||||
LARGE_RESOLUTION_THRESHOLD = "large_resolution_threshold"
|
||||
MODERATE_RESOLUTION_THRESHOLD = "moderate_resolution_threshold"
|
||||
LOW_RATIO = "low_ratio"
|
||||
HIGH_RATIO = "high_ratio"
|
||||
VERY_LOW_RATIO = "very_low_ratio"
|
||||
PENALTY_INCREASE_THRESHOLD = "penalty_increase_threshold"
|
||||
PENALTY_INCREASE_FACTOR = "penalty_increase_factor"
|
||||
SHORT_STEP_THRESHOLD = "short_step_threshold"
|
||||
LOW_RADIUS_FACTOR = "low_radius_factor"
|
||||
BYRD_OMOJOKUN_FACTOR = "byrd_omojokun_factor"
|
||||
THRESHOLD_RATIO_CONSTRAINTS = "threshold_ratio_constraints"
|
||||
LARGE_SHIFT_FACTOR = "large_shift_factor"
|
||||
LARGE_GRADIENT_FACTOR = "large_gradient_factor"
|
||||
RESOLUTION_FACTOR = "resolution_factor"
|
||||
IMPROVE_TCG = "improve_tcg"
|
||||
|
||||
|
||||
# Default options.
|
||||
DEFAULT_OPTIONS = {
|
||||
Options.DEBUG.value: False,
|
||||
Options.FEASIBILITY_TOL.value: np.sqrt(np.finfo(float).eps),
|
||||
Options.FILTER_SIZE.value: sys.maxsize,
|
||||
Options.HISTORY_SIZE.value: sys.maxsize,
|
||||
Options.MAX_EVAL.value: lambda n: 500 * n,
|
||||
Options.MAX_ITER.value: lambda n: 1000 * n,
|
||||
Options.NPT.value: lambda n: 2 * n + 1,
|
||||
Options.RHOBEG.value: 1.0,
|
||||
Options.RHOEND.value: 1e-6,
|
||||
Options.SCALE.value: False,
|
||||
Options.STORE_HISTORY.value: False,
|
||||
Options.TARGET.value: -np.inf,
|
||||
Options.VERBOSE.value: False,
|
||||
}
|
||||
|
||||
# Default constants.
|
||||
DEFAULT_CONSTANTS = {
|
||||
Constants.DECREASE_RADIUS_FACTOR.value: 0.5,
|
||||
Constants.INCREASE_RADIUS_FACTOR.value: np.sqrt(2.0),
|
||||
Constants.INCREASE_RADIUS_THRESHOLD.value: 2.0,
|
||||
Constants.DECREASE_RADIUS_THRESHOLD.value: 1.4,
|
||||
Constants.DECREASE_RESOLUTION_FACTOR.value: 0.1,
|
||||
Constants.LARGE_RESOLUTION_THRESHOLD.value: 250.0,
|
||||
Constants.MODERATE_RESOLUTION_THRESHOLD.value: 16.0,
|
||||
Constants.LOW_RATIO.value: 0.1,
|
||||
Constants.HIGH_RATIO.value: 0.7,
|
||||
Constants.VERY_LOW_RATIO.value: 0.01,
|
||||
Constants.PENALTY_INCREASE_THRESHOLD.value: 1.5,
|
||||
Constants.PENALTY_INCREASE_FACTOR.value: 2.0,
|
||||
Constants.SHORT_STEP_THRESHOLD.value: 0.5,
|
||||
Constants.LOW_RADIUS_FACTOR.value: 0.1,
|
||||
Constants.BYRD_OMOJOKUN_FACTOR.value: 0.8,
|
||||
Constants.THRESHOLD_RATIO_CONSTRAINTS.value: 2.0,
|
||||
Constants.LARGE_SHIFT_FACTOR.value: 10.0,
|
||||
Constants.LARGE_GRADIENT_FACTOR.value: 10.0,
|
||||
Constants.RESOLUTION_FACTOR.value: 2.0,
|
||||
Constants.IMPROVE_TCG.value: True,
|
||||
}
|
||||
|
||||
# Printing options.
|
||||
PRINT_OPTIONS = {
|
||||
"threshold": 6,
|
||||
"edgeitems": 2,
|
||||
"linewidth": sys.maxsize,
|
||||
"formatter": {
|
||||
"float_kind": lambda x: np.format_float_scientific(
|
||||
x,
|
||||
precision=3,
|
||||
unique=False,
|
||||
pad_left=2,
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
# Constants.
|
||||
BARRIER = 2.0 ** min(
|
||||
100,
|
||||
np.finfo(float).maxexp // 2,
|
||||
-np.finfo(float).minexp // 2,
|
||||
)
|
||||
@ -0,0 +1,14 @@
|
||||
from .geometry import cauchy_geometry, spider_geometry
|
||||
from .optim import (
|
||||
tangential_byrd_omojokun,
|
||||
constrained_tangential_byrd_omojokun,
|
||||
normal_byrd_omojokun,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"cauchy_geometry",
|
||||
"spider_geometry",
|
||||
"tangential_byrd_omojokun",
|
||||
"constrained_tangential_byrd_omojokun",
|
||||
"normal_byrd_omojokun",
|
||||
]
|
||||
@ -0,0 +1,387 @@
|
||||
import inspect
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..utils import get_arrays_tol
|
||||
|
||||
|
||||
TINY = np.finfo(float).tiny
|
||||
|
||||
|
||||
def cauchy_geometry(const, grad, curv, xl, xu, delta, debug):
|
||||
r"""
|
||||
Maximize approximately the absolute value of a quadratic function subject
|
||||
to bound constraints in a trust region.
|
||||
|
||||
This function solves approximately
|
||||
|
||||
.. math::
|
||||
|
||||
\max_{s \in \mathbb{R}^n} \quad \bigg\lvert c + g^{\mathsf{T}} s +
|
||||
\frac{1}{2} s^{\mathsf{T}} H s \bigg\rvert \quad \text{s.t.} \quad
|
||||
\left\{ \begin{array}{l}
|
||||
l \le s \le u,\\
|
||||
\lVert s \rVert \le \Delta,
|
||||
\end{array} \right.
|
||||
|
||||
by maximizing the objective function along the constrained Cauchy
|
||||
direction.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
const : float
|
||||
Constant :math:`c` as shown above.
|
||||
grad : `numpy.ndarray`, shape (n,)
|
||||
Gradient :math:`g` as shown above.
|
||||
curv : callable
|
||||
Curvature of :math:`H` along any vector.
|
||||
|
||||
``curv(s) -> float``
|
||||
|
||||
returns :math:`s^{\mathsf{T}} H s`.
|
||||
xl : `numpy.ndarray`, shape (n,)
|
||||
Lower bounds :math:`l` as shown above.
|
||||
xu : `numpy.ndarray`, shape (n,)
|
||||
Upper bounds :math:`u` as shown above.
|
||||
delta : float
|
||||
Trust-region radius :math:`\Delta` as shown above.
|
||||
debug : bool
|
||||
Whether to make debugging tests during the execution.
|
||||
|
||||
Returns
|
||||
-------
|
||||
`numpy.ndarray`, shape (n,)
|
||||
Approximate solution :math:`s`.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function is described as the first alternative in Section 6.5 of [1]_.
|
||||
It is assumed that the origin is feasible with respect to the bound
|
||||
constraints and that `delta` is finite and positive.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] T. M. Ragonneau. *Model-Based Derivative-Free Optimization Methods
|
||||
and Software*. PhD thesis, Department of Applied Mathematics, The Hong
|
||||
Kong Polytechnic University, Hong Kong, China, 2022. URL:
|
||||
https://theses.lib.polyu.edu.hk/handle/200/12294.
|
||||
"""
|
||||
if debug:
|
||||
assert isinstance(const, float)
|
||||
assert isinstance(grad, np.ndarray) and grad.ndim == 1
|
||||
assert inspect.signature(curv).bind(grad)
|
||||
assert isinstance(xl, np.ndarray) and xl.shape == grad.shape
|
||||
assert isinstance(xu, np.ndarray) and xu.shape == grad.shape
|
||||
assert isinstance(delta, float)
|
||||
assert isinstance(debug, bool)
|
||||
tol = get_arrays_tol(xl, xu)
|
||||
assert np.all(xl <= tol)
|
||||
assert np.all(xu >= -tol)
|
||||
assert np.isfinite(delta) and delta > 0.0
|
||||
xl = np.minimum(xl, 0.0)
|
||||
xu = np.maximum(xu, 0.0)
|
||||
|
||||
# To maximize the absolute value of a quadratic function, we maximize the
|
||||
# function itself or its negative, and we choose the solution that provides
|
||||
# the largest function value.
|
||||
step1, q_val1 = _cauchy_geom(const, grad, curv, xl, xu, delta, debug)
|
||||
step2, q_val2 = _cauchy_geom(
|
||||
-const,
|
||||
-grad,
|
||||
lambda x: -curv(x),
|
||||
xl,
|
||||
xu,
|
||||
delta,
|
||||
debug,
|
||||
)
|
||||
step = step1 if abs(q_val1) >= abs(q_val2) else step2
|
||||
|
||||
if debug:
|
||||
assert np.all(xl <= step)
|
||||
assert np.all(step <= xu)
|
||||
assert np.linalg.norm(step) < 1.1 * delta
|
||||
return step
|
||||
|
||||
|
||||
def spider_geometry(const, grad, curv, xpt, xl, xu, delta, debug):
|
||||
r"""
|
||||
Maximize approximately the absolute value of a quadratic function subject
|
||||
to bound constraints in a trust region.
|
||||
|
||||
This function solves approximately
|
||||
|
||||
.. math::
|
||||
|
||||
\max_{s \in \mathbb{R}^n} \quad \bigg\lvert c + g^{\mathsf{T}} s +
|
||||
\frac{1}{2} s^{\mathsf{T}} H s \bigg\rvert \quad \text{s.t.} \quad
|
||||
\left\{ \begin{array}{l}
|
||||
l \le s \le u,\\
|
||||
\lVert s \rVert \le \Delta,
|
||||
\end{array} \right.
|
||||
|
||||
by maximizing the objective function along given straight lines.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
const : float
|
||||
Constant :math:`c` as shown above.
|
||||
grad : `numpy.ndarray`, shape (n,)
|
||||
Gradient :math:`g` as shown above.
|
||||
curv : callable
|
||||
Curvature of :math:`H` along any vector.
|
||||
|
||||
``curv(s) -> float``
|
||||
|
||||
returns :math:`s^{\mathsf{T}} H s`.
|
||||
xpt : `numpy.ndarray`, shape (n, npt)
|
||||
Points defining the straight lines. The straight lines considered are
|
||||
the ones passing through the origin and the points in `xpt`.
|
||||
xl : `numpy.ndarray`, shape (n,)
|
||||
Lower bounds :math:`l` as shown above.
|
||||
xu : `numpy.ndarray`, shape (n,)
|
||||
Upper bounds :math:`u` as shown above.
|
||||
delta : float
|
||||
Trust-region radius :math:`\Delta` as shown above.
|
||||
debug : bool
|
||||
Whether to make debugging tests during the execution.
|
||||
|
||||
Returns
|
||||
-------
|
||||
`numpy.ndarray`, shape (n,)
|
||||
Approximate solution :math:`s`.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function is described as the second alternative in Section 6.5 of
|
||||
[1]_. It is assumed that the origin is feasible with respect to the bound
|
||||
constraints and that `delta` is finite and positive.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] T. M. Ragonneau. *Model-Based Derivative-Free Optimization Methods
|
||||
and Software*. PhD thesis, Department of Applied Mathematics, The Hong
|
||||
Kong Polytechnic University, Hong Kong, China, 2022. URL:
|
||||
https://theses.lib.polyu.edu.hk/handle/200/12294.
|
||||
"""
|
||||
if debug:
|
||||
assert isinstance(const, float)
|
||||
assert isinstance(grad, np.ndarray) and grad.ndim == 1
|
||||
assert inspect.signature(curv).bind(grad)
|
||||
assert (
|
||||
isinstance(xpt, np.ndarray)
|
||||
and xpt.ndim == 2
|
||||
and xpt.shape[0] == grad.size
|
||||
)
|
||||
assert isinstance(xl, np.ndarray) and xl.shape == grad.shape
|
||||
assert isinstance(xu, np.ndarray) and xu.shape == grad.shape
|
||||
assert isinstance(delta, float)
|
||||
assert isinstance(debug, bool)
|
||||
tol = get_arrays_tol(xl, xu)
|
||||
assert np.all(xl <= tol)
|
||||
assert np.all(xu >= -tol)
|
||||
assert np.isfinite(delta) and delta > 0.0
|
||||
xl = np.minimum(xl, 0.0)
|
||||
xu = np.maximum(xu, 0.0)
|
||||
|
||||
# Iterate through the straight lines.
|
||||
step = np.zeros_like(grad)
|
||||
q_val = const
|
||||
s_norm = np.linalg.norm(xpt, axis=0)
|
||||
|
||||
# Set alpha_xl to the step size for the lower-bound constraint and
|
||||
# alpha_xu to the step size for the upper-bound constraint.
|
||||
|
||||
# xl.shape = (N,)
|
||||
# xpt.shape = (N, M)
|
||||
# i_xl_pos.shape = (M, N)
|
||||
i_xl_pos = (xl > -np.inf) & (xpt.T > -TINY * xl)
|
||||
i_xl_neg = (xl > -np.inf) & (xpt.T < TINY * xl)
|
||||
i_xu_pos = (xu < np.inf) & (xpt.T > TINY * xu)
|
||||
i_xu_neg = (xu < np.inf) & (xpt.T < -TINY * xu)
|
||||
|
||||
# (M, N)
|
||||
alpha_xl_pos = np.atleast_2d(
|
||||
np.broadcast_to(xl, i_xl_pos.shape)[i_xl_pos] / xpt.T[i_xl_pos]
|
||||
)
|
||||
# (M,)
|
||||
alpha_xl_pos = np.max(alpha_xl_pos, axis=1, initial=-np.inf)
|
||||
# make sure it's (M,)
|
||||
alpha_xl_pos = np.broadcast_to(np.atleast_1d(alpha_xl_pos), xpt.shape[1])
|
||||
|
||||
alpha_xl_neg = np.atleast_2d(
|
||||
np.broadcast_to(xl, i_xl_neg.shape)[i_xl_neg] / xpt.T[i_xl_neg]
|
||||
)
|
||||
alpha_xl_neg = np.max(alpha_xl_neg, axis=1, initial=np.inf)
|
||||
alpha_xl_neg = np.broadcast_to(np.atleast_1d(alpha_xl_neg), xpt.shape[1])
|
||||
|
||||
alpha_xu_neg = np.atleast_2d(
|
||||
np.broadcast_to(xu, i_xu_neg.shape)[i_xu_neg] / xpt.T[i_xu_neg]
|
||||
)
|
||||
alpha_xu_neg = np.max(alpha_xu_neg, axis=1, initial=-np.inf)
|
||||
alpha_xu_neg = np.broadcast_to(np.atleast_1d(alpha_xu_neg), xpt.shape[1])
|
||||
|
||||
alpha_xu_pos = np.atleast_2d(
|
||||
np.broadcast_to(xu, i_xu_pos.shape)[i_xu_pos] / xpt.T[i_xu_pos]
|
||||
)
|
||||
alpha_xu_pos = np.max(alpha_xu_pos, axis=1, initial=np.inf)
|
||||
alpha_xu_pos = np.broadcast_to(np.atleast_1d(alpha_xu_pos), xpt.shape[1])
|
||||
|
||||
for k in range(xpt.shape[1]):
|
||||
# Set alpha_tr to the step size for the trust-region constraint.
|
||||
if s_norm[k] > TINY * delta:
|
||||
alpha_tr = max(delta / s_norm[k], 0.0)
|
||||
else:
|
||||
# The current straight line is basically zero.
|
||||
continue
|
||||
|
||||
alpha_bd_pos = max(min(alpha_xu_pos[k], alpha_xl_neg[k]), 0.0)
|
||||
alpha_bd_neg = min(max(alpha_xl_pos[k], alpha_xu_neg[k]), 0.0)
|
||||
|
||||
# Set alpha_quad_pos and alpha_quad_neg to the step size to the extrema
|
||||
# of the quadratic function along the positive and negative directions.
|
||||
grad_step = grad @ xpt[:, k]
|
||||
curv_step = curv(xpt[:, k])
|
||||
if (
|
||||
grad_step >= 0.0
|
||||
and curv_step < -TINY * grad_step
|
||||
or grad_step <= 0.0
|
||||
and curv_step > -TINY * grad_step
|
||||
):
|
||||
alpha_quad_pos = max(-grad_step / curv_step, 0.0)
|
||||
else:
|
||||
alpha_quad_pos = np.inf
|
||||
if (
|
||||
grad_step >= 0.0
|
||||
and curv_step > TINY * grad_step
|
||||
or grad_step <= 0.0
|
||||
and curv_step < TINY * grad_step
|
||||
):
|
||||
alpha_quad_neg = min(-grad_step / curv_step, 0.0)
|
||||
else:
|
||||
alpha_quad_neg = -np.inf
|
||||
|
||||
# Select the step that provides the largest value of the objective
|
||||
# function if it improves the current best. The best positive step is
|
||||
# either the one that reaches the constraints or the one that reaches
|
||||
# the extremum of the objective function along the current direction
|
||||
# (only possible if the resulting step is feasible). We test both, and
|
||||
# we perform similar calculations along the negative step.
|
||||
# N.B.: we select the largest possible step among all the ones that
|
||||
# maximize the objective function. This is to avoid returning the zero
|
||||
# step in some extreme cases.
|
||||
alpha_pos = min(alpha_tr, alpha_bd_pos)
|
||||
alpha_neg = max(-alpha_tr, alpha_bd_neg)
|
||||
q_val_pos = (
|
||||
const + alpha_pos * grad_step + 0.5 * alpha_pos**2.0 * curv_step
|
||||
)
|
||||
q_val_neg = (
|
||||
const + alpha_neg * grad_step + 0.5 * alpha_neg**2.0 * curv_step
|
||||
)
|
||||
if alpha_quad_pos < alpha_pos:
|
||||
q_val_quad_pos = (
|
||||
const
|
||||
+ alpha_quad_pos * grad_step
|
||||
+ 0.5 * alpha_quad_pos**2.0 * curv_step
|
||||
)
|
||||
if abs(q_val_quad_pos) > abs(q_val_pos):
|
||||
alpha_pos = alpha_quad_pos
|
||||
q_val_pos = q_val_quad_pos
|
||||
if alpha_quad_neg > alpha_neg:
|
||||
q_val_quad_neg = (
|
||||
const
|
||||
+ alpha_quad_neg * grad_step
|
||||
+ 0.5 * alpha_quad_neg**2.0 * curv_step
|
||||
)
|
||||
if abs(q_val_quad_neg) > abs(q_val_neg):
|
||||
alpha_neg = alpha_quad_neg
|
||||
q_val_neg = q_val_quad_neg
|
||||
if abs(q_val_pos) >= abs(q_val_neg) and abs(q_val_pos) > abs(q_val):
|
||||
step = np.clip(alpha_pos * xpt[:, k], xl, xu)
|
||||
q_val = q_val_pos
|
||||
elif abs(q_val_neg) > abs(q_val_pos) and abs(q_val_neg) > abs(q_val):
|
||||
step = np.clip(alpha_neg * xpt[:, k], xl, xu)
|
||||
q_val = q_val_neg
|
||||
|
||||
if debug:
|
||||
assert np.all(xl <= step)
|
||||
assert np.all(step <= xu)
|
||||
assert np.linalg.norm(step) < 1.1 * delta
|
||||
return step
|
||||
|
||||
|
||||
def _cauchy_geom(const, grad, curv, xl, xu, delta, debug):
|
||||
"""
|
||||
Same as `bound_constrained_cauchy_step` without the absolute value.
|
||||
"""
|
||||
# Calculate the initial active set.
|
||||
fixed_xl = (xl < 0.0) & (grad > 0.0)
|
||||
fixed_xu = (xu > 0.0) & (grad < 0.0)
|
||||
|
||||
# Calculate the Cauchy step.
|
||||
cauchy_step = np.zeros_like(grad)
|
||||
cauchy_step[fixed_xl] = xl[fixed_xl]
|
||||
cauchy_step[fixed_xu] = xu[fixed_xu]
|
||||
if np.linalg.norm(cauchy_step) > delta:
|
||||
working = fixed_xl | fixed_xu
|
||||
while True:
|
||||
# Calculate the Cauchy step for the directions in the working set.
|
||||
g_norm = np.linalg.norm(grad[working])
|
||||
delta_reduced = np.sqrt(
|
||||
delta**2.0 - cauchy_step[~working] @ cauchy_step[~working]
|
||||
)
|
||||
if g_norm > TINY * abs(delta_reduced):
|
||||
mu = max(delta_reduced / g_norm, 0.0)
|
||||
else:
|
||||
break
|
||||
cauchy_step[working] = mu * grad[working]
|
||||
|
||||
# Update the working set.
|
||||
fixed_xl = working & (cauchy_step < xl)
|
||||
fixed_xu = working & (cauchy_step > xu)
|
||||
if not np.any(fixed_xl) and not np.any(fixed_xu):
|
||||
# Stop the calculations as the Cauchy step is now feasible.
|
||||
break
|
||||
cauchy_step[fixed_xl] = xl[fixed_xl]
|
||||
cauchy_step[fixed_xu] = xu[fixed_xu]
|
||||
working = working & ~(fixed_xl | fixed_xu)
|
||||
|
||||
# Calculate the step that maximizes the quadratic along the Cauchy step.
|
||||
grad_step = grad @ cauchy_step
|
||||
if grad_step >= 0.0:
|
||||
# Set alpha_tr to the step size for the trust-region constraint.
|
||||
s_norm = np.linalg.norm(cauchy_step)
|
||||
if s_norm > TINY * delta:
|
||||
alpha_tr = max(delta / s_norm, 0.0)
|
||||
else:
|
||||
# The Cauchy step is basically zero.
|
||||
alpha_tr = 0.0
|
||||
|
||||
# Set alpha_quad to the step size for the maximization problem.
|
||||
curv_step = curv(cauchy_step)
|
||||
if curv_step < -TINY * grad_step:
|
||||
alpha_quad = max(-grad_step / curv_step, 0.0)
|
||||
else:
|
||||
alpha_quad = np.inf
|
||||
|
||||
# Set alpha_bd to the step size for the bound constraints.
|
||||
i_xl = (xl > -np.inf) & (cauchy_step < TINY * xl)
|
||||
i_xu = (xu < np.inf) & (cauchy_step > TINY * xu)
|
||||
alpha_xl = np.min(xl[i_xl] / cauchy_step[i_xl], initial=np.inf)
|
||||
alpha_xu = np.min(xu[i_xu] / cauchy_step[i_xu], initial=np.inf)
|
||||
alpha_bd = min(alpha_xl, alpha_xu)
|
||||
|
||||
# Calculate the solution and the corresponding function value.
|
||||
alpha = min(alpha_tr, alpha_quad, alpha_bd)
|
||||
step = np.clip(alpha * cauchy_step, xl, xu)
|
||||
q_val = const + alpha * grad_step + 0.5 * alpha**2.0 * curv_step
|
||||
else:
|
||||
# This case is never reached in exact arithmetic. It prevents this
|
||||
# function to return a step that decreases the objective function.
|
||||
step = np.zeros_like(grad)
|
||||
q_val = const
|
||||
|
||||
if debug:
|
||||
assert np.all(xl <= step)
|
||||
assert np.all(step <= xu)
|
||||
assert np.linalg.norm(step) < 1.1 * delta
|
||||
return step, q_val
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,18 @@
|
||||
from .exceptions import (
|
||||
MaxEvalError,
|
||||
TargetSuccess,
|
||||
CallbackSuccess,
|
||||
FeasibleSuccess,
|
||||
)
|
||||
from .math import get_arrays_tol, exact_1d_array
|
||||
from .versions import show_versions
|
||||
|
||||
__all__ = [
|
||||
"MaxEvalError",
|
||||
"TargetSuccess",
|
||||
"CallbackSuccess",
|
||||
"FeasibleSuccess",
|
||||
"get_arrays_tol",
|
||||
"exact_1d_array",
|
||||
"show_versions",
|
||||
]
|
||||
@ -0,0 +1,22 @@
|
||||
class MaxEvalError(Exception):
|
||||
"""
|
||||
Exception raised when the maximum number of evaluations is reached.
|
||||
"""
|
||||
|
||||
|
||||
class TargetSuccess(Exception):
|
||||
"""
|
||||
Exception raised when the target value is reached.
|
||||
"""
|
||||
|
||||
|
||||
class CallbackSuccess(StopIteration):
|
||||
"""
|
||||
Exception raised when the callback function raises a ``StopIteration``.
|
||||
"""
|
||||
|
||||
|
||||
class FeasibleSuccess(Exception):
|
||||
"""
|
||||
Exception raised when a feasible point of a feasible problem is found.
|
||||
"""
|
||||
@ -0,0 +1,77 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
EPS = np.finfo(float).eps
|
||||
|
||||
|
||||
def get_arrays_tol(*arrays):
|
||||
"""
|
||||
Get a relative tolerance for a set of arrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
*arrays: tuple
|
||||
Set of `numpy.ndarray` to get the tolerance for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
Relative tolerance for the set of arrays.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If no array is provided.
|
||||
"""
|
||||
if len(arrays) == 0:
|
||||
raise ValueError("At least one array must be provided.")
|
||||
size = max(array.size for array in arrays)
|
||||
weight = max(
|
||||
np.max(np.abs(array[np.isfinite(array)]), initial=1.0)
|
||||
for array in arrays
|
||||
)
|
||||
return 10.0 * EPS * max(size, 1.0) * weight
|
||||
|
||||
|
||||
def exact_1d_array(x, message):
|
||||
"""
|
||||
Preprocess a 1-dimensional array.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : array_like
|
||||
Array to be preprocessed.
|
||||
message : str
|
||||
Error message if `x` cannot be interpreter as a 1-dimensional array.
|
||||
|
||||
Returns
|
||||
-------
|
||||
`numpy.ndarray`
|
||||
Preprocessed array.
|
||||
"""
|
||||
x = np.atleast_1d(np.squeeze(x)).astype(float)
|
||||
if x.ndim != 1:
|
||||
raise ValueError(message)
|
||||
return x
|
||||
|
||||
|
||||
def exact_2d_array(x, message):
|
||||
"""
|
||||
Preprocess a 2-dimensional array.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : array_like
|
||||
Array to be preprocessed.
|
||||
message : str
|
||||
Error message if `x` cannot be interpreter as a 2-dimensional array.
|
||||
|
||||
Returns
|
||||
-------
|
||||
`numpy.ndarray`
|
||||
Preprocessed array.
|
||||
"""
|
||||
x = np.atleast_2d(x).astype(float)
|
||||
if x.ndim != 2:
|
||||
raise ValueError(message)
|
||||
return x
|
||||
@ -0,0 +1,67 @@
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
|
||||
def _get_sys_info():
|
||||
"""
|
||||
Get useful system information.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Useful system information.
|
||||
"""
|
||||
return {
|
||||
"python": sys.version.replace(os.linesep, " "),
|
||||
"executable": sys.executable,
|
||||
"machine": platform.platform(),
|
||||
}
|
||||
|
||||
|
||||
def _get_deps_info():
|
||||
"""
|
||||
Get the versions of the dependencies.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Versions of the dependencies.
|
||||
"""
|
||||
deps = ["cobyqa", "numpy", "scipy", "setuptools", "pip"]
|
||||
deps_info = {}
|
||||
for module in deps:
|
||||
try:
|
||||
deps_info[module] = version(module)
|
||||
except PackageNotFoundError:
|
||||
deps_info[module] = None
|
||||
return deps_info
|
||||
|
||||
|
||||
def show_versions():
|
||||
"""
|
||||
Display useful system and dependencies information.
|
||||
|
||||
When reporting issues, please include this information.
|
||||
"""
|
||||
print("System settings")
|
||||
print("---------------")
|
||||
sys_info = _get_sys_info()
|
||||
print(
|
||||
"\n".join(
|
||||
f"{k:>{max(map(len, sys_info.keys())) + 1}}: {v}"
|
||||
for k, v in sys_info.items()
|
||||
)
|
||||
)
|
||||
|
||||
print()
|
||||
print("Python dependencies")
|
||||
print("-------------------")
|
||||
deps_info = _get_deps_info()
|
||||
print(
|
||||
"\n".join(
|
||||
f"{k:>{max(map(len, deps_info.keys())) + 1}}: {v}"
|
||||
for k, v in deps_info.items()
|
||||
)
|
||||
)
|
||||
399
venv/lib/python3.12/site-packages/scipy/_lib/decorator.py
Normal file
399
venv/lib/python3.12/site-packages/scipy/_lib/decorator.py
Normal file
@ -0,0 +1,399 @@
|
||||
# ######################### LICENSE ############################ #
|
||||
|
||||
# Copyright (c) 2005-2015, Michele Simionato
|
||||
# All rights reserved.
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are
|
||||
# met:
|
||||
|
||||
# Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# Redistributions in bytecode form must reproduce the above copyright
|
||||
# notice, this list of conditions and the following disclaimer in
|
||||
# the documentation and/or other materials provided with the
|
||||
# distribution.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
# HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
|
||||
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
|
||||
# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
|
||||
# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
||||
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
|
||||
# DAMAGE.
|
||||
|
||||
"""
|
||||
Decorator module, see https://pypi.python.org/pypi/decorator
|
||||
for the documentation.
|
||||
"""
|
||||
import re
|
||||
import sys
|
||||
import inspect
|
||||
import operator
|
||||
import itertools
|
||||
import collections
|
||||
|
||||
from inspect import getfullargspec
|
||||
|
||||
__version__ = '4.0.5'
|
||||
|
||||
|
||||
def get_init(cls):
|
||||
return cls.__init__
|
||||
|
||||
|
||||
# getargspec has been deprecated in Python 3.5
|
||||
ArgSpec = collections.namedtuple(
|
||||
'ArgSpec', 'args varargs varkw defaults')
|
||||
|
||||
|
||||
def getargspec(f):
|
||||
"""A replacement for inspect.getargspec"""
|
||||
spec = getfullargspec(f)
|
||||
return ArgSpec(spec.args, spec.varargs, spec.varkw, spec.defaults)
|
||||
|
||||
|
||||
DEF = re.compile(r'\s*def\s*([_\w][_\w\d]*)\s*\(')
|
||||
|
||||
|
||||
# basic functionality
|
||||
class FunctionMaker:
|
||||
"""
|
||||
An object with the ability to create functions with a given signature.
|
||||
It has attributes name, doc, module, signature, defaults, dict, and
|
||||
methods update and make.
|
||||
"""
|
||||
|
||||
# Atomic get-and-increment provided by the GIL
|
||||
_compile_count = itertools.count()
|
||||
|
||||
def __init__(self, func=None, name=None, signature=None,
|
||||
defaults=None, doc=None, module=None, funcdict=None):
|
||||
self.shortsignature = signature
|
||||
if func:
|
||||
# func can be a class or a callable, but not an instance method
|
||||
self.name = func.__name__
|
||||
if self.name == '<lambda>': # small hack for lambda functions
|
||||
self.name = '_lambda_'
|
||||
self.doc = func.__doc__
|
||||
self.module = func.__module__
|
||||
if inspect.isfunction(func):
|
||||
argspec = getfullargspec(func)
|
||||
self.annotations = getattr(func, '__annotations__', {})
|
||||
for a in ('args', 'varargs', 'varkw', 'defaults', 'kwonlyargs',
|
||||
'kwonlydefaults'):
|
||||
setattr(self, a, getattr(argspec, a))
|
||||
for i, arg in enumerate(self.args):
|
||||
setattr(self, 'arg%d' % i, arg)
|
||||
allargs = list(self.args)
|
||||
allshortargs = list(self.args)
|
||||
if self.varargs:
|
||||
allargs.append('*' + self.varargs)
|
||||
allshortargs.append('*' + self.varargs)
|
||||
elif self.kwonlyargs:
|
||||
allargs.append('*') # single star syntax
|
||||
for a in self.kwonlyargs:
|
||||
allargs.append('%s=None' % a)
|
||||
allshortargs.append(f'{a}={a}')
|
||||
if self.varkw:
|
||||
allargs.append('**' + self.varkw)
|
||||
allshortargs.append('**' + self.varkw)
|
||||
self.signature = ', '.join(allargs)
|
||||
self.shortsignature = ', '.join(allshortargs)
|
||||
self.dict = func.__dict__.copy()
|
||||
# func=None happens when decorating a caller
|
||||
if name:
|
||||
self.name = name
|
||||
if signature is not None:
|
||||
self.signature = signature
|
||||
if defaults:
|
||||
self.defaults = defaults
|
||||
if doc:
|
||||
self.doc = doc
|
||||
if module:
|
||||
self.module = module
|
||||
if funcdict:
|
||||
self.dict = funcdict
|
||||
# check existence required attributes
|
||||
assert hasattr(self, 'name')
|
||||
if not hasattr(self, 'signature'):
|
||||
raise TypeError('You are decorating a non-function: %s' % func)
|
||||
|
||||
def update(self, func, **kw):
|
||||
"Update the signature of func with the data in self"
|
||||
func.__name__ = self.name
|
||||
func.__doc__ = getattr(self, 'doc', None)
|
||||
func.__dict__ = getattr(self, 'dict', {})
|
||||
func.__defaults__ = getattr(self, 'defaults', ())
|
||||
func.__kwdefaults__ = getattr(self, 'kwonlydefaults', None)
|
||||
func.__annotations__ = getattr(self, 'annotations', None)
|
||||
try:
|
||||
frame = sys._getframe(3)
|
||||
except AttributeError: # for IronPython and similar implementations
|
||||
callermodule = '?'
|
||||
else:
|
||||
callermodule = frame.f_globals.get('__name__', '?')
|
||||
func.__module__ = getattr(self, 'module', callermodule)
|
||||
func.__dict__.update(kw)
|
||||
|
||||
def make(self, src_templ, evaldict=None, addsource=False, **attrs):
|
||||
"Make a new function from a given template and update the signature"
|
||||
src = src_templ % vars(self) # expand name and signature
|
||||
evaldict = evaldict or {}
|
||||
mo = DEF.match(src)
|
||||
if mo is None:
|
||||
raise SyntaxError('not a valid function template\n%s' % src)
|
||||
name = mo.group(1) # extract the function name
|
||||
names = set([name] + [arg.strip(' *') for arg in
|
||||
self.shortsignature.split(',')])
|
||||
for n in names:
|
||||
if n in ('_func_', '_call_'):
|
||||
raise NameError(f'{n} is overridden in\n{src}')
|
||||
if not src.endswith('\n'): # add a newline just for safety
|
||||
src += '\n' # this is needed in old versions of Python
|
||||
|
||||
# Ensure each generated function has a unique filename for profilers
|
||||
# (such as cProfile) that depend on the tuple of (<filename>,
|
||||
# <definition line>, <function name>) being unique.
|
||||
filename = '<decorator-gen-%d>' % (next(self._compile_count),)
|
||||
try:
|
||||
code = compile(src, filename, 'single')
|
||||
exec(code, evaldict)
|
||||
except: # noqa: E722
|
||||
print('Error in generated code:', file=sys.stderr)
|
||||
print(src, file=sys.stderr)
|
||||
raise
|
||||
func = evaldict[name]
|
||||
if addsource:
|
||||
attrs['__source__'] = src
|
||||
self.update(func, **attrs)
|
||||
return func
|
||||
|
||||
@classmethod
|
||||
def create(cls, obj, body, evaldict, defaults=None,
|
||||
doc=None, module=None, addsource=True, **attrs):
|
||||
"""
|
||||
Create a function from the strings name, signature, and body.
|
||||
evaldict is the evaluation dictionary. If addsource is true, an
|
||||
attribute __source__ is added to the result. The attributes attrs
|
||||
are added, if any.
|
||||
"""
|
||||
if isinstance(obj, str): # "name(signature)"
|
||||
name, rest = obj.strip().split('(', 1)
|
||||
signature = rest[:-1] # strip a right parens
|
||||
func = None
|
||||
else: # a function
|
||||
name = None
|
||||
signature = None
|
||||
func = obj
|
||||
self = cls(func, name, signature, defaults, doc, module)
|
||||
ibody = '\n'.join(' ' + line for line in body.splitlines())
|
||||
return self.make('def %(name)s(%(signature)s):\n' + ibody,
|
||||
evaldict, addsource, **attrs)
|
||||
|
||||
|
||||
def decorate(func, caller):
|
||||
"""
|
||||
decorate(func, caller) decorates a function using a caller.
|
||||
"""
|
||||
evaldict = func.__globals__.copy()
|
||||
evaldict['_call_'] = caller
|
||||
evaldict['_func_'] = func
|
||||
fun = FunctionMaker.create(
|
||||
func, "return _call_(_func_, %(shortsignature)s)",
|
||||
evaldict, __wrapped__=func)
|
||||
if hasattr(func, '__qualname__'):
|
||||
fun.__qualname__ = func.__qualname__
|
||||
return fun
|
||||
|
||||
|
||||
def decorator(caller, _func=None):
|
||||
"""decorator(caller) converts a caller function into a decorator"""
|
||||
if _func is not None: # return a decorated function
|
||||
# this is obsolete behavior; you should use decorate instead
|
||||
return decorate(_func, caller)
|
||||
# else return a decorator function
|
||||
if inspect.isclass(caller):
|
||||
name = caller.__name__.lower()
|
||||
callerfunc = get_init(caller)
|
||||
doc = (f'decorator({caller.__name__}) converts functions/generators into '
|
||||
f'factories of {caller.__name__} objects')
|
||||
elif inspect.isfunction(caller):
|
||||
if caller.__name__ == '<lambda>':
|
||||
name = '_lambda_'
|
||||
else:
|
||||
name = caller.__name__
|
||||
callerfunc = caller
|
||||
doc = caller.__doc__
|
||||
else: # assume caller is an object with a __call__ method
|
||||
name = caller.__class__.__name__.lower()
|
||||
callerfunc = caller.__call__.__func__
|
||||
doc = caller.__call__.__doc__
|
||||
evaldict = callerfunc.__globals__.copy()
|
||||
evaldict['_call_'] = caller
|
||||
evaldict['_decorate_'] = decorate
|
||||
return FunctionMaker.create(
|
||||
'%s(func)' % name, 'return _decorate_(func, _call_)',
|
||||
evaldict, doc=doc, module=caller.__module__,
|
||||
__wrapped__=caller)
|
||||
|
||||
|
||||
# ####################### contextmanager ####################### #
|
||||
|
||||
try: # Python >= 3.2
|
||||
from contextlib import _GeneratorContextManager
|
||||
except ImportError: # Python >= 2.5
|
||||
from contextlib import GeneratorContextManager as _GeneratorContextManager
|
||||
|
||||
|
||||
class ContextManager(_GeneratorContextManager):
|
||||
def __call__(self, func):
|
||||
"""Context manager decorator"""
|
||||
return FunctionMaker.create(
|
||||
func, "with _self_: return _func_(%(shortsignature)s)",
|
||||
dict(_self_=self, _func_=func), __wrapped__=func)
|
||||
|
||||
|
||||
init = getfullargspec(_GeneratorContextManager.__init__)
|
||||
n_args = len(init.args)
|
||||
if n_args == 2 and not init.varargs: # (self, genobj) Python 2.7
|
||||
def __init__(self, g, *a, **k):
|
||||
return _GeneratorContextManager.__init__(self, g(*a, **k))
|
||||
ContextManager.__init__ = __init__
|
||||
elif n_args == 2 and init.varargs: # (self, gen, *a, **k) Python 3.4
|
||||
pass
|
||||
elif n_args == 4: # (self, gen, args, kwds) Python 3.5
|
||||
def __init__(self, g, *a, **k):
|
||||
return _GeneratorContextManager.__init__(self, g, a, k)
|
||||
ContextManager.__init__ = __init__
|
||||
|
||||
contextmanager = decorator(ContextManager)
|
||||
|
||||
|
||||
# ############################ dispatch_on ############################ #
|
||||
|
||||
def append(a, vancestors):
|
||||
"""
|
||||
Append ``a`` to the list of the virtual ancestors, unless it is already
|
||||
included.
|
||||
"""
|
||||
add = True
|
||||
for j, va in enumerate(vancestors):
|
||||
if issubclass(va, a):
|
||||
add = False
|
||||
break
|
||||
if issubclass(a, va):
|
||||
vancestors[j] = a
|
||||
add = False
|
||||
if add:
|
||||
vancestors.append(a)
|
||||
|
||||
|
||||
# inspired from simplegeneric by P.J. Eby and functools.singledispatch
|
||||
def dispatch_on(*dispatch_args):
|
||||
"""
|
||||
Factory of decorators turning a function into a generic function
|
||||
dispatching on the given arguments.
|
||||
"""
|
||||
assert dispatch_args, 'No dispatch args passed'
|
||||
dispatch_str = '(%s,)' % ', '.join(dispatch_args)
|
||||
|
||||
def check(arguments, wrong=operator.ne, msg=''):
|
||||
"""Make sure one passes the expected number of arguments"""
|
||||
if wrong(len(arguments), len(dispatch_args)):
|
||||
raise TypeError('Expected %d arguments, got %d%s' %
|
||||
(len(dispatch_args), len(arguments), msg))
|
||||
|
||||
def gen_func_dec(func):
|
||||
"""Decorator turning a function into a generic function"""
|
||||
|
||||
# first check the dispatch arguments
|
||||
argset = set(getfullargspec(func).args)
|
||||
if not set(dispatch_args) <= argset:
|
||||
raise NameError('Unknown dispatch arguments %s' % dispatch_str)
|
||||
|
||||
typemap = {}
|
||||
|
||||
def vancestors(*types):
|
||||
"""
|
||||
Get a list of sets of virtual ancestors for the given types
|
||||
"""
|
||||
check(types)
|
||||
ras = [[] for _ in range(len(dispatch_args))]
|
||||
for types_ in typemap:
|
||||
for t, type_, ra in zip(types, types_, ras):
|
||||
if issubclass(t, type_) and type_ not in t.__mro__:
|
||||
append(type_, ra)
|
||||
return [set(ra) for ra in ras]
|
||||
|
||||
def ancestors(*types):
|
||||
"""
|
||||
Get a list of virtual MROs, one for each type
|
||||
"""
|
||||
check(types)
|
||||
lists = []
|
||||
for t, vas in zip(types, vancestors(*types)):
|
||||
n_vas = len(vas)
|
||||
if n_vas > 1:
|
||||
raise RuntimeError(
|
||||
f'Ambiguous dispatch for {t}: {vas}')
|
||||
elif n_vas == 1:
|
||||
va, = vas
|
||||
mro = type('t', (t, va), {}).__mro__[1:]
|
||||
else:
|
||||
mro = t.__mro__
|
||||
lists.append(mro[:-1]) # discard t and object
|
||||
return lists
|
||||
|
||||
def register(*types):
|
||||
"""
|
||||
Decorator to register an implementation for the given types
|
||||
"""
|
||||
check(types)
|
||||
|
||||
def dec(f):
|
||||
check(getfullargspec(f).args, operator.lt, ' in ' + f.__name__)
|
||||
typemap[types] = f
|
||||
return f
|
||||
return dec
|
||||
|
||||
def dispatch_info(*types):
|
||||
"""
|
||||
An utility to introspect the dispatch algorithm
|
||||
"""
|
||||
check(types)
|
||||
lst = [tuple(a.__name__ for a in anc)
|
||||
for anc in itertools.product(*ancestors(*types))]
|
||||
return lst
|
||||
|
||||
def _dispatch(dispatch_args, *args, **kw):
|
||||
types = tuple(type(arg) for arg in dispatch_args)
|
||||
try: # fast path
|
||||
f = typemap[types]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
return f(*args, **kw)
|
||||
combinations = itertools.product(*ancestors(*types))
|
||||
next(combinations) # the first one has been already tried
|
||||
for types_ in combinations:
|
||||
f = typemap.get(types_)
|
||||
if f is not None:
|
||||
return f(*args, **kw)
|
||||
|
||||
# else call the default implementation
|
||||
return func(*args, **kw)
|
||||
|
||||
return FunctionMaker.create(
|
||||
func, 'return _f_(%s, %%(shortsignature)s)' % dispatch_str,
|
||||
dict(_f_=_dispatch), register=register, default=func,
|
||||
typemap=typemap, vancestors=vancestors, ancestors=ancestors,
|
||||
dispatch_info=dispatch_info, __wrapped__=func)
|
||||
|
||||
gen_func_dec.__name__ = 'dispatch_on' + dispatch_str
|
||||
return gen_func_dec
|
||||
239
venv/lib/python3.12/site-packages/scipy/_lib/deprecation.py
Normal file
239
venv/lib/python3.12/site-packages/scipy/_lib/deprecation.py
Normal file
@ -0,0 +1,239 @@
|
||||
from inspect import Parameter, signature
|
||||
import functools
|
||||
import warnings
|
||||
from importlib import import_module
|
||||
|
||||
|
||||
__all__ = ["_deprecated"]
|
||||
|
||||
|
||||
# Object to use as default value for arguments to be deprecated. This should
|
||||
# be used over 'None' as the user could parse 'None' as a positional argument
|
||||
_NoValue = object()
|
||||
|
||||
def _sub_module_deprecation(*, sub_package, module, private_modules, all,
|
||||
attribute, correct_module=None):
|
||||
"""Helper function for deprecating modules that are public but were
|
||||
intended to be private.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sub_package : str
|
||||
Subpackage the module belongs to eg. stats
|
||||
module : str
|
||||
Public but intended private module to deprecate
|
||||
private_modules : list
|
||||
Private replacement(s) for `module`; should contain the
|
||||
content of ``all``, possibly spread over several modules.
|
||||
all : list
|
||||
``__all__`` belonging to `module`
|
||||
attribute : str
|
||||
The attribute in `module` being accessed
|
||||
correct_module : str, optional
|
||||
Module in `sub_package` that `attribute` should be imported from.
|
||||
Default is that `attribute` should be imported from ``scipy.sub_package``.
|
||||
"""
|
||||
if correct_module is not None:
|
||||
correct_import = f"scipy.{sub_package}.{correct_module}"
|
||||
else:
|
||||
correct_import = f"scipy.{sub_package}"
|
||||
|
||||
if attribute not in all:
|
||||
raise AttributeError(
|
||||
f"`scipy.{sub_package}.{module}` has no attribute `{attribute}`; "
|
||||
f"furthermore, `scipy.{sub_package}.{module}` is deprecated "
|
||||
f"and will be removed in SciPy 2.0.0."
|
||||
)
|
||||
|
||||
attr = getattr(import_module(correct_import), attribute, None)
|
||||
|
||||
if attr is not None:
|
||||
message = (
|
||||
f"Please import `{attribute}` from the `{correct_import}` namespace; "
|
||||
f"the `scipy.{sub_package}.{module}` namespace is deprecated "
|
||||
f"and will be removed in SciPy 2.0.0."
|
||||
)
|
||||
else:
|
||||
message = (
|
||||
f"`scipy.{sub_package}.{module}.{attribute}` is deprecated along with "
|
||||
f"the `scipy.{sub_package}.{module}` namespace. "
|
||||
f"`scipy.{sub_package}.{module}.{attribute}` will be removed "
|
||||
f"in SciPy 1.14.0, and the `scipy.{sub_package}.{module}` namespace "
|
||||
f"will be removed in SciPy 2.0.0."
|
||||
)
|
||||
|
||||
warnings.warn(message, category=DeprecationWarning, stacklevel=3)
|
||||
|
||||
for module in private_modules:
|
||||
try:
|
||||
return getattr(import_module(f"scipy.{sub_package}.{module}"), attribute)
|
||||
except AttributeError as e:
|
||||
# still raise an error if the attribute isn't in any of the expected
|
||||
# private modules
|
||||
if module == private_modules[-1]:
|
||||
raise e
|
||||
continue
|
||||
|
||||
|
||||
def _deprecated(msg, stacklevel=2):
|
||||
"""Deprecate a function by emitting a warning on use."""
|
||||
def wrap(fun):
|
||||
if isinstance(fun, type):
|
||||
warnings.warn(
|
||||
f"Trying to deprecate class {fun!r}",
|
||||
category=RuntimeWarning, stacklevel=2)
|
||||
return fun
|
||||
|
||||
@functools.wraps(fun)
|
||||
def call(*args, **kwargs):
|
||||
warnings.warn(msg, category=DeprecationWarning,
|
||||
stacklevel=stacklevel)
|
||||
return fun(*args, **kwargs)
|
||||
call.__doc__ = fun.__doc__
|
||||
return call
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
class _DeprecationHelperStr:
|
||||
"""
|
||||
Helper class used by deprecate_cython_api
|
||||
"""
|
||||
def __init__(self, content, message):
|
||||
self._content = content
|
||||
self._message = message
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self._content)
|
||||
|
||||
def __eq__(self, other):
|
||||
res = (self._content == other)
|
||||
if res:
|
||||
warnings.warn(self._message, category=DeprecationWarning,
|
||||
stacklevel=2)
|
||||
return res
|
||||
|
||||
|
||||
def deprecate_cython_api(module, routine_name, new_name=None, message=None):
|
||||
"""
|
||||
Deprecate an exported cdef function in a public Cython API module.
|
||||
|
||||
Only functions can be deprecated; typedefs etc. cannot.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
module : module
|
||||
Public Cython API module (e.g. scipy.linalg.cython_blas).
|
||||
routine_name : str
|
||||
Name of the routine to deprecate. May also be a fused-type
|
||||
routine (in which case its all specializations are deprecated).
|
||||
new_name : str
|
||||
New name to include in the deprecation warning message
|
||||
message : str
|
||||
Additional text in the deprecation warning message
|
||||
|
||||
Examples
|
||||
--------
|
||||
Usually, this function would be used in the top-level of the
|
||||
module ``.pyx`` file:
|
||||
|
||||
>>> from scipy._lib.deprecation import deprecate_cython_api
|
||||
>>> import scipy.linalg.cython_blas as mod
|
||||
>>> deprecate_cython_api(mod, "dgemm", "dgemm_new",
|
||||
... message="Deprecated in Scipy 1.5.0")
|
||||
>>> del deprecate_cython_api, mod
|
||||
|
||||
After this, Cython modules that use the deprecated function emit a
|
||||
deprecation warning when they are imported.
|
||||
|
||||
"""
|
||||
old_name = f"{module.__name__}.{routine_name}"
|
||||
|
||||
if new_name is None:
|
||||
depdoc = "`%s` is deprecated!" % old_name
|
||||
else:
|
||||
depdoc = f"`{old_name}` is deprecated, use `{new_name}` instead!"
|
||||
|
||||
if message is not None:
|
||||
depdoc += "\n" + message
|
||||
|
||||
d = module.__pyx_capi__
|
||||
|
||||
# Check if the function is a fused-type function with a mangled name
|
||||
j = 0
|
||||
has_fused = False
|
||||
while True:
|
||||
fused_name = f"__pyx_fuse_{j}{routine_name}"
|
||||
if fused_name in d:
|
||||
has_fused = True
|
||||
d[_DeprecationHelperStr(fused_name, depdoc)] = d.pop(fused_name)
|
||||
j += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# If not, apply deprecation to the named routine
|
||||
if not has_fused:
|
||||
d[_DeprecationHelperStr(routine_name, depdoc)] = d.pop(routine_name)
|
||||
|
||||
|
||||
# taken from scikit-learn, see
|
||||
# https://github.com/scikit-learn/scikit-learn/blob/1.3.0/sklearn/utils/validation.py#L38
|
||||
def _deprecate_positional_args(func=None, *, version=None):
|
||||
"""Decorator for methods that issues warnings for positional arguments.
|
||||
|
||||
Using the keyword-only argument syntax in pep 3102, arguments after the
|
||||
* will issue a warning when passed as a positional argument.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : callable, default=None
|
||||
Function to check arguments on.
|
||||
version : callable, default=None
|
||||
The version when positional arguments will result in error.
|
||||
"""
|
||||
if version is None:
|
||||
msg = "Need to specify a version where signature will be changed"
|
||||
raise ValueError(msg)
|
||||
|
||||
def _inner_deprecate_positional_args(f):
|
||||
sig = signature(f)
|
||||
kwonly_args = []
|
||||
all_args = []
|
||||
|
||||
for name, param in sig.parameters.items():
|
||||
if param.kind == Parameter.POSITIONAL_OR_KEYWORD:
|
||||
all_args.append(name)
|
||||
elif param.kind == Parameter.KEYWORD_ONLY:
|
||||
kwonly_args.append(name)
|
||||
|
||||
@functools.wraps(f)
|
||||
def inner_f(*args, **kwargs):
|
||||
extra_args = len(args) - len(all_args)
|
||||
if extra_args <= 0:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
# extra_args > 0
|
||||
args_msg = [
|
||||
f"{name}={arg}"
|
||||
for name, arg in zip(kwonly_args[:extra_args], args[-extra_args:])
|
||||
]
|
||||
args_msg = ", ".join(args_msg)
|
||||
warnings.warn(
|
||||
(
|
||||
f"You are passing {args_msg} as a positional argument. "
|
||||
"Please change your invocation to use keyword arguments. "
|
||||
f"From SciPy {version}, passing these as positional "
|
||||
"arguments will result in an error."
|
||||
),
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
kwargs.update(zip(sig.parameters, args))
|
||||
return f(**kwargs)
|
||||
|
||||
return inner_f
|
||||
|
||||
if func is not None:
|
||||
return _inner_deprecate_positional_args(func)
|
||||
|
||||
return _inner_deprecate_positional_args
|
||||
275
venv/lib/python3.12/site-packages/scipy/_lib/doccer.py
Normal file
275
venv/lib/python3.12/site-packages/scipy/_lib/doccer.py
Normal file
@ -0,0 +1,275 @@
|
||||
''' Utilities to allow inserting docstring fragments for common
|
||||
parameters into function and method docstrings'''
|
||||
|
||||
import sys
|
||||
|
||||
__all__ = [
|
||||
'docformat', 'inherit_docstring_from', 'indentcount_lines',
|
||||
'filldoc', 'unindent_dict', 'unindent_string', 'extend_notes_in_docstring',
|
||||
'replace_notes_in_docstring', 'doc_replace'
|
||||
]
|
||||
|
||||
|
||||
def docformat(docstring, docdict=None):
|
||||
''' Fill a function docstring from variables in dictionary
|
||||
|
||||
Adapt the indent of the inserted docs
|
||||
|
||||
Parameters
|
||||
----------
|
||||
docstring : string
|
||||
docstring from function, possibly with dict formatting strings
|
||||
docdict : dict, optional
|
||||
dictionary with keys that match the dict formatting strings
|
||||
and values that are docstring fragments to be inserted. The
|
||||
indentation of the inserted docstrings is set to match the
|
||||
minimum indentation of the ``docstring`` by adding this
|
||||
indentation to all lines of the inserted string, except the
|
||||
first.
|
||||
|
||||
Returns
|
||||
-------
|
||||
outstring : string
|
||||
string with requested ``docdict`` strings inserted
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> docformat(' Test string with %(value)s', {'value':'inserted value'})
|
||||
' Test string with inserted value'
|
||||
>>> docstring = 'First line\\n Second line\\n %(value)s'
|
||||
>>> inserted_string = "indented\\nstring"
|
||||
>>> docdict = {'value': inserted_string}
|
||||
>>> docformat(docstring, docdict)
|
||||
'First line\\n Second line\\n indented\\n string'
|
||||
'''
|
||||
if not docstring:
|
||||
return docstring
|
||||
if docdict is None:
|
||||
docdict = {}
|
||||
if not docdict:
|
||||
return docstring
|
||||
lines = docstring.expandtabs().splitlines()
|
||||
# Find the minimum indent of the main docstring, after first line
|
||||
if len(lines) < 2:
|
||||
icount = 0
|
||||
else:
|
||||
icount = indentcount_lines(lines[1:])
|
||||
indent = ' ' * icount
|
||||
# Insert this indent to dictionary docstrings
|
||||
indented = {}
|
||||
for name, dstr in docdict.items():
|
||||
lines = dstr.expandtabs().splitlines()
|
||||
try:
|
||||
newlines = [lines[0]]
|
||||
for line in lines[1:]:
|
||||
newlines.append(indent+line)
|
||||
indented[name] = '\n'.join(newlines)
|
||||
except IndexError:
|
||||
indented[name] = dstr
|
||||
return docstring % indented
|
||||
|
||||
|
||||
def inherit_docstring_from(cls):
|
||||
"""
|
||||
This decorator modifies the decorated function's docstring by
|
||||
replacing occurrences of '%(super)s' with the docstring of the
|
||||
method of the same name from the class `cls`.
|
||||
|
||||
If the decorated method has no docstring, it is simply given the
|
||||
docstring of `cls`s method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cls : Python class or instance
|
||||
A class with a method with the same name as the decorated method.
|
||||
The docstring of the method in this class replaces '%(super)s' in the
|
||||
docstring of the decorated method.
|
||||
|
||||
Returns
|
||||
-------
|
||||
f : function
|
||||
The decorator function that modifies the __doc__ attribute
|
||||
of its argument.
|
||||
|
||||
Examples
|
||||
--------
|
||||
In the following, the docstring for Bar.func created using the
|
||||
docstring of `Foo.func`.
|
||||
|
||||
>>> class Foo:
|
||||
... def func(self):
|
||||
... '''Do something useful.'''
|
||||
... return
|
||||
...
|
||||
>>> class Bar(Foo):
|
||||
... @inherit_docstring_from(Foo)
|
||||
... def func(self):
|
||||
... '''%(super)s
|
||||
... Do it fast.
|
||||
... '''
|
||||
... return
|
||||
...
|
||||
>>> b = Bar()
|
||||
>>> b.func.__doc__
|
||||
'Do something useful.\n Do it fast.\n '
|
||||
|
||||
"""
|
||||
def _doc(func):
|
||||
cls_docstring = getattr(cls, func.__name__).__doc__
|
||||
func_docstring = func.__doc__
|
||||
if func_docstring is None:
|
||||
func.__doc__ = cls_docstring
|
||||
else:
|
||||
new_docstring = func_docstring % dict(super=cls_docstring)
|
||||
func.__doc__ = new_docstring
|
||||
return func
|
||||
return _doc
|
||||
|
||||
|
||||
def extend_notes_in_docstring(cls, notes):
|
||||
"""
|
||||
This decorator replaces the decorated function's docstring
|
||||
with the docstring from corresponding method in `cls`.
|
||||
It extends the 'Notes' section of that docstring to include
|
||||
the given `notes`.
|
||||
"""
|
||||
def _doc(func):
|
||||
cls_docstring = getattr(cls, func.__name__).__doc__
|
||||
# If python is called with -OO option,
|
||||
# there is no docstring
|
||||
if cls_docstring is None:
|
||||
return func
|
||||
end_of_notes = cls_docstring.find(' References\n')
|
||||
if end_of_notes == -1:
|
||||
end_of_notes = cls_docstring.find(' Examples\n')
|
||||
if end_of_notes == -1:
|
||||
end_of_notes = len(cls_docstring)
|
||||
func.__doc__ = (cls_docstring[:end_of_notes] + notes +
|
||||
cls_docstring[end_of_notes:])
|
||||
return func
|
||||
return _doc
|
||||
|
||||
|
||||
def replace_notes_in_docstring(cls, notes):
|
||||
"""
|
||||
This decorator replaces the decorated function's docstring
|
||||
with the docstring from corresponding method in `cls`.
|
||||
It replaces the 'Notes' section of that docstring with
|
||||
the given `notes`.
|
||||
"""
|
||||
def _doc(func):
|
||||
cls_docstring = getattr(cls, func.__name__).__doc__
|
||||
notes_header = ' Notes\n -----\n'
|
||||
# If python is called with -OO option,
|
||||
# there is no docstring
|
||||
if cls_docstring is None:
|
||||
return func
|
||||
start_of_notes = cls_docstring.find(notes_header)
|
||||
end_of_notes = cls_docstring.find(' References\n')
|
||||
if end_of_notes == -1:
|
||||
end_of_notes = cls_docstring.find(' Examples\n')
|
||||
if end_of_notes == -1:
|
||||
end_of_notes = len(cls_docstring)
|
||||
func.__doc__ = (cls_docstring[:start_of_notes + len(notes_header)] +
|
||||
notes +
|
||||
cls_docstring[end_of_notes:])
|
||||
return func
|
||||
return _doc
|
||||
|
||||
|
||||
def indentcount_lines(lines):
|
||||
''' Minimum indent for all lines in line list
|
||||
|
||||
>>> lines = [' one', ' two', ' three']
|
||||
>>> indentcount_lines(lines)
|
||||
1
|
||||
>>> lines = []
|
||||
>>> indentcount_lines(lines)
|
||||
0
|
||||
>>> lines = [' one']
|
||||
>>> indentcount_lines(lines)
|
||||
1
|
||||
>>> indentcount_lines([' '])
|
||||
0
|
||||
'''
|
||||
indentno = sys.maxsize
|
||||
for line in lines:
|
||||
stripped = line.lstrip()
|
||||
if stripped:
|
||||
indentno = min(indentno, len(line) - len(stripped))
|
||||
if indentno == sys.maxsize:
|
||||
return 0
|
||||
return indentno
|
||||
|
||||
|
||||
def filldoc(docdict, unindent_params=True):
|
||||
''' Return docstring decorator using docdict variable dictionary
|
||||
|
||||
Parameters
|
||||
----------
|
||||
docdict : dictionary
|
||||
dictionary containing name, docstring fragment pairs
|
||||
unindent_params : {False, True}, boolean, optional
|
||||
If True, strip common indentation from all parameters in
|
||||
docdict
|
||||
|
||||
Returns
|
||||
-------
|
||||
decfunc : function
|
||||
decorator that applies dictionary to input function docstring
|
||||
|
||||
'''
|
||||
if unindent_params:
|
||||
docdict = unindent_dict(docdict)
|
||||
|
||||
def decorate(f):
|
||||
f.__doc__ = docformat(f.__doc__, docdict)
|
||||
return f
|
||||
return decorate
|
||||
|
||||
|
||||
def unindent_dict(docdict):
|
||||
''' Unindent all strings in a docdict '''
|
||||
can_dict = {}
|
||||
for name, dstr in docdict.items():
|
||||
can_dict[name] = unindent_string(dstr)
|
||||
return can_dict
|
||||
|
||||
|
||||
def unindent_string(docstring):
|
||||
''' Set docstring to minimum indent for all lines, including first
|
||||
|
||||
>>> unindent_string(' two')
|
||||
'two'
|
||||
>>> unindent_string(' two\\n three')
|
||||
'two\\n three'
|
||||
'''
|
||||
lines = docstring.expandtabs().splitlines()
|
||||
icount = indentcount_lines(lines)
|
||||
if icount == 0:
|
||||
return docstring
|
||||
return '\n'.join([line[icount:] for line in lines])
|
||||
|
||||
|
||||
def doc_replace(obj, oldval, newval):
|
||||
"""Decorator to take the docstring from obj, with oldval replaced by newval
|
||||
|
||||
Equivalent to ``func.__doc__ = obj.__doc__.replace(oldval, newval)``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
obj : object
|
||||
The object to take the docstring from.
|
||||
oldval : string
|
||||
The string to replace from the original docstring.
|
||||
newval : string
|
||||
The string to replace ``oldval`` with.
|
||||
"""
|
||||
# __doc__ may be None for optimized Python (-OO)
|
||||
doc = (obj.__doc__ or '').replace(oldval, newval)
|
||||
|
||||
def inner(func):
|
||||
func.__doc__ = doc
|
||||
return func
|
||||
|
||||
return inner
|
||||
Binary file not shown.
@ -0,0 +1,101 @@
|
||||
""" Test for assert_deallocated context manager and gc utilities
|
||||
"""
|
||||
import gc
|
||||
|
||||
from scipy._lib._gcutils import (set_gc_state, gc_state, assert_deallocated,
|
||||
ReferenceError, IS_PYPY)
|
||||
|
||||
from numpy.testing import assert_equal
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_set_gc_state():
|
||||
gc_status = gc.isenabled()
|
||||
try:
|
||||
for state in (True, False):
|
||||
gc.enable()
|
||||
set_gc_state(state)
|
||||
assert_equal(gc.isenabled(), state)
|
||||
gc.disable()
|
||||
set_gc_state(state)
|
||||
assert_equal(gc.isenabled(), state)
|
||||
finally:
|
||||
if gc_status:
|
||||
gc.enable()
|
||||
|
||||
|
||||
def test_gc_state():
|
||||
# Test gc_state context manager
|
||||
gc_status = gc.isenabled()
|
||||
try:
|
||||
for pre_state in (True, False):
|
||||
set_gc_state(pre_state)
|
||||
for with_state in (True, False):
|
||||
# Check the gc state is with_state in with block
|
||||
with gc_state(with_state):
|
||||
assert_equal(gc.isenabled(), with_state)
|
||||
# And returns to previous state outside block
|
||||
assert_equal(gc.isenabled(), pre_state)
|
||||
# Even if the gc state is set explicitly within the block
|
||||
with gc_state(with_state):
|
||||
assert_equal(gc.isenabled(), with_state)
|
||||
set_gc_state(not with_state)
|
||||
assert_equal(gc.isenabled(), pre_state)
|
||||
finally:
|
||||
if gc_status:
|
||||
gc.enable()
|
||||
|
||||
|
||||
@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
|
||||
def test_assert_deallocated():
|
||||
# Ordinary use
|
||||
class C:
|
||||
def __init__(self, arg0, arg1, name='myname'):
|
||||
self.name = name
|
||||
for gc_current in (True, False):
|
||||
with gc_state(gc_current):
|
||||
# We are deleting from with-block context, so that's OK
|
||||
with assert_deallocated(C, 0, 2, 'another name') as c:
|
||||
assert_equal(c.name, 'another name')
|
||||
del c
|
||||
# Or not using the thing in with-block context, also OK
|
||||
with assert_deallocated(C, 0, 2, name='third name'):
|
||||
pass
|
||||
assert_equal(gc.isenabled(), gc_current)
|
||||
|
||||
|
||||
@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
|
||||
def test_assert_deallocated_nodel():
|
||||
class C:
|
||||
pass
|
||||
with pytest.raises(ReferenceError):
|
||||
# Need to delete after using if in with-block context
|
||||
# Note: assert_deallocated(C) needs to be assigned for the test
|
||||
# to function correctly. It is assigned to _, but _ itself is
|
||||
# not referenced in the body of the with, it is only there for
|
||||
# the refcount.
|
||||
with assert_deallocated(C) as _:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
|
||||
def test_assert_deallocated_circular():
|
||||
class C:
|
||||
def __init__(self):
|
||||
self._circular = self
|
||||
with pytest.raises(ReferenceError):
|
||||
# Circular reference, no automatic garbage collection
|
||||
with assert_deallocated(C) as c:
|
||||
del c
|
||||
|
||||
|
||||
@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
|
||||
def test_assert_deallocated_circular2():
|
||||
class C:
|
||||
def __init__(self):
|
||||
self._circular = self
|
||||
with pytest.raises(ReferenceError):
|
||||
# Still circular reference, no automatic garbage collection
|
||||
with assert_deallocated(C):
|
||||
pass
|
||||
@ -0,0 +1,67 @@
|
||||
from pytest import raises as assert_raises
|
||||
from scipy._lib._pep440 import Version, parse
|
||||
|
||||
|
||||
def test_main_versions():
|
||||
assert Version('1.8.0') == Version('1.8.0')
|
||||
for ver in ['1.9.0', '2.0.0', '1.8.1']:
|
||||
assert Version('1.8.0') < Version(ver)
|
||||
|
||||
for ver in ['1.7.0', '1.7.1', '0.9.9']:
|
||||
assert Version('1.8.0') > Version(ver)
|
||||
|
||||
|
||||
def test_version_1_point_10():
|
||||
# regression test for gh-2998.
|
||||
assert Version('1.9.0') < Version('1.10.0')
|
||||
assert Version('1.11.0') < Version('1.11.1')
|
||||
assert Version('1.11.0') == Version('1.11.0')
|
||||
assert Version('1.99.11') < Version('1.99.12')
|
||||
|
||||
|
||||
def test_alpha_beta_rc():
|
||||
assert Version('1.8.0rc1') == Version('1.8.0rc1')
|
||||
for ver in ['1.8.0', '1.8.0rc2']:
|
||||
assert Version('1.8.0rc1') < Version(ver)
|
||||
|
||||
for ver in ['1.8.0a2', '1.8.0b3', '1.7.2rc4']:
|
||||
assert Version('1.8.0rc1') > Version(ver)
|
||||
|
||||
assert Version('1.8.0b1') > Version('1.8.0a2')
|
||||
|
||||
|
||||
def test_dev_version():
|
||||
assert Version('1.9.0.dev+Unknown') < Version('1.9.0')
|
||||
for ver in ['1.9.0', '1.9.0a1', '1.9.0b2', '1.9.0b2.dev+ffffffff', '1.9.0.dev1']:
|
||||
assert Version('1.9.0.dev+f16acvda') < Version(ver)
|
||||
|
||||
assert Version('1.9.0.dev+f16acvda') == Version('1.9.0.dev+f16acvda')
|
||||
|
||||
|
||||
def test_dev_a_b_rc_mixed():
|
||||
assert Version('1.9.0a2.dev+f16acvda') == Version('1.9.0a2.dev+f16acvda')
|
||||
assert Version('1.9.0a2.dev+6acvda54') < Version('1.9.0a2')
|
||||
|
||||
|
||||
def test_dev0_version():
|
||||
assert Version('1.9.0.dev0+Unknown') < Version('1.9.0')
|
||||
for ver in ['1.9.0', '1.9.0a1', '1.9.0b2', '1.9.0b2.dev0+ffffffff']:
|
||||
assert Version('1.9.0.dev0+f16acvda') < Version(ver)
|
||||
|
||||
assert Version('1.9.0.dev0+f16acvda') == Version('1.9.0.dev0+f16acvda')
|
||||
|
||||
|
||||
def test_dev0_a_b_rc_mixed():
|
||||
assert Version('1.9.0a2.dev0+f16acvda') == Version('1.9.0a2.dev0+f16acvda')
|
||||
assert Version('1.9.0a2.dev0+6acvda54') < Version('1.9.0a2')
|
||||
|
||||
|
||||
def test_raises():
|
||||
for ver in ['1,9.0', '1.7.x']:
|
||||
assert_raises(ValueError, Version, ver)
|
||||
|
||||
def test_legacy_version():
|
||||
# Non-PEP-440 version identifiers always compare less. For NumPy this only
|
||||
# occurs on dev builds prior to 1.10.0 which are unsupported anyway.
|
||||
assert parse('invalid') < Version('0.0.0')
|
||||
assert parse('1.9.0-f16acvda') < Version('1.0.0')
|
||||
@ -0,0 +1,32 @@
|
||||
import sys
|
||||
from scipy._lib._testutils import _parse_size, _get_mem_available
|
||||
import pytest
|
||||
|
||||
|
||||
def test__parse_size():
|
||||
expected = {
|
||||
'12': 12e6,
|
||||
'12 b': 12,
|
||||
'12k': 12e3,
|
||||
' 12 M ': 12e6,
|
||||
' 12 G ': 12e9,
|
||||
' 12Tb ': 12e12,
|
||||
'12 Mib ': 12 * 1024.0**2,
|
||||
'12Tib': 12 * 1024.0**4,
|
||||
}
|
||||
|
||||
for inp, outp in sorted(expected.items()):
|
||||
if outp is None:
|
||||
with pytest.raises(ValueError):
|
||||
_parse_size(inp)
|
||||
else:
|
||||
assert _parse_size(inp) == outp
|
||||
|
||||
|
||||
def test__mem_available():
|
||||
# May return None on non-Linux platforms
|
||||
available = _get_mem_available()
|
||||
if sys.platform.startswith('linux'):
|
||||
assert available >= 0
|
||||
else:
|
||||
assert available is None or available >= 0
|
||||
@ -0,0 +1,51 @@
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from numpy.testing import assert_
|
||||
from pytest import raises as assert_raises
|
||||
|
||||
from scipy._lib._threadsafety import ReentrancyLock, non_reentrant, ReentrancyError
|
||||
|
||||
|
||||
def test_parallel_threads():
|
||||
# Check that ReentrancyLock serializes work in parallel threads.
|
||||
#
|
||||
# The test is not fully deterministic, and may succeed falsely if
|
||||
# the timings go wrong.
|
||||
|
||||
lock = ReentrancyLock("failure")
|
||||
|
||||
failflag = [False]
|
||||
exceptions_raised = []
|
||||
|
||||
def worker(k):
|
||||
try:
|
||||
with lock:
|
||||
assert_(not failflag[0])
|
||||
failflag[0] = True
|
||||
time.sleep(0.1 * k)
|
||||
assert_(failflag[0])
|
||||
failflag[0] = False
|
||||
except Exception:
|
||||
exceptions_raised.append(traceback.format_exc(2))
|
||||
|
||||
threads = [threading.Thread(target=lambda k=k: worker(k))
|
||||
for k in range(3)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
exceptions_raised = "\n".join(exceptions_raised)
|
||||
assert_(not exceptions_raised, exceptions_raised)
|
||||
|
||||
|
||||
def test_reentering():
|
||||
# Check that ReentrancyLock prevents re-entering from the same thread.
|
||||
|
||||
@non_reentrant()
|
||||
def func(x):
|
||||
return func(x)
|
||||
|
||||
assert_raises(ReentrancyError, func, 0)
|
||||
447
venv/lib/python3.12/site-packages/scipy/_lib/tests/test__util.py
Normal file
447
venv/lib/python3.12/site-packages/scipy/_lib/tests/test__util.py
Normal file
@ -0,0 +1,447 @@
|
||||
from multiprocessing import Pool
|
||||
from multiprocessing.pool import Pool as PWL
|
||||
import re
|
||||
import math
|
||||
from fractions import Fraction
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_equal, assert_
|
||||
import pytest
|
||||
from pytest import raises as assert_raises
|
||||
import hypothesis.extra.numpy as npst
|
||||
from hypothesis import given, strategies, reproduce_failure # noqa: F401
|
||||
from scipy.conftest import array_api_compatible, skip_xp_invalid_arg
|
||||
|
||||
from scipy._lib._array_api import (xp_assert_equal, xp_assert_close, is_numpy,
|
||||
copy as xp_copy)
|
||||
from scipy._lib._util import (_aligned_zeros, check_random_state, MapWrapper,
|
||||
getfullargspec_no_self, FullArgSpec,
|
||||
rng_integers, _validate_int, _rename_parameter,
|
||||
_contains_nan, _rng_html_rewrite, _lazywhere)
|
||||
|
||||
skip_xp_backends = pytest.mark.skip_xp_backends
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test__aligned_zeros():
|
||||
niter = 10
|
||||
|
||||
def check(shape, dtype, order, align):
|
||||
err_msg = repr((shape, dtype, order, align))
|
||||
x = _aligned_zeros(shape, dtype, order, align=align)
|
||||
if align is None:
|
||||
align = np.dtype(dtype).alignment
|
||||
assert_equal(x.__array_interface__['data'][0] % align, 0)
|
||||
if hasattr(shape, '__len__'):
|
||||
assert_equal(x.shape, shape, err_msg)
|
||||
else:
|
||||
assert_equal(x.shape, (shape,), err_msg)
|
||||
assert_equal(x.dtype, dtype)
|
||||
if order == "C":
|
||||
assert_(x.flags.c_contiguous, err_msg)
|
||||
elif order == "F":
|
||||
if x.size > 0:
|
||||
# Size-0 arrays get invalid flags on NumPy 1.5
|
||||
assert_(x.flags.f_contiguous, err_msg)
|
||||
elif order is None:
|
||||
assert_(x.flags.c_contiguous, err_msg)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
# try various alignments
|
||||
for align in [1, 2, 3, 4, 8, 16, 32, 64, None]:
|
||||
for n in [0, 1, 3, 11]:
|
||||
for order in ["C", "F", None]:
|
||||
for dtype in [np.uint8, np.float64]:
|
||||
for shape in [n, (1, 2, 3, n)]:
|
||||
for j in range(niter):
|
||||
check(shape, dtype, order, align)
|
||||
|
||||
|
||||
def test_check_random_state():
|
||||
# If seed is None, return the RandomState singleton used by np.random.
|
||||
# If seed is an int, return a new RandomState instance seeded with seed.
|
||||
# If seed is already a RandomState instance, return it.
|
||||
# Otherwise raise ValueError.
|
||||
rsi = check_random_state(1)
|
||||
assert_equal(type(rsi), np.random.RandomState)
|
||||
rsi = check_random_state(rsi)
|
||||
assert_equal(type(rsi), np.random.RandomState)
|
||||
rsi = check_random_state(None)
|
||||
assert_equal(type(rsi), np.random.RandomState)
|
||||
assert_raises(ValueError, check_random_state, 'a')
|
||||
rg = np.random.Generator(np.random.PCG64())
|
||||
rsi = check_random_state(rg)
|
||||
assert_equal(type(rsi), np.random.Generator)
|
||||
|
||||
|
||||
def test_getfullargspec_no_self():
|
||||
p = MapWrapper(1)
|
||||
argspec = getfullargspec_no_self(p.__init__)
|
||||
assert_equal(argspec, FullArgSpec(['pool'], None, None, (1,), [],
|
||||
None, {}))
|
||||
argspec = getfullargspec_no_self(p.__call__)
|
||||
assert_equal(argspec, FullArgSpec(['func', 'iterable'], None, None, None,
|
||||
[], None, {}))
|
||||
|
||||
class _rv_generic:
|
||||
def _rvs(self, a, b=2, c=3, *args, size=None, **kwargs):
|
||||
return None
|
||||
|
||||
rv_obj = _rv_generic()
|
||||
argspec = getfullargspec_no_self(rv_obj._rvs)
|
||||
assert_equal(argspec, FullArgSpec(['a', 'b', 'c'], 'args', 'kwargs',
|
||||
(2, 3), ['size'], {'size': None}, {}))
|
||||
|
||||
|
||||
def test_mapwrapper_serial():
|
||||
in_arg = np.arange(10.)
|
||||
out_arg = np.sin(in_arg)
|
||||
|
||||
p = MapWrapper(1)
|
||||
assert_(p._mapfunc is map)
|
||||
assert_(p.pool is None)
|
||||
assert_(p._own_pool is False)
|
||||
out = list(p(np.sin, in_arg))
|
||||
assert_equal(out, out_arg)
|
||||
|
||||
with assert_raises(RuntimeError):
|
||||
p = MapWrapper(0)
|
||||
|
||||
|
||||
def test_pool():
|
||||
with Pool(2) as p:
|
||||
p.map(math.sin, [1, 2, 3, 4])
|
||||
|
||||
|
||||
def test_mapwrapper_parallel():
|
||||
in_arg = np.arange(10.)
|
||||
out_arg = np.sin(in_arg)
|
||||
|
||||
with MapWrapper(2) as p:
|
||||
out = p(np.sin, in_arg)
|
||||
assert_equal(list(out), out_arg)
|
||||
|
||||
assert_(p._own_pool is True)
|
||||
assert_(isinstance(p.pool, PWL))
|
||||
assert_(p._mapfunc is not None)
|
||||
|
||||
# the context manager should've closed the internal pool
|
||||
# check that it has by asking it to calculate again.
|
||||
with assert_raises(Exception) as excinfo:
|
||||
p(np.sin, in_arg)
|
||||
|
||||
assert_(excinfo.type is ValueError)
|
||||
|
||||
# can also set a PoolWrapper up with a map-like callable instance
|
||||
with Pool(2) as p:
|
||||
q = MapWrapper(p.map)
|
||||
|
||||
assert_(q._own_pool is False)
|
||||
q.close()
|
||||
|
||||
# closing the PoolWrapper shouldn't close the internal pool
|
||||
# because it didn't create it
|
||||
out = p.map(np.sin, in_arg)
|
||||
assert_equal(list(out), out_arg)
|
||||
|
||||
|
||||
def test_rng_integers():
|
||||
rng = np.random.RandomState()
|
||||
|
||||
# test that numbers are inclusive of high point
|
||||
arr = rng_integers(rng, low=2, high=5, size=100, endpoint=True)
|
||||
assert np.max(arr) == 5
|
||||
assert np.min(arr) == 2
|
||||
assert arr.shape == (100, )
|
||||
|
||||
# test that numbers are inclusive of high point
|
||||
arr = rng_integers(rng, low=5, size=100, endpoint=True)
|
||||
assert np.max(arr) == 5
|
||||
assert np.min(arr) == 0
|
||||
assert arr.shape == (100, )
|
||||
|
||||
# test that numbers are exclusive of high point
|
||||
arr = rng_integers(rng, low=2, high=5, size=100, endpoint=False)
|
||||
assert np.max(arr) == 4
|
||||
assert np.min(arr) == 2
|
||||
assert arr.shape == (100, )
|
||||
|
||||
# test that numbers are exclusive of high point
|
||||
arr = rng_integers(rng, low=5, size=100, endpoint=False)
|
||||
assert np.max(arr) == 4
|
||||
assert np.min(arr) == 0
|
||||
assert arr.shape == (100, )
|
||||
|
||||
# now try with np.random.Generator
|
||||
try:
|
||||
rng = np.random.default_rng()
|
||||
except AttributeError:
|
||||
return
|
||||
|
||||
# test that numbers are inclusive of high point
|
||||
arr = rng_integers(rng, low=2, high=5, size=100, endpoint=True)
|
||||
assert np.max(arr) == 5
|
||||
assert np.min(arr) == 2
|
||||
assert arr.shape == (100, )
|
||||
|
||||
# test that numbers are inclusive of high point
|
||||
arr = rng_integers(rng, low=5, size=100, endpoint=True)
|
||||
assert np.max(arr) == 5
|
||||
assert np.min(arr) == 0
|
||||
assert arr.shape == (100, )
|
||||
|
||||
# test that numbers are exclusive of high point
|
||||
arr = rng_integers(rng, low=2, high=5, size=100, endpoint=False)
|
||||
assert np.max(arr) == 4
|
||||
assert np.min(arr) == 2
|
||||
assert arr.shape == (100, )
|
||||
|
||||
# test that numbers are exclusive of high point
|
||||
arr = rng_integers(rng, low=5, size=100, endpoint=False)
|
||||
assert np.max(arr) == 4
|
||||
assert np.min(arr) == 0
|
||||
assert arr.shape == (100, )
|
||||
|
||||
|
||||
class TestValidateInt:
|
||||
|
||||
@pytest.mark.parametrize('n', [4, np.uint8(4), np.int16(4), np.array(4)])
|
||||
def test_validate_int(self, n):
|
||||
n = _validate_int(n, 'n')
|
||||
assert n == 4
|
||||
|
||||
@pytest.mark.parametrize('n', [4.0, np.array([4]), Fraction(4, 1)])
|
||||
def test_validate_int_bad(self, n):
|
||||
with pytest.raises(TypeError, match='n must be an integer'):
|
||||
_validate_int(n, 'n')
|
||||
|
||||
def test_validate_int_below_min(self):
|
||||
with pytest.raises(ValueError, match='n must be an integer not '
|
||||
'less than 0'):
|
||||
_validate_int(-1, 'n', 0)
|
||||
|
||||
|
||||
class TestRenameParameter:
|
||||
# check that wrapper `_rename_parameter` for backward-compatible
|
||||
# keyword renaming works correctly
|
||||
|
||||
# Example method/function that still accepts keyword `old`
|
||||
@_rename_parameter("old", "new")
|
||||
def old_keyword_still_accepted(self, new):
|
||||
return new
|
||||
|
||||
# Example method/function for which keyword `old` is deprecated
|
||||
@_rename_parameter("old", "new", dep_version="1.9.0")
|
||||
def old_keyword_deprecated(self, new):
|
||||
return new
|
||||
|
||||
def test_old_keyword_still_accepted(self):
|
||||
# positional argument and both keyword work identically
|
||||
res1 = self.old_keyword_still_accepted(10)
|
||||
res2 = self.old_keyword_still_accepted(new=10)
|
||||
res3 = self.old_keyword_still_accepted(old=10)
|
||||
assert res1 == res2 == res3 == 10
|
||||
|
||||
# unexpected keyword raises an error
|
||||
message = re.escape("old_keyword_still_accepted() got an unexpected")
|
||||
with pytest.raises(TypeError, match=message):
|
||||
self.old_keyword_still_accepted(unexpected=10)
|
||||
|
||||
# multiple values for the same parameter raises an error
|
||||
message = re.escape("old_keyword_still_accepted() got multiple")
|
||||
with pytest.raises(TypeError, match=message):
|
||||
self.old_keyword_still_accepted(10, new=10)
|
||||
with pytest.raises(TypeError, match=message):
|
||||
self.old_keyword_still_accepted(10, old=10)
|
||||
with pytest.raises(TypeError, match=message):
|
||||
self.old_keyword_still_accepted(new=10, old=10)
|
||||
|
||||
def test_old_keyword_deprecated(self):
|
||||
# positional argument and both keyword work identically,
|
||||
# but use of old keyword results in DeprecationWarning
|
||||
dep_msg = "Use of keyword argument `old` is deprecated"
|
||||
res1 = self.old_keyword_deprecated(10)
|
||||
res2 = self.old_keyword_deprecated(new=10)
|
||||
with pytest.warns(DeprecationWarning, match=dep_msg):
|
||||
res3 = self.old_keyword_deprecated(old=10)
|
||||
assert res1 == res2 == res3 == 10
|
||||
|
||||
# unexpected keyword raises an error
|
||||
message = re.escape("old_keyword_deprecated() got an unexpected")
|
||||
with pytest.raises(TypeError, match=message):
|
||||
self.old_keyword_deprecated(unexpected=10)
|
||||
|
||||
# multiple values for the same parameter raises an error and,
|
||||
# if old keyword is used, results in DeprecationWarning
|
||||
message = re.escape("old_keyword_deprecated() got multiple")
|
||||
with pytest.raises(TypeError, match=message):
|
||||
self.old_keyword_deprecated(10, new=10)
|
||||
with pytest.raises(TypeError, match=message), \
|
||||
pytest.warns(DeprecationWarning, match=dep_msg):
|
||||
self.old_keyword_deprecated(10, old=10)
|
||||
with pytest.raises(TypeError, match=message), \
|
||||
pytest.warns(DeprecationWarning, match=dep_msg):
|
||||
self.old_keyword_deprecated(new=10, old=10)
|
||||
|
||||
|
||||
class TestContainsNaNTest:
|
||||
|
||||
def test_policy(self):
|
||||
data = np.array([1, 2, 3, np.nan])
|
||||
|
||||
contains_nan, nan_policy = _contains_nan(data, nan_policy="propagate")
|
||||
assert contains_nan
|
||||
assert nan_policy == "propagate"
|
||||
|
||||
contains_nan, nan_policy = _contains_nan(data, nan_policy="omit")
|
||||
assert contains_nan
|
||||
assert nan_policy == "omit"
|
||||
|
||||
msg = "The input contains nan values"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
_contains_nan(data, nan_policy="raise")
|
||||
|
||||
msg = "nan_policy must be one of"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
_contains_nan(data, nan_policy="nan")
|
||||
|
||||
def test_contains_nan(self):
|
||||
data1 = np.array([1, 2, 3])
|
||||
assert not _contains_nan(data1)[0]
|
||||
|
||||
data2 = np.array([1, 2, 3, np.nan])
|
||||
assert _contains_nan(data2)[0]
|
||||
|
||||
data3 = np.array([np.nan, 2, 3, np.nan])
|
||||
assert _contains_nan(data3)[0]
|
||||
|
||||
data4 = np.array([[1, 2], [3, 4]])
|
||||
assert not _contains_nan(data4)[0]
|
||||
|
||||
data5 = np.array([[1, 2], [3, np.nan]])
|
||||
assert _contains_nan(data5)[0]
|
||||
|
||||
@skip_xp_invalid_arg
|
||||
def test_contains_nan_with_strings(self):
|
||||
data1 = np.array([1, 2, "3", np.nan]) # converted to string "nan"
|
||||
assert not _contains_nan(data1)[0]
|
||||
|
||||
data2 = np.array([1, 2, "3", np.nan], dtype='object')
|
||||
assert _contains_nan(data2)[0]
|
||||
|
||||
data3 = np.array([["1", 2], [3, np.nan]]) # converted to string "nan"
|
||||
assert not _contains_nan(data3)[0]
|
||||
|
||||
data4 = np.array([["1", 2], [3, np.nan]], dtype='object')
|
||||
assert _contains_nan(data4)[0]
|
||||
|
||||
@skip_xp_backends('jax.numpy',
|
||||
reasons=["JAX arrays do not support item assignment"])
|
||||
@pytest.mark.usefixtures("skip_xp_backends")
|
||||
@array_api_compatible
|
||||
@pytest.mark.parametrize("nan_policy", ['propagate', 'omit', 'raise'])
|
||||
def test_array_api(self, xp, nan_policy):
|
||||
rng = np.random.default_rng(932347235892482)
|
||||
x0 = rng.random(size=(2, 3, 4))
|
||||
x = xp.asarray(x0)
|
||||
x_nan = xp_copy(x, xp=xp)
|
||||
x_nan[1, 2, 1] = np.nan
|
||||
|
||||
contains_nan, nan_policy_out = _contains_nan(x, nan_policy=nan_policy)
|
||||
assert not contains_nan
|
||||
assert nan_policy_out == nan_policy
|
||||
|
||||
if nan_policy == 'raise':
|
||||
message = 'The input contains...'
|
||||
with pytest.raises(ValueError, match=message):
|
||||
_contains_nan(x_nan, nan_policy=nan_policy)
|
||||
elif nan_policy == 'omit' and not is_numpy(xp):
|
||||
message = "`nan_policy='omit' is incompatible..."
|
||||
with pytest.raises(ValueError, match=message):
|
||||
_contains_nan(x_nan, nan_policy=nan_policy)
|
||||
elif nan_policy == 'propagate':
|
||||
contains_nan, nan_policy_out = _contains_nan(
|
||||
x_nan, nan_policy=nan_policy)
|
||||
assert contains_nan
|
||||
assert nan_policy_out == nan_policy
|
||||
|
||||
|
||||
def test__rng_html_rewrite():
|
||||
def mock_str():
|
||||
lines = [
|
||||
'np.random.default_rng(8989843)',
|
||||
'np.random.default_rng(seed)',
|
||||
'np.random.default_rng(0x9a71b21474694f919882289dc1559ca)',
|
||||
' bob ',
|
||||
]
|
||||
return lines
|
||||
|
||||
res = _rng_html_rewrite(mock_str)()
|
||||
ref = [
|
||||
'np.random.default_rng()',
|
||||
'np.random.default_rng(seed)',
|
||||
'np.random.default_rng()',
|
||||
' bob ',
|
||||
]
|
||||
|
||||
assert res == ref
|
||||
|
||||
|
||||
class TestLazywhere:
|
||||
n_arrays = strategies.integers(min_value=1, max_value=3)
|
||||
rng_seed = strategies.integers(min_value=1000000000, max_value=9999999999)
|
||||
dtype = strategies.sampled_from((np.float32, np.float64))
|
||||
p = strategies.floats(min_value=0, max_value=1)
|
||||
data = strategies.data()
|
||||
|
||||
@pytest.mark.fail_slow(5)
|
||||
@pytest.mark.filterwarnings('ignore::RuntimeWarning') # overflows, etc.
|
||||
@skip_xp_backends('jax.numpy',
|
||||
reasons=["JAX arrays do not support item assignment"])
|
||||
@pytest.mark.usefixtures("skip_xp_backends")
|
||||
@array_api_compatible
|
||||
@given(n_arrays=n_arrays, rng_seed=rng_seed, dtype=dtype, p=p, data=data)
|
||||
def test_basic(self, n_arrays, rng_seed, dtype, p, data, xp):
|
||||
mbs = npst.mutually_broadcastable_shapes(num_shapes=n_arrays+1,
|
||||
min_side=0)
|
||||
input_shapes, result_shape = data.draw(mbs)
|
||||
cond_shape, *shapes = input_shapes
|
||||
fillvalue = xp.asarray(data.draw(npst.arrays(dtype=dtype, shape=tuple())))
|
||||
arrays = [xp.asarray(data.draw(npst.arrays(dtype=dtype, shape=shape)))
|
||||
for shape in shapes]
|
||||
|
||||
def f(*args):
|
||||
return sum(arg for arg in args)
|
||||
|
||||
def f2(*args):
|
||||
return sum(arg for arg in args) / 2
|
||||
|
||||
rng = np.random.default_rng(rng_seed)
|
||||
cond = xp.asarray(rng.random(size=cond_shape) > p)
|
||||
|
||||
res1 = _lazywhere(cond, arrays, f, fillvalue)
|
||||
res2 = _lazywhere(cond, arrays, f, f2=f2)
|
||||
|
||||
# Ensure arrays are at least 1d to follow sane type promotion rules.
|
||||
if xp == np:
|
||||
cond, fillvalue, *arrays = np.atleast_1d(cond, fillvalue, *arrays)
|
||||
|
||||
ref1 = xp.where(cond, f(*arrays), fillvalue)
|
||||
ref2 = xp.where(cond, f(*arrays), f2(*arrays))
|
||||
|
||||
if xp == np:
|
||||
ref1 = ref1.reshape(result_shape)
|
||||
ref2 = ref2.reshape(result_shape)
|
||||
res1 = xp.asarray(res1)[()]
|
||||
res2 = xp.asarray(res2)[()]
|
||||
|
||||
isinstance(res1, type(xp.asarray([])))
|
||||
xp_assert_close(res1, ref1, rtol=2e-16)
|
||||
assert_equal(res1.shape, ref1.shape)
|
||||
assert_equal(res1.dtype, ref1.dtype)
|
||||
|
||||
isinstance(res2, type(xp.asarray([])))
|
||||
xp_assert_equal(res2, ref2)
|
||||
assert_equal(res2.shape, ref2.shape)
|
||||
assert_equal(res2.dtype, ref2.dtype)
|
||||
@ -0,0 +1,114 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from scipy.conftest import array_api_compatible
|
||||
from scipy._lib._array_api import (
|
||||
_GLOBAL_CONFIG, array_namespace, _asarray, copy, xp_assert_equal, is_numpy
|
||||
)
|
||||
import scipy._lib.array_api_compat.numpy as np_compat
|
||||
|
||||
skip_xp_backends = pytest.mark.skip_xp_backends
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _GLOBAL_CONFIG["SCIPY_ARRAY_API"],
|
||||
reason="Array API test; set environment variable SCIPY_ARRAY_API=1 to run it")
|
||||
class TestArrayAPI:
|
||||
|
||||
def test_array_namespace(self):
|
||||
x, y = np.array([0, 1, 2]), np.array([0, 1, 2])
|
||||
xp = array_namespace(x, y)
|
||||
assert 'array_api_compat.numpy' in xp.__name__
|
||||
|
||||
_GLOBAL_CONFIG["SCIPY_ARRAY_API"] = False
|
||||
xp = array_namespace(x, y)
|
||||
assert 'array_api_compat.numpy' in xp.__name__
|
||||
_GLOBAL_CONFIG["SCIPY_ARRAY_API"] = True
|
||||
|
||||
@array_api_compatible
|
||||
def test_asarray(self, xp):
|
||||
x, y = _asarray([0, 1, 2], xp=xp), _asarray(np.arange(3), xp=xp)
|
||||
ref = xp.asarray([0, 1, 2])
|
||||
xp_assert_equal(x, ref)
|
||||
xp_assert_equal(y, ref)
|
||||
|
||||
@pytest.mark.filterwarnings("ignore: the matrix subclass")
|
||||
def test_raises(self):
|
||||
msg = "of type `numpy.ma.MaskedArray` are not supported"
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
array_namespace(np.ma.array(1), np.array(1))
|
||||
|
||||
msg = "of type `numpy.matrix` are not supported"
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
array_namespace(np.array(1), np.matrix(1))
|
||||
|
||||
msg = "only boolean and numerical dtypes are supported"
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
array_namespace([object()])
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
array_namespace('abc')
|
||||
|
||||
def test_array_likes(self):
|
||||
# should be no exceptions
|
||||
array_namespace([0, 1, 2])
|
||||
array_namespace(1, 2, 3)
|
||||
array_namespace(1)
|
||||
|
||||
@skip_xp_backends('jax.numpy',
|
||||
reasons=["JAX arrays do not support item assignment"])
|
||||
@pytest.mark.usefixtures("skip_xp_backends")
|
||||
@array_api_compatible
|
||||
def test_copy(self, xp):
|
||||
for _xp in [xp, None]:
|
||||
x = xp.asarray([1, 2, 3])
|
||||
y = copy(x, xp=_xp)
|
||||
# with numpy we'd want to use np.shared_memory, but that's not specified
|
||||
# in the array-api
|
||||
x[0] = 10
|
||||
x[1] = 11
|
||||
x[2] = 12
|
||||
|
||||
assert x[0] != y[0]
|
||||
assert x[1] != y[1]
|
||||
assert x[2] != y[2]
|
||||
assert id(x) != id(y)
|
||||
|
||||
@array_api_compatible
|
||||
@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float32', 'float64'])
|
||||
@pytest.mark.parametrize('shape', [(), (3,)])
|
||||
def test_strict_checks(self, xp, dtype, shape):
|
||||
# Check that `_strict_check` behaves as expected
|
||||
dtype = getattr(xp, dtype)
|
||||
x = xp.broadcast_to(xp.asarray(1, dtype=dtype), shape)
|
||||
x = x if shape else x[()]
|
||||
y = np_compat.asarray(1)[()]
|
||||
|
||||
options = dict(check_namespace=True, check_dtype=False, check_shape=False)
|
||||
if xp == np:
|
||||
xp_assert_equal(x, y, **options)
|
||||
else:
|
||||
with pytest.raises(AssertionError, match="Namespaces do not match."):
|
||||
xp_assert_equal(x, y, **options)
|
||||
|
||||
options = dict(check_namespace=False, check_dtype=True, check_shape=False)
|
||||
if y.dtype.name in str(x.dtype):
|
||||
xp_assert_equal(x, y, **options)
|
||||
else:
|
||||
with pytest.raises(AssertionError, match="dtypes do not match."):
|
||||
xp_assert_equal(x, y, **options)
|
||||
|
||||
options = dict(check_namespace=False, check_dtype=False, check_shape=True)
|
||||
if x.shape == y.shape:
|
||||
xp_assert_equal(x, y, **options)
|
||||
else:
|
||||
with pytest.raises(AssertionError, match="Shapes do not match."):
|
||||
xp_assert_equal(x, y, **options)
|
||||
|
||||
@array_api_compatible
|
||||
def test_check_scalar(self, xp):
|
||||
if not is_numpy(xp):
|
||||
pytest.skip("Scalars only exist in NumPy")
|
||||
|
||||
if is_numpy(xp):
|
||||
with pytest.raises(AssertionError, match="Types do not match."):
|
||||
xp_assert_equal(xp.asarray(0.), xp.float64(0))
|
||||
xp_assert_equal(xp.float64(0), xp.asarray(0.))
|
||||
162
venv/lib/python3.12/site-packages/scipy/_lib/tests/test_bunch.py
Normal file
162
venv/lib/python3.12/site-packages/scipy/_lib/tests/test_bunch.py
Normal file
@ -0,0 +1,162 @@
|
||||
import pytest
|
||||
import pickle
|
||||
from numpy.testing import assert_equal
|
||||
from scipy._lib._bunch import _make_tuple_bunch
|
||||
|
||||
|
||||
# `Result` is defined at the top level of the module so it can be
|
||||
# used to test pickling.
|
||||
Result = _make_tuple_bunch('Result', ['x', 'y', 'z'], ['w', 'beta'])
|
||||
|
||||
|
||||
class TestMakeTupleBunch:
|
||||
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
# Tests with Result
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
def setup_method(self):
|
||||
# Set up an instance of Result.
|
||||
self.result = Result(x=1, y=2, z=3, w=99, beta=0.5)
|
||||
|
||||
def test_attribute_access(self):
|
||||
assert_equal(self.result.x, 1)
|
||||
assert_equal(self.result.y, 2)
|
||||
assert_equal(self.result.z, 3)
|
||||
assert_equal(self.result.w, 99)
|
||||
assert_equal(self.result.beta, 0.5)
|
||||
|
||||
def test_indexing(self):
|
||||
assert_equal(self.result[0], 1)
|
||||
assert_equal(self.result[1], 2)
|
||||
assert_equal(self.result[2], 3)
|
||||
assert_equal(self.result[-1], 3)
|
||||
with pytest.raises(IndexError, match='index out of range'):
|
||||
self.result[3]
|
||||
|
||||
def test_unpacking(self):
|
||||
x0, y0, z0 = self.result
|
||||
assert_equal((x0, y0, z0), (1, 2, 3))
|
||||
assert_equal(self.result, (1, 2, 3))
|
||||
|
||||
def test_slice(self):
|
||||
assert_equal(self.result[1:], (2, 3))
|
||||
assert_equal(self.result[::2], (1, 3))
|
||||
assert_equal(self.result[::-1], (3, 2, 1))
|
||||
|
||||
def test_len(self):
|
||||
assert_equal(len(self.result), 3)
|
||||
|
||||
def test_repr(self):
|
||||
s = repr(self.result)
|
||||
assert_equal(s, 'Result(x=1, y=2, z=3, w=99, beta=0.5)')
|
||||
|
||||
def test_hash(self):
|
||||
assert_equal(hash(self.result), hash((1, 2, 3)))
|
||||
|
||||
def test_pickle(self):
|
||||
s = pickle.dumps(self.result)
|
||||
obj = pickle.loads(s)
|
||||
assert isinstance(obj, Result)
|
||||
assert_equal(obj.x, self.result.x)
|
||||
assert_equal(obj.y, self.result.y)
|
||||
assert_equal(obj.z, self.result.z)
|
||||
assert_equal(obj.w, self.result.w)
|
||||
assert_equal(obj.beta, self.result.beta)
|
||||
|
||||
def test_read_only_existing(self):
|
||||
with pytest.raises(AttributeError, match="can't set attribute"):
|
||||
self.result.x = -1
|
||||
|
||||
def test_read_only_new(self):
|
||||
self.result.plate_of_shrimp = "lattice of coincidence"
|
||||
assert self.result.plate_of_shrimp == "lattice of coincidence"
|
||||
|
||||
def test_constructor_missing_parameter(self):
|
||||
with pytest.raises(TypeError, match='missing'):
|
||||
# `w` is missing.
|
||||
Result(x=1, y=2, z=3, beta=0.75)
|
||||
|
||||
def test_constructor_incorrect_parameter(self):
|
||||
with pytest.raises(TypeError, match='unexpected'):
|
||||
# `foo` is not an existing field.
|
||||
Result(x=1, y=2, z=3, w=123, beta=0.75, foo=999)
|
||||
|
||||
def test_module(self):
|
||||
m = 'scipy._lib.tests.test_bunch'
|
||||
assert_equal(Result.__module__, m)
|
||||
assert_equal(self.result.__module__, m)
|
||||
|
||||
def test_extra_fields_per_instance(self):
|
||||
# This test exists to ensure that instances of the same class
|
||||
# store their own values for the extra fields. That is, the values
|
||||
# are stored per instance and not in the class.
|
||||
result1 = Result(x=1, y=2, z=3, w=-1, beta=0.0)
|
||||
result2 = Result(x=4, y=5, z=6, w=99, beta=1.0)
|
||||
assert_equal(result1.w, -1)
|
||||
assert_equal(result1.beta, 0.0)
|
||||
# The rest of these checks aren't essential, but let's check
|
||||
# them anyway.
|
||||
assert_equal(result1[:], (1, 2, 3))
|
||||
assert_equal(result2.w, 99)
|
||||
assert_equal(result2.beta, 1.0)
|
||||
assert_equal(result2[:], (4, 5, 6))
|
||||
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
# Other tests
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
def test_extra_field_names_is_optional(self):
|
||||
Square = _make_tuple_bunch('Square', ['width', 'height'])
|
||||
sq = Square(width=1, height=2)
|
||||
assert_equal(sq.width, 1)
|
||||
assert_equal(sq.height, 2)
|
||||
s = repr(sq)
|
||||
assert_equal(s, 'Square(width=1, height=2)')
|
||||
|
||||
def test_tuple_like(self):
|
||||
Tup = _make_tuple_bunch('Tup', ['a', 'b'])
|
||||
tu = Tup(a=1, b=2)
|
||||
assert isinstance(tu, tuple)
|
||||
assert isinstance(tu + (1,), tuple)
|
||||
|
||||
def test_explicit_module(self):
|
||||
m = 'some.module.name'
|
||||
Foo = _make_tuple_bunch('Foo', ['x'], ['a', 'b'], module=m)
|
||||
foo = Foo(x=1, a=355, b=113)
|
||||
assert_equal(Foo.__module__, m)
|
||||
assert_equal(foo.__module__, m)
|
||||
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
# Argument validation
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
@pytest.mark.parametrize('args', [('123', ['a'], ['b']),
|
||||
('Foo', ['-3'], ['x']),
|
||||
('Foo', ['a'], ['+-*/'])])
|
||||
def test_identifiers_not_allowed(self, args):
|
||||
with pytest.raises(ValueError, match='identifiers'):
|
||||
_make_tuple_bunch(*args)
|
||||
|
||||
@pytest.mark.parametrize('args', [('Foo', ['a', 'b', 'a'], ['x']),
|
||||
('Foo', ['a', 'b'], ['b', 'x'])])
|
||||
def test_repeated_field_names(self, args):
|
||||
with pytest.raises(ValueError, match='Duplicate'):
|
||||
_make_tuple_bunch(*args)
|
||||
|
||||
@pytest.mark.parametrize('args', [('Foo', ['_a'], ['x']),
|
||||
('Foo', ['a'], ['_x'])])
|
||||
def test_leading_underscore_not_allowed(self, args):
|
||||
with pytest.raises(ValueError, match='underscore'):
|
||||
_make_tuple_bunch(*args)
|
||||
|
||||
@pytest.mark.parametrize('args', [('Foo', ['def'], ['x']),
|
||||
('Foo', ['a'], ['or']),
|
||||
('and', ['a'], ['x'])])
|
||||
def test_keyword_not_allowed_in_fields(self, args):
|
||||
with pytest.raises(ValueError, match='keyword'):
|
||||
_make_tuple_bunch(*args)
|
||||
|
||||
def test_at_least_one_field_name_required(self):
|
||||
with pytest.raises(ValueError, match='at least one name'):
|
||||
_make_tuple_bunch('Qwerty', [], ['a', 'b'])
|
||||
@ -0,0 +1,204 @@
|
||||
from numpy.testing import assert_equal, assert_
|
||||
from pytest import raises as assert_raises
|
||||
|
||||
import time
|
||||
import pytest
|
||||
import ctypes
|
||||
import threading
|
||||
from scipy._lib import _ccallback_c as _test_ccallback_cython
|
||||
from scipy._lib import _test_ccallback
|
||||
from scipy._lib._ccallback import LowLevelCallable
|
||||
|
||||
try:
|
||||
import cffi
|
||||
HAVE_CFFI = True
|
||||
except ImportError:
|
||||
HAVE_CFFI = False
|
||||
|
||||
|
||||
ERROR_VALUE = 2.0
|
||||
|
||||
|
||||
def callback_python(a, user_data=None):
|
||||
if a == ERROR_VALUE:
|
||||
raise ValueError("bad value")
|
||||
|
||||
if user_data is None:
|
||||
return a + 1
|
||||
else:
|
||||
return a + user_data
|
||||
|
||||
def _get_cffi_func(base, signature):
|
||||
if not HAVE_CFFI:
|
||||
pytest.skip("cffi not installed")
|
||||
|
||||
# Get function address
|
||||
voidp = ctypes.cast(base, ctypes.c_void_p)
|
||||
address = voidp.value
|
||||
|
||||
# Create corresponding cffi handle
|
||||
ffi = cffi.FFI()
|
||||
func = ffi.cast(signature, address)
|
||||
return func
|
||||
|
||||
|
||||
def _get_ctypes_data():
|
||||
value = ctypes.c_double(2.0)
|
||||
return ctypes.cast(ctypes.pointer(value), ctypes.c_voidp)
|
||||
|
||||
|
||||
def _get_cffi_data():
|
||||
if not HAVE_CFFI:
|
||||
pytest.skip("cffi not installed")
|
||||
ffi = cffi.FFI()
|
||||
return ffi.new('double *', 2.0)
|
||||
|
||||
|
||||
CALLERS = {
|
||||
'simple': _test_ccallback.test_call_simple,
|
||||
'nodata': _test_ccallback.test_call_nodata,
|
||||
'nonlocal': _test_ccallback.test_call_nonlocal,
|
||||
'cython': _test_ccallback_cython.test_call_cython,
|
||||
}
|
||||
|
||||
# These functions have signatures known to the callers
|
||||
FUNCS = {
|
||||
'python': lambda: callback_python,
|
||||
'capsule': lambda: _test_ccallback.test_get_plus1_capsule(),
|
||||
'cython': lambda: LowLevelCallable.from_cython(_test_ccallback_cython,
|
||||
"plus1_cython"),
|
||||
'ctypes': lambda: _test_ccallback_cython.plus1_ctypes,
|
||||
'cffi': lambda: _get_cffi_func(_test_ccallback_cython.plus1_ctypes,
|
||||
'double (*)(double, int *, void *)'),
|
||||
'capsule_b': lambda: _test_ccallback.test_get_plus1b_capsule(),
|
||||
'cython_b': lambda: LowLevelCallable.from_cython(_test_ccallback_cython,
|
||||
"plus1b_cython"),
|
||||
'ctypes_b': lambda: _test_ccallback_cython.plus1b_ctypes,
|
||||
'cffi_b': lambda: _get_cffi_func(_test_ccallback_cython.plus1b_ctypes,
|
||||
'double (*)(double, double, int *, void *)'),
|
||||
}
|
||||
|
||||
# These functions have signatures the callers don't know
|
||||
BAD_FUNCS = {
|
||||
'capsule_bc': lambda: _test_ccallback.test_get_plus1bc_capsule(),
|
||||
'cython_bc': lambda: LowLevelCallable.from_cython(_test_ccallback_cython,
|
||||
"plus1bc_cython"),
|
||||
'ctypes_bc': lambda: _test_ccallback_cython.plus1bc_ctypes,
|
||||
'cffi_bc': lambda: _get_cffi_func(
|
||||
_test_ccallback_cython.plus1bc_ctypes,
|
||||
'double (*)(double, double, double, int *, void *)'
|
||||
),
|
||||
}
|
||||
|
||||
USER_DATAS = {
|
||||
'ctypes': _get_ctypes_data,
|
||||
'cffi': _get_cffi_data,
|
||||
'capsule': _test_ccallback.test_get_data_capsule,
|
||||
}
|
||||
|
||||
|
||||
def test_callbacks():
|
||||
def check(caller, func, user_data):
|
||||
caller = CALLERS[caller]
|
||||
func = FUNCS[func]()
|
||||
user_data = USER_DATAS[user_data]()
|
||||
|
||||
if func is callback_python:
|
||||
def func2(x):
|
||||
return func(x, 2.0)
|
||||
else:
|
||||
func2 = LowLevelCallable(func, user_data)
|
||||
func = LowLevelCallable(func)
|
||||
|
||||
# Test basic call
|
||||
assert_equal(caller(func, 1.0), 2.0)
|
||||
|
||||
# Test 'bad' value resulting to an error
|
||||
assert_raises(ValueError, caller, func, ERROR_VALUE)
|
||||
|
||||
# Test passing in user_data
|
||||
assert_equal(caller(func2, 1.0), 3.0)
|
||||
|
||||
for caller in sorted(CALLERS.keys()):
|
||||
for func in sorted(FUNCS.keys()):
|
||||
for user_data in sorted(USER_DATAS.keys()):
|
||||
check(caller, func, user_data)
|
||||
|
||||
|
||||
def test_bad_callbacks():
|
||||
def check(caller, func, user_data):
|
||||
caller = CALLERS[caller]
|
||||
user_data = USER_DATAS[user_data]()
|
||||
func = BAD_FUNCS[func]()
|
||||
|
||||
if func is callback_python:
|
||||
def func2(x):
|
||||
return func(x, 2.0)
|
||||
else:
|
||||
func2 = LowLevelCallable(func, user_data)
|
||||
func = LowLevelCallable(func)
|
||||
|
||||
# Test that basic call fails
|
||||
assert_raises(ValueError, caller, LowLevelCallable(func), 1.0)
|
||||
|
||||
# Test that passing in user_data also fails
|
||||
assert_raises(ValueError, caller, func2, 1.0)
|
||||
|
||||
# Test error message
|
||||
llfunc = LowLevelCallable(func)
|
||||
try:
|
||||
caller(llfunc, 1.0)
|
||||
except ValueError as err:
|
||||
msg = str(err)
|
||||
assert_(llfunc.signature in msg, msg)
|
||||
assert_('double (double, double, int *, void *)' in msg, msg)
|
||||
|
||||
for caller in sorted(CALLERS.keys()):
|
||||
for func in sorted(BAD_FUNCS.keys()):
|
||||
for user_data in sorted(USER_DATAS.keys()):
|
||||
check(caller, func, user_data)
|
||||
|
||||
|
||||
def test_signature_override():
|
||||
caller = _test_ccallback.test_call_simple
|
||||
func = _test_ccallback.test_get_plus1_capsule()
|
||||
|
||||
llcallable = LowLevelCallable(func, signature="bad signature")
|
||||
assert_equal(llcallable.signature, "bad signature")
|
||||
assert_raises(ValueError, caller, llcallable, 3)
|
||||
|
||||
llcallable = LowLevelCallable(func, signature="double (double, int *, void *)")
|
||||
assert_equal(llcallable.signature, "double (double, int *, void *)")
|
||||
assert_equal(caller(llcallable, 3), 4)
|
||||
|
||||
|
||||
def test_threadsafety():
|
||||
def callback(a, caller):
|
||||
if a <= 0:
|
||||
return 1
|
||||
else:
|
||||
res = caller(lambda x: callback(x, caller), a - 1)
|
||||
return 2*res
|
||||
|
||||
def check(caller):
|
||||
caller = CALLERS[caller]
|
||||
|
||||
results = []
|
||||
|
||||
count = 10
|
||||
|
||||
def run():
|
||||
time.sleep(0.01)
|
||||
r = caller(lambda x: callback(x, caller), count)
|
||||
results.append(r)
|
||||
|
||||
threads = [threading.Thread(target=run) for j in range(20)]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert_equal(results, [2.0**count]*len(threads))
|
||||
|
||||
for caller in CALLERS.keys():
|
||||
check(caller)
|
||||
@ -0,0 +1,10 @@
|
||||
import pytest
|
||||
|
||||
|
||||
def test_cython_api_deprecation():
|
||||
match = ("`scipy._lib._test_deprecation_def.foo_deprecated` "
|
||||
"is deprecated, use `foo` instead!\n"
|
||||
"Deprecated in Scipy 42.0.0")
|
||||
with pytest.warns(DeprecationWarning, match=match):
|
||||
from .. import _test_deprecation_call
|
||||
assert _test_deprecation_call.call() == (1, 1)
|
||||
@ -0,0 +1,17 @@
|
||||
import pytest
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
from .test_public_api import PUBLIC_MODULES
|
||||
|
||||
# Regression tests for gh-6793.
|
||||
# Check that all modules are importable in a new Python process.
|
||||
# This is not necessarily true if there are import cycles present.
|
||||
|
||||
@pytest.mark.fail_slow(20)
|
||||
@pytest.mark.slow
|
||||
def test_public_modules_importable():
|
||||
pids = [subprocess.Popen([sys.executable, '-c', f'import {module}'])
|
||||
for module in PUBLIC_MODULES]
|
||||
for i, pid in enumerate(pids):
|
||||
assert pid.wait() == 0, f'Failed to import {PUBLIC_MODULES[i]}'
|
||||
@ -0,0 +1,496 @@
|
||||
"""
|
||||
This test script is adopted from:
|
||||
https://github.com/numpy/numpy/blob/main/numpy/tests/test_public_api.py
|
||||
"""
|
||||
|
||||
import pkgutil
|
||||
import types
|
||||
import importlib
|
||||
import warnings
|
||||
from importlib import import_module
|
||||
|
||||
import pytest
|
||||
|
||||
import scipy
|
||||
|
||||
from scipy.conftest import xp_available_backends
|
||||
|
||||
|
||||
def test_dir_testing():
|
||||
"""Assert that output of dir has only one "testing/tester"
|
||||
attribute without duplicate"""
|
||||
assert len(dir(scipy)) == len(set(dir(scipy)))
|
||||
|
||||
|
||||
# Historically SciPy has not used leading underscores for private submodules
|
||||
# much. This has resulted in lots of things that look like public modules
|
||||
# (i.e. things that can be imported as `import scipy.somesubmodule.somefile`),
|
||||
# but were never intended to be public. The PUBLIC_MODULES list contains
|
||||
# modules that are either public because they were meant to be, or because they
|
||||
# contain public functions/objects that aren't present in any other namespace
|
||||
# for whatever reason and therefore should be treated as public.
|
||||
PUBLIC_MODULES = ["scipy." + s for s in [
|
||||
"cluster",
|
||||
"cluster.vq",
|
||||
"cluster.hierarchy",
|
||||
"constants",
|
||||
"datasets",
|
||||
"fft",
|
||||
"fftpack",
|
||||
"integrate",
|
||||
"interpolate",
|
||||
"io",
|
||||
"io.arff",
|
||||
"io.matlab",
|
||||
"io.wavfile",
|
||||
"linalg",
|
||||
"linalg.blas",
|
||||
"linalg.cython_blas",
|
||||
"linalg.lapack",
|
||||
"linalg.cython_lapack",
|
||||
"linalg.interpolative",
|
||||
"misc",
|
||||
"ndimage",
|
||||
"odr",
|
||||
"optimize",
|
||||
"signal",
|
||||
"signal.windows",
|
||||
"sparse",
|
||||
"sparse.linalg",
|
||||
"sparse.csgraph",
|
||||
"spatial",
|
||||
"spatial.distance",
|
||||
"spatial.transform",
|
||||
"special",
|
||||
"stats",
|
||||
"stats.contingency",
|
||||
"stats.distributions",
|
||||
"stats.mstats",
|
||||
"stats.qmc",
|
||||
"stats.sampling"
|
||||
]]
|
||||
|
||||
# The PRIVATE_BUT_PRESENT_MODULES list contains modules that lacked underscores
|
||||
# in their name and hence looked public, but weren't meant to be. All these
|
||||
# namespace were deprecated in the 1.8.0 release - see "clear split between
|
||||
# public and private API" in the 1.8.0 release notes.
|
||||
# These private modules support will be removed in SciPy v2.0.0, as the
|
||||
# deprecation messages emitted by each of these modules say.
|
||||
PRIVATE_BUT_PRESENT_MODULES = [
|
||||
'scipy.constants.codata',
|
||||
'scipy.constants.constants',
|
||||
'scipy.fftpack.basic',
|
||||
'scipy.fftpack.convolve',
|
||||
'scipy.fftpack.helper',
|
||||
'scipy.fftpack.pseudo_diffs',
|
||||
'scipy.fftpack.realtransforms',
|
||||
'scipy.integrate.dop',
|
||||
'scipy.integrate.lsoda',
|
||||
'scipy.integrate.odepack',
|
||||
'scipy.integrate.quadpack',
|
||||
'scipy.integrate.vode',
|
||||
'scipy.interpolate.dfitpack',
|
||||
'scipy.interpolate.fitpack',
|
||||
'scipy.interpolate.fitpack2',
|
||||
'scipy.interpolate.interpnd',
|
||||
'scipy.interpolate.interpolate',
|
||||
'scipy.interpolate.ndgriddata',
|
||||
'scipy.interpolate.polyint',
|
||||
'scipy.interpolate.rbf',
|
||||
'scipy.io.arff.arffread',
|
||||
'scipy.io.harwell_boeing',
|
||||
'scipy.io.idl',
|
||||
'scipy.io.matlab.byteordercodes',
|
||||
'scipy.io.matlab.mio',
|
||||
'scipy.io.matlab.mio4',
|
||||
'scipy.io.matlab.mio5',
|
||||
'scipy.io.matlab.mio5_params',
|
||||
'scipy.io.matlab.mio5_utils',
|
||||
'scipy.io.matlab.mio_utils',
|
||||
'scipy.io.matlab.miobase',
|
||||
'scipy.io.matlab.streams',
|
||||
'scipy.io.mmio',
|
||||
'scipy.io.netcdf',
|
||||
'scipy.linalg.basic',
|
||||
'scipy.linalg.decomp',
|
||||
'scipy.linalg.decomp_cholesky',
|
||||
'scipy.linalg.decomp_lu',
|
||||
'scipy.linalg.decomp_qr',
|
||||
'scipy.linalg.decomp_schur',
|
||||
'scipy.linalg.decomp_svd',
|
||||
'scipy.linalg.matfuncs',
|
||||
'scipy.linalg.misc',
|
||||
'scipy.linalg.special_matrices',
|
||||
'scipy.misc.common',
|
||||
'scipy.misc.doccer',
|
||||
'scipy.ndimage.filters',
|
||||
'scipy.ndimage.fourier',
|
||||
'scipy.ndimage.interpolation',
|
||||
'scipy.ndimage.measurements',
|
||||
'scipy.ndimage.morphology',
|
||||
'scipy.odr.models',
|
||||
'scipy.odr.odrpack',
|
||||
'scipy.optimize.cobyla',
|
||||
'scipy.optimize.cython_optimize',
|
||||
'scipy.optimize.lbfgsb',
|
||||
'scipy.optimize.linesearch',
|
||||
'scipy.optimize.minpack',
|
||||
'scipy.optimize.minpack2',
|
||||
'scipy.optimize.moduleTNC',
|
||||
'scipy.optimize.nonlin',
|
||||
'scipy.optimize.optimize',
|
||||
'scipy.optimize.slsqp',
|
||||
'scipy.optimize.tnc',
|
||||
'scipy.optimize.zeros',
|
||||
'scipy.signal.bsplines',
|
||||
'scipy.signal.filter_design',
|
||||
'scipy.signal.fir_filter_design',
|
||||
'scipy.signal.lti_conversion',
|
||||
'scipy.signal.ltisys',
|
||||
'scipy.signal.signaltools',
|
||||
'scipy.signal.spectral',
|
||||
'scipy.signal.spline',
|
||||
'scipy.signal.waveforms',
|
||||
'scipy.signal.wavelets',
|
||||
'scipy.signal.windows.windows',
|
||||
'scipy.sparse.base',
|
||||
'scipy.sparse.bsr',
|
||||
'scipy.sparse.compressed',
|
||||
'scipy.sparse.construct',
|
||||
'scipy.sparse.coo',
|
||||
'scipy.sparse.csc',
|
||||
'scipy.sparse.csr',
|
||||
'scipy.sparse.data',
|
||||
'scipy.sparse.dia',
|
||||
'scipy.sparse.dok',
|
||||
'scipy.sparse.extract',
|
||||
'scipy.sparse.lil',
|
||||
'scipy.sparse.linalg.dsolve',
|
||||
'scipy.sparse.linalg.eigen',
|
||||
'scipy.sparse.linalg.interface',
|
||||
'scipy.sparse.linalg.isolve',
|
||||
'scipy.sparse.linalg.matfuncs',
|
||||
'scipy.sparse.sparsetools',
|
||||
'scipy.sparse.spfuncs',
|
||||
'scipy.sparse.sputils',
|
||||
'scipy.spatial.ckdtree',
|
||||
'scipy.spatial.kdtree',
|
||||
'scipy.spatial.qhull',
|
||||
'scipy.spatial.transform.rotation',
|
||||
'scipy.special.add_newdocs',
|
||||
'scipy.special.basic',
|
||||
'scipy.special.cython_special',
|
||||
'scipy.special.orthogonal',
|
||||
'scipy.special.sf_error',
|
||||
'scipy.special.specfun',
|
||||
'scipy.special.spfun_stats',
|
||||
'scipy.stats.biasedurn',
|
||||
'scipy.stats.kde',
|
||||
'scipy.stats.morestats',
|
||||
'scipy.stats.mstats_basic',
|
||||
'scipy.stats.mstats_extras',
|
||||
'scipy.stats.mvn',
|
||||
'scipy.stats.stats',
|
||||
]
|
||||
|
||||
|
||||
def is_unexpected(name):
|
||||
"""Check if this needs to be considered."""
|
||||
if '._' in name or '.tests' in name or '.setup' in name:
|
||||
return False
|
||||
|
||||
if name in PUBLIC_MODULES:
|
||||
return False
|
||||
|
||||
if name in PRIVATE_BUT_PRESENT_MODULES:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
SKIP_LIST = [
|
||||
'scipy.conftest',
|
||||
'scipy.version',
|
||||
'scipy.special.libsf_error_state'
|
||||
]
|
||||
|
||||
|
||||
# XXX: this test does more than it says on the tin - in using `pkgutil.walk_packages`,
|
||||
# it will raise if it encounters any exceptions which are not handled by `ignore_errors`
|
||||
# while attempting to import each discovered package.
|
||||
# For now, `ignore_errors` only ignores what is necessary, but this could be expanded -
|
||||
# for example, to all errors from private modules or git subpackages - if desired.
|
||||
def test_all_modules_are_expected():
|
||||
"""
|
||||
Test that we don't add anything that looks like a new public module by
|
||||
accident. Check is based on filenames.
|
||||
"""
|
||||
|
||||
def ignore_errors(name):
|
||||
# if versions of other array libraries are installed which are incompatible
|
||||
# with the installed NumPy version, there can be errors on importing
|
||||
# `array_api_compat`. This should only raise if SciPy is configured with
|
||||
# that library as an available backend.
|
||||
backends = {'cupy': 'cupy',
|
||||
'pytorch': 'torch',
|
||||
'dask.array': 'dask.array'}
|
||||
for backend, dir_name in backends.items():
|
||||
path = f'array_api_compat.{dir_name}'
|
||||
if path in name and backend not in xp_available_backends:
|
||||
return
|
||||
raise
|
||||
|
||||
modnames = []
|
||||
|
||||
for _, modname, _ in pkgutil.walk_packages(path=scipy.__path__,
|
||||
prefix=scipy.__name__ + '.',
|
||||
onerror=ignore_errors):
|
||||
if is_unexpected(modname) and modname not in SKIP_LIST:
|
||||
# We have a name that is new. If that's on purpose, add it to
|
||||
# PUBLIC_MODULES. We don't expect to have to add anything to
|
||||
# PRIVATE_BUT_PRESENT_MODULES. Use an underscore in the name!
|
||||
modnames.append(modname)
|
||||
|
||||
if modnames:
|
||||
raise AssertionError(f'Found unexpected modules: {modnames}')
|
||||
|
||||
|
||||
# Stuff that clearly shouldn't be in the API and is detected by the next test
|
||||
# below
|
||||
SKIP_LIST_2 = [
|
||||
'scipy.char',
|
||||
'scipy.rec',
|
||||
'scipy.emath',
|
||||
'scipy.math',
|
||||
'scipy.random',
|
||||
'scipy.ctypeslib',
|
||||
'scipy.ma'
|
||||
]
|
||||
|
||||
|
||||
def test_all_modules_are_expected_2():
|
||||
"""
|
||||
Method checking all objects. The pkgutil-based method in
|
||||
`test_all_modules_are_expected` does not catch imports into a namespace,
|
||||
only filenames.
|
||||
"""
|
||||
|
||||
def find_unexpected_members(mod_name):
|
||||
members = []
|
||||
module = importlib.import_module(mod_name)
|
||||
if hasattr(module, '__all__'):
|
||||
objnames = module.__all__
|
||||
else:
|
||||
objnames = dir(module)
|
||||
|
||||
for objname in objnames:
|
||||
if not objname.startswith('_'):
|
||||
fullobjname = mod_name + '.' + objname
|
||||
if isinstance(getattr(module, objname), types.ModuleType):
|
||||
if is_unexpected(fullobjname) and fullobjname not in SKIP_LIST_2:
|
||||
members.append(fullobjname)
|
||||
|
||||
return members
|
||||
|
||||
unexpected_members = find_unexpected_members("scipy")
|
||||
for modname in PUBLIC_MODULES:
|
||||
unexpected_members.extend(find_unexpected_members(modname))
|
||||
|
||||
if unexpected_members:
|
||||
raise AssertionError("Found unexpected object(s) that look like "
|
||||
f"modules: {unexpected_members}")
|
||||
|
||||
|
||||
def test_api_importable():
|
||||
"""
|
||||
Check that all submodules listed higher up in this file can be imported
|
||||
Note that if a PRIVATE_BUT_PRESENT_MODULES entry goes missing, it may
|
||||
simply need to be removed from the list (deprecation may or may not be
|
||||
needed - apply common sense).
|
||||
"""
|
||||
def check_importable(module_name):
|
||||
try:
|
||||
importlib.import_module(module_name)
|
||||
except (ImportError, AttributeError):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
module_names = []
|
||||
for module_name in PUBLIC_MODULES:
|
||||
if not check_importable(module_name):
|
||||
module_names.append(module_name)
|
||||
|
||||
if module_names:
|
||||
raise AssertionError("Modules in the public API that cannot be "
|
||||
f"imported: {module_names}")
|
||||
|
||||
with warnings.catch_warnings(record=True):
|
||||
warnings.filterwarnings('always', category=DeprecationWarning)
|
||||
warnings.filterwarnings('always', category=ImportWarning)
|
||||
for module_name in PRIVATE_BUT_PRESENT_MODULES:
|
||||
if not check_importable(module_name):
|
||||
module_names.append(module_name)
|
||||
|
||||
if module_names:
|
||||
raise AssertionError("Modules that are not really public but looked "
|
||||
"public and can not be imported: "
|
||||
f"{module_names}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("module_name", "correct_module"),
|
||||
[('scipy.constants.codata', None),
|
||||
('scipy.constants.constants', None),
|
||||
('scipy.fftpack.basic', None),
|
||||
('scipy.fftpack.helper', None),
|
||||
('scipy.fftpack.pseudo_diffs', None),
|
||||
('scipy.fftpack.realtransforms', None),
|
||||
('scipy.integrate.dop', None),
|
||||
('scipy.integrate.lsoda', None),
|
||||
('scipy.integrate.odepack', None),
|
||||
('scipy.integrate.quadpack', None),
|
||||
('scipy.integrate.vode', None),
|
||||
('scipy.interpolate.fitpack', None),
|
||||
('scipy.interpolate.fitpack2', None),
|
||||
('scipy.interpolate.interpolate', None),
|
||||
('scipy.interpolate.ndgriddata', None),
|
||||
('scipy.interpolate.polyint', None),
|
||||
('scipy.interpolate.rbf', None),
|
||||
('scipy.io.harwell_boeing', None),
|
||||
('scipy.io.idl', None),
|
||||
('scipy.io.mmio', None),
|
||||
('scipy.io.netcdf', None),
|
||||
('scipy.io.arff.arffread', 'arff'),
|
||||
('scipy.io.matlab.byteordercodes', 'matlab'),
|
||||
('scipy.io.matlab.mio_utils', 'matlab'),
|
||||
('scipy.io.matlab.mio', 'matlab'),
|
||||
('scipy.io.matlab.mio4', 'matlab'),
|
||||
('scipy.io.matlab.mio5_params', 'matlab'),
|
||||
('scipy.io.matlab.mio5_utils', 'matlab'),
|
||||
('scipy.io.matlab.mio5', 'matlab'),
|
||||
('scipy.io.matlab.miobase', 'matlab'),
|
||||
('scipy.io.matlab.streams', 'matlab'),
|
||||
('scipy.linalg.basic', None),
|
||||
('scipy.linalg.decomp', None),
|
||||
('scipy.linalg.decomp_cholesky', None),
|
||||
('scipy.linalg.decomp_lu', None),
|
||||
('scipy.linalg.decomp_qr', None),
|
||||
('scipy.linalg.decomp_schur', None),
|
||||
('scipy.linalg.decomp_svd', None),
|
||||
('scipy.linalg.matfuncs', None),
|
||||
('scipy.linalg.misc', None),
|
||||
('scipy.linalg.special_matrices', None),
|
||||
('scipy.misc.common', None),
|
||||
('scipy.ndimage.filters', None),
|
||||
('scipy.ndimage.fourier', None),
|
||||
('scipy.ndimage.interpolation', None),
|
||||
('scipy.ndimage.measurements', None),
|
||||
('scipy.ndimage.morphology', None),
|
||||
('scipy.odr.models', None),
|
||||
('scipy.odr.odrpack', None),
|
||||
('scipy.optimize.cobyla', None),
|
||||
('scipy.optimize.lbfgsb', None),
|
||||
('scipy.optimize.linesearch', None),
|
||||
('scipy.optimize.minpack', None),
|
||||
('scipy.optimize.minpack2', None),
|
||||
('scipy.optimize.moduleTNC', None),
|
||||
('scipy.optimize.nonlin', None),
|
||||
('scipy.optimize.optimize', None),
|
||||
('scipy.optimize.slsqp', None),
|
||||
('scipy.optimize.tnc', None),
|
||||
('scipy.optimize.zeros', None),
|
||||
('scipy.signal.bsplines', None),
|
||||
('scipy.signal.filter_design', None),
|
||||
('scipy.signal.fir_filter_design', None),
|
||||
('scipy.signal.lti_conversion', None),
|
||||
('scipy.signal.ltisys', None),
|
||||
('scipy.signal.signaltools', None),
|
||||
('scipy.signal.spectral', None),
|
||||
('scipy.signal.waveforms', None),
|
||||
('scipy.signal.wavelets', None),
|
||||
('scipy.signal.windows.windows', 'windows'),
|
||||
('scipy.sparse.lil', None),
|
||||
('scipy.sparse.linalg.dsolve', 'linalg'),
|
||||
('scipy.sparse.linalg.eigen', 'linalg'),
|
||||
('scipy.sparse.linalg.interface', 'linalg'),
|
||||
('scipy.sparse.linalg.isolve', 'linalg'),
|
||||
('scipy.sparse.linalg.matfuncs', 'linalg'),
|
||||
('scipy.sparse.sparsetools', None),
|
||||
('scipy.sparse.spfuncs', None),
|
||||
('scipy.sparse.sputils', None),
|
||||
('scipy.spatial.ckdtree', None),
|
||||
('scipy.spatial.kdtree', None),
|
||||
('scipy.spatial.qhull', None),
|
||||
('scipy.spatial.transform.rotation', 'transform'),
|
||||
('scipy.special.add_newdocs', None),
|
||||
('scipy.special.basic', None),
|
||||
('scipy.special.orthogonal', None),
|
||||
('scipy.special.sf_error', None),
|
||||
('scipy.special.specfun', None),
|
||||
('scipy.special.spfun_stats', None),
|
||||
('scipy.stats.biasedurn', None),
|
||||
('scipy.stats.kde', None),
|
||||
('scipy.stats.morestats', None),
|
||||
('scipy.stats.mstats_basic', 'mstats'),
|
||||
('scipy.stats.mstats_extras', 'mstats'),
|
||||
('scipy.stats.mvn', None),
|
||||
('scipy.stats.stats', None)])
|
||||
def test_private_but_present_deprecation(module_name, correct_module):
|
||||
# gh-18279, gh-17572, gh-17771 noted that deprecation warnings
|
||||
# for imports from private modules
|
||||
# were misleading. Check that this is resolved.
|
||||
module = import_module(module_name)
|
||||
if correct_module is None:
|
||||
import_name = f'scipy.{module_name.split(".")[1]}'
|
||||
else:
|
||||
import_name = f'scipy.{module_name.split(".")[1]}.{correct_module}'
|
||||
|
||||
correct_import = import_module(import_name)
|
||||
|
||||
# Attributes that were formerly in `module_name` can still be imported from
|
||||
# `module_name`, albeit with a deprecation warning.
|
||||
for attr_name in module.__all__:
|
||||
if attr_name == "varmats_from_mat":
|
||||
# defer handling this case, see
|
||||
# https://github.com/scipy/scipy/issues/19223
|
||||
continue
|
||||
# ensure attribute is present where the warning is pointing
|
||||
assert getattr(correct_import, attr_name, None) is not None
|
||||
message = f"Please import `{attr_name}` from the `{import_name}`..."
|
||||
with pytest.deprecated_call(match=message):
|
||||
getattr(module, attr_name)
|
||||
|
||||
# Attributes that were not in `module_name` get an error notifying the user
|
||||
# that the attribute is not in `module_name` and that `module_name` is deprecated.
|
||||
message = f"`{module_name}` is deprecated..."
|
||||
with pytest.raises(AttributeError, match=message):
|
||||
getattr(module, "ekki")
|
||||
|
||||
|
||||
def test_misc_doccer_deprecation():
|
||||
# gh-18279, gh-17572, gh-17771 noted that deprecation warnings
|
||||
# for imports from private modules were misleading.
|
||||
# Check that this is resolved.
|
||||
# `test_private_but_present_deprecation` cannot be used since `correct_import`
|
||||
# is a different subpackage (`_lib` instead of `misc`).
|
||||
module = import_module('scipy.misc.doccer')
|
||||
correct_import = import_module('scipy._lib.doccer')
|
||||
|
||||
# Attributes that were formerly in `scipy.misc.doccer` can still be imported from
|
||||
# `scipy.misc.doccer`, albeit with a deprecation warning. The specific message
|
||||
# depends on whether the attribute is in `scipy._lib.doccer` or not.
|
||||
for attr_name in module.__all__:
|
||||
attr = getattr(correct_import, attr_name, None)
|
||||
if attr is None:
|
||||
message = f"`scipy.misc.{attr_name}` is deprecated..."
|
||||
else:
|
||||
message = f"Please import `{attr_name}` from the `scipy._lib.doccer`..."
|
||||
with pytest.deprecated_call(match=message):
|
||||
getattr(module, attr_name)
|
||||
|
||||
# Attributes that were not in `scipy.misc.doccer` get an error
|
||||
# notifying the user that the attribute is not in `scipy.misc.doccer`
|
||||
# and that `scipy.misc.doccer` is deprecated.
|
||||
message = "`scipy.misc.doccer` is deprecated..."
|
||||
with pytest.raises(AttributeError, match=message):
|
||||
getattr(module, "ekki")
|
||||
@ -0,0 +1,18 @@
|
||||
import re
|
||||
|
||||
import scipy
|
||||
from numpy.testing import assert_
|
||||
|
||||
|
||||
def test_valid_scipy_version():
|
||||
# Verify that the SciPy version is a valid one (no .post suffix or other
|
||||
# nonsense). See NumPy issue gh-6431 for an issue caused by an invalid
|
||||
# version.
|
||||
version_pattern = r"^[0-9]+\.[0-9]+\.[0-9]+(|a[0-9]|b[0-9]|rc[0-9])"
|
||||
dev_suffix = r"(\.dev0\+.+([0-9a-f]{7}|Unknown))"
|
||||
if scipy.version.release:
|
||||
res = re.match(version_pattern, scipy.__version__)
|
||||
else:
|
||||
res = re.match(version_pattern + dev_suffix, scipy.__version__)
|
||||
|
||||
assert_(res is not None, scipy.__version__)
|
||||
@ -0,0 +1,42 @@
|
||||
""" Test tmpdirs module """
|
||||
from os import getcwd
|
||||
from os.path import realpath, abspath, dirname, isfile, join as pjoin, exists
|
||||
|
||||
from scipy._lib._tmpdirs import tempdir, in_tempdir, in_dir
|
||||
|
||||
from numpy.testing import assert_, assert_equal
|
||||
|
||||
MY_PATH = abspath(__file__)
|
||||
MY_DIR = dirname(MY_PATH)
|
||||
|
||||
|
||||
def test_tempdir():
|
||||
with tempdir() as tmpdir:
|
||||
fname = pjoin(tmpdir, 'example_file.txt')
|
||||
with open(fname, "w") as fobj:
|
||||
fobj.write('a string\\n')
|
||||
assert_(not exists(tmpdir))
|
||||
|
||||
|
||||
def test_in_tempdir():
|
||||
my_cwd = getcwd()
|
||||
with in_tempdir() as tmpdir:
|
||||
with open('test.txt', "w") as f:
|
||||
f.write('some text')
|
||||
assert_(isfile('test.txt'))
|
||||
assert_(isfile(pjoin(tmpdir, 'test.txt')))
|
||||
assert_(not exists(tmpdir))
|
||||
assert_equal(getcwd(), my_cwd)
|
||||
|
||||
|
||||
def test_given_directory():
|
||||
# Test InGivenDirectory
|
||||
cwd = getcwd()
|
||||
with in_dir() as tmpdir:
|
||||
assert_equal(tmpdir, abspath(cwd))
|
||||
assert_equal(tmpdir, abspath(getcwd()))
|
||||
with in_dir(MY_DIR) as tmpdir:
|
||||
assert_equal(tmpdir, MY_DIR)
|
||||
assert_equal(realpath(MY_DIR), realpath(abspath(getcwd())))
|
||||
# We were deleting the given directory! Check not so now.
|
||||
assert_(isfile(MY_PATH))
|
||||
@ -0,0 +1,135 @@
|
||||
"""
|
||||
Tests which scan for certain occurrences in the code, they may not find
|
||||
all of these occurrences but should catch almost all. This file was adapted
|
||||
from NumPy.
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import ast
|
||||
import tokenize
|
||||
|
||||
import scipy
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class ParseCall(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
self.ls = []
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
self.ls.append(node.attr)
|
||||
|
||||
def visit_Name(self, node):
|
||||
self.ls.append(node.id)
|
||||
|
||||
|
||||
class FindFuncs(ast.NodeVisitor):
|
||||
def __init__(self, filename):
|
||||
super().__init__()
|
||||
self.__filename = filename
|
||||
self.bad_filters = []
|
||||
self.bad_stacklevels = []
|
||||
|
||||
def visit_Call(self, node):
|
||||
p = ParseCall()
|
||||
p.visit(node.func)
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
if p.ls[-1] == 'simplefilter' or p.ls[-1] == 'filterwarnings':
|
||||
# get first argument of the `args` node of the filter call
|
||||
match node.args[0]:
|
||||
case ast.Constant() as c:
|
||||
argtext = c.value
|
||||
case ast.JoinedStr() as js:
|
||||
# if we get an f-string, discard the templated pieces, which
|
||||
# are likely the type or specific message; we're interested
|
||||
# in the action, which is less likely to use a template
|
||||
argtext = "".join(
|
||||
x.value for x in js.values if isinstance(x, ast.Constant)
|
||||
)
|
||||
case _:
|
||||
raise ValueError("unknown ast node type")
|
||||
# check if filter is set to ignore
|
||||
if argtext == "ignore":
|
||||
self.bad_filters.append(
|
||||
f"{self.__filename}:{node.lineno}")
|
||||
|
||||
if p.ls[-1] == 'warn' and (
|
||||
len(p.ls) == 1 or p.ls[-2] == 'warnings'):
|
||||
|
||||
if self.__filename == "_lib/tests/test_warnings.py":
|
||||
# This file
|
||||
return
|
||||
|
||||
# See if stacklevel exists:
|
||||
if len(node.args) == 3:
|
||||
return
|
||||
args = {kw.arg for kw in node.keywords}
|
||||
if "stacklevel" not in args:
|
||||
self.bad_stacklevels.append(
|
||||
f"{self.__filename}:{node.lineno}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def warning_calls():
|
||||
# combined "ignore" and stacklevel error
|
||||
base = Path(scipy.__file__).parent
|
||||
|
||||
bad_filters = []
|
||||
bad_stacklevels = []
|
||||
|
||||
for path in base.rglob("*.py"):
|
||||
# use tokenize to auto-detect encoding on systems where no
|
||||
# default encoding is defined (e.g., LANG='C')
|
||||
with tokenize.open(str(path)) as file:
|
||||
tree = ast.parse(file.read(), filename=str(path))
|
||||
finder = FindFuncs(path.relative_to(base))
|
||||
finder.visit(tree)
|
||||
bad_filters.extend(finder.bad_filters)
|
||||
bad_stacklevels.extend(finder.bad_stacklevels)
|
||||
|
||||
return bad_filters, bad_stacklevels
|
||||
|
||||
|
||||
@pytest.mark.fail_slow(20)
|
||||
@pytest.mark.slow
|
||||
def test_warning_calls_filters(warning_calls):
|
||||
bad_filters, bad_stacklevels = warning_calls
|
||||
|
||||
# We try not to add filters in the code base, because those filters aren't
|
||||
# thread-safe. We aim to only filter in tests with
|
||||
# np.testing.suppress_warnings. However, in some cases it may prove
|
||||
# necessary to filter out warnings, because we can't (easily) fix the root
|
||||
# cause for them and we don't want users to see some warnings when they use
|
||||
# SciPy correctly. So we list exceptions here. Add new entries only if
|
||||
# there's a good reason.
|
||||
allowed_filters = (
|
||||
os.path.join('datasets', '_fetchers.py'),
|
||||
os.path.join('datasets', '__init__.py'),
|
||||
os.path.join('optimize', '_optimize.py'),
|
||||
os.path.join('optimize', '_constraints.py'),
|
||||
os.path.join('optimize', '_nnls.py'),
|
||||
os.path.join('signal', '_ltisys.py'),
|
||||
os.path.join('sparse', '__init__.py'), # np.matrix pending-deprecation
|
||||
os.path.join('stats', '_discrete_distns.py'), # gh-14901
|
||||
os.path.join('stats', '_continuous_distns.py'),
|
||||
os.path.join('stats', '_binned_statistic.py'), # gh-19345
|
||||
os.path.join('stats', 'tests', 'test_axis_nan_policy.py'), # gh-20694
|
||||
os.path.join('_lib', '_util.py'), # gh-19341
|
||||
os.path.join('sparse', 'linalg', '_dsolve', 'linsolve.py'), # gh-17924
|
||||
"conftest.py",
|
||||
)
|
||||
bad_filters = [item for item in bad_filters if item.split(':')[0] not in
|
||||
allowed_filters]
|
||||
|
||||
if bad_filters:
|
||||
raise AssertionError(
|
||||
"warning ignore filter should not be used, instead, use\n"
|
||||
"numpy.testing.suppress_warnings (in tests only);\n"
|
||||
"found in:\n {}".format(
|
||||
"\n ".join(bad_filters)))
|
||||
|
||||
31
venv/lib/python3.12/site-packages/scipy/_lib/uarray.py
Normal file
31
venv/lib/python3.12/site-packages/scipy/_lib/uarray.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""`uarray` provides functions for generating multimethods that dispatch to
|
||||
multiple different backends
|
||||
|
||||
This should be imported, rather than `_uarray` so that an installed version could
|
||||
be used instead, if available. This means that users can call
|
||||
`uarray.set_backend` directly instead of going through SciPy.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# Prefer an installed version of uarray, if available
|
||||
try:
|
||||
import uarray as _uarray
|
||||
except ImportError:
|
||||
_has_uarray = False
|
||||
else:
|
||||
from scipy._lib._pep440 import Version as _Version
|
||||
|
||||
_has_uarray = _Version(_uarray.__version__) >= _Version("0.8")
|
||||
del _uarray
|
||||
del _Version
|
||||
|
||||
|
||||
if _has_uarray:
|
||||
from uarray import * # noqa: F403
|
||||
from uarray import _Function
|
||||
else:
|
||||
from ._uarray import * # noqa: F403
|
||||
from ._uarray import _Function # noqa: F401
|
||||
|
||||
del _has_uarray
|
||||
31
venv/lib/python3.12/site-packages/scipy/cluster/__init__.py
Normal file
31
venv/lib/python3.12/site-packages/scipy/cluster/__init__.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""
|
||||
=========================================
|
||||
Clustering package (:mod:`scipy.cluster`)
|
||||
=========================================
|
||||
|
||||
.. currentmodule:: scipy.cluster
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
cluster.vq
|
||||
cluster.hierarchy
|
||||
|
||||
Clustering algorithms are useful in information theory, target detection,
|
||||
communications, compression, and other areas. The `vq` module only
|
||||
supports vector quantization and the k-means algorithms.
|
||||
|
||||
The `hierarchy` module provides functions for hierarchical and
|
||||
agglomerative clustering. Its features include generating hierarchical
|
||||
clusters from distance matrices,
|
||||
calculating statistics on clusters, cutting linkages
|
||||
to generate flat clusters, and visualizing clusters with dendrograms.
|
||||
|
||||
"""
|
||||
__all__ = ['vq', 'hierarchy']
|
||||
|
||||
from . import vq, hierarchy
|
||||
|
||||
from scipy._lib._testutils import PytestTester
|
||||
test = PytestTester(__name__)
|
||||
del PytestTester
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
4173
venv/lib/python3.12/site-packages/scipy/cluster/hierarchy.py
Normal file
4173
venv/lib/python3.12/site-packages/scipy/cluster/hierarchy.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,145 @@
|
||||
from numpy import array
|
||||
|
||||
|
||||
Q_X = array([[5.26563660e-01, 3.14160190e-01, 8.00656370e-02],
|
||||
[7.50205180e-01, 4.60299830e-01, 8.98696460e-01],
|
||||
[6.65461230e-01, 6.94011420e-01, 9.10465700e-01],
|
||||
[9.64047590e-01, 1.43082200e-03, 7.39874220e-01],
|
||||
[1.08159060e-01, 5.53028790e-01, 6.63804780e-02],
|
||||
[9.31359130e-01, 8.25424910e-01, 9.52315440e-01],
|
||||
[6.78086960e-01, 3.41903970e-01, 5.61481950e-01],
|
||||
[9.82730940e-01, 7.04605210e-01, 8.70978630e-02],
|
||||
[6.14691610e-01, 4.69989230e-02, 6.02406450e-01],
|
||||
[5.80161260e-01, 9.17354970e-01, 5.88163850e-01],
|
||||
[1.38246310e+00, 1.96358160e+00, 1.94437880e+00],
|
||||
[2.10675860e+00, 1.67148730e+00, 1.34854480e+00],
|
||||
[1.39880070e+00, 1.66142050e+00, 1.32224550e+00],
|
||||
[1.71410460e+00, 1.49176380e+00, 1.45432170e+00],
|
||||
[1.54102340e+00, 1.84374950e+00, 1.64658950e+00],
|
||||
[2.08512480e+00, 1.84524350e+00, 2.17340850e+00],
|
||||
[1.30748740e+00, 1.53801650e+00, 2.16007740e+00],
|
||||
[1.41447700e+00, 1.99329070e+00, 1.99107420e+00],
|
||||
[1.61943490e+00, 1.47703280e+00, 1.89788160e+00],
|
||||
[1.59880600e+00, 1.54988980e+00, 1.57563350e+00],
|
||||
[3.37247380e+00, 2.69635310e+00, 3.39981700e+00],
|
||||
[3.13705120e+00, 3.36528090e+00, 3.06089070e+00],
|
||||
[3.29413250e+00, 3.19619500e+00, 2.90700170e+00],
|
||||
[2.65510510e+00, 3.06785900e+00, 2.97198540e+00],
|
||||
[3.30941040e+00, 2.59283970e+00, 2.57714110e+00],
|
||||
[2.59557220e+00, 3.33477370e+00, 3.08793190e+00],
|
||||
[2.58206180e+00, 3.41615670e+00, 3.26441990e+00],
|
||||
[2.71127000e+00, 2.77032450e+00, 2.63466500e+00],
|
||||
[2.79617850e+00, 3.25473720e+00, 3.41801560e+00],
|
||||
[2.64741750e+00, 2.54538040e+00, 3.25354110e+00]])
|
||||
|
||||
ytdist = array([662., 877., 255., 412., 996., 295., 468., 268., 400., 754.,
|
||||
564., 138., 219., 869., 669.])
|
||||
|
||||
linkage_ytdist_single = array([[2., 5., 138., 2.],
|
||||
[3., 4., 219., 2.],
|
||||
[0., 7., 255., 3.],
|
||||
[1., 8., 268., 4.],
|
||||
[6., 9., 295., 6.]])
|
||||
|
||||
linkage_ytdist_complete = array([[2., 5., 138., 2.],
|
||||
[3., 4., 219., 2.],
|
||||
[1., 6., 400., 3.],
|
||||
[0., 7., 412., 3.],
|
||||
[8., 9., 996., 6.]])
|
||||
|
||||
linkage_ytdist_average = array([[2., 5., 138., 2.],
|
||||
[3., 4., 219., 2.],
|
||||
[0., 7., 333.5, 3.],
|
||||
[1., 6., 347.5, 3.],
|
||||
[8., 9., 680.77777778, 6.]])
|
||||
|
||||
linkage_ytdist_weighted = array([[2., 5., 138., 2.],
|
||||
[3., 4., 219., 2.],
|
||||
[0., 7., 333.5, 3.],
|
||||
[1., 6., 347.5, 3.],
|
||||
[8., 9., 670.125, 6.]])
|
||||
|
||||
# the optimal leaf ordering of linkage_ytdist_single
|
||||
linkage_ytdist_single_olo = array([[5., 2., 138., 2.],
|
||||
[4., 3., 219., 2.],
|
||||
[7., 0., 255., 3.],
|
||||
[1., 8., 268., 4.],
|
||||
[6., 9., 295., 6.]])
|
||||
|
||||
X = array([[1.43054825, -7.5693489],
|
||||
[6.95887839, 6.82293382],
|
||||
[2.87137846, -9.68248579],
|
||||
[7.87974764, -6.05485803],
|
||||
[8.24018364, -6.09495602],
|
||||
[7.39020262, 8.54004355]])
|
||||
|
||||
linkage_X_centroid = array([[3., 4., 0.36265956, 2.],
|
||||
[1., 5., 1.77045373, 2.],
|
||||
[0., 2., 2.55760419, 2.],
|
||||
[6., 8., 6.43614494, 4.],
|
||||
[7., 9., 15.17363237, 6.]])
|
||||
|
||||
linkage_X_median = array([[3., 4., 0.36265956, 2.],
|
||||
[1., 5., 1.77045373, 2.],
|
||||
[0., 2., 2.55760419, 2.],
|
||||
[6., 8., 6.43614494, 4.],
|
||||
[7., 9., 15.17363237, 6.]])
|
||||
|
||||
linkage_X_ward = array([[3., 4., 0.36265956, 2.],
|
||||
[1., 5., 1.77045373, 2.],
|
||||
[0., 2., 2.55760419, 2.],
|
||||
[6., 8., 9.10208346, 4.],
|
||||
[7., 9., 24.7784379, 6.]])
|
||||
|
||||
# the optimal leaf ordering of linkage_X_ward
|
||||
linkage_X_ward_olo = array([[4., 3., 0.36265956, 2.],
|
||||
[5., 1., 1.77045373, 2.],
|
||||
[2., 0., 2.55760419, 2.],
|
||||
[6., 8., 9.10208346, 4.],
|
||||
[7., 9., 24.7784379, 6.]])
|
||||
|
||||
inconsistent_ytdist = {
|
||||
1: array([[138., 0., 1., 0.],
|
||||
[219., 0., 1., 0.],
|
||||
[255., 0., 1., 0.],
|
||||
[268., 0., 1., 0.],
|
||||
[295., 0., 1., 0.]]),
|
||||
2: array([[138., 0., 1., 0.],
|
||||
[219., 0., 1., 0.],
|
||||
[237., 25.45584412, 2., 0.70710678],
|
||||
[261.5, 9.19238816, 2., 0.70710678],
|
||||
[233.66666667, 83.9424406, 3., 0.7306594]]),
|
||||
3: array([[138., 0., 1., 0.],
|
||||
[219., 0., 1., 0.],
|
||||
[237., 25.45584412, 2., 0.70710678],
|
||||
[247.33333333, 25.38372182, 3., 0.81417007],
|
||||
[239., 69.36377537, 4., 0.80733783]]),
|
||||
4: array([[138., 0., 1., 0.],
|
||||
[219., 0., 1., 0.],
|
||||
[237., 25.45584412, 2., 0.70710678],
|
||||
[247.33333333, 25.38372182, 3., 0.81417007],
|
||||
[235., 60.73302232, 5., 0.98793042]])}
|
||||
|
||||
fcluster_inconsistent = {
|
||||
0.8: array([6, 2, 2, 4, 6, 2, 3, 7, 3, 5, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1]),
|
||||
1.0: array([6, 2, 2, 4, 6, 2, 3, 7, 3, 5, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1]),
|
||||
2.0: array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1])}
|
||||
|
||||
fcluster_distance = {
|
||||
0.6: array([4, 4, 4, 4, 4, 4, 4, 5, 4, 4, 6, 6, 6, 6, 6, 7, 6, 6, 6, 6, 3,
|
||||
1, 1, 1, 2, 1, 1, 1, 1, 1]),
|
||||
1.0: array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1]),
|
||||
2.0: array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1])}
|
||||
|
||||
fcluster_maxclust = {
|
||||
8.0: array([5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 7, 7, 7, 7, 7, 8, 7, 7, 7, 7, 4,
|
||||
1, 1, 1, 3, 1, 1, 1, 1, 2]),
|
||||
4.0: array([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1]),
|
||||
1.0: array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1])}
|
||||
@ -0,0 +1,202 @@
|
||||
import pytest
|
||||
from pytest import raises as assert_raises
|
||||
import numpy as np
|
||||
from scipy.cluster.hierarchy import DisjointSet
|
||||
import string
|
||||
|
||||
|
||||
def generate_random_token():
|
||||
k = len(string.ascii_letters)
|
||||
tokens = list(np.arange(k, dtype=int))
|
||||
tokens += list(np.arange(k, dtype=float))
|
||||
tokens += list(string.ascii_letters)
|
||||
tokens += [None for i in range(k)]
|
||||
tokens = np.array(tokens, dtype=object)
|
||||
rng = np.random.RandomState(seed=0)
|
||||
|
||||
while 1:
|
||||
size = rng.randint(1, 3)
|
||||
element = rng.choice(tokens, size)
|
||||
if size == 1:
|
||||
yield element[0]
|
||||
else:
|
||||
yield tuple(element)
|
||||
|
||||
|
||||
def get_elements(n):
|
||||
# dict is deterministic without difficulty of comparing numpy ints
|
||||
elements = {}
|
||||
for element in generate_random_token():
|
||||
if element not in elements:
|
||||
elements[element] = len(elements)
|
||||
if len(elements) >= n:
|
||||
break
|
||||
return list(elements.keys())
|
||||
|
||||
|
||||
def test_init():
|
||||
n = 10
|
||||
elements = get_elements(n)
|
||||
dis = DisjointSet(elements)
|
||||
assert dis.n_subsets == n
|
||||
assert list(dis) == elements
|
||||
|
||||
|
||||
def test_len():
|
||||
n = 10
|
||||
elements = get_elements(n)
|
||||
dis = DisjointSet(elements)
|
||||
assert len(dis) == n
|
||||
|
||||
dis.add("dummy")
|
||||
assert len(dis) == n + 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n", [10, 100])
|
||||
def test_contains(n):
|
||||
elements = get_elements(n)
|
||||
dis = DisjointSet(elements)
|
||||
for x in elements:
|
||||
assert x in dis
|
||||
|
||||
assert "dummy" not in dis
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n", [10, 100])
|
||||
def test_add(n):
|
||||
elements = get_elements(n)
|
||||
dis1 = DisjointSet(elements)
|
||||
|
||||
dis2 = DisjointSet()
|
||||
for i, x in enumerate(elements):
|
||||
dis2.add(x)
|
||||
assert len(dis2) == i + 1
|
||||
|
||||
# test idempotency by adding element again
|
||||
dis2.add(x)
|
||||
assert len(dis2) == i + 1
|
||||
|
||||
assert list(dis1) == list(dis2)
|
||||
|
||||
|
||||
def test_element_not_present():
|
||||
elements = get_elements(n=10)
|
||||
dis = DisjointSet(elements)
|
||||
|
||||
with assert_raises(KeyError):
|
||||
dis["dummy"]
|
||||
|
||||
with assert_raises(KeyError):
|
||||
dis.merge(elements[0], "dummy")
|
||||
|
||||
with assert_raises(KeyError):
|
||||
dis.connected(elements[0], "dummy")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("direction", ["forwards", "backwards"])
|
||||
@pytest.mark.parametrize("n", [10, 100])
|
||||
def test_linear_union_sequence(n, direction):
|
||||
elements = get_elements(n)
|
||||
dis = DisjointSet(elements)
|
||||
assert elements == list(dis)
|
||||
|
||||
indices = list(range(n - 1))
|
||||
if direction == "backwards":
|
||||
indices = indices[::-1]
|
||||
|
||||
for it, i in enumerate(indices):
|
||||
assert not dis.connected(elements[i], elements[i + 1])
|
||||
assert dis.merge(elements[i], elements[i + 1])
|
||||
assert dis.connected(elements[i], elements[i + 1])
|
||||
assert dis.n_subsets == n - 1 - it
|
||||
|
||||
roots = [dis[i] for i in elements]
|
||||
if direction == "forwards":
|
||||
assert all(elements[0] == r for r in roots)
|
||||
else:
|
||||
assert all(elements[-2] == r for r in roots)
|
||||
assert not dis.merge(elements[0], elements[-1])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n", [10, 100])
|
||||
def test_self_unions(n):
|
||||
elements = get_elements(n)
|
||||
dis = DisjointSet(elements)
|
||||
|
||||
for x in elements:
|
||||
assert dis.connected(x, x)
|
||||
assert not dis.merge(x, x)
|
||||
assert dis.connected(x, x)
|
||||
assert dis.n_subsets == len(elements)
|
||||
|
||||
assert elements == list(dis)
|
||||
roots = [dis[x] for x in elements]
|
||||
assert elements == roots
|
||||
|
||||
|
||||
@pytest.mark.parametrize("order", ["ab", "ba"])
|
||||
@pytest.mark.parametrize("n", [10, 100])
|
||||
def test_equal_size_ordering(n, order):
|
||||
elements = get_elements(n)
|
||||
dis = DisjointSet(elements)
|
||||
|
||||
rng = np.random.RandomState(seed=0)
|
||||
indices = np.arange(n)
|
||||
rng.shuffle(indices)
|
||||
|
||||
for i in range(0, len(indices), 2):
|
||||
a, b = elements[indices[i]], elements[indices[i + 1]]
|
||||
if order == "ab":
|
||||
assert dis.merge(a, b)
|
||||
else:
|
||||
assert dis.merge(b, a)
|
||||
|
||||
expected = elements[min(indices[i], indices[i + 1])]
|
||||
assert dis[a] == expected
|
||||
assert dis[b] == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kmax", [5, 10])
|
||||
def test_binary_tree(kmax):
|
||||
n = 2**kmax
|
||||
elements = get_elements(n)
|
||||
dis = DisjointSet(elements)
|
||||
rng = np.random.RandomState(seed=0)
|
||||
|
||||
for k in 2**np.arange(kmax):
|
||||
for i in range(0, n, 2 * k):
|
||||
r1, r2 = rng.randint(0, k, size=2)
|
||||
a, b = elements[i + r1], elements[i + k + r2]
|
||||
assert not dis.connected(a, b)
|
||||
assert dis.merge(a, b)
|
||||
assert dis.connected(a, b)
|
||||
|
||||
assert elements == list(dis)
|
||||
roots = [dis[i] for i in elements]
|
||||
expected_indices = np.arange(n) - np.arange(n) % (2 * k)
|
||||
expected = [elements[i] for i in expected_indices]
|
||||
assert roots == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n", [10, 100])
|
||||
def test_subsets(n):
|
||||
elements = get_elements(n)
|
||||
dis = DisjointSet(elements)
|
||||
|
||||
rng = np.random.RandomState(seed=0)
|
||||
for i, j in rng.randint(0, n, (n, 2)):
|
||||
x = elements[i]
|
||||
y = elements[j]
|
||||
|
||||
expected = {element for element in dis if {dis[element]} == {dis[x]}}
|
||||
assert dis.subset_size(x) == len(dis.subset(x))
|
||||
assert expected == dis.subset(x)
|
||||
|
||||
expected = {dis[element]: set() for element in dis}
|
||||
for element in dis:
|
||||
expected[dis[element]].add(element)
|
||||
expected = list(expected.values())
|
||||
assert expected == dis.subsets()
|
||||
|
||||
dis.merge(x, y)
|
||||
assert dis.subset(x) == dis.subset(y)
|
||||
File diff suppressed because it is too large
Load Diff
435
venv/lib/python3.12/site-packages/scipy/cluster/tests/test_vq.py
Normal file
435
venv/lib/python3.12/site-packages/scipy/cluster/tests/test_vq.py
Normal file
@ -0,0 +1,435 @@
|
||||
import warnings
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import (
|
||||
assert_array_equal, assert_equal, assert_, suppress_warnings
|
||||
)
|
||||
import pytest
|
||||
from pytest import raises as assert_raises
|
||||
|
||||
from scipy.cluster.vq import (kmeans, kmeans2, py_vq, vq, whiten,
|
||||
ClusterError, _krandinit)
|
||||
from scipy.cluster import _vq
|
||||
from scipy.conftest import array_api_compatible
|
||||
from scipy.sparse._sputils import matrix
|
||||
|
||||
from scipy._lib._array_api import (
|
||||
SCIPY_ARRAY_API, copy, cov, xp_assert_close, xp_assert_equal
|
||||
)
|
||||
|
||||
pytestmark = [array_api_compatible, pytest.mark.usefixtures("skip_xp_backends")]
|
||||
skip_xp_backends = pytest.mark.skip_xp_backends
|
||||
|
||||
TESTDATA_2D = np.array([
|
||||
-2.2, 1.17, -1.63, 1.69, -2.04, 4.38, -3.09, 0.95, -1.7, 4.79, -1.68, 0.68,
|
||||
-2.26, 3.34, -2.29, 2.55, -1.72, -0.72, -1.99, 2.34, -2.75, 3.43, -2.45,
|
||||
2.41, -4.26, 3.65, -1.57, 1.87, -1.96, 4.03, -3.01, 3.86, -2.53, 1.28,
|
||||
-4.0, 3.95, -1.62, 1.25, -3.42, 3.17, -1.17, 0.12, -3.03, -0.27, -2.07,
|
||||
-0.55, -1.17, 1.34, -2.82, 3.08, -2.44, 0.24, -1.71, 2.48, -5.23, 4.29,
|
||||
-2.08, 3.69, -1.89, 3.62, -2.09, 0.26, -0.92, 1.07, -2.25, 0.88, -2.25,
|
||||
2.02, -4.31, 3.86, -2.03, 3.42, -2.76, 0.3, -2.48, -0.29, -3.42, 3.21,
|
||||
-2.3, 1.73, -2.84, 0.69, -1.81, 2.48, -5.24, 4.52, -2.8, 1.31, -1.67,
|
||||
-2.34, -1.18, 2.17, -2.17, 2.82, -1.85, 2.25, -2.45, 1.86, -6.79, 3.94,
|
||||
-2.33, 1.89, -1.55, 2.08, -1.36, 0.93, -2.51, 2.74, -2.39, 3.92, -3.33,
|
||||
2.99, -2.06, -0.9, -2.83, 3.35, -2.59, 3.05, -2.36, 1.85, -1.69, 1.8,
|
||||
-1.39, 0.66, -2.06, 0.38, -1.47, 0.44, -4.68, 3.77, -5.58, 3.44, -2.29,
|
||||
2.24, -1.04, -0.38, -1.85, 4.23, -2.88, 0.73, -2.59, 1.39, -1.34, 1.75,
|
||||
-1.95, 1.3, -2.45, 3.09, -1.99, 3.41, -5.55, 5.21, -1.73, 2.52, -2.17,
|
||||
0.85, -2.06, 0.49, -2.54, 2.07, -2.03, 1.3, -3.23, 3.09, -1.55, 1.44,
|
||||
-0.81, 1.1, -2.99, 2.92, -1.59, 2.18, -2.45, -0.73, -3.12, -1.3, -2.83,
|
||||
0.2, -2.77, 3.24, -1.98, 1.6, -4.59, 3.39, -4.85, 3.75, -2.25, 1.71, -3.28,
|
||||
3.38, -1.74, 0.88, -2.41, 1.92, -2.24, 1.19, -2.48, 1.06, -1.68, -0.62,
|
||||
-1.3, 0.39, -1.78, 2.35, -3.54, 2.44, -1.32, 0.66, -2.38, 2.76, -2.35,
|
||||
3.95, -1.86, 4.32, -2.01, -1.23, -1.79, 2.76, -2.13, -0.13, -5.25, 3.84,
|
||||
-2.24, 1.59, -4.85, 2.96, -2.41, 0.01, -0.43, 0.13, -3.92, 2.91, -1.75,
|
||||
-0.53, -1.69, 1.69, -1.09, 0.15, -2.11, 2.17, -1.53, 1.22, -2.1, -0.86,
|
||||
-2.56, 2.28, -3.02, 3.33, -1.12, 3.86, -2.18, -1.19, -3.03, 0.79, -0.83,
|
||||
0.97, -3.19, 1.45, -1.34, 1.28, -2.52, 4.22, -4.53, 3.22, -1.97, 1.75,
|
||||
-2.36, 3.19, -0.83, 1.53, -1.59, 1.86, -2.17, 2.3, -1.63, 2.71, -2.03,
|
||||
3.75, -2.57, -0.6, -1.47, 1.33, -1.95, 0.7, -1.65, 1.27, -1.42, 1.09, -3.0,
|
||||
3.87, -2.51, 3.06, -2.6, 0.74, -1.08, -0.03, -2.44, 1.31, -2.65, 2.99,
|
||||
-1.84, 1.65, -4.76, 3.75, -2.07, 3.98, -2.4, 2.67, -2.21, 1.49, -1.21,
|
||||
1.22, -5.29, 2.38, -2.85, 2.28, -5.6, 3.78, -2.7, 0.8, -1.81, 3.5, -3.75,
|
||||
4.17, -1.29, 2.99, -5.92, 3.43, -1.83, 1.23, -1.24, -1.04, -2.56, 2.37,
|
||||
-3.26, 0.39, -4.63, 2.51, -4.52, 3.04, -1.7, 0.36, -1.41, 0.04, -2.1, 1.0,
|
||||
-1.87, 3.78, -4.32, 3.59, -2.24, 1.38, -1.99, -0.22, -1.87, 1.95, -0.84,
|
||||
2.17, -5.38, 3.56, -1.27, 2.9, -1.79, 3.31, -5.47, 3.85, -1.44, 3.69,
|
||||
-2.02, 0.37, -1.29, 0.33, -2.34, 2.56, -1.74, -1.27, -1.97, 1.22, -2.51,
|
||||
-0.16, -1.64, -0.96, -2.99, 1.4, -1.53, 3.31, -2.24, 0.45, -2.46, 1.71,
|
||||
-2.88, 1.56, -1.63, 1.46, -1.41, 0.68, -1.96, 2.76, -1.61,
|
||||
2.11]).reshape((200, 2))
|
||||
|
||||
|
||||
# Global data
|
||||
X = np.array([[3.0, 3], [4, 3], [4, 2],
|
||||
[9, 2], [5, 1], [6, 2], [9, 4],
|
||||
[5, 2], [5, 4], [7, 4], [6, 5]])
|
||||
|
||||
CODET1 = np.array([[3.0000, 3.0000],
|
||||
[6.2000, 4.0000],
|
||||
[5.8000, 1.8000]])
|
||||
|
||||
CODET2 = np.array([[11.0/3, 8.0/3],
|
||||
[6.7500, 4.2500],
|
||||
[6.2500, 1.7500]])
|
||||
|
||||
LABEL1 = np.array([0, 1, 2, 2, 2, 2, 1, 2, 1, 1, 1])
|
||||
|
||||
|
||||
class TestWhiten:
|
||||
|
||||
def test_whiten(self, xp):
|
||||
desired = xp.asarray([[5.08738849, 2.97091878],
|
||||
[3.19909255, 0.69660580],
|
||||
[4.51041982, 0.02640918],
|
||||
[4.38567074, 0.95120889],
|
||||
[2.32191480, 1.63195503]])
|
||||
|
||||
obs = xp.asarray([[0.98744510, 0.82766775],
|
||||
[0.62093317, 0.19406729],
|
||||
[0.87545741, 0.00735733],
|
||||
[0.85124403, 0.26499712],
|
||||
[0.45067590, 0.45464607]])
|
||||
xp_assert_close(whiten(obs), desired, rtol=1e-5)
|
||||
|
||||
@skip_xp_backends('jax.numpy',
|
||||
reasons=['jax arrays do not support item assignment'])
|
||||
def test_whiten_zero_std(self, xp):
|
||||
desired = xp.asarray([[0., 1.0, 2.86666544],
|
||||
[0., 1.0, 1.32460034],
|
||||
[0., 1.0, 3.74382172]])
|
||||
|
||||
obs = xp.asarray([[0., 1., 0.74109533],
|
||||
[0., 1., 0.34243798],
|
||||
[0., 1., 0.96785929]])
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter('always')
|
||||
|
||||
xp_assert_close(whiten(obs), desired, rtol=1e-5)
|
||||
|
||||
assert_equal(len(w), 1)
|
||||
assert_(issubclass(w[-1].category, RuntimeWarning))
|
||||
|
||||
def test_whiten_not_finite(self, xp):
|
||||
for bad_value in xp.nan, xp.inf, -xp.inf:
|
||||
obs = xp.asarray([[0.98744510, bad_value],
|
||||
[0.62093317, 0.19406729],
|
||||
[0.87545741, 0.00735733],
|
||||
[0.85124403, 0.26499712],
|
||||
[0.45067590, 0.45464607]])
|
||||
assert_raises(ValueError, whiten, obs)
|
||||
|
||||
@pytest.mark.skipif(SCIPY_ARRAY_API,
|
||||
reason='`np.matrix` unsupported in array API mode')
|
||||
def test_whiten_not_finite_matrix(self, xp):
|
||||
for bad_value in np.nan, np.inf, -np.inf:
|
||||
obs = matrix([[0.98744510, bad_value],
|
||||
[0.62093317, 0.19406729],
|
||||
[0.87545741, 0.00735733],
|
||||
[0.85124403, 0.26499712],
|
||||
[0.45067590, 0.45464607]])
|
||||
assert_raises(ValueError, whiten, obs)
|
||||
|
||||
|
||||
class TestVq:
|
||||
|
||||
@skip_xp_backends(cpu_only=True)
|
||||
def test_py_vq(self, xp):
|
||||
initc = np.concatenate([[X[0]], [X[1]], [X[2]]])
|
||||
# label1.dtype varies between int32 and int64 over platforms
|
||||
label1 = py_vq(xp.asarray(X), xp.asarray(initc))[0]
|
||||
xp_assert_equal(label1, xp.asarray(LABEL1, dtype=xp.int64),
|
||||
check_dtype=False)
|
||||
|
||||
@pytest.mark.skipif(SCIPY_ARRAY_API,
|
||||
reason='`np.matrix` unsupported in array API mode')
|
||||
def test_py_vq_matrix(self, xp):
|
||||
initc = np.concatenate([[X[0]], [X[1]], [X[2]]])
|
||||
# label1.dtype varies between int32 and int64 over platforms
|
||||
label1 = py_vq(matrix(X), matrix(initc))[0]
|
||||
assert_array_equal(label1, LABEL1)
|
||||
|
||||
@skip_xp_backends(np_only=True, reasons=['`_vq` only supports NumPy backend'])
|
||||
def test_vq(self, xp):
|
||||
initc = np.concatenate([[X[0]], [X[1]], [X[2]]])
|
||||
label1, _ = _vq.vq(xp.asarray(X), xp.asarray(initc))
|
||||
assert_array_equal(label1, LABEL1)
|
||||
_, _ = vq(xp.asarray(X), xp.asarray(initc))
|
||||
|
||||
@pytest.mark.skipif(SCIPY_ARRAY_API,
|
||||
reason='`np.matrix` unsupported in array API mode')
|
||||
def test_vq_matrix(self, xp):
|
||||
initc = np.concatenate([[X[0]], [X[1]], [X[2]]])
|
||||
label1, _ = _vq.vq(matrix(X), matrix(initc))
|
||||
assert_array_equal(label1, LABEL1)
|
||||
_, _ = vq(matrix(X), matrix(initc))
|
||||
|
||||
@skip_xp_backends(cpu_only=True)
|
||||
def test_vq_1d(self, xp):
|
||||
# Test special rank 1 vq algo, python implementation.
|
||||
data = X[:, 0]
|
||||
initc = data[:3]
|
||||
a, b = _vq.vq(data, initc)
|
||||
data = xp.asarray(data)
|
||||
initc = xp.asarray(initc)
|
||||
ta, tb = py_vq(data[:, np.newaxis], initc[:, np.newaxis])
|
||||
# ta.dtype varies between int32 and int64 over platforms
|
||||
xp_assert_equal(ta, xp.asarray(a, dtype=xp.int64), check_dtype=False)
|
||||
xp_assert_equal(tb, xp.asarray(b))
|
||||
|
||||
@skip_xp_backends(np_only=True, reasons=['`_vq` only supports NumPy backend'])
|
||||
def test__vq_sametype(self, xp):
|
||||
a = xp.asarray([1.0, 2.0], dtype=xp.float64)
|
||||
b = a.astype(xp.float32)
|
||||
assert_raises(TypeError, _vq.vq, a, b)
|
||||
|
||||
@skip_xp_backends(np_only=True, reasons=['`_vq` only supports NumPy backend'])
|
||||
def test__vq_invalid_type(self, xp):
|
||||
a = xp.asarray([1, 2], dtype=int)
|
||||
assert_raises(TypeError, _vq.vq, a, a)
|
||||
|
||||
@skip_xp_backends(cpu_only=True)
|
||||
def test_vq_large_nfeat(self, xp):
|
||||
X = np.random.rand(20, 20)
|
||||
code_book = np.random.rand(3, 20)
|
||||
|
||||
codes0, dis0 = _vq.vq(X, code_book)
|
||||
codes1, dis1 = py_vq(
|
||||
xp.asarray(X), xp.asarray(code_book)
|
||||
)
|
||||
xp_assert_close(dis1, xp.asarray(dis0), rtol=1e-5)
|
||||
# codes1.dtype varies between int32 and int64 over platforms
|
||||
xp_assert_equal(codes1, xp.asarray(codes0, dtype=xp.int64), check_dtype=False)
|
||||
|
||||
X = X.astype(np.float32)
|
||||
code_book = code_book.astype(np.float32)
|
||||
|
||||
codes0, dis0 = _vq.vq(X, code_book)
|
||||
codes1, dis1 = py_vq(
|
||||
xp.asarray(X), xp.asarray(code_book)
|
||||
)
|
||||
xp_assert_close(dis1, xp.asarray(dis0, dtype=xp.float64), rtol=1e-5)
|
||||
# codes1.dtype varies between int32 and int64 over platforms
|
||||
xp_assert_equal(codes1, xp.asarray(codes0, dtype=xp.int64), check_dtype=False)
|
||||
|
||||
@skip_xp_backends(cpu_only=True)
|
||||
def test_vq_large_features(self, xp):
|
||||
X = np.random.rand(10, 5) * 1000000
|
||||
code_book = np.random.rand(2, 5) * 1000000
|
||||
|
||||
codes0, dis0 = _vq.vq(X, code_book)
|
||||
codes1, dis1 = py_vq(
|
||||
xp.asarray(X), xp.asarray(code_book)
|
||||
)
|
||||
xp_assert_close(dis1, xp.asarray(dis0), rtol=1e-5)
|
||||
# codes1.dtype varies between int32 and int64 over platforms
|
||||
xp_assert_equal(codes1, xp.asarray(codes0, dtype=xp.int64), check_dtype=False)
|
||||
|
||||
|
||||
# Whole class skipped on GPU for now;
|
||||
# once pdist/cdist are hooked up for CuPy, more tests will work
|
||||
@skip_xp_backends(cpu_only=True)
|
||||
class TestKMean:
|
||||
|
||||
def test_large_features(self, xp):
|
||||
# Generate a data set with large values, and run kmeans on it to
|
||||
# (regression for 1077).
|
||||
d = 300
|
||||
n = 100
|
||||
|
||||
m1 = np.random.randn(d)
|
||||
m2 = np.random.randn(d)
|
||||
x = 10000 * np.random.randn(n, d) - 20000 * m1
|
||||
y = 10000 * np.random.randn(n, d) + 20000 * m2
|
||||
|
||||
data = np.empty((x.shape[0] + y.shape[0], d), np.float64)
|
||||
data[:x.shape[0]] = x
|
||||
data[x.shape[0]:] = y
|
||||
|
||||
kmeans(xp.asarray(data), 2)
|
||||
|
||||
def test_kmeans_simple(self, xp):
|
||||
np.random.seed(54321)
|
||||
initc = np.concatenate([[X[0]], [X[1]], [X[2]]])
|
||||
code1 = kmeans(xp.asarray(X), xp.asarray(initc), iter=1)[0]
|
||||
xp_assert_close(code1, xp.asarray(CODET2))
|
||||
|
||||
@pytest.mark.skipif(SCIPY_ARRAY_API,
|
||||
reason='`np.matrix` unsupported in array API mode')
|
||||
def test_kmeans_simple_matrix(self, xp):
|
||||
np.random.seed(54321)
|
||||
initc = np.concatenate([[X[0]], [X[1]], [X[2]]])
|
||||
code1 = kmeans(matrix(X), matrix(initc), iter=1)[0]
|
||||
xp_assert_close(code1, CODET2)
|
||||
|
||||
def test_kmeans_lost_cluster(self, xp):
|
||||
# This will cause kmeans to have a cluster with no points.
|
||||
data = xp.asarray(TESTDATA_2D)
|
||||
initk = xp.asarray([[-1.8127404, -0.67128041],
|
||||
[2.04621601, 0.07401111],
|
||||
[-2.31149087, -0.05160469]])
|
||||
|
||||
kmeans(data, initk)
|
||||
with suppress_warnings() as sup:
|
||||
sup.filter(UserWarning,
|
||||
"One of the clusters is empty. Re-run kmeans with a "
|
||||
"different initialization")
|
||||
kmeans2(data, initk, missing='warn')
|
||||
|
||||
assert_raises(ClusterError, kmeans2, data, initk, missing='raise')
|
||||
|
||||
def test_kmeans2_simple(self, xp):
|
||||
np.random.seed(12345678)
|
||||
initc = xp.asarray(np.concatenate([[X[0]], [X[1]], [X[2]]]))
|
||||
arrays = [xp.asarray] if SCIPY_ARRAY_API else [np.asarray, matrix]
|
||||
for tp in arrays:
|
||||
code1 = kmeans2(tp(X), tp(initc), iter=1)[0]
|
||||
code2 = kmeans2(tp(X), tp(initc), iter=2)[0]
|
||||
|
||||
xp_assert_close(code1, xp.asarray(CODET1))
|
||||
xp_assert_close(code2, xp.asarray(CODET2))
|
||||
|
||||
@pytest.mark.skipif(SCIPY_ARRAY_API,
|
||||
reason='`np.matrix` unsupported in array API mode')
|
||||
def test_kmeans2_simple_matrix(self, xp):
|
||||
np.random.seed(12345678)
|
||||
initc = xp.asarray(np.concatenate([[X[0]], [X[1]], [X[2]]]))
|
||||
code1 = kmeans2(matrix(X), matrix(initc), iter=1)[0]
|
||||
code2 = kmeans2(matrix(X), matrix(initc), iter=2)[0]
|
||||
|
||||
xp_assert_close(code1, CODET1)
|
||||
xp_assert_close(code2, CODET2)
|
||||
|
||||
def test_kmeans2_rank1(self, xp):
|
||||
data = xp.asarray(TESTDATA_2D)
|
||||
data1 = data[:, 0]
|
||||
|
||||
initc = data1[:3]
|
||||
code = copy(initc, xp=xp)
|
||||
kmeans2(data1, code, iter=1)[0]
|
||||
kmeans2(data1, code, iter=2)[0]
|
||||
|
||||
def test_kmeans2_rank1_2(self, xp):
|
||||
data = xp.asarray(TESTDATA_2D)
|
||||
data1 = data[:, 0]
|
||||
kmeans2(data1, 2, iter=1)
|
||||
|
||||
def test_kmeans2_high_dim(self, xp):
|
||||
# test kmeans2 when the number of dimensions exceeds the number
|
||||
# of input points
|
||||
data = xp.asarray(TESTDATA_2D)
|
||||
data = xp.reshape(data, (20, 20))[:10, :]
|
||||
kmeans2(data, 2)
|
||||
|
||||
@skip_xp_backends('jax.numpy',
|
||||
reasons=['jax arrays do not support item assignment'],
|
||||
cpu_only=True)
|
||||
def test_kmeans2_init(self, xp):
|
||||
np.random.seed(12345)
|
||||
data = xp.asarray(TESTDATA_2D)
|
||||
k = 3
|
||||
|
||||
kmeans2(data, k, minit='points')
|
||||
kmeans2(data[:, 1], k, minit='points') # special case (1-D)
|
||||
|
||||
kmeans2(data, k, minit='++')
|
||||
kmeans2(data[:, 1], k, minit='++') # special case (1-D)
|
||||
|
||||
# minit='random' can give warnings, filter those
|
||||
with suppress_warnings() as sup:
|
||||
sup.filter(message="One of the clusters is empty. Re-run.")
|
||||
kmeans2(data, k, minit='random')
|
||||
kmeans2(data[:, 1], k, minit='random') # special case (1-D)
|
||||
|
||||
@pytest.mark.skipif(sys.platform == 'win32',
|
||||
reason='Fails with MemoryError in Wine.')
|
||||
def test_krandinit(self, xp):
|
||||
data = xp.asarray(TESTDATA_2D)
|
||||
datas = [xp.reshape(data, (200, 2)),
|
||||
xp.reshape(data, (20, 20))[:10, :]]
|
||||
k = int(1e6)
|
||||
for data in datas:
|
||||
rng = np.random.default_rng(1234)
|
||||
init = _krandinit(data, k, rng, xp)
|
||||
orig_cov = cov(data.T)
|
||||
init_cov = cov(init.T)
|
||||
xp_assert_close(orig_cov, init_cov, atol=1.1e-2)
|
||||
|
||||
def test_kmeans2_empty(self, xp):
|
||||
# Regression test for gh-1032.
|
||||
assert_raises(ValueError, kmeans2, xp.asarray([]), 2)
|
||||
|
||||
def test_kmeans_0k(self, xp):
|
||||
# Regression test for gh-1073: fail when k arg is 0.
|
||||
assert_raises(ValueError, kmeans, xp.asarray(X), 0)
|
||||
assert_raises(ValueError, kmeans2, xp.asarray(X), 0)
|
||||
assert_raises(ValueError, kmeans2, xp.asarray(X), xp.asarray([]))
|
||||
|
||||
def test_kmeans_large_thres(self, xp):
|
||||
# Regression test for gh-1774
|
||||
x = xp.asarray([1, 2, 3, 4, 10], dtype=xp.float64)
|
||||
res = kmeans(x, 1, thresh=1e16)
|
||||
xp_assert_close(res[0], xp.asarray([4.], dtype=xp.float64))
|
||||
xp_assert_close(res[1], xp.asarray(2.3999999999999999, dtype=xp.float64)[()])
|
||||
|
||||
@skip_xp_backends('jax.numpy',
|
||||
reasons=['jax arrays do not support item assignment'],
|
||||
cpu_only=True)
|
||||
def test_kmeans2_kpp_low_dim(self, xp):
|
||||
# Regression test for gh-11462
|
||||
prev_res = xp.asarray([[-1.95266667, 0.898],
|
||||
[-3.153375, 3.3945]], dtype=xp.float64)
|
||||
np.random.seed(42)
|
||||
res, _ = kmeans2(xp.asarray(TESTDATA_2D), 2, minit='++')
|
||||
xp_assert_close(res, prev_res)
|
||||
|
||||
@skip_xp_backends('jax.numpy',
|
||||
reasons=['jax arrays do not support item assignment'],
|
||||
cpu_only=True)
|
||||
def test_kmeans2_kpp_high_dim(self, xp):
|
||||
# Regression test for gh-11462
|
||||
n_dim = 100
|
||||
size = 10
|
||||
centers = np.vstack([5 * np.ones(n_dim),
|
||||
-5 * np.ones(n_dim)])
|
||||
np.random.seed(42)
|
||||
data = np.vstack([
|
||||
np.random.multivariate_normal(centers[0], np.eye(n_dim), size=size),
|
||||
np.random.multivariate_normal(centers[1], np.eye(n_dim), size=size)
|
||||
])
|
||||
|
||||
data = xp.asarray(data)
|
||||
res, _ = kmeans2(data, 2, minit='++')
|
||||
xp_assert_equal(xp.sign(res), xp.sign(xp.asarray(centers)))
|
||||
|
||||
def test_kmeans_diff_convergence(self, xp):
|
||||
# Regression test for gh-8727
|
||||
obs = xp.asarray([-3, -1, 0, 1, 1, 8], dtype=xp.float64)
|
||||
res = kmeans(obs, xp.asarray([-3., 0.99]))
|
||||
xp_assert_close(res[0], xp.asarray([-0.4, 8.], dtype=xp.float64))
|
||||
xp_assert_close(res[1], xp.asarray(1.0666666666666667, dtype=xp.float64)[()])
|
||||
|
||||
@skip_xp_backends('jax.numpy',
|
||||
reasons=['jax arrays do not support item assignment'],
|
||||
cpu_only=True)
|
||||
def test_kmeans_and_kmeans2_random_seed(self, xp):
|
||||
|
||||
seed_list = [
|
||||
1234, np.random.RandomState(1234), np.random.default_rng(1234)
|
||||
]
|
||||
|
||||
for seed in seed_list:
|
||||
seed1 = deepcopy(seed)
|
||||
seed2 = deepcopy(seed)
|
||||
data = xp.asarray(TESTDATA_2D)
|
||||
# test for kmeans
|
||||
res1, _ = kmeans(data, 2, seed=seed1)
|
||||
res2, _ = kmeans(data, 2, seed=seed2)
|
||||
xp_assert_close(res1, res2) # should be same results
|
||||
# test for kmeans2
|
||||
for minit in ["random", "points", "++"]:
|
||||
res1, _ = kmeans2(data, 2, minit=minit, seed=seed1)
|
||||
res2, _ = kmeans2(data, 2, minit=minit, seed=seed2)
|
||||
xp_assert_close(res1, res2) # should be same results
|
||||
835
venv/lib/python3.12/site-packages/scipy/cluster/vq.py
Normal file
835
venv/lib/python3.12/site-packages/scipy/cluster/vq.py
Normal file
@ -0,0 +1,835 @@
|
||||
"""
|
||||
K-means clustering and vector quantization (:mod:`scipy.cluster.vq`)
|
||||
====================================================================
|
||||
|
||||
Provides routines for k-means clustering, generating code books
|
||||
from k-means models and quantizing vectors by comparing them with
|
||||
centroids in a code book.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated/
|
||||
|
||||
whiten -- Normalize a group of observations so each feature has unit variance
|
||||
vq -- Calculate code book membership of a set of observation vectors
|
||||
kmeans -- Perform k-means on a set of observation vectors forming k clusters
|
||||
kmeans2 -- A different implementation of k-means with more methods
|
||||
-- for initializing centroids
|
||||
|
||||
Background information
|
||||
----------------------
|
||||
The k-means algorithm takes as input the number of clusters to
|
||||
generate, k, and a set of observation vectors to cluster. It
|
||||
returns a set of centroids, one for each of the k clusters. An
|
||||
observation vector is classified with the cluster number or
|
||||
centroid index of the centroid closest to it.
|
||||
|
||||
A vector v belongs to cluster i if it is closer to centroid i than
|
||||
any other centroid. If v belongs to i, we say centroid i is the
|
||||
dominating centroid of v. The k-means algorithm tries to
|
||||
minimize distortion, which is defined as the sum of the squared distances
|
||||
between each observation vector and its dominating centroid.
|
||||
The minimization is achieved by iteratively reclassifying
|
||||
the observations into clusters and recalculating the centroids until
|
||||
a configuration is reached in which the centroids are stable. One can
|
||||
also define a maximum number of iterations.
|
||||
|
||||
Since vector quantization is a natural application for k-means,
|
||||
information theory terminology is often used. The centroid index
|
||||
or cluster index is also referred to as a "code" and the table
|
||||
mapping codes to centroids and, vice versa, is often referred to as a
|
||||
"code book". The result of k-means, a set of centroids, can be
|
||||
used to quantize vectors. Quantization aims to find an encoding of
|
||||
vectors that reduces the expected distortion.
|
||||
|
||||
All routines expect obs to be an M by N array, where the rows are
|
||||
the observation vectors. The codebook is a k by N array, where the
|
||||
ith row is the centroid of code word i. The observation vectors
|
||||
and centroids have the same feature dimension.
|
||||
|
||||
As an example, suppose we wish to compress a 24-bit color image
|
||||
(each pixel is represented by one byte for red, one for blue, and
|
||||
one for green) before sending it over the web. By using a smaller
|
||||
8-bit encoding, we can reduce the amount of data by two
|
||||
thirds. Ideally, the colors for each of the 256 possible 8-bit
|
||||
encoding values should be chosen to minimize distortion of the
|
||||
color. Running k-means with k=256 generates a code book of 256
|
||||
codes, which fills up all possible 8-bit sequences. Instead of
|
||||
sending a 3-byte value for each pixel, the 8-bit centroid index
|
||||
(or code word) of the dominating centroid is transmitted. The code
|
||||
book is also sent over the wire so each 8-bit code can be
|
||||
translated back to a 24-bit pixel value representation. If the
|
||||
image of interest was of an ocean, we would expect many 24-bit
|
||||
blues to be represented by 8-bit codes. If it was an image of a
|
||||
human face, more flesh-tone colors would be represented in the
|
||||
code book.
|
||||
|
||||
"""
|
||||
import warnings
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
from scipy._lib._array_api import (
|
||||
_asarray, array_namespace, size, atleast_nd, copy, cov
|
||||
)
|
||||
from scipy._lib._util import check_random_state, rng_integers
|
||||
from scipy.spatial.distance import cdist
|
||||
|
||||
from . import _vq
|
||||
|
||||
__docformat__ = 'restructuredtext'
|
||||
|
||||
__all__ = ['whiten', 'vq', 'kmeans', 'kmeans2']
|
||||
|
||||
|
||||
class ClusterError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def whiten(obs, check_finite=True):
|
||||
"""
|
||||
Normalize a group of observations on a per feature basis.
|
||||
|
||||
Before running k-means, it is beneficial to rescale each feature
|
||||
dimension of the observation set by its standard deviation (i.e. "whiten"
|
||||
it - as in "white noise" where each frequency has equal power).
|
||||
Each feature is divided by its standard deviation across all observations
|
||||
to give it unit variance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
obs : ndarray
|
||||
Each row of the array is an observation. The
|
||||
columns are the features seen during each observation.
|
||||
|
||||
>>> # f0 f1 f2
|
||||
>>> obs = [[ 1., 1., 1.], #o0
|
||||
... [ 2., 2., 2.], #o1
|
||||
... [ 3., 3., 3.], #o2
|
||||
... [ 4., 4., 4.]] #o3
|
||||
|
||||
check_finite : bool, optional
|
||||
Whether to check that the input matrices contain only finite numbers.
|
||||
Disabling may give a performance gain, but may result in problems
|
||||
(crashes, non-termination) if the inputs do contain infinities or NaNs.
|
||||
Default: True
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : ndarray
|
||||
Contains the values in `obs` scaled by the standard deviation
|
||||
of each column.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import numpy as np
|
||||
>>> from scipy.cluster.vq import whiten
|
||||
>>> features = np.array([[1.9, 2.3, 1.7],
|
||||
... [1.5, 2.5, 2.2],
|
||||
... [0.8, 0.6, 1.7,]])
|
||||
>>> whiten(features)
|
||||
array([[ 4.17944278, 2.69811351, 7.21248917],
|
||||
[ 3.29956009, 2.93273208, 9.33380951],
|
||||
[ 1.75976538, 0.7038557 , 7.21248917]])
|
||||
|
||||
"""
|
||||
xp = array_namespace(obs)
|
||||
obs = _asarray(obs, check_finite=check_finite, xp=xp)
|
||||
std_dev = xp.std(obs, axis=0)
|
||||
zero_std_mask = std_dev == 0
|
||||
if xp.any(zero_std_mask):
|
||||
std_dev[zero_std_mask] = 1.0
|
||||
warnings.warn("Some columns have standard deviation zero. "
|
||||
"The values of these columns will not change.",
|
||||
RuntimeWarning, stacklevel=2)
|
||||
return obs / std_dev
|
||||
|
||||
|
||||
def vq(obs, code_book, check_finite=True):
|
||||
"""
|
||||
Assign codes from a code book to observations.
|
||||
|
||||
Assigns a code from a code book to each observation. Each
|
||||
observation vector in the 'M' by 'N' `obs` array is compared with the
|
||||
centroids in the code book and assigned the code of the closest
|
||||
centroid.
|
||||
|
||||
The features in `obs` should have unit variance, which can be
|
||||
achieved by passing them through the whiten function. The code
|
||||
book can be created with the k-means algorithm or a different
|
||||
encoding algorithm.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
obs : ndarray
|
||||
Each row of the 'M' x 'N' array is an observation. The columns are
|
||||
the "features" seen during each observation. The features must be
|
||||
whitened first using the whiten function or something equivalent.
|
||||
code_book : ndarray
|
||||
The code book is usually generated using the k-means algorithm.
|
||||
Each row of the array holds a different code, and the columns are
|
||||
the features of the code.
|
||||
|
||||
>>> # f0 f1 f2 f3
|
||||
>>> code_book = [
|
||||
... [ 1., 2., 3., 4.], #c0
|
||||
... [ 1., 2., 3., 4.], #c1
|
||||
... [ 1., 2., 3., 4.]] #c2
|
||||
|
||||
check_finite : bool, optional
|
||||
Whether to check that the input matrices contain only finite numbers.
|
||||
Disabling may give a performance gain, but may result in problems
|
||||
(crashes, non-termination) if the inputs do contain infinities or NaNs.
|
||||
Default: True
|
||||
|
||||
Returns
|
||||
-------
|
||||
code : ndarray
|
||||
A length M array holding the code book index for each observation.
|
||||
dist : ndarray
|
||||
The distortion (distance) between the observation and its nearest
|
||||
code.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import numpy as np
|
||||
>>> from scipy.cluster.vq import vq
|
||||
>>> code_book = np.array([[1., 1., 1.],
|
||||
... [2., 2., 2.]])
|
||||
>>> features = np.array([[1.9, 2.3, 1.7],
|
||||
... [1.5, 2.5, 2.2],
|
||||
... [0.8, 0.6, 1.7]])
|
||||
>>> vq(features, code_book)
|
||||
(array([1, 1, 0], dtype=int32), array([0.43588989, 0.73484692, 0.83066239]))
|
||||
|
||||
"""
|
||||
xp = array_namespace(obs, code_book)
|
||||
obs = _asarray(obs, xp=xp, check_finite=check_finite)
|
||||
code_book = _asarray(code_book, xp=xp, check_finite=check_finite)
|
||||
ct = xp.result_type(obs, code_book)
|
||||
|
||||
c_obs = xp.astype(obs, ct, copy=False)
|
||||
c_code_book = xp.astype(code_book, ct, copy=False)
|
||||
|
||||
if xp.isdtype(ct, kind='real floating'):
|
||||
c_obs = np.asarray(c_obs)
|
||||
c_code_book = np.asarray(c_code_book)
|
||||
result = _vq.vq(c_obs, c_code_book)
|
||||
return xp.asarray(result[0]), xp.asarray(result[1])
|
||||
return py_vq(obs, code_book, check_finite=False)
|
||||
|
||||
|
||||
def py_vq(obs, code_book, check_finite=True):
|
||||
""" Python version of vq algorithm.
|
||||
|
||||
The algorithm computes the Euclidean distance between each
|
||||
observation and every frame in the code_book.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
obs : ndarray
|
||||
Expects a rank 2 array. Each row is one observation.
|
||||
code_book : ndarray
|
||||
Code book to use. Same format than obs. Should have same number of
|
||||
features (e.g., columns) than obs.
|
||||
check_finite : bool, optional
|
||||
Whether to check that the input matrices contain only finite numbers.
|
||||
Disabling may give a performance gain, but may result in problems
|
||||
(crashes, non-termination) if the inputs do contain infinities or NaNs.
|
||||
Default: True
|
||||
|
||||
Returns
|
||||
-------
|
||||
code : ndarray
|
||||
code[i] gives the label of the ith obversation; its code is
|
||||
code_book[code[i]].
|
||||
mind_dist : ndarray
|
||||
min_dist[i] gives the distance between the ith observation and its
|
||||
corresponding code.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function is slower than the C version but works for
|
||||
all input types. If the inputs have the wrong types for the
|
||||
C versions of the function, this one is called as a last resort.
|
||||
|
||||
It is about 20 times slower than the C version.
|
||||
|
||||
"""
|
||||
xp = array_namespace(obs, code_book)
|
||||
obs = _asarray(obs, xp=xp, check_finite=check_finite)
|
||||
code_book = _asarray(code_book, xp=xp, check_finite=check_finite)
|
||||
|
||||
if obs.ndim != code_book.ndim:
|
||||
raise ValueError("Observation and code_book should have the same rank")
|
||||
|
||||
if obs.ndim == 1:
|
||||
obs = obs[:, xp.newaxis]
|
||||
code_book = code_book[:, xp.newaxis]
|
||||
|
||||
# Once `cdist` has array API support, this `xp.asarray` call can be removed
|
||||
dist = xp.asarray(cdist(obs, code_book))
|
||||
code = xp.argmin(dist, axis=1)
|
||||
min_dist = xp.min(dist, axis=1)
|
||||
return code, min_dist
|
||||
|
||||
|
||||
def _kmeans(obs, guess, thresh=1e-5, xp=None):
|
||||
""" "raw" version of k-means.
|
||||
|
||||
Returns
|
||||
-------
|
||||
code_book
|
||||
The lowest distortion codebook found.
|
||||
avg_dist
|
||||
The average distance a observation is from a code in the book.
|
||||
Lower means the code_book matches the data better.
|
||||
|
||||
See Also
|
||||
--------
|
||||
kmeans : wrapper around k-means
|
||||
|
||||
Examples
|
||||
--------
|
||||
Note: not whitened in this example.
|
||||
|
||||
>>> import numpy as np
|
||||
>>> from scipy.cluster.vq import _kmeans
|
||||
>>> features = np.array([[ 1.9,2.3],
|
||||
... [ 1.5,2.5],
|
||||
... [ 0.8,0.6],
|
||||
... [ 0.4,1.8],
|
||||
... [ 1.0,1.0]])
|
||||
>>> book = np.array((features[0],features[2]))
|
||||
>>> _kmeans(features,book)
|
||||
(array([[ 1.7 , 2.4 ],
|
||||
[ 0.73333333, 1.13333333]]), 0.40563916697728591)
|
||||
|
||||
"""
|
||||
xp = np if xp is None else xp
|
||||
code_book = guess
|
||||
diff = xp.inf
|
||||
prev_avg_dists = deque([diff], maxlen=2)
|
||||
while diff > thresh:
|
||||
# compute membership and distances between obs and code_book
|
||||
obs_code, distort = vq(obs, code_book, check_finite=False)
|
||||
prev_avg_dists.append(xp.mean(distort, axis=-1))
|
||||
# recalc code_book as centroids of associated obs
|
||||
obs = np.asarray(obs)
|
||||
obs_code = np.asarray(obs_code)
|
||||
code_book, has_members = _vq.update_cluster_means(obs, obs_code,
|
||||
code_book.shape[0])
|
||||
obs = xp.asarray(obs)
|
||||
obs_code = xp.asarray(obs_code)
|
||||
code_book = xp.asarray(code_book)
|
||||
has_members = xp.asarray(has_members)
|
||||
code_book = code_book[has_members]
|
||||
diff = xp.abs(prev_avg_dists[0] - prev_avg_dists[1])
|
||||
|
||||
return code_book, prev_avg_dists[1]
|
||||
|
||||
|
||||
def kmeans(obs, k_or_guess, iter=20, thresh=1e-5, check_finite=True,
|
||||
*, seed=None):
|
||||
"""
|
||||
Performs k-means on a set of observation vectors forming k clusters.
|
||||
|
||||
The k-means algorithm adjusts the classification of the observations
|
||||
into clusters and updates the cluster centroids until the position of
|
||||
the centroids is stable over successive iterations. In this
|
||||
implementation of the algorithm, the stability of the centroids is
|
||||
determined by comparing the absolute value of the change in the average
|
||||
Euclidean distance between the observations and their corresponding
|
||||
centroids against a threshold. This yields
|
||||
a code book mapping centroids to codes and vice versa.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
obs : ndarray
|
||||
Each row of the M by N array is an observation vector. The
|
||||
columns are the features seen during each observation.
|
||||
The features must be whitened first with the `whiten` function.
|
||||
|
||||
k_or_guess : int or ndarray
|
||||
The number of centroids to generate. A code is assigned to
|
||||
each centroid, which is also the row index of the centroid
|
||||
in the code_book matrix generated.
|
||||
|
||||
The initial k centroids are chosen by randomly selecting
|
||||
observations from the observation matrix. Alternatively,
|
||||
passing a k by N array specifies the initial k centroids.
|
||||
|
||||
iter : int, optional
|
||||
The number of times to run k-means, returning the codebook
|
||||
with the lowest distortion. This argument is ignored if
|
||||
initial centroids are specified with an array for the
|
||||
``k_or_guess`` parameter. This parameter does not represent the
|
||||
number of iterations of the k-means algorithm.
|
||||
|
||||
thresh : float, optional
|
||||
Terminates the k-means algorithm if the change in
|
||||
distortion since the last k-means iteration is less than
|
||||
or equal to threshold.
|
||||
|
||||
check_finite : bool, optional
|
||||
Whether to check that the input matrices contain only finite numbers.
|
||||
Disabling may give a performance gain, but may result in problems
|
||||
(crashes, non-termination) if the inputs do contain infinities or NaNs.
|
||||
Default: True
|
||||
|
||||
seed : {None, int, `numpy.random.Generator`, `numpy.random.RandomState`}, optional
|
||||
Seed for initializing the pseudo-random number generator.
|
||||
If `seed` is None (or `numpy.random`), the `numpy.random.RandomState`
|
||||
singleton is used.
|
||||
If `seed` is an int, a new ``RandomState`` instance is used,
|
||||
seeded with `seed`.
|
||||
If `seed` is already a ``Generator`` or ``RandomState`` instance then
|
||||
that instance is used.
|
||||
The default is None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
codebook : ndarray
|
||||
A k by N array of k centroids. The ith centroid
|
||||
codebook[i] is represented with the code i. The centroids
|
||||
and codes generated represent the lowest distortion seen,
|
||||
not necessarily the globally minimal distortion.
|
||||
Note that the number of centroids is not necessarily the same as the
|
||||
``k_or_guess`` parameter, because centroids assigned to no observations
|
||||
are removed during iterations.
|
||||
|
||||
distortion : float
|
||||
The mean (non-squared) Euclidean distance between the observations
|
||||
passed and the centroids generated. Note the difference to the standard
|
||||
definition of distortion in the context of the k-means algorithm, which
|
||||
is the sum of the squared distances.
|
||||
|
||||
See Also
|
||||
--------
|
||||
kmeans2 : a different implementation of k-means clustering
|
||||
with more methods for generating initial centroids but without
|
||||
using a distortion change threshold as a stopping criterion.
|
||||
|
||||
whiten : must be called prior to passing an observation matrix
|
||||
to kmeans.
|
||||
|
||||
Notes
|
||||
-----
|
||||
For more functionalities or optimal performance, you can use
|
||||
`sklearn.cluster.KMeans <https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html>`_.
|
||||
`This <https://hdbscan.readthedocs.io/en/latest/performance_and_scalability.html#comparison-of-high-performance-implementations>`_
|
||||
is a benchmark result of several implementations.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import numpy as np
|
||||
>>> from scipy.cluster.vq import vq, kmeans, whiten
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> features = np.array([[ 1.9,2.3],
|
||||
... [ 1.5,2.5],
|
||||
... [ 0.8,0.6],
|
||||
... [ 0.4,1.8],
|
||||
... [ 0.1,0.1],
|
||||
... [ 0.2,1.8],
|
||||
... [ 2.0,0.5],
|
||||
... [ 0.3,1.5],
|
||||
... [ 1.0,1.0]])
|
||||
>>> whitened = whiten(features)
|
||||
>>> book = np.array((whitened[0],whitened[2]))
|
||||
>>> kmeans(whitened,book)
|
||||
(array([[ 2.3110306 , 2.86287398], # random
|
||||
[ 0.93218041, 1.24398691]]), 0.85684700941625547)
|
||||
|
||||
>>> codes = 3
|
||||
>>> kmeans(whitened,codes)
|
||||
(array([[ 2.3110306 , 2.86287398], # random
|
||||
[ 1.32544402, 0.65607529],
|
||||
[ 0.40782893, 2.02786907]]), 0.5196582527686241)
|
||||
|
||||
>>> # Create 50 datapoints in two clusters a and b
|
||||
>>> pts = 50
|
||||
>>> rng = np.random.default_rng()
|
||||
>>> a = rng.multivariate_normal([0, 0], [[4, 1], [1, 4]], size=pts)
|
||||
>>> b = rng.multivariate_normal([30, 10],
|
||||
... [[10, 2], [2, 1]],
|
||||
... size=pts)
|
||||
>>> features = np.concatenate((a, b))
|
||||
>>> # Whiten data
|
||||
>>> whitened = whiten(features)
|
||||
>>> # Find 2 clusters in the data
|
||||
>>> codebook, distortion = kmeans(whitened, 2)
|
||||
>>> # Plot whitened data and cluster centers in red
|
||||
>>> plt.scatter(whitened[:, 0], whitened[:, 1])
|
||||
>>> plt.scatter(codebook[:, 0], codebook[:, 1], c='r')
|
||||
>>> plt.show()
|
||||
|
||||
"""
|
||||
if isinstance(k_or_guess, int):
|
||||
xp = array_namespace(obs)
|
||||
else:
|
||||
xp = array_namespace(obs, k_or_guess)
|
||||
obs = _asarray(obs, xp=xp, check_finite=check_finite)
|
||||
guess = _asarray(k_or_guess, xp=xp, check_finite=check_finite)
|
||||
if iter < 1:
|
||||
raise ValueError("iter must be at least 1, got %s" % iter)
|
||||
|
||||
# Determine whether a count (scalar) or an initial guess (array) was passed.
|
||||
if size(guess) != 1:
|
||||
if size(guess) < 1:
|
||||
raise ValueError("Asked for 0 clusters. Initial book was %s" %
|
||||
guess)
|
||||
return _kmeans(obs, guess, thresh=thresh, xp=xp)
|
||||
|
||||
# k_or_guess is a scalar, now verify that it's an integer
|
||||
k = int(guess)
|
||||
if k != guess:
|
||||
raise ValueError("If k_or_guess is a scalar, it must be an integer.")
|
||||
if k < 1:
|
||||
raise ValueError("Asked for %d clusters." % k)
|
||||
|
||||
rng = check_random_state(seed)
|
||||
|
||||
# initialize best distance value to a large value
|
||||
best_dist = xp.inf
|
||||
for i in range(iter):
|
||||
# the initial code book is randomly selected from observations
|
||||
guess = _kpoints(obs, k, rng, xp)
|
||||
book, dist = _kmeans(obs, guess, thresh=thresh, xp=xp)
|
||||
if dist < best_dist:
|
||||
best_book = book
|
||||
best_dist = dist
|
||||
return best_book, best_dist
|
||||
|
||||
|
||||
def _kpoints(data, k, rng, xp):
|
||||
"""Pick k points at random in data (one row = one observation).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : ndarray
|
||||
Expect a rank 1 or 2 array. Rank 1 are assumed to describe one
|
||||
dimensional data, rank 2 multidimensional data, in which case one
|
||||
row is one observation.
|
||||
k : int
|
||||
Number of samples to generate.
|
||||
rng : `numpy.random.Generator` or `numpy.random.RandomState`
|
||||
Random number generator.
|
||||
|
||||
Returns
|
||||
-------
|
||||
x : ndarray
|
||||
A 'k' by 'N' containing the initial centroids
|
||||
|
||||
"""
|
||||
idx = rng.choice(data.shape[0], size=int(k), replace=False)
|
||||
# convert to array with default integer dtype (avoids numpy#25607)
|
||||
idx = xp.asarray(idx, dtype=xp.asarray([1]).dtype)
|
||||
return xp.take(data, idx, axis=0)
|
||||
|
||||
|
||||
def _krandinit(data, k, rng, xp):
|
||||
"""Returns k samples of a random variable whose parameters depend on data.
|
||||
|
||||
More precisely, it returns k observations sampled from a Gaussian random
|
||||
variable whose mean and covariances are the ones estimated from the data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : ndarray
|
||||
Expect a rank 1 or 2 array. Rank 1 is assumed to describe 1-D
|
||||
data, rank 2 multidimensional data, in which case one
|
||||
row is one observation.
|
||||
k : int
|
||||
Number of samples to generate.
|
||||
rng : `numpy.random.Generator` or `numpy.random.RandomState`
|
||||
Random number generator.
|
||||
|
||||
Returns
|
||||
-------
|
||||
x : ndarray
|
||||
A 'k' by 'N' containing the initial centroids
|
||||
|
||||
"""
|
||||
mu = xp.mean(data, axis=0)
|
||||
k = np.asarray(k)
|
||||
|
||||
if data.ndim == 1:
|
||||
_cov = cov(data)
|
||||
x = rng.standard_normal(size=k)
|
||||
x = xp.asarray(x)
|
||||
x *= xp.sqrt(_cov)
|
||||
elif data.shape[1] > data.shape[0]:
|
||||
# initialize when the covariance matrix is rank deficient
|
||||
_, s, vh = xp.linalg.svd(data - mu, full_matrices=False)
|
||||
x = rng.standard_normal(size=(k, size(s)))
|
||||
x = xp.asarray(x)
|
||||
sVh = s[:, None] * vh / xp.sqrt(data.shape[0] - xp.asarray(1.))
|
||||
x = x @ sVh
|
||||
else:
|
||||
_cov = atleast_nd(cov(data.T), ndim=2)
|
||||
|
||||
# k rows, d cols (one row = one obs)
|
||||
# Generate k sample of a random variable ~ Gaussian(mu, cov)
|
||||
x = rng.standard_normal(size=(k, size(mu)))
|
||||
x = xp.asarray(x)
|
||||
x = x @ xp.linalg.cholesky(_cov).T
|
||||
|
||||
x += mu
|
||||
return x
|
||||
|
||||
|
||||
def _kpp(data, k, rng, xp):
|
||||
""" Picks k points in the data based on the kmeans++ method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : ndarray
|
||||
Expect a rank 1 or 2 array. Rank 1 is assumed to describe 1-D
|
||||
data, rank 2 multidimensional data, in which case one
|
||||
row is one observation.
|
||||
k : int
|
||||
Number of samples to generate.
|
||||
rng : `numpy.random.Generator` or `numpy.random.RandomState`
|
||||
Random number generator.
|
||||
|
||||
Returns
|
||||
-------
|
||||
init : ndarray
|
||||
A 'k' by 'N' containing the initial centroids.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] D. Arthur and S. Vassilvitskii, "k-means++: the advantages of
|
||||
careful seeding", Proceedings of the Eighteenth Annual ACM-SIAM Symposium
|
||||
on Discrete Algorithms, 2007.
|
||||
"""
|
||||
|
||||
ndim = len(data.shape)
|
||||
if ndim == 1:
|
||||
data = data[:, None]
|
||||
|
||||
dims = data.shape[1]
|
||||
|
||||
init = xp.empty((int(k), dims))
|
||||
|
||||
for i in range(k):
|
||||
if i == 0:
|
||||
init[i, :] = data[rng_integers(rng, data.shape[0]), :]
|
||||
|
||||
else:
|
||||
D2 = cdist(init[:i,:], data, metric='sqeuclidean').min(axis=0)
|
||||
probs = D2/D2.sum()
|
||||
cumprobs = probs.cumsum()
|
||||
r = rng.uniform()
|
||||
cumprobs = np.asarray(cumprobs)
|
||||
init[i, :] = data[np.searchsorted(cumprobs, r), :]
|
||||
|
||||
if ndim == 1:
|
||||
init = init[:, 0]
|
||||
return init
|
||||
|
||||
|
||||
_valid_init_meth = {'random': _krandinit, 'points': _kpoints, '++': _kpp}
|
||||
|
||||
|
||||
def _missing_warn():
|
||||
"""Print a warning when called."""
|
||||
warnings.warn("One of the clusters is empty. "
|
||||
"Re-run kmeans with a different initialization.",
|
||||
stacklevel=3)
|
||||
|
||||
|
||||
def _missing_raise():
|
||||
"""Raise a ClusterError when called."""
|
||||
raise ClusterError("One of the clusters is empty. "
|
||||
"Re-run kmeans with a different initialization.")
|
||||
|
||||
|
||||
_valid_miss_meth = {'warn': _missing_warn, 'raise': _missing_raise}
|
||||
|
||||
|
||||
def kmeans2(data, k, iter=10, thresh=1e-5, minit='random',
|
||||
missing='warn', check_finite=True, *, seed=None):
|
||||
"""
|
||||
Classify a set of observations into k clusters using the k-means algorithm.
|
||||
|
||||
The algorithm attempts to minimize the Euclidean distance between
|
||||
observations and centroids. Several initialization methods are
|
||||
included.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : ndarray
|
||||
A 'M' by 'N' array of 'M' observations in 'N' dimensions or a length
|
||||
'M' array of 'M' 1-D observations.
|
||||
k : int or ndarray
|
||||
The number of clusters to form as well as the number of
|
||||
centroids to generate. If `minit` initialization string is
|
||||
'matrix', or if a ndarray is given instead, it is
|
||||
interpreted as initial cluster to use instead.
|
||||
iter : int, optional
|
||||
Number of iterations of the k-means algorithm to run. Note
|
||||
that this differs in meaning from the iters parameter to
|
||||
the kmeans function.
|
||||
thresh : float, optional
|
||||
(not used yet)
|
||||
minit : str, optional
|
||||
Method for initialization. Available methods are 'random',
|
||||
'points', '++' and 'matrix':
|
||||
|
||||
'random': generate k centroids from a Gaussian with mean and
|
||||
variance estimated from the data.
|
||||
|
||||
'points': choose k observations (rows) at random from data for
|
||||
the initial centroids.
|
||||
|
||||
'++': choose k observations accordingly to the kmeans++ method
|
||||
(careful seeding)
|
||||
|
||||
'matrix': interpret the k parameter as a k by M (or length k
|
||||
array for 1-D data) array of initial centroids.
|
||||
missing : str, optional
|
||||
Method to deal with empty clusters. Available methods are
|
||||
'warn' and 'raise':
|
||||
|
||||
'warn': give a warning and continue.
|
||||
|
||||
'raise': raise an ClusterError and terminate the algorithm.
|
||||
check_finite : bool, optional
|
||||
Whether to check that the input matrices contain only finite numbers.
|
||||
Disabling may give a performance gain, but may result in problems
|
||||
(crashes, non-termination) if the inputs do contain infinities or NaNs.
|
||||
Default: True
|
||||
seed : {None, int, `numpy.random.Generator`, `numpy.random.RandomState`}, optional
|
||||
Seed for initializing the pseudo-random number generator.
|
||||
If `seed` is None (or `numpy.random`), the `numpy.random.RandomState`
|
||||
singleton is used.
|
||||
If `seed` is an int, a new ``RandomState`` instance is used,
|
||||
seeded with `seed`.
|
||||
If `seed` is already a ``Generator`` or ``RandomState`` instance then
|
||||
that instance is used.
|
||||
The default is None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
centroid : ndarray
|
||||
A 'k' by 'N' array of centroids found at the last iteration of
|
||||
k-means.
|
||||
label : ndarray
|
||||
label[i] is the code or index of the centroid the
|
||||
ith observation is closest to.
|
||||
|
||||
See Also
|
||||
--------
|
||||
kmeans
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] D. Arthur and S. Vassilvitskii, "k-means++: the advantages of
|
||||
careful seeding", Proceedings of the Eighteenth Annual ACM-SIAM Symposium
|
||||
on Discrete Algorithms, 2007.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from scipy.cluster.vq import kmeans2
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> import numpy as np
|
||||
|
||||
Create z, an array with shape (100, 2) containing a mixture of samples
|
||||
from three multivariate normal distributions.
|
||||
|
||||
>>> rng = np.random.default_rng()
|
||||
>>> a = rng.multivariate_normal([0, 6], [[2, 1], [1, 1.5]], size=45)
|
||||
>>> b = rng.multivariate_normal([2, 0], [[1, -1], [-1, 3]], size=30)
|
||||
>>> c = rng.multivariate_normal([6, 4], [[5, 0], [0, 1.2]], size=25)
|
||||
>>> z = np.concatenate((a, b, c))
|
||||
>>> rng.shuffle(z)
|
||||
|
||||
Compute three clusters.
|
||||
|
||||
>>> centroid, label = kmeans2(z, 3, minit='points')
|
||||
>>> centroid
|
||||
array([[ 2.22274463, -0.61666946], # may vary
|
||||
[ 0.54069047, 5.86541444],
|
||||
[ 6.73846769, 4.01991898]])
|
||||
|
||||
How many points are in each cluster?
|
||||
|
||||
>>> counts = np.bincount(label)
|
||||
>>> counts
|
||||
array([29, 51, 20]) # may vary
|
||||
|
||||
Plot the clusters.
|
||||
|
||||
>>> w0 = z[label == 0]
|
||||
>>> w1 = z[label == 1]
|
||||
>>> w2 = z[label == 2]
|
||||
>>> plt.plot(w0[:, 0], w0[:, 1], 'o', alpha=0.5, label='cluster 0')
|
||||
>>> plt.plot(w1[:, 0], w1[:, 1], 'd', alpha=0.5, label='cluster 1')
|
||||
>>> plt.plot(w2[:, 0], w2[:, 1], 's', alpha=0.5, label='cluster 2')
|
||||
>>> plt.plot(centroid[:, 0], centroid[:, 1], 'k*', label='centroids')
|
||||
>>> plt.axis('equal')
|
||||
>>> plt.legend(shadow=True)
|
||||
>>> plt.show()
|
||||
|
||||
"""
|
||||
if int(iter) < 1:
|
||||
raise ValueError("Invalid iter (%s), "
|
||||
"must be a positive integer." % iter)
|
||||
try:
|
||||
miss_meth = _valid_miss_meth[missing]
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Unknown missing method {missing!r}") from e
|
||||
|
||||
if isinstance(k, int):
|
||||
xp = array_namespace(data)
|
||||
else:
|
||||
xp = array_namespace(data, k)
|
||||
data = _asarray(data, xp=xp, check_finite=check_finite)
|
||||
code_book = copy(k, xp=xp)
|
||||
if data.ndim == 1:
|
||||
d = 1
|
||||
elif data.ndim == 2:
|
||||
d = data.shape[1]
|
||||
else:
|
||||
raise ValueError("Input of rank > 2 is not supported.")
|
||||
|
||||
if size(data) < 1 or size(code_book) < 1:
|
||||
raise ValueError("Empty input is not supported.")
|
||||
|
||||
# If k is not a single value, it should be compatible with data's shape
|
||||
if minit == 'matrix' or size(code_book) > 1:
|
||||
if data.ndim != code_book.ndim:
|
||||
raise ValueError("k array doesn't match data rank")
|
||||
nc = code_book.shape[0]
|
||||
if data.ndim > 1 and code_book.shape[1] != d:
|
||||
raise ValueError("k array doesn't match data dimension")
|
||||
else:
|
||||
nc = int(code_book)
|
||||
|
||||
if nc < 1:
|
||||
raise ValueError("Cannot ask kmeans2 for %d clusters"
|
||||
" (k was %s)" % (nc, code_book))
|
||||
elif nc != code_book:
|
||||
warnings.warn("k was not an integer, was converted.", stacklevel=2)
|
||||
|
||||
try:
|
||||
init_meth = _valid_init_meth[minit]
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Unknown init method {minit!r}") from e
|
||||
else:
|
||||
rng = check_random_state(seed)
|
||||
code_book = init_meth(data, code_book, rng, xp)
|
||||
|
||||
data = np.asarray(data)
|
||||
code_book = np.asarray(code_book)
|
||||
for i in range(iter):
|
||||
# Compute the nearest neighbor for each obs using the current code book
|
||||
label = vq(data, code_book, check_finite=check_finite)[0]
|
||||
# Update the code book by computing centroids
|
||||
new_code_book, has_members = _vq.update_cluster_means(data, label, nc)
|
||||
if not has_members.all():
|
||||
miss_meth()
|
||||
# Set the empty clusters to their previous positions
|
||||
new_code_book[~has_members] = code_book[~has_members]
|
||||
code_book = new_code_book
|
||||
|
||||
return xp.asarray(code_book), xp.asarray(label)
|
||||
413
venv/lib/python3.12/site-packages/scipy/conftest.py
Normal file
413
venv/lib/python3.12/site-packages/scipy/conftest.py
Normal file
@ -0,0 +1,413 @@
|
||||
# Pytest customization
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
|
||||
import numpy as np
|
||||
import numpy.testing as npt
|
||||
import pytest
|
||||
import hypothesis
|
||||
|
||||
from scipy._lib._fpumode import get_fpu_mode
|
||||
from scipy._lib._testutils import FPUModeChangeWarning
|
||||
from scipy._lib._array_api import SCIPY_ARRAY_API, SCIPY_DEVICE
|
||||
from scipy._lib import _pep440
|
||||
|
||||
try:
|
||||
from scipy_doctest.conftest import dt_config
|
||||
HAVE_SCPDT = True
|
||||
except ModuleNotFoundError:
|
||||
HAVE_SCPDT = False
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers",
|
||||
"slow: Tests that are very slow.")
|
||||
config.addinivalue_line("markers",
|
||||
"xslow: mark test as extremely slow (not run unless explicitly requested)")
|
||||
config.addinivalue_line("markers",
|
||||
"xfail_on_32bit: mark test as failing on 32-bit platforms")
|
||||
try:
|
||||
import pytest_timeout # noqa:F401
|
||||
except Exception:
|
||||
config.addinivalue_line(
|
||||
"markers", 'timeout: mark a test for a non-default timeout')
|
||||
try:
|
||||
# This is a more reliable test of whether pytest_fail_slow is installed
|
||||
# When I uninstalled it, `import pytest_fail_slow` didn't fail!
|
||||
from pytest_fail_slow import parse_duration # type: ignore[import-not-found] # noqa:F401,E501
|
||||
except Exception:
|
||||
config.addinivalue_line(
|
||||
"markers", 'fail_slow: mark a test for a non-default timeout failure')
|
||||
config.addinivalue_line("markers",
|
||||
"skip_xp_backends(*backends, reasons=None, np_only=False, cpu_only=False): "
|
||||
"mark the desired skip configuration for the `skip_xp_backends` fixture.")
|
||||
|
||||
|
||||
def pytest_runtest_setup(item):
|
||||
mark = item.get_closest_marker("xslow")
|
||||
if mark is not None:
|
||||
try:
|
||||
v = int(os.environ.get('SCIPY_XSLOW', '0'))
|
||||
except ValueError:
|
||||
v = False
|
||||
if not v:
|
||||
pytest.skip("very slow test; "
|
||||
"set environment variable SCIPY_XSLOW=1 to run it")
|
||||
mark = item.get_closest_marker("xfail_on_32bit")
|
||||
if mark is not None and np.intp(0).itemsize < 8:
|
||||
pytest.xfail(f'Fails on our 32-bit test platform(s): {mark.args[0]}')
|
||||
|
||||
# Older versions of threadpoolctl have an issue that may lead to this
|
||||
# warning being emitted, see gh-14441
|
||||
with npt.suppress_warnings() as sup:
|
||||
sup.filter(pytest.PytestUnraisableExceptionWarning)
|
||||
|
||||
try:
|
||||
from threadpoolctl import threadpool_limits
|
||||
|
||||
HAS_THREADPOOLCTL = True
|
||||
except Exception: # observed in gh-14441: (ImportError, AttributeError)
|
||||
# Optional dependency only. All exceptions are caught, for robustness
|
||||
HAS_THREADPOOLCTL = False
|
||||
|
||||
if HAS_THREADPOOLCTL:
|
||||
# Set the number of openmp threads based on the number of workers
|
||||
# xdist is using to prevent oversubscription. Simplified version of what
|
||||
# sklearn does (it can rely on threadpoolctl and its builtin OpenMP helper
|
||||
# functions)
|
||||
try:
|
||||
xdist_worker_count = int(os.environ['PYTEST_XDIST_WORKER_COUNT'])
|
||||
except KeyError:
|
||||
# raises when pytest-xdist is not installed
|
||||
return
|
||||
|
||||
if not os.getenv('OMP_NUM_THREADS'):
|
||||
max_openmp_threads = os.cpu_count() // 2 # use nr of physical cores
|
||||
threads_per_worker = max(max_openmp_threads // xdist_worker_count, 1)
|
||||
try:
|
||||
threadpool_limits(threads_per_worker, user_api='blas')
|
||||
except Exception:
|
||||
# May raise AttributeError for older versions of OpenBLAS.
|
||||
# Catch any error for robustness.
|
||||
return
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def check_fpu_mode(request):
|
||||
"""
|
||||
Check FPU mode was not changed during the test.
|
||||
"""
|
||||
old_mode = get_fpu_mode()
|
||||
yield
|
||||
new_mode = get_fpu_mode()
|
||||
|
||||
if old_mode != new_mode:
|
||||
warnings.warn(f"FPU mode changed from {old_mode:#x} to {new_mode:#x} during "
|
||||
"the test",
|
||||
category=FPUModeChangeWarning, stacklevel=0)
|
||||
|
||||
|
||||
# Array API backend handling
|
||||
xp_available_backends = {'numpy': np}
|
||||
|
||||
if SCIPY_ARRAY_API and isinstance(SCIPY_ARRAY_API, str):
|
||||
# fill the dict of backends with available libraries
|
||||
try:
|
||||
import array_api_strict
|
||||
xp_available_backends.update({'array_api_strict': array_api_strict})
|
||||
if _pep440.parse(array_api_strict.__version__) < _pep440.Version('2.0'):
|
||||
raise ImportError("array-api-strict must be >= version 2.0")
|
||||
array_api_strict.set_array_api_strict_flags(
|
||||
api_version='2023.12'
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import torch # type: ignore[import-not-found]
|
||||
xp_available_backends.update({'pytorch': torch})
|
||||
# can use `mps` or `cpu`
|
||||
torch.set_default_device(SCIPY_DEVICE)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import cupy # type: ignore[import-not-found]
|
||||
xp_available_backends.update({'cupy': cupy})
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import jax.numpy # type: ignore[import-not-found]
|
||||
xp_available_backends.update({'jax.numpy': jax.numpy})
|
||||
jax.config.update("jax_enable_x64", True)
|
||||
jax.config.update("jax_default_device", jax.devices(SCIPY_DEVICE)[0])
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# by default, use all available backends
|
||||
if SCIPY_ARRAY_API.lower() not in ("1", "true"):
|
||||
SCIPY_ARRAY_API_ = json.loads(SCIPY_ARRAY_API)
|
||||
|
||||
if 'all' in SCIPY_ARRAY_API_:
|
||||
pass # same as True
|
||||
else:
|
||||
# only select a subset of backend by filtering out the dict
|
||||
try:
|
||||
xp_available_backends = {
|
||||
backend: xp_available_backends[backend]
|
||||
for backend in SCIPY_ARRAY_API_
|
||||
}
|
||||
except KeyError:
|
||||
msg = f"'--array-api-backend' must be in {xp_available_backends.keys()}"
|
||||
raise ValueError(msg)
|
||||
|
||||
if 'cupy' in xp_available_backends:
|
||||
SCIPY_DEVICE = 'cuda'
|
||||
|
||||
array_api_compatible = pytest.mark.parametrize("xp", xp_available_backends.values())
|
||||
|
||||
skip_xp_invalid_arg = pytest.mark.skipif(SCIPY_ARRAY_API,
|
||||
reason = ('Test involves masked arrays, object arrays, or other types '
|
||||
'that are not valid input when `SCIPY_ARRAY_API` is used.'))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def skip_xp_backends(xp, request):
|
||||
"""
|
||||
Skip based on the ``skip_xp_backends`` marker.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
*backends : tuple
|
||||
Backends to skip, e.g. ``("array_api_strict", "torch")``.
|
||||
These are overriden when ``np_only`` is ``True``, and are not
|
||||
necessary to provide for non-CPU backends when ``cpu_only`` is ``True``.
|
||||
reasons : list, optional
|
||||
A list of reasons for each skip. When ``np_only`` is ``True``,
|
||||
this should be a singleton list. Otherwise, this should be a list
|
||||
of reasons, one for each corresponding backend in ``backends``.
|
||||
If unprovided, default reasons are used. Note that it is not possible
|
||||
to specify a custom reason with ``cpu_only``. Default: ``None``.
|
||||
np_only : bool, optional
|
||||
When ``True``, the test is skipped for all backends other
|
||||
than the default NumPy backend. There is no need to provide
|
||||
any ``backends`` in this case. To specify a reason, pass a
|
||||
singleton list to ``reasons``. Default: ``False``.
|
||||
cpu_only : bool, optional
|
||||
When ``True``, the test is skipped on non-CPU devices.
|
||||
There is no need to provide any ``backends`` in this case,
|
||||
but any ``backends`` will also be skipped on the CPU.
|
||||
Default: ``False``.
|
||||
"""
|
||||
if "skip_xp_backends" not in request.keywords:
|
||||
return
|
||||
backends = request.keywords["skip_xp_backends"].args
|
||||
kwargs = request.keywords["skip_xp_backends"].kwargs
|
||||
np_only = kwargs.get("np_only", False)
|
||||
cpu_only = kwargs.get("cpu_only", False)
|
||||
if np_only:
|
||||
reasons = kwargs.get("reasons", ["do not run with non-NumPy backends."])
|
||||
reason = reasons[0]
|
||||
if xp.__name__ != 'numpy':
|
||||
pytest.skip(reason=reason)
|
||||
return
|
||||
if cpu_only:
|
||||
reason = "do not run with `SCIPY_ARRAY_API` set and not on CPU"
|
||||
if SCIPY_ARRAY_API and SCIPY_DEVICE != 'cpu':
|
||||
if xp.__name__ == 'cupy':
|
||||
pytest.skip(reason=reason)
|
||||
elif xp.__name__ == 'torch':
|
||||
if 'cpu' not in xp.empty(0).device.type:
|
||||
pytest.skip(reason=reason)
|
||||
elif xp.__name__ == 'jax.numpy':
|
||||
for d in xp.empty(0).devices():
|
||||
if 'cpu' not in d.device_kind:
|
||||
pytest.skip(reason=reason)
|
||||
|
||||
if backends is not None:
|
||||
reasons = kwargs.get("reasons", False)
|
||||
for i, backend in enumerate(backends):
|
||||
if xp.__name__ == backend:
|
||||
if not reasons:
|
||||
reason = f"do not run with array API backend: {backend}"
|
||||
else:
|
||||
reason = reasons[i]
|
||||
pytest.skip(reason=reason)
|
||||
|
||||
|
||||
# Following the approach of NumPy's conftest.py...
|
||||
# Use a known and persistent tmpdir for hypothesis' caches, which
|
||||
# can be automatically cleared by the OS or user.
|
||||
hypothesis.configuration.set_hypothesis_home_dir(
|
||||
os.path.join(tempfile.gettempdir(), ".hypothesis")
|
||||
)
|
||||
|
||||
# We register two custom profiles for SciPy - for details see
|
||||
# https://hypothesis.readthedocs.io/en/latest/settings.html
|
||||
# The first is designed for our own CI runs; the latter also
|
||||
# forces determinism and is designed for use via scipy.test()
|
||||
hypothesis.settings.register_profile(
|
||||
name="nondeterministic", deadline=None, print_blob=True,
|
||||
)
|
||||
hypothesis.settings.register_profile(
|
||||
name="deterministic",
|
||||
deadline=None, print_blob=True, database=None, derandomize=True,
|
||||
suppress_health_check=list(hypothesis.HealthCheck),
|
||||
)
|
||||
|
||||
# Profile is currently set by environment variable `SCIPY_HYPOTHESIS_PROFILE`
|
||||
# In the future, it would be good to work the choice into dev.py.
|
||||
SCIPY_HYPOTHESIS_PROFILE = os.environ.get("SCIPY_HYPOTHESIS_PROFILE",
|
||||
"deterministic")
|
||||
hypothesis.settings.load_profile(SCIPY_HYPOTHESIS_PROFILE)
|
||||
|
||||
|
||||
############################################################################
|
||||
# doctesting stuff
|
||||
|
||||
if HAVE_SCPDT:
|
||||
|
||||
# FIXME: populate the dict once
|
||||
@contextmanager
|
||||
def warnings_errors_and_rng(test=None):
|
||||
"""Temporarily turn (almost) all warnings to errors.
|
||||
|
||||
Filter out known warnings which we allow.
|
||||
"""
|
||||
known_warnings = dict()
|
||||
|
||||
# these functions are known to emit "divide by zero" RuntimeWarnings
|
||||
divide_by_zero = [
|
||||
'scipy.linalg.norm', 'scipy.ndimage.center_of_mass',
|
||||
]
|
||||
for name in divide_by_zero:
|
||||
known_warnings[name] = dict(category=RuntimeWarning,
|
||||
message='divide by zero')
|
||||
|
||||
# Deprecated stuff in scipy.signal and elsewhere
|
||||
deprecated = [
|
||||
'scipy.signal.cwt', 'scipy.signal.morlet', 'scipy.signal.morlet2',
|
||||
'scipy.signal.ricker',
|
||||
'scipy.integrate.simpson',
|
||||
'scipy.interpolate.interp2d',
|
||||
]
|
||||
for name in deprecated:
|
||||
known_warnings[name] = dict(category=DeprecationWarning)
|
||||
|
||||
from scipy import integrate
|
||||
# the funcions are known to emit IntergrationWarnings
|
||||
integration_w = ['scipy.special.ellip_normal',
|
||||
'scipy.special.ellip_harm_2',
|
||||
]
|
||||
for name in integration_w:
|
||||
known_warnings[name] = dict(category=integrate.IntegrationWarning,
|
||||
message='The occurrence of roundoff')
|
||||
|
||||
# scipy.stats deliberately emits UserWarnings sometimes
|
||||
user_w = ['scipy.stats.anderson_ksamp', 'scipy.stats.kurtosistest',
|
||||
'scipy.stats.normaltest', 'scipy.sparse.linalg.norm']
|
||||
for name in user_w:
|
||||
known_warnings[name] = dict(category=UserWarning)
|
||||
|
||||
# additional one-off warnings to filter
|
||||
dct = {
|
||||
'scipy.sparse.linalg.norm':
|
||||
dict(category=UserWarning, message="Exited at iteration"),
|
||||
# tutorials
|
||||
'linalg.rst':
|
||||
dict(message='the matrix subclass is not',
|
||||
category=PendingDeprecationWarning),
|
||||
'stats.rst':
|
||||
dict(message='The maximum number of subdivisions',
|
||||
category=integrate.IntegrationWarning),
|
||||
}
|
||||
known_warnings.update(dct)
|
||||
|
||||
# these legitimately emit warnings in examples
|
||||
legit = set('scipy.signal.normalize')
|
||||
|
||||
# Now, the meat of the matter: filter warnings,
|
||||
# also control the random seed for each doctest.
|
||||
|
||||
# XXX: this matches the refguide-check behavior, but is a tad strange:
|
||||
# makes sure that the seed the old-fashioned np.random* methods is
|
||||
# *NOT* reproducible but the new-style `default_rng()` *IS* repoducible.
|
||||
# Should these two be either both repro or both not repro?
|
||||
|
||||
from scipy._lib._util import _fixed_default_rng
|
||||
import numpy as np
|
||||
with _fixed_default_rng():
|
||||
np.random.seed(None)
|
||||
with warnings.catch_warnings():
|
||||
if test and test.name in known_warnings:
|
||||
warnings.filterwarnings('ignore',
|
||||
**known_warnings[test.name])
|
||||
yield
|
||||
elif test and test.name in legit:
|
||||
yield
|
||||
else:
|
||||
warnings.simplefilter('error', Warning)
|
||||
yield
|
||||
|
||||
|
||||
dt_config.user_context_mgr = warnings_errors_and_rng
|
||||
dt_config.skiplist = set([
|
||||
'scipy.linalg.LinAlgError', # comes from numpy
|
||||
'scipy.fftpack.fftshift', # fftpack stuff is also from numpy
|
||||
'scipy.fftpack.ifftshift',
|
||||
'scipy.fftpack.fftfreq',
|
||||
'scipy.special.sinc', # sinc is from numpy
|
||||
'scipy.optimize.show_options', # does not have much to doctest
|
||||
'scipy.signal.normalize', # manipulates warnings (XXX temp skip)
|
||||
'scipy.sparse.linalg.norm', # XXX temp skip
|
||||
])
|
||||
|
||||
# these are affected by NumPy 2.0 scalar repr: rely on string comparison
|
||||
if np.__version__ < "2":
|
||||
dt_config.skiplist.update(set([
|
||||
'scipy.io.hb_read',
|
||||
'scipy.io.hb_write',
|
||||
'scipy.sparse.csgraph.connected_components',
|
||||
'scipy.sparse.csgraph.depth_first_order',
|
||||
'scipy.sparse.csgraph.shortest_path',
|
||||
'scipy.sparse.csgraph.floyd_warshall',
|
||||
'scipy.sparse.csgraph.dijkstra',
|
||||
'scipy.sparse.csgraph.bellman_ford',
|
||||
'scipy.sparse.csgraph.johnson',
|
||||
'scipy.sparse.csgraph.yen',
|
||||
'scipy.sparse.csgraph.breadth_first_order',
|
||||
'scipy.sparse.csgraph.reverse_cuthill_mckee',
|
||||
'scipy.sparse.csgraph.structural_rank',
|
||||
'scipy.sparse.csgraph.construct_dist_matrix',
|
||||
'scipy.sparse.csgraph.reconstruct_path',
|
||||
'scipy.ndimage.value_indices',
|
||||
'scipy.stats.mstats.describe',
|
||||
]))
|
||||
|
||||
# help pytest collection a bit: these names are either private
|
||||
# (distributions), or just do not need doctesting.
|
||||
dt_config.pytest_extra_ignore = [
|
||||
"scipy.stats.distributions",
|
||||
"scipy.optimize.cython_optimize",
|
||||
"scipy.test",
|
||||
"scipy.show_config",
|
||||
]
|
||||
|
||||
dt_config.pytest_extra_xfail = {
|
||||
# name: reason
|
||||
"io.rst": "",
|
||||
"ND_regular_grid.rst": "ReST parser limitation",
|
||||
"extrapolation_examples.rst": "ReST parser limitation",
|
||||
"sampling_pinv.rst": "__cinit__ unexpected argument",
|
||||
"sampling_srou.rst": "nan in scalar_power",
|
||||
"probability_distributions.rst": "integration warning",
|
||||
}
|
||||
|
||||
# tutorials
|
||||
dt_config.pseudocode = set(['integrate.nquad(func,'])
|
||||
dt_config.local_resources = {'io.rst': ["octave_a.mat"]}
|
||||
############################################################################
|
||||
347
venv/lib/python3.12/site-packages/scipy/constants/__init__.py
Normal file
347
venv/lib/python3.12/site-packages/scipy/constants/__init__.py
Normal file
@ -0,0 +1,347 @@
|
||||
r"""
|
||||
==================================
|
||||
Constants (:mod:`scipy.constants`)
|
||||
==================================
|
||||
|
||||
.. currentmodule:: scipy.constants
|
||||
|
||||
Physical and mathematical constants and units.
|
||||
|
||||
|
||||
Mathematical constants
|
||||
======================
|
||||
|
||||
================ =================================================================
|
||||
``pi`` Pi
|
||||
``golden`` Golden ratio
|
||||
``golden_ratio`` Golden ratio
|
||||
================ =================================================================
|
||||
|
||||
|
||||
Physical constants
|
||||
==================
|
||||
|
||||
=========================== =================================================================
|
||||
``c`` speed of light in vacuum
|
||||
``speed_of_light`` speed of light in vacuum
|
||||
``mu_0`` the magnetic constant :math:`\mu_0`
|
||||
``epsilon_0`` the electric constant (vacuum permittivity), :math:`\epsilon_0`
|
||||
``h`` the Planck constant :math:`h`
|
||||
``Planck`` the Planck constant :math:`h`
|
||||
``hbar`` :math:`\hbar = h/(2\pi)`
|
||||
``G`` Newtonian constant of gravitation
|
||||
``gravitational_constant`` Newtonian constant of gravitation
|
||||
``g`` standard acceleration of gravity
|
||||
``e`` elementary charge
|
||||
``elementary_charge`` elementary charge
|
||||
``R`` molar gas constant
|
||||
``gas_constant`` molar gas constant
|
||||
``alpha`` fine-structure constant
|
||||
``fine_structure`` fine-structure constant
|
||||
``N_A`` Avogadro constant
|
||||
``Avogadro`` Avogadro constant
|
||||
``k`` Boltzmann constant
|
||||
``Boltzmann`` Boltzmann constant
|
||||
``sigma`` Stefan-Boltzmann constant :math:`\sigma`
|
||||
``Stefan_Boltzmann`` Stefan-Boltzmann constant :math:`\sigma`
|
||||
``Wien`` Wien displacement law constant
|
||||
``Rydberg`` Rydberg constant
|
||||
``m_e`` electron mass
|
||||
``electron_mass`` electron mass
|
||||
``m_p`` proton mass
|
||||
``proton_mass`` proton mass
|
||||
``m_n`` neutron mass
|
||||
``neutron_mass`` neutron mass
|
||||
=========================== =================================================================
|
||||
|
||||
|
||||
Constants database
|
||||
------------------
|
||||
|
||||
In addition to the above variables, :mod:`scipy.constants` also contains the
|
||||
2018 CODATA recommended values [CODATA2018]_ database containing more physical
|
||||
constants.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated/
|
||||
|
||||
value -- Value in physical_constants indexed by key
|
||||
unit -- Unit in physical_constants indexed by key
|
||||
precision -- Relative precision in physical_constants indexed by key
|
||||
find -- Return list of physical_constant keys with a given string
|
||||
ConstantWarning -- Constant sought not in newest CODATA data set
|
||||
|
||||
.. data:: physical_constants
|
||||
|
||||
Dictionary of physical constants, of the format
|
||||
``physical_constants[name] = (value, unit, uncertainty)``.
|
||||
|
||||
Available constants:
|
||||
|
||||
====================================================================== ====
|
||||
%(constant_names)s
|
||||
====================================================================== ====
|
||||
|
||||
|
||||
Units
|
||||
=====
|
||||
|
||||
SI prefixes
|
||||
-----------
|
||||
|
||||
============ =================================================================
|
||||
``quetta`` :math:`10^{30}`
|
||||
``ronna`` :math:`10^{27}`
|
||||
``yotta`` :math:`10^{24}`
|
||||
``zetta`` :math:`10^{21}`
|
||||
``exa`` :math:`10^{18}`
|
||||
``peta`` :math:`10^{15}`
|
||||
``tera`` :math:`10^{12}`
|
||||
``giga`` :math:`10^{9}`
|
||||
``mega`` :math:`10^{6}`
|
||||
``kilo`` :math:`10^{3}`
|
||||
``hecto`` :math:`10^{2}`
|
||||
``deka`` :math:`10^{1}`
|
||||
``deci`` :math:`10^{-1}`
|
||||
``centi`` :math:`10^{-2}`
|
||||
``milli`` :math:`10^{-3}`
|
||||
``micro`` :math:`10^{-6}`
|
||||
``nano`` :math:`10^{-9}`
|
||||
``pico`` :math:`10^{-12}`
|
||||
``femto`` :math:`10^{-15}`
|
||||
``atto`` :math:`10^{-18}`
|
||||
``zepto`` :math:`10^{-21}`
|
||||
``yocto`` :math:`10^{-24}`
|
||||
``ronto`` :math:`10^{-27}`
|
||||
``quecto`` :math:`10^{-30}`
|
||||
============ =================================================================
|
||||
|
||||
Binary prefixes
|
||||
---------------
|
||||
|
||||
============ =================================================================
|
||||
``kibi`` :math:`2^{10}`
|
||||
``mebi`` :math:`2^{20}`
|
||||
``gibi`` :math:`2^{30}`
|
||||
``tebi`` :math:`2^{40}`
|
||||
``pebi`` :math:`2^{50}`
|
||||
``exbi`` :math:`2^{60}`
|
||||
``zebi`` :math:`2^{70}`
|
||||
``yobi`` :math:`2^{80}`
|
||||
============ =================================================================
|
||||
|
||||
Mass
|
||||
----
|
||||
|
||||
================= ============================================================
|
||||
``gram`` :math:`10^{-3}` kg
|
||||
``metric_ton`` :math:`10^{3}` kg
|
||||
``grain`` one grain in kg
|
||||
``lb`` one pound (avoirdupous) in kg
|
||||
``pound`` one pound (avoirdupous) in kg
|
||||
``blob`` one inch version of a slug in kg (added in 1.0.0)
|
||||
``slinch`` one inch version of a slug in kg (added in 1.0.0)
|
||||
``slug`` one slug in kg (added in 1.0.0)
|
||||
``oz`` one ounce in kg
|
||||
``ounce`` one ounce in kg
|
||||
``stone`` one stone in kg
|
||||
``grain`` one grain in kg
|
||||
``long_ton`` one long ton in kg
|
||||
``short_ton`` one short ton in kg
|
||||
``troy_ounce`` one Troy ounce in kg
|
||||
``troy_pound`` one Troy pound in kg
|
||||
``carat`` one carat in kg
|
||||
``m_u`` atomic mass constant (in kg)
|
||||
``u`` atomic mass constant (in kg)
|
||||
``atomic_mass`` atomic mass constant (in kg)
|
||||
================= ============================================================
|
||||
|
||||
Angle
|
||||
-----
|
||||
|
||||
================= ============================================================
|
||||
``degree`` degree in radians
|
||||
``arcmin`` arc minute in radians
|
||||
``arcminute`` arc minute in radians
|
||||
``arcsec`` arc second in radians
|
||||
``arcsecond`` arc second in radians
|
||||
================= ============================================================
|
||||
|
||||
|
||||
Time
|
||||
----
|
||||
|
||||
================= ============================================================
|
||||
``minute`` one minute in seconds
|
||||
``hour`` one hour in seconds
|
||||
``day`` one day in seconds
|
||||
``week`` one week in seconds
|
||||
``year`` one year (365 days) in seconds
|
||||
``Julian_year`` one Julian year (365.25 days) in seconds
|
||||
================= ============================================================
|
||||
|
||||
|
||||
Length
|
||||
------
|
||||
|
||||
===================== ============================================================
|
||||
``inch`` one inch in meters
|
||||
``foot`` one foot in meters
|
||||
``yard`` one yard in meters
|
||||
``mile`` one mile in meters
|
||||
``mil`` one mil in meters
|
||||
``pt`` one point in meters
|
||||
``point`` one point in meters
|
||||
``survey_foot`` one survey foot in meters
|
||||
``survey_mile`` one survey mile in meters
|
||||
``nautical_mile`` one nautical mile in meters
|
||||
``fermi`` one Fermi in meters
|
||||
``angstrom`` one Angstrom in meters
|
||||
``micron`` one micron in meters
|
||||
``au`` one astronomical unit in meters
|
||||
``astronomical_unit`` one astronomical unit in meters
|
||||
``light_year`` one light year in meters
|
||||
``parsec`` one parsec in meters
|
||||
===================== ============================================================
|
||||
|
||||
Pressure
|
||||
--------
|
||||
|
||||
================= ============================================================
|
||||
``atm`` standard atmosphere in pascals
|
||||
``atmosphere`` standard atmosphere in pascals
|
||||
``bar`` one bar in pascals
|
||||
``torr`` one torr (mmHg) in pascals
|
||||
``mmHg`` one torr (mmHg) in pascals
|
||||
``psi`` one psi in pascals
|
||||
================= ============================================================
|
||||
|
||||
Area
|
||||
----
|
||||
|
||||
================= ============================================================
|
||||
``hectare`` one hectare in square meters
|
||||
``acre`` one acre in square meters
|
||||
================= ============================================================
|
||||
|
||||
|
||||
Volume
|
||||
------
|
||||
|
||||
=================== ========================================================
|
||||
``liter`` one liter in cubic meters
|
||||
``litre`` one liter in cubic meters
|
||||
``gallon`` one gallon (US) in cubic meters
|
||||
``gallon_US`` one gallon (US) in cubic meters
|
||||
``gallon_imp`` one gallon (UK) in cubic meters
|
||||
``fluid_ounce`` one fluid ounce (US) in cubic meters
|
||||
``fluid_ounce_US`` one fluid ounce (US) in cubic meters
|
||||
``fluid_ounce_imp`` one fluid ounce (UK) in cubic meters
|
||||
``bbl`` one barrel in cubic meters
|
||||
``barrel`` one barrel in cubic meters
|
||||
=================== ========================================================
|
||||
|
||||
Speed
|
||||
-----
|
||||
|
||||
================== ==========================================================
|
||||
``kmh`` kilometers per hour in meters per second
|
||||
``mph`` miles per hour in meters per second
|
||||
``mach`` one Mach (approx., at 15 C, 1 atm) in meters per second
|
||||
``speed_of_sound`` one Mach (approx., at 15 C, 1 atm) in meters per second
|
||||
``knot`` one knot in meters per second
|
||||
================== ==========================================================
|
||||
|
||||
|
||||
Temperature
|
||||
-----------
|
||||
|
||||
===================== =======================================================
|
||||
``zero_Celsius`` zero of Celsius scale in Kelvin
|
||||
``degree_Fahrenheit`` one Fahrenheit (only differences) in Kelvins
|
||||
===================== =======================================================
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated/
|
||||
|
||||
convert_temperature
|
||||
|
||||
Energy
|
||||
------
|
||||
|
||||
==================== =======================================================
|
||||
``eV`` one electron volt in Joules
|
||||
``electron_volt`` one electron volt in Joules
|
||||
``calorie`` one calorie (thermochemical) in Joules
|
||||
``calorie_th`` one calorie (thermochemical) in Joules
|
||||
``calorie_IT`` one calorie (International Steam Table calorie, 1956) in Joules
|
||||
``erg`` one erg in Joules
|
||||
``Btu`` one British thermal unit (International Steam Table) in Joules
|
||||
``Btu_IT`` one British thermal unit (International Steam Table) in Joules
|
||||
``Btu_th`` one British thermal unit (thermochemical) in Joules
|
||||
``ton_TNT`` one ton of TNT in Joules
|
||||
==================== =======================================================
|
||||
|
||||
Power
|
||||
-----
|
||||
|
||||
==================== =======================================================
|
||||
``hp`` one horsepower in watts
|
||||
``horsepower`` one horsepower in watts
|
||||
==================== =======================================================
|
||||
|
||||
Force
|
||||
-----
|
||||
|
||||
==================== =======================================================
|
||||
``dyn`` one dyne in newtons
|
||||
``dyne`` one dyne in newtons
|
||||
``lbf`` one pound force in newtons
|
||||
``pound_force`` one pound force in newtons
|
||||
``kgf`` one kilogram force in newtons
|
||||
``kilogram_force`` one kilogram force in newtons
|
||||
==================== =======================================================
|
||||
|
||||
Optics
|
||||
------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated/
|
||||
|
||||
lambda2nu
|
||||
nu2lambda
|
||||
|
||||
References
|
||||
==========
|
||||
|
||||
.. [CODATA2018] CODATA Recommended Values of the Fundamental
|
||||
Physical Constants 2018.
|
||||
|
||||
https://physics.nist.gov/cuu/Constants/
|
||||
|
||||
""" # noqa: E501
|
||||
# Modules contributed by BasSw (wegwerp@gmail.com)
|
||||
from ._codata import *
|
||||
from ._constants import *
|
||||
from ._codata import _obsolete_constants, physical_constants
|
||||
|
||||
# Deprecated namespaces, to be removed in v2.0.0
|
||||
from . import codata, constants
|
||||
|
||||
_constant_names_list = [(_k.lower(), _k, _v)
|
||||
for _k, _v in physical_constants.items()
|
||||
if _k not in _obsolete_constants]
|
||||
_constant_names = "\n".join(["``{}``{} {} {}".format(_x[1], " "*(66-len(_x[1])),
|
||||
_x[2][0], _x[2][1])
|
||||
for _x in sorted(_constant_names_list)])
|
||||
if __doc__:
|
||||
__doc__ = __doc__ % dict(constant_names=_constant_names)
|
||||
|
||||
del _constant_names
|
||||
del _constant_names_list
|
||||
|
||||
__all__ = [s for s in dir() if not s.startswith('_')]
|
||||
|
||||
from scipy._lib._testutils import PytestTester
|
||||
test = PytestTester(__name__)
|
||||
del PytestTester
|
||||
1748
venv/lib/python3.12/site-packages/scipy/constants/_codata.py
Normal file
1748
venv/lib/python3.12/site-packages/scipy/constants/_codata.py
Normal file
File diff suppressed because it is too large
Load Diff
368
venv/lib/python3.12/site-packages/scipy/constants/_constants.py
Normal file
368
venv/lib/python3.12/site-packages/scipy/constants/_constants.py
Normal file
@ -0,0 +1,368 @@
|
||||
"""
|
||||
Collection of physical constants and conversion factors.
|
||||
|
||||
Most constants are in SI units, so you can do
|
||||
print '10 mile per minute is', 10*mile/minute, 'm/s or', 10*mile/(minute*knot), 'knots'
|
||||
|
||||
The list is not meant to be comprehensive, but just convenient for everyday use.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math as _math
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from ._codata import value as _cd
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy.typing as npt
|
||||
|
||||
from scipy._lib._array_api import array_namespace, _asarray
|
||||
|
||||
|
||||
"""
|
||||
BasSw 2006
|
||||
physical constants: imported from CODATA
|
||||
unit conversion: see e.g., NIST special publication 811
|
||||
Use at own risk: double-check values before calculating your Mars orbit-insertion burn.
|
||||
Some constants exist in a few variants, which are marked with suffixes.
|
||||
The ones without any suffix should be the most common ones.
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
'Avogadro', 'Boltzmann', 'Btu', 'Btu_IT', 'Btu_th', 'G',
|
||||
'Julian_year', 'N_A', 'Planck', 'R', 'Rydberg',
|
||||
'Stefan_Boltzmann', 'Wien', 'acre', 'alpha',
|
||||
'angstrom', 'arcmin', 'arcminute', 'arcsec',
|
||||
'arcsecond', 'astronomical_unit', 'atm',
|
||||
'atmosphere', 'atomic_mass', 'atto', 'au', 'bar',
|
||||
'barrel', 'bbl', 'blob', 'c', 'calorie',
|
||||
'calorie_IT', 'calorie_th', 'carat', 'centi',
|
||||
'convert_temperature', 'day', 'deci', 'degree',
|
||||
'degree_Fahrenheit', 'deka', 'dyn', 'dyne', 'e',
|
||||
'eV', 'electron_mass', 'electron_volt',
|
||||
'elementary_charge', 'epsilon_0', 'erg',
|
||||
'exa', 'exbi', 'femto', 'fermi', 'fine_structure',
|
||||
'fluid_ounce', 'fluid_ounce_US', 'fluid_ounce_imp',
|
||||
'foot', 'g', 'gallon', 'gallon_US', 'gallon_imp',
|
||||
'gas_constant', 'gibi', 'giga', 'golden', 'golden_ratio',
|
||||
'grain', 'gram', 'gravitational_constant', 'h', 'hbar',
|
||||
'hectare', 'hecto', 'horsepower', 'hour', 'hp',
|
||||
'inch', 'k', 'kgf', 'kibi', 'kilo', 'kilogram_force',
|
||||
'kmh', 'knot', 'lambda2nu', 'lb', 'lbf',
|
||||
'light_year', 'liter', 'litre', 'long_ton', 'm_e',
|
||||
'm_n', 'm_p', 'm_u', 'mach', 'mebi', 'mega',
|
||||
'metric_ton', 'micro', 'micron', 'mil', 'mile',
|
||||
'milli', 'minute', 'mmHg', 'mph', 'mu_0', 'nano',
|
||||
'nautical_mile', 'neutron_mass', 'nu2lambda',
|
||||
'ounce', 'oz', 'parsec', 'pebi', 'peta',
|
||||
'pi', 'pico', 'point', 'pound', 'pound_force',
|
||||
'proton_mass', 'psi', 'pt', 'quecto', 'quetta', 'ronna', 'ronto',
|
||||
'short_ton', 'sigma', 'slinch', 'slug', 'speed_of_light',
|
||||
'speed_of_sound', 'stone', 'survey_foot',
|
||||
'survey_mile', 'tebi', 'tera', 'ton_TNT',
|
||||
'torr', 'troy_ounce', 'troy_pound', 'u',
|
||||
'week', 'yard', 'year', 'yobi', 'yocto',
|
||||
'yotta', 'zebi', 'zepto', 'zero_Celsius', 'zetta'
|
||||
]
|
||||
|
||||
|
||||
# mathematical constants
|
||||
pi = _math.pi
|
||||
golden = golden_ratio = (1 + _math.sqrt(5)) / 2
|
||||
|
||||
# SI prefixes
|
||||
quetta = 1e30
|
||||
ronna = 1e27
|
||||
yotta = 1e24
|
||||
zetta = 1e21
|
||||
exa = 1e18
|
||||
peta = 1e15
|
||||
tera = 1e12
|
||||
giga = 1e9
|
||||
mega = 1e6
|
||||
kilo = 1e3
|
||||
hecto = 1e2
|
||||
deka = 1e1
|
||||
deci = 1e-1
|
||||
centi = 1e-2
|
||||
milli = 1e-3
|
||||
micro = 1e-6
|
||||
nano = 1e-9
|
||||
pico = 1e-12
|
||||
femto = 1e-15
|
||||
atto = 1e-18
|
||||
zepto = 1e-21
|
||||
yocto = 1e-24
|
||||
ronto = 1e-27
|
||||
quecto = 1e-30
|
||||
|
||||
# binary prefixes
|
||||
kibi = 2**10
|
||||
mebi = 2**20
|
||||
gibi = 2**30
|
||||
tebi = 2**40
|
||||
pebi = 2**50
|
||||
exbi = 2**60
|
||||
zebi = 2**70
|
||||
yobi = 2**80
|
||||
|
||||
# physical constants
|
||||
c = speed_of_light = _cd('speed of light in vacuum')
|
||||
mu_0 = _cd('vacuum mag. permeability')
|
||||
epsilon_0 = _cd('vacuum electric permittivity')
|
||||
h = Planck = _cd('Planck constant')
|
||||
hbar = h / (2 * pi)
|
||||
G = gravitational_constant = _cd('Newtonian constant of gravitation')
|
||||
g = _cd('standard acceleration of gravity')
|
||||
e = elementary_charge = _cd('elementary charge')
|
||||
R = gas_constant = _cd('molar gas constant')
|
||||
alpha = fine_structure = _cd('fine-structure constant')
|
||||
N_A = Avogadro = _cd('Avogadro constant')
|
||||
k = Boltzmann = _cd('Boltzmann constant')
|
||||
sigma = Stefan_Boltzmann = _cd('Stefan-Boltzmann constant')
|
||||
Wien = _cd('Wien wavelength displacement law constant')
|
||||
Rydberg = _cd('Rydberg constant')
|
||||
|
||||
# mass in kg
|
||||
gram = 1e-3
|
||||
metric_ton = 1e3
|
||||
grain = 64.79891e-6
|
||||
lb = pound = 7000 * grain # avoirdupois
|
||||
blob = slinch = pound * g / 0.0254 # lbf*s**2/in (added in 1.0.0)
|
||||
slug = blob / 12 # lbf*s**2/foot (added in 1.0.0)
|
||||
oz = ounce = pound / 16
|
||||
stone = 14 * pound
|
||||
long_ton = 2240 * pound
|
||||
short_ton = 2000 * pound
|
||||
|
||||
troy_ounce = 480 * grain # only for metals / gems
|
||||
troy_pound = 12 * troy_ounce
|
||||
carat = 200e-6
|
||||
|
||||
m_e = electron_mass = _cd('electron mass')
|
||||
m_p = proton_mass = _cd('proton mass')
|
||||
m_n = neutron_mass = _cd('neutron mass')
|
||||
m_u = u = atomic_mass = _cd('atomic mass constant')
|
||||
|
||||
# angle in rad
|
||||
degree = pi / 180
|
||||
arcmin = arcminute = degree / 60
|
||||
arcsec = arcsecond = arcmin / 60
|
||||
|
||||
# time in second
|
||||
minute = 60.0
|
||||
hour = 60 * minute
|
||||
day = 24 * hour
|
||||
week = 7 * day
|
||||
year = 365 * day
|
||||
Julian_year = 365.25 * day
|
||||
|
||||
# length in meter
|
||||
inch = 0.0254
|
||||
foot = 12 * inch
|
||||
yard = 3 * foot
|
||||
mile = 1760 * yard
|
||||
mil = inch / 1000
|
||||
pt = point = inch / 72 # typography
|
||||
survey_foot = 1200.0 / 3937
|
||||
survey_mile = 5280 * survey_foot
|
||||
nautical_mile = 1852.0
|
||||
fermi = 1e-15
|
||||
angstrom = 1e-10
|
||||
micron = 1e-6
|
||||
au = astronomical_unit = 149597870700.0
|
||||
light_year = Julian_year * c
|
||||
parsec = au / arcsec
|
||||
|
||||
# pressure in pascal
|
||||
atm = atmosphere = _cd('standard atmosphere')
|
||||
bar = 1e5
|
||||
torr = mmHg = atm / 760
|
||||
psi = pound * g / (inch * inch)
|
||||
|
||||
# area in meter**2
|
||||
hectare = 1e4
|
||||
acre = 43560 * foot**2
|
||||
|
||||
# volume in meter**3
|
||||
litre = liter = 1e-3
|
||||
gallon = gallon_US = 231 * inch**3 # US
|
||||
# pint = gallon_US / 8
|
||||
fluid_ounce = fluid_ounce_US = gallon_US / 128
|
||||
bbl = barrel = 42 * gallon_US # for oil
|
||||
|
||||
gallon_imp = 4.54609e-3 # UK
|
||||
fluid_ounce_imp = gallon_imp / 160
|
||||
|
||||
# speed in meter per second
|
||||
kmh = 1e3 / hour
|
||||
mph = mile / hour
|
||||
# approx value of mach at 15 degrees in 1 atm. Is this a common value?
|
||||
mach = speed_of_sound = 340.5
|
||||
knot = nautical_mile / hour
|
||||
|
||||
# temperature in kelvin
|
||||
zero_Celsius = 273.15
|
||||
degree_Fahrenheit = 1/1.8 # only for differences
|
||||
|
||||
# energy in joule
|
||||
eV = electron_volt = elementary_charge # * 1 Volt
|
||||
calorie = calorie_th = 4.184
|
||||
calorie_IT = 4.1868
|
||||
erg = 1e-7
|
||||
Btu_th = pound * degree_Fahrenheit * calorie_th / gram
|
||||
Btu = Btu_IT = pound * degree_Fahrenheit * calorie_IT / gram
|
||||
ton_TNT = 1e9 * calorie_th
|
||||
# Wh = watt_hour
|
||||
|
||||
# power in watt
|
||||
hp = horsepower = 550 * foot * pound * g
|
||||
|
||||
# force in newton
|
||||
dyn = dyne = 1e-5
|
||||
lbf = pound_force = pound * g
|
||||
kgf = kilogram_force = g # * 1 kg
|
||||
|
||||
# functions for conversions that are not linear
|
||||
|
||||
|
||||
def convert_temperature(
|
||||
val: npt.ArrayLike,
|
||||
old_scale: str,
|
||||
new_scale: str,
|
||||
) -> Any:
|
||||
"""
|
||||
Convert from a temperature scale to another one among Celsius, Kelvin,
|
||||
Fahrenheit, and Rankine scales.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
val : array_like
|
||||
Value(s) of the temperature(s) to be converted expressed in the
|
||||
original scale.
|
||||
old_scale : str
|
||||
Specifies as a string the original scale from which the temperature
|
||||
value(s) will be converted. Supported scales are Celsius ('Celsius',
|
||||
'celsius', 'C' or 'c'), Kelvin ('Kelvin', 'kelvin', 'K', 'k'),
|
||||
Fahrenheit ('Fahrenheit', 'fahrenheit', 'F' or 'f'), and Rankine
|
||||
('Rankine', 'rankine', 'R', 'r').
|
||||
new_scale : str
|
||||
Specifies as a string the new scale to which the temperature
|
||||
value(s) will be converted. Supported scales are Celsius ('Celsius',
|
||||
'celsius', 'C' or 'c'), Kelvin ('Kelvin', 'kelvin', 'K', 'k'),
|
||||
Fahrenheit ('Fahrenheit', 'fahrenheit', 'F' or 'f'), and Rankine
|
||||
('Rankine', 'rankine', 'R', 'r').
|
||||
|
||||
Returns
|
||||
-------
|
||||
res : float or array of floats
|
||||
Value(s) of the converted temperature(s) expressed in the new scale.
|
||||
|
||||
Notes
|
||||
-----
|
||||
.. versionadded:: 0.18.0
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from scipy.constants import convert_temperature
|
||||
>>> import numpy as np
|
||||
>>> convert_temperature(np.array([-40, 40]), 'Celsius', 'Kelvin')
|
||||
array([ 233.15, 313.15])
|
||||
|
||||
"""
|
||||
xp = array_namespace(val)
|
||||
_val = _asarray(val, xp=xp, subok=True)
|
||||
# Convert from `old_scale` to Kelvin
|
||||
if old_scale.lower() in ['celsius', 'c']:
|
||||
tempo = _val + zero_Celsius
|
||||
elif old_scale.lower() in ['kelvin', 'k']:
|
||||
tempo = _val
|
||||
elif old_scale.lower() in ['fahrenheit', 'f']:
|
||||
tempo = (_val - 32) * 5 / 9 + zero_Celsius
|
||||
elif old_scale.lower() in ['rankine', 'r']:
|
||||
tempo = _val * 5 / 9
|
||||
else:
|
||||
raise NotImplementedError(f"{old_scale=} is unsupported: supported scales "
|
||||
"are Celsius, Kelvin, Fahrenheit, and "
|
||||
"Rankine")
|
||||
# and from Kelvin to `new_scale`.
|
||||
if new_scale.lower() in ['celsius', 'c']:
|
||||
res = tempo - zero_Celsius
|
||||
elif new_scale.lower() in ['kelvin', 'k']:
|
||||
res = tempo
|
||||
elif new_scale.lower() in ['fahrenheit', 'f']:
|
||||
res = (tempo - zero_Celsius) * 9 / 5 + 32
|
||||
elif new_scale.lower() in ['rankine', 'r']:
|
||||
res = tempo * 9 / 5
|
||||
else:
|
||||
raise NotImplementedError(f"{new_scale=} is unsupported: supported "
|
||||
"scales are 'Celsius', 'Kelvin', "
|
||||
"'Fahrenheit', and 'Rankine'")
|
||||
|
||||
return res
|
||||
|
||||
|
||||
# optics
|
||||
|
||||
|
||||
def lambda2nu(lambda_: npt.ArrayLike) -> Any:
|
||||
"""
|
||||
Convert wavelength to optical frequency
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lambda_ : array_like
|
||||
Wavelength(s) to be converted.
|
||||
|
||||
Returns
|
||||
-------
|
||||
nu : float or array of floats
|
||||
Equivalent optical frequency.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Computes ``nu = c / lambda`` where c = 299792458.0, i.e., the
|
||||
(vacuum) speed of light in meters/second.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from scipy.constants import lambda2nu, speed_of_light
|
||||
>>> import numpy as np
|
||||
>>> lambda2nu(np.array((1, speed_of_light)))
|
||||
array([ 2.99792458e+08, 1.00000000e+00])
|
||||
|
||||
"""
|
||||
xp = array_namespace(lambda_)
|
||||
return c / _asarray(lambda_, xp=xp, subok=True)
|
||||
|
||||
|
||||
def nu2lambda(nu: npt.ArrayLike) -> Any:
|
||||
"""
|
||||
Convert optical frequency to wavelength.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nu : array_like
|
||||
Optical frequency to be converted.
|
||||
|
||||
Returns
|
||||
-------
|
||||
lambda : float or array of floats
|
||||
Equivalent wavelength(s).
|
||||
|
||||
Notes
|
||||
-----
|
||||
Computes ``lambda = c / nu`` where c = 299792458.0, i.e., the
|
||||
(vacuum) speed of light in meters/second.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from scipy.constants import nu2lambda, speed_of_light
|
||||
>>> import numpy as np
|
||||
>>> nu2lambda(np.array((1, speed_of_light)))
|
||||
array([ 2.99792458e+08, 1.00000000e+00])
|
||||
|
||||
"""
|
||||
xp = array_namespace(nu)
|
||||
return c / _asarray(nu, xp=xp, subok=True)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user