small changes
This commit is contained in:
parent
9ca185c30f
commit
2e7b572069
1 changed files with 61 additions and 41 deletions
|
|
@ -205,65 +205,85 @@ double cost_function(Matrix* calculated, int expected){
|
|||
|
||||
void train_network(Neural_Network* network, Image *image, int label) {
|
||||
|
||||
// Flatten the image into matrix
|
||||
Matrix* input = matrix_flatten(image->pixel_values, 0);
|
||||
|
||||
Matrix* hidden1_outputs = apply(sigmoid, add(dot(network->weights_1, input), network->bias_1));
|
||||
Matrix* hidden2_outputs = apply(sigmoid, add(dot(network->weights_2, hidden1_outputs), network->bias_2));
|
||||
Matrix* hidden3_outputs = apply(sigmoid, add(dot(network->weights_3, hidden2_outputs), network->bias_3));
|
||||
Matrix* final_outputs = apply(sigmoid, add(dot(network->weights_output, hidden3_outputs), network->bias_output));
|
||||
// Perform forward propagation
|
||||
Matrix* h1_dot = dot(network->weights_1, input);
|
||||
Matrix* h1_add = add(h1_dot, network->bias_1);
|
||||
Matrix* h1_outputs = apply(sigmoid, h1_add);
|
||||
|
||||
Matrix* h2_dot = dot(network->weights_2, h1_outputs);
|
||||
Matrix* h2_add = add(h2_dot, network->bias_2);
|
||||
Matrix* h2_outputs = apply(sigmoid, h2_add);
|
||||
|
||||
//begin backpropagation:
|
||||
Matrix* h3_dot = dot(network->weights_3, h2_outputs);
|
||||
Matrix* h3_add = add(h3_dot, network->bias_3);
|
||||
Matrix* h3_outputs = apply(sigmoid, h3_add);
|
||||
|
||||
Matrix* final_dot = dot(network->weights_output, h3_outputs);
|
||||
Matrix* final_add = add(final_dot, network->bias_output);
|
||||
Matrix* final_outputs = apply(sigmoid, final_add);
|
||||
|
||||
// begin backpropagation
|
||||
Matrix* sigma = matrix_create(final_outputs->rows, 1);
|
||||
matrix_fill(sigma, 1);
|
||||
Matrix* temp_1 = subtract(sigma, final_outputs);
|
||||
Matrix* temp_2 = multiply(temp_1, final_outputs); // * soll-ist
|
||||
Matrix* temp_3 = matrix_create(final_outputs->rows, final_outputs->columns);
|
||||
matrix_fill(temp_3,0);
|
||||
temp_3->numbers[label][0] = 1;
|
||||
Matrix* temp_4 = subtract(temp_3, final_outputs);
|
||||
sigma = multiply(temp_2, temp_4);
|
||||
Matrix* temp1 = subtract(sigma, final_outputs);
|
||||
Matrix* temp2 = multiply(temp1, final_outputs); // * soll-ist
|
||||
Matrix* temp3 = matrix_create(final_outputs->rows, final_outputs->columns);
|
||||
matrix_fill(temp3, 0);
|
||||
temp3->numbers[label][0] = 1;
|
||||
Matrix* temp4 = subtract(temp3, final_outputs);
|
||||
sigma = multiply(temp2, temp4);
|
||||
|
||||
|
||||
matrix_free(temp_3);
|
||||
matrix_free(temp_4);
|
||||
//sigma done
|
||||
|
||||
Matrix* temp1 = transpose(hidden3_outputs);
|
||||
Matrix* temp2 = dot(sigma, temp1);
|
||||
Matrix* weights_delta = scale(temp2, network->learning_rate);
|
||||
Matrix* temp5 = transpose(h3_outputs);
|
||||
Matrix* temp6 = dot(sigma, temp5);
|
||||
Matrix* weights_delta = scale(temp6, network->learning_rate);
|
||||
Matrix* bias_delta = scale(sigma, network->learning_rate);
|
||||
|
||||
|
||||
|
||||
Matrix* temp = add(weights_delta, network->weights_output);
|
||||
Matrix* temp7 = add(weights_delta, network->weights_output);
|
||||
matrix_free(network->weights_output);
|
||||
network->weights_output = temp;
|
||||
temp = add(bias_delta, network->bias_output);
|
||||
network->weights_output = temp7;
|
||||
|
||||
Matrix* temp8 = add(bias_delta, network->bias_output);
|
||||
matrix_free(network->bias_output);
|
||||
network->bias_output = temp;
|
||||
network->bias_output = temp8;
|
||||
|
||||
// other levels
|
||||
backPropagation(network->learning_rate, network->weights_3, network->bias_3, h3_outputs, h2_outputs, sigma);
|
||||
backPropagation(network->learning_rate, network->weights_2, network->bias_2, h2_outputs, h1_outputs, sigma);
|
||||
backPropagation(network->learning_rate, network->weights_1, network->bias_1, h1_outputs, input, sigma);
|
||||
|
||||
matrix_free(input);
|
||||
|
||||
matrix_free(h1_dot);
|
||||
matrix_free(h1_add);
|
||||
matrix_free(h1_outputs);
|
||||
|
||||
matrix_free(h2_dot);
|
||||
matrix_free(h2_add);
|
||||
matrix_free(h2_outputs);
|
||||
|
||||
matrix_free(h3_dot);
|
||||
matrix_free(h3_add);
|
||||
matrix_free(h3_outputs);
|
||||
|
||||
matrix_free(final_dot);
|
||||
matrix_free(final_add);
|
||||
matrix_free(final_outputs);
|
||||
|
||||
matrix_free(weights_delta);
|
||||
matrix_free(bias_delta);
|
||||
|
||||
// other levels
|
||||
|
||||
backPropagation(network->learning_rate, network->weights_3, network->bias_3, hidden3_outputs, hidden2_outputs, sigma);
|
||||
backPropagation(network->learning_rate, network->weights_2, network->bias_2, hidden2_outputs, hidden1_outputs, sigma);
|
||||
backPropagation(network->learning_rate, network->weights_1, network->bias_1, hidden1_outputs, input, sigma);
|
||||
|
||||
|
||||
matrix_free(temp);
|
||||
matrix_free(sigma);
|
||||
matrix_free(temp_1);
|
||||
matrix_free(temp_2);
|
||||
matrix_free(temp1);
|
||||
matrix_free(temp2);
|
||||
matrix_free(input);
|
||||
matrix_free(hidden1_outputs);
|
||||
matrix_free(hidden2_outputs);
|
||||
matrix_free(hidden3_outputs);
|
||||
matrix_free(final_outputs);
|
||||
matrix_free(temp3);
|
||||
matrix_free(temp4);
|
||||
matrix_free(temp5);
|
||||
matrix_free(temp6);
|
||||
matrix_free(temp7);
|
||||
matrix_free(temp8);
|
||||
}
|
||||
|
||||
void backPropagation(double learning_rate, Matrix* weights, Matrix* biases, Matrix* current_layer_activation, Matrix* previous_layer_activation, Matrix* sigma_old) {
|
||||
|
|
|
|||
Reference in a new issue