Vector Targets
'Vector targets' means multiple scalars with complex interdependence and shared internal system structures.
As an example, we chose areas of four faces of random tetrahedrons, defined by randomly generated coordinates
of the vertices as uniformly distributed random numbers from [0, 100].
The training dataset is 500,000 records. The validation dataset is 50,000. Number of features is 12.
Baseline MATLAB script is below:
function tetrahedron_nn()
% Parameters
num_train = 500000;
num_val = 50000;
input_dim = 12; % 4 vertices * 3 coordinates
output_dim = 4; % 4 face areas
% Generate data
[X_train, Y_train] = generate_data(num_train);
[X_val, Y_val] = generate_data(num_val);
% Define neural network
layers = [
featureInputLayer(input_dim)
fullyConnectedLayer(64)
reluLayer
fullyConnectedLayer(64)
reluLayer
fullyConnectedLayer(output_dim)
regressionLayer
];
% Training options
options = trainingOptions('adam', ...
'MaxEpochs', 50, ...
'MiniBatchSize', 1024, ...
'Shuffle', 'every-epoch', ...
'ValidationData', {X_val, Y_val}, ...
'ValidationFrequency', 50, ...
'Plots', 'training-progress', ...
'Verbose', true);
% Train the network
net = trainNetwork(X_train, Y_train, layers, options);
% Predict and evaluate
Y_pred = predict(net, X_val);
% Compute Pearson correlations
for i = 1:output_dim
corr_val = corr(Y_pred(:, i), Y_val(:, i));
fprintf('Face %d Pearson correlation: %.4f\n', i, corr_val);
end
end
function [X, Y] = generate_data(n)
X = 100 * rand(n, 12); % Random tetrahedron vertices in [0, 100]
Y = zeros(n, 4); % Areas of four faces
for i = 1:n
V = reshape(X(i, :), [3, 4])'; % 4 vertices (3D points)
faces = [1 2 3; 1 2 4; 1 3 4; 2 3 4];
for j = 1:4
A = V(faces(j, 1), :);
B = V(faces(j, 2), :);
C = V(faces(j, 3), :);
Y(i, j) = 0.5 * norm(cross(B - A, C - A));
end
end
end
Training time 16 min, 52 sec. Accuracy metrics were Pearson correlation coefficients for each individual face, they are
shown below:
Face 1 Pearson correlation: 0.9768
Face 2 Pearson correlation: 0.9775
Face 3 Pearson correlation: 0.9777
Face 4 Pearson correlation: 0.9777
KAN C++ counter test is provided by the link at the top. The program printout is shown below:
Epoch 49, current relative error 0.023376
Time for training 349.875 sec.
Relative RMSE error for unseen data 0.02386
Pearsons for individual targets: 0.9833 0.9847 0.9840 0.9845
We can see that it is quicker and more accurate. The accuracy and performance may vary for different configurations of the both
script and C++ code, but in all cases KAN is several times quicker and more accurate.
|
|