""" Version 2 of the ODESolver class hierarchy. This class works for systems of ODEs and for a single (scalar) ODE. """ import numpy as np class ODESolver: def __init__(self, f): # Wrap user's f in a new function that always # converts list/tuple to array (or let array be array) self.f = lambda u, t: np.asarray(f(u, t), float) def set_initial_condition(self, U0): if isinstance(U0, (float,int)): # scalar ODE self.neq = 1 # no of equations U0 = float(U0) else: # system of ODEs U0 = np.asarray(U0) self.neq = U0.size # no of equations self.U0 = U0 def solve(self, time_points): self.t = np.asarray(time_points) N = len(self.t) if self.neq == 1: # scalar ODEs self.u = np.zeros(N) else: # systems of ODEs self.u = np.zeros((N,self.neq)) # Assume that self.t[0] corresponds to self.U0 self.u[0] = self.U0 # Time loop for n in range(N-1): self.n = n self.u[n+1] = self.advance() return self.u, self.t class ForwardEuler(ODESolver): def advance(self): u, f, n, t = self.u, self.f, self.n, self.t dt = t[n+1] - t[n] unew = u[n] + dt*f(u[n], t[n]) return unew class ExplicitMidpoint(ODESolver): def advance(self): u, f, n, t = self.u, self.f, self.n, self.t dt = t[n+1] - t[n] dt2 = dt/2.0 k1 = f(u[n], t[n]) k2 = f(u[n] + dt2*k1, t[n] + dt2) unew = u[n] + dt*k2 return unew class RungeKutta4(ODESolver): def advance(self): u, f, n, t = self.u, self.f, self.n, self.t dt = t[n+1] - t[n] dt2 = dt/2.0 k1 = f(u[n], t[n]) k2 = f(u[n] + dt2*k1, t[n] + dt2) k3 = f(u[n] + dt2*k2, t[n] + dt2) k4 = f(u[n] + dt*k3, t[n] + dt) unew = u[n] + (dt/6.0)*(k1 + 2*k2 + 2*k3 + k4) return unew registered_solver_classes = [ ForwardEuler, ExplicitMidpoint, RungeKutta4] def test_exact_numerical_solution(): """ Test the different methods for a problem where the analytical solution is known and linear. All the methods should be exact to machine precision for this choice. """ a = 0.2; b = 3 def f(u, t): return a + (u - u_exact(t))**5 def u_exact(t): """Exact u(t) corresponding to f above.""" return a*t + b U0 = u_exact(0) T = 8 N = 10 tol = 1E-15 t_points = np.linspace(0, T, N) for solver_class in registered_solver_classes: solver = solver_class(f) solver.set_initial_condition(U0) u, t = solver.solve(t_points) u_e = u_exact(t) max_error = (u_e - u).max() msg = f'{solver.__class__.__name__} failed with max_error={max_error}' assert max_error < tol, msg if __name__ == '__main__': test_exact_numerical_solution()