Content



KAN's Core

There are multiple training methods for Kolmogorov-Arnold network. This article explains only Kaczmarz method, published by authors in 2021. The other methods, such as Broyden, Adam and stochastic gradient descent can be found in literature.

Introduction

Kolmogorov-Arnold model is a set of continuous functions $\Phi_q, f_{q,p}$ converting input vectors $X$, shown in formula by their elements $x_p$, called features, into modeled scalars $m$, called targets $$ m = \sum_{q=0}^{2n} \Phi_q\left(\sum_{p=1}^{n} f_{q,p}(x_{p})\right). $$ Training the model is identification of these functions provided dataset as a list of corresponding input vectors $X^k$ and targets $m^k$. This model has recently acquired name Kolmogorov-Arnold network (KAN) and regarded as an alternative to multilayer perceptron (MLP).

We use subscripts for indexes in formulas and superscripts to denote particular values, for example, when we say that $x_i = x^k$, that means that variable $x$ sitting at the indexed position $i$ takes the numerical value $x^k$. Equality ${x_i}^{k+1} = {x_i}^{k} + \Delta$ means modification of the element $x_i$ by adding $\Delta$ in $k-th$ computational step.

One function only

Let us start from the one function only and move on further to KAN. The 1D dataset is given by points with their coordinates $x^k, y^k$ and the goal is to adjust arbitrary piecewise linear function given by their ordinates $f_i$ and arranged with an even abscissa intervals $D$ in the field of definition $[x_{min}, x_{max}]$, assuming given points are spread near certain curve line making this approximation reasonable.



For every abscissa $x$ we can identify corresponding linear segment by computing index $$ i = \lfloor (x - x_{min}) / D \rfloor, $$ (where $\lfloor \cdot \rfloor$ is "FLOOR" function), relative offset to abscissa from the left $$L = (x - x_{min}) / D - i = R / D,$$ $L \in [0, 1.0)$ and relative offset from the right is $1.0 - L$.

Having that, we can specify desired equality as linear interpolation $\mathbf{v}^T \mathbf{f} = y$, where $\mathbf{f}$ is vector with ordered ordinates and $\mathbf{v}^T$ is vector with only two non zero elements matching positions of corresponding ordinates $$\mathbf{f}^T = [f_0\: f_1\: f_2\: ... f_i\: f_{i+1} ...],$$ $$\mathbf{v}^T = [0,\: 0,\: 0,\: ... 1-L,\: L, ...].$$ Please note that position of $1-L$ is matching left edge of linear segment and $L$ is matching right edge. The reason must be obvious.

Collecting of system of linear algebraic equations for this case is not even necessary. The training is actually much simpler procedure known as projection descent introduced by Stefan Kaczmarz in 1937. It only needs navigation dataset with red points data, identification of particular linear segment for each point, and update left and right points by modifications $${f_i}^{k+1} = {f_i}^{k} + \frac{(y - \hat y) \cdot (1 - L) }{(1 - L)^2 + L^2}$$ $${f_{i + 1}}^{k+1} = {f_{i + 1}}^{k} + \frac{(y - \hat y) \cdot L }{(1-L)^2 + L^2},$$ where $y$ and $\hat y = {f_i}^{k} \cdot (1 - L) + {f_{i + 1}}^{k} \cdot (L) = {f_i}^{k} + ({f_{i + 1}}^{k} - {f_i}^{k}) \cdot L$ are provided and computed targets. The squared norms $(1-L)^2 + L^2$ are always in range $[1, 0.5]$ and can be replaced by regularization $\alpha$ for computational stability.

The matter of projection descent is navigation of arbitrary point through hyperplanes of linear equations by projecting it from one hyperplane to another. The convergence is extremely fast when all hyperplanes are near pairwise orthogonal which is true in this case. The point that is navigated, in this case, is vector $\mathbf{f}^T = [f_0\: f_1\: f_2\: ... f_i\: f_{i+1} ...]$. The training may be arranged in so-called online manner, that means while reading dataset. The experimental points or observations of real systems usually not fit any model exactly that means that training must be interrupted after achieving certain accuracy. Regularization provides computational stability to the training process. Usually multiple runs through the dataset (epochs) are needed.

