asd
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,460 @@
|
||||
import collections
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
|
||||
from scipy._lib._util import MapWrapper
|
||||
|
||||
|
||||
class VertexBase(ABC):
|
||||
"""
|
||||
Base class for a vertex.
|
||||
"""
|
||||
def __init__(self, x, nn=None, index=None):
|
||||
"""
|
||||
Initiation of a vertex object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : tuple or vector
|
||||
The geometric location (domain).
|
||||
nn : list, optional
|
||||
Nearest neighbour list.
|
||||
index : int, optional
|
||||
Index of vertex.
|
||||
"""
|
||||
self.x = x
|
||||
self.hash = hash(self.x) # Save precomputed hash
|
||||
|
||||
if nn is not None:
|
||||
self.nn = set(nn) # can use .indexupdate to add a new list
|
||||
else:
|
||||
self.nn = set()
|
||||
|
||||
self.index = index
|
||||
|
||||
def __hash__(self):
|
||||
return self.hash
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item not in ['x_a']:
|
||||
raise AttributeError(f"{type(self)} object has no attribute "
|
||||
f"'{item}'")
|
||||
if item == 'x_a':
|
||||
self.x_a = np.array(self.x)
|
||||
return self.x_a
|
||||
|
||||
@abstractmethod
|
||||
def connect(self, v):
|
||||
raise NotImplementedError("This method is only implemented with an "
|
||||
"associated child of the base class.")
|
||||
|
||||
@abstractmethod
|
||||
def disconnect(self, v):
|
||||
raise NotImplementedError("This method is only implemented with an "
|
||||
"associated child of the base class.")
|
||||
|
||||
def star(self):
|
||||
"""Returns the star domain ``st(v)`` of the vertex.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
v :
|
||||
The vertex ``v`` in ``st(v)``
|
||||
|
||||
Returns
|
||||
-------
|
||||
st : set
|
||||
A set containing all the vertices in ``st(v)``
|
||||
"""
|
||||
self.st = self.nn
|
||||
self.st.add(self)
|
||||
return self.st
|
||||
|
||||
|
||||
class VertexScalarField(VertexBase):
|
||||
"""
|
||||
Add homology properties of a scalar field f: R^n --> R associated with
|
||||
the geometry built from the VertexBase class
|
||||
"""
|
||||
|
||||
def __init__(self, x, field=None, nn=None, index=None, field_args=(),
|
||||
g_cons=None, g_cons_args=()):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
x : tuple,
|
||||
vector of vertex coordinates
|
||||
field : callable, optional
|
||||
a scalar field f: R^n --> R associated with the geometry
|
||||
nn : list, optional
|
||||
list of nearest neighbours
|
||||
index : int, optional
|
||||
index of the vertex
|
||||
field_args : tuple, optional
|
||||
additional arguments to be passed to field
|
||||
g_cons : callable, optional
|
||||
constraints on the vertex
|
||||
g_cons_args : tuple, optional
|
||||
additional arguments to be passed to g_cons
|
||||
|
||||
"""
|
||||
super().__init__(x, nn=nn, index=index)
|
||||
|
||||
# Note Vertex is only initiated once for all x so only
|
||||
# evaluated once
|
||||
# self.feasible = None
|
||||
|
||||
# self.f is externally defined by the cache to allow parallel
|
||||
# processing
|
||||
# None type that will break arithmetic operations unless defined
|
||||
# self.f = None
|
||||
|
||||
self.check_min = True
|
||||
self.check_max = True
|
||||
|
||||
def connect(self, v):
|
||||
"""Connects self to another vertex object v.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
v : VertexBase or VertexScalarField object
|
||||
"""
|
||||
if v is not self and v not in self.nn:
|
||||
self.nn.add(v)
|
||||
v.nn.add(self)
|
||||
|
||||
# Flags for checking homology properties:
|
||||
self.check_min = True
|
||||
self.check_max = True
|
||||
v.check_min = True
|
||||
v.check_max = True
|
||||
|
||||
def disconnect(self, v):
|
||||
if v in self.nn:
|
||||
self.nn.remove(v)
|
||||
v.nn.remove(self)
|
||||
|
||||
# Flags for checking homology properties:
|
||||
self.check_min = True
|
||||
self.check_max = True
|
||||
v.check_min = True
|
||||
v.check_max = True
|
||||
|
||||
def minimiser(self):
|
||||
"""Check whether this vertex is strictly less than all its
|
||||
neighbours"""
|
||||
if self.check_min:
|
||||
self._min = all(self.f < v.f for v in self.nn)
|
||||
self.check_min = False
|
||||
|
||||
return self._min
|
||||
|
||||
def maximiser(self):
|
||||
"""
|
||||
Check whether this vertex is strictly greater than all its
|
||||
neighbours.
|
||||
"""
|
||||
if self.check_max:
|
||||
self._max = all(self.f > v.f for v in self.nn)
|
||||
self.check_max = False
|
||||
|
||||
return self._max
|
||||
|
||||
|
||||
class VertexVectorField(VertexBase):
|
||||
"""
|
||||
Add homology properties of a scalar field f: R^n --> R^m associated with
|
||||
the geometry built from the VertexBase class.
|
||||
"""
|
||||
|
||||
def __init__(self, x, sfield=None, vfield=None, field_args=(),
|
||||
vfield_args=(), g_cons=None,
|
||||
g_cons_args=(), nn=None, index=None):
|
||||
super().__init__(x, nn=nn, index=index)
|
||||
|
||||
raise NotImplementedError("This class is still a work in progress")
|
||||
|
||||
|
||||
class VertexCacheBase:
|
||||
"""Base class for a vertex cache for a simplicial complex."""
|
||||
def __init__(self):
|
||||
|
||||
self.cache = collections.OrderedDict()
|
||||
self.nfev = 0 # Feasible points
|
||||
self.index = -1
|
||||
|
||||
def __iter__(self):
|
||||
for v in self.cache:
|
||||
yield self.cache[v]
|
||||
return
|
||||
|
||||
def size(self):
|
||||
"""Returns the size of the vertex cache."""
|
||||
return self.index + 1
|
||||
|
||||
def print_out(self):
|
||||
headlen = len(f"Vertex cache of size: {len(self.cache)}:")
|
||||
print('=' * headlen)
|
||||
print(f"Vertex cache of size: {len(self.cache)}:")
|
||||
print('=' * headlen)
|
||||
for v in self.cache:
|
||||
self.cache[v].print_out()
|
||||
|
||||
|
||||
class VertexCube(VertexBase):
|
||||
"""Vertex class to be used for a pure simplicial complex with no associated
|
||||
differential geometry (single level domain that exists in R^n)"""
|
||||
def __init__(self, x, nn=None, index=None):
|
||||
super().__init__(x, nn=nn, index=index)
|
||||
|
||||
def connect(self, v):
|
||||
if v is not self and v not in self.nn:
|
||||
self.nn.add(v)
|
||||
v.nn.add(self)
|
||||
|
||||
def disconnect(self, v):
|
||||
if v in self.nn:
|
||||
self.nn.remove(v)
|
||||
v.nn.remove(self)
|
||||
|
||||
|
||||
class VertexCacheIndex(VertexCacheBase):
|
||||
def __init__(self):
|
||||
"""
|
||||
Class for a vertex cache for a simplicial complex without an associated
|
||||
field. Useful only for building and visualising a domain complex.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
"""
|
||||
super().__init__()
|
||||
self.Vertex = VertexCube
|
||||
|
||||
def __getitem__(self, x, nn=None):
|
||||
try:
|
||||
return self.cache[x]
|
||||
except KeyError:
|
||||
self.index += 1
|
||||
xval = self.Vertex(x, index=self.index)
|
||||
# logging.info("New generated vertex at x = {}".format(x))
|
||||
# NOTE: Surprisingly high performance increase if logging
|
||||
# is commented out
|
||||
self.cache[x] = xval
|
||||
return self.cache[x]
|
||||
|
||||
|
||||
class VertexCacheField(VertexCacheBase):
|
||||
def __init__(self, field=None, field_args=(), g_cons=None, g_cons_args=(),
|
||||
workers=1):
|
||||
"""
|
||||
Class for a vertex cache for a simplicial complex with an associated
|
||||
field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
field : callable
|
||||
Scalar or vector field callable.
|
||||
field_args : tuple, optional
|
||||
Any additional fixed parameters needed to completely specify the
|
||||
field function
|
||||
g_cons : dict or sequence of dict, optional
|
||||
Constraints definition.
|
||||
Function(s) ``R**n`` in the form::
|
||||
g_cons_args : tuple, optional
|
||||
Any additional fixed parameters needed to completely specify the
|
||||
constraint functions
|
||||
workers : int optional
|
||||
Uses `multiprocessing.Pool <multiprocessing>`) to compute the field
|
||||
functions in parallel.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.index = -1
|
||||
self.Vertex = VertexScalarField
|
||||
self.field = field
|
||||
self.field_args = field_args
|
||||
self.wfield = FieldWrapper(field, field_args) # if workers is not 1
|
||||
|
||||
self.g_cons = g_cons
|
||||
self.g_cons_args = g_cons_args
|
||||
self.wgcons = ConstraintWrapper(g_cons, g_cons_args)
|
||||
self.gpool = set() # A set of tuples to process for feasibility
|
||||
|
||||
# Field processing objects
|
||||
self.fpool = set() # A set of tuples to process for scalar function
|
||||
self.sfc_lock = False # True if self.fpool is non-Empty
|
||||
|
||||
self.workers = workers
|
||||
self._mapwrapper = MapWrapper(workers)
|
||||
|
||||
if workers == 1:
|
||||
self.process_gpool = self.proc_gpool
|
||||
if g_cons is None:
|
||||
self.process_fpool = self.proc_fpool_nog
|
||||
else:
|
||||
self.process_fpool = self.proc_fpool_g
|
||||
else:
|
||||
self.process_gpool = self.pproc_gpool
|
||||
if g_cons is None:
|
||||
self.process_fpool = self.pproc_fpool_nog
|
||||
else:
|
||||
self.process_fpool = self.pproc_fpool_g
|
||||
|
||||
def __getitem__(self, x, nn=None):
|
||||
try:
|
||||
return self.cache[x]
|
||||
except KeyError:
|
||||
self.index += 1
|
||||
xval = self.Vertex(x, field=self.field, nn=nn, index=self.index,
|
||||
field_args=self.field_args,
|
||||
g_cons=self.g_cons,
|
||||
g_cons_args=self.g_cons_args)
|
||||
|
||||
self.cache[x] = xval # Define in cache
|
||||
self.gpool.add(xval) # Add to pool for processing feasibility
|
||||
self.fpool.add(xval) # Add to pool for processing field values
|
||||
return self.cache[x]
|
||||
|
||||
def __getstate__(self):
|
||||
self_dict = self.__dict__.copy()
|
||||
del self_dict['pool']
|
||||
return self_dict
|
||||
|
||||
def process_pools(self):
|
||||
if self.g_cons is not None:
|
||||
self.process_gpool()
|
||||
self.process_fpool()
|
||||
self.proc_minimisers()
|
||||
|
||||
def feasibility_check(self, v):
|
||||
v.feasible = True
|
||||
for g, args in zip(self.g_cons, self.g_cons_args):
|
||||
# constraint may return more than 1 value.
|
||||
if np.any(g(v.x_a, *args) < 0.0):
|
||||
v.f = np.inf
|
||||
v.feasible = False
|
||||
break
|
||||
|
||||
def compute_sfield(self, v):
|
||||
"""Compute the scalar field values of a vertex object `v`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
v : VertexBase or VertexScalarField object
|
||||
"""
|
||||
try:
|
||||
v.f = self.field(v.x_a, *self.field_args)
|
||||
self.nfev += 1
|
||||
except AttributeError:
|
||||
v.f = np.inf
|
||||
# logging.warning(f"Field function not found at x = {self.x_a}")
|
||||
if np.isnan(v.f):
|
||||
v.f = np.inf
|
||||
|
||||
def proc_gpool(self):
|
||||
"""Process all constraints."""
|
||||
if self.g_cons is not None:
|
||||
for v in self.gpool:
|
||||
self.feasibility_check(v)
|
||||
# Clean the pool
|
||||
self.gpool = set()
|
||||
|
||||
def pproc_gpool(self):
|
||||
"""Process all constraints in parallel."""
|
||||
gpool_l = []
|
||||
for v in self.gpool:
|
||||
gpool_l.append(v.x_a)
|
||||
|
||||
G = self._mapwrapper(self.wgcons.gcons, gpool_l)
|
||||
for v, g in zip(self.gpool, G):
|
||||
v.feasible = g # set vertex object attribute v.feasible = g (bool)
|
||||
|
||||
def proc_fpool_g(self):
|
||||
"""Process all field functions with constraints supplied."""
|
||||
for v in self.fpool:
|
||||
if v.feasible:
|
||||
self.compute_sfield(v)
|
||||
# Clean the pool
|
||||
self.fpool = set()
|
||||
|
||||
def proc_fpool_nog(self):
|
||||
"""Process all field functions with no constraints supplied."""
|
||||
for v in self.fpool:
|
||||
self.compute_sfield(v)
|
||||
# Clean the pool
|
||||
self.fpool = set()
|
||||
|
||||
def pproc_fpool_g(self):
|
||||
"""
|
||||
Process all field functions with constraints supplied in parallel.
|
||||
"""
|
||||
self.wfield.func
|
||||
fpool_l = []
|
||||
for v in self.fpool:
|
||||
if v.feasible:
|
||||
fpool_l.append(v.x_a)
|
||||
else:
|
||||
v.f = np.inf
|
||||
F = self._mapwrapper(self.wfield.func, fpool_l)
|
||||
for va, f in zip(fpool_l, F):
|
||||
vt = tuple(va)
|
||||
self[vt].f = f # set vertex object attribute v.f = f
|
||||
self.nfev += 1
|
||||
# Clean the pool
|
||||
self.fpool = set()
|
||||
|
||||
def pproc_fpool_nog(self):
|
||||
"""
|
||||
Process all field functions with no constraints supplied in parallel.
|
||||
"""
|
||||
self.wfield.func
|
||||
fpool_l = []
|
||||
for v in self.fpool:
|
||||
fpool_l.append(v.x_a)
|
||||
F = self._mapwrapper(self.wfield.func, fpool_l)
|
||||
for va, f in zip(fpool_l, F):
|
||||
vt = tuple(va)
|
||||
self[vt].f = f # set vertex object attribute v.f = f
|
||||
self.nfev += 1
|
||||
# Clean the pool
|
||||
self.fpool = set()
|
||||
|
||||
def proc_minimisers(self):
|
||||
"""Check for minimisers."""
|
||||
for v in self:
|
||||
v.minimiser()
|
||||
v.maximiser()
|
||||
|
||||
|
||||
class ConstraintWrapper:
|
||||
"""Object to wrap constraints to pass to `multiprocessing.Pool`."""
|
||||
def __init__(self, g_cons, g_cons_args):
|
||||
self.g_cons = g_cons
|
||||
self.g_cons_args = g_cons_args
|
||||
|
||||
def gcons(self, v_x_a):
|
||||
vfeasible = True
|
||||
for g, args in zip(self.g_cons, self.g_cons_args):
|
||||
# constraint may return more than 1 value.
|
||||
if np.any(g(v_x_a, *args) < 0.0):
|
||||
vfeasible = False
|
||||
break
|
||||
return vfeasible
|
||||
|
||||
|
||||
class FieldWrapper:
|
||||
"""Object to wrap field to pass to `multiprocessing.Pool`."""
|
||||
def __init__(self, field, field_args):
|
||||
self.field = field
|
||||
self.field_args = field_args
|
||||
|
||||
def func(self, v_x_a):
|
||||
try:
|
||||
v_f = self.field(v_x_a, *self.field_args)
|
||||
except Exception:
|
||||
v_f = np.inf
|
||||
if np.isnan(v_f):
|
||||
v_f = np.inf
|
||||
|
||||
return v_f
|
||||
Reference in New Issue
Block a user