Jiale Chen, Dingling Yao, Adeel Pervez, Dan Alistarh, Francesco Locatello
Abstract
We propose Scalable Mechanistic Neural Network (S-MNN), an enhanced neural network framework designed for scientific machine learning applications involving long temporal sequences. By reformulating the original Mechanistic Neural Network (MNN) (Pervez et al., 2024), we reduce the computational time and space complexities from cubic and quadratic with respect to the sequence length, respectively, to linear. This significant improvement enables efficient modeling of long-term dynamics without sacrificing accuracy or interpretability. Extensive experiments demonstrate that S-MNN matches the original MNN in precision while substantially reducing computational resources. Consequently, S-MNN can drop-in replace the original MNN in applications, providing a practical and efficient tool for integrating mechanistic bottlenecks into neural network models of complex dynamical systems.
This paper introduces the Scalable Mechanistic Neural Network (S-MNN), an enhanced version of the original Mechanistic Neural Network (MNN) that significantly improves computational efficiency while maintaining accuracy. The key innovation is reducing both time and space complexities from cubic/quadratic to linear with respect to sequence length. Key Contributions:
Complexity Reduction - Reformulated the original MNN’s underlying linear system by eliminating slack variables and central difference constraints - Reduced quadratic programming to least squares regression - Achieved linear time and space complexities through efficient banded matrix structures
Solver Design - Developed an efficient solver leveraging inherent sparsity patterns - Optimized for GPU execution with full parallelism exploitation - Implemented numerical stability improvements over iterative methods
Figure 3
Practical Applications The authors demonstrate S-MNN’s effectiveness across multiple scientific applications: - Governing equation discovery for the Lorenz system - Solving the Korteweg-de Vries (KdV) partial differential equation - Long-term sea surface temperature (SST) prediction
Figure 1
Performance Improvements The empirical results show that S-MNN: - Matches or exceeds the accuracy of the original MNN - Achieves ~5x speedup compared to dense MNN solver - Reduces memory usage by 50% - Enables processing of longer sequences that were previously infeasible
Figure 6 Technical Implementation: The paper provides detailed mathematical formulations and algorithms for: - Matrix decomposition and solving techniques - Forward and backward pass computations - Gradient calculations - Numerical stability considerations Limitations and Future Work: - Sequential operations in Cholesky decomposition still limit some parallelism - Small batch sizes can lead to GPU underutilization - Future work aims to develop fully parallel algorithms while maintaining linear complexity The paper represents a significant advancement in scientific machine learning, making mechanistic neural networks practical for long-sequence applications like climate modeling. The results demonstrate that S-MNN can effectively replace the original MNN while providing substantial computational benefits. The most significant figures are Figure 1 (showing real-world SST prediction), Figure 3 (comparing convergence rates), and Figure 6 (demonstrating scalability improvements). Figure 1 makes the best thumbnail as it clearly illustrates the practical impact of the method on a real-world problem.