That was the hardest part of training KAN, the outstanding part is even simpler.

Generalized additive models (GAMs) or discrete Urysohn operators (DUOs)

Now we consider the model with the sum of multiple functions. They are called GAMs or DUOs $$ z = \sum_{p=1}^{n} f_p(x_p).$$ All we need to do is to reuse previous section. The difference between given and estimated targets $z - \hat z$ is now evenly split between all functions of the model and everything else is performed as already explained. It is still the same projection descent introduced by Stefan Kaczmarz, only with block vectors. $$[f_{0,0}\: f_{0,1}\: f_{0,2}\: ... f_{0,i}\: f_{0,i+1} ... f_{1,0}\: f_{1,1}\: f_{1,2}\: ... f_{1,i}\: f_{1,i+1} ...],$$ $$[0,\: 0,\: 0,\: ... (1-L_0),\: L_0, ... 0,\: 0,\: 0,\: ... (1-L_1),\: L_1, ...].$$
The name Generalized Additive Model does not need an explanation. Name Discrete Urysohn Operator needs short clarification. There is well known integral equation named after its researcher Pavel Urysohn. It may take several forms, one of them looks like follows:

$$ z = \int_0^T U(x(s), s)ds. $$ In this case $x(s)$ is a function defined on $[0,T]$ and $U$ is a function with of two arguments called kernel. These two formulas, sum and integral, can be regarded as discrete and continuous forms of the same model. The values $x_p$ can be interpreted as sequential points of function $x(s)$ and functions $f_p$ as slices of $U$, so sum is a finite difference approximation to Urysohn integral equation. The term Discrete Urysohn operator is used in mathematical literature, it is not a new terminology.

Training a model goes record by record. For each of them we compute estimation $\hat z$, find the difference $\Delta z = z - \hat z$, divide it evenly between all functions $\Delta z / n$, multiply by parameter of regularization $\alpha \cdot \Delta z / n$ and update corresponding single linear segment of every $f_p$ by adding $\alpha \cdot \Delta z/n \cdot (1-L)$ to the left ordinate and $\alpha \cdot \Delta z/n \cdot L$ to the right ordinate. The convergence is extremely fast, because if to express training process as an iterative solution, the corresponding system of linear algebraic equations, if constructed, should be very sparse.

Kolmogorov-Arnold representation is a tree of GAMs or DUOs

The fact that Kolmogorov-Arnold representation is a tree of GAMs or DUOs was notices only in 2021, although representation was introduced in 1957. If we make $2n + 1$ different GAMs or DUOs for the same inputs $x_p$ and consider their outputs $z_q$ as latent or unobserved variables (similar to hidden layer in MLP) than we can express final model output also as GAM or DUO $$ m = \sum_{q=0}^{2n} \Phi_q(\hat z_q),$$ where $\hat z_q$ are $$ \hat z_q = \sum_{p=1}^{n} f_{q,p}(x_p), \:\:\:q \in [0, 2n].$$ The training is an update of inner and outer operators for each new record. Inner operators are those which have features as inputs and outer operator computes the target. The outer operator is updated as already explained by reducing difference $\Delta m = m - \hat m$ for computed hidden layer $\hat z_q$ values. The inputs for outer operator are computed when the model value $\hat m$ is obtained and the difference $\alpha (m - \hat m)$ is used to update it.

The update of each inner operator must reduce the difference $\Delta m = m - \hat m$, so we can use $\alpha (m - \hat m)$ as the output difference, accounting the sign of derivative of corresponding outer function, since the output of the inner operator is an argument of the outer, so we use $$\alpha \cdot (m - \hat m) \cdot sign \left(\frac{d\Phi_q}{dz_q} \right).$$ for each $q-th$ inner block, where $sign$ is either $+1$ or $-1$. It was found experimentally that convergence and accuracy is slightly higher when the actual derivatives are used instead of $sign$ $$\alpha \cdot (m - \hat m) \cdot \frac{d\Phi_q}{dz_q}.$$ Later the theoretical proof for this change has been found.

