asd
This commit is contained in:
@ -0,0 +1,101 @@
|
||||
""" Test for assert_deallocated context manager and gc utilities
|
||||
"""
|
||||
import gc
|
||||
|
||||
from scipy._lib._gcutils import (set_gc_state, gc_state, assert_deallocated,
|
||||
ReferenceError, IS_PYPY)
|
||||
|
||||
from numpy.testing import assert_equal
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_set_gc_state():
|
||||
gc_status = gc.isenabled()
|
||||
try:
|
||||
for state in (True, False):
|
||||
gc.enable()
|
||||
set_gc_state(state)
|
||||
assert_equal(gc.isenabled(), state)
|
||||
gc.disable()
|
||||
set_gc_state(state)
|
||||
assert_equal(gc.isenabled(), state)
|
||||
finally:
|
||||
if gc_status:
|
||||
gc.enable()
|
||||
|
||||
|
||||
def test_gc_state():
|
||||
# Test gc_state context manager
|
||||
gc_status = gc.isenabled()
|
||||
try:
|
||||
for pre_state in (True, False):
|
||||
set_gc_state(pre_state)
|
||||
for with_state in (True, False):
|
||||
# Check the gc state is with_state in with block
|
||||
with gc_state(with_state):
|
||||
assert_equal(gc.isenabled(), with_state)
|
||||
# And returns to previous state outside block
|
||||
assert_equal(gc.isenabled(), pre_state)
|
||||
# Even if the gc state is set explicitly within the block
|
||||
with gc_state(with_state):
|
||||
assert_equal(gc.isenabled(), with_state)
|
||||
set_gc_state(not with_state)
|
||||
assert_equal(gc.isenabled(), pre_state)
|
||||
finally:
|
||||
if gc_status:
|
||||
gc.enable()
|
||||
|
||||
|
||||
@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
|
||||
def test_assert_deallocated():
|
||||
# Ordinary use
|
||||
class C:
|
||||
def __init__(self, arg0, arg1, name='myname'):
|
||||
self.name = name
|
||||
for gc_current in (True, False):
|
||||
with gc_state(gc_current):
|
||||
# We are deleting from with-block context, so that's OK
|
||||
with assert_deallocated(C, 0, 2, 'another name') as c:
|
||||
assert_equal(c.name, 'another name')
|
||||
del c
|
||||
# Or not using the thing in with-block context, also OK
|
||||
with assert_deallocated(C, 0, 2, name='third name'):
|
||||
pass
|
||||
assert_equal(gc.isenabled(), gc_current)
|
||||
|
||||
|
||||
@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
|
||||
def test_assert_deallocated_nodel():
|
||||
class C:
|
||||
pass
|
||||
with pytest.raises(ReferenceError):
|
||||
# Need to delete after using if in with-block context
|
||||
# Note: assert_deallocated(C) needs to be assigned for the test
|
||||
# to function correctly. It is assigned to _, but _ itself is
|
||||
# not referenced in the body of the with, it is only there for
|
||||
# the refcount.
|
||||
with assert_deallocated(C) as _:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
|
||||
def test_assert_deallocated_circular():
|
||||
class C:
|
||||
def __init__(self):
|
||||
self._circular = self
|
||||
with pytest.raises(ReferenceError):
|
||||
# Circular reference, no automatic garbage collection
|
||||
with assert_deallocated(C) as c:
|
||||
del c
|
||||
|
||||
|
||||
@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
|
||||
def test_assert_deallocated_circular2():
|
||||
class C:
|
||||
def __init__(self):
|
||||
self._circular = self
|
||||
with pytest.raises(ReferenceError):
|
||||
# Still circular reference, no automatic garbage collection
|
||||
with assert_deallocated(C):
|
||||
pass
|
||||
@ -0,0 +1,67 @@
|
||||
from pytest import raises as assert_raises
|
||||
from scipy._lib._pep440 import Version, parse
|
||||
|
||||
|
||||
def test_main_versions():
|
||||
assert Version('1.8.0') == Version('1.8.0')
|
||||
for ver in ['1.9.0', '2.0.0', '1.8.1']:
|
||||
assert Version('1.8.0') < Version(ver)
|
||||
|
||||
for ver in ['1.7.0', '1.7.1', '0.9.9']:
|
||||
assert Version('1.8.0') > Version(ver)
|
||||
|
||||
|
||||
def test_version_1_point_10():
|
||||
# regression test for gh-2998.
|
||||
assert Version('1.9.0') < Version('1.10.0')
|
||||
assert Version('1.11.0') < Version('1.11.1')
|
||||
assert Version('1.11.0') == Version('1.11.0')
|
||||
assert Version('1.99.11') < Version('1.99.12')
|
||||
|
||||
|
||||
def test_alpha_beta_rc():
|
||||
assert Version('1.8.0rc1') == Version('1.8.0rc1')
|
||||
for ver in ['1.8.0', '1.8.0rc2']:
|
||||
assert Version('1.8.0rc1') < Version(ver)
|
||||
|
||||
for ver in ['1.8.0a2', '1.8.0b3', '1.7.2rc4']:
|
||||
assert Version('1.8.0rc1') > Version(ver)
|
||||
|
||||
assert Version('1.8.0b1') > Version('1.8.0a2')
|
||||
|
||||
|
||||
def test_dev_version():
|
||||
assert Version('1.9.0.dev+Unknown') < Version('1.9.0')
|
||||
for ver in ['1.9.0', '1.9.0a1', '1.9.0b2', '1.9.0b2.dev+ffffffff', '1.9.0.dev1']:
|
||||
assert Version('1.9.0.dev+f16acvda') < Version(ver)
|
||||
|
||||
assert Version('1.9.0.dev+f16acvda') == Version('1.9.0.dev+f16acvda')
|
||||
|
||||
|
||||
def test_dev_a_b_rc_mixed():
|
||||
assert Version('1.9.0a2.dev+f16acvda') == Version('1.9.0a2.dev+f16acvda')
|
||||
assert Version('1.9.0a2.dev+6acvda54') < Version('1.9.0a2')
|
||||
|
||||
|
||||
def test_dev0_version():
|
||||
assert Version('1.9.0.dev0+Unknown') < Version('1.9.0')
|
||||
for ver in ['1.9.0', '1.9.0a1', '1.9.0b2', '1.9.0b2.dev0+ffffffff']:
|
||||
assert Version('1.9.0.dev0+f16acvda') < Version(ver)
|
||||
|
||||
assert Version('1.9.0.dev0+f16acvda') == Version('1.9.0.dev0+f16acvda')
|
||||
|
||||
|
||||
def test_dev0_a_b_rc_mixed():
|
||||
assert Version('1.9.0a2.dev0+f16acvda') == Version('1.9.0a2.dev0+f16acvda')
|
||||
assert Version('1.9.0a2.dev0+6acvda54') < Version('1.9.0a2')
|
||||
|
||||
|
||||
def test_raises():
|
||||
for ver in ['1,9.0', '1.7.x']:
|
||||
assert_raises(ValueError, Version, ver)
|
||||
|
||||
def test_legacy_version():
|
||||
# Non-PEP-440 version identifiers always compare less. For NumPy this only
|
||||
# occurs on dev builds prior to 1.10.0 which are unsupported anyway.
|
||||
assert parse('invalid') < Version('0.0.0')
|
||||
assert parse('1.9.0-f16acvda') < Version('1.0.0')
|
||||
@ -0,0 +1,32 @@
|
||||
import sys
|
||||
from scipy._lib._testutils import _parse_size, _get_mem_available
|
||||
import pytest
|
||||
|
||||
|
||||
def test__parse_size():
|
||||
expected = {
|
||||
'12': 12e6,
|
||||
'12 b': 12,
|
||||
'12k': 12e3,
|
||||
' 12 M ': 12e6,
|
||||
' 12 G ': 12e9,
|
||||
' 12Tb ': 12e12,
|
||||
'12 Mib ': 12 * 1024.0**2,
|
||||
'12Tib': 12 * 1024.0**4,
|
||||
}
|
||||
|
||||
for inp, outp in sorted(expected.items()):
|
||||
if outp is None:
|
||||
with pytest.raises(ValueError):
|
||||
_parse_size(inp)
|
||||
else:
|
||||
assert _parse_size(inp) == outp
|
||||
|
||||
|
||||
def test__mem_available():
|
||||
# May return None on non-Linux platforms
|
||||
available = _get_mem_available()
|
||||
if sys.platform.startswith('linux'):
|
||||
assert available >= 0
|
||||
else:
|
||||
assert available is None or available >= 0
|
||||
@ -0,0 +1,51 @@
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from numpy.testing import assert_
|
||||
from pytest import raises as assert_raises
|
||||
|
||||
from scipy._lib._threadsafety import ReentrancyLock, non_reentrant, ReentrancyError
|
||||
|
||||
|
||||
def test_parallel_threads():
|
||||
# Check that ReentrancyLock serializes work in parallel threads.
|
||||
#
|
||||
# The test is not fully deterministic, and may succeed falsely if
|
||||
# the timings go wrong.
|
||||
|
||||
lock = ReentrancyLock("failure")
|
||||
|
||||
failflag = [False]
|
||||
exceptions_raised = []
|
||||
|
||||
def worker(k):
|
||||
try:
|
||||
with lock:
|
||||
assert_(not failflag[0])
|
||||
failflag[0] = True
|
||||
time.sleep(0.1 * k)
|
||||
assert_(failflag[0])
|
||||
failflag[0] = False
|
||||
except Exception:
|
||||
exceptions_raised.append(traceback.format_exc(2))
|
||||
|
||||
threads = [threading.Thread(target=lambda k=k: worker(k))
|
||||
for k in range(3)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
exceptions_raised = "\n".join(exceptions_raised)
|
||||
assert_(not exceptions_raised, exceptions_raised)
|
||||
|
||||
|
||||
def test_reentering():
|
||||
# Check that ReentrancyLock prevents re-entering from the same thread.
|
||||
|
||||
@non_reentrant()
|
||||
def func(x):
|
||||
return func(x)
|
||||
|
||||
assert_raises(ReentrancyError, func, 0)
|
||||
447
venv/lib/python3.12/site-packages/scipy/_lib/tests/test__util.py
Normal file
447
venv/lib/python3.12/site-packages/scipy/_lib/tests/test__util.py
Normal file
@ -0,0 +1,447 @@
|
||||
from multiprocessing import Pool
|
||||
from multiprocessing.pool import Pool as PWL
|
||||
import re
|
||||
import math
|
||||
from fractions import Fraction
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_equal, assert_
|
||||
import pytest
|
||||
from pytest import raises as assert_raises
|
||||
import hypothesis.extra.numpy as npst
|
||||
from hypothesis import given, strategies, reproduce_failure # noqa: F401
|
||||
from scipy.conftest import array_api_compatible, skip_xp_invalid_arg
|
||||
|
||||
from scipy._lib._array_api import (xp_assert_equal, xp_assert_close, is_numpy,
|
||||
copy as xp_copy)
|
||||
from scipy._lib._util import (_aligned_zeros, check_random_state, MapWrapper,
|
||||
getfullargspec_no_self, FullArgSpec,
|
||||
rng_integers, _validate_int, _rename_parameter,
|
||||
_contains_nan, _rng_html_rewrite, _lazywhere)
|
||||
|
||||
skip_xp_backends = pytest.mark.skip_xp_backends
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test__aligned_zeros():
|
||||
niter = 10
|
||||
|
||||
def check(shape, dtype, order, align):
|
||||
err_msg = repr((shape, dtype, order, align))
|
||||
x = _aligned_zeros(shape, dtype, order, align=align)
|
||||
if align is None:
|
||||
align = np.dtype(dtype).alignment
|
||||
assert_equal(x.__array_interface__['data'][0] % align, 0)
|
||||
if hasattr(shape, '__len__'):
|
||||
assert_equal(x.shape, shape, err_msg)
|
||||
else:
|
||||
assert_equal(x.shape, (shape,), err_msg)
|
||||
assert_equal(x.dtype, dtype)
|
||||
if order == "C":
|
||||
assert_(x.flags.c_contiguous, err_msg)
|
||||
elif order == "F":
|
||||
if x.size > 0:
|
||||
# Size-0 arrays get invalid flags on NumPy 1.5
|
||||
assert_(x.flags.f_contiguous, err_msg)
|
||||
elif order is None:
|
||||
assert_(x.flags.c_contiguous, err_msg)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
# try various alignments
|
||||
for align in [1, 2, 3, 4, 8, 16, 32, 64, None]:
|
||||
for n in [0, 1, 3, 11]:
|
||||
for order in ["C", "F", None]:
|
||||
for dtype in [np.uint8, np.float64]:
|
||||
for shape in [n, (1, 2, 3, n)]:
|
||||
for j in range(niter):
|
||||
check(shape, dtype, order, align)
|
||||
|
||||
|
||||
def test_check_random_state():
|
||||
# If seed is None, return the RandomState singleton used by np.random.
|
||||
# If seed is an int, return a new RandomState instance seeded with seed.
|
||||
# If seed is already a RandomState instance, return it.
|
||||
# Otherwise raise ValueError.
|
||||
rsi = check_random_state(1)
|
||||
assert_equal(type(rsi), np.random.RandomState)
|
||||
rsi = check_random_state(rsi)
|
||||
assert_equal(type(rsi), np.random.RandomState)
|
||||
rsi = check_random_state(None)
|
||||
assert_equal(type(rsi), np.random.RandomState)
|
||||
assert_raises(ValueError, check_random_state, 'a')
|
||||
rg = np.random.Generator(np.random.PCG64())
|
||||
rsi = check_random_state(rg)
|
||||
assert_equal(type(rsi), np.random.Generator)
|
||||
|
||||
|
||||
def test_getfullargspec_no_self():
|
||||
p = MapWrapper(1)
|
||||
argspec = getfullargspec_no_self(p.__init__)
|
||||
assert_equal(argspec, FullArgSpec(['pool'], None, None, (1,), [],
|
||||
None, {}))
|
||||
argspec = getfullargspec_no_self(p.__call__)
|
||||
assert_equal(argspec, FullArgSpec(['func', 'iterable'], None, None, None,
|
||||
[], None, {}))
|
||||
|
||||
class _rv_generic:
|
||||
def _rvs(self, a, b=2, c=3, *args, size=None, **kwargs):
|
||||
return None
|
||||
|
||||
rv_obj = _rv_generic()
|
||||
argspec = getfullargspec_no_self(rv_obj._rvs)
|
||||
assert_equal(argspec, FullArgSpec(['a', 'b', 'c'], 'args', 'kwargs',
|
||||
(2, 3), ['size'], {'size': None}, {}))
|
||||
|
||||
|
||||
def test_mapwrapper_serial():
|
||||
in_arg = np.arange(10.)
|
||||
out_arg = np.sin(in_arg)
|
||||
|
||||
p = MapWrapper(1)
|
||||
assert_(p._mapfunc is map)
|
||||
assert_(p.pool is None)
|
||||
assert_(p._own_pool is False)
|
||||
out = list(p(np.sin, in_arg))
|
||||
assert_equal(out, out_arg)
|
||||
|
||||
with assert_raises(RuntimeError):
|
||||
p = MapWrapper(0)
|
||||
|
||||
|
||||
def test_pool():
|
||||
with Pool(2) as p:
|
||||
p.map(math.sin, [1, 2, 3, 4])
|
||||
|
||||
|
||||
def test_mapwrapper_parallel():
|
||||
in_arg = np.arange(10.)
|
||||
out_arg = np.sin(in_arg)
|
||||
|
||||
with MapWrapper(2) as p:
|
||||
out = p(np.sin, in_arg)
|
||||
assert_equal(list(out), out_arg)
|
||||
|
||||
assert_(p._own_pool is True)
|
||||
assert_(isinstance(p.pool, PWL))
|
||||
assert_(p._mapfunc is not None)
|
||||
|
||||
# the context manager should've closed the internal pool
|
||||
# check that it has by asking it to calculate again.
|
||||
with assert_raises(Exception) as excinfo:
|
||||
p(np.sin, in_arg)
|
||||
|
||||
assert_(excinfo.type is ValueError)
|
||||
|
||||
# can also set a PoolWrapper up with a map-like callable instance
|
||||
with Pool(2) as p:
|
||||
q = MapWrapper(p.map)
|
||||
|
||||
assert_(q._own_pool is False)
|
||||
q.close()
|
||||
|
||||
# closing the PoolWrapper shouldn't close the internal pool
|
||||
# because it didn't create it
|
||||
out = p.map(np.sin, in_arg)
|
||||
assert_equal(list(out), out_arg)
|
||||
|
||||
|
||||
def test_rng_integers():
|
||||
rng = np.random.RandomState()
|
||||
|
||||
# test that numbers are inclusive of high point
|
||||
arr = rng_integers(rng, low=2, high=5, size=100, endpoint=True)
|
||||
assert np.max(arr) == 5
|
||||
assert np.min(arr) == 2
|
||||
assert arr.shape == (100, )
|
||||
|
||||
# test that numbers are inclusive of high point
|
||||
arr = rng_integers(rng, low=5, size=100, endpoint=True)
|
||||
assert np.max(arr) == 5
|
||||
assert np.min(arr) == 0
|
||||
assert arr.shape == (100, )
|
||||
|
||||
# test that numbers are exclusive of high point
|
||||
arr = rng_integers(rng, low=2, high=5, size=100, endpoint=False)
|
||||
assert np.max(arr) == 4
|
||||
assert np.min(arr) == 2
|
||||
assert arr.shape == (100, )
|
||||
|
||||
# test that numbers are exclusive of high point
|
||||
arr = rng_integers(rng, low=5, size=100, endpoint=False)
|
||||
assert np.max(arr) == 4
|
||||
assert np.min(arr) == 0
|
||||
assert arr.shape == (100, )
|
||||
|
||||
# now try with np.random.Generator
|
||||
try:
|
||||
rng = np.random.default_rng()
|
||||
except AttributeError:
|
||||
return
|
||||
|
||||
# test that numbers are inclusive of high point
|
||||
arr = rng_integers(rng, low=2, high=5, size=100, endpoint=True)
|
||||
assert np.max(arr) == 5
|
||||
assert np.min(arr) == 2
|
||||
assert arr.shape == (100, )
|
||||
|
||||
# test that numbers are inclusive of high point
|
||||
arr = rng_integers(rng, low=5, size=100, endpoint=True)
|
||||
assert np.max(arr) == 5
|
||||
assert np.min(arr) == 0
|
||||
assert arr.shape == (100, )
|
||||
|
||||
# test that numbers are exclusive of high point
|
||||
arr = rng_integers(rng, low=2, high=5, size=100, endpoint=False)
|
||||
assert np.max(arr) == 4
|
||||
assert np.min(arr) == 2
|
||||
assert arr.shape == (100, )
|
||||
|
||||
# test that numbers are exclusive of high point
|
||||
arr = rng_integers(rng, low=5, size=100, endpoint=False)
|
||||
assert np.max(arr) == 4
|
||||
assert np.min(arr) == 0
|
||||
assert arr.shape == (100, )
|
||||
|
||||
|
||||
class TestValidateInt:
|
||||
|
||||
@pytest.mark.parametrize('n', [4, np.uint8(4), np.int16(4), np.array(4)])
|
||||
def test_validate_int(self, n):
|
||||
n = _validate_int(n, 'n')
|
||||
assert n == 4
|
||||
|
||||
@pytest.mark.parametrize('n', [4.0, np.array([4]), Fraction(4, 1)])
|
||||
def test_validate_int_bad(self, n):
|
||||
with pytest.raises(TypeError, match='n must be an integer'):
|
||||
_validate_int(n, 'n')
|
||||
|
||||
def test_validate_int_below_min(self):
|
||||
with pytest.raises(ValueError, match='n must be an integer not '
|
||||
'less than 0'):
|
||||
_validate_int(-1, 'n', 0)
|
||||
|
||||
|
||||
class TestRenameParameter:
|
||||
# check that wrapper `_rename_parameter` for backward-compatible
|
||||
# keyword renaming works correctly
|
||||
|
||||
# Example method/function that still accepts keyword `old`
|
||||
@_rename_parameter("old", "new")
|
||||
def old_keyword_still_accepted(self, new):
|
||||
return new
|
||||
|
||||
# Example method/function for which keyword `old` is deprecated
|
||||
@_rename_parameter("old", "new", dep_version="1.9.0")
|
||||
def old_keyword_deprecated(self, new):
|
||||
return new
|
||||
|
||||
def test_old_keyword_still_accepted(self):
|
||||
# positional argument and both keyword work identically
|
||||
res1 = self.old_keyword_still_accepted(10)
|
||||
res2 = self.old_keyword_still_accepted(new=10)
|
||||
res3 = self.old_keyword_still_accepted(old=10)
|
||||
assert res1 == res2 == res3 == 10
|
||||
|
||||
# unexpected keyword raises an error
|
||||
message = re.escape("old_keyword_still_accepted() got an unexpected")
|
||||
with pytest.raises(TypeError, match=message):
|
||||
self.old_keyword_still_accepted(unexpected=10)
|
||||
|
||||
# multiple values for the same parameter raises an error
|
||||
message = re.escape("old_keyword_still_accepted() got multiple")
|
||||
with pytest.raises(TypeError, match=message):
|
||||
self.old_keyword_still_accepted(10, new=10)
|
||||
with pytest.raises(TypeError, match=message):
|
||||
self.old_keyword_still_accepted(10, old=10)
|
||||
with pytest.raises(TypeError, match=message):
|
||||
self.old_keyword_still_accepted(new=10, old=10)
|
||||
|
||||
def test_old_keyword_deprecated(self):
|
||||
# positional argument and both keyword work identically,
|
||||
# but use of old keyword results in DeprecationWarning
|
||||
dep_msg = "Use of keyword argument `old` is deprecated"
|
||||
res1 = self.old_keyword_deprecated(10)
|
||||
res2 = self.old_keyword_deprecated(new=10)
|
||||
with pytest.warns(DeprecationWarning, match=dep_msg):
|
||||
res3 = self.old_keyword_deprecated(old=10)
|
||||
assert res1 == res2 == res3 == 10
|
||||
|
||||
# unexpected keyword raises an error
|
||||
message = re.escape("old_keyword_deprecated() got an unexpected")
|
||||
with pytest.raises(TypeError, match=message):
|
||||
self.old_keyword_deprecated(unexpected=10)
|
||||
|
||||
# multiple values for the same parameter raises an error and,
|
||||
# if old keyword is used, results in DeprecationWarning
|
||||
message = re.escape("old_keyword_deprecated() got multiple")
|
||||
with pytest.raises(TypeError, match=message):
|
||||
self.old_keyword_deprecated(10, new=10)
|
||||
with pytest.raises(TypeError, match=message), \
|
||||
pytest.warns(DeprecationWarning, match=dep_msg):
|
||||
self.old_keyword_deprecated(10, old=10)
|
||||
with pytest.raises(TypeError, match=message), \
|
||||
pytest.warns(DeprecationWarning, match=dep_msg):
|
||||
self.old_keyword_deprecated(new=10, old=10)
|
||||
|
||||
|
||||
class TestContainsNaNTest:
|
||||
|
||||
def test_policy(self):
|
||||
data = np.array([1, 2, 3, np.nan])
|
||||
|
||||
contains_nan, nan_policy = _contains_nan(data, nan_policy="propagate")
|
||||
assert contains_nan
|
||||
assert nan_policy == "propagate"
|
||||
|
||||
contains_nan, nan_policy = _contains_nan(data, nan_policy="omit")
|
||||
assert contains_nan
|
||||
assert nan_policy == "omit"
|
||||
|
||||
msg = "The input contains nan values"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
_contains_nan(data, nan_policy="raise")
|
||||
|
||||
msg = "nan_policy must be one of"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
_contains_nan(data, nan_policy="nan")
|
||||
|
||||
def test_contains_nan(self):
|
||||
data1 = np.array([1, 2, 3])
|
||||
assert not _contains_nan(data1)[0]
|
||||
|
||||
data2 = np.array([1, 2, 3, np.nan])
|
||||
assert _contains_nan(data2)[0]
|
||||
|
||||
data3 = np.array([np.nan, 2, 3, np.nan])
|
||||
assert _contains_nan(data3)[0]
|
||||
|
||||
data4 = np.array([[1, 2], [3, 4]])
|
||||
assert not _contains_nan(data4)[0]
|
||||
|
||||
data5 = np.array([[1, 2], [3, np.nan]])
|
||||
assert _contains_nan(data5)[0]
|
||||
|
||||
@skip_xp_invalid_arg
|
||||
def test_contains_nan_with_strings(self):
|
||||
data1 = np.array([1, 2, "3", np.nan]) # converted to string "nan"
|
||||
assert not _contains_nan(data1)[0]
|
||||
|
||||
data2 = np.array([1, 2, "3", np.nan], dtype='object')
|
||||
assert _contains_nan(data2)[0]
|
||||
|
||||
data3 = np.array([["1", 2], [3, np.nan]]) # converted to string "nan"
|
||||
assert not _contains_nan(data3)[0]
|
||||
|
||||
data4 = np.array([["1", 2], [3, np.nan]], dtype='object')
|
||||
assert _contains_nan(data4)[0]
|
||||
|
||||
@skip_xp_backends('jax.numpy',
|
||||
reasons=["JAX arrays do not support item assignment"])
|
||||
@pytest.mark.usefixtures("skip_xp_backends")
|
||||
@array_api_compatible
|
||||
@pytest.mark.parametrize("nan_policy", ['propagate', 'omit', 'raise'])
|
||||
def test_array_api(self, xp, nan_policy):
|
||||
rng = np.random.default_rng(932347235892482)
|
||||
x0 = rng.random(size=(2, 3, 4))
|
||||
x = xp.asarray(x0)
|
||||
x_nan = xp_copy(x, xp=xp)
|
||||
x_nan[1, 2, 1] = np.nan
|
||||
|
||||
contains_nan, nan_policy_out = _contains_nan(x, nan_policy=nan_policy)
|
||||
assert not contains_nan
|
||||
assert nan_policy_out == nan_policy
|
||||
|
||||
if nan_policy == 'raise':
|
||||
message = 'The input contains...'
|
||||
with pytest.raises(ValueError, match=message):
|
||||
_contains_nan(x_nan, nan_policy=nan_policy)
|
||||
elif nan_policy == 'omit' and not is_numpy(xp):
|
||||
message = "`nan_policy='omit' is incompatible..."
|
||||
with pytest.raises(ValueError, match=message):
|
||||
_contains_nan(x_nan, nan_policy=nan_policy)
|
||||
elif nan_policy == 'propagate':
|
||||
contains_nan, nan_policy_out = _contains_nan(
|
||||
x_nan, nan_policy=nan_policy)
|
||||
assert contains_nan
|
||||
assert nan_policy_out == nan_policy
|
||||
|
||||
|
||||
def test__rng_html_rewrite():
|
||||
def mock_str():
|
||||
lines = [
|
||||
'np.random.default_rng(8989843)',
|
||||
'np.random.default_rng(seed)',
|
||||
'np.random.default_rng(0x9a71b21474694f919882289dc1559ca)',
|
||||
' bob ',
|
||||
]
|
||||
return lines
|
||||
|
||||
res = _rng_html_rewrite(mock_str)()
|
||||
ref = [
|
||||
'np.random.default_rng()',
|
||||
'np.random.default_rng(seed)',
|
||||
'np.random.default_rng()',
|
||||
' bob ',
|
||||
]
|
||||
|
||||
assert res == ref
|
||||
|
||||
|
||||
class TestLazywhere:
|
||||
n_arrays = strategies.integers(min_value=1, max_value=3)
|
||||
rng_seed = strategies.integers(min_value=1000000000, max_value=9999999999)
|
||||
dtype = strategies.sampled_from((np.float32, np.float64))
|
||||
p = strategies.floats(min_value=0, max_value=1)
|
||||
data = strategies.data()
|
||||
|
||||
@pytest.mark.fail_slow(5)
|
||||
@pytest.mark.filterwarnings('ignore::RuntimeWarning') # overflows, etc.
|
||||
@skip_xp_backends('jax.numpy',
|
||||
reasons=["JAX arrays do not support item assignment"])
|
||||
@pytest.mark.usefixtures("skip_xp_backends")
|
||||
@array_api_compatible
|
||||
@given(n_arrays=n_arrays, rng_seed=rng_seed, dtype=dtype, p=p, data=data)
|
||||
def test_basic(self, n_arrays, rng_seed, dtype, p, data, xp):
|
||||
mbs = npst.mutually_broadcastable_shapes(num_shapes=n_arrays+1,
|
||||
min_side=0)
|
||||
input_shapes, result_shape = data.draw(mbs)
|
||||
cond_shape, *shapes = input_shapes
|
||||
fillvalue = xp.asarray(data.draw(npst.arrays(dtype=dtype, shape=tuple())))
|
||||
arrays = [xp.asarray(data.draw(npst.arrays(dtype=dtype, shape=shape)))
|
||||
for shape in shapes]
|
||||
|
||||
def f(*args):
|
||||
return sum(arg for arg in args)
|
||||
|
||||
def f2(*args):
|
||||
return sum(arg for arg in args) / 2
|
||||
|
||||
rng = np.random.default_rng(rng_seed)
|
||||
cond = xp.asarray(rng.random(size=cond_shape) > p)
|
||||
|
||||
res1 = _lazywhere(cond, arrays, f, fillvalue)
|
||||
res2 = _lazywhere(cond, arrays, f, f2=f2)
|
||||
|
||||
# Ensure arrays are at least 1d to follow sane type promotion rules.
|
||||
if xp == np:
|
||||
cond, fillvalue, *arrays = np.atleast_1d(cond, fillvalue, *arrays)
|
||||
|
||||
ref1 = xp.where(cond, f(*arrays), fillvalue)
|
||||
ref2 = xp.where(cond, f(*arrays), f2(*arrays))
|
||||
|
||||
if xp == np:
|
||||
ref1 = ref1.reshape(result_shape)
|
||||
ref2 = ref2.reshape(result_shape)
|
||||
res1 = xp.asarray(res1)[()]
|
||||
res2 = xp.asarray(res2)[()]
|
||||
|
||||
isinstance(res1, type(xp.asarray([])))
|
||||
xp_assert_close(res1, ref1, rtol=2e-16)
|
||||
assert_equal(res1.shape, ref1.shape)
|
||||
assert_equal(res1.dtype, ref1.dtype)
|
||||
|
||||
isinstance(res2, type(xp.asarray([])))
|
||||
xp_assert_equal(res2, ref2)
|
||||
assert_equal(res2.shape, ref2.shape)
|
||||
assert_equal(res2.dtype, ref2.dtype)
|
||||
@ -0,0 +1,114 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from scipy.conftest import array_api_compatible
|
||||
from scipy._lib._array_api import (
|
||||
_GLOBAL_CONFIG, array_namespace, _asarray, copy, xp_assert_equal, is_numpy
|
||||
)
|
||||
import scipy._lib.array_api_compat.numpy as np_compat
|
||||
|
||||
skip_xp_backends = pytest.mark.skip_xp_backends
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _GLOBAL_CONFIG["SCIPY_ARRAY_API"],
|
||||
reason="Array API test; set environment variable SCIPY_ARRAY_API=1 to run it")
|
||||
class TestArrayAPI:
|
||||
|
||||
def test_array_namespace(self):
|
||||
x, y = np.array([0, 1, 2]), np.array([0, 1, 2])
|
||||
xp = array_namespace(x, y)
|
||||
assert 'array_api_compat.numpy' in xp.__name__
|
||||
|
||||
_GLOBAL_CONFIG["SCIPY_ARRAY_API"] = False
|
||||
xp = array_namespace(x, y)
|
||||
assert 'array_api_compat.numpy' in xp.__name__
|
||||
_GLOBAL_CONFIG["SCIPY_ARRAY_API"] = True
|
||||
|
||||
@array_api_compatible
|
||||
def test_asarray(self, xp):
|
||||
x, y = _asarray([0, 1, 2], xp=xp), _asarray(np.arange(3), xp=xp)
|
||||
ref = xp.asarray([0, 1, 2])
|
||||
xp_assert_equal(x, ref)
|
||||
xp_assert_equal(y, ref)
|
||||
|
||||
@pytest.mark.filterwarnings("ignore: the matrix subclass")
|
||||
def test_raises(self):
|
||||
msg = "of type `numpy.ma.MaskedArray` are not supported"
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
array_namespace(np.ma.array(1), np.array(1))
|
||||
|
||||
msg = "of type `numpy.matrix` are not supported"
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
array_namespace(np.array(1), np.matrix(1))
|
||||
|
||||
msg = "only boolean and numerical dtypes are supported"
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
array_namespace([object()])
|
||||
with pytest.raises(TypeError, match=msg):
|
||||
array_namespace('abc')
|
||||
|
||||
def test_array_likes(self):
|
||||
# should be no exceptions
|
||||
array_namespace([0, 1, 2])
|
||||
array_namespace(1, 2, 3)
|
||||
array_namespace(1)
|
||||
|
||||
@skip_xp_backends('jax.numpy',
|
||||
reasons=["JAX arrays do not support item assignment"])
|
||||
@pytest.mark.usefixtures("skip_xp_backends")
|
||||
@array_api_compatible
|
||||
def test_copy(self, xp):
|
||||
for _xp in [xp, None]:
|
||||
x = xp.asarray([1, 2, 3])
|
||||
y = copy(x, xp=_xp)
|
||||
# with numpy we'd want to use np.shared_memory, but that's not specified
|
||||
# in the array-api
|
||||
x[0] = 10
|
||||
x[1] = 11
|
||||
x[2] = 12
|
||||
|
||||
assert x[0] != y[0]
|
||||
assert x[1] != y[1]
|
||||
assert x[2] != y[2]
|
||||
assert id(x) != id(y)
|
||||
|
||||
@array_api_compatible
|
||||
@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float32', 'float64'])
|
||||
@pytest.mark.parametrize('shape', [(), (3,)])
|
||||
def test_strict_checks(self, xp, dtype, shape):
|
||||
# Check that `_strict_check` behaves as expected
|
||||
dtype = getattr(xp, dtype)
|
||||
x = xp.broadcast_to(xp.asarray(1, dtype=dtype), shape)
|
||||
x = x if shape else x[()]
|
||||
y = np_compat.asarray(1)[()]
|
||||
|
||||
options = dict(check_namespace=True, check_dtype=False, check_shape=False)
|
||||
if xp == np:
|
||||
xp_assert_equal(x, y, **options)
|
||||
else:
|
||||
with pytest.raises(AssertionError, match="Namespaces do not match."):
|
||||
xp_assert_equal(x, y, **options)
|
||||
|
||||
options = dict(check_namespace=False, check_dtype=True, check_shape=False)
|
||||
if y.dtype.name in str(x.dtype):
|
||||
xp_assert_equal(x, y, **options)
|
||||
else:
|
||||
with pytest.raises(AssertionError, match="dtypes do not match."):
|
||||
xp_assert_equal(x, y, **options)
|
||||
|
||||
options = dict(check_namespace=False, check_dtype=False, check_shape=True)
|
||||
if x.shape == y.shape:
|
||||
xp_assert_equal(x, y, **options)
|
||||
else:
|
||||
with pytest.raises(AssertionError, match="Shapes do not match."):
|
||||
xp_assert_equal(x, y, **options)
|
||||
|
||||
@array_api_compatible
|
||||
def test_check_scalar(self, xp):
|
||||
if not is_numpy(xp):
|
||||
pytest.skip("Scalars only exist in NumPy")
|
||||
|
||||
if is_numpy(xp):
|
||||
with pytest.raises(AssertionError, match="Types do not match."):
|
||||
xp_assert_equal(xp.asarray(0.), xp.float64(0))
|
||||
xp_assert_equal(xp.float64(0), xp.asarray(0.))
|
||||
162
venv/lib/python3.12/site-packages/scipy/_lib/tests/test_bunch.py
Normal file
162
venv/lib/python3.12/site-packages/scipy/_lib/tests/test_bunch.py
Normal file
@ -0,0 +1,162 @@
|
||||
import pytest
|
||||
import pickle
|
||||
from numpy.testing import assert_equal
|
||||
from scipy._lib._bunch import _make_tuple_bunch
|
||||
|
||||
|
||||
# `Result` is defined at the top level of the module so it can be
|
||||
# used to test pickling.
|
||||
Result = _make_tuple_bunch('Result', ['x', 'y', 'z'], ['w', 'beta'])
|
||||
|
||||
|
||||
class TestMakeTupleBunch:
|
||||
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
# Tests with Result
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
def setup_method(self):
|
||||
# Set up an instance of Result.
|
||||
self.result = Result(x=1, y=2, z=3, w=99, beta=0.5)
|
||||
|
||||
def test_attribute_access(self):
|
||||
assert_equal(self.result.x, 1)
|
||||
assert_equal(self.result.y, 2)
|
||||
assert_equal(self.result.z, 3)
|
||||
assert_equal(self.result.w, 99)
|
||||
assert_equal(self.result.beta, 0.5)
|
||||
|
||||
def test_indexing(self):
|
||||
assert_equal(self.result[0], 1)
|
||||
assert_equal(self.result[1], 2)
|
||||
assert_equal(self.result[2], 3)
|
||||
assert_equal(self.result[-1], 3)
|
||||
with pytest.raises(IndexError, match='index out of range'):
|
||||
self.result[3]
|
||||
|
||||
def test_unpacking(self):
|
||||
x0, y0, z0 = self.result
|
||||
assert_equal((x0, y0, z0), (1, 2, 3))
|
||||
assert_equal(self.result, (1, 2, 3))
|
||||
|
||||
def test_slice(self):
|
||||
assert_equal(self.result[1:], (2, 3))
|
||||
assert_equal(self.result[::2], (1, 3))
|
||||
assert_equal(self.result[::-1], (3, 2, 1))
|
||||
|
||||
def test_len(self):
|
||||
assert_equal(len(self.result), 3)
|
||||
|
||||
def test_repr(self):
|
||||
s = repr(self.result)
|
||||
assert_equal(s, 'Result(x=1, y=2, z=3, w=99, beta=0.5)')
|
||||
|
||||
def test_hash(self):
|
||||
assert_equal(hash(self.result), hash((1, 2, 3)))
|
||||
|
||||
def test_pickle(self):
|
||||
s = pickle.dumps(self.result)
|
||||
obj = pickle.loads(s)
|
||||
assert isinstance(obj, Result)
|
||||
assert_equal(obj.x, self.result.x)
|
||||
assert_equal(obj.y, self.result.y)
|
||||
assert_equal(obj.z, self.result.z)
|
||||
assert_equal(obj.w, self.result.w)
|
||||
assert_equal(obj.beta, self.result.beta)
|
||||
|
||||
def test_read_only_existing(self):
|
||||
with pytest.raises(AttributeError, match="can't set attribute"):
|
||||
self.result.x = -1
|
||||
|
||||
def test_read_only_new(self):
|
||||
self.result.plate_of_shrimp = "lattice of coincidence"
|
||||
assert self.result.plate_of_shrimp == "lattice of coincidence"
|
||||
|
||||
def test_constructor_missing_parameter(self):
|
||||
with pytest.raises(TypeError, match='missing'):
|
||||
# `w` is missing.
|
||||
Result(x=1, y=2, z=3, beta=0.75)
|
||||
|
||||
def test_constructor_incorrect_parameter(self):
|
||||
with pytest.raises(TypeError, match='unexpected'):
|
||||
# `foo` is not an existing field.
|
||||
Result(x=1, y=2, z=3, w=123, beta=0.75, foo=999)
|
||||
|
||||
def test_module(self):
|
||||
m = 'scipy._lib.tests.test_bunch'
|
||||
assert_equal(Result.__module__, m)
|
||||
assert_equal(self.result.__module__, m)
|
||||
|
||||
def test_extra_fields_per_instance(self):
|
||||
# This test exists to ensure that instances of the same class
|
||||
# store their own values for the extra fields. That is, the values
|
||||
# are stored per instance and not in the class.
|
||||
result1 = Result(x=1, y=2, z=3, w=-1, beta=0.0)
|
||||
result2 = Result(x=4, y=5, z=6, w=99, beta=1.0)
|
||||
assert_equal(result1.w, -1)
|
||||
assert_equal(result1.beta, 0.0)
|
||||
# The rest of these checks aren't essential, but let's check
|
||||
# them anyway.
|
||||
assert_equal(result1[:], (1, 2, 3))
|
||||
assert_equal(result2.w, 99)
|
||||
assert_equal(result2.beta, 1.0)
|
||||
assert_equal(result2[:], (4, 5, 6))
|
||||
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
# Other tests
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
def test_extra_field_names_is_optional(self):
|
||||
Square = _make_tuple_bunch('Square', ['width', 'height'])
|
||||
sq = Square(width=1, height=2)
|
||||
assert_equal(sq.width, 1)
|
||||
assert_equal(sq.height, 2)
|
||||
s = repr(sq)
|
||||
assert_equal(s, 'Square(width=1, height=2)')
|
||||
|
||||
def test_tuple_like(self):
|
||||
Tup = _make_tuple_bunch('Tup', ['a', 'b'])
|
||||
tu = Tup(a=1, b=2)
|
||||
assert isinstance(tu, tuple)
|
||||
assert isinstance(tu + (1,), tuple)
|
||||
|
||||
def test_explicit_module(self):
|
||||
m = 'some.module.name'
|
||||
Foo = _make_tuple_bunch('Foo', ['x'], ['a', 'b'], module=m)
|
||||
foo = Foo(x=1, a=355, b=113)
|
||||
assert_equal(Foo.__module__, m)
|
||||
assert_equal(foo.__module__, m)
|
||||
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
# Argument validation
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
@pytest.mark.parametrize('args', [('123', ['a'], ['b']),
|
||||
('Foo', ['-3'], ['x']),
|
||||
('Foo', ['a'], ['+-*/'])])
|
||||
def test_identifiers_not_allowed(self, args):
|
||||
with pytest.raises(ValueError, match='identifiers'):
|
||||
_make_tuple_bunch(*args)
|
||||
|
||||
@pytest.mark.parametrize('args', [('Foo', ['a', 'b', 'a'], ['x']),
|
||||
('Foo', ['a', 'b'], ['b', 'x'])])
|
||||
def test_repeated_field_names(self, args):
|
||||
with pytest.raises(ValueError, match='Duplicate'):
|
||||
_make_tuple_bunch(*args)
|
||||
|
||||
@pytest.mark.parametrize('args', [('Foo', ['_a'], ['x']),
|
||||
('Foo', ['a'], ['_x'])])
|
||||
def test_leading_underscore_not_allowed(self, args):
|
||||
with pytest.raises(ValueError, match='underscore'):
|
||||
_make_tuple_bunch(*args)
|
||||
|
||||
@pytest.mark.parametrize('args', [('Foo', ['def'], ['x']),
|
||||
('Foo', ['a'], ['or']),
|
||||
('and', ['a'], ['x'])])
|
||||
def test_keyword_not_allowed_in_fields(self, args):
|
||||
with pytest.raises(ValueError, match='keyword'):
|
||||
_make_tuple_bunch(*args)
|
||||
|
||||
def test_at_least_one_field_name_required(self):
|
||||
with pytest.raises(ValueError, match='at least one name'):
|
||||
_make_tuple_bunch('Qwerty', [], ['a', 'b'])
|
||||
@ -0,0 +1,204 @@
|
||||
from numpy.testing import assert_equal, assert_
|
||||
from pytest import raises as assert_raises
|
||||
|
||||
import time
|
||||
import pytest
|
||||
import ctypes
|
||||
import threading
|
||||
from scipy._lib import _ccallback_c as _test_ccallback_cython
|
||||
from scipy._lib import _test_ccallback
|
||||
from scipy._lib._ccallback import LowLevelCallable
|
||||
|
||||
try:
|
||||
import cffi
|
||||
HAVE_CFFI = True
|
||||
except ImportError:
|
||||
HAVE_CFFI = False
|
||||
|
||||
|
||||
ERROR_VALUE = 2.0
|
||||
|
||||
|
||||
def callback_python(a, user_data=None):
|
||||
if a == ERROR_VALUE:
|
||||
raise ValueError("bad value")
|
||||
|
||||
if user_data is None:
|
||||
return a + 1
|
||||
else:
|
||||
return a + user_data
|
||||
|
||||
def _get_cffi_func(base, signature):
|
||||
if not HAVE_CFFI:
|
||||
pytest.skip("cffi not installed")
|
||||
|
||||
# Get function address
|
||||
voidp = ctypes.cast(base, ctypes.c_void_p)
|
||||
address = voidp.value
|
||||
|
||||
# Create corresponding cffi handle
|
||||
ffi = cffi.FFI()
|
||||
func = ffi.cast(signature, address)
|
||||
return func
|
||||
|
||||
|
||||
def _get_ctypes_data():
|
||||
value = ctypes.c_double(2.0)
|
||||
return ctypes.cast(ctypes.pointer(value), ctypes.c_voidp)
|
||||
|
||||
|
||||
def _get_cffi_data():
|
||||
if not HAVE_CFFI:
|
||||
pytest.skip("cffi not installed")
|
||||
ffi = cffi.FFI()
|
||||
return ffi.new('double *', 2.0)
|
||||
|
||||
|
||||
CALLERS = {
|
||||
'simple': _test_ccallback.test_call_simple,
|
||||
'nodata': _test_ccallback.test_call_nodata,
|
||||
'nonlocal': _test_ccallback.test_call_nonlocal,
|
||||
'cython': _test_ccallback_cython.test_call_cython,
|
||||
}
|
||||
|
||||
# These functions have signatures known to the callers
|
||||
FUNCS = {
|
||||
'python': lambda: callback_python,
|
||||
'capsule': lambda: _test_ccallback.test_get_plus1_capsule(),
|
||||
'cython': lambda: LowLevelCallable.from_cython(_test_ccallback_cython,
|
||||
"plus1_cython"),
|
||||
'ctypes': lambda: _test_ccallback_cython.plus1_ctypes,
|
||||
'cffi': lambda: _get_cffi_func(_test_ccallback_cython.plus1_ctypes,
|
||||
'double (*)(double, int *, void *)'),
|
||||
'capsule_b': lambda: _test_ccallback.test_get_plus1b_capsule(),
|
||||
'cython_b': lambda: LowLevelCallable.from_cython(_test_ccallback_cython,
|
||||
"plus1b_cython"),
|
||||
'ctypes_b': lambda: _test_ccallback_cython.plus1b_ctypes,
|
||||
'cffi_b': lambda: _get_cffi_func(_test_ccallback_cython.plus1b_ctypes,
|
||||
'double (*)(double, double, int *, void *)'),
|
||||
}
|
||||
|
||||
# These functions have signatures the callers don't know
|
||||
BAD_FUNCS = {
|
||||
'capsule_bc': lambda: _test_ccallback.test_get_plus1bc_capsule(),
|
||||
'cython_bc': lambda: LowLevelCallable.from_cython(_test_ccallback_cython,
|
||||
"plus1bc_cython"),
|
||||
'ctypes_bc': lambda: _test_ccallback_cython.plus1bc_ctypes,
|
||||
'cffi_bc': lambda: _get_cffi_func(
|
||||
_test_ccallback_cython.plus1bc_ctypes,
|
||||
'double (*)(double, double, double, int *, void *)'
|
||||
),
|
||||
}
|
||||
|
||||
USER_DATAS = {
|
||||
'ctypes': _get_ctypes_data,
|
||||
'cffi': _get_cffi_data,
|
||||
'capsule': _test_ccallback.test_get_data_capsule,
|
||||
}
|
||||
|
||||
|
||||
def test_callbacks():
|
||||
def check(caller, func, user_data):
|
||||
caller = CALLERS[caller]
|
||||
func = FUNCS[func]()
|
||||
user_data = USER_DATAS[user_data]()
|
||||
|
||||
if func is callback_python:
|
||||
def func2(x):
|
||||
return func(x, 2.0)
|
||||
else:
|
||||
func2 = LowLevelCallable(func, user_data)
|
||||
func = LowLevelCallable(func)
|
||||
|
||||
# Test basic call
|
||||
assert_equal(caller(func, 1.0), 2.0)
|
||||
|
||||
# Test 'bad' value resulting to an error
|
||||
assert_raises(ValueError, caller, func, ERROR_VALUE)
|
||||
|
||||
# Test passing in user_data
|
||||
assert_equal(caller(func2, 1.0), 3.0)
|
||||
|
||||
for caller in sorted(CALLERS.keys()):
|
||||
for func in sorted(FUNCS.keys()):
|
||||
for user_data in sorted(USER_DATAS.keys()):
|
||||
check(caller, func, user_data)
|
||||
|
||||
|
||||
def test_bad_callbacks():
|
||||
def check(caller, func, user_data):
|
||||
caller = CALLERS[caller]
|
||||
user_data = USER_DATAS[user_data]()
|
||||
func = BAD_FUNCS[func]()
|
||||
|
||||
if func is callback_python:
|
||||
def func2(x):
|
||||
return func(x, 2.0)
|
||||
else:
|
||||
func2 = LowLevelCallable(func, user_data)
|
||||
func = LowLevelCallable(func)
|
||||
|
||||
# Test that basic call fails
|
||||
assert_raises(ValueError, caller, LowLevelCallable(func), 1.0)
|
||||
|
||||
# Test that passing in user_data also fails
|
||||
assert_raises(ValueError, caller, func2, 1.0)
|
||||
|
||||
# Test error message
|
||||
llfunc = LowLevelCallable(func)
|
||||
try:
|
||||
caller(llfunc, 1.0)
|
||||
except ValueError as err:
|
||||
msg = str(err)
|
||||
assert_(llfunc.signature in msg, msg)
|
||||
assert_('double (double, double, int *, void *)' in msg, msg)
|
||||
|
||||
for caller in sorted(CALLERS.keys()):
|
||||
for func in sorted(BAD_FUNCS.keys()):
|
||||
for user_data in sorted(USER_DATAS.keys()):
|
||||
check(caller, func, user_data)
|
||||
|
||||
|
||||
def test_signature_override():
|
||||
caller = _test_ccallback.test_call_simple
|
||||
func = _test_ccallback.test_get_plus1_capsule()
|
||||
|
||||
llcallable = LowLevelCallable(func, signature="bad signature")
|
||||
assert_equal(llcallable.signature, "bad signature")
|
||||
assert_raises(ValueError, caller, llcallable, 3)
|
||||
|
||||
llcallable = LowLevelCallable(func, signature="double (double, int *, void *)")
|
||||
assert_equal(llcallable.signature, "double (double, int *, void *)")
|
||||
assert_equal(caller(llcallable, 3), 4)
|
||||
|
||||
|
||||
def test_threadsafety():
|
||||
def callback(a, caller):
|
||||
if a <= 0:
|
||||
return 1
|
||||
else:
|
||||
res = caller(lambda x: callback(x, caller), a - 1)
|
||||
return 2*res
|
||||
|
||||
def check(caller):
|
||||
caller = CALLERS[caller]
|
||||
|
||||
results = []
|
||||
|
||||
count = 10
|
||||
|
||||
def run():
|
||||
time.sleep(0.01)
|
||||
r = caller(lambda x: callback(x, caller), count)
|
||||
results.append(r)
|
||||
|
||||
threads = [threading.Thread(target=run) for j in range(20)]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert_equal(results, [2.0**count]*len(threads))
|
||||
|
||||
for caller in CALLERS.keys():
|
||||
check(caller)
|
||||
@ -0,0 +1,10 @@
|
||||
import pytest
|
||||
|
||||
|
||||
def test_cython_api_deprecation():
|
||||
match = ("`scipy._lib._test_deprecation_def.foo_deprecated` "
|
||||
"is deprecated, use `foo` instead!\n"
|
||||
"Deprecated in Scipy 42.0.0")
|
||||
with pytest.warns(DeprecationWarning, match=match):
|
||||
from .. import _test_deprecation_call
|
||||
assert _test_deprecation_call.call() == (1, 1)
|
||||
@ -0,0 +1,17 @@
|
||||
import pytest
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
from .test_public_api import PUBLIC_MODULES
|
||||
|
||||
# Regression tests for gh-6793.
|
||||
# Check that all modules are importable in a new Python process.
|
||||
# This is not necessarily true if there are import cycles present.
|
||||
|
||||
@pytest.mark.fail_slow(20)
|
||||
@pytest.mark.slow
|
||||
def test_public_modules_importable():
|
||||
pids = [subprocess.Popen([sys.executable, '-c', f'import {module}'])
|
||||
for module in PUBLIC_MODULES]
|
||||
for i, pid in enumerate(pids):
|
||||
assert pid.wait() == 0, f'Failed to import {PUBLIC_MODULES[i]}'
|
||||
@ -0,0 +1,496 @@
|
||||
"""
|
||||
This test script is adopted from:
|
||||
https://github.com/numpy/numpy/blob/main/numpy/tests/test_public_api.py
|
||||
"""
|
||||
|
||||
import pkgutil
|
||||
import types
|
||||
import importlib
|
||||
import warnings
|
||||
from importlib import import_module
|
||||
|
||||
import pytest
|
||||
|
||||
import scipy
|
||||
|
||||
from scipy.conftest import xp_available_backends
|
||||
|
||||
|
||||
def test_dir_testing():
|
||||
"""Assert that output of dir has only one "testing/tester"
|
||||
attribute without duplicate"""
|
||||
assert len(dir(scipy)) == len(set(dir(scipy)))
|
||||
|
||||
|
||||
# Historically SciPy has not used leading underscores for private submodules
|
||||
# much. This has resulted in lots of things that look like public modules
|
||||
# (i.e. things that can be imported as `import scipy.somesubmodule.somefile`),
|
||||
# but were never intended to be public. The PUBLIC_MODULES list contains
|
||||
# modules that are either public because they were meant to be, or because they
|
||||
# contain public functions/objects that aren't present in any other namespace
|
||||
# for whatever reason and therefore should be treated as public.
|
||||
PUBLIC_MODULES = ["scipy." + s for s in [
|
||||
"cluster",
|
||||
"cluster.vq",
|
||||
"cluster.hierarchy",
|
||||
"constants",
|
||||
"datasets",
|
||||
"fft",
|
||||
"fftpack",
|
||||
"integrate",
|
||||
"interpolate",
|
||||
"io",
|
||||
"io.arff",
|
||||
"io.matlab",
|
||||
"io.wavfile",
|
||||
"linalg",
|
||||
"linalg.blas",
|
||||
"linalg.cython_blas",
|
||||
"linalg.lapack",
|
||||
"linalg.cython_lapack",
|
||||
"linalg.interpolative",
|
||||
"misc",
|
||||
"ndimage",
|
||||
"odr",
|
||||
"optimize",
|
||||
"signal",
|
||||
"signal.windows",
|
||||
"sparse",
|
||||
"sparse.linalg",
|
||||
"sparse.csgraph",
|
||||
"spatial",
|
||||
"spatial.distance",
|
||||
"spatial.transform",
|
||||
"special",
|
||||
"stats",
|
||||
"stats.contingency",
|
||||
"stats.distributions",
|
||||
"stats.mstats",
|
||||
"stats.qmc",
|
||||
"stats.sampling"
|
||||
]]
|
||||
|
||||
# The PRIVATE_BUT_PRESENT_MODULES list contains modules that lacked underscores
|
||||
# in their name and hence looked public, but weren't meant to be. All these
|
||||
# namespace were deprecated in the 1.8.0 release - see "clear split between
|
||||
# public and private API" in the 1.8.0 release notes.
|
||||
# These private modules support will be removed in SciPy v2.0.0, as the
|
||||
# deprecation messages emitted by each of these modules say.
|
||||
PRIVATE_BUT_PRESENT_MODULES = [
|
||||
'scipy.constants.codata',
|
||||
'scipy.constants.constants',
|
||||
'scipy.fftpack.basic',
|
||||
'scipy.fftpack.convolve',
|
||||
'scipy.fftpack.helper',
|
||||
'scipy.fftpack.pseudo_diffs',
|
||||
'scipy.fftpack.realtransforms',
|
||||
'scipy.integrate.dop',
|
||||
'scipy.integrate.lsoda',
|
||||
'scipy.integrate.odepack',
|
||||
'scipy.integrate.quadpack',
|
||||
'scipy.integrate.vode',
|
||||
'scipy.interpolate.dfitpack',
|
||||
'scipy.interpolate.fitpack',
|
||||
'scipy.interpolate.fitpack2',
|
||||
'scipy.interpolate.interpnd',
|
||||
'scipy.interpolate.interpolate',
|
||||
'scipy.interpolate.ndgriddata',
|
||||
'scipy.interpolate.polyint',
|
||||
'scipy.interpolate.rbf',
|
||||
'scipy.io.arff.arffread',
|
||||
'scipy.io.harwell_boeing',
|
||||
'scipy.io.idl',
|
||||
'scipy.io.matlab.byteordercodes',
|
||||
'scipy.io.matlab.mio',
|
||||
'scipy.io.matlab.mio4',
|
||||
'scipy.io.matlab.mio5',
|
||||
'scipy.io.matlab.mio5_params',
|
||||
'scipy.io.matlab.mio5_utils',
|
||||
'scipy.io.matlab.mio_utils',
|
||||
'scipy.io.matlab.miobase',
|
||||
'scipy.io.matlab.streams',
|
||||
'scipy.io.mmio',
|
||||
'scipy.io.netcdf',
|
||||
'scipy.linalg.basic',
|
||||
'scipy.linalg.decomp',
|
||||
'scipy.linalg.decomp_cholesky',
|
||||
'scipy.linalg.decomp_lu',
|
||||
'scipy.linalg.decomp_qr',
|
||||
'scipy.linalg.decomp_schur',
|
||||
'scipy.linalg.decomp_svd',
|
||||
'scipy.linalg.matfuncs',
|
||||
'scipy.linalg.misc',
|
||||
'scipy.linalg.special_matrices',
|
||||
'scipy.misc.common',
|
||||
'scipy.misc.doccer',
|
||||
'scipy.ndimage.filters',
|
||||
'scipy.ndimage.fourier',
|
||||
'scipy.ndimage.interpolation',
|
||||
'scipy.ndimage.measurements',
|
||||
'scipy.ndimage.morphology',
|
||||
'scipy.odr.models',
|
||||
'scipy.odr.odrpack',
|
||||
'scipy.optimize.cobyla',
|
||||
'scipy.optimize.cython_optimize',
|
||||
'scipy.optimize.lbfgsb',
|
||||
'scipy.optimize.linesearch',
|
||||
'scipy.optimize.minpack',
|
||||
'scipy.optimize.minpack2',
|
||||
'scipy.optimize.moduleTNC',
|
||||
'scipy.optimize.nonlin',
|
||||
'scipy.optimize.optimize',
|
||||
'scipy.optimize.slsqp',
|
||||
'scipy.optimize.tnc',
|
||||
'scipy.optimize.zeros',
|
||||
'scipy.signal.bsplines',
|
||||
'scipy.signal.filter_design',
|
||||
'scipy.signal.fir_filter_design',
|
||||
'scipy.signal.lti_conversion',
|
||||
'scipy.signal.ltisys',
|
||||
'scipy.signal.signaltools',
|
||||
'scipy.signal.spectral',
|
||||
'scipy.signal.spline',
|
||||
'scipy.signal.waveforms',
|
||||
'scipy.signal.wavelets',
|
||||
'scipy.signal.windows.windows',
|
||||
'scipy.sparse.base',
|
||||
'scipy.sparse.bsr',
|
||||
'scipy.sparse.compressed',
|
||||
'scipy.sparse.construct',
|
||||
'scipy.sparse.coo',
|
||||
'scipy.sparse.csc',
|
||||
'scipy.sparse.csr',
|
||||
'scipy.sparse.data',
|
||||
'scipy.sparse.dia',
|
||||
'scipy.sparse.dok',
|
||||
'scipy.sparse.extract',
|
||||
'scipy.sparse.lil',
|
||||
'scipy.sparse.linalg.dsolve',
|
||||
'scipy.sparse.linalg.eigen',
|
||||
'scipy.sparse.linalg.interface',
|
||||
'scipy.sparse.linalg.isolve',
|
||||
'scipy.sparse.linalg.matfuncs',
|
||||
'scipy.sparse.sparsetools',
|
||||
'scipy.sparse.spfuncs',
|
||||
'scipy.sparse.sputils',
|
||||
'scipy.spatial.ckdtree',
|
||||
'scipy.spatial.kdtree',
|
||||
'scipy.spatial.qhull',
|
||||
'scipy.spatial.transform.rotation',
|
||||
'scipy.special.add_newdocs',
|
||||
'scipy.special.basic',
|
||||
'scipy.special.cython_special',
|
||||
'scipy.special.orthogonal',
|
||||
'scipy.special.sf_error',
|
||||
'scipy.special.specfun',
|
||||
'scipy.special.spfun_stats',
|
||||
'scipy.stats.biasedurn',
|
||||
'scipy.stats.kde',
|
||||
'scipy.stats.morestats',
|
||||
'scipy.stats.mstats_basic',
|
||||
'scipy.stats.mstats_extras',
|
||||
'scipy.stats.mvn',
|
||||
'scipy.stats.stats',
|
||||
]
|
||||
|
||||
|
||||
def is_unexpected(name):
|
||||
"""Check if this needs to be considered."""
|
||||
if '._' in name or '.tests' in name or '.setup' in name:
|
||||
return False
|
||||
|
||||
if name in PUBLIC_MODULES:
|
||||
return False
|
||||
|
||||
if name in PRIVATE_BUT_PRESENT_MODULES:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
SKIP_LIST = [
|
||||
'scipy.conftest',
|
||||
'scipy.version',
|
||||
'scipy.special.libsf_error_state'
|
||||
]
|
||||
|
||||
|
||||
# XXX: this test does more than it says on the tin - in using `pkgutil.walk_packages`,
|
||||
# it will raise if it encounters any exceptions which are not handled by `ignore_errors`
|
||||
# while attempting to import each discovered package.
|
||||
# For now, `ignore_errors` only ignores what is necessary, but this could be expanded -
|
||||
# for example, to all errors from private modules or git subpackages - if desired.
|
||||
def test_all_modules_are_expected():
|
||||
"""
|
||||
Test that we don't add anything that looks like a new public module by
|
||||
accident. Check is based on filenames.
|
||||
"""
|
||||
|
||||
def ignore_errors(name):
|
||||
# if versions of other array libraries are installed which are incompatible
|
||||
# with the installed NumPy version, there can be errors on importing
|
||||
# `array_api_compat`. This should only raise if SciPy is configured with
|
||||
# that library as an available backend.
|
||||
backends = {'cupy': 'cupy',
|
||||
'pytorch': 'torch',
|
||||
'dask.array': 'dask.array'}
|
||||
for backend, dir_name in backends.items():
|
||||
path = f'array_api_compat.{dir_name}'
|
||||
if path in name and backend not in xp_available_backends:
|
||||
return
|
||||
raise
|
||||
|
||||
modnames = []
|
||||
|
||||
for _, modname, _ in pkgutil.walk_packages(path=scipy.__path__,
|
||||
prefix=scipy.__name__ + '.',
|
||||
onerror=ignore_errors):
|
||||
if is_unexpected(modname) and modname not in SKIP_LIST:
|
||||
# We have a name that is new. If that's on purpose, add it to
|
||||
# PUBLIC_MODULES. We don't expect to have to add anything to
|
||||
# PRIVATE_BUT_PRESENT_MODULES. Use an underscore in the name!
|
||||
modnames.append(modname)
|
||||
|
||||
if modnames:
|
||||
raise AssertionError(f'Found unexpected modules: {modnames}')
|
||||
|
||||
|
||||
# Stuff that clearly shouldn't be in the API and is detected by the next test
|
||||
# below
|
||||
SKIP_LIST_2 = [
|
||||
'scipy.char',
|
||||
'scipy.rec',
|
||||
'scipy.emath',
|
||||
'scipy.math',
|
||||
'scipy.random',
|
||||
'scipy.ctypeslib',
|
||||
'scipy.ma'
|
||||
]
|
||||
|
||||
|
||||
def test_all_modules_are_expected_2():
|
||||
"""
|
||||
Method checking all objects. The pkgutil-based method in
|
||||
`test_all_modules_are_expected` does not catch imports into a namespace,
|
||||
only filenames.
|
||||
"""
|
||||
|
||||
def find_unexpected_members(mod_name):
|
||||
members = []
|
||||
module = importlib.import_module(mod_name)
|
||||
if hasattr(module, '__all__'):
|
||||
objnames = module.__all__
|
||||
else:
|
||||
objnames = dir(module)
|
||||
|
||||
for objname in objnames:
|
||||
if not objname.startswith('_'):
|
||||
fullobjname = mod_name + '.' + objname
|
||||
if isinstance(getattr(module, objname), types.ModuleType):
|
||||
if is_unexpected(fullobjname) and fullobjname not in SKIP_LIST_2:
|
||||
members.append(fullobjname)
|
||||
|
||||
return members
|
||||
|
||||
unexpected_members = find_unexpected_members("scipy")
|
||||
for modname in PUBLIC_MODULES:
|
||||
unexpected_members.extend(find_unexpected_members(modname))
|
||||
|
||||
if unexpected_members:
|
||||
raise AssertionError("Found unexpected object(s) that look like "
|
||||
f"modules: {unexpected_members}")
|
||||
|
||||
|
||||
def test_api_importable():
|
||||
"""
|
||||
Check that all submodules listed higher up in this file can be imported
|
||||
Note that if a PRIVATE_BUT_PRESENT_MODULES entry goes missing, it may
|
||||
simply need to be removed from the list (deprecation may or may not be
|
||||
needed - apply common sense).
|
||||
"""
|
||||
def check_importable(module_name):
|
||||
try:
|
||||
importlib.import_module(module_name)
|
||||
except (ImportError, AttributeError):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
module_names = []
|
||||
for module_name in PUBLIC_MODULES:
|
||||
if not check_importable(module_name):
|
||||
module_names.append(module_name)
|
||||
|
||||
if module_names:
|
||||
raise AssertionError("Modules in the public API that cannot be "
|
||||
f"imported: {module_names}")
|
||||
|
||||
with warnings.catch_warnings(record=True):
|
||||
warnings.filterwarnings('always', category=DeprecationWarning)
|
||||
warnings.filterwarnings('always', category=ImportWarning)
|
||||
for module_name in PRIVATE_BUT_PRESENT_MODULES:
|
||||
if not check_importable(module_name):
|
||||
module_names.append(module_name)
|
||||
|
||||
if module_names:
|
||||
raise AssertionError("Modules that are not really public but looked "
|
||||
"public and can not be imported: "
|
||||
f"{module_names}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("module_name", "correct_module"),
|
||||
[('scipy.constants.codata', None),
|
||||
('scipy.constants.constants', None),
|
||||
('scipy.fftpack.basic', None),
|
||||
('scipy.fftpack.helper', None),
|
||||
('scipy.fftpack.pseudo_diffs', None),
|
||||
('scipy.fftpack.realtransforms', None),
|
||||
('scipy.integrate.dop', None),
|
||||
('scipy.integrate.lsoda', None),
|
||||
('scipy.integrate.odepack', None),
|
||||
('scipy.integrate.quadpack', None),
|
||||
('scipy.integrate.vode', None),
|
||||
('scipy.interpolate.fitpack', None),
|
||||
('scipy.interpolate.fitpack2', None),
|
||||
('scipy.interpolate.interpolate', None),
|
||||
('scipy.interpolate.ndgriddata', None),
|
||||
('scipy.interpolate.polyint', None),
|
||||
('scipy.interpolate.rbf', None),
|
||||
('scipy.io.harwell_boeing', None),
|
||||
('scipy.io.idl', None),
|
||||
('scipy.io.mmio', None),
|
||||
('scipy.io.netcdf', None),
|
||||
('scipy.io.arff.arffread', 'arff'),
|
||||
('scipy.io.matlab.byteordercodes', 'matlab'),
|
||||
('scipy.io.matlab.mio_utils', 'matlab'),
|
||||
('scipy.io.matlab.mio', 'matlab'),
|
||||
('scipy.io.matlab.mio4', 'matlab'),
|
||||
('scipy.io.matlab.mio5_params', 'matlab'),
|
||||
('scipy.io.matlab.mio5_utils', 'matlab'),
|
||||
('scipy.io.matlab.mio5', 'matlab'),
|
||||
('scipy.io.matlab.miobase', 'matlab'),
|
||||
('scipy.io.matlab.streams', 'matlab'),
|
||||
('scipy.linalg.basic', None),
|
||||
('scipy.linalg.decomp', None),
|
||||
('scipy.linalg.decomp_cholesky', None),
|
||||
('scipy.linalg.decomp_lu', None),
|
||||
('scipy.linalg.decomp_qr', None),
|
||||
('scipy.linalg.decomp_schur', None),
|
||||
('scipy.linalg.decomp_svd', None),
|
||||
('scipy.linalg.matfuncs', None),
|
||||
('scipy.linalg.misc', None),
|
||||
('scipy.linalg.special_matrices', None),
|
||||
('scipy.misc.common', None),
|
||||
('scipy.ndimage.filters', None),
|
||||
('scipy.ndimage.fourier', None),
|
||||
('scipy.ndimage.interpolation', None),
|
||||
('scipy.ndimage.measurements', None),
|
||||
('scipy.ndimage.morphology', None),
|
||||
('scipy.odr.models', None),
|
||||
('scipy.odr.odrpack', None),
|
||||
('scipy.optimize.cobyla', None),
|
||||
('scipy.optimize.lbfgsb', None),
|
||||
('scipy.optimize.linesearch', None),
|
||||
('scipy.optimize.minpack', None),
|
||||
('scipy.optimize.minpack2', None),
|
||||
('scipy.optimize.moduleTNC', None),
|
||||
('scipy.optimize.nonlin', None),
|
||||
('scipy.optimize.optimize', None),
|
||||
('scipy.optimize.slsqp', None),
|
||||
('scipy.optimize.tnc', None),
|
||||
('scipy.optimize.zeros', None),
|
||||
('scipy.signal.bsplines', None),
|
||||
('scipy.signal.filter_design', None),
|
||||
('scipy.signal.fir_filter_design', None),
|
||||
('scipy.signal.lti_conversion', None),
|
||||
('scipy.signal.ltisys', None),
|
||||
('scipy.signal.signaltools', None),
|
||||
('scipy.signal.spectral', None),
|
||||
('scipy.signal.waveforms', None),
|
||||
('scipy.signal.wavelets', None),
|
||||
('scipy.signal.windows.windows', 'windows'),
|
||||
('scipy.sparse.lil', None),
|
||||
('scipy.sparse.linalg.dsolve', 'linalg'),
|
||||
('scipy.sparse.linalg.eigen', 'linalg'),
|
||||
('scipy.sparse.linalg.interface', 'linalg'),
|
||||
('scipy.sparse.linalg.isolve', 'linalg'),
|
||||
('scipy.sparse.linalg.matfuncs', 'linalg'),
|
||||
('scipy.sparse.sparsetools', None),
|
||||
('scipy.sparse.spfuncs', None),
|
||||
('scipy.sparse.sputils', None),
|
||||
('scipy.spatial.ckdtree', None),
|
||||
('scipy.spatial.kdtree', None),
|
||||
('scipy.spatial.qhull', None),
|
||||
('scipy.spatial.transform.rotation', 'transform'),
|
||||
('scipy.special.add_newdocs', None),
|
||||
('scipy.special.basic', None),
|
||||
('scipy.special.orthogonal', None),
|
||||
('scipy.special.sf_error', None),
|
||||
('scipy.special.specfun', None),
|
||||
('scipy.special.spfun_stats', None),
|
||||
('scipy.stats.biasedurn', None),
|
||||
('scipy.stats.kde', None),
|
||||
('scipy.stats.morestats', None),
|
||||
('scipy.stats.mstats_basic', 'mstats'),
|
||||
('scipy.stats.mstats_extras', 'mstats'),
|
||||
('scipy.stats.mvn', None),
|
||||
('scipy.stats.stats', None)])
|
||||
def test_private_but_present_deprecation(module_name, correct_module):
|
||||
# gh-18279, gh-17572, gh-17771 noted that deprecation warnings
|
||||
# for imports from private modules
|
||||
# were misleading. Check that this is resolved.
|
||||
module = import_module(module_name)
|
||||
if correct_module is None:
|
||||
import_name = f'scipy.{module_name.split(".")[1]}'
|
||||
else:
|
||||
import_name = f'scipy.{module_name.split(".")[1]}.{correct_module}'
|
||||
|
||||
correct_import = import_module(import_name)
|
||||
|
||||
# Attributes that were formerly in `module_name` can still be imported from
|
||||
# `module_name`, albeit with a deprecation warning.
|
||||
for attr_name in module.__all__:
|
||||
if attr_name == "varmats_from_mat":
|
||||
# defer handling this case, see
|
||||
# https://github.com/scipy/scipy/issues/19223
|
||||
continue
|
||||
# ensure attribute is present where the warning is pointing
|
||||
assert getattr(correct_import, attr_name, None) is not None
|
||||
message = f"Please import `{attr_name}` from the `{import_name}`..."
|
||||
with pytest.deprecated_call(match=message):
|
||||
getattr(module, attr_name)
|
||||
|
||||
# Attributes that were not in `module_name` get an error notifying the user
|
||||
# that the attribute is not in `module_name` and that `module_name` is deprecated.
|
||||
message = f"`{module_name}` is deprecated..."
|
||||
with pytest.raises(AttributeError, match=message):
|
||||
getattr(module, "ekki")
|
||||
|
||||
|
||||
def test_misc_doccer_deprecation():
|
||||
# gh-18279, gh-17572, gh-17771 noted that deprecation warnings
|
||||
# for imports from private modules were misleading.
|
||||
# Check that this is resolved.
|
||||
# `test_private_but_present_deprecation` cannot be used since `correct_import`
|
||||
# is a different subpackage (`_lib` instead of `misc`).
|
||||
module = import_module('scipy.misc.doccer')
|
||||
correct_import = import_module('scipy._lib.doccer')
|
||||
|
||||
# Attributes that were formerly in `scipy.misc.doccer` can still be imported from
|
||||
# `scipy.misc.doccer`, albeit with a deprecation warning. The specific message
|
||||
# depends on whether the attribute is in `scipy._lib.doccer` or not.
|
||||
for attr_name in module.__all__:
|
||||
attr = getattr(correct_import, attr_name, None)
|
||||
if attr is None:
|
||||
message = f"`scipy.misc.{attr_name}` is deprecated..."
|
||||
else:
|
||||
message = f"Please import `{attr_name}` from the `scipy._lib.doccer`..."
|
||||
with pytest.deprecated_call(match=message):
|
||||
getattr(module, attr_name)
|
||||
|
||||
# Attributes that were not in `scipy.misc.doccer` get an error
|
||||
# notifying the user that the attribute is not in `scipy.misc.doccer`
|
||||
# and that `scipy.misc.doccer` is deprecated.
|
||||
message = "`scipy.misc.doccer` is deprecated..."
|
||||
with pytest.raises(AttributeError, match=message):
|
||||
getattr(module, "ekki")
|
||||
@ -0,0 +1,18 @@
|
||||
import re
|
||||
|
||||
import scipy
|
||||
from numpy.testing import assert_
|
||||
|
||||
|
||||
def test_valid_scipy_version():
|
||||
# Verify that the SciPy version is a valid one (no .post suffix or other
|
||||
# nonsense). See NumPy issue gh-6431 for an issue caused by an invalid
|
||||
# version.
|
||||
version_pattern = r"^[0-9]+\.[0-9]+\.[0-9]+(|a[0-9]|b[0-9]|rc[0-9])"
|
||||
dev_suffix = r"(\.dev0\+.+([0-9a-f]{7}|Unknown))"
|
||||
if scipy.version.release:
|
||||
res = re.match(version_pattern, scipy.__version__)
|
||||
else:
|
||||
res = re.match(version_pattern + dev_suffix, scipy.__version__)
|
||||
|
||||
assert_(res is not None, scipy.__version__)
|
||||
@ -0,0 +1,42 @@
|
||||
""" Test tmpdirs module """
|
||||
from os import getcwd
|
||||
from os.path import realpath, abspath, dirname, isfile, join as pjoin, exists
|
||||
|
||||
from scipy._lib._tmpdirs import tempdir, in_tempdir, in_dir
|
||||
|
||||
from numpy.testing import assert_, assert_equal
|
||||
|
||||
MY_PATH = abspath(__file__)
|
||||
MY_DIR = dirname(MY_PATH)
|
||||
|
||||
|
||||
def test_tempdir():
|
||||
with tempdir() as tmpdir:
|
||||
fname = pjoin(tmpdir, 'example_file.txt')
|
||||
with open(fname, "w") as fobj:
|
||||
fobj.write('a string\\n')
|
||||
assert_(not exists(tmpdir))
|
||||
|
||||
|
||||
def test_in_tempdir():
|
||||
my_cwd = getcwd()
|
||||
with in_tempdir() as tmpdir:
|
||||
with open('test.txt', "w") as f:
|
||||
f.write('some text')
|
||||
assert_(isfile('test.txt'))
|
||||
assert_(isfile(pjoin(tmpdir, 'test.txt')))
|
||||
assert_(not exists(tmpdir))
|
||||
assert_equal(getcwd(), my_cwd)
|
||||
|
||||
|
||||
def test_given_directory():
|
||||
# Test InGivenDirectory
|
||||
cwd = getcwd()
|
||||
with in_dir() as tmpdir:
|
||||
assert_equal(tmpdir, abspath(cwd))
|
||||
assert_equal(tmpdir, abspath(getcwd()))
|
||||
with in_dir(MY_DIR) as tmpdir:
|
||||
assert_equal(tmpdir, MY_DIR)
|
||||
assert_equal(realpath(MY_DIR), realpath(abspath(getcwd())))
|
||||
# We were deleting the given directory! Check not so now.
|
||||
assert_(isfile(MY_PATH))
|
||||
@ -0,0 +1,135 @@
|
||||
"""
|
||||
Tests which scan for certain occurrences in the code, they may not find
|
||||
all of these occurrences but should catch almost all. This file was adapted
|
||||
from NumPy.
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import ast
|
||||
import tokenize
|
||||
|
||||
import scipy
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class ParseCall(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
self.ls = []
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
self.ls.append(node.attr)
|
||||
|
||||
def visit_Name(self, node):
|
||||
self.ls.append(node.id)
|
||||
|
||||
|
||||
class FindFuncs(ast.NodeVisitor):
|
||||
def __init__(self, filename):
|
||||
super().__init__()
|
||||
self.__filename = filename
|
||||
self.bad_filters = []
|
||||
self.bad_stacklevels = []
|
||||
|
||||
def visit_Call(self, node):
|
||||
p = ParseCall()
|
||||
p.visit(node.func)
|
||||
ast.NodeVisitor.generic_visit(self, node)
|
||||
|
||||
if p.ls[-1] == 'simplefilter' or p.ls[-1] == 'filterwarnings':
|
||||
# get first argument of the `args` node of the filter call
|
||||
match node.args[0]:
|
||||
case ast.Constant() as c:
|
||||
argtext = c.value
|
||||
case ast.JoinedStr() as js:
|
||||
# if we get an f-string, discard the templated pieces, which
|
||||
# are likely the type or specific message; we're interested
|
||||
# in the action, which is less likely to use a template
|
||||
argtext = "".join(
|
||||
x.value for x in js.values if isinstance(x, ast.Constant)
|
||||
)
|
||||
case _:
|
||||
raise ValueError("unknown ast node type")
|
||||
# check if filter is set to ignore
|
||||
if argtext == "ignore":
|
||||
self.bad_filters.append(
|
||||
f"{self.__filename}:{node.lineno}")
|
||||
|
||||
if p.ls[-1] == 'warn' and (
|
||||
len(p.ls) == 1 or p.ls[-2] == 'warnings'):
|
||||
|
||||
if self.__filename == "_lib/tests/test_warnings.py":
|
||||
# This file
|
||||
return
|
||||
|
||||
# See if stacklevel exists:
|
||||
if len(node.args) == 3:
|
||||
return
|
||||
args = {kw.arg for kw in node.keywords}
|
||||
if "stacklevel" not in args:
|
||||
self.bad_stacklevels.append(
|
||||
f"{self.__filename}:{node.lineno}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def warning_calls():
|
||||
# combined "ignore" and stacklevel error
|
||||
base = Path(scipy.__file__).parent
|
||||
|
||||
bad_filters = []
|
||||
bad_stacklevels = []
|
||||
|
||||
for path in base.rglob("*.py"):
|
||||
# use tokenize to auto-detect encoding on systems where no
|
||||
# default encoding is defined (e.g., LANG='C')
|
||||
with tokenize.open(str(path)) as file:
|
||||
tree = ast.parse(file.read(), filename=str(path))
|
||||
finder = FindFuncs(path.relative_to(base))
|
||||
finder.visit(tree)
|
||||
bad_filters.extend(finder.bad_filters)
|
||||
bad_stacklevels.extend(finder.bad_stacklevels)
|
||||
|
||||
return bad_filters, bad_stacklevels
|
||||
|
||||
|
||||
@pytest.mark.fail_slow(20)
|
||||
@pytest.mark.slow
|
||||
def test_warning_calls_filters(warning_calls):
|
||||
bad_filters, bad_stacklevels = warning_calls
|
||||
|
||||
# We try not to add filters in the code base, because those filters aren't
|
||||
# thread-safe. We aim to only filter in tests with
|
||||
# np.testing.suppress_warnings. However, in some cases it may prove
|
||||
# necessary to filter out warnings, because we can't (easily) fix the root
|
||||
# cause for them and we don't want users to see some warnings when they use
|
||||
# SciPy correctly. So we list exceptions here. Add new entries only if
|
||||
# there's a good reason.
|
||||
allowed_filters = (
|
||||
os.path.join('datasets', '_fetchers.py'),
|
||||
os.path.join('datasets', '__init__.py'),
|
||||
os.path.join('optimize', '_optimize.py'),
|
||||
os.path.join('optimize', '_constraints.py'),
|
||||
os.path.join('optimize', '_nnls.py'),
|
||||
os.path.join('signal', '_ltisys.py'),
|
||||
os.path.join('sparse', '__init__.py'), # np.matrix pending-deprecation
|
||||
os.path.join('stats', '_discrete_distns.py'), # gh-14901
|
||||
os.path.join('stats', '_continuous_distns.py'),
|
||||
os.path.join('stats', '_binned_statistic.py'), # gh-19345
|
||||
os.path.join('stats', 'tests', 'test_axis_nan_policy.py'), # gh-20694
|
||||
os.path.join('_lib', '_util.py'), # gh-19341
|
||||
os.path.join('sparse', 'linalg', '_dsolve', 'linsolve.py'), # gh-17924
|
||||
"conftest.py",
|
||||
)
|
||||
bad_filters = [item for item in bad_filters if item.split(':')[0] not in
|
||||
allowed_filters]
|
||||
|
||||
if bad_filters:
|
||||
raise AssertionError(
|
||||
"warning ignore filter should not be used, instead, use\n"
|
||||
"numpy.testing.suppress_warnings (in tests only);\n"
|
||||
"found in:\n {}".format(
|
||||
"\n ".join(bad_filters)))
|
||||
|
||||
Reference in New Issue
Block a user