import numpy as np def check_arguments(fun, y0, support_complex): """Helper function for checking arguments common to all solvers.""" y0 = np.asarray(y0) if np.issubdtype(y0.dtype, np.complexfloating): if not support_complex: raise ValueError("`y0` is complex, but the chosen solver does " "not support integration in a complex domain.") dtype = complex else: dtype = float y0 = y0.astype(dtype, copy=False) if y0.ndim != 1: raise ValueError("`y0` must be 1-dimensional.") if not np.isfinite(y0).all(): raise ValueError("All components of the initial state `y0` must be finite.") def fun_wrapped(t, y): return np.asarray(fun(t, y), dtype=dtype) return fun_wrapped, y0 class OdeSolver: """Base class for ODE solvers. In order to implement a new solver you need to follow the guidelines: 1. A constructor must accept parameters presented in the base class (listed below) along with any other parameters specific to a solver. 2. A constructor must accept arbitrary extraneous arguments ``**extraneous``, but warn that these arguments are irrelevant using `common.warn_extraneous` function. Do not pass these arguments to the base class. 3. A solver must implement a private method `_step_impl(self)` which propagates a solver one step further. It must return tuple ``(success, message)``, where ``success`` is a boolean indicating whether a step was successful, and ``message`` is a string containing description of a failure if a step failed or None otherwise. 4. A solver must implement a private method `_dense_output_impl(self)`, which returns a `DenseOutput` object covering the last successful step. 5. A solver must have attributes listed below in Attributes section. Note that ``t_old`` and ``step_size`` are updated automatically. 6. Use `fun(self, t, y)` method for the system rhs evaluation, this way the number of function evaluations (`nfev`) will be tracked automatically. 7. For convenience, a base class provides `fun_single(self, t, y)` and `fun_vectorized(self, t, y)` for evaluating the rhs in non-vectorized and vectorized fashions respectively (regardless of how `fun` from the constructor is implemented). These calls don't increment `nfev`. 8. If a solver uses a Jacobian matrix and LU decompositions, it should track the number of Jacobian evaluations (`njev`) and the number of LU decompositions (`nlu`). 9. By convention, the function evaluations used to compute a finite difference approximation of the Jacobian should not be counted in `nfev`, thus use `fun_single(self, t, y)` or `fun_vectorized(self, t, y)` when computing a finite difference approximation of the Jacobian. Parameters ---------- fun : callable Right-hand side of the system: the time derivative of the state ``y`` at time ``t``. The calling signature is ``fun(t, y)``, where ``t`` is a scalar and ``y`` is an ndarray with ``len(y) = len(y0)``. ``fun`` must return an array of the same shape as ``y``. See `vectorized` for more information. t0 : float Initial time. y0 : array_like, shape (n,) Initial state. t_bound : float Boundary time --- the integration won't continue beyond it. It also determines the direction of the integration. vectorized : bool Whether `fun` can be called in a vectorized fashion. Default is False. If ``vectorized`` is False, `fun` will always be called with ``y`` of shape ``(n,)``, where ``n = len(y0)``. If ``vectorized`` is True, `fun` may be called with ``y`` of shape ``(n, k)``, where ``k`` is an integer. In this case, `fun` must behave such that ``fun(t, y)[:, i] == fun(t, y[:, i])`` (i.e. each column of the returned array is the time derivative of the state corresponding with a column of ``y``). Setting ``vectorized=True`` allows for faster finite difference approximation of the Jacobian by methods 'Radau' and 'BDF', but will result in slower execution for other methods. It can also result in slower overall execution for 'Radau' and 'BDF' in some circumstances (e.g. small ``len(y0)``). support_complex : bool, optional Whether integration in a complex domain should be supported. Generally determined by a derived solver class capabilities. Default is False. Attributes ---------- n : int Number of equations. status : string Current status of the solver: 'running', 'finished' or 'failed'. t_bound : float Boundary time. direction : float Integration direction: +1 or -1. t : float Current time. y : ndarray Current state. t_old : float Previous time. None if no steps were made yet. step_size : float Size of the last successful step. None if no steps were made yet. nfev : int Number of the system's rhs evaluations. njev : int Number of the Jacobian evaluations. nlu : int Number of LU decompositions. """ TOO_SMALL_STEP = "Required step size is less than spacing between numbers." def __init__(self, fun, t0, y0, t_bound, vectorized, support_complex=False): self.t_old = None self.t = t0 self._fun, self.y = check_arguments(fun, y0, support_complex) self.t_bound = t_bound self.vectorized = vectorized if vectorized: def fun_single(t, y): return self._fun(t, y[:, None]).ravel() fun_vectorized = self._fun else: fun_single = self._fun def fun_vectorized(t, y): f = np.empty_like(y) for i, yi in enumerate(y.T): f[:, i] = self._fun(t, yi) return f def fun(t, y): self.nfev += 1 return self.fun_single(t, y) self.fun = fun self.fun_single = fun_single self.fun_vectorized = fun_vectorized self.direction = np.sign(t_bound - t0) if t_bound != t0 else 1 self.n = self.y.size self.status = 'running' self.nfev = 0 self.njev = 0 self.nlu = 0 @property def step_size(self): if self.t_old is None: return None else: return np.abs(self.t - self.t_old) def step(self): """Perform one integration step. Returns ------- message : string or None Report from the solver. Typically a reason for a failure if `self.status` is 'failed' after the step was taken or None otherwise. """ if self.status != 'running': raise RuntimeError("Attempt to step on a failed or finished " "solver.") if self.n == 0 or self.t == self.t_bound: # Handle corner cases of empty solver or no integration. self.t_old = self.t self.t = self.t_bound message = None self.status = 'finished' else: t = self.t success, message = self._step_impl() if not success: self.status = 'failed' else: self.t_old = t if self.direction * (self.t - self.t_bound) >= 0: self.status = 'finished' return message def dense_output(self): """Compute a local interpolant over the last successful step. Returns ------- sol : `DenseOutput` Local interpolant over the last successful step. """ if self.t_old is None: raise RuntimeError("Dense output is available after a successful " "step was made.") if self.n == 0 or self.t == self.t_old: # Handle corner cases of empty solver and no integration. return ConstantDenseOutput(self.t_old, self.t, self.y) else: return self._dense_output_impl() def _step_impl(self): raise NotImplementedError def _dense_output_impl(self): raise NotImplementedError class DenseOutput: """Base class for local interpolant over step made by an ODE solver. It interpolates between `t_min` and `t_max` (see Attributes below). Evaluation outside this interval is not forbidden, but the accuracy is not guaranteed. Attributes ---------- t_min, t_max : float Time range of the interpolation. """ def __init__(self, t_old, t): self.t_old = t_old self.t = t self.t_min = min(t, t_old) self.t_max = max(t, t_old) def __call__(self, t): """Evaluate the interpolant. Parameters ---------- t : float or array_like with shape (n_points,) Points to evaluate the solution at. Returns ------- y : ndarray, shape (n,) or (n, n_points) Computed values. Shape depends on whether `t` was a scalar or a 1-D array. """ t = np.asarray(t) if t.ndim > 1: raise ValueError("`t` must be a float or a 1-D array.") return self._call_impl(t) def _call_impl(self, t): raise NotImplementedError class ConstantDenseOutput(DenseOutput): """Constant value interpolator. This class used for degenerate integration cases: equal integration limits or a system with 0 equations. """ def __init__(self, t_old, t, value): super().__init__(t_old, t) self.value = value def _call_impl(self, t): if t.ndim == 0: return self.value else: ret = np.empty((self.value.shape[0], t.shape[0])) ret[:] = self.value[:, None] return ret