diff --git a/asv_bench/asv.conf.json b/asv_bench/asv.conf.json index 0ab7dfef7..be5ae2fc0 100644 --- a/asv_bench/asv.conf.json +++ b/asv_bench/asv.conf.json @@ -10,6 +10,10 @@ "html_dir": ".asv/html", "show_commit_url": "https://github.com/$OWNER/$REPO/commit/", "build_command": [ - "python -m build --outdir {build_cache_dir} {build_dir}" + "PIP_NO_BUILD_ISOLATION=0 python -m build", + "PIP_NO_BUILD_ISOLATION=false python -m pip wheel --no-deps -w {build_cache_dir} {build_dir}" + ], + "install_command": [ + "in-dir={env_dir} python -m pip install {wheel_file}[bench] --force-reinstall" ] } diff --git a/asv_bench/benchmarks/inspect_to_sympy.py b/asv_bench/benchmarks/inspect_to_sympy.py new file mode 100644 index 000000000..8ec305d36 --- /dev/null +++ b/asv_bench/benchmarks/inspect_to_sympy.py @@ -0,0 +1,455 @@ +"""Utilities to convert a `dysts` dynamical system object's rhs to SymPy. + +This module inspects the source of an object's RHS method (by default +named ``rhs``), parses the function using ``ast``, and converts the +returned expression(s) into SymPy expressions. + +The conversion is intentionally conservative and aims to handle common +patterns used in simple rhs implementations, e.g. returning a tuple/list +of arithmetic expressions, using indexing into a state vector (``x[0]``), +and calls to common ``numpy``/``math`` functions (``np.sin``, ``math.exp``, ...). + +Limitations: +- It does not execute arbitrary code from the inspected function. +- Complex control flow, loops, or non-trivial Python constructs may not + be fully supported. + +Example +------- +from dysts.flows import Lorenz +from inspect_to_sympy import object_to_sympy_rhs + +lor = Lorenz() +symbols, exprs, lambda_rhs = object_to_sympy_rhs(lor) +# `symbols` is a list of SymPy symbols for the state vector +# `exprs` is a list of SymPy expressions for the RHS +# `lambda_rhs` is a SymPy Lambda mapping state symbols -> rhs expressions +""" +from __future__ import annotations + +import ast +import inspect +import textwrap +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Tuple + +import numpy as np +import sympy as sp +from dysts.base import BaseDyn + + +def _is_name(node: ast.AST, name: str) -> bool: + return isinstance(node, ast.Name) and node.id == name + + +class _ASTToSympy(ast.NodeVisitor): + def __init__( + self, + state_name: str, + state_symbols: List[sp.Symbol], + locals_map: Dict[str, Any], + ): + self.state_name = state_name + self.state_symbols = state_symbols + self.locals = dict(locals_map) + + def generic_visit(self, node): + raise NotImplementedError(f"AST node not supported: {node!r}") + + def visit_Constant(self, node: ast.Constant): + return sp.sympify(node.value) + + def visit_Num(self, node: ast.Num): + return sp.sympify(node.n) + + def visit_Name(self, node: ast.Name): + if node.id in self.locals: + return self.locals[node.id] + return sp.Symbol(node.id) + + def visit_Tuple(self, node: ast.Tuple): + elems = [] + for elt in node.elts: + val = self.visit(elt) + if isinstance(val, (list, tuple)): + elems.extend(list(val)) + else: + elems.append(val) + return tuple(elems) + + def visit_List(self, node: ast.List): + elems = [] + for elt in node.elts: + val = self.visit(elt) + if isinstance(val, (list, tuple)): + elems.extend(list(val)) + else: + elems.append(val) + return elems + + def visit_Starred(self, node: ast.Starred): + # Handle starred expressions like `*x` in list/tuple literals. + # If the starred value is the state vector name, expand to state symbols. + if isinstance(node.value, ast.Name) and node.value.id == self.state_name: + return tuple(self.state_symbols) + # Otherwise, evaluate the value and if it is a sequence, return its items + val = self.visit(node.value) + if isinstance(val, (list, tuple)): + return tuple(val) + raise NotImplementedError( + "Unsupported starred expression; cannot expand non-iterable" + ) + + def visit_BinOp(self, node: ast.BinOp): + left = self.visit(node.left) + right = self.visit(node.right) + if isinstance(node.op, ast.Add): + return left + right + if isinstance(node.op, ast.Sub): + return left - right + if isinstance(node.op, ast.Mult): + return left * right + if isinstance(node.op, ast.Div): + return left / right + if isinstance(node.op, ast.Pow): + return left**right + if isinstance(node.op, ast.Mod): + return left % right + raise NotImplementedError(f"Binary op not supported: {node.op!r}") + + def visit_UnaryOp(self, node: ast.UnaryOp): + operand = self.visit(node.operand) + if isinstance(node.op, ast.USub): + return -operand + if isinstance(node.op, ast.UAdd): + return +operand + raise NotImplementedError(f"Unary op not supported: {node.op!r}") + + def visit_Call(self, node: ast.Call): + # Determine function name + func = node.func + func_name = None + mod_name = None + + if isinstance(func, ast.Name): + func_name = func.id + elif isinstance(func, ast.Attribute): + # e.g. np.sin or math.exp + if isinstance(func.value, ast.Name): + mod_name = func.value.id + func_name = func.attr + else: + raise NotImplementedError( + f"Call to unsupported func node: {ast.dump(func)}" + ) + + # Map common numpy/math functions to sympy + func_map = { + "sin": sp.sin, + "cos": sp.cos, + "tan": sp.tan, + "exp": sp.exp, + "log": sp.log, + "sqrt": sp.sqrt, + "abs": sp.Abs, + "atan": sp.atan, + "asin": sp.asin, + "acos": sp.acos, + } + + args = [self.visit(a) for a in node.args] + + # Special-case array constructors: return underlying list/tuple + if func_name in ("array", "asarray") and mod_name in ("np", "numpy"): + # expect a single positional arg that's a list/tuple + if len(args) == 1: + return args[0] + + if func_name in func_map: + return func_map[func_name](*args) + + # Unknown function: create a Sympy Function + symf = sp.Function(func_name) + return symf(*args) + + def visit_Subscript(self, node: ast.Subscript): + # Support patterns like x[0] where x is the state vector name + value = node.value + # handle simple constant index + if _is_name(value, self.state_name): + # Python >=3.9: slice is directly the node.slice + idx_node = node.slice + if isinstance(idx_node, ast.Constant): + idx = idx_node.value + else: + raise NotImplementedError( + "Only constant indices into state vector supported" + ) + return self.state_symbols[idx] + + # If it's something else, try to evaluate generically + base = self.visit(value) + # slice may be constant + if isinstance(node.slice, ast.Constant): + key = node.slice.value + return base[key] + raise NotImplementedError("Unsupported subscript pattern") + + +def _numeric_consistency_check( + dysts_flow: BaseDyn, + rhsfunc: Callable, + arg_names: List[str], + state_names: List[str], + vector_mode: bool, + sys_dim: int, + lambda_rhs: sp.Lambda, +) -> None: + """Compare the original dysts rhs function to the SymPy-derived lambda. + + Raises a RuntimeError if they disagree. + """ + # default to nonnegative support (e.g. Lotka volterra) + random_state = np.random.standard_exponential(size=sys_dim) + + # Construct call arguments for the original function (bound method). + call_args = [] + for name in arg_names: + if name == "self": + continue + if name in state_names and not vector_mode: + idx = state_names.index(name) + call_args.append(random_state[idx]) + elif name in state_names and vector_mode: + call_args.append(np.asarray(random_state, dtype=float)) + elif name == "t": + call_args.append(float(np.random.standard_normal(size=()))) + else: + call_args.append(dysts_flow.params[name]) + + dysts_val = rhsfunc(*call_args) + orig_arr = np.asarray(dysts_val, dtype=float).ravel() + + sym_val = lambda_rhs(*tuple(random_state)) + sym_arr = np.asarray(sym_val, dtype=float).ravel() + + if orig_arr.shape != sym_arr.shape: + raise RuntimeError( + f"_rhs shape {orig_arr.shape} != sympy shape {sym_arr.shape}" + ) + + if not np.allclose(orig_arr, sym_arr, rtol=1e-6, atol=1e-9): + raise RuntimeError("Numeric mismatch between original and sympy conversion.") + + +def dynsys_to_sympy( + obj: Any, func_name: str = "_rhs" +) -> Tuple[List[sp.Symbol], List[sp.Expr], sp.Lambda]: + """Inspect ``obj`` for a method named ``func_name`` and return a SymPy + representation of its RHS. + + Returns: + a tuple ``(state_symbols, exprs, lambda_rhs)`` where ``state_symbols`` + is a list of SymPy symbols for the state vector, ``exprs`` is a list of + SymPy expressions for the RHS components, and ``lambda_rhs`` is a SymPy + Lambda mapping the state symbols to the RHS vector. + + Example: + + >>> from dysts.flows import Lorenz + >>> from inspect_to_sympy import dynsys_to_sympy + >>> lor = Lorenz() + >>> symbols, exprs, lambda_rhs = dynsys_to_sympy(lor) + >>> print(lor._rhs(1, 2, 3, t=0.0, **lor.params)) + (10, 23, -6.0009999999999994) + + >>> print(tuple(lambda_rhs(1, 2, 3))) + (10, 23, -6.00100000000000) + + """ + + if not hasattr(obj, func_name): + raise AttributeError(f"Object has no attribute {func_name!r}") + + func = getattr(obj, func_name) + src = inspect.getsource(func) + src = textwrap.dedent(src) + + parsed = ast.parse(src) + + # Find first FunctionDef + fndef = None + for node in parsed.body: + if isinstance(node, ast.FunctionDef): + fndef = node + break + if fndef is None: + raise RuntimeError("No function definition found in source") + + # Determine state argument names. Common dysts signature: + # (self, *states, t, *parameters). Prefer obj.dimension when available. + arg_names = [a.arg for a in fndef.args.args] + if len(arg_names) == 0: + raise RuntimeError("Function has no arguments") + + start_idx = 0 + if arg_names[0] == "self": + start_idx = 1 + + vector_mode = False + state_args: List[str] + t_idx = None + if "t" in arg_names: + t_idx = arg_names.index("t") + + if hasattr(obj, "dimension") and isinstance(getattr(obj, "dimension"), int): + n_state = int(getattr(obj, "dimension")) + if t_idx is not None: + potential = arg_names[start_idx:t_idx] + if len(potential) >= n_state: + state_args = potential[:n_state] + else: + state_args = [arg_names[start_idx]] + vector_mode = True + else: + potential = arg_names[start_idx:] + if len(potential) >= n_state: + state_args = potential[:n_state] + else: + state_args = [arg_names[start_idx]] + vector_mode = True + else: + if t_idx is not None: + state_args = arg_names[start_idx:t_idx] + if len(state_args) == 0: + state_args = [arg_names[start_idx]] + vector_mode = True + elif len(state_args) == 1: + # single name could be vector or scalar; assume vector-mode + vector_mode = True + else: + state_args = [arg_names[start_idx]] + vector_mode = True + + # If vector_mode, inspect AST for subscript/index usage or tuple unpacking + if vector_mode: + state_name = state_args[0] + max_index = -1 + unpack_size = None + for node in ast.walk(fndef): + if ( + isinstance(node, ast.Subscript) + and isinstance(node.value, ast.Name) + and node.value.id == state_name + ): + sl = node.slice + if isinstance(sl, ast.Constant) and isinstance(sl.value, int): + if sl.value > max_index: + max_index = sl.value + if isinstance(node, ast.Assign): + if isinstance(node.value, ast.Name) and node.value.id == state_name: + targets = node.targets + if len(targets) == 1 and isinstance( + targets[0], (ast.Tuple, ast.List) + ): + unpack_size = len(targets[0].elts) + + if unpack_size is not None: + n_state = unpack_size + elif max_index >= 0: + n_state = max_index + 1 + else: + n_state = int(getattr(obj, "dimension", 3)) + + state_symbols = [sp.Symbol(f"x{i}") for i in range(n_state)] + primary_state_name = state_name + else: + # individual state args -> use their arg names as symbol names + state_symbols = [sp.Symbol(n) for n in state_args] + primary_state_name = state_args[0] if len(state_args) > 0 else "x" + + # Build locals mapping from known state arg names and parameters + locals_map: Dict[str, Any] = {} + for i, name in enumerate(state_args): + if i < len(state_symbols): + locals_map[name] = state_symbols[i] + + # map parameters (if present) to numeric values or symbols + if hasattr(obj, "parameters") and isinstance(getattr(obj, "parameters"), dict): + params = getattr(obj, "parameters") + if t_idx is not None: + param_arg_names = arg_names[t_idx + 1 :] + else: + param_arg_names = [] + for pname in param_arg_names: + if pname in params: + locals_map[pname] = sp.sympify(params[pname]) + else: + locals_map[pname] = sp.Symbol(pname) + + converter = _ASTToSympy(primary_state_name, state_symbols, locals_map) + + return_expr = None + # Walk through function body statements, handle Assign and Return + for stmt in fndef.body: + if isinstance(stmt, ast.Assign): + # only simple single-target assignments supported + if len(stmt.targets) != 1: + raise ValueError("Only single-target assignments supported") + target = stmt.targets[0] + if isinstance(target, ast.Name): + value_expr = converter.visit(stmt.value) + locals_map[target.id] = value_expr + elif ( + isinstance(target, (ast.Tuple, ast.List)) + and isinstance(stmt.value, ast.Name) + and stmt.value.id == state_name + ): + # unpacking like a,b,c = x -> map names to state symbols + for i, elt in enumerate(target.elts): + if isinstance(elt, ast.Name): + locals_map[elt.id] = state_symbols[i] + elif isinstance(stmt, ast.Return): + return_expr = stmt.value + + if return_expr is None: + # maybe last statement is an Expr with list construction; + # try to find a Return node deep + for node in ast.walk(fndef): + if isinstance(node, ast.Return): + return_expr = node.value + break + + if return_expr is None: + raise RuntimeError("No return expression found in function body") + + # Refresh converter with updated locals + converter = _ASTToSympy(primary_state_name, state_symbols, locals_map) + rhs_val = converter.visit(return_expr) + + if isinstance(rhs_val, (list, tuple)): + exprs = list(rhs_val) + else: + # single expression: treat as 1-dim RHS + exprs = [rhs_val] + + lambda_rhs = sp.Lambda(tuple(state_symbols), sp.Matrix(exprs)) + + # Run numeric consistency guard (raises on mismatch) + _numeric_consistency_check( + obj, + func, + arg_names, + state_args, + vector_mode, + len(state_symbols), + lambda_rhs, + ) + + return state_symbols, exprs, lambda_rhs + + +__all__ = ["dynsys_to_sympy"] diff --git a/asv_bench/benchmarks/joint.py b/asv_bench/benchmarks/joint.py new file mode 100644 index 000000000..eb7612dd3 --- /dev/null +++ b/asv_bench/benchmarks/joint.py @@ -0,0 +1,274 @@ +import re +from itertools import chain +from typing import cast +from typing import NamedTuple +from typing import Optional +from typing import TypeVar +from warnings import warn + +import jax.numpy as jnp +import numpy as np +import sklearn.metrics +from dysts.flows import Lorenz +from numpy.typing import NBitBase +from numpy.typing import NDArray + +import pysindy as ps +from pysindy.sssindy import SSSINDy +from pysindy.sssindy.expressions import JaxPolyLib +from pysindy.sssindy.expressions import JointObjective +from pysindy.sssindy.interpolants import RKHSInterpolant +from pysindy.sssindy.interpolants.kernels import get_gaussianRBF +from pysindy.sssindy.opt import L2CholeskyLMRegularizer +from pysindy.sssindy.opt import LMSettings +from pysindy.sssindy.opt import LMSolver + + +class SSSINDyLorenzSparsity: + """ + See that we do decently on the Lorenz system + """ + + timeout = 360 + + def setup(self): + self.data = gen_lorenz(seed=124, dt=0.02, t_end=5) + self.sss_model = SSSINDy( + JointObjective(50, 1, JaxPolyLib(), RKHSInterpolant(get_gaussianRBF(0.05))), + LMSolver( + optimizer_settings=LMSettings( + callback_every=5, search_increase_ratio=2.0 + ) + ), + ) + + def time_experiment(self): + eval_sss(self.data, self.sss_model) + + def track_experiment(self): + results = eval_sss(self.data, self.sss_model) + return results + + track_experiment.unit = "MAE" + + +def eval_sss(data: "ProbData", sss_model: SSSINDy) -> float: + sss_model.feature_names = data.input_features + n_coloc = len(data.t_train) + n_traj = len(data.x_train) + sss_metrics = [] + t_train = data.t_train + x_train = [jnp.array(x) for x in data.x_train] + t_coloc = jnp.linspace(t_train[0], t_train[-1], n_coloc) + t_coloc = n_traj * [jnp.array(t_coloc)] + t_train = n_traj * [jnp.array(t_train)] + sss_model.fit(x_train, t=t_train, t_coloc=t_coloc) + + coeff_true, sss_coeffs, _ = unionize_coeff_matrices( + sss_model, data.coeff_true, True + ) + sss_model.print(flush=True) + + x_interp = jnp.hstack(sss_model.x_predict(data.t_train)) + x_true = jnp.hstack(data.x_train_true) + interp_rel_err = jnp.linalg.norm(x_interp - x_true) / jnp.linalg.norm(x_true) + sss_metrics.append( + coeff_metrics(sss_coeffs, coeff_true) | {"interp_rel_err": interp_rel_err} + ) + + return sss_metrics[0]["coeff_mae"] + + +#################################### +# Copied from gen_experiments to avoid circular imports +#################################### + + +NpFlt = np.dtype[np.floating[NBitBase]] +Float1D = np.ndarray[tuple[int], NpFlt] +Float2D = np.ndarray[tuple[int, int], NpFlt] +Shape = TypeVar("Shape", bound=tuple[int, ...]) +FloatND = np.ndarray[Shape, np.dtype[np.floating[NBitBase]]] + + +class ProbData(NamedTuple): + dt: float + t_train: Float1D + x_train: list[FloatND] + x_test: list[FloatND] + x_dot_test: list[FloatND] + x_train_true: list[FloatND] + input_features: list[str] + coeff_true: list[dict[str, float]] + + +def gen_lorenz( + seed: Optional[int] = None, + n_trajectories: int = 1, + ic_stdev: float = 3, + dt: float = 0.01, + t_end: float = 10, +) -> ProbData: + """Generate zero-noise training and test data + + An Experiment step according to the mitosis experiment runner. + Note that test data has no noise. + + Arguments: + group: the function to integrate + seed (int): the random seed for number generation + n_trajectories (int): number of trajectories of training data + ic_stdev (float): standard deviation for generating initial conditions + noise_abs (float): measurement noise standard deviation. + Defaults to .1 if noise_rel is None. + noise_rel (float): measurement noise-to-signal power ratio. + Either noise_abs or noise_rel must be None. Defaults to + None. + dt: time step for sample + t_end: end time of simulation + display: Whether to display graphics of generated data. + + Returns: + dictionary of data and descriptive information + """ + coeff_true = [ + {"x": -10, "y": 10}, + {"x": 28, "y": -1, "x z": -1}, + {"z": -8 / 3, "x y": 1}, + ] + input_features = ["x", "y", "z"] + x0_center = np.array([0, 0, 15]) + rng = np.random.default_rng(seed) + nt = int(t_end // dt) + x_train = [] + n_coord = 3 + x0_train = ic_stdev * rng.standard_normal((n_trajectories, n_coord)) + x0_center + for x0 in x0_train: + x_train.append(Lorenz().make_trajectory(nt, dt, x0)) + t_train = np.linspace(0, t_end, nt, dtype=np.float64) + x_dot_train = [np.vstack([Lorenz().rhs(xij, 0) for xij in xi]) for xi in x_train] + return ProbData( + dt, + t_train, + x_train, + x_train, + x_dot_train, + x_train, + input_features, + coeff_true, + ) + + +def unionize_coeff_matrices( + model: ps._core._BaseSINDy, + model_true: tuple[list[str], list[dict[str, float]]] | list[dict[str, float]], + strict: bool = False, +) -> tuple[NDArray[np.float64], NDArray[np.float64], list[str]]: + """Reformat true coefficients and coefficient matrix compatibly + + In order to calculate accuracy metrics between true and estimated + coefficients, this function compares the names of true coefficients + and a the fitted model's features in order to create comparable + (i.e. non-ragged) true and estimated coefficient matrices. In + a word, it stacks the correct coefficient matrix and the estimated + coefficient matrix in a matrix that represents the union of true + features and modeled features. + + Arguments: + model: fitted model + model_true: A tuple of (a) a list of input feature names, and + (b) a list of dicts of format function_name: coefficient, + one dict for each modeled coordinate/target. The old format + of passing one + strict: + whether to attempt to translate the model's features into the + input variable names in the true model. + Returns: + Tuple of true coefficient matrix, estimated coefficient matrix, + and combined feature names + + Warning: + Does not disambiguate between commutatively equivalent function + names such as 'x z' and 'z x' or 'x^2' and 'x x'. + + Warning: + In non-strict mode, when different input variables are detected in the + SINDy model and in the true model, will attempt to translate true + features to model inputs, e.g. ``x^2`` -> ``x0^2``. This is a + text replacement, not a lexical replacement, so there are edge cases + where translation fails. Input variables are sorted alphabetically. + """ + inputs_model = cast(list[str], model.feature_names) + if isinstance(model_true, list): + warn( + "Passing coeff_true as merely the list of functions is deprecated. " + " It is now required to pass a tuple of system coordinate variables" + " as well as the list of functions.", + DeprecationWarning, + ) + in_funcs = set.union(*[set(d.keys()) for d in model_true]) + + def extract_vars(fname: str) -> set[str]: + # split on ops like *,^, only accept x, x2, from x * x2 ^ 2, but need x' + return { + var for var in re.split(r"[^\w']", fname) if re.match(r"[^\d]", var) + } + + inputs_set = set.union(*[extract_vars(fname) for fname in in_funcs]) + inputs_true = sorted(inputs_set) + coeff_true = model_true + else: + inputs_true, coeff_true = model_true + model_features = model.get_feature_names() + true_features = [set(coeffs.keys()) for coeffs in coeff_true] + if inputs_true != inputs_model: + if strict: + raise ValueError( + "True model and fit model have different input variable names" + ) + mapper = dict(zip(inputs_model, inputs_true, strict=True)) + translated_features: list[str] = [] + for feat in model_features: + for k, v in mapper.items(): + feat = feat.replace(k, v) + translated_features.append(feat) + model_features = translated_features + + unmodeled_features = set(chain.from_iterable(true_features)) - set(model_features) + model_features.extend(list(unmodeled_features)) + est_coeff_mat = model.coefficients() + new_est_coeff = np.zeros((est_coeff_mat.shape[0], len(model_features))) + new_est_coeff[:, : est_coeff_mat.shape[1]] = est_coeff_mat + true_coeff_mat = np.zeros_like(new_est_coeff) + for row, terms in enumerate(coeff_true): + for term, coeff in terms.items(): + true_coeff_mat[row, model_features.index(term)] = coeff + + return true_coeff_mat, new_est_coeff, model_features + + +def coeff_metrics(coefficients, coeff_true): + metrics = {} + metrics["coeff_precision"] = sklearn.metrics.precision_score( + coeff_true.flatten() != 0, coefficients.flatten() != 0 + ) + metrics["coeff_recall"] = sklearn.metrics.recall_score( + coeff_true.flatten() != 0, coefficients.flatten() != 0 + ) + metrics["coeff_f1"] = sklearn.metrics.f1_score( + coeff_true.flatten() != 0, coefficients.flatten() != 0 + ) + metrics["coeff_mse"] = sklearn.metrics.mean_squared_error( + coeff_true.flatten(), coefficients.flatten() + ) + metrics["coeff_mae"] = sklearn.metrics.mean_absolute_error( + coeff_true.flatten(), coefficients.flatten() + ) + metrics["main"] = metrics["coeff_f1"] + return metrics + + +if __name__ == "__main__": + bm = SSSINDyLorenzSparsity() + bm.setup() + bm.time_experiment() diff --git a/pyproject.toml b/pyproject.toml index 4ff50ec16..9cee1dcd0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,9 +33,20 @@ dependencies = [ "derivative>=0.6.2", "scipy", "typing_extensions", + "cvxpy>=1.5", + "scs>=2.1, !=2.1.4", + "jax>=0.6.2", + "equinox", + "sympy", + "sympy2jax", + "tqdm", + "jaxopt", + "numpyro", + "gurobipy>=9.5.1,!=10.0.0" ] [project.optional-dependencies] +bench = ["dysts"] dev = [ "asv", "matplotlib", @@ -43,6 +54,7 @@ dev = [ "pytest>=6.2.4, <8.0.0", "black", "build", + "dysts", "pytest-cov", "pytest-lazy-fixture", "flake8-builtins-unleashed", @@ -63,17 +75,7 @@ docs = [ "sphinxcontrib-apidoc", "matplotlib" ] -miosr = [ - "gurobipy>=9.5.1,!=10.0.0" -] -cvxpy = [ - "cvxpy>=1.5", - "scs>=2.1, !=2.1.4" -] -sbr = [ - "numpyro", - "jax" -] + [tool.black] line-length = 88 diff --git a/pysindy/jsindy/__init__.py b/pysindy/jsindy/__init__.py new file mode 100644 index 000000000..d5d9691fb --- /dev/null +++ b/pysindy/jsindy/__init__.py @@ -0,0 +1,3 @@ +from .sindy_model import JSINDyModel + +__all__ = ["JSINDyModel"] diff --git a/pysindy/jsindy/dynamics_model.py b/pysindy/jsindy/dynamics_model.py new file mode 100644 index 000000000..54a96f497 --- /dev/null +++ b/pysindy/jsindy/dynamics_model.py @@ -0,0 +1,105 @@ +from abc import ABC + +import jax +import jax.numpy as jnp + +import pysindy as ps +from .util import l2reg_lstsq + + +class DynamicsModel(ABC): + def predict(self, x, theta): + pass + +class PolyLib(ps.PolynomialLibrary): + def fit(self, x: jax.Array): + #Using ps.PolynomialLibrary to get powers right now + super().fit(x) + self.jpowers_ = jnp.array(self.powers_) + + def transform(self, x: jax.Array): + if jnp.ndim(x)==2: + return jnp.prod(jax.vmap(jnp.pow,in_axes=(None,0))(x,self.jpowers_),axis=2).T + elif jnp.ndim(x)==1: + return jnp.prod(jax.vmap(jnp.pow,in_axes=(None,0))(x,self.jpowers_),axis=1) + else: + raise ValueError(f"Polynomial library cannot handle input shape, {x.shape}") + + def __call__(self, X): + return self.transform(X) + +class FeatureLinearModel(DynamicsModel): + def __init__( + self, + feature_map=PolyLib(degree=2), + reg_scaling = 1. + ) -> None: + self.feature_map = feature_map + self.attached = False + self.reg_scaling = reg_scaling + + def __str__(self): + return "\n".join([f"{key}: {value}" for key, value in self.feature_map.get_params().items()]) + + def attach(self, x: jax.Array,input_orders = (0,)): + shaped_features = jnp.hstack([x]*len(input_orders)) + + self.feature_map.fit(shaped_features) + self.num_targets = x.shape[1] + + self.num_theta = ( + self.num_targets * self.feature_map.n_output_features_ + ) + self.num_features = self.feature_map.n_output_features_ + self.attached = True + self.regmat = self.reg_scaling*jnp.eye(self.num_theta) + + self.tot_params = self.num_features*self.num_targets + self.param_shape = (self.num_features, self.num_targets) + + def initialize(self,t,x,params,input_orders): + self.attach(x,input_orders = input_orders) + return params + + def initialize_partialobs(self,t,y,v,params,input_orders): + #Pretending that v is x gives all of the right shapes + self.attach(v,input_orders = input_orders) + return params + + + # somewhere in jsindy.fit a predict is used and needs to fixed + def predict(self, x, theta): + if jnp.ndim(x)==1: + return self.feature_map.transform(x) @ theta + elif jnp.ndim(x)==2: + return self.feature_map.transform(x) @ theta + else: + raise ValueError(f"x shape not compatible, x.shape = {x.shape}") + + def __call__(self, x,theta): + return self.predict(x,theta) + + def get_fitted_theta(self,x,xdot,lam = 1e-2): + A = self.feature_map.transform(x) + return l2reg_lstsq(A,xdot,reg = lam) + +# class FeatureLinearModel(): +# def __init__( +# self, +# feature_map, +# in_dim, +# out_dim, +# ): +# self.shape = ... +# self.feature_map = feature_map +# self.regularization_weights = ... + +# def featurize(self,x): +# return self.feature_map(x) + +# def predict(self, x,theta): +# FX = self.featurize(x) +# return FX@theta + +# def __call__(self, x,theta): +# self.predict(x,theta) diff --git a/pysindy/jsindy/kernels/__init__.py b/pysindy/jsindy/kernels/__init__.py new file mode 100644 index 000000000..3fea81798 --- /dev/null +++ b/pysindy/jsindy/kernels/__init__.py @@ -0,0 +1,28 @@ +from .base_kernels import ConstantKernel +from .base_kernels import Kernel +from .base_kernels import softplus_inverse +from .fit_kernel import build_loocv +from .fit_kernel import build_neg_marglike +from .fit_kernel import fit_kernel +from .fit_kernel import fit_kernel_partialobs +from .kernels import GaussianRBFKernel +from .kernels import LinearKernel +from .kernels import PolynomialKernel +from .kernels import RationalQuadraticKernel +from .kernels import ScalarMaternKernel +from .kernels import SpectralMixtureKernel + +__all__ = [ + "Kernel", + "GaussianRBFKernel", + "ScalarMaternKernel", + "RationalQuadraticKernel", + "LinearKernel", + "PolynomialKernel", + "SpectralMixtureKernel", + "fit_kernel", + "build_loocv", + "build_neg_marglike", + "softplus_inverse", + "fit_kernel_partialobs" +] diff --git a/pysindy/jsindy/kernels/base_kernels.py b/pysindy/jsindy/kernels/base_kernels.py new file mode 100644 index 000000000..aebc1f321 --- /dev/null +++ b/pysindy/jsindy/kernels/base_kernels.py @@ -0,0 +1,196 @@ +from abc import abstractmethod + +import equinox as eqx +import jax +import jax.numpy as jnp +from jax.nn import softplus + +def softplus_inverse(y: jnp.ndarray) -> jnp.ndarray: + return y + jnp.log1p(-jnp.exp(-y)) + +class Kernel(eqx.Module): + """Abstract base class for kernels in JAX + Equinox.""" + + @abstractmethod + def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + """Compute k(x, y). Must be overridden by subclasses.""" + pass + + def __add__(self, other: "Kernel"): + """ + Overload the '+' operator so we can do k1 + k2. + Internally, we return a SumKernel object containing both. + Also handles the case if `other` is already a SumKernel, in + which case we combine everything into one big sum. + """ + if isinstance(other, SumKernel): + # Combine self with an existing SumKernel's list + return SumKernel(*( [self] + list(other.kernels) )) + elif isinstance(other, Kernel): + return SumKernel(self, other) + else: + return NotImplemented + + def __mul__(self,other:"Kernel"): + """ + Overload the '*' operator so we can do k1 * k2. + Internally, we return a ProductKernel object containing both. + Also handles the case if `other` is already a ProductKernel, in + which case we combine everything into one big sum. + """ + if isinstance(other, ProductKernel): + return ProductKernel(*( [self] + list(other.kernels) )) + elif isinstance(other, Kernel): + return ProductKernel(self, other) + else: + return NotImplemented + + def transform(self,f): + """ + Creates a transformed kernel, returning a kernel function + k_transformed(x,y) = k(f(x),f(y)) + """ + return TransformedKernel(self,f) + + def scale(self,c): + """ + returns a kernel rescaled by a constant factor c + really should be implemented better + but the abstract Kernel doesn't include the variances yet + Thus, we return a product kernel with the constant kernel, + abusing the __mul__ overloading + """ + kc = ConstantKernel(c) + return kc * self + + +class TransformedKernel(Kernel): + """ + Transformed kernel, representing the + composition of a kernel with another + fixed function + """ + kernel: Kernel + transform: callable = eqx.field(static=True) + + def __init__(self,kernel,transform): + self.kernel = kernel + self.transform = transform + + def __call__(self, x, y): + return self.kernel(self.transform(x),self.transform(y)) + + def __str__(self): + return f"Transformed({self.kernel.__str__()})" + + +class SumKernel(Kernel): + """ + Represents the sum of multiple kernels: + k_sum(x, y) = sum_{k in kernels} k(x, y) + """ + kernels: tuple[Kernel, ...] + + def __init__(self, *kernels: Kernel): + self.kernels = kernels + + def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + return sum(k(x, y) for k in self.kernels) + + def __add__(self, other: "Kernel"): + """ + If we do (k1 + k2) + k3, the left side is a SumKernel, so + we define its __add__ to merge again into one SumKernel. + """ + if isinstance(other, SumKernel): + return SumKernel(*(list(self.kernels) + list(other.kernels))) + elif isinstance(other, Kernel): + return SumKernel(*(list(self.kernels) + [other])) + else: + return NotImplemented + + def scale(self,c): + """ + Push scaling down a level + """ + return SumKernel(*[k.scale(c) for k in self.kernels]) + + def __str__(self): + component_str = [k.__str__() for k in self.kernels] + return f"{" + ".join(component_str)}" + +class ProductKernel(Kernel): + """ + Represents the sum of multiple kernels: + k_sum(x, y) = prod_{k in kernels} k(x, y) + """ + kernels: tuple[Kernel, ...] + + def __init__(self, *kernels: Kernel): + self.kernels = kernels + + def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + return jnp.prod(jnp.array([k(x, y) for k in self.kernels])) + + def __prod__(self, other: "Kernel"): + """ + If we do (k1*k2)*k3, the left side is a ProductKernel, so + we define its __prod__ to merge again into one ProductKernel. + """ + if isinstance(other, SumKernel): + return ProductKernel(*(list(self.kernels) + list(other.kernels))) + elif isinstance(other, Kernel): + return ProductKernel(*(list(self.kernels) + [other])) + else: + return NotImplemented + + def scale(self,c): + """ + Scale the first kernel + """ + return ProductKernel(*([self.kernels[0].scale(c)] + [self.kernels[1:]])) + + def __str__(self): + component_str = ["(" + k.__str__() + ")" for k in self.kernels] + return f"{"*".join(component_str)}" + +class FrozenKernel(Kernel): + kernel:Kernel + def __init__(self,kernel): + self.kernel = kernel + + def __call__(self, x, y): + return jax.lax.stop_gradient(self.kernel)(x, y) + + def __str__(self): + return self.kernel.__str__() + +class ConstantKernel(Kernel): + """ + Constant kernel k(x, y) = c for all x, y. + + Params: + variance, variance of the constant shift + Internally stored as "raw_" after applying softplus_inverse. + """ + raw_variance: jnp.ndarray + + def __init__(self, variance: float = 1.0): + """ + :param variance: A positive float specifying the kernel's constant value. + """ + if variance <= 0: + raise ValueError("ConstantKernel requires a strictly positive constant.") + # Store an unconstrained parameter via softplus-inverse + self.raw_variance = softplus_inverse(jnp.array(variance)) + + def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + v = softplus(self.raw_variance) + return v + + def scale(self,c): + return ConstantKernel(c*softplus(self.raw_variance)) + + def __str__(self): + v = softplus(self.raw_variance) + return f"{v:.3f}" diff --git a/pysindy/jsindy/kernels/fit_kernel.py b/pysindy/jsindy/kernels/fit_kernel.py new file mode 100644 index 000000000..39e63698e --- /dev/null +++ b/pysindy/jsindy/kernels/fit_kernel.py @@ -0,0 +1,188 @@ +import jax +import jax.numpy as jnp +from jax.nn import softplus +from jaxopt import LBFGS +from jsindy.kerneltools import vectorize_kfunc + +from .kernels import softplus_inverse +from .tree_opt import run_gradient_descent +from .tree_opt import run_jaxopt_solver + + +SIGMA2_FLOOR = 1e-6 +def build_neg_marglike(X,y): + if jnp.ndim(y)==1: + m = 1 + elif jnp.ndim(y)==2: + m = y.shape[1] + else: + raise ValueError("y must be either a 1 or two dimensional array") + + + def neg_marginal_likelihood(kernel,sigma2): + K = vectorize_kfunc(kernel)(X,X) + I = jnp.eye(len(X)) + + C = jax.scipy.linalg.cholesky(K + sigma2 * I,lower = True) + logdet = 2*jnp.sum(jnp.log(jnp.diag(C))) + yTKinvY = jnp.sum( + (jax.scipy.linalg.solve_triangular(C,y,lower = True))**2 + ) + return m * logdet + yTKinvY + + def loss(params): + k = params['kernel'] + sigma2 = softplus(params['transformed_sigma2']) + SIGMA2_FLOOR + return neg_marginal_likelihood(k,sigma2) + + return loss + +def build_neg_marglike_partialobs(t,y,v): + if jnp.ndim(y)==1: + m = 1 + elif jnp.ndim(y)==2: + m = y.shape[1] + else: + raise ValueError("y must be either a 1 or two dimensional array") + + def neg_marginal_likelihood(kernel,sigma2): + Kt = vectorize_kfunc(kernel)(t,t) + I = jnp.eye(len(t)) + VV = v@v.T + K = Kt*VV + + C = jax.scipy.linalg.cholesky(K + sigma2 * I,lower = True) + logdet = 2*jnp.sum(jnp.log(jnp.diag(C))) + yTKinvY = jnp.sum( + (jax.scipy.linalg.solve_triangular(C,y,lower = True))**2 + ) + return m * logdet + yTKinvY + + def loss(params): + k = params['kernel'] + sigma2 = softplus(params['transformed_sigma2']) + SIGMA2_FLOOR + return neg_marginal_likelihood(k,sigma2) + + return loss + + +def build_loocv(X,y): + def loocv(kernel,sigma2): + k = vectorize_kfunc(kernel) + K = k(X,X) + I = jnp.eye(len(X)) + P = jnp.linalg.inv(K + sigma2*I) + KP = K@P + loo_preds = K@P@y - (jnp.diag(KP)/jnp.diag(P))*(P@y) + mse_loo = jnp.mean((loo_preds - y)**2) + return mse_loo + + def loss(params): + k = params['kernel'] + sigma2 = softplus(params['transformed_sigma2']) + return loocv(k,sigma2) + return loss + +def build_random_split_obj(X, y, p=0.2, rng_key=None): + """ + p: proportion of data to use as validation set (between 0 and 1) + rng_key: optional JAX + """ + n = X.shape[0] + if rng_key is None: + rng_key = jax.random.key(1) + perm = jax.random.permutation(rng_key, n) + n_val = int(jnp.round(p * n)) + val_idx = perm[:n_val] + train_idx = perm[n_val:] + Xtrain = X[train_idx] + ytrain = y[train_idx] + Xval = X[val_idx] + yval = y[val_idx] + + def l2_cv(kernel, sigma2): + K = vectorize_kfunc(kernel)(Xtrain, Xtrain) + I = jnp.eye(len(ytrain)) + c = jnp.linalg.solve(K + sigma2 * I, ytrain) + ypred = vectorize_kfunc(kernel)(Xval, Xtrain) @ c + return jnp.mean((ypred - yval) ** 2) + + def loss(params): + k = params['kernel'] + sigma2 = softplus(params['transformed_sigma2']) + return l2_cv(k, sigma2) + return loss + +def build_every_other_obj(X,y): + Xtrain = X[::2] + ytrain = y[::2] + Xval = X[1::2] + yval = y[1::2] + def l2_cv(kernel,sigma2): + K = vectorize_kfunc(kernel)(Xtrain,Xtrain) + I = jnp.eye(len(ytrain)) + c = jnp.linalg.solve(K + sigma2*I,ytrain) + ypred = vectorize_kfunc(kernel)(Xval,Xtrain)@c + return jnp.mean((ypred - yval)**2) + + def loss(params): + k = params['kernel'] + sigma2 = softplus(params['transformed_sigma2']) + return l2_cv(k,sigma2) + return loss + +def fit_kernel( + init_kernel, + init_sigma2, + X, + y, + loss_builder = build_neg_marglike, + gd_tol = 1e-4, + lbfgs_tol = 1e-6, + max_gd_iter = 3000, + max_lbfgs_iter = 1000, + show_progress=True, + ): + loss = loss_builder(X,y) + init_params = {'kernel':init_kernel, + 'transformed_sigma2':jnp.array(softplus_inverse(init_sigma2)) + } + + params,conv_history_gd = run_gradient_descent( + loss,init_params,tol = gd_tol, + maxiter = max_gd_iter, + show_progress=show_progress, + init_stepsize=1e-4 + ) + solver = LBFGS(loss,maxiter = max_lbfgs_iter,tol = lbfgs_tol) + params,conv_history_bfgs,state = run_jaxopt_solver(solver,params, show_progress=show_progress) + conv_hist = [conv_history_gd,conv_history_bfgs] + + return params['kernel'],jax.nn.softplus(params['transformed_sigma2']) + SIGMA2_FLOOR,conv_hist + +def fit_kernel_partialobs( + init_kernel, + init_sigma2, + t,y,v, + gd_tol = 1e-4, + lbfgs_tol = 1e-6, + max_gd_iter = 3000, + max_lbfgs_iter = 1000, + show_progress=True, + ): + loss = build_neg_marglike_partialobs(t,y,v) + init_params = {'kernel':init_kernel, + 'transformed_sigma2':jnp.array(softplus_inverse(init_sigma2)) + } + + params,conv_history_gd = run_gradient_descent( + loss,init_params,tol = gd_tol, + maxiter = max_gd_iter, + show_progress=show_progress, + init_stepsize=1e-4 + ) + solver = LBFGS(loss,maxiter = max_lbfgs_iter,tol = lbfgs_tol) + params,conv_history_bfgs,state = run_jaxopt_solver(solver,params, show_progress=show_progress) + conv_hist = [conv_history_gd,conv_history_bfgs] + + return params['kernel'],jax.nn.softplus(params['transformed_sigma2']) + SIGMA2_FLOOR,conv_hist diff --git a/pysindy/jsindy/kernels/kernels.py b/pysindy/jsindy/kernels/kernels.py new file mode 100644 index 000000000..a1fbcb29d --- /dev/null +++ b/pysindy/jsindy/kernels/kernels.py @@ -0,0 +1,290 @@ +import equinox as eqx +import jax +import jax.numpy as jnp +from jax.nn import softplus + +from .base_kernels import Kernel +from .base_kernels import softplus_inverse +from .matern import build_matern_core + + +class TranslationInvariantKernel(Kernel): + """ + Not used for anything yet, but maybe unifies some of the other kernels + Kernels defined by k(x,y) = var * h( (x-y)/ls ) + """ + + core_func: callable + raw_variance: jax.Array + raw_lengthscale: jax.Array + + min_lengthscale: jax.Array = eqx.field(static=True) + fix_variance: bool = eqx.field(static=True) + fix_lengthscale: bool = eqx.field(static=True) + + def __init__( + self, + core_func, + lengthscale, + variance, + min_lengthscale, + fix_variance=False, + fix_lengthscale=False, + ): + self.raw_variance = softplus_inverse(jnp.array(variance)) + if lengthscale < min_lengthscale: + raise ValueError("Initial lengthscale below minimum") + self.raw_lengthscale = softplus_inverse( + jnp.array(lengthscale) - min_lengthscale + ) + self.min_lengthscale = min_lengthscale + self.fix_variance = fix_variance + self.fix_lengthscale = fix_lengthscale + self.core_func = core_func + + def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + var = softplus(self.raw_variance) + if self.fix_variance is True: + var = jax.lax.stop_gradient(var) + + ls = softplus(self.raw_lengthscale) + self.min_lengthscale + if self.fix_lengthscale is True: + ls = jax.lax.stop_gradient(ls) + + scaled_diff = (y - x) / ls + return var * self.core_func(scaled_diff) + + +class ScalarMaternKernel(Kernel): + """ + Scalar half-integer order matern kernel + order = p+(1/2) + + Parameters: + p: int + variance > 0 + lengthscale > 0 + Internally stored as "raw_" after applying softplus_inverse. + """ + + core_matern: callable = eqx.field(static=True) + p_order: int = eqx.field(static=True) + raw_variance: jax.Array + raw_lengthscale: jax.Array + min_lengthscale: jax.Array = eqx.field(static=True) + + def __init__(self, p, lengthscale=1.0, variance=1.0, min_lengthscale=0.01): + self.raw_variance = softplus_inverse(jnp.array(variance)) + # if lengthscale jnp.ndarray: + var = softplus(self.raw_variance) + ls = softplus(self.raw_lengthscale) + self.min_lengthscale + scaled_diff = (y - x) / ls + return var * self.core_matern(scaled_diff) + + def scale(self, c): + new_raw_var = softplus_inverse(c * softplus(self.raw_variance)) + return eqx.tree_at(lambda x: x.raw_variance, self, new_raw_var) + + def __str__(self): + var = softplus(self.raw_variance) + ls = softplus(self.raw_lengthscale) + self.min_lengthscale + return f"{var:.2f}Matern({self.p_order},{ls:.2f})" + + +class GaussianRBFKernel(Kernel): + """ + RBF (squared exponential) kernel: + k(x, y) = variance * exp(-||x - y||^2 / (2*lengthscale^2)) + + Parameters: + variance > 0 + lengthscale > 0 + Internally stored as "raw_" after applying softplus_inverse. + """ + + raw_variance: jax.Array + raw_lengthscale: jax.Array + min_lengthscale: jax.Array = eqx.field(static=True) + + def __init__(self, lengthscale=1.0, variance=1.0, min_lengthscale=0.01): + # Convert user-supplied positive parameters to unconstrained domain + if lengthscale < min_lengthscale: + raise ValueError("Initial lengthscale below minimum") + self.raw_variance = softplus_inverse(jnp.array(variance)) + self.raw_lengthscale = softplus_inverse( + jnp.array(lengthscale) - min_lengthscale + ) + self.min_lengthscale = min_lengthscale + + def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + var = softplus(self.raw_variance) + ls = softplus(self.raw_lengthscale) + self.min_lengthscale + sqdist = jnp.sum((x - y) ** 2) + return var * jnp.exp(-0.5 * sqdist / (ls**2)) + + def scale(self, c): + new_raw_var = softplus_inverse(c * softplus(self.raw_variance)) + return eqx.tree_at(lambda x: x.raw_variance, self, new_raw_var) + + def __str__(self): + var = softplus(self.raw_variance) + ls = softplus(self.raw_lengthscale) + self.min_lengthscale + return f"{var:.2f}GRBF({ls:.2f})" + + +class RationalQuadraticKernel(Kernel): + """ + Rational Quadratic kernel: + k(x, y) = variance * [1 + (||x - y||^2 / (2 * alpha * lengthscale^2))]^(-alpha) + + Parameters: + variance > 0 + lengthscale > 0 + alpha > 0 + Internally stored as "raw_" after applying softplus_inverse. + """ + + raw_variance: jax.Array + raw_lengthscale: jax.Array + raw_alpha: jax.Array + min_lengthscale: jax.Array = eqx.field(static=True) + + def __init__(self, lengthscale=1.0, alpha=1.0, variance=1.0, min_lengthscale=0.01): + self.raw_variance = softplus_inverse(jnp.array(variance)) + self.raw_lengthscale = softplus_inverse(jnp.array(lengthscale)) + self.raw_alpha = softplus_inverse(jnp.array(alpha)) + self.min_lengthscale = min_lengthscale + + def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + var = softplus(self.raw_variance) + ls = softplus(self.raw_lengthscale) + self.min_lengthscale + a = softplus(self.raw_alpha) + + sqdist = jnp.sum((x - y) ** 2) + factor = 1.0 + (sqdist / (2.0 * a * ls**2)) + return var * jnp.power(factor, -a) + + def scale(self, c): + new_raw_var = softplus_inverse(c * softplus(self.raw_variance)) + return eqx.tree_at(lambda x: x.raw_variance, self, new_raw_var) + + def __str__(self): + var = softplus(self.raw_variance) + a = softplus(self.raw_alpha) + ls = softplus(self.raw_lengthscale) + self.min_lengthscale + return f"{var:.2f}RQ({a},{ls:.2f})" + + +class SpectralMixtureKernel(Kernel): + """ + Spectral Mixture kernel for scalar inputs: + k(x, y) = sum_{m=1..M} w_m * exp(-2 * (pi*sigma_m)^2 * (x-y)^2) * cos(2 pi (x-y) * periods_m) + where tau = x - y. + + Internally stored as "raw_" after applying softplus_inverse. + """ + + raw_weights: jnp.ndarray + raw_freq_sigmas: jnp.ndarray + periods: jnp.ndarray + + def __init__(self, key, num_mixture=20, period_variance=10.0): + key1, key2, key3 = jax.random.split(key, 3) + self.raw_weights = jax.random.normal(key1, shape=(num_mixture,)) + self.raw_freq_sigmas = jax.random.normal(key2, shape=(num_mixture,)) + self.periods = jnp.sqrt(period_variance) * jax.random.normal( + key3, shape=(num_mixture,) + ) + + def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + tau = x - y + weights = softplus(self.raw_weights) + freq_sigmas = softplus(self.raw_freq_sigmas) + + kernel_components = jnp.exp( + -2.0 * (jnp.pi * freq_sigmas) ** 2 * tau**2 + ) * jnp.cos(2.0 * jnp.pi * tau * self.periods) + return jnp.sum(weights * kernel_components) + + def scale(self, c): + new_raw_weights = softplus_inverse(c * softplus(self.raw_weights)) + return eqx.tree_at(lambda x: x.raw_weights, self, new_raw_weights) + + def __print__(self): + weights = softplus(self.raw_weights) + return f"{jnp.sum(weights):.2f}SpecMix(n={len(self.periods)})" + + +class LinearKernel(Kernel): + """ + Linear Kernel k(x, y) = v* + + Params: + variance, variance + Internally stored as "raw_" after applying softplus_inverse. + """ + + raw_variance: jnp.ndarray + + def __init__(self, variance: float = 1.0): + """ + :param constant: A positive float specifying the kernel's variance + """ + if variance <= 0: + raise ValueError("LinearKernel requires a strictly positive constant.") + # Store an unconstrained parameter via softplus-inverse + self.raw_variance = softplus_inverse(jnp.array(variance)) + + def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + v = softplus(self.raw_variance) # guaranteed positive + return v * jnp.dot(x, y) + + def scale(self, c): + new_raw_var = softplus_inverse(c * softplus(self.raw_variance)) + return eqx.tree_at(lambda x: x.raw_variance, self, new_raw_var) + + def __str__(self): + v = softplus(self.raw_variance) + return f"{v:.2f}Lin()" + + +class PolynomialKernel(Kernel): + """ + Polynomial Kernel k(x, y) = v * (+c)^p + + Params: + variance, variance + Internally stored as "raw_" after applying softplus_inverse. + """ + + raw_variance: jnp.ndarray + degree: int = eqx.field(static=True) + c: jnp.ndarray + + def __init__(self, variance: float = 1.0, c: float = 1.0, degree: int = 2): + if variance <= 0: + raise ValueError("LinearKernel requires a strictly positive constant.") + self.raw_variance = softplus_inverse(jnp.array(variance)) + self.c = jnp.array(c) + self.degree = degree + + def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + v = softplus(self.raw_variance) # guaranteed positive + return v * jnp.pow(jnp.dot(x, y) + self.c, self.degree) + + def scale(self, c): + new_raw_var = softplus_inverse(c * softplus(self.raw_variance)) + return eqx.tree_at(lambda x: x.raw_variance, self, new_raw_var) + + def __print__(self): + v = softplus(self.raw_variance) # guaranteed positive + return f"{v:.2f}Poly({self.c},{self.degree})" diff --git a/pysindy/jsindy/kernels/matern.py b/pysindy/jsindy/kernels/matern.py new file mode 100644 index 000000000..9ad48dac5 --- /dev/null +++ b/pysindy/jsindy/kernels/matern.py @@ -0,0 +1,86 @@ +import jax +import sympy as sym +import sympy2jax +from jax import custom_jvp +from sympy import factorial + +jax.config.update("jax_enable_x64", True) + + +def make_custom_jvp_function(f, fprime): + """Return a function with custom JVP defined by (fprime).""" + + @jax.custom_jvp + def f_wrapped(x): + return f(x) + + @f_wrapped.defjvp + def f_jvp(primals, tangents): + (x,) = primals + (x_dot,) = tangents + return f(x), fprime(x) * x_dot + + return f_wrapped + + +def make_sympy_callable(expr): + def inner(d): + return sympy2jax.SymbolicModule(expr)(d=d) + + return inner + + +def get_sympy_matern(p): + d2 = sym.symbols("d2", positive=True, real=True) + exp_multiplier = -sym.sqrt(2 * p + 1) + coefficients = [ + (factorial(p) / factorial(2 * p)) + * (factorial(p + i) / (factorial(i) * factorial(p - i))) + * (sym.sqrt(8 * p + 4)) ** (p - i) + for i in range(p + 1) + ] + powers = list(range(p, -1, -1)) + matern = sum( + [c * sym.sqrt((d2**power)) for c, power in zip(coefficients, powers)] + ) * sym.exp(exp_multiplier * sym.sqrt(d2)) + return d2, matern + + +def build_matern_core(p): + d2, matern = get_sympy_matern(p) + d = sym.var("d", pos=True, real=True) + + maternd = sym.powdenest(matern.subs(d2, d**2)) + subrule = { + d * sym.DiracDelta(d): 0, + sym.Abs(d) * sym.DiracDelta(d): 0, + sym.Abs(d) * sym.sign(d): d, + d * sym.sign(d): sym.Abs(d), + } + + def compute_next_derivative(expr): + return sym.powdenest(sym.expand(expr.diff(d).subs(subrule))).subs(subrule) + + derivatives = [compute_next_derivative(maternd)] + for k in range(2 * p - 1): + derivatives.append(compute_next_derivative(derivatives[-1])) + + jax_derivatives = [make_sympy_callable(f) for f in derivatives] + + wrapped_derivatives = [ + make_custom_jvp_function(f, fprime) + for f, fprime in zip(jax_derivatives[:-1], jax_derivatives[1:]) + ] + + matern_func_raw = sympy2jax.SymbolicModule(maternd) + core_matern = custom_jvp(lambda d: matern_func_raw(d=d)) + + @core_matern.defjvp + def core_matern_jvp(primals, tangents): + (x,) = primals + (x_dot,) = tangents + ans = core_matern(x) + ans_dot = wrapped_derivatives[0](x) * x_dot + return ans, ans_dot + + return core_matern diff --git a/pysindy/jsindy/kernels/tree_opt.py b/pysindy/jsindy/kernels/tree_opt.py new file mode 100644 index 000000000..72d3c085c --- /dev/null +++ b/pysindy/jsindy/kernels/tree_opt.py @@ -0,0 +1,145 @@ +from warnings import warn + +import jax +import jax.numpy as jnp +from jsindy.util import tree_add +from jsindy.util import tree_dot +from jsindy.util import tree_scale +from tqdm.auto import tqdm + + +def build_armijo_linesearch(f, decrease_ratio=0.5, slope=0.05, max_iter=25): + def armijo_linesearch(x, f_curr, d, g, t0=0.1): + """ + x: current parameters (pytree) + f_curr: f(x) + d: descent direction (pytree) + g: gradient at x (pytree) + t0: initial step size + a: Armijo constant + """ + candidate = tree_add(x, tree_scale(d, -t0)) + dec0 = f(candidate) - f_curr + pred_dec0 = -t0 * tree_dot(d, g) + + # The loop state: (iteration, t, current decrease, predicted decrease) + init_state = (0, t0, dec0, pred_dec0) + + def cond_fun(state): + i, t, dec, pred_dec = state + # Continue while we haven't satisfied the Armijo condition and haven't exceeded max_iter iterations. + not_enough_decrease = dec >= slope * pred_dec + return jnp.logical_and(i < max_iter, not_enough_decrease) + + def body_fun(state): + i, t, dec, pred_dec = state + t_new = decrease_ratio * t + candidate_new = tree_add(x, tree_scale(d, -t_new)) + dec_new = f(candidate_new) - f_curr + pred_dec_new = -t_new * tree_dot(d, g) + return (i + 1, t_new, dec_new, pred_dec_new) + + # Run the while loop + i_final, t_final, dec_final, pred_dec_final = jax.lax.while_loop( + cond_fun, body_fun, init_state + ) + armijo_rat_final = dec_final / pred_dec_final + candidate_final = tree_add(x, tree_scale(d, -t_final)) + return candidate_final, t_final, armijo_rat_final + + return armijo_linesearch + + +def run_gradient_descent( + loss, + init_params, + init_stepsize=0.001, + maxiter=10000, + tol=1e-6, + show_progress=True, + **kwargs, +): + params = init_params + losses = [] + step_sizes = [] + gnorms = [] + + loss_valgrad = jax.value_and_grad(loss) + loss_fun = loss + armijo_linesearch = build_armijo_linesearch(loss_fun, **kwargs) + t = init_stepsize + + @jax.jit + def gd_update(params, t): + lossval, g = loss_valgrad(params) + new_params, new_t, armijo_rat = armijo_linesearch(params, lossval, g, g, t0=t) + gnorm = jnp.sqrt(tree_dot(g, g)) + return new_params, new_t, gnorm, lossval, armijo_rat + + if show_progress: + wrapper = tqdm + else: + wrapper = lambda x: x + + for i in wrapper(range(maxiter)): + params, t, gnorm, lossval, armijo_rat = gd_update(params, t) + if armijo_rat < 0.01: + warn("Line search failed") + if i > 0: + if lossval > losses[-1]: + print(lossval) + losses.append(lossval) + step_sizes.append(t) + gnorms.append(gnorm) + if gnorm < tol: + break + if armijo_rat > 0.5: + t = 1.2 * t + if armijo_rat < 0.1: + t = t / 2 + + conv_history = { + "values": jnp.array(losses), + "stepsizes": jnp.array(step_sizes), + "gradnorms": jnp.array(gnorms), + } + return params, conv_history + + +def run_jaxopt_solver(solver, x0, show_progress=True): + state = solver.init_state(x0) + sol = x0 + values, errors, stepsizes = [state.value], [state.error], [state.stepsize] + num_restarts = 0 + + @jax.jit + def update(sol, state): + return solver.update(sol, state) + + if show_progress: + wrapper = tqdm + else: + wrapper = lambda x: x + + for iter_num in wrapper(range(solver.maxiter)): + sol, state = update(sol, state) + values.append(state.value) + errors.append(state.error) + stepsizes.append(state.stepsize) + if solver.verbose > 0: + print("Gradient Norm: ", state.error) + print("Loss Value: ", state.value) + if state.error <= solver.tol: + break + if stepsizes[-1] == 0: + num_restarts = num_restarts + 1 + print(f"Restart {num_restarts}") + if num_restarts > 10: + break + state = solver.init_state(sol) + convergence_data = { + "values": jnp.array(values), + "gradnorms": jnp.array(errors), + "stepsizes": jnp.array(stepsizes), + } + return sol, convergence_data, state diff --git a/pysindy/jsindy/kerneltools.py b/pysindy/jsindy/kerneltools.py new file mode 100644 index 000000000..3040e8d95 --- /dev/null +++ b/pysindy/jsindy/kerneltools.py @@ -0,0 +1,94 @@ +from functools import partial +from types import ModuleType +from typing import Any +from typing import Callable + +import jax +import jax.numpy as jnp +from jax import grad + + +def diagpart(M): + return jnp.diag(jnp.diag(M)) + + +def vectorize_kfunc(k): + return jax.vmap(jax.vmap(k, in_axes=(None, 0)), in_axes=(0, None)) + + +def op_k_apply(k: Callable[[float, float], float], L_op, R_op): + return R_op(L_op(k, 0), 1) + + +def make_block(k, L_op, R_op): + return vectorize_kfunc(op_k_apply(k, L_op, R_op)) + + +def get_kernel_block_ops( + k, ops_left, ops_right, output_dim=1, type_pkg: ModuleType = jnp +): + def k_super(x, y): + I_mat = type_pkg.eye(output_dim) + blocks = [ + [ + type_pkg.kron(make_block(k, L_op, R_op)(x, y), I_mat) + for R_op in ops_right + ] + for L_op in ops_left + ] + return type_pkg.block(blocks) + + return k_super + + +def eval_k(k, index): + return k + + +def diff_k(k, index): + return grad(k, index) + + +def diff2_k(k, index): + return grad(grad(k, index), index) + + +def get_selected_grad(k, index, selected_index): + gradf = grad(k, index) + + def selgrad(*args): + return gradf(*args)[selected_index] + + return selgrad + + +def dx_k(k, index): + return get_selected_grad(k, index, 1) + + +def dxx_k(k, index): + return get_selected_grad(get_selected_grad(k, index, 1), index, 1) + + +def dt_k(k, index): + return get_selected_grad(k, index, 0) + + +def nth_derivative_1d(k: Callable, index: int, n: int) -> Callable: + """ + Computes derivative of order n of k with respect to index and returns the resulting + function as a callable + """ + result = k + for _ in range(n): + result = jax.grad(result, argnums=index) + return result + + +def nth_derivative_operator_1d(n): + """ + Computes the operator associated to the nth derivative, which maps functions to + functions. These now match the format of the operators defined above, like diff_k, + diff2_k. + """ + return partial(nth_derivative_1d, n=n) diff --git a/pysindy/jsindy/optim/__init__.py b/pysindy/jsindy/optim/__init__.py new file mode 100644 index 000000000..c934dc5cd --- /dev/null +++ b/pysindy/jsindy/optim/__init__.py @@ -0,0 +1,11 @@ +from .optimizers import AlternatingActiveSetLMSolver +from .optimizers import AnnealedAlternatingActiveSetLMSolver +from .optimizers import LMSettings +from .optimizers import LMSolver + +__all__ = [ + "AlternatingActiveSetLMSolver", + "LMSolver", + "LMSettings", + "AnnealedAlternatingActiveSetLMSolver", +] diff --git a/pysindy/jsindy/optim/optimizers.py b/pysindy/jsindy/optim/optimizers.py new file mode 100644 index 000000000..c15a80714 --- /dev/null +++ b/pysindy/jsindy/optim/optimizers.py @@ -0,0 +1,376 @@ +from dataclasses import dataclass +from functools import partial + +import jax +import jax.numpy as jnp +from jax.scipy.linalg import block_diag +from jsindy.optim.solvers.alt_active_set_lm_solver import AlternatingActiveSolve +from jsindy.optim.solvers.lm_solver import CholeskyLM +from jsindy.optim.solvers.lm_solver import LMSettings +from jsindy.trajectory_model import TrajectoryModel +from jsindy.util import full_data_initialize +from jsindy.util import partial_obs_initialize + + +class LMSolver: + def __init__(self, beta_reg=1.0, solver_settings=LMSettings()): + self.solver_settings = solver_settings + self.beta_reg = beta_reg + + def run(self, model, params): + # init_params = params["init_params"] + params["data_weight"] = 1 / (params["sigma2_est"] + 0.01) + params["colloc_weight"] = 10 + + if model.is_partially_observed is False: + z0, theta0 = full_data_initialize( + model.t, + model.x, + model.traj_model, + model.dynamics_model, + sigma2_est=params["sigma2_est"] + 0.01, + input_orders=model.input_orders, + ode_order=model.ode_order, + ) + else: + z0, theta0 = partial_obs_initialize( + model.t, + model.y, + model.v, + model.traj_model, + model.dynamics_model, + sigma2_est=params["sigma2_est"] + 0.01, + input_orders=model.input_orders, + ode_order=model.ode_order, + ) + z_theta_init = jnp.hstack([z0, theta0.flatten()]) + + def resid_func(z_theta): + z = z_theta[: model.traj_model.tot_params] + theta = z_theta[model.traj_model.tot_params :].reshape( + model.dynamics_model.param_shape + ) + return model.residuals.residual( + z, theta, params["data_weight"], params["colloc_weight"] + ) + + jac_func = jax.jacrev(resid_func) + damping_matrix = block_diag( + model.traj_model.regmat, model.dynamics_model.regmat + ) + + lm_prob = LMProblem(resid_func, jac_func, damping_matrix) + z_theta, opt_results = CholeskyLM( + z_theta_init, lm_prob, self.beta_reg, self.solver_settings + ) + z = z_theta[: model.traj_model.tot_params] + theta = z_theta[model.traj_model.tot_params :].reshape( + model.dynamics_model.param_shape + ) + + return z, theta, opt_results, params + + +class LMProblem: + def __init__(self, resid_func, jac_func, damping_matrix): + self.resid_func = resid_func + self.jac_func = jac_func + self.damping_matrix = damping_matrix + + +class AlternatingActiveSetLMSolver: + def __init__( + self, + beta_reg=1.0, + colloc_weight_scale=100.0, + fixed_colloc_weight=None, + fixed_data_weight=None, + solver_settings=LMSettings(), + max_inner_iterations=200, + sparsifier=None, + ): + self.solver_settings = solver_settings + self.beta_reg = beta_reg + self.colloc_weight_scale = colloc_weight_scale + self.fixed_colloc_weight = fixed_colloc_weight + self.fixed_data_weight = fixed_data_weight + self.max_inner_iterations = max_inner_iterations + self.sparsifier = sparsifier + self.params = {} + + def __str__(self): + return f""" + Alternating Active Set Optimizer + beta_reg: {self.beta_reg}, + sparsifier: {self.sparsifier.__str__()} + data_weight: {self.params['data_weight']} + colloc_weight: {self.params['colloc_weight']} + """ + + def run(self, model, params): + if self.fixed_data_weight is not None: + params["data_weight"] = self.fixed_data_weight + else: + params["data_weight"] = 1 / (params["sigma2_est"] + 0.001) + if self.fixed_colloc_weight is None: + params["colloc_weight"] = self.colloc_weight_scale * params["data_weight"] + else: + params["colloc_weight"] = self.fixed_colloc_weight + print(params) + + if model.is_partially_observed is False: + z0, theta0 = full_data_initialize( + model.t, + model.x, + model.traj_model, + model.dynamics_model, + sigma2_est=params["sigma2_est"] + 0.01, + input_orders=model.input_orders, + ode_order=model.ode_order, + ) + else: + z0, theta0 = partial_obs_initialize( + model.t, + model.y, + model.v, + model.traj_model, + model.dynamics_model, + sigma2_est=params["sigma2_est"] + 0.01, + input_orders=model.input_orders, + ode_order=model.ode_order, + ) + z_theta_init = jnp.hstack([z0, theta0.flatten()]) + + def resid_func(z_theta): + z = z_theta[: model.traj_model.tot_params] + theta = z_theta[model.traj_model.tot_params :].reshape( + model.dynamics_model.param_shape + ) + return model.residuals.residual( + z, theta, params["data_weight"], params["colloc_weight"] + ) + + jac_func = jax.jacrev(resid_func) + damping_matrix = block_diag( + model.traj_model.regmat, model.dynamics_model.regmat + ) + + lm_prob = LMProblem(resid_func, jac_func, damping_matrix) + if self.solver_settings.show_progress: + print("Warm Start") + + z_theta, lm_opt_results = CholeskyLM( + z_theta_init, lm_prob, self.beta_reg, self.solver_settings + ) + z = z_theta[: model.traj_model.tot_params] + theta = z_theta[model.traj_model.tot_params :].reshape( + model.dynamics_model.param_shape + ) + + if self.solver_settings.show_progress: + print("Model after smooth warm start") + model.print(theta=theta) + print("Alternating Activeset Sparsifier") + + def F_split(z, theta): + data_weight = params["data_weight"] + colloc_weight = params["colloc_weight"] + return model.residuals.residual(z, theta, data_weight, colloc_weight) + + # fix this later + aaslm_prob = AASLMProblem( + system_dim=model.traj_model.system_dim, + num_features=model.dynamics_model.num_features, + F_split=F_split, + t_colloc=model.t_colloc, + interpolant=model.traj_model, + state_param_regmat=model.traj_model.regmat, + model_param_regmat=model.dynamics_model.regmat, + feature_library=model.dynamics_model.feature_map, + ) + + z, theta, aas_lm_opt_results = AlternatingActiveSolve( + z0=z, + theta0=theta, + residual_objective=aaslm_prob, + beta=self.beta_reg, + show_progress=self.solver_settings.show_progress, + max_inner_iter=self.max_inner_iterations, + sparsifier=self.sparsifier, + input_orders=model.input_orders, + ode_order=model.ode_order, + ) + theta = theta.reshape(model.dynamics_model.param_shape) + self.params = params + + return z, theta, [lm_opt_results, aas_lm_opt_results], params + + +class AnnealedAlternatingActiveSetLMSolver: + def __init__( + self, + beta_reg=1.0, + colloc_weight_scale=100.0, + fixed_colloc_weight=None, + fixed_data_weight=None, + solver_settings=LMSettings(), + max_inner_iterations=200, + sparsifier=None, + num_annealing_steps=4, + anneal_colloc_mult=5.0, + anneal_beta_mult=2.0, + ): + self.solver_settings = solver_settings + self.beta_reg = beta_reg + self.colloc_weight_scale = colloc_weight_scale + self.fixed_colloc_weight = fixed_colloc_weight + self.fixed_data_weight = fixed_data_weight + self.max_inner_iterations = max_inner_iterations + self.sparsifier = sparsifier + + self.num_annealing_steps = num_annealing_steps + self.anneal_colloc_mult = anneal_colloc_mult + self.anneal_beta_mult = anneal_beta_mult + + def __str__(self): + return f""" + Annealed Alternating Active Set Optimizer + beta_reg: {self.beta_reg}, + sparsifier: {self.sparsifier.__str__()} + data_weight: {self.fixed_data_weight} + colloc_weight: {self.fixed_colloc_weight} + annealing_steps: {self.anneal_colloc_mult} + anneal_colloc_mult: {self.anneal_colloc_mult} + anneal_beta_mult: {self.anneal_beta_mult} + """ + + def run(self, model, params): + sigma2est = params.get("sigma2_est", 0) + if sigma2est is None: + # If not using data-adapted interpolant + sigma2est = 0.0 + if self.fixed_data_weight is not None: + params["data_weight"] = self.fixed_data_weight + else: + params["data_weight"] = 1 / (sigma2est + 0.001) + if self.fixed_colloc_weight is None: + params["colloc_weight"] = self.colloc_weight_scale * params["data_weight"] + else: + params["colloc_weight"] = self.fixed_colloc_weight + print(params) + if model.is_partially_observed is False: + z0, theta0 = full_data_initialize( + model.t, + model.x, + model.traj_model, + model.dynamics_model, + sigma2_est=sigma2est + 0.01, + input_orders=model.input_orders, + ode_order=model.ode_order, + ) + else: + z0, theta0 = partial_obs_initialize( + model.t, + model.y, + model.v, + model.traj_model, + model.dynamics_model, + sigma2_est=sigma2est + 0.01, + input_orders=model.input_orders, + ode_order=model.ode_order, + ) + z_theta_init = jnp.hstack([z0, theta0.flatten()]) + + num_steps = self.num_annealing_steps + dataweight_vals = [params["data_weight"]] * num_steps + colloc_weight_vals = [ + params["colloc_weight"] * (self.anneal_colloc_mult ** (i + 1 - num_steps)) + for i in range(num_steps) + ] + beta_reg_vals = [ + self.beta_reg * (self.anneal_beta_mult ** (num_steps - i - 1)) + for i in range(num_steps) + ] + parameter_sequence = zip(dataweight_vals, colloc_weight_vals, beta_reg_vals) + + def resid_func(z_theta, data_weight, colloc_weight): + z = z_theta[: model.traj_model.tot_params] + theta = z_theta[model.traj_model.tot_params :].reshape( + model.dynamics_model.param_shape + ) + return model.residuals.residual(z, theta, data_weight, colloc_weight) + + full_lm_opt_results = [] + + for data_weight, colloc_weight, beta_reg in parameter_sequence: + residual_function = partial( + resid_func, data_weight=data_weight, colloc_weight=colloc_weight + ) + jac_func = jax.jacrev(residual_function) + damping_matrix = block_diag( + model.traj_model.regmat, model.dynamics_model.regmat + ) + + lm_prob = LMProblem(residual_function, jac_func, damping_matrix) + if self.solver_settings.show_progress: + print( + f"Solving for data_weight = {data_weight}, colloc_weight = {colloc_weight} beta_reg = {beta_reg}" + ) + z_theta, lm_opt_results = CholeskyLM( + z_theta_init, lm_prob, self.beta_reg, self.solver_settings + ) + z_theta_init = z_theta + full_lm_opt_results.append(lm_opt_results) + z = z_theta[: model.traj_model.tot_params] + theta = z_theta[model.traj_model.tot_params :].reshape( + model.dynamics_model.param_shape + ) + + if self.solver_settings.show_progress: + print("Model after smooth warm start") + model.print(theta=theta) + print("Alternating Activeset Sparsifier") + + def F_split(z, theta): + data_weight = params["data_weight"] + colloc_weight = params["colloc_weight"] + return model.residuals.residual(z, theta, data_weight, colloc_weight) + + # fix this later + aaslm_prob = AASLMProblem( + system_dim=model.traj_model.system_dim, + num_features=model.dynamics_model.num_features, + F_split=F_split, + t_colloc=model.t_colloc, + interpolant=model.traj_model, + state_param_regmat=model.traj_model.regmat, + model_param_regmat=model.dynamics_model.regmat, + feature_library=model.dynamics_model.feature_map, + ) + + z, theta, aas_lm_opt_results = AlternatingActiveSolve( + z0=z, + theta0=theta, + residual_objective=aaslm_prob, + beta=self.beta_reg, + show_progress=self.solver_settings.show_progress, + max_inner_iter=self.max_inner_iterations, + sparsifier=self.sparsifier, + input_orders=model.input_orders, + ode_order=model.ode_order, + ) + theta = theta.reshape(model.dynamics_model.param_shape) + + return z, theta, [full_lm_opt_results, aas_lm_opt_results], params + + +@dataclass +class AASLMProblem: + system_dim: int + num_features: int + F_split: callable + t_colloc: jax.Array + interpolant: TrajectoryModel + state_param_regmat: jax.Array + model_param_regmat: jax.Array + feature_library: callable diff --git a/pysindy/jsindy/optim/solvers/__init__.py b/pysindy/jsindy/optim/solvers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pysindy/jsindy/optim/solvers/alt_active_set_lm_solver.py b/pysindy/jsindy/optim/solvers/alt_active_set_lm_solver.py new file mode 100644 index 000000000..f2f6e60db --- /dev/null +++ b/pysindy/jsindy/optim/solvers/alt_active_set_lm_solver.py @@ -0,0 +1,210 @@ +import time +from dataclasses import dataclass + +import jax +import jax.numpy as jnp +import numpy as np + +from pysindy import STLSQ + + +def norm2(x): + if len(x) == 0: + return 0.0 + else: + return jnp.sum(x**2) + + +def maxnorm(x): + if len(x) == 0: + return 0.0 + else: + return jnp.max(jnp.abs(x)) + + +@dataclass +class ConvHistory: + history: dict + convergence_tag: str + + +class pySindySparsifier: + def __init__(self, pysindy_optimizer=None): + if pysindy_optimizer is None: + pysindy_optimizer = STLSQ(threshold=0.25, alpha=0.01) + self.optimizer = pysindy_optimizer + + def __str__(self): + return self.optimizer.__str__() + + def __call__(self, feat_X, Xdot): + self.optimizer.fit(feat_X, Xdot) + theta = jnp.array(self.optimizer.coef_) + return theta + + +def AlternatingActiveSolve( + z0, + theta0, + residual_objective, + beta, + sparsifier: pySindySparsifier = None, + show_progress: bool = True, + max_inner_iter=50, + input_orders=(0,), + ode_order=1, +): + start_time = time.time() + if sparsifier is None: + sparsifier = pySindySparsifier() + ##Initialize with LMSolve results and + # Set up objective + + z = z0 + theta = theta0 + + # init_shape = (residual_objective.system_dim,residual_objective.num_features) + theta_shape = (residual_objective.num_features, residual_objective.system_dim) + + # theta = theta.flatten() + + # resfunc = jax.jit(lambda z,theta:residual_objective.F_split(z,theta.reshape(init_shape).T)) + resfunc = jax.jit( + lambda z, theta: residual_objective.F_split(z, theta.reshape(theta_shape)) + ) + + jac_z = jax.jit(jax.jacrev(resfunc, argnums=0)) + jac_theta = jax.jit(jax.jacrev(resfunc, argnums=1)) + + t_colloc = residual_objective.t_colloc + interp = residual_objective.interpolant + + tol = 1e-6 + + @jax.jit + def loss(z, theta): + return (1 / 2) * jnp.sum(resfunc(z, theta) ** 2) + ( + beta / 2 + ) * z.T @ residual_objective.state_param_regmat @ z + + def update_coefs(z): + X_inputs = jnp.hstack([interp.derivative(t_colloc, z, k) for k in input_orders]) + feat_X = residual_objective.feature_library(X_inputs) + + Xdot = interp.derivative(t_colloc, z, ode_order) + theta = sparsifier(feat_X, Xdot).T + return theta.flatten() + + def update_params(z, theta, prox_reg_init=1.0): + active_set = jnp.where(jnp.abs(theta) > 1e-7)[0] + m = len(active_set) + prox_reg = prox_reg_init + K = residual_objective.state_param_regmat + I = residual_objective.model_param_regmat[active_set][:, active_set] + H = jax.scipy.linalg.block_diag(K, I) + K0 = jax.scipy.linalg.block_diag(K, 0 * I) + obj_val = loss(z, theta) + + Jz = jac_z(z, theta) + Jtheta = jac_theta(z, theta)[:, active_set] + J = jnp.hstack([Jz, Jtheta]) + F = resfunc(z, theta) + JtJ = J.T @ J + rhs = J.T @ F + beta * jnp.hstack([K @ z, jnp.zeros(m)]) + + loss_vals = [obj_val] + gnorms = [maxnorm(rhs)] + + max_line_search = 20 + for i in range(max_inner_iter): + for k in range(max_line_search): + succeeded = False + + M = JtJ + beta * K0 + prox_reg * H + step = jnp.linalg.solve(M, rhs) + + dz = step[:-m] + dtheta = step[-m:] + + cand_z = z - dz + cand_theta = theta.at[active_set].set(theta[active_set] - dtheta) + + new_obj_val = loss(cand_z, cand_theta) + + predicted_decrease = obj_val - ( + 0.5 * norm2(F - J @ step) + 0.5 * beta * cand_z.T @ K @ cand_z + ) + + true_decrease = obj_val - new_obj_val + + rho = true_decrease / predicted_decrease + if rho > 0.001 and true_decrease > 0: + succeeded = True + obj_val = new_obj_val + z = cand_z + theta = cand_theta + break + else: + prox_reg = 2 * prox_reg + + if succeeded is False: + print("Line search Failed") + break + Jz = jac_z(z, theta) + Jtheta = jac_theta(z, theta)[:, active_set] + J = jnp.hstack([Jz, Jtheta]) + F = resfunc(z, theta) + JtJ = J.T @ J + rhs = J.T @ F + beta * jnp.hstack([K @ z, jnp.zeros(m)]) + prox_reg = jnp.maximum(1 / 3, 1 - (2 * rho - 1) ** 3) * prox_reg + gnorm = maxnorm(rhs) + gnorms.append(gnorm) + loss_vals.append(obj_val) + if gnorm < tol: + break + + return z, theta, gnorms, loss_vals, prox_reg + + all_gnorms = [] + all_objval = [] + cum_time = [] + prox_reg = 1.0 + finished = False + support = jnp.where(jnp.abs(theta) > 1e-7)[0] + for i in range(20): + theta = update_coefs(z) + new_support = jnp.where(jnp.abs(theta) > 1e-7)[0] + z_new, theta, gnorms, objvals, prox_reg = update_params(z, theta, 2 * prox_reg) + all_gnorms.append(gnorms) + all_objval.append(objvals) + dz = jnp.linalg.norm(z - z_new) + z = z_new + if finished == True: + cum_time.append(time.time() - start_time) + break + + if set(np.array(support)) == set(np.array(new_support)): + # Run 1 more iteration just to be sure + if show_progress: + print("Active set stabilized") + convergence_tag = "stable-active-set" + finished = True + else: + sym_diff = set(np.array(support)).symmetric_difference( + set(np.array(new_support)) + ) + if show_progress: + print(f"{len(sym_diff)} active coeffs changed") + support = new_support + + cum_time.append(time.time() - start_time) + convergence_tag = "maximum-iterations" + conv_hist = ConvHistory( + history={ + "gnorms": all_gnorms, + "objval": all_objval, + "cumlative_time": cum_time, + }, + convergence_tag=convergence_tag, + ) + return z, theta, conv_hist diff --git a/pysindy/jsindy/optim/solvers/lm_solver.py b/pysindy/jsindy/optim/solvers/lm_solver.py new file mode 100644 index 000000000..466421309 --- /dev/null +++ b/pysindy/jsindy/optim/solvers/lm_solver.py @@ -0,0 +1,372 @@ +import time +from dataclasses import dataclass +from dataclasses import field +from typing import Callable +from typing import Union + +import jax +import jax.numpy as jnp +from jax.scipy.linalg import cho_factor +from jax.scipy.linalg import cho_solve +from tqdm.auto import tqdm + + +@dataclass +class LMSettings: + """ + max_iter : int, optional + by default 501 + atol_gradnorm : float, optional + Gradient norm stopping condition absolute tolerance + atol_gn_decrement: float, optional + Gauss-Newton decrement stopping condition absolute tolerance + cmin : float, optional + Minimum armijo ratio to accept step, by default 0.05 + line_search_increase_ratio : float, optional + constant to increase reg strength by in backtracking line search, by default 1.5 + max_line_search_iterations : int, optional + by default 20 + min_alpha : float, optional + min damping strength, by default 1e-9 + max_alpha : float, optional + max damping strength, by default 50. + init_alpha : float, optional + initial damping strength, by default 3. + step_adapt_multipler : float, optional + value to use for adapting alpha, by default 1.2 + callback : callable, optional + function called to print another loss each iteration, by default None + print_every : int, optional + How often to print convergence data, by default 100 + """ + + max_iter: int = 501 + atol_gradnorm: float = 1e-8 + cmin: float = 0.05 + line_search_increase_ratio: float = 1.5 + max_line_search_iterations: int = 20 + min_alpha: float = 1e-12 + max_alpha: float = 100.0 + init_alpha: float = 3.0 + step_adapt_multiplier: float = 1.2 + callback: Union[Callable, None] = None + print_every: int = 200 + track_iterates: bool = False + show_progress: bool = True + use_jit: bool = True + no_tqdm: bool = False + + +@dataclass +class ConvergenceHistory: + track_iterates: bool = False + loss_vals: list = field(default_factory=list) + gradnorm: list = field(default_factory=list) + iterate_history: list = field(default_factory=list) + improvement_ratios: list = field(default_factory=list) + alpha_vals: list = field(default_factory=list) + cumulative_time: list = field(default_factory=list) + linear_system_rel_residual: list = field(default_factory=list) + regularization_loss_contribution: list = field(default_factory=list) + convergence_tag: str = "not-yet-run" + + def update( + self, + loss, + gradnorm, + iterate, + armijo_ratio, + alpha, + cumulative_time, + linear_system_rel_residual, + regularization_loss_contribution=0.0, + ): + # Append the new values to the corresponding lists + self.loss_vals.append(loss) + self.gradnorm.append(gradnorm) + self.improvement_ratios.append(armijo_ratio) + self.alpha_vals.append(alpha) + self.cumulative_time.append(cumulative_time) + self.linear_system_rel_residual.append(linear_system_rel_residual) + self.regularization_loss_contribution.append(regularization_loss_contribution) + + # Conditionally track iterates if enabled + if self.track_iterates: + self.iterate_history.append(iterate) + + def finish(self, convergence_tag="finished"): + # Convert lists to JAX arrays + self.loss_vals = jnp.array(self.loss_vals) + self.gradnorm = jnp.array(self.gradnorm) + self.improvement_ratios = jnp.array(self.improvement_ratios) + self.alpha_vals = jnp.array(self.alpha_vals) + self.cumulative_time = jnp.array(self.cumulative_time) + self.linear_system_rel_residual = jnp.array(self.linear_system_rel_residual) + self.regularization_loss_contribution = jnp.array( + self.regularization_loss_contribution + ) + if self.track_iterates: + self.iterate_history = jnp.array(self.iterate_history) + self.convergence_tag = convergence_tag + + +def print_progress( + i, + loss, + gradnorm, + alpha, + improvement_ratio, +): + print( + f"Iteration {i}, loss = {loss:.4}," + f" gradnorm = {gradnorm:.4}, alpha = {alpha:.4}," + f" improvement_ratio = {improvement_ratio:.4}" + ) + + +def CholeskyLM(init_params, model, beta, optSettings: LMSettings = LMSettings()): + """Adaptively regularized Levenberg Marquardt optimizer + Parameters + ---------- + init_params : jax array + initial guess + model : + Object that contains model.F, and model.jac, and model.damping_matrix + beta : float + (global) regularization strength + optSettings: LMParams + optimizer settings + + Returns + ------- + solution + approximate minimizer + convergence_history + ConvergenceHistory tracker + """ + conv_history = ConvergenceHistory(optSettings.track_iterates) + start_time = time.time() + params = init_params.copy() + J = model.jac_func(params) + residuals = model.resid_func(params) + damping_matrix = model.damping_matrix + alpha = optSettings.init_alpha + if optSettings.show_progress and optSettings.no_tqdm is False: + loop_wrapper = tqdm + else: + loop_wrapper = lambda x: x + + regularization_contribution = (1 / 2) * beta * params.T @ damping_matrix @ params + conv_history.update( + loss=(1 / 2) * jnp.sum(residuals**2) + regularization_contribution, + gradnorm=jnp.linalg.norm(J.T @ residuals + beta * damping_matrix @ params), + iterate=params, + armijo_ratio=1.0, + alpha=alpha, + cumulative_time=time.time() - start_time, + linear_system_rel_residual=0.0, + regularization_loss_contribution=regularization_contribution, + ) + + def evaluate_objective(params): + """ + Queries the objective, computing jacobian and residuals at + current parameters to build a subproblem + """ + J = model.jac_func(params) + residuals = model.resid_func(params) + damping_matrix = model.damping_matrix + loss = (1 / 2) * jnp.sum(residuals**2) + ( + 1 / 2 + ) * beta * params.T @ damping_matrix @ params + JtJ = J.T @ J + rhs = J.T @ residuals + beta * damping_matrix @ params + return J, residuals, damping_matrix, loss, JtJ, rhs + + if optSettings.use_jit is True: + evaluate_objective = jax.jit(evaluate_objective) + + @jax.jit + def compute_step( + params, alpha, J, JtJ, residuals, rhs, previous_loss, damping_matrix + ): + """ + Solves subproblem constructed by evaluate_objective + """ + # Form and solve linear system for step + M = JtJ + (alpha + beta) * damping_matrix + # Add small nugget + M = M + 1e-12 * jnp.diag(jnp.diag(M)) + Mchol = cho_factor(M) + step = cho_solve(Mchol, rhs) + Jstep = J @ step + + # Apply 1 step of iterative refinement + linear_residual = ( + J.T @ (Jstep - residuals) + + (alpha + beta) * damping_matrix @ step + - beta * damping_matrix @ params + ) + step = step - cho_solve(Mchol, linear_residual) + + # Track the linear system residual + linear_residual = ( + J.T @ (Jstep - residuals) + + (alpha + beta) * damping_matrix @ step + - beta * damping_matrix @ params + ) + linear_system_rel_residual = jnp.linalg.norm(linear_residual) / jnp.linalg.norm( + rhs + ) + + # Compute step and if we decreased loss + new_params = params - step + new_reg_piece = (1 / 2) * beta * new_params.T @ damping_matrix @ new_params + new_loss = (1 / 2) * jnp.sum(model.resid_func(new_params) ** 2) + new_reg_piece + predicted_loss = (1 / 2) * jnp.sum((Jstep - residuals) ** 2) + new_reg_piece + improvement_ratio = (previous_loss - new_loss) / ( + previous_loss - predicted_loss + ) + + return ( + step, + new_params, + new_loss, + improvement_ratio, + linear_system_rel_residual, + new_reg_piece, + ) + + def LevenbergMarquadtUpdate(params, alpha): + r"""Minimizes the local quadratic approximation to a function + and performs a line search on the proximal regularization alpha + to ensure sufficient decrease. + + Solves for the negative optimal update, using proximal regularization + to control how close (in an L2 sense) the update is to zero. + Optimization variable is :math:`u`, the negative update between + previous iterate :math:`x^-` and next iterate :math:`x^+` + .. math:: + \min_u \|Ju + r\|^2 + + \text{reg_weight} \|x^--u\|^2_K + + \alpha \|step\|^2_K + where :math:`r` is the residual vector, and the damping matrix + :math:`damping_matrix` adjusts the L-2 regularization to be an elliptical norm + (e.g. an RKHS norm) + Args: + params: Current parametrization value of function to approximate + alpha: damping strength. Larger values shrink the step size. + """ + J, residuals, damping_matrix, loss, JtJ, rhs = evaluate_objective(params) + alpha = jnp.clip(alpha, optSettings.min_alpha, optSettings.max_alpha) + for i in range(optSettings.max_line_search_iterations): + ( + step, + new_params, + new_loss, + improvement_ratio, + linear_system_rel_residual, + new_reg_piece, + ) = compute_step( + params, alpha, J, JtJ, residuals, rhs, loss, damping_matrix + ) + + if improvement_ratio >= optSettings.cmin: + # Check if we get at least some proportion of predicted improvement + succeeded = True + return ( + new_params, + new_loss, + rhs, + improvement_ratio, + alpha, + linear_system_rel_residual, + new_reg_piece, + succeeded, + ) + else: + alpha = optSettings.line_search_increase_ratio * alpha + succeeded = False + return ( + new_params, + new_loss, + rhs, + improvement_ratio, + alpha, + linear_system_rel_residual, + new_reg_piece, + succeeded, + ) + + for i in loop_wrapper(range(optSettings.max_iter)): + ( + params, + loss, + rhs, + improvement_ratio, + alpha, + linear_system_rel_residual, + reg_piece, + succeeded, + ) = LevenbergMarquadtUpdate(params, alpha) + + # Get new value for alpha + multiplier = optSettings.step_adapt_multiplier + if improvement_ratio <= 0.2: + alpha = multiplier * alpha + if improvement_ratio >= 0.8: + alpha = alpha / multiplier + + if not succeeded: + print("Line Search Failed!") + print("Final Iteration Results") + if optSettings.show_progress is True: + print_progress( + i, loss, conv_history.gradnorm[-1], alpha, improvement_ratio + ) + conv_history.finish(convergence_tag="failed-line-search") + return params, conv_history + model_decrease = (conv_history.loss_vals[-1] - loss) / improvement_ratio + conv_history.update( + loss=loss, + gradnorm=jnp.linalg.norm(rhs), + iterate=params, + armijo_ratio=improvement_ratio, + alpha=alpha, + cumulative_time=time.time() - start_time, + linear_system_rel_residual=linear_system_rel_residual, + regularization_loss_contribution=reg_piece, + ) + + if conv_history.gradnorm[-1] <= optSettings.atol_gradnorm: + conv_history.finish(convergence_tag="atol-gradient-norm") + if optSettings.show_progress is True: + print_progress( + i, loss, conv_history.gradnorm[-1], alpha, improvement_ratio + ) + return params, conv_history + + if i > 50: + gradnorm_stagnate = ( + conv_history.gradnorm[-1] >= 0.99 * conv_history.gradnorm[-25] + ) + fval_stagnate = ( + conv_history.loss_vals[-1] >= conv_history.loss_vals[-25] - 1e-9 + ) + if gradnorm_stagnate and fval_stagnate: + conv_history.finish(convergence_tag="stagnation") + if optSettings.show_progress is True: + print_progress( + i, loss, conv_history.gradnorm[-1], alpha, improvement_ratio + ) + return params, conv_history + + if i % optSettings.print_every == 0 or i <= 5 or i == optSettings.max_iter - 1: + if optSettings.show_progress is True: + print_progress( + i, loss, conv_history.gradnorm[-1], alpha, improvement_ratio + ) + if optSettings.callback: + optSettings.callback(params) + conv_history.finish(convergence_tag="maximum-iterations") + return params, conv_history diff --git a/pysindy/jsindy/residual_functions.py b/pysindy/jsindy/residual_functions.py new file mode 100644 index 000000000..e6122412c --- /dev/null +++ b/pysindy/jsindy/residual_functions.py @@ -0,0 +1,100 @@ +import jax +import jax.numpy as jnp +from jsindy.dynamics_model import FeatureLinearModel +from jsindy.trajectory_model import TrajectoryModel + + +class FullDataTerm: + def __init__(self, t, x, trajectory_model: TrajectoryModel): + self.t = t + self.x = x + self.trajectory_model = trajectory_model + self.system_dim = x.shape[1] + self.num_obs = len(t) + self.total_size = self.num_obs * self.system_dim + + def residual(self, z): + # TODO: Code optimization, directly adapt trajectoy_model + # To the observation locations + return self.x - self.trajectory_model(self.t, z) + + def residual_flat(self, z): + return self.residual(z).flatten() + + +class PartialDataTerm: + def __init__(self, t, y, v, trajectory_model: TrajectoryModel): + self.t = t + self.y = y + self.v = v + self.trajectory_model = trajectory_model + self.system_dim = v.shape[1] + self.num_obs = len(t) + self.total_size = len(t) + + def residual(self, z): + pred_y = jnp.sum(self.trajectory_model(self.t, z) * self.v, axis=1) + return self.y - pred_y + + def residual_flat(self, z): + return self.residual(z) + + +class CollocationTerm: + def __init__( + self, + t_colloc, + w_colloc, + trajectory_model: TrajectoryModel, + dynamics_model: FeatureLinearModel, + input_orders=(0,), + ode_order=1, + ): + self.t_colloc = t_colloc + self.w_colloc = w_colloc + assert len(t_colloc) == len(w_colloc) + self.num_colloc = len(t_colloc) + self.system_dim = trajectory_model.system_dim + self.trajectory_model = trajectory_model + self.dynamics_model = dynamics_model + self.input_orders = input_orders + self.ode_order = ode_order + + def residual(self, z, theta): + X_inputs = jnp.hstack( + [ + self.trajectory_model.derivative(self.t_colloc, z, k) + for k in self.input_orders + ] + ) + + Xdot_pred = self.dynamics_model(X_inputs, theta) + Xdot_true = self.trajectory_model.derivative( + self.t_colloc, z, diff_order=self.ode_order + ) + return jnp.sqrt(self.w_colloc[:, None]) * (Xdot_true - Xdot_pred) + + def residual_flat(self, z, theta): + return self.residual(z, theta).flatten() + + +class JointResidual: + def __init__( + self, data_term: FullDataTerm | PartialDataTerm, colloc_term: CollocationTerm + ): + self.data_term = data_term + self.colloc_term = colloc_term + + def data_residual(self, z): + return self.data_term.residual_flat(z) + + def colloc_residual(self, z, theta): + return self.colloc_term.residual_flat(z, theta) + + def residual(self, z, theta, data_weight, colloc_weight): + return jnp.hstack( + [ + jnp.sqrt(data_weight) * self.data_residual(z), + jnp.sqrt(colloc_weight) * self.colloc_residual(z, theta), + ] + ) diff --git a/pysindy/jsindy/sindy_model.py b/pysindy/jsindy/sindy_model.py new file mode 100644 index 000000000..91c573726 --- /dev/null +++ b/pysindy/jsindy/sindy_model.py @@ -0,0 +1,209 @@ +import jax +jax.config.update('jax_enable_x64',True) +import jax.numpy as jnp +from jsindy.util import check_is_partial_data,get_collocation_points_weights,get_equations +from jsindy.trajectory_model import TrajectoryModel +from jsindy.dynamics_model import FeatureLinearModel +from jsindy.residual_functions import ( + FullDataTerm,PartialDataTerm,CollocationTerm, + JointResidual) +from jsindy.optim import LMSolver +from textwrap import dedent + +class JSINDyModel(): + def __init__( + self, + trajectory_model:TrajectoryModel, + dynamics_model:FeatureLinearModel, + optimizer:LMSolver = LMSolver(), + feature_names: list[str] = None, + input_orders: tuple[int, ...] = (0,), + ode_order: int = 1, + ): + self.traj_model = trajectory_model + self.dynamics_model = dynamics_model + self.optimizer = optimizer + input_orders = tuple(sorted(input_orders)) + assert input_orders[0] == 0 + self.input_orders = input_orders + self.ode_order = ode_order + self.variable_names = feature_names.copy() + + if self.input_orders ==(0,): + self.feature_names = feature_names + else: + self.feature_names = ( + feature_names + + sum([ + [f"({name}{"'"*k})" for name in feature_names] for k in self.input_orders[1:] + ],[]) + ) + + def __str__(self): + traj_model_str = (self.traj_model.__str__()) + dynamics_model_str = (self.dynamics_model.__str__()) + optimizer_str = (self.optimizer.__str__()) + model_string = ( + f""" + --------Trajectory Model-------- + {traj_model_str} + + --------Feature Library--------- + {dynamics_model_str} + + --------Optimizer Setup-------- + {optimizer_str} + """ + ) + return '\n'.join(map(lambda x:x.lstrip(),model_string.__str__().split('\n'))) + + def initialize_fit_full_obs( + self, + t, + x, + t_colloc = None, + w_colloc = None, + params = None, + ): + if params is None: + params = dict() + t_colloc,w_colloc = _setup_colloc(t,t_colloc,w_colloc) + + self.t_colloc = t_colloc + self.w_colloc = w_colloc + self.t = t + self.x = x + + params = self.traj_model.initialize( + self.t,self.x,t_colloc,params + ) + + params = self.dynamics_model.initialize( + self.t,self.x,params,self.input_orders + ) + + self.data_term = FullDataTerm( + self.t,self.x,self.traj_model + ) + self.colloc_term = CollocationTerm( + self.t_colloc,self.w_colloc, + self.traj_model,self.dynamics_model, + input_orders = self.input_orders,ode_order = self.ode_order + ) + self.residuals = JointResidual(self.data_term,self.colloc_term) + return params + + def initialize_fit_partial_obs( + self, + t, + y, + v, + t_colloc = None, + w_colloc = None, + params= None + ): + if params is None: + params = dict() + t_colloc,w_colloc = _setup_colloc(t,t_colloc,w_colloc) + + self.t_colloc = t_colloc + self.w_colloc = w_colloc + self.t = t + self.y = y + self.v = v + + params = self.traj_model.initialize_partialobs( + self.t,self.y,self.v,t_colloc,params + ) + + params = self.dynamics_model.initialize_partialobs( + self.t,self.y,self.v,params,self.input_orders + ) + + self.data_term = PartialDataTerm( + self.t,self.y,self.v,self.traj_model + ) + self.colloc_term = CollocationTerm( + self.t_colloc,self.w_colloc, + self.traj_model,self.dynamics_model, + input_orders = self.input_orders,ode_order = self.ode_order + ) + self.residuals = JointResidual(self.data_term,self.colloc_term) + return params + + + def fit( + self, + t, + x = None, + t_colloc = None, + w_colloc = None, + params = None, + partialobs_y = None, + partialobs_v = None, + ): + #TODO: Add a logs dictionary that's carried around in the same way that params is + + if params is None: + params = dict() + params["show_progress"] = self.optimizer.solver_settings.show_progress + + is_partially_observed = check_is_partial_data(t,x,partialobs_y,partialobs_v) + self.is_partially_observed = is_partially_observed + if is_partially_observed is True: + params = self.initialize_fit_partial_obs( + t,partialobs_y,partialobs_v, + t_colloc,w_colloc,params + ) + else: + params = self.initialize_fit_full_obs( + t,x,t_colloc, + w_colloc, params + ) + + z,theta,opt_result,params = self.optimizer.run(self,params) + self.z = z + self.theta = theta + self.opt_result = opt_result + self.params = params + + def print(self,theta=None, precision: int = 3, **kwargs) -> None: + """Print the SINDy model equations. + precision: int, optional (default 3) + Precision to be used when printing out model coefficients. + **kwargs: Additional keyword arguments passed to the builtin print function + """ + if theta is None: + theta = self.theta + eqns = get_equations( + coef = theta.T, + feature_names = self.feature_names, + feature_library = self.dynamics_model.feature_map, + precision = precision + ) + if self.feature_names is None: + feature_names = [f"x{i}" for i in range(len(eqns))] + else: + feature_names = self.variable_names + + for name, eqn in zip(feature_names, eqns, strict=True): + lhs = f"({name}){"'"*self.ode_order}" + print(f"{lhs} = {eqn}", **kwargs) + + def predict(self,x,theta = None): + if theta is None: + theta = self.theta + return self.dynamics_model.predict(x,theta) + + def predict_state(self,t,z = None): + if z is None: + z = self.z + return self.traj_model.predict(t,z) + +def _setup_colloc(t,t_colloc,w_colloc): + if t_colloc is not None and w_colloc is None: + w_colloc = 1/len(t_colloc) * jnp.ones_like(t_colloc) + + if t_colloc is None: + t_colloc,w_colloc = get_collocation_points_weights(t) + return t_colloc,w_colloc diff --git a/pysindy/jsindy/trajectory_model.py b/pysindy/jsindy/trajectory_model.py new file mode 100644 index 000000000..ea2af7b41 --- /dev/null +++ b/pysindy/jsindy/trajectory_model.py @@ -0,0 +1,434 @@ +from abc import ABC +from typing import Any + +import jax +import jax.numpy as jnp +from jax.scipy.linalg import cholesky +from jax.scipy.linalg import solve_triangular +from jsindy.kernels import ConstantKernel +from jsindy.kernels import fit_kernel +from jsindy.kernels import fit_kernel_partialobs +from jsindy.kernels import Kernel +from jsindy.kernels import ScalarMaternKernel +from jsindy.kernels import softplus_inverse +from jsindy.kerneltools import diagpart +from jsindy.kerneltools import eval_k +from jsindy.kerneltools import get_kernel_block_ops +from jsindy.kerneltools import nth_derivative_operator_1d +from jsindy.util import l2reg_lstsq +from jsindy.util import row_block_diag + + +class TrajectoryModel(ABC): + system_dim: int + + def __call__(self, t, z): + pass + + def initalize_fit(self, t, x): + pass + + def derivative(self, t, z, diff_order=1): + pass + + +class RKHSInterpolant(TrajectoryModel): + """ + Args: + dimension: Dimension of the system + time_points: time points that we include from basis from canonical feature map + derivative_orders: Orders of derivatives that we wish to model and include in + the basis. + """ + + kernel: Kernel + # dimension: int + # time_points: jax.Array + # derivative_orders: tuple[int, ...] + # num_params: int + + def __init__( + self, + kernel=None, + derivative_orders: tuple[int, ...] = (0, 1), + nugget=1e-5, + ) -> None: + if kernel is None: + kernel = ConstantKernel(variance=5.0) + ScalarMaternKernel( + p=5, variance=10.0 + ) + self.kernel = kernel + self.is_attached = False + self.derivative_orders = derivative_orders + self.nugget = nugget + + def __str__(self): + return f""" + RKHS Trajectory Model + kernel: {self.kernel.__str__()} + derivative_orders: {self.derivative_orders} + nugget: {self.nugget} + """ + + def initialize( + self, + t, + x, + t_colloc, + params, + sigma2_est=None, + ): + params["sigma2_est"] = sigma2_est + self.attach(t_obs=t, x_obs=x, basis_time_points=t_colloc) + self.system_dim = x.shape[1] + self.num_basis = len(self.derivative_orders) * len(t_colloc) + self.tot_params = self.system_dim * self.num_basis + return params + + def initialize_partialobs( + self, + t, + y, + v, + t_colloc, + params, + sigma2_est=None, + ): + params["sigma2_est"] = sigma2_est + self.attach_partialobs(t_obs=t, y=y, v=v, basis_time_points=t_colloc) + self.system_dim = v.shape[1] + self.num_basis = len(self.derivative_orders) * len(t_colloc) + self.tot_params = self.system_dim * self.num_basis + return params + + def attach(self, t_obs, x_obs, basis_time_points): + self.dimension = x_obs.shape[1] + self.time_points = basis_time_points + self.basis_operators = tuple( + nth_derivative_operator_1d(n) for n in self.derivative_orders + ) + self.num_params = ( + len(self.derivative_orders) * len(basis_time_points) * self.dimension + ) + self.evaluation_kmat = get_kernel_block_ops( + self.kernel, (eval_k,), self.basis_operators, output_dim=self.dimension + ) + self.RKHS_mat = get_kernel_block_ops( + self.kernel, self.basis_operators, self.basis_operators, self.dimension + )(self.time_points, self.time_points) + self.RKHS_mat = self.RKHS_mat + self.nugget * diagpart(self.RKHS_mat) + self.regmat = self.RKHS_mat + self.is_attached = True + + def attach_partialobs(self, t_obs, y, v, basis_time_points): + self.dimension = v.shape[1] + self.time_points = basis_time_points + self.basis_operators = tuple( + nth_derivative_operator_1d(n) for n in self.derivative_orders + ) + self.num_params = ( + len(self.derivative_orders) * len(basis_time_points) * self.dimension + ) + self.evaluation_kmat = get_kernel_block_ops( + self.kernel, (eval_k,), self.basis_operators, output_dim=self.dimension + ) + self.RKHS_mat = get_kernel_block_ops( + self.kernel, self.basis_operators, self.basis_operators, self.dimension + )(self.time_points, self.time_points) + self.RKHS_mat = self.RKHS_mat + self.nugget * diagpart(self.RKHS_mat) + self.regmat = self.RKHS_mat + self.is_attached = True + + def _evaluate_operator(self, t, z, operator): + evaluation_matrix = get_kernel_block_ops( + k=self.kernel, + ops_left=(operator,), + ops_right=self.basis_operators, + output_dim=self.dimension, + )(t, self.time_points) + return evaluation_matrix @ z + + def __call__(self, t, z) -> Any: + return self.predict(t, z) + + def predict(self, t, z): + return self._evaluate_operator(t, z, eval_k).reshape(t.shape[0], self.dimension) + + def derivative(self, t, z, diff_order=1) -> Any: + return self._evaluate_operator( + t, z, nth_derivative_operator_1d(diff_order) + ).reshape(t.shape[0], self.dimension) + + def get_fitted_params(self, t, obs, lam=1e-4): + A = get_kernel_block_ops( + k=self.kernel, + ops_left=(eval_k,), + ops_right=self.basis_operators, + output_dim=self.dimension, + )(t, self.time_points) + + K = self.regmat + + M = A.T @ A + lam * K + M = M + 1e-7 * jnp.diag(M) + return jnp.linalg.solve(M, A.T @ obs.flatten()) + + def get_partialobs_fitted_params(self, t, y, v, lam=1e-4): + A = get_kernel_block_ops( + k=self.kernel, + ops_left=(eval_k,), + ops_right=self.basis_operators, + output_dim=self.dimension, + )(t, self.time_points) + V = row_block_diag(v) + A = V @ A + + K = self.regmat + + M = A.T @ A + lam * K + M = M + 1e-7 * jnp.diag(M) + return jnp.linalg.solve(M, A.T @ y) + + +class CholRKHSInterpolant(TrajectoryModel): + """ + Args: + dimension: Dimension of the system + time_points: time points that we include from basis from canonical feature map + derivative_orders: Orders of derivatives that we wish to model and include in + the basis. + """ + + kernel: Kernel + # dimension: int + # time_points: jax.Array + # derivative_orders: tuple[int, ...] + # num_params: int + + def __init__( + self, + kernel=None, + derivative_orders: tuple[int, ...] = (0, 1), + nugget=1e-8, + ) -> None: + if kernel is None: + kernel = ConstantKernel(variance=5.0) + ScalarMaternKernel( + p=5, variance=10.0 + ) + self.kernel = kernel + self.is_attached = False + self.derivative_orders = derivative_orders + self.nugget = nugget + + def __repr__(self): + return f""" + Cholesky Parametrized RKHS Trajectory Model + kernel: {self.kernel.__str__()} + derivative_orders: {self.derivative_orders} + nugget: {self.nugget} + """ + + def initialize( + self, + t, + x, + t_colloc, + params, + sigma2_est=None, + ): + params["sigma2_est"] = sigma2_est + self.attach(t_obs=t, x_obs=x, basis_time_points=t_colloc) + self.system_dim = x.shape[1] + self.num_basis = len(self.derivative_orders) * len(t_colloc) + self.tot_params = self.system_dim * self.num_basis + return params + + def initialize_partialobs( + self, + t, + y, + v, + t_colloc, + params, + sigma2_est=None, + ): + params["sigma2_est"] = sigma2_est + self.attach_partialobs(t_obs=t, y=y, v=v, basis_time_points=t_colloc) + self.system_dim = v.shape[1] + self.num_basis = len(self.derivative_orders) * len(t_colloc) + self.tot_params = self.system_dim * self.num_basis + return params + + def attach_partialobs(self, t_obs, y, v, basis_time_points): + self.dimension = v.shape[1] + self.time_points = basis_time_points + self.basis_operators = tuple( + nth_derivative_operator_1d(n) for n in self.derivative_orders + ) + self.num_params = ( + len(self.derivative_orders) * len(basis_time_points) * self.dimension + ) + self.evaluation_kmat = get_kernel_block_ops( + self.kernel, (eval_k,), self.basis_operators, output_dim=self.dimension + ) + self.RKHS_mat = get_kernel_block_ops( + self.kernel, self.basis_operators, self.basis_operators, self.dimension + )(self.time_points, self.time_points) + self.RKHS_mat = self.RKHS_mat + self.nugget * diagpart(self.RKHS_mat) + self.cholT = cholesky( + self.RKHS_mat + self.nugget * diagpart(self.RKHS_mat), lower=False + ) + self.regmat = jnp.eye(len(self.RKHS_mat)) + self.is_attached = True + + def attach(self, t_obs, x_obs, basis_time_points): + self.dimension = x_obs.shape[1] + self.time_points = basis_time_points + self.basis_operators = tuple( + nth_derivative_operator_1d(n) for n in self.derivative_orders + ) + self.num_params = ( + len(self.derivative_orders) * len(basis_time_points) * self.dimension + ) + self.evaluation_kmat = get_kernel_block_ops( + self.kernel, (eval_k,), self.basis_operators, output_dim=self.dimension + ) + self.RKHS_mat = get_kernel_block_ops( + self.kernel, self.basis_operators, self.basis_operators, self.dimension + )(self.time_points, self.time_points) + self.RKHS_mat = self.RKHS_mat + self.nugget * diagpart(self.RKHS_mat) + self.cholT = cholesky( + self.RKHS_mat + self.nugget * diagpart(self.RKHS_mat), lower=False + ) + self.regmat = jnp.eye(len(self.RKHS_mat)) + self.is_attached = True + + def _evaluate_operator(self, t, z, operator): + evaluation_matrix = get_kernel_block_ops( + k=self.kernel, + ops_left=(operator,), + ops_right=self.basis_operators, + output_dim=self.dimension, + )(t, self.time_points) + return evaluation_matrix @ solve_triangular(self.cholT, z) + + def predict(self, t, z): + return self._evaluate_operator(t, z, eval_k).reshape(t.shape[0], self.dimension) + + def __call__(self, t, z) -> Any: + return self.predict(t, z) + + def derivative(self, t, z, diff_order=1) -> Any: + return self._evaluate_operator( + t, z, nth_derivative_operator_1d(diff_order) + ).reshape(t.shape[0], self.dimension) + + def get_fitted_params(self, t, obs, lam=1e-4): + K = self.evaluation_kmat(t, self.time_points) + M = solve_triangular(self.cholT.T, K.T, lower=True).T + return l2reg_lstsq(M, obs.flatten(), reg=lam) + + def get_partialobs_fitted_params(self, t, y, v, lam=1e-4): + K = self.evaluation_kmat(t, self.time_points) + V = row_block_diag(v) + K = V @ K + M = solve_triangular(self.cholT.T, K.T, lower=True).T + return l2reg_lstsq(M, y, reg=lam) + + +class DataAdaptedRKHSInterpolant(RKHSInterpolant): + """ + Args: + dimension: Dimension of the system + time_points: time points that we include from basis from canonical feature map + derivative_orders: Orders of derivatives that we wish to model and include in + the basis. + """ + + def __repr__(self): + return f""" + MLE Adapted RKHS Trajectory Model + kernel: {self.kernel.__str__()} + derivative_orders: {self.derivative_orders} + nugget: {self.nugget} + """ + + def initialize(self, t, x, t_colloc, params): + fitted_kernel, sigma2_est, conv = fit_kernel( + init_kernel=self.kernel, + init_sigma2=jnp.var(x) / 20, + X=t, + y=x, + lbfgs_tol=1e-8, + show_progress=params["show_progress"], + ) + self.kernel = fitted_kernel + params["sigma2_est"] = sigma2_est + self.attach(t_obs=t, x_obs=x, basis_time_points=t_colloc) + self.system_dim = x.shape[1] + self.num_basis = len(self.derivative_orders) * len(t_colloc) + self.tot_params = self.system_dim * self.num_basis + return params + + def initialize_partialobs(self, t, y, v, t_colloc, params): + fitted_kernel, sigma2_est, conv = fit_kernel_partialobs( + init_kernel=self.kernel, + init_sigma2=jnp.var(y) / 20, + t=t, + y=y, + v=v, + lbfgs_tol=1e-8, + show_progress=params["show_progress"], + ) + self.kernel = fitted_kernel + params["sigma2_est"] = sigma2_est + self.attach_partialobs(t_obs=t, y=y, v=v, basis_time_points=t_colloc) + self.system_dim = v.shape[1] + self.num_basis = len(self.derivative_orders) * len(t_colloc) + self.tot_params = self.system_dim * self.num_basis + return params + + +class CholDataAdaptedRKHSInterpolant(CholRKHSInterpolant): + def __repr__(self): + return f""" + MLE Adapted Cholesky Parametrized RKHS Trajectory Model + kernel: {self.kernel.__str__()} + derivative_orders: {self.derivative_orders} + nugget: {self.nugget} + """ + + def initialize(self, t, x, t_colloc, params): + fitted_kernel, sigma2_est, conv = fit_kernel( + init_kernel=self.kernel, + init_sigma2=jnp.var(x) / 20, + X=t, + y=x, + lbfgs_tol=1e-8, + show_progress=False, # params["show_progress"] + ) + self.kernel = fitted_kernel + params["sigma2_est"] = sigma2_est + self.attach(t_obs=t, x_obs=x, basis_time_points=t_colloc) + self.system_dim = x.shape[1] + self.num_basis = len(self.derivative_orders) * len(t_colloc) + self.tot_params = self.system_dim * self.num_basis + return params + + def initialize_partialobs(self, t, y, v, t_colloc, params): + fitted_kernel, sigma2_est, conv = fit_kernel_partialobs( + init_kernel=self.kernel, + init_sigma2=jnp.var(y) / 20, + t=t, + y=y, + v=v, + lbfgs_tol=1e-8, + show_progress=False, # params["show_progress"] + ) + self.kernel = fitted_kernel + params["sigma2_est"] = sigma2_est + self.attach_partialobs(t_obs=t, y=y, v=v, basis_time_points=t_colloc) + self.system_dim = v.shape[1] + self.num_basis = len(self.derivative_orders) * len(t_colloc) + self.tot_params = self.system_dim * self.num_basis + return params diff --git a/pysindy/jsindy/util.py b/pysindy/jsindy/util.py new file mode 100644 index 000000000..795710eaf --- /dev/null +++ b/pysindy/jsindy/util.py @@ -0,0 +1,158 @@ +import jax +import jax.numpy as jnp + + +def validate_data_inputs(t, x, y, v): + if x is None: + assert y is not None + assert v is not None + assert len(t) == len(v) + assert len(t) == len(y) + if y is None: + assert x is not None + assert len(t) == len(x) + if x is not None: + assert y is None + assert v is None + + +def check_is_partial_data(t, x, y, v): + validate_data_inputs(t, x, y, v) + if v is None: + return False + else: + return True + + +def get_collocation_points_weights(t, num_colloc=500, bleedout_nodes=1.0): + min_t = jnp.min(t) + max_t = jnp.max(t) + span = max_t - min_t + lower = min_t - bleedout_nodes * span / num_colloc + upper = max_t + bleedout_nodes * span / num_colloc + col_points = jnp.linspace(lower, upper, num_colloc) + # Scale so that it's consistent to integral, rather than sum to 1. + col_weights = (upper - lower) / num_colloc * jnp.ones_like(col_points) + return col_points, col_weights + + +@jax.jit +def l2reg_lstsq(A, y, reg=1e-10): + U, sigma, Vt = jnp.linalg.svd(A, full_matrices=False) + if jnp.ndim(y) == 2: + return Vt.T @ ((sigma / (sigma**2 + reg))[:, None] * (U.T @ y)) + else: + return Vt.T @ ((sigma / (sigma**2 + reg)) * (U.T @ y)) + + +def tree_dot(tree, other): + # Multiply corresponding leaves and sum each product over all its elements. + vdots = jax.tree.map(lambda x, y: jnp.sum(x * y), tree, other) + return jax.tree.reduce(lambda x, y: x + y, vdots, initializer=0.0) + + +def tree_add(tree, other): + return jax.tree.map(lambda x, y: x + y, tree, other) + + +def tree_scale(tree, scalar): + return jax.tree.map(lambda x: scalar * x, tree) + + +def get_equations( + coef, feature_names, feature_library, precision: int = 3 +) -> list[str]: + """ + Get the right hand sides of the SINDy model equations. + + Parameters + ---------- + precision: int, optional (default 3) + Number of decimal points to include for each coefficient in the + equation. + + Returns + ------- + equations: list of strings + List of strings representing the SINDy model equations for each + input feature. + """ + feat_names = feature_library.get_feature_names(feature_names) + + def term(c, name): + rounded_coef = jnp.round(c, precision) + if rounded_coef == 0: + return "" + else: + return f"{c:.{precision}f} {name}" + + equations = [] + for coef_row in coef: + components = [term(c, i) for c, i in zip(coef_row, feat_names)] + eq = " + ".join(filter(bool, components)) + if not eq: + eq = f"{0:.{precision}f}" + equations.append(eq) + + return equations + + +def full_data_initialize( + t, + x, + traj_model, + dynamics_model, + sigma2_est=0.1, + theta_reg=0.001, + input_orders=(0,), + ode_order=1, +): + t_grid = jnp.linspace(jnp.min(t), jnp.max(t), 500) + z = traj_model.get_fitted_params(t, x, lam=sigma2_est) + + X_inputs = jnp.hstack([traj_model.derivative(t_grid, z, k) for k in input_orders]) + Xdot_pred = traj_model.derivative(t_grid, z, ode_order) + + theta = dynamics_model.get_fitted_theta(X_inputs, Xdot_pred, lam=theta_reg) + return z, theta + + +def partial_obs_initialize( + t, + y, + v, + traj_model, + dynamics_model, + sigma2_est=0.1, + theta_reg=0.001, + input_orders=(0,), + ode_order=1, +): + t_grid = jnp.linspace(jnp.min(t), jnp.max(t), 500) + z = traj_model.get_partialobs_fitted_params(t, y, v, lam=sigma2_est) + + X_inputs = jnp.hstack([traj_model.derivative(t_grid, z, k) for k in input_orders]) + Xdot_pred = traj_model.derivative(t_grid, z, ode_order) + + theta = dynamics_model.get_fitted_theta(X_inputs, Xdot_pred, lam=theta_reg) + return z, theta + + +def legendre_nodes_weights(n, a, b): + from numpy.polynomial.legendre import leggauss + + nodes, weights = leggauss(n) + nodes = jnp.array(nodes) + weights = jnp.array(weights) + width = b - a + nodes = (width) / 2 * nodes + (a + b) / 2 + weights = (width / 2) * weights + return nodes, weights + + +def row_block_diag(V): + n, d = V.shape + eye = jnp.eye(n)[:, :, None] # Shape (n, n, 1) + V_exp = V[None, :, :] # Shape (1, n, d) + blocks = eye * V_exp # Shape (n, n, d) + return blocks.reshape(n, n * d) # Shape (n, n*d) diff --git a/pysindy/sssindy/__init__.py b/pysindy/sssindy/__init__.py new file mode 100644 index 000000000..4f903719f --- /dev/null +++ b/pysindy/sssindy/__init__.py @@ -0,0 +1,21 @@ +from ._typing import KernelFunc +from .expressions import JaxPolyLib +from .expressions import JointObjective +from .interpolants import InterpolantDifferentiation +from .interpolants import RKHSInterpolant +from .opt import L2CholeskyLMRegularizer +from .opt import LMSolver +from .opt import SINDyAlternatingLMReg +from .sssindy import SSSINDy + +__all__ = [ + "InterpolantDifferentiation", + "JaxPolyLib", + "KernelFunc", + "JointObjective", + "L2CholeskyLMRegularizer", + "LMSolver", + "RKHSInterpolant", + "SINDyAlternatingLMReg", + "SSSINDy", +] diff --git a/pysindy/sssindy/_skjax.py b/pysindy/sssindy/_skjax.py new file mode 100644 index 000000000..a7a540459 --- /dev/null +++ b/pysindy/sssindy/_skjax.py @@ -0,0 +1,68 @@ +from jax import tree_util + + +def register_scikit_pytree( + cls: type, + data_fields: list[str], + data_fit_fields: list[str], + meta_fields: list[str], + meta_fit_fields: list[str], +) -> type: + """Register sklearn.BaseEstimator-like classes as pytrees + + Args: + cls: class to decorate + data_fields: initialization attributes that are compilable jax types + (e.g. float, jax.Array, pytree) + data_fit_fields: data-dependent attributes that are set with a call to fit(). + These must also be compilable jax types. + meta_fields: initialization non-jax attributes, which must be hashable + in order to serve as a JIT compiler cache key + meta_fit_fields: data-dependent non-jax attributes. + + Adapted from https://github.com/jax-ml/jax/issues/25760 + """ + expected_fields = set(data_fields + meta_fields) + total_fields = expected_fields.union(set(data_fit_fields + meta_fit_fields)) + + def flatten_with_keys(obj): + try: + actual_fields = obj.__dict__.keys() + except AttributeError: + # All Python objects without __dict__ have __slots__. + # __slots__ may be a str or iterable of strings: + # https://docs.python.org/3/reference/datamodel.html#slots + slots = obj.__slots__ + actual_fields = {slots} if isinstance(slots, str) else set(slots) + + if actual_fields != expected_fields and actual_fields != total_fields: + raise TypeError( + "unexpected attributes on object: " + f"got {sorted(actual_fields)}, expected {sorted(expected_fields)}" + f" or {sorted(total_fields)}" + ) + + children_with_keys = [ + (tree_util.GetAttrKey(k), getattr(obj, k)) for k in data_fields + ] + if data_fit_fields and hasattr(obj, data_fit_fields[0]): + children_with_keys += [ + (tree_util.GetAttrKey(k), getattr(obj, k)) for k in data_fit_fields + ] + aux_data = tuple((k, getattr(obj, k)) for k in meta_fields) + if meta_fit_fields and hasattr(obj, meta_fit_fields[0]): + aux_data = aux_data + tuple((k, getattr(obj, k)) for k in meta_fit_fields) + return children_with_keys, aux_data + + def unflatten_func(aux_data, children): + result = object.__new__(cls) + # zip will truncate to shortest, so if fit fields are not present, + # those keys are ignored. + for k, v in zip(data_fields + data_fit_fields, children): + object.__setattr__(result, k, v) + for k, v in aux_data: + object.__setattr__(result, k, v) + return result + + tree_util.register_pytree_with_keys(cls, flatten_with_keys, unflatten_func) + return cls diff --git a/pysindy/sssindy/_typing.py b/pysindy/sssindy/_typing.py new file mode 100644 index 000000000..af15809eb --- /dev/null +++ b/pysindy/sssindy/_typing.py @@ -0,0 +1,15 @@ +from typing import Callable +from typing import TypeAlias +from typing import TypeVar + +import jax +import numpy as np +from numpy.typing import NBitBase + + +Float1D = np.ndarray[tuple[int], np.dtype[np.floating[NBitBase]]] +Float2D = np.ndarray[tuple[int, int], np.dtype[np.floating[NBitBase]]] +ArrayType = TypeVar("ArrayType", np.ndarray, jax.Array, covariant=True) +AnyArray = np.ndarray | jax.Array +KernelFunc: TypeAlias = Callable[[jax.Array, jax.Array], jax.Array] +TrajOrList = TypeVar("TrajOrList", list[jax.Array], jax.Array) diff --git a/pysindy/sssindy/expressions.py b/pysindy/sssindy/expressions.py new file mode 100644 index 000000000..c3aba4a4a --- /dev/null +++ b/pysindy/sssindy/expressions.py @@ -0,0 +1,313 @@ +from collections.abc import Sequence +from copy import copy +from dataclasses import dataclass +from functools import partial +from typing import Any +from typing import Callable +from typing import Optional + +import jax +import jax.numpy as jnp +from jax.scipy.linalg import block_diag +from jax.tree_util import register_dataclass +from sklearn.base import BaseEstimator +from sklearn.base import check_is_fitted +from sklearn.base import TransformerMixin +from typing_extensions import Self + +import pysindy as ps +from ._typing import Float1D +from ._typing import Float2D +from .interpolants.base import TrajectoryInterpolant +from pysindy.feature_library.base import x_sequence_or_item + + +@partial( + register_dataclass, + data_fields=["model_param_regmat", "state_param_regmat"], + meta_fields=[ + "data_residual_func", + "dynamics_residual_func", + "n_meas", + "full_n_process", + "full_n_theta", + "system_dim", + "num_features", + "traj_coef_slices", + ], +) +@dataclass +class ObjectiveResidual: + """ + Arguments returned when calling KernelObjective.transform(). + + Warning: This dataclass generates a hash based upon container identity, + assuming immutability. + + Args: + ----- + data_residual_func: residual function for only the data loss. + dynamics_residual_func: residual function for only the dynamics loss. + n_meas: Total number of measurements taken (across all trajectories). + full_n_process: total number of coefficients on kernel Ansatz for approximation + to trajectories. + full_n_theta: Total number of coefficients on dynamics, or SINDy, approximation + to true governing dynamics. + system_dim: Dimension of governing ode system. + num_features: Number of features in the feature library. + traj_coef_slices: list of slices to access the coefficients for each trajectory. + Each indexes an array of shape ``(full_n_process + full_n_theta,)``. + + Attributes: + resid_func (JitWrapper): univariate residual function for the full loss. + jac_func (JitWrapper): univariate jacobian of the residual function. + damping_matrix (jax.Array): Matrix reflecting the natural parameter metric + for inner products and norms. + """ + + data_residual_func: Callable[[list[jax.Array]], Any] + dynamics_residual_func: Callable[[list[jax.Array], jax.Array], Any] + model_param_regmat: jax.Array + state_param_regmat: jax.Array + n_meas: int + full_n_process: int + full_n_theta: int + system_dim: int + num_features: int + traj_coef_slices: list[slice] + + def __post_init__(self): + self.resid_func = self.F_stacked + self.jac_func = jax.jacrev(self.F_stacked) + self.damping_matrix = block_diag( + self.state_param_regmat, self.model_param_regmat + ) + + def extract_state_params(self, stacked_flattened_params): + state_params = [ + stacked_flattened_params[traj_slice] for traj_slice in self.traj_coef_slices + ] + return state_params + + def extract_model_params(self, stacked_flattened_params): + theta_model = stacked_flattened_params[self.full_n_process :].reshape( + self.num_features, self.system_dim + ) + return theta_model + + def F_split(self, state_params, model_params): + return jnp.hstack( + [ + self.data_residual_func(state_params), + self.dynamics_residual_func(state_params, model_params), + ] + ) + + def F_stacked(self, stacked_flattened_params): + """ + This stacks the input variables (state_params + model_params) and returns + the stacked + """ + state_params = self.extract_state_params(stacked_flattened_params) + theta_model = self.extract_model_params(stacked_flattened_params) + return self.F_split(state_params=state_params, model_params=theta_model) + + def __hash__(self): + return super().__hash__() + + def __eq__(self, other): + return self is other + + +def _make_data_residual_f( + interpolant: TrajectoryInterpolant, + measurement_times: Float1D, + z_measurements: Float2D, +): + normalization = jnp.sqrt( + jnp.prod(jnp.array(z_measurements.shape)) + ) # Normalization factor + + def residual_func(state_params): + return ( + jnp.array(z_measurements) + - interpolant(measurement_times, params=state_params) + ).flatten() / normalization + + return residual_func + + +class JaxPolyLib(ps.PolynomialLibrary): + @x_sequence_or_item + def fit(self, x: list[jax.Array]): + super().fit(x) + + @x_sequence_or_item + def transform(self, x: list[jax.Array]): + xforms = [] + for dataset in x: + terms = [ + jnp.prod(dataset**exps, axis=1, keepdims=True) + for exps in self.powers_ + ] + xforms.append(jnp.concatenate(terms, axis=-1)) + return xforms + + def __call__(self, X): + return self.transform(X) + + +def _make_dynamics_residual_f( + interpolant: TrajectoryInterpolant, t_coloc: Float1D, feature_lib: JaxPolyLib +): + def residual_func(state_params, theta_model): + state_estimates = interpolant(t_coloc, state_params) + features = feature_lib(state_estimates) + + derivative_estimates = interpolant.derivative( + t_coloc, state_params, diff_order=1 + ) + dynamic_residuals = (features @ theta_model - derivative_estimates).flatten() + return dynamic_residuals / jnp.sqrt(len(dynamic_residuals)) + + return residual_func + + +@dataclass +class JointObjective(BaseEstimator, TransformerMixin): + """Single-step SINDy loss expression, specialized for a kernel basis + + TODO: Eventually we'll want to eliminate the discrepancy + between this and residual objective, or explicitly make this + a residual-factory + + An expression conceptualizes the combination of data and process misfit. + Upon instantiation, it concretizes all elements of the loss function except + the data and the optimization variables. + + When fit, it also identifies the shape of the optimization variables + + Attributes: + + traj_coef_slices_: list[slice] + indexes to each trajectory's coefficients in the process variable + """ + + data_weight: float + dynamics_weight: float + lib: JaxPolyLib + interp_template: TrajectoryInterpolant + + def fit( + self, + z_meas: list[jax.Array], + t_meas: list[jax.Array], + *, + t_coloc: Optional[list[jax.Array]] = None, + ) -> Self: + """Determine the residual functions for data and colocation + + Arguments: + z_meas: Measurements of trajectories, axes following pysindy convention + t_meas: time of these measurements + t_coloc: time points for each trajectory to encourage fit of + SINDy coefficients + + """ + if t_coloc is None: + t_coloc = t_meas + self.t_meas_ = t_meas + self.t_coloc_ = t_coloc + self.lib.fit(z_meas) + self.n_meas_ = [t.shape[0] for t in t_meas] + self.n_coloc_ = [t.shape[0] for t in t_coloc] + self.system_dimension = z_meas[0].shape[-1] + self.n_features_ = self.lib.n_output_features_ + # Ike/Alex, we should see if we can keep variables in their native shapes as + # long as possible, rather than pre-flattening them here and in _make_*_residual + self.full_n_theta_ = self.n_features_ * self.system_dimension + + self.n_trajectories_ = len(z_meas) + self.traj_interps = [ + copy(self.interp_template).fit_time( + dimension=self.system_dimension, + time_points=t, + ) + for t in self.t_coloc_ + ] # Later generalize to higher orders + + traj_coef_start = 0 + self.traj_coef_slices_ = [] + for traj in self.traj_interps: + traj_coef_end = traj_coef_start + traj.num_params + self.traj_coef_slices_.append(slice(traj_coef_start, traj_coef_end)) + + self.full_n_process = sum(interp.num_params for interp in self.traj_interps) + self.data_resid_funcs_ = [ + _make_data_residual_f(interp, t, z) # type: ignore + for interp, z, t in zip(self.traj_interps, z_meas, t_meas, strict=True) + ] + self.dyna_resid_funcs_ = [ + _make_dynamics_residual_f(interp, t, self.lib) # type: ignore + for interp, t in zip(self.traj_interps, t_coloc, strict=True) + ] + return self + + @x_sequence_or_item + def transform(self, *args, **kwargs) -> ObjectiveResidual: + """Convert the data into a residual function. + + Arguments are retained only for sklearn compatibility + + Returns: + A tuple of residual function for the loss, a matrix defining the + RKHS norm (and by extension, the trajectory basis functions and number + of basis coefficients), and the number of total measurements, . + The residual function accepts stacked trajectory estimates and + sindy coefficient estimates; the matrix defining the RKHS also + defines the basis for the trajectory estimates + """ + check_is_fitted(self) + + def stacked_data_residual_fun(state_params: list[jax.Array]): + return jnp.hstack( + [ + jnp.sqrt(self.data_weight) * res(coef) + for coef, res in zip(state_params, self.data_resid_funcs_) + ] + ) + + def stacked_dyna_residual_fun( + state_params: list[jax.Array], theta_model: jax.Array + ): + return jnp.hstack( + [ + jnp.sqrt(self.dynamics_weight) * dyna_resid(coef, theta_model) + for coef, dyna_resid in zip(state_params, self.dyna_resid_funcs_) + ] + ) + + # This one should be instantiated by the feature library + model_param_regmat = jnp.eye(self.full_n_theta_) + + # State + state_param_regmat = block_diag(*[traj.gram_mat for traj in self.traj_interps]) + + residual_objective = ObjectiveResidual( + data_residual_func=stacked_data_residual_fun, + dynamics_residual_func=stacked_dyna_residual_fun, + model_param_regmat=model_param_regmat, + state_param_regmat=state_param_regmat, + n_meas=sum(self.n_meas_) * self.system_dimension, + full_n_process=self.full_n_process, + full_n_theta=self.full_n_theta_, + system_dim=self.system_dimension, + traj_coef_slices=self.traj_coef_slices_, + num_features=self.lib.n_output_features_, + ) + + # convert to ObjectiveResidual and update where appropriate + return residual_objective + + def get_feature_names(self, input_features: Optional[Sequence[str]] = None): + return self.lib.get_feature_names(input_features=input_features) diff --git a/pysindy/sssindy/interpolants/__init__.py b/pysindy/sssindy/interpolants/__init__.py new file mode 100644 index 000000000..dd1bff4bc --- /dev/null +++ b/pysindy/sssindy/interpolants/__init__.py @@ -0,0 +1,27 @@ +from .base import MockInterpolant +from .base import TrajectoryInterpolant +from .compat import InterpolantDifferentiation +from .fit_kernel import fit_kernel +from .kernels import ConstantKernel +from .kernels import GaussianRBFKernel +from .kernels import get_gaussianRBF +from .kernels import RationalQuadraticKernel +from .kernels import ScalarMaternKernel +from .kernels import SpectralMixtureKernel +from .kernels import TransformedKernel +from .rkhs import RKHSInterpolant + +__all__ = [ + "TrajectoryInterpolant", + "RKHSInterpolant", + "MockInterpolant", + "fit_kernel", + "get_gaussianRBF", + "InterpolantDifferentiation", + "ScalarMaternKernel", + "SpectralMixtureKernel", + "GaussianRBFKernel", + "RationalQuadraticKernel", + "ConstantKernel", + "TransformedKernel", +] diff --git a/pysindy/sssindy/interpolants/base.py b/pysindy/sssindy/interpolants/base.py new file mode 100644 index 000000000..b3a55091a --- /dev/null +++ b/pysindy/sssindy/interpolants/base.py @@ -0,0 +1,96 @@ +from abc import ABC +from abc import abstractmethod +from typing import Any + +import jax +import jax.numpy as jnp +from typing_extensions import Self + + +class TrajectoryInterpolant(ABC): + """Model for a trajectory estimate, represents system state as a function of time""" + + num_params: int + + @abstractmethod + def fit_time(self, dimension: int, time_points: jax.Array) -> Self: + """Establish the shape and internal structure of the interpolant.""" + pass + + @abstractmethod + def fit_obs(self, t: jax.Array, x: jax.Array, noise_var: float) -> jax.Array: + """Discover coefficients of internal model given observations data. + + Args: + x: observation data, in shape (n_time, system_dimension) + noise_var: the variance in the measurement noise + + Returns: + The parameters used to evaluate the interpolant and its derivatives. + """ + pass + + @abstractmethod + def interpolate( + self, x: jax.Array, t: jax.Array, t_colloc: jax.Array, diff_order=0 + ) -> jax.Array: + """Fit a copy of this interpolant to observations x at time t + + This does not mutate the original interpolant. + + Arguments: + x: Observations of the system at time t + t: Time points of the observations + t_colloc: Points at which to interpolate + diff_order: Order of the derivative to evaluate + + Returns: + An nth-order derivative that interpolates the data + """ + pass + + @abstractmethod + def __call__(self, t, params) -> Any: + pass + + @abstractmethod + def derivative(self, t, params, diff_order=1) -> Any: + pass + + +class LSQInterpolant(TrajectoryInterpolant): + gram_mat: jax.Array + + +class MockInterpolant(TrajectoryInterpolant): + """Don't interpolate any data, just return it back. Say all derivatives are zero""" + + def __init__(self): + pass + + def fit_time(self, dimension, time_points): + self.time_points = time_points + self.dimension = dimension + + self.num_params = self.dimension * len(self.time_points) + self.gram_mat = jnp.diag(jnp.ones((self.num_params))) + return self + + def __call__(self, t, params) -> Any: + return params.reshape(t.shape[0], self.dimension) + + def derivative(self, t, params, diff_order=1) -> Any: + return jnp.zeros((t.shape[0], self.dimension)) + + def fit_obs(self, t: jax.Array, x: jax.Array, noise_var: float) -> jax.Array: + if len(t) != len(x): + raise ValueError("I'm a mock interpolant, I don't do any interpolating") + return x + + def interpolate( + self, x: jax.Array, t: jax.Array, t_colloc: jax.Array, diff_order=0 + ) -> jax.Array: + if diff_order == 0: + return x + else: + return jnp.zeros_like(x) diff --git a/pysindy/sssindy/interpolants/compat.py b/pysindy/sssindy/interpolants/compat.py new file mode 100644 index 000000000..588a6cba7 --- /dev/null +++ b/pysindy/sssindy/interpolants/compat.py @@ -0,0 +1,26 @@ +import pysindy as ps +from .base import TrajectoryInterpolant + + +class InterpolantDifferentiation(ps.BaseDifferentiation): + """Use the new interpolation methods for differentiation in classic SINDy. + + Args: + interpolant: The interpolant to use for differentiation. + d: The order of the derivative to compute. + noise_var: the measurement noise variance + """ + + def __init__( + self, interpolant: TrajectoryInterpolant, d: int = 1, noise_var: float = 0 + ): + self.interpolant = interpolant + self.d = d + self.noise_var = noise_var + + def _differentiate(self, x, t): + self.interpolant.fit_time(x.shape[-1], t) + interp_params = self.interpolant.fit_obs(t, x, noise_var=self.noise_var) + self.smoothed_x_ = self.interpolant(t, interp_params) + x_dot = self.interpolant.derivative(t, interp_params, diff_order=self.d) + return x_dot diff --git a/pysindy/sssindy/interpolants/fit_kernel.py b/pysindy/sssindy/interpolants/fit_kernel.py new file mode 100644 index 000000000..cf25be769 --- /dev/null +++ b/pysindy/sssindy/interpolants/fit_kernel.py @@ -0,0 +1,85 @@ +from logging import getLogger + +import jax +import jax.numpy as jnp +from jax.nn import softplus +from jaxopt import LBFGS + +from .kernels import softplus_inverse +from .kerneltools import vectorize_kfunc +from .tree_opt import run_gradient_descent +from .tree_opt import run_jaxopt_solver + +logger = getLogger(__name__) + + +def build_neg_marglike(X, y): + if jnp.ndim(y) == 1: + m = 1 + elif jnp.ndim(y) == 2: + m = y.shape[1] + else: + raise ValueError("y must be either a 1 or two dimensional array") + + def neg_marginal_likelihood(kernel, sigma2): + K = vectorize_kfunc(kernel)(X, X) + identity = jnp.eye(len(X)) + + C = jax.scipy.linalg.cholesky(K + sigma2 * identity, lower=True) + logdet = 2 * jnp.sum(jnp.log(jnp.diag(C))) + yTKinvY = jnp.sum((jax.scipy.linalg.solve_triangular(C, y, lower=True)) ** 2) + return m * logdet + yTKinvY + + def loss(params): + k = params["kernel"] + sigma2 = softplus(params["transformed_sigma2"]) + return neg_marginal_likelihood(k, sigma2) + + return loss + + +def build_loocv(X, y): + def loocv(kernel, sigma2): + k = vectorize_kfunc(kernel) + K = k(X, X) + identity = jnp.eye(len(X)) + P = jnp.linalg.inv(K + sigma2 * identity) + KP = K @ P + loo_preds = K @ P @ y - (jnp.diag(KP) / jnp.diag(P)) * (P @ y) + mse_loo = jnp.mean((loo_preds - y) ** 2) + return mse_loo + + def loss(params): + k = params["kernel"] + sigma2 = softplus(params["transformed_sigma2"]) + return loocv(k, sigma2) + + return loss + + +def fit_kernel( + init_kernel, + init_sigma2, + X, + y, + loss_builder=build_neg_marglike, + gd_tol=1e-1, + lbfgs_tol=1e-5, + max_gd_iter=1000, + max_lbfgs_iter=1000, +): + loss = loss_builder(X, y) + init_params = { + "kernel": init_kernel, + "transformed_sigma2": jnp.array(softplus_inverse(init_sigma2)), + } + logger.info("Warm starting marginal likelihood with gradient descent") + params, conv_history_gd = run_gradient_descent( + loss, init_params, tol=gd_tol, maxiter=max_gd_iter + ) + solver = LBFGS(loss, maxiter=max_lbfgs_iter, tol=lbfgs_tol) + logger.info("Solving marginal likelihood with LBFGS") + params, conv_history_bfgs, state = run_jaxopt_solver(solver, params) + conv_hist = [conv_history_gd, conv_history_bfgs] + + return params["kernel"], jax.nn.softplus(params["transformed_sigma2"]), conv_hist diff --git a/pysindy/sssindy/interpolants/kernels.py b/pysindy/sssindy/interpolants/kernels.py new file mode 100644 index 000000000..5348337a0 --- /dev/null +++ b/pysindy/sssindy/interpolants/kernels.py @@ -0,0 +1,366 @@ +from abc import ABC +from abc import abstractmethod +from typing import Callable +from warnings import warn + +import equinox as eqx +import jax +import jax.numpy as jnp +from jax.nn import softplus +from jax.tree_util import Partial as partial + +from .matern import build_matern_core + + +def softplus_inverse(y: jnp.ndarray) -> jnp.ndarray: + return y + jnp.log1p(-jnp.exp(-y)) + + +class Kernel(eqx.Module, ABC): + """Abstract base class for kernels in JAX + Equinox.""" + + @abstractmethod + def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + """Compute k(x, y). Must be overridden by subclasses.""" + pass + + @abstractmethod + def pformat(self) -> str: + """Format the kernel as a string. + + All internal scaling returned to user parameter space.""" + pass + + def __add__(self, other: "Kernel"): + """ + Overload the '+' operator so we can do k1 + k2. + Internally, we return a SumKernel object containing both. + Also handles the case if `other` is already a SumKernel, in + which case we combine everything into one big sum. + """ + if isinstance(other, SumKernel): + # Combine self with an existing SumKernel's list + return SumKernel(*([self] + list(other.kernels))) + elif isinstance(other, Kernel): + return SumKernel(self, other) + else: + return NotImplemented + + def __prod__(self, other: "Kernel"): + """ + Overload the '*' operator so we can do k1 * k2. + Internally, we return a ProductKernel object containing both. + Also handles the case if `other` is already a ProductKernel, in + which case we combine everything into one big sum. + """ + if isinstance(other, ProductKernel): + return ProductKernel(*([self] + list(other.kernels))) + elif isinstance(other, Kernel): + return ProductKernel(self, other) + else: + return NotImplemented + + def transform(f): + """ + Creates a transformed kernel, returning a kernel function + k_transformed(x,y) = k(f(x),f(y)) + """ + + +class TransformedKernel(Kernel): + """ + Transformed kernel, representing the + composition of a kernel with another + fixed function + """ + + kernel: Kernel + transform: Callable = eqx.field(static=True) + + def __init__(self, kernel, transform): + self.kernel = kernel + self.transform = transform + + def __call__(self, x, y): + return self.kernel(self.transform(x), self.transform(y)) + + def pformat(self): + return ( + f"TransformedKernel(transform={self.transform.__name__}\n" + f"\tapplied to {self.kernel.pformat()})" + ) + + +class SumKernel(Kernel): + """ + Represents the sum of multiple kernels: + k_sum(x, y) = sum_{k in kernels} k(x, y) + """ + + kernels: tuple[Kernel, ...] + + def __init__(self, *kernels: Kernel): + self.kernels = kernels + + def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + return sum(k(x, y) for k in self.kernels) + + def __add__(self, other: "Kernel"): + """ + If we do (k1 + k2) + k3, the left side is a SumKernel, so + we define its __add__ to merge again into one SumKernel. + """ + if isinstance(other, SumKernel): + return SumKernel(*(list(self.kernels) + list(other.kernels))) + elif isinstance(other, Kernel): + return SumKernel(*(list(self.kernels) + [other])) + else: + return NotImplemented + + def pformat(self): + kstrings = ["\n\t" + kernel.pformat() for kernel in self.kernels] + return "Sum of (" + ", ".join(kstrings) + "\n)" + + +class ProductKernel(Kernel): + """ + Represents the sum of multiple kernels: + k_sum(x, y) = prod_{k in kernels} k(x, y) + """ + + kernels: tuple[Kernel, ...] + + def __init__(self, *kernels: Kernel): + self.kernels = kernels + + def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + return jnp.prod(jnp.array([k(x, y) for k in self.kernels])) + + def __prod__(self, other: "Kernel"): + """ + If we do (k1*k2)*k3, the left side is a ProductKernel, so + we define its __prod__ to merge again into one ProductKernel. + """ + if isinstance(other, SumKernel): + return ProductKernel(*(list(self.kernels) + list(other.kernels))) + elif isinstance(other, Kernel): + return ProductKernel(*(list(self.kernels) + [other])) + else: + return NotImplemented + + def pformat(self): + return f"Product of ({[kernel.pformat() for kernel in self.kernels]})" + + +class ScalarMaternKernel(Kernel): + """ + Scalar half-integer order matern kernel + order = p+(1/2) + + Parameters: + p: int + variance > 0 + lengthscale > 0 + Internally stored as "raw_" after applying softplus_inverse. + """ + + core_matern: Callable = eqx.field(static=True) + p: int = eqx.field(static=True) + raw_variance: jax.Array + raw_lengthscale: jax.Array + min_lengthscale: jax.Array = eqx.field(static=True) + + def __init__(self, p: int, lengthscale=1.0, variance=1.0, min_lengthscale=0.01): + self.raw_variance = softplus_inverse(jnp.array(variance)) + if lengthscale < min_lengthscale: + raise ValueError("Initial lengthscale below minimum") + self.raw_lengthscale = softplus_inverse( + jnp.array(lengthscale) - min_lengthscale + ) + self.p = p + self.core_matern = build_matern_core(p) + self.min_lengthscale = min_lengthscale + + def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + var = softplus(self.raw_variance) + ls = softplus(self.raw_lengthscale) + self.min_lengthscale + scaled_diff = (y - x) / ls + return var * self.core_matern(scaled_diff) + + def pformat(self): + return ( + f"Matern kernel: order={self.p}, " + f"variance={softplus(self.raw_variance)}, " + f"lengthscale={softplus(self.raw_lengthscale) + self.min_lengthscale}" + ) + + +class GaussianRBFKernel(Kernel): + """ + RBF (squared exponential) kernel: + k(x, y) = variance * exp(-||x - y||^2 / (2*lengthscale^2)) + + Parameters: + variance > 0 + lengthscale > 0 + + Internally stored as "raw_" after applying softplus_inverse. Note that + lengthscale is 1/sqrt(2 * gamma), where gamma is what sklearn uses. + """ + + raw_variance: jax.Array + raw_lengthscale: jax.Array + min_lengthscale: jax.Array = eqx.field(static=True) + + def __init__(self, lengthscale=1.0, variance=1.0, min_lengthscale=0.01): + # Convert user-supplied positive parameters to unconstrained domain + if lengthscale < min_lengthscale: + raise ValueError("Initial lengthscale below minimum") + self.raw_variance = softplus_inverse(jnp.array(variance)) + self.raw_lengthscale = softplus_inverse(jnp.array(lengthscale)) + self.min_lengthscale = min_lengthscale + + def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + var = softplus(self.raw_variance) + ls = softplus(self.raw_lengthscale) + self.min_lengthscale + sqdist = jnp.sum((x - y) ** 2) + return var * jnp.exp(-0.5 * sqdist / (ls**2)) + + def pformat(self): + return ( + f"GaussianRBF kernel: variance={softplus(self.raw_variance)}, " + f"lengthscale={softplus(self.raw_lengthscale) + self.min_lengthscale}" + ) + + +class RationalQuadraticKernel(Kernel): + """ + Rational Quadratic kernel: + k(x, y) = variance * [1 + (||x - y||^2 / (2 * alpha * lengthscale^2))]^(-alpha) + + Parameters: + variance > 0 + lengthscale > 0 + alpha > 0 + Internally stored as "raw_" after applying softplus_inverse. + """ + + raw_variance: jax.Array + raw_lengthscale: jax.Array + raw_alpha: jax.Array + min_lengthscale: jax.Array = eqx.field(static=True) + + def __init__(self, lengthscale=1.0, alpha=1.0, variance=1.0, min_lengthscale=0.01): + self.raw_variance = softplus_inverse(jnp.array(variance)) + self.raw_lengthscale = softplus_inverse(jnp.array(lengthscale)) + self.raw_alpha = softplus_inverse(jnp.array(alpha)) + self.min_lengthscale = min_lengthscale + + def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + var = softplus(self.raw_variance) + ls = softplus(self.raw_lengthscale) + self.min_lengthscale + a = softplus(self.raw_alpha) + + sqdist = jnp.sum((x - y) ** 2) + factor = 1.0 + (sqdist / (2.0 * a * ls**2)) + return var * jnp.power(factor, -a) + + def pformat(self): + return ( + f"RationalQuadratic kernel: variance={softplus(self.raw_variance)}, " + f"lengthscale={softplus(self.raw_lengthscale) + self.min_lengthscale}, " + f"alpha={softplus(self.raw_alpha)}" + ) + + +class SpectralMixtureKernel(Kernel): + r""" + Spectral Mixture kernel for scalar inputs: + + .. math:: + k(\tau) = \sum_{m=1}^M w_m * \exp(-2 * (\pi*\sigma_m)^2 + * (x-y)^2) * \cos(2 \pi (x-y) * \text{periods_m}) + + where :math:`\tau = x - y`. + + Internally stored as "raw_" after applying softplus_inverse. + """ + + raw_weights: jnp.ndarray + raw_freq_sigmas: jnp.ndarray + periods: jnp.ndarray + + def __init__(self, key, num_mixture=20, period_variance=10.0): + key1, key2, key3 = jax.random.split(key, 3) + self.raw_weights = jax.random.normal(key1, shape=(num_mixture,)) + self.raw_freq_sigmas = jax.random.normal(key2, shape=(num_mixture,)) + self.periods = jnp.sqrt(period_variance) * jax.random.normal( + key3, shape=(num_mixture,) + ) + + def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + tau = x - y + weights = softplus(self.raw_weights) + freq_sigmas = softplus(self.raw_freq_sigmas) + + kernel_components = jnp.exp( + -2.0 * (jnp.pi * freq_sigmas) ** 2 * tau**2 + ) * jnp.cos(2.0 * jnp.pi * tau * self.periods) + return jnp.sum(weights * kernel_components) + + def pformat(self): + return ( + f"SpectralMixture kernel: " + f"weights={softplus(self.raw_weights)}, " + f"freq_sigmas={softplus(self.raw_freq_sigmas)}, " + f"periods={self.periods}" + ) + + +class ConstantKernel(Kernel): + """ + Constant kernel k(x, y) = c for all x, y. + + Params: + c, variance of the constant shift + Internally stored as "raw_" after applying softplus_inverse. + """ + + raw_constant: jnp.ndarray + + def __init__(self, variance: float = 1.0): + """ + :param constant: A positive float specifying the kernel's constant value. + """ + if variance <= 0: + raise ValueError("ConstantKernel requires a strictly positive constant.") + # Store an unconstrained parameter via softplus-inverse + self.raw_constant = softplus_inverse(jnp.array(variance)) + + def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + c = softplus(self.raw_constant) # guaranteed positive + return c + + def pformat(self): + return f"Constant kernel: constant={softplus(self.raw_constant)}, " + + +def get_gaussianRBF(gamma: float) -> Callable[[jax.Array, jax.Array], jax.Array]: + """ + Builds an RBF kernel function. + + Args: + gamma (double): Length scale of the RBF kernel. + + Returns: + function: This function returns the RBF kernel with fixed parameter gamma. + """ + warn( + "Instead of the functional API, how about a nice cuppa GaussianRBFKernel?", + DeprecationWarning, + ) + return partial(gaussian_rbf, gamma=gamma) + + +def gaussian_rbf(x, y, *, gamma: float): + return jnp.exp(-jnp.sum((x - y) ** 2) / (2 * gamma**2)) diff --git a/pysindy/sssindy/interpolants/kerneltools.py b/pysindy/sssindy/interpolants/kerneltools.py new file mode 100644 index 000000000..e4ae01c42 --- /dev/null +++ b/pysindy/sssindy/interpolants/kerneltools.py @@ -0,0 +1,108 @@ +from functools import partial +from types import ModuleType +from typing import Any +from typing import Callable + +import jax +import jax.numpy as jnp +from jax import grad + +# from jax import jit + +Float1D = Any +Float2D = Any +# Scalar Kernel is a PDS kernel on R +ScalarKernel = Callable[[float, float], float] +VectorizedScalarKernel = Callable[[Float1D, Float1D], Float2D] + + +def diagpart(M): + return jnp.diag(jnp.diag(M)) + + +def vectorize_kfunc(k: ScalarKernel) -> VectorizedScalarKernel: + return jax.vmap(jax.vmap(k, in_axes=(None, 0)), in_axes=(0, None)) + + +def op_k_apply(k: Callable[[float, float], float], L_op, R_op): + return R_op(L_op(k, 0), 1) + + +def make_block(k, L_op, R_op): + return vectorize_kfunc(op_k_apply(k, L_op, R_op)) + + +def get_kernel_block_ops( + k, ops_left, ops_right, output_dim=1, type_pkg: ModuleType = jnp +): + def k_super(x, y): + """ + Returns: + Kernel matrix with the layout: + (n_ops_left, output_dim, n_ops_right, output_dim), + reshaped to (n_ops_left * output_dim, n_ops_right * output_dim) + """ + I_mat = type_pkg.eye(output_dim) + blocks = [ + [ + type_pkg.kron(make_block(k, L_op, R_op)(x, y), I_mat) + for R_op in ops_right + ] + for L_op in ops_left + ] + return type_pkg.block(blocks) + + return k_super + + +def eval_k(k, index): + return k + + +def diff_k(k, index): + return grad(k, index) + + +def diff2_k(k, index): + return grad(grad(k, index), index) + + +def get_selected_grad(k, index, selected_index): + gradf = grad(k, index) + + def selgrad(*args): + return gradf(*args)[selected_index] + + return selgrad + + +def dx_k(k, index): + return get_selected_grad(k, index, 1) + + +def dxx_k(k, index): + return get_selected_grad(get_selected_grad(k, index, 1), index, 1) + + +def dt_k(k, index): + return get_selected_grad(k, index, 0) + + +def nth_derivative_1d(k: Callable, index: int, n: int) -> Callable: + """ + Computes derivative of order n of k with respect to index and returns the resulting + function as a callable + """ + result = k + for _ in range(n): + result = jax.grad(result, argnums=index) + return result + + +def nth_derivative_operator_1d(n): + """ + Computes the operator associated to the nth derivative, which maps functions to + functions. These now match the format of the operators defined above, like diff_k, + diff2_k. + """ + return partial(nth_derivative_1d, n=n) diff --git a/pysindy/sssindy/interpolants/matern.py b/pysindy/sssindy/interpolants/matern.py new file mode 100644 index 000000000..4fd47a4c1 --- /dev/null +++ b/pysindy/sssindy/interpolants/matern.py @@ -0,0 +1,84 @@ +import jax +import sympy as sym +import sympy2jax +from jax import custom_jvp +from sympy import factorial + + +def make_custom_jvp_function(f, fprime): + """Return a function with custom JVP defined by (fprime).""" + + @jax.custom_jvp + def f_wrapped(x): + return f(x) + + @f_wrapped.defjvp + def f_jvp(primals, tangents): + (x,) = primals + (x_dot,) = tangents + return f(x), fprime(x) * x_dot + + return f_wrapped + + +def make_sympy_callable(expr): + def inner(d): + return sympy2jax.SymbolicModule(expr)(d=d) + + return inner + + +def get_sympy_matern(p): + d2 = sym.symbols("d2", positive=True, real=True) + exp_multiplier = -sym.sqrt(2 * p + 1) + coefficients = [ + (factorial(p) / factorial(2 * p)) + * (factorial(p + i) / (factorial(i) * factorial(p - i))) + * (sym.sqrt(8 * p + 4)) ** (p - i) + for i in range(p + 1) + ] + powers = list(range(p, -1, -1)) + matern = sum( + [c * sym.sqrt((d2**power)) for c, power in zip(coefficients, powers)] + ) * sym.exp(exp_multiplier * sym.sqrt(d2)) + return d2, matern + + +def build_matern_core(p): + d2, matern = get_sympy_matern(p) + d = sym.var("d", pos=True, real=True) + + maternd = sym.powdenest(matern.subs(d2, d**2)) + subrule = { + d * sym.DiracDelta(d): 0, + sym.Abs(d) * sym.DiracDelta(d): 0, + sym.Abs(d) * sym.sign(d): d, + d * sym.sign(d): sym.Abs(d), + } + + def compute_next_derivative(expr): + return sym.powdenest(sym.expand(expr.diff(d).subs(subrule))).subs(subrule) + + derivatives = [compute_next_derivative(maternd)] + for k in range(2 * p - 1): + derivatives.append(compute_next_derivative(derivatives[-1])) + + jax_derivatives = [make_sympy_callable(f) for f in derivatives] + + wrapped_derivatives = [ + make_custom_jvp_function(f, fprime) + for f, fprime in zip(jax_derivatives[:-1], jax_derivatives[1:]) + ] + + matern_func_raw = sympy2jax.SymbolicModule(maternd) + core_matern = custom_jvp(lambda d: matern_func_raw(d=d)) + + @core_matern.defjvp + def core_matern_jvp(primals, tangents): + (x,) = primals + (x_dot,) = tangents + ans = core_matern(x) + ans_dot = wrapped_derivatives[0](x) * x_dot + return ans, ans_dot + + return core_matern diff --git a/pysindy/sssindy/interpolants/rkhs.py b/pysindy/sssindy/interpolants/rkhs.py new file mode 100644 index 000000000..a87ef4622 --- /dev/null +++ b/pysindy/sssindy/interpolants/rkhs.py @@ -0,0 +1,139 @@ +from typing import Any + +import jax +import jax.numpy as jnp +from jax.scipy.linalg import cholesky +from jax.scipy.linalg import solve_triangular +from typing_extensions import Self + +from .._typing import KernelFunc +from .base import LSQInterpolant +from .kernels import Kernel +from .kerneltools import diagpart +from .kerneltools import eval_k +from .kerneltools import get_kernel_block_ops +from .kerneltools import nth_derivative_operator_1d +from .utils import l2reg_lstsq + + +class RKHSInterpolant(LSQInterpolant): + """ + RKHS function in from R1 to Rd, modeling a d-dimensional trajectory as a function + of time. Uses a fixed basis, requires the time points that objective depends on + upon instantiation to build basis based on representer theorem. + """ + + kernel: KernelFunc | Kernel + derivative_orders: tuple[int, ...] + nugget: float + + def __init__( + self, + kernel: KernelFunc | Kernel, + derivative_orders: tuple[int, ...] = (0, 1), + nugget=1e-5, + ) -> None: + """ + dimension: Dimension of the system + time_points: time points that we include from basis from canonical feature map + derivative_orders: Orders of derivatives that we wish to model and include in + the basis. + """ + self.kernel = kernel + self.derivative_orders = derivative_orders + self.nugget = nugget + self.basis_operators = tuple( + nth_derivative_operator_1d(n) for n in self.derivative_orders + ) + + def fit_time(self, dimension: int, time_points: jax.Array) -> Self: + self.dimension = dimension + self.time_points = time_points + self.num_params = len(self.derivative_orders) * len(time_points) * dimension + + self.evaluation_kmat = get_kernel_block_ops( + self.kernel, (eval_k,), self.basis_operators, output_dim=self.dimension + ) + RKHS_mat = get_kernel_block_ops( + self.kernel, self.basis_operators, self.basis_operators, self.dimension + )(self.time_points, self.time_points) + self.gram_mat = RKHS_mat + self.nugget * diagpart(RKHS_mat) + + self.cholT = cholesky(self.gram_mat, lower=False) + + return self + + def fit_obs(self, t: jax.Array, obs: jax.Array, noise_var: float) -> jax.Array: + """Only works for fitting observations of the system, not derivatives.""" + if not hasattr(self, "gram_mat"): + raise ValueError( + "You must call fit_time before calling fit_obs. " + "fit_obs requires the gram matrix to be set up first." + ) + K_obs = get_kernel_block_ops( + k=self.kernel, + ops_left=(eval_k,), + ops_right=self.basis_operators, + output_dim=self.dimension, + )(t, self.time_points) + if noise_var == 0.0 and ( + jnp.any(t != self.time_points) + or len(self.basis_operators) != 1 + or self.basis_operators[0].func != nth_derivative_operator_1d(0).func # type: ignore # noqa: E501 + or self.basis_operators[0].args != nth_derivative_operator_1d(0).args # type: ignore # noqa: E501 + ): + raise ValueError( + "Cannot exactly interpolate unless if observation times" + "match basis times and no derivative operators are present." + ) + M = solve_triangular(self.cholT.T, K_obs.T, lower=True).T + params_chol_basis = l2reg_lstsq(M, obs.flatten(), reg=noise_var) + return solve_triangular(self.cholT, params_chol_basis, lower=False) + + def interpolate( + self, + x: jax.Array, + t: jax.Array, + t_colloc: jax.Array, + diff_order: int = 0, + noise_var: float = 0, + ) -> jax.Array: + """Fit a copy of this interpolant to observations x at time t + + This does not modify the original interpolant. + + Arguments: + x: Observations of the system at time t + t: Time points of the observations + t_colloc: Points at which to interpolate + diff_order: Order of the derivative to evaluate + noise_var: The variance of measurement noise error + + Returns: + A smooth nth-order derivative that interpolates the data + """ + proxy_interpolant = RKHSInterpolant(self.kernel, (0,), self.nugget) + proxy_interpolant.fit_time(x.shape[-1], t) + params = proxy_interpolant.fit_obs(t, x, noise_var=noise_var) + if diff_order == 0: + return proxy_interpolant(t_colloc, params) + return proxy_interpolant.derivative(t_colloc, params, diff_order) + + def _evaluate_operator(self, t, params, operator): + evaluation_matrix = get_kernel_block_ops( + k=self.kernel, + ops_left=(operator,), + ops_right=self.basis_operators, + output_dim=self.dimension, + )(t, self.time_points) + return evaluation_matrix @ params + + def __call__(self, t, params) -> Any: + return self._evaluate_operator(t, params, eval_k).reshape( + t.shape[0], self.dimension + ) + + def derivative(self, t, params, diff_order=1) -> Any: + return self._evaluate_operator( + t, params, nth_derivative_operator_1d(diff_order) + ).reshape(t.shape[0], self.dimension) diff --git a/pysindy/sssindy/interpolants/tree_opt.py b/pysindy/sssindy/interpolants/tree_opt.py new file mode 100644 index 000000000..f48a40c28 --- /dev/null +++ b/pysindy/sssindy/interpolants/tree_opt.py @@ -0,0 +1,144 @@ +from logging import getLogger +from warnings import warn + +import jax +import jax.numpy as jnp +from tqdm.auto import tqdm + +logger = getLogger(__name__) + + +def tree_dot(tree, other): + # Multiply corresponding leaves and sum each product over all its elements. + vdots = jax.tree.map(lambda x, y: jnp.sum(x * y), tree, other) + return jax.tree.reduce(lambda x, y: x + y, vdots, initializer=0.0) + + +def tree_add(tree, other): + return jax.tree.map(lambda x, y: x + y, tree, other) + + +def tree_scale(tree, scalar): + return jax.tree.map(lambda x: scalar * x, tree) + + +def build_armijo_linesearch(f, decrease_ratio=0.5, slope=0.05, max_iter=25): + def armijo_linesearch(x, f_curr, d, g, t0=0.1): + """ + x: current parameters (pytree) + f_curr: f(x) + d: descent direction (pytree) + g: gradient at x (pytree) + t0: initial step size + a: Armijo constant + """ + candidate = tree_add(x, tree_scale(d, -t0)) + dec0 = f(candidate) - f_curr + pred_dec0 = -t0 * tree_dot(d, g) + + # The loop state: (iteration, t, current decrease, predicted decrease) + init_state = (0, t0, dec0, pred_dec0) + + def cond_fun(state): + i, t, dec, pred_dec = state + # Continue while we haven't satisfied the Armijo condition and haven't + # exceeded max_iter iterations. + not_enough_decrease = dec >= slope * pred_dec + return jnp.logical_and(i < max_iter, not_enough_decrease) + + def body_fun(state): + i, t, dec, pred_dec = state + t_new = decrease_ratio * t + candidate_new = tree_add(x, tree_scale(d, -t_new)) + dec_new = f(candidate_new) - f_curr + pred_dec_new = -t_new * tree_dot(d, g) + return (i + 1, t_new, dec_new, pred_dec_new) + + # Run the while loop + i_final, t_final, dec_final, pred_dec_final = jax.lax.while_loop( + cond_fun, body_fun, init_state + ) + armijo_rat_final = dec_final / pred_dec_final + candidate_final = tree_add(x, tree_scale(d, -t_final)) + return candidate_final, t_final, armijo_rat_final + + return armijo_linesearch + + +def run_gradient_descent( + loss, init_params, init_stepsize=0.001, maxiter=10000, tol=1e-6, **kwargs +): + params = init_params + losses = [] + step_sizes = [] + gnorms = [] + + loss_valgrad = jax.value_and_grad(loss) + loss_fun = loss + armijo_linesearch = build_armijo_linesearch(loss_fun, **kwargs) + t = init_stepsize + + @jax.jit + def gd_update(params, t): + lossval, g = loss_valgrad(params) + new_params, new_t, armijo_rat = armijo_linesearch(params, lossval, g, g, t0=t) + gnorm = jnp.sqrt(tree_dot(g, g)) + return new_params, new_t, gnorm, lossval, armijo_rat + + for i in tqdm(range(maxiter)): + params, t, gnorm, lossval, armijo_rat = gd_update(params, t) + if armijo_rat < 0.01: + warn("Line search failed") + if i > 0: + if lossval > losses[-1]: + print(lossval) + losses.append(lossval) + step_sizes.append(t) + gnorms.append(gnorm) + if gnorm < tol: + break + if armijo_rat > 0.5: + t = 1.2 * t + if armijo_rat < 0.1: + t = t / 2 + + conv_history = { + "values": jnp.array(losses), + "stepsizes": jnp.array(step_sizes), + "gradnorms": jnp.array(gnorms), + } + return params, conv_history + + +def run_jaxopt_solver(solver, x0): + state = solver.init_state(x0) + sol = x0 + values, errors, stepsizes = [state.value], [state.error], [state.stepsize] + num_restarts = 0 + + @jax.jit + def update(sol, state): + return solver.update(sol, state) + + for iter_num in tqdm(range(solver.maxiter)): + sol, state = update(sol, state) + values.append(state.value) + errors.append(state.error) + stepsizes.append(state.stepsize) + if solver.verbose > 0: + print("Gradient Norm: ", state.error) + print("Loss Value: ", state.value) + if state.error <= solver.tol: + break + if stepsizes[-1] == 0: + num_restarts = num_restarts + 1 + print(f"Restart {num_restarts}") + if num_restarts > 10: + break + state = solver.init_state(sol) + convergence_data = { + "values": jnp.array(values), + "gradnorms": jnp.array(errors), + "stepsizes": jnp.array(stepsizes), + } + return sol, convergence_data, state diff --git a/pysindy/sssindy/interpolants/utils.py b/pysindy/sssindy/interpolants/utils.py new file mode 100644 index 000000000..bdf1ea96b --- /dev/null +++ b/pysindy/sssindy/interpolants/utils.py @@ -0,0 +1,21 @@ +import jax +import jax.numpy as jnp + + +@jax.jit +def l2reg_lstsq(A: jax.Array, y: jax.Array, reg: float = 1e-10): + r"""Solve the L2-regularized least squares problem + + .. math: + \|Ax - b\|^2 + reg * \|x\|^2 + + Args: + A: Explanatory variables/data matrix/regression matrix + y: Response variables/regression target + reg + """ + U, sigma, Vt = jnp.linalg.svd(A, full_matrices=False) + if jnp.ndim(y) == 2: + return Vt.T @ ((sigma / (sigma**2 + reg))[:, None] * (U.T @ y)) + else: + return Vt.T @ ((sigma / (sigma**2 + reg)) * (U.T @ y)) diff --git a/pysindy/sssindy/opt.py b/pysindy/sssindy/opt.py new file mode 100644 index 000000000..6789f9f30 --- /dev/null +++ b/pysindy/sssindy/opt.py @@ -0,0 +1,741 @@ +import time +from abc import ABC +from abc import abstractmethod +from dataclasses import dataclass +from dataclasses import field +from functools import partial +from typing import Callable +from typing import cast +from typing import Optional +from warnings import warn + +import jax +import jax.numpy as jnp +import numpy as np +from jax import lax +from jax.scipy.linalg import cho_factor +from jax.scipy.linalg import cho_solve +from jax.tree_util import register_dataclass +from tqdm.auto import tqdm +from typing_extensions import Self + +from ._skjax import register_scikit_pytree +from .expressions import ObjectiveResidual +from pysindy.optimizers import STLSQ +from pysindy.optimizers.base import _BaseOptimizer +from pysindy.optimizers.base import BaseOptimizer + + +@partial( + register_dataclass, + data_fields=[ + "params", + "loss", + "gradnorm", + "improvement_ratio", + "step_damping", + "regularization_loss", + "lin_sys_rel_resid", + ], + meta_fields=[], +) +@dataclass +class OptimizerStepResult: + """Contains information about a single optimization step.""" + + params: jax.Array + loss: float + gradnorm: float + improvement_ratio: float + step_damping: float + regularization_loss: float + lin_sys_rel_resid: float + + +def print_progress(iteration: int, step_result: OptimizerStepResult): + print( + f"Iteration {iteration}, loss = {step_result.loss: .4}, " + f" gradnorm = {step_result.gradnorm: .4}, " + f"alpha = {step_result.step_damping: .4}, " + f" improvement_ratio = {step_result.improvement_ratio: .4}" + ) + + +@dataclass +class LMSettings: + """Parameters controlling the behavior of LevenbergMarquardt optimization. + + max_iter: Max number of inner iterations. Must be positive, by default 501 + atol_gradnorm: Gradient norm stopping condition absolute tolerance + atol_gn_decrement: Gauss-Newton decrement stopping condition absolute tolerance + min_improvement: Minimum improvement ratio to accept in backtracking line search, + as used in the Armijo condition or Armijo-Goldstein condition. + By default 0.05. + search_increase_ratio: constant to increase reg strength by in backtracking + (proximal) search, by default 1.5. + max_search_iterations: maximum number of backtracking search iterations, by + default 20 + min_step_damping: minimum weight for penalty term for deviating from + previous iterates. Must be nonnegative, by default 1e-9 + max_step_damping : maximum weight for penalty term for deviating from + previous iterates. Must be nonnegative, by default 50. + init_step_damping: starting weight for penalty term for deviating + from previous iterates. Must be between ``min_step_damping`` and + ``max_step_damping`` by default 3. + step_adapt_multipler: value to use for adapting damping, by default 1.2 + callbacks: Functions that the iteration number and parameters as + argument and are called every iteration (by default, printing). + callback_every: How often to use callbacks, by default every 100 iterations. + track_iterates: Whether to save the optimizer value at each iteration + in the ConvergenceHistory. Increases memory usage + use_jit: Whether to jit functions inside optimization. + """ + + max_iter: int = 501 + atol_gradnorm: float = 1e-6 + atol_gn_decrement: float = 1e-12 + min_improvement: float = 0.05 + search_increase_ratio: float = 2.0 + max_search_iterations: int = 20 + min_step_damping: float = 1e-12 + max_step_damping: float = 100.0 + init_damp: float = 3.0 + step_adapt_multiplier: float = 1.2 + callbacks: tuple[Callable[[int, OptimizerStepResult], object], ...] = field( + default=(print_progress,) + ) + callback_every: int = 100 + track_iterates: bool = False + use_jit: bool = True + + +@dataclass +class STLSQLMSettings(LMSettings): + """ + stlsq_max_iter: Number of iterations for sequentially thresholded Ridge regression. + prox_reg: proximal regularizer strength for trajectory and dynamic params. + ridge_reg: regularizer strength for dynamics ridge regression. + threshold: truncation cutoff for dynamic params. + """ + + stlsq_max_iter: int = 50 + prox_reg: float = 1.0 + ridge_reg: float = 1e-2 + threshold: float = 0.5 + + +@dataclass +class ConvergenceHistory: + track_iterates: bool = False + loss_vals: list[float] = field(default_factory=list) + gradnorm: list[float] = field(default_factory=list) + iterate_history: list = field(default_factory=list) + improvement_ratios: list[float] = field(default_factory=list) + damping_vals: list[float] = field(default_factory=list) + cumulative_time: list = field(default_factory=list) + linear_system_rel_residual: list = field(default_factory=list) + regularization_loss_contribution: list = field(default_factory=list) + convergence_tag: str = "not-yet-run" + + def update( + self, + step_results: OptimizerStepResult, + cumulative_time: float, + ): + # Append the new values to the corresponding lists + self.loss_vals.append(step_results.loss) + self.gradnorm.append(step_results.gradnorm) + self.improvement_ratios.append(step_results.improvement_ratio) + self.damping_vals.append(step_results.step_damping) + self.cumulative_time.append(cumulative_time) + self.linear_system_rel_residual.append(step_results.lin_sys_rel_resid) + self.regularization_loss_contribution.append(step_results.regularization_loss) + + # Conditionally track iterates if enabled + if self.track_iterates: + self.iterate_history.append(step_results.params) + + def finish(self, convergence_tag="finished"): + self.convergence_tag = convergence_tag + + +@partial( + register_dataclass, + meta_fields=[], + data_fields=["J", "residuals", "loss", "loss_hess_appx", "loss_const_grad"], +) +@dataclass +class ObjEvaluation: + """The evaluation of an optimization problem at a particular point""" + + J: jax.Array + residuals: jax.Array + loss: float + loss_hess_appx: jax.Array + loss_const_grad: jax.Array + + def add_regularization(self, other: Self) -> Self: + """Add evaluation of a regularizer (which does not affect residuals)""" + return self.__class__( + self.J, + self.residuals, + self.loss + other.loss, + self.loss_hess_appx + other.loss_hess_appx, + self.loss_const_grad + other.loss_const_grad, + ) + + +def _evaluate_objective(params: jax.Array, problem: ObjectiveResidual) -> ObjEvaluation: + r"""Create a linear least-squares problem from a nonlinear objective. + + It evaluates quantities in the problem: + + .. math:: + \min_u 1/2\cdot\|F(x+u)\|^2 + + where :math:`F` is the residual function + + Args: + params: the current value of the optimization variable, :math:`x` + problem: The nonlinear objective + + Note: + The Jacobian is the most expensive part of this computation. If merely + the residuals or loss are required, it makes sense to call those + independently. + + Returns: + An ObjEvaluation with the Jacobian matrix :math:`df(x) / dx`, residuals + :math:`F(x)`, the loss value :math: `1/2\cdot\|F(x)\|^2`, the approximate hessian + of the LSQ loss with respect to the parameters, :math:`J^TJ`, and the gradient + of the total loss (incl regularizer) with respect to a parameter step (at origin). + """ + J = problem.jac_func(params) + residuals = problem.resid_func(params) + loss = cast(float, 0.5 * jnp.sum(residuals**2)) + JtJ = J.T @ J + rhs = J.T @ residuals + return ObjEvaluation(J, residuals, loss, JtJ, rhs) + + +@partial( + register_scikit_pytree, + data_fields=[], + data_fit_fields=["prob", "mat_weight"], + meta_fields=[], + meta_fit_fields=[], +) +@dataclass +class _LMRegularizer(ABC): + """A global regularizer for Levenberg-Marquardt optimization. + + Note that child classes need to be re-registered with jax to be jittable + + Args: + prox_reg: The scalar regularization strength + + Attributes: + mat_weight: The weight matrix for the regularizer. E.g. Gram matrix, + elliptic norm matrix. Must be positive definite. + """ + + def fit(self, problem: ObjectiveResidual) -> Self: + """Assign the parts of the regularizer that do not change during iteration.""" + self.mat_weight = problem.damping_matrix + self.prob = problem + return self + + @abstractmethod + def eval(self, params: jax.Array) -> ObjEvaluation: + """Evaluate the regularizer at a particular point + + Part of ``ObjEvaluation`` is the Jacobian and residual. A regularizer + should supply float zero for the Jacobian and residual components. + """ + ... + + @abstractmethod + def step( + self, params: jax.Array, curr_vals: ObjEvaluation + ) -> tuple[jax.Array, jax.Array, float]: + """Calculate the a step of the optimization problem using this regularizer + + Args: + params: The current value of the optimization variable + curr_vals: The evaluation of the LSQ part of the optimization problem + at ``params`` + Returns: + tuple of the negative parameter step, the estimated negative residual step, + and the linear system residual + """ + ... + + +@partial( + register_scikit_pytree, + data_fields=["prox_reg"], + data_fit_fields=["prob", "mat_weight"], + meta_fields=[], + meta_fit_fields=[], +) +@dataclass +class L2CholeskyLMRegularizer(_LMRegularizer): + prox_reg: float + + @jax.jit + def eval(self, params: jax.Array) -> ObjEvaluation: + loss = 0.5 * cast(float, self.prox_reg * params.T @ self.mat_weight @ params) + loss_hess = self.prox_reg * self.mat_weight + loss_grad = loss_hess @ params + return ObjEvaluation(0, 0, loss, loss_hess, loss_grad) # type: ignore + + @jax.jit + def step( + self, params: jax.Array, curr_vals: ObjEvaluation + ) -> tuple[jax.Array, jax.Array, float]: + Mchol = cho_factor(curr_vals.loss_hess_appx) + step = cho_solve(Mchol, curr_vals.loss_const_grad) + resid_step = curr_vals.J @ step + + linear_residual = ( + curr_vals.J.T @ (resid_step - curr_vals.residuals) + + curr_vals.loss_hess_appx @ step + - curr_vals.loss_const_grad + ) + linear_residual = jnp.linalg.norm(linear_residual) / jnp.linalg.norm( + curr_vals.loss_const_grad + ) + + return step, resid_step, linear_residual + + +@partial( + register_scikit_pytree, + data_fields=["prox_reg"], + data_fit_fields=["prob", "mat_weight"], + meta_fields=["theta_optimizer"], + meta_fit_fields=["n_proc", "n_theta"], +) +@dataclass +class SINDyAlternatingLMReg(L2CholeskyLMRegularizer): + theta_optimizer: BaseOptimizer + + def fit(self, problem: ObjectiveResidual) -> Self: + self.prob = problem + self.n_proc = self.prob.full_n_process + self.n_theta = self.prob.full_n_theta + proc_block = problem.damping_matrix[: self.n_proc, : self.n_proc] + triagonal_block = jnp.zeros((self.n_proc, self.n_theta)) + theta_block = jnp.zeros((self.n_theta, self.n_theta)) + self.mat_weight = jnp.block( + [[proc_block, triagonal_block], [triagonal_block.T, theta_block]] + ) + return self + + def step( + self, params: jax.Array, curr_vals: ObjEvaluation + ) -> tuple[jax.Array, jax.Array, float]: + sys_dim = self.prob.system_dim + n_meas = self.prob.n_meas + + # jittable? + def _inner_wrap(self, params, curr_vals): + n_resid, var_len = curr_vals.J.shape + proc = lax.dynamic_slice(params, (0,), (self.n_proc,)) + theta = lax.dynamic_slice(params, (self.n_proc,), (self.n_theta,)) + + loss_hess_appx = lax.dynamic_slice( + curr_vals.loss_hess_appx, (0, 0), (self.n_proc, self.n_proc) + ) + Mchol = cho_factor(loss_hess_appx) + loss_const_grad = lax.dynamic_slice( + curr_vals.loss_const_grad, (0,), (self.n_proc,) + ) + proc_step = cho_solve(Mchol, loss_const_grad) + new_proc = lax.dynamic_slice(params, (0,), (self.n_proc,)) - proc_step + + J_theta = lax.slice(curr_vals.J, (0, self.n_proc), (n_resid, var_len)) + J_proc = lax.dynamic_slice(curr_vals.J, (0, 0), (n_resid, self.n_proc)) + feat_evals_appx = lax.slice( + J_theta, (n_meas, 0), (n_resid, self.n_theta), (3, 3) + ) + derivative_appx = lax.slice( + J_theta @ theta - J_proc @ (new_proc - proc) - curr_vals.residuals, + (n_meas,), + (n_resid,), + ).reshape((-1, sys_dim)) + return theta, proc_step, feat_evals_appx, derivative_appx + + theta, proc_step, feat_evals_appx, derivative_appx = _inner_wrap( + self, params, curr_vals + ) + if jnp.isnan(proc_step).any(): + return ( + jnp.nan * jnp.ones_like(params), + jnp.nan * jnp.ones_like(curr_vals.residuals), + jnp.nan, + ) + feat_evals_appx = np.array(feat_evals_appx) + derivative_appx = np.array(derivative_appx) + theta_new = self.theta_optimizer.fit( + feat_evals_appx, derivative_appx + ).coef_.T.flatten() + theta_step = theta - theta_new + + step = jnp.hstack((proc_step, theta_step)) + resid_step = curr_vals.J @ step + + linear_residual = ( + curr_vals.J.T @ (resid_step - curr_vals.residuals) + + curr_vals.loss_hess_appx @ step + - curr_vals.loss_const_grad + ) + linear_residual = jnp.linalg.norm(linear_residual) / jnp.linalg.norm( + curr_vals.loss_const_grad + ) + + return step, resid_step, linear_residual + + +def LevenbergMarquardt( + init_params: jax.Array, + problem: ObjectiveResidual, + regularizer: _LMRegularizer, + opt_settings: LMSettings = LMSettings(), +) -> tuple[jax.Array, ConvergenceHistory]: + """Adaptively regularized Levenberg Marquardt optimizer + Parameters + ---------- + init_params: initial guess + model : + Object that contains model.F, and model.jac, and model.damping_matrix + beta : float + (global) regularization strength + reg_weight: Amount of global regularization to apply. Must be positive. + opt_settings: optimizer settings + + Returns + ------- + A tuple of the solution and convergence information + """ + conv_history = ConvergenceHistory(opt_settings.track_iterates) + params = init_params.copy() + step_damping = opt_settings.init_damp + regularizer.fit(problem) + + # Zeroth Step + if opt_settings: + _evaluate_objective = jax.jit(globals()["_evaluate_objective"]) + result = cast(ObjEvaluation, _evaluate_objective(params, problem)) + reg_result = regularizer.eval(params) + result = result.add_regularization(reg_result) + gradnorm = jnp.linalg.norm(result.loss_const_grad) + zeroth_step = OptimizerStepResult( + params, result.loss, gradnorm, 1.0, opt_settings.init_damp, reg_result.loss, 0.0 + ) + conv_history.update(zeroth_step, cumulative_time=0.0) + + start_time = time.time() + for i in tqdm(range(opt_settings.max_iter), leave=False): + step_result, succeeded = _LevenbergMarquardtUpdate( + params, step_damping, problem, regularizer, opt_settings + ) + + step_damping = ( + jnp.maximum(1 / 3, 1 - (2 * step_result.improvement_ratio - 1) ** 3) + * step_result.step_damping + ) + if not succeeded: + warn(f"Search Failed on iteration {i}! Final Iteration Results:") + conv_history.finish(convergence_tag="failed-line-search") + return params, conv_history + + params = step_result.params + model_decrease = ( + conv_history.loss_vals[-1] - step_result.loss + ) / step_result.improvement_ratio + conv_history.update(step_result, time.time() - start_time) + + if step_result.gradnorm <= opt_settings.atol_gradnorm: + conv_history.finish("atol-gradient-norm") + break + elif model_decrease * (1 + step_damping) <= opt_settings.atol_gn_decrement: + conv_history.finish("atol-gauss-newton-decrement") + break + if i % opt_settings.callback_every == 0 or i <= 5: + for callback in opt_settings.callbacks: + callback(i, step_result) + else: + conv_history.finish(convergence_tag="maximum-iterations") + + for callback in opt_settings.callbacks: + callback(i, step_result) + return params, conv_history + + +def _take_prox_step( + params: jax.Array, + step_damp: float, + curr_objdata: ObjEvaluation, + problem: ObjectiveResidual, + global_reg: _LMRegularizer, + gradnorm: float, +) -> OptimizerStepResult: + local_reg = L2CholeskyLMRegularizer(step_damp) + local_reg.fit(problem) + # zeros_like reflects that variable is step, not params + local_vals = curr_objdata.add_regularization(local_reg.eval(jnp.zeros_like(params))) + step, resid_step, linear_residual = global_reg.step(params, local_vals) + # step is negative, because curr_vals.loss_grad gets subtracted to get rhs + new_params = params - step + reg_loss = global_reg.eval(new_params).loss + new_loss = cast( + float, 0.5 * jnp.sum(problem.resid_func(new_params) ** 2) + reg_loss + ) + pred_loss = cast( + float, 0.5 * jnp.sum((resid_step - curr_objdata.residuals) ** 2) + reg_loss + ) + improvement_ratio = (curr_objdata.loss - new_loss) / (curr_objdata.loss - pred_loss) + return OptimizerStepResult( + new_params, + new_loss, + gradnorm, + improvement_ratio, + step_damp, + reg_loss, + linear_residual, + ) + + +def _LevenbergMarquardtUpdate( + params: jax.Array, + init_damp: float, + problem: ObjectiveResidual, + global_reg: _LMRegularizer, + opt_settings: LMSettings, +) -> tuple[OptimizerStepResult, bool]: + + r"""Regularizes and minimizes the local quadratic approximation of a problem + + .. math:: + \min_x 1/2\cdot \|\widetilde F(x)\|^2 + R(x) + + where :math:`\widetilde F` is the linear approximation of the residual. This + function enforces locality with an additional damping term based upon distance + from the previous iterate. If the new iterate does not improve the loss + sufficiently, the damping term is increased. + Args: + params: Current parametrization value of function to approximate + init_damp: Starting damping strength. Larger values shrink the step size. + problem: The optimization problem to solve, + global_reg: The regularizer, which should know how to evaluate itself + at a point and how to iterate when combined with a least-squares + ObjEvaluation. + opt_settings: damping adaption and termination criteria. + """ + # Values that don't change during proximity search + if opt_settings: + _evaluate_objective = jax.jit(globals()["_evaluate_objective"]) + curr_objdata = cast(ObjEvaluation, _evaluate_objective(params, problem)) + curr_objdata = curr_objdata.add_regularization(global_reg.eval(params)) + gradnorm = jnp.linalg.norm(curr_objdata.loss_const_grad) + step_damp = cast( + float, + jnp.clip( + init_damp, + opt_settings.min_step_damping, + opt_settings.max_step_damping, + ), + ) + + for i in range(opt_settings.max_search_iterations): + step_result = _take_prox_step( + params, step_damp, curr_objdata, problem, global_reg, gradnorm + ) + + if step_result.improvement_ratio >= opt_settings.min_improvement: + return step_result, True + else: + step_damp = opt_settings.search_increase_ratio * step_damp + return step_result, False + + +def STLSQ_solve(u0, theta0, residual_objective, beta, optSettings): + conv_history = ConvergenceHistory(track_iterates=optSettings.track_iterates) + + @jax.jit + def F_split(u, theta): + return residual_objective.F_split( + [u], + theta.reshape( + residual_objective.num_features, residual_objective.system_dim + ), + ) + + def phi(u, theta): + return 0.5 * jnp.sum(F_split(u, theta) ** 2) + + # def data_mse(u): + # mse = jnp.mean() + + @jax.jit + def evaluate_objective(u, theta): + Fval = F_split(u, theta) + Ju = jax.jacrev(F_split, argnums=0)(u, theta) + Jtheta = jax.jacrev(F_split, argnums=1)(u, theta) + return Fval, Ju, Jtheta + + loop_wrapper = tqdm + + max_iter = optSettings.stlsq_max_iter + rho = optSettings.prox_reg + alpha = optSettings.ridge_reg + lam = optSettings.threshold + + u, theta = u0, theta0 + K = residual_objective.state_param_regmat[: len(u), : len(u)] + + loss_vals = [] + for k in loop_wrapper(range(max_iter)): + u_old = u + + Fval, Ju, Jtheta = evaluate_objective(u, theta) + rhs_u = rho * K @ u_old - Ju.T @ (Fval - Ju @ u) + + u = jax.scipy.linalg.solve(Ju.T @ Ju + (rho + beta) * K, rhs_u, assume_a="pos") + + rhs_ridge = -Fval + Jtheta @ theta - Ju @ (u - u_old) + stlsq_opt = STLSQ(threshold=lam, alpha=alpha) + stlsq_opt.ind_ = np.ones_like(np.array(theta), dtype=int) + stlsq_opt.fit(x_=np.array(Jtheta), y=np.array(rhs_ridge)) + theta = jnp.array(stlsq_opt.coef_)[0] + + phik = phi(u, theta) + loss = phik + 0.5 * beta * u.T @ K @ u + params = jnp.hstack((u, theta)) + result = OptimizerStepResult(params, loss, 0.0, 0.0, 0.0, 0.0, 0.0) + conv_history.update(result, 0.0) + loss_vals.append(loss) + + return params, conv_history + + +class _BaseSSOptimizer(_BaseOptimizer, ABC): + process_: jax.Array + + @abstractmethod + def fit(x, y, init_params: Optional[jax.Array] = None) -> Self: + ... + + +@dataclass +class STLSQLMSolver(_BaseSSOptimizer): + beta: float = 1e-12 + optimizer_settings: STLSQLMSettings = field(default_factory=STLSQLMSettings) + + def fit( + self, + residual_objective: ObjectiveResidual, + init_params: Optional[jax.Array] = None, + *args, + **kwargs, + ): + """ + Arguments: + residual_objective: A tuple of data residual and dynamics residual + """ + full_n_kernel = residual_objective.full_n_process + full_n_theta = residual_objective.full_n_theta + system_dimension = residual_objective.system_dim + + init_params = jnp.zeros(full_n_kernel + full_n_theta) + + regularizer = L2CholeskyLMRegularizer(self.beta) + + params, history = LevenbergMarquardt( + init_params, + residual_objective, + regularizer, + self.optimizer_settings, + ) + + u0 = params[:full_n_kernel] + theta0 = params[full_n_kernel:] + + params, stlsq_history = STLSQ_solve( + u0=u0, + theta0=theta0, + residual_objective=residual_objective, + beta=self.beta, + optSettings=self.optimizer_settings, + ) + + self.all_params_ = params + self.coef_ = np.array(params[full_n_kernel:].reshape(-1, system_dimension).T) + self.process_ = params[:full_n_kernel] + self.history_ = history + self.stlsq_ = stlsq_history + + return self + + +@dataclass +class LMSolver(_BaseSSOptimizer): + """A Levenberg-Marquardt solver for single-step SINDy problems. + + It iterates by linearizing the data and dynamics loss around the previous + iteration's values, adding the process and SINDy regularizers, and keeping + the next iteration "close" to the previous iteration. This results in a + local quadratic approximation to the loss. + + The "closeness" allowed is defined adaptively by a penalty term: when an + iteration does not improve the true minimum nearly as much as the + approximator minimum, it tightens the penalty for moving. + + Currently applies an RKHS norm on the process terms and an L-2 norm on the + SINDy terms. + + Attributes: + reg_weight: overall regularization coeffficient + optimizer_settings: Settings for optimizer. See LMSettings for more details + """ + + regularizer: _LMRegularizer = field( + default_factory=lambda: L2CholeskyLMRegularizer(1e-12) + ) + optimizer_settings: LMSettings = field(default_factory=LMSettings) + + def fit( + self, + residual_objective: ObjectiveResidual, + init_params: Optional[jax.Array] = None, + ): + """ + Arguments: + residual_objective: A tuple of data residual and dynamics residual + """ + full_n_process = residual_objective.full_n_process + full_n_theta = residual_objective.full_n_theta + system_dimension = residual_objective.system_dim + + if init_params is None: + init_params = jnp.zeros(full_n_process + full_n_theta) + + params, history = LevenbergMarquardt( + init_params, + residual_objective, + self.regularizer, + opt_settings=self.optimizer_settings, + ) + + self.all_params_ = params + self.coef_ = np.array(params[full_n_process:].reshape(-1, system_dimension).T) + self.process_ = params[:full_n_process] + if np.isnan(params).any(): + self.coef_ = np.zeros_like(self.coef_) + raise ValueError("Optimization resulted in Nans, fit is unreliable") + if jnp.isnan(self.process_).any(): + self.process_ = jnp.zeros_like(self.process_) + self.history_ = history + + return self diff --git a/pysindy/sssindy/opt_attic.py b/pysindy/sssindy/opt_attic.py new file mode 100644 index 000000000..6aebf4587 --- /dev/null +++ b/pysindy/sssindy/opt_attic.py @@ -0,0 +1,84 @@ +import jax +import jax.numpy as jnp +import numpy as np +from jax import jit +from tqdm.auto import tqdm + + +def run_jax_solver(solver, x0): + state = solver.init_state(x0) + sol = x0 + values, errors, stepsizes = [state.value], [state.error], [state.stepsize] + + def update(sol, state): + return solver.update(sol, state) + + jitted_update = jax.jit(update) + for iter_num in tqdm(range(solver.maxiter)): + sol, state = jitted_update(sol, state) + values.append(state.value) + errors.append(state.error) + stepsizes.append(state.stepsize) + if solver.verbose > 0: + print("Gradient Norm: ", state.error) + print("Loss Value: ", state.value) + if state.error <= solver.tol: + break + if stepsizes[-1] == 0: + print("Restart") + state = solver.init_state(sol) + convergence_data = { + "values": np.array(values), + "gradnorms": np.array(errors), + "stepsizes": np.array(stepsizes), + } + return sol, convergence_data, state + + +@jit +def l2reg_lstsq(A, y, reg=1e-10): + U, sigma, Vt = jnp.linalg.svd(A, full_matrices=False) + return Vt.T @ ((sigma / (sigma**2 + reg)) * (U.T @ y)) + + +def refine_solution( + params, equation_model, reg_sequence=10 ** (jnp.arange(-4.0, -18, -0.5)) +): + """Refines solution with almost pure gauss newton through SVD""" + refinement_losses = [] + refined_params = params.copy() + for reg in tqdm(reg_sequence): + J = equation_model.jac(refined_params) + F = equation_model.F(refined_params) + refined_params = refined_params - l2reg_lstsq(J, F, reg) + refinement_losses += [equation_model.loss(refined_params)] + return refined_params, jnp.array(refinement_losses) + + +def adaptive_refine_solution( + params, equation_model, initial_reg=1e-4, num_iter=100, mult=0.7 +): + refinement_losses = [equation_model.loss(params)] + refined_params = params.copy() + reg_vals = [initial_reg] + reg = initial_reg + for i in tqdm(range(num_iter)): + J = equation_model.jac(refined_params) + F = equation_model.F(refined_params) + U, sigma, Vt = jnp.linalg.svd(J, full_matrices=False) + + candidate_regs = [mult * reg, reg, reg / mult] + candidate_steps = [ + Vt.T @ ((sigma / (sigma**2 + S)) * (U.T @ F)) for S in candidate_regs + ] + + loss_vals = jnp.array( + [equation_model.loss(refined_params - step) for step in candidate_steps] + ) + choice = jnp.argmin(loss_vals) + reg = candidate_regs[choice] + step = candidate_steps[choice] + refined_params = refined_params - step + refinement_losses.append(loss_vals[choice]) + reg_vals.append(reg) + return refined_params, jnp.array(refinement_losses), jnp.array(reg_vals) diff --git a/pysindy/sssindy/sssindy.py b/pysindy/sssindy/sssindy.py new file mode 100644 index 000000000..45fee92d0 --- /dev/null +++ b/pysindy/sssindy/sssindy.py @@ -0,0 +1,227 @@ +from typing import Callable +from typing import cast +from typing import Optional + +import jax +import jax.numpy as jnp +import numpy as np +from jax._src.prng import PRNGKeyArray +from jax.random import normal +from scipy.integrate import solve_ivp +from sklearn.metrics import r2_score +from sklearn.pipeline import Pipeline + +from .expressions import JaxPolyLib +from .expressions import JointObjective +from .expressions import ObjectiveResidual +from .interpolants import get_gaussianRBF +from .interpolants import RKHSInterpolant +from .opt import _BaseSSOptimizer +from .opt import LMSolver +from pysindy._core import _adapt_to_multiple_trajectories +from pysindy._core import _BaseSINDy +from pysindy._core import _check_multiple_trajectories +from pysindy._core import TrajectoryType +from pysindy.optimizers import BaseOptimizer + +StrategySpec = ( + str | jax.Array | np.ndarray | PRNGKeyArray | BaseOptimizer | _BaseSSOptimizer +) + + +class SSSINDy(_BaseSINDy): + + optimizer: _BaseSSOptimizer + expression: JointObjective + feature_library: JaxPolyLib + feature_names: Optional[list[str]] + init_strategy: StrategySpec + + def __init__( + self, + expression: JointObjective = JointObjective( + data_weight=3.0, + dynamics_weight=1.0, + lib=JaxPolyLib(2), + interp_template=RKHSInterpolant(get_gaussianRBF(0.2)), + ), + optimizer: _BaseSSOptimizer = LMSolver(), + init_strategy: StrategySpec = "zeros", + feature_names: Optional[list[str]] = None, + ): + super().__init__() + self.expression = expression + self.optimizer = optimizer + self.feature_library = expression.lib + self.feature_names = feature_names + self.init_strategy = init_strategy + + def fit( + self, + x: list[jax.Array], + t: list[jax.Array], + t_coloc: Optional[list[jax.Array]] = None, + ): + if t_coloc is None: + t_coloc = t + if not _check_multiple_trajectories(x, None, None): + x, t, t_coloc, u = _adapt_to_multiple_trajectories(x, t, t_coloc, None) + t_coloc = cast(list[jax.Array], t_coloc) + + self.n_control_features_ = 0 # cannot yet fit control features + self.model = Pipeline( + [("expression", self.expression), ("optimizer", self.optimizer)] + ) + + objective = self.expression.fit(x, t, t_coloc=t_coloc).transform(x, t, t_coloc) + + init_params = _initialize_params( + self.init_strategy, self.expression, objective, x, t, t_coloc + ) + + self.optimizer.fit(objective, init_params=init_params) + self.fitted = True + self.n_control_features_ = 0 + self._fit_shape() + + def x_predict(self, t): + return [ + interp(t, params=self.optimizer.process_[slise]) + for interp, slise in zip( + self.expression.traj_interps, self.expression.traj_coef_slices_ + ) + ] + + def predict(self, x: jax.Array) -> jax.Array: + r"""Predict the time derivative \dot{x} = f(x). Later include time dependence""" + feats = self.feature_library.transform( + x.reshape(-1, self.feature_library.n_features_in_) # type: ignore + ) + return feats @ self.optimizer.coef_.T # type: ignore + + def coefficients(self): + if hasattr(self, "fitted") and self.fitted: + return self.optimizer.coef_ + else: + raise ValueError("Must run fit() first.") + + def simulate(self, x0: jax.Array, t: jax.Array, **integrator_kws) -> jax.Array: + def rhs(t, x): + return self.predict(x[np.newaxis, :])[0] + + return ((solve_ivp(rhs, (t[0], t[-1]), x0, t_eval=t, **integrator_kws)).y).T + + def score( + self, + x: TrajectoryType, + t: TrajectoryType, + x_dot: TrajectoryType, + metric: Callable[[jax.Array, jax.Array], float] = r2_score, # type: ignore + ) -> float: + if not isinstance(x, list): + x = [x] # type: ignore + if not isinstance(t, list): + t = [t] # type: ignore + if not isinstance(x_dot, list): + x_dot = [x_dot] # type: ignore + if x_dot is None: + x_dot = jnp.vstack([self.predict(x_i) for x_i in x]) + x_dot_predict = jnp.vstack(x_dot) + x_dot_interp = jnp.vstack( + [ + interp.derivative(t_i, self.optimizer.process_[slise]) + for t_i, interp, slise in zip( + t, self.expression.traj_interps, self.expression.traj_coef_slices_ + ) + ] + ) + + return metric(x_dot_interp, x_dot_predict) + + +def _initialize_params( + init_strategy: StrategySpec, + expression: JointObjective, + objective: ObjectiveResidual, + x: Optional[list[jax.Array]] = None, + t: Optional[list[jax.Array]] = None, + t_colloc: Optional[list[jax.Array]] = None, +) -> jax.Array: + """Initialize the optimization parameter using a variety of strategies. + + Args: + init_strategy: The strategy to use for initialization. It can be a: + - string: "zeros" or "ones" + - jax.Array: A jax array of initial values. + - np.ndarray: A numpy array of initial values. + - PRNGKeyArray: A jax random key for generating normally + distributed random values. + - BaseOptimizer: When sending a BaseOptimizer, the data residual + is first used to fit the process coefficients, then the BaseOptimizer + is used to fit the SINDy dynamics coefficients. + - _BaseSSOptimizer: Fit using an initial joint SINDy optimizer. + expression: The expression object. Depending on the strategy, this + may be safely passed as None. + objective: The objective residual object. Depending on the strategy, this + may be safely passed as None. Otherwise, it needs to be consistent with + the expression. + t_coloc: The collocation points. + Returns: + Initial values for the optimization parameter. + """ + if init_strategy == "zeros": + init_params = jnp.zeros(objective.full_n_process + objective.full_n_theta) + elif init_strategy == "ones": + init_params = jnp.ones(objective.full_n_process + objective.full_n_theta) + elif isinstance(init_strategy, PRNGKeyArray): + # Match order matters here: PRNGKeyArray is a jax.Array, so it would match + # the next condition if we don't check for it first. + init_params = normal( + init_strategy, shape=(objective.full_n_process + objective.full_n_theta,) + ) + elif isinstance(init_strategy, jax.Array): + init_params = init_strategy + elif isinstance(init_strategy, np.ndarray): + init_params = jnp.array(init_strategy) + elif isinstance(init_strategy, BaseOptimizer): + + if not ( + isinstance(x, list) and isinstance(t, list) and isinstance(t_colloc, list) + ): + raise ValueError( + "If init_strategy is a BaseOptimizer, x, t, and t_coloc must be " + "provided as lists." + ) + init_proc = [] + init_colloc = [] + init_colloc_d = [] + for traj, traj_x, traj_t, traj_tc in zip( + expression.traj_interps, x, t, t_colloc + ): + traj.fit_time(expression.system_dimension, traj_t) + tparams = traj.fit_obs(traj_t, traj_x, noise_var=1e-5) + init_proc.append(tparams.flatten()) + init_colloc.append(traj(traj_tc, tparams)) + init_colloc_d.append(traj.derivative(traj_tc, tparams)) + + init_proc = jnp.hstack(init_proc) + x_ = np.vstack(expression.lib.transform(init_colloc)) + y_ = np.vstack(init_colloc_d) + init_theta = init_strategy.fit(x_, y_).coef_ + init_params = jnp.hstack((init_proc, init_theta.T.flatten())) + elif isinstance(init_strategy, list): + params = jnp.zeros(objective.full_n_process + objective.full_n_theta) + if not init_strategy: + raise ValueError("If init_strategy is a list, it must not be empty. ") + for strat in init_strategy: + if not isinstance(strat, _BaseSSOptimizer): + raise ValueError( + f"If init_strategy is a list, all elements must be " + f"of type _BaseSSOptimizer. Got {type(strat)}." + ) + strat.fit(objective) + params = jnp.hstack((strat.process_, strat.coef_.flatten())) + init_params = params + else: + raise ValueError(f"init_strategy: {init_strategy} not understood. ") + return init_params diff --git a/test/test_inspect_to_sympy.py b/test/test_inspect_to_sympy.py new file mode 100644 index 000000000..e24bc1072 --- /dev/null +++ b/test/test_inspect_to_sympy.py @@ -0,0 +1,20 @@ +import pytest + +try: + from dysts.flows import Lorenz +except Exception: # pragma: no cover - skip if dysts not installed + pytest.skip( + "dysts not available; skipping inspect_to_sympy tests", allow_module_level=True + ) + +from asv_bench.benchmarks.inspect_to_sympy import dynsys_to_sympy + + +def test_lorenz_to_sympy(): + lor = Lorenz() + symbols, exprs, lambda_rhs = dynsys_to_sympy(lor, func_name="_rhs") + assert len(symbols) == lor.dimension + # evaluate lambda with simple numeric values + vals = tuple(float(i + 1) for i in range(lor.dimension)) + mat = lambda_rhs(*vals) + assert mat.shape[0] == lor.dimension diff --git a/test/test_sss.py b/test/test_sss.py new file mode 100644 index 000000000..7b35c4f42 --- /dev/null +++ b/test/test_sss.py @@ -0,0 +1,191 @@ +import jax +import jax.numpy as jnp +import jax.random as jrp +import numpy as np +import pytest +from scipy.integrate import solve_ivp +from sklearn.kernel_ridge import KernelRidge + +from pysindy import STLSQ +from pysindy.sssindy import JaxPolyLib +from pysindy.sssindy import JointObjective +from pysindy.sssindy import LMSolver +from pysindy.sssindy import SSSINDy +from pysindy.sssindy.interpolants import GaussianRBFKernel +from pysindy.sssindy.interpolants import get_gaussianRBF +from pysindy.sssindy.interpolants import RKHSInterpolant +from pysindy.sssindy.interpolants.base import MockInterpolant +from pysindy.sssindy.opt import _evaluate_objective +from pysindy.sssindy.opt import _LMRegularizer +from pysindy.sssindy.opt import L2CholeskyLMRegularizer +from pysindy.sssindy.opt import LMSettings +from pysindy.sssindy.opt import SINDyAlternatingLMReg +from pysindy.sssindy.sssindy import _initialize_params +from pysindy.utils.odes import lorenz + +jax.config.update("jax_enable_x64", True) +jax.config.update("jax_default_device", jax.devices()[0]) + + +# Todo: when merging with pysindy, most of the data fixtures are clones +# from that repo's conftest.py +@pytest.fixture(scope="session") +def data_1d(): + t = np.linspace(0, 1, 12) + x = 0.2 * t.reshape(-1, 1) + return x, t + + +@pytest.fixture(scope="session") +def data_lorenz(): + t = np.linspace(0, 1, 12) + x0 = [8, 27, -7] + x = solve_ivp(lorenz, (t[0], t[-1]), x0, t_eval=t).y.T + + return x, t + + +def test_print(data_1d, capsys): + x = jnp.array(data_1d[0]) + t = jnp.array(data_1d[1]) + model = SSSINDy(optimizer=LMSolver(optimizer_settings=LMSettings(max_iter=30))) + model.fit(x, t) + model.print() + out, _ = capsys.readouterr() + model.predict(x) + model.score(x, t, x) + model.simulate(x[0], t) + + assert len(out) > 0 + assert " = " in out + + +@pytest.mark.parametrize( + "reg", + [ + (L2CholeskyLMRegularizer(1e-12)), + (SINDyAlternatingLMReg(1e-12, theta_optimizer=STLSQ())), + ], + ids=type, +) +def test_lm_regularizers(data_lorenz, reg: _LMRegularizer): + interp = RKHSInterpolant(get_gaussianRBF(0.2), (0,), 1e-5) + exp = JointObjective(1, 1, JaxPolyLib(), interp) + x = [jnp.array(data_lorenz[0])] + t = [jnp.array(data_lorenz[1])] + objective = exp.fit(x, t).transform(x, t) + reg.fit(objective) + interp.fit_time(x[0].shape[-1], t[0]) + ez_coeff = jnp.linalg.inv(interp.evaluation_kmat(t[0], t[0])) @ x[0].reshape( + (-1, 1) + ) + params = jnp.hstack((ez_coeff.flatten(), jnp.zeros(objective.full_n_theta))) + + origin_val = _evaluate_objective(params, objective) + local_val = L2CholeskyLMRegularizer(1).fit(objective).eval(params) + curr_val = origin_val.add_regularization(local_val) + reg.eval(params) + result = reg.step(params, curr_val) + assert not jnp.isnan(result[0]).any() + + +def test_expression(data_lorenz): + exp = JointObjective(1, 1, JaxPolyLib(), MockInterpolant()) + x = [jnp.array(data_lorenz[0])] + t = [jnp.array(data_lorenz[1])] + objective = exp.fit(x, t).transform(x, t) + vector_len = objective.full_n_process + objective.full_n_theta + params = jnp.zeros(vector_len) + _evaluate_objective(params, objective) + jax.jit(_evaluate_objective)(params, objective) + + +def test_multiple_trajectories(data_lorenz): + x = [jnp.array(data_lorenz[0]), jnp.array(data_lorenz[0])] + t = [jnp.array(data_lorenz[1]), jnp.array(data_lorenz[1])] + model = SSSINDy(optimizer=LMSolver(optimizer_settings=LMSettings(max_iter=30))) + model.fit(x, t) + model.x_predict(t[0]) + model.predict(x[0][0]) + model.score(x, t, x) + + +@pytest.mark.parametrize("data", ["sin_data", "data_lorenz"], ids=["sin", "lorenz"]) +@pytest.mark.parametrize( + "init_strategy", + [ + jrp.key(5), + "zeros", + "ones", + STLSQ(), + [LMSolver(optimizer_settings=LMSettings(max_iter=2, use_jit=False))], + ], + ids=type, +) +def test_init_params(data, init_strategy, request): + exp = JointObjective(1, 1, JaxPolyLib(), RKHSInterpolant(get_gaussianRBF(0.2))) + x, t = request.getfixturevalue(data) + t = t.flatten() + obj = exp.fit([x], [t], t_coloc=[t]).transform(x, t, t) + params = _initialize_params(init_strategy, exp, obj, [x], [t], [t]) + assert len(params) == obj.full_n_theta + obj.full_n_process + + +@pytest.fixture() +def sin_data(): + t = jnp.arange(0, 6, 1, dtype=float) + x = jnp.sin(t).reshape((-1, 1)) + return x, t + + +@pytest.fixture() +def twod_sin_data(): + t = jnp.arange(0, 6, 1, dtype=float) + x = jnp.sin(t).reshape((-1, 1)) + return jnp.hstack((x, -x)), t + + +@pytest.fixture() +def threed_sin_data(): + t = jnp.arange(0, 6, 1, dtype=float) + x = jnp.sin(t).reshape((-1, 1)) + return jnp.hstack((x, -x, 3 * x)), t + + +@pytest.mark.parametrize( + "k_constructor", [get_gaussianRBF, GaussianRBFKernel], ids=["old", "new"] +) +@pytest.mark.parametrize( + "data", + ["sin_data", "twod_sin_data", "threed_sin_data"], + ids=["1d", "2d", "3d"], +) +def test_rbf_kernel(request, data, k_constructor): + x, t_obs = request.getfixturevalue(data) + gamma = 1 + dt = t_obs[1] - t_obs[0] + t_pred = jnp.arange(t_obs.min(), t_obs.max(), dt / 2, dtype=float).reshape((-1, 1)) + + sk_kernel = KernelRidge(alpha=0, kernel="rbf", gamma=gamma) + sk_kernel.fit(t_obs.reshape((-1, 1)), x) + sk_pred = sk_kernel.predict(t_pred) + + our_gamma = jnp.sqrt(1 / (2 * gamma)) + our_interp = RKHSInterpolant( + nugget=0, + kernel=k_constructor(our_gamma), + derivative_orders=(0,), + ) + our_interp.fit_time(x.shape[-1], t_obs.flatten()) + params = our_interp.fit_obs(t_obs.flatten(), x, noise_var=0) + our_pred = our_interp.__call__(t_pred, params) + + # Because interpolants can cross the axis at slightly different times, + # np/jnp.allclose will raise false positives + rel_error = jnp.linalg.norm((sk_pred - our_pred)) / jnp.linalg.norm(sk_pred) + assert rel_error < 0.01 + vec_align = jnp.sum(sk_pred * our_pred) / jnp.linalg.norm(sk_pred) ** 2 + assert 0.99 < vec_align < 1.01 + + our_pred_again = our_interp.interpolate(x, t_obs.flatten(), t_pred, 0) + assert jnp.allclose(our_pred_again, our_pred, atol=1e-8)