Following Kaczmarz method we have to divide difference $\Delta m = m - \hat m$ by $(2n + 1)$, that means evenly between inner blocks and within each inner block divide by $n$, but since this is constant for the model the parameter of regularization may inexplicitly include this scaling and we can drop some unnecessary divisions and above difference is applied directly to each function of each inner block.

Rescaling of intermediate arguments

The limits for definition of intermediate arguments $\hat z_q$ may change during the training. It can be easily fixed by moving $x_{min}, x_{max}$ and recalculating of $D$. These limits can be chosen freely at the start of the training and updated when necessary. Below is an example of rescaling. It is executed each time function FitDefinition(x) is called, that means before processing of argument $x$.
void UnivariatePL::FitDefinition(double x) {
    if (x < _xmin) {
        _xmin = x;
        SetLimits();
    }
    if (x > _xmax) {
        _xmax = x;
        SetLimits();
    }
}

void UnivariatePL::SetLimits() {
    double range = _xmax - _xmin;
    _xmin -= 0.01 * range;
    _xmax += 0.01 * range;
    _deltax = (_xmax - _xmin) / (_points - 1);
}

Using splines does not improve accuracy

It may look like piecewise linear model is an approximation to any smooth function model and on that reason is less accurate. It is WRONG. There are multiple coding examples showing that accuracy is the same. The reason is clear. The target is the sum of $2n + 1$ piecewise linear functions and each argument for them is the sum of another $n$ piecewise linear functions with feely chosen sizes of the segments. The basis functions can be replaced during training with preservation of already achieved accuracy. There are some coding examples on this site where training starts as piecewise linear model and ended as spline model.

Computational complexity

The simple way of evaluate computational complexity is to find the number of long operations, such as division and multiplication during training. The other operations, such as allocating memory for arrays and objects and navigating through them should not be a big burden, because these lists and sized are tiny in general case.

We estimate it for piecewise linear case, since splines can be simply added to piecewise linear model when training is already completed preserving already acquired accuracy. Also, not all published samples are optimized for fast execution and even those that already are, still have room for further improvement. However, we can make some conclusion based on KANPro2.

The main time is spent in the following operations of the training loop:
           for (int j = 0; j < nModels; ++j) {
               model += addends[j]->ComputeUsingInput(dataHolder->inputs[i]);
           }
           double residual = dataHolder->target[i] - model;
           for (int j = 0; j < nModels; ++j) {
               addends[j]->UpdateUsingMemory(residual * muPL);
           }
The 'addend' object here is one term of Kolmogorov-Arnold $$ \Phi_q\left(\sum_{p=1}^{n} f_{q,p}(x_{p})\right). $$ Function 'ComputeUsingInput' computes the function value for each function of the addend, which is number of inputs plus one. Function 'UpdateUsingMemory' updates each of these functions reusing few already computed parameters. These parts can be found in 'UnivariatePL::GetFunctionUsingInput' and 'UnivariatePL::UpdateUsingMemory'. You can see that they have only 3 long operations all together and there is one more mulitplication per record 'residual * muPL', which is regularization.

So totally it is 3 times number of functions in a model plus 1 long operation per record in epoch. Number of required epochs is in general similar to neural networks 20 to 100. So everything else must be clear. This is how we achieve quick processing. For example, 5 inputs, 10 000 records may take from 0.1 to 0.4 seconds on a single CPU.

If we need to compare training time to MLP, we need other example with more records and inputs, because comparing fractions of seconds doesn't give right result, because other operations, such as loading program, take longer.

So, quick and simple evaluation, sort of big picture, is 3 multiplications per each function of the model in one step of updating model for each record in training set per epoch.

Proof of the claimed efficiency

All statements of this essay can be verified by downloading and testing the code provided on this site. The performance and accuracy is the same or better than MLP. This concept is published in high rated journals and publicly available since 2021.

References

Video lecture
Code download