fixed pointer stuff

This commit is contained in:
Thomas 2023-09-22 09:46:41 +02:00
parent 86ac3e855c
commit 01c599603f
2 changed files with 20 additions and 8 deletions

6
main.c
View file

@ -8,7 +8,7 @@ int main() {
Image** images = import_images("../data/train-images.idx3-ubyte", "../data/train-labels.idx1-ubyte", NULL, 60000); Image** images = import_images("../data/train-images.idx3-ubyte", "../data/train-labels.idx1-ubyte", NULL, 60000);
// img_visualize(images[4]); // img_visualize(images[4]);
Neural_Network* nn = new_network(28*28, 16, 10, 0.5); Neural_Network* nn = new_network(28*28, 8, 10, 0.25);
randomize_network(nn, 20); randomize_network(nn, 20);
// save_network(nn); // save_network(nn);
@ -19,8 +19,6 @@ int main() {
train_network(nn, images[i], images[i]->label); train_network(nn, images[i], images[i]->label);
} }
measure_network_accuracy(nn, images, 100); printf("%lf\n", measure_network_accuracy(nn, images, 100));
} }

View file

@ -241,13 +241,27 @@ void train_network(Neural_Network* network, Image *image, int label) {
Matrix* weights_delta = scale(temp6, network->learning_rate); Matrix* weights_delta = scale(temp6, network->learning_rate);
Matrix* bias_delta = scale(sigma1, network->learning_rate); Matrix* bias_delta = scale(sigma1, network->learning_rate);
// Matrix* temp7 = add(weights_delta, network->weights_output);
// matrix_free(network->weights_output);
// network->weights_output = temp7;
//
// Matrix* temp8 = add(bias_delta, network->bias_output);
// matrix_free(network->bias_output);
// network->bias_output = temp8;
Matrix* temp7 = add(weights_delta, network->weights_output); Matrix* temp7 = add(weights_delta, network->weights_output);
matrix_free(network->weights_output); for (int i = 0; i < network->weights_output->rows; ++i) {
network->weights_output = temp7; for (int j = 0; j < network->weights_output->columns; ++j) {
network->weights_output->numbers[i][j] = temp7->numbers[i][j];
}
}
Matrix* temp8 = add(bias_delta, network->bias_output); Matrix* temp8 = add(bias_delta, network->bias_output);
matrix_free(network->bias_output); for (int i = 0; i < network->bias_output->rows; ++i) {
network->bias_output = temp8; for (int j = 0; j < network->bias_output->columns; ++j) {
network->bias_output->numbers[i][j] = temp8->numbers[i][j];
}
}
// other levels // other levels
Matrix* sigma2 = backPropagation(network->learning_rate, network->weights_3, network->bias_3, h3_outputs, h2_outputs, sigma1); Matrix* sigma2 = backPropagation(network->learning_rate, network->weights_3, network->bias_3, h3_outputs, h2_outputs, sigma1);