This commit is contained in:
Ghost_Element 2023-09-21 17:43:28 +02:00
parent b2e59c9ad7
commit 1c17a2e6c4

View file

@ -188,6 +188,88 @@ void train_network(Neural_Network* network, Image *image, int label) {
Matrix* hidden3_outputs = apply(sigmoid, add(dot(network->weights_3, hidden2_outputs), network->bias_3));
Matrix* final_outputs = apply(sigmoid, add(dot(network->weights_output, hidden3_outputs), network->bias_output));
//begin backpropagation:
Matrix* sigma = matrix_create(final_outputs->rows, 1);
matrix_fill(sigma, 1);
Matrix* temp_1 = subtract(sigma, final_outputs);
Matrix* temp_2 = multiply(temp_1, final_outputs); // * soll-ist
Matrix* temp_3 = matrix_create(final_outputs->rows, final_outputs->columns);
matrix_fill(temp_3,0);
temp_3->numbers[label][0] = 1;
Matrix* temp_4 = subtract(temp_3, final_outputs);
sigma = multiply(temp_2, temp_4);
matrix_free(temp_3);
matrix_free(temp_4);
//sigma done
Matrix* temp1 = transpose(hidden3_outputs);
Matrix* temp2 = dot(sigma, temp1);
Matrix* weights_delta = scale(temp2, network->learning_rate);
Matrix* bias_delta = scale(sigma, network->learning_rate);
Matrix* temp = add(weights_delta, network->weights_output);
matrix_free(network->weights_output);
network->weights_output = temp;
temp = add(bias_delta, network->bias_output);
matrix_free(network->bias_output);
network->bias_output = temp;
matrix_free(weights_delta);
matrix_free(bias_delta);
// other levels
Matrix* sigma_current = matrix_create(hidden3_outputs->rows, 1);
matrix_fill(sigma_current, 1);
temp_1 = subtract(sigma_current, hidden3_outputs);
temp_2 = multiply(temp_1, hidden3_outputs); // *sum(delta*weights)
for(int j=0;j<hidden3_outputs->rows;j++) {
double sum = 0;
for (int i = 0; i < sigma->rows; i++) {
sum += hidden3_outputs->numbers[j][i]*sigma->numbers[i][0];
}
temp_1->numbers[j][0]=sum;
}
sigma_current = multiply(temp_2, temp_1);
// sigma done
temp1 = transpose(hidden2_outputs);
temp2 = dot(sigma_current, temp1);
weights_delta = scale(temp2, network->learning_rate);
bias_delta = scale(sigma_current, network->learning_rate);
temp = add(weights_delta, network->weights_3);
matrix_free(network->weights_3);
network->weights_3 = temp;
temp = add(bias_delta, network->bias_3);
matrix_free(network->bias_3);
network->bias_3 = temp;
matrix_free(weights_delta);
matrix_free(bias_delta);
matrix_free(temp_1);
matrix_free(temp_2);
matrix_free(temp1);
matrix_free(temp2);
//matrix_print(sigma);
}
//void batch_train_network(Neural_Network* network, Image** images, int size);