Source code for colder.simulation.numerical


import numpy as np
import scipy
import matplotlib.pyplot as plt

from typing import Tuple, List, Union



[docs] class interpolator1D:
[docs] def __init__(self, x_data : np.array, y_data : np.array, bc_type : str = 'not-a-knot'): """ Initialize a 1D spline interpolator from input sample data. Args: x_data (np.array): Sample data, x. y_data (np.array): Sample data, y. bc_type (str, optional): Passed to `scipy.interpolate.CubicSpline`. Defaults to 'not-a-knot'. """ # NOTE bc_type = 'clamped' should be optimal for driving terms self.x_data = x_data self.y_data = y_data assert len(self.x_data) == len(self.y_data), 'input data must be of same len' self.x_range = tuple([ min(self.x_data), max(self.x_data) ]) self.n_interp_points = len(self.x_data) self.spline = scipy.interpolate.CubicSpline(self.x_data, self.y_data, bc_type=bc_type)
[docs] @classmethod def from_function(cls, function : callable, x_range : tuple, n_interp_points : int = 50, fargs : dict = {}, *args, **kwargs): """ Initialize a 1D spline interpolator from input function and n points sampled in range `x_range`. Args: function (callable): callable function to interpolate. The arguments must be (`x_data`, fargs). x_range (tuple): Tuple containing the range of x data to be sampled. n_interp_points (int, optional): How many x data points to sample. Defaults to 50. fargs (dict, optional): keyword dictionary passed to `function`. Defaults to {}. Returns: _type_: interpolator1D """ x_data = np.linspace(x_range[0], x_range[1], n_interp_points) y_data = function(x_data, **fargs) return cls(x_data, y_data, *args, **kwargs)
[docs] def get_function(self) -> callable: """ Returns a lambda function calling the spline function. The function first and only argument will be the x point. Returns: callable: Lambda function wrapping the spline (interpolated) function. """ return lambda x : self.spline(x)
[docs] def get_derivative(self, ord : int = 1) -> callable: """ Returns a lambda function calling the spline function for derivative of order `ord`. Args: ord (int, optional): Order of the derivative. Defaults to 1. Returns: callable: Lambda function wrapping the spline (interpolated) function. """ assert ord > 0, 'must take derivative of order greater than one' return lambda x : self.spline(x, ord)
[docs] def plot(self, plot_derivative : bool = True, plot_data : bool = False, labels : List[str] = ["f", "df"], plt_obj = None): """ Plot the interpolated functions. Args: plot_derivative (bool, optional): If true, the derivatives are plotted. Defaults to True. plot_data (bool, optional): If true, plots the original points before interpolation. Defaults to False. labels (List[str], optional): Labels for the curve and their derivatives. Defaults to ["f", "df"]. plt_obj (_type_, optional): Use a matplotlib plt instance instead of the default one. Defaults to None. Returns: matplotlib.pyplot: `matplotlib.pyplot` figure """ if plt_obj is None: plt_obj = plt xs = np.linspace(self.x_range[0], self.x_range[1], self.n_interp_points*20) plt_obj.plot(xs, self.spline(xs), label=labels[0]) if plot_derivative: plt_obj.plot(xs, self.spline(xs, 1), label=labels[1]) if plot_data: plt_obj.plot(self.x_data, self.y_data, 'x', label='spline data') return plt