Merge remote-tracking branch 'origin/Development' into Development

This commit is contained in:
Raphael Walcher 2023-09-22 09:54:17 +02:00
commit 2cfa7f708e
2 changed files with 20 additions and 8 deletions

View file

@ -241,13 +241,27 @@ void train_network(Neural_Network* network, Image *image, int label) {
Matrix* weights_delta = scale(temp6, network->learning_rate);
Matrix* bias_delta = scale(sigma1, network->learning_rate);
// Matrix* temp7 = add(weights_delta, network->weights_output);
// matrix_free(network->weights_output);
// network->weights_output = temp7;
//
// Matrix* temp8 = add(bias_delta, network->bias_output);
// matrix_free(network->bias_output);
// network->bias_output = temp8;
Matrix* temp7 = add(weights_delta, network->weights_output);
matrix_free(network->weights_output);
network->weights_output = temp7;
for (int i = 0; i < network->weights_output->rows; ++i) {
for (int j = 0; j < network->weights_output->columns; ++j) {
network->weights_output->numbers[i][j] = temp7->numbers[i][j];
}
}
Matrix* temp8 = add(bias_delta, network->bias_output);
matrix_free(network->bias_output);
network->bias_output = temp8;
for (int i = 0; i < network->bias_output->rows; ++i) {
for (int j = 0; j < network->bias_output->columns; ++j) {
network->bias_output->numbers[i][j] = temp8->numbers[i][j];
}
}
// other levels
Matrix* sigma2 = backPropagation(network->learning_rate, network->weights_3, network->bias_3, h3_outputs, h2_outputs, sigma1);