fixed pointer stuff
This commit is contained in:
parent
86ac3e855c
commit
01c599603f
2 changed files with 20 additions and 8 deletions
6
main.c
6
main.c
|
|
@ -8,7 +8,7 @@ int main() {
|
||||||
Image** images = import_images("../data/train-images.idx3-ubyte", "../data/train-labels.idx1-ubyte", NULL, 60000);
|
Image** images = import_images("../data/train-images.idx3-ubyte", "../data/train-labels.idx1-ubyte", NULL, 60000);
|
||||||
// img_visualize(images[4]);
|
// img_visualize(images[4]);
|
||||||
|
|
||||||
Neural_Network* nn = new_network(28*28, 16, 10, 0.5);
|
Neural_Network* nn = new_network(28*28, 8, 10, 0.25);
|
||||||
randomize_network(nn, 20);
|
randomize_network(nn, 20);
|
||||||
// save_network(nn);
|
// save_network(nn);
|
||||||
|
|
||||||
|
|
@ -19,8 +19,6 @@ int main() {
|
||||||
train_network(nn, images[i], images[i]->label);
|
train_network(nn, images[i], images[i]->label);
|
||||||
}
|
}
|
||||||
|
|
||||||
measure_network_accuracy(nn, images, 100);
|
printf("%lf\n", measure_network_accuracy(nn, images, 100));
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
@ -241,13 +241,27 @@ void train_network(Neural_Network* network, Image *image, int label) {
|
||||||
Matrix* weights_delta = scale(temp6, network->learning_rate);
|
Matrix* weights_delta = scale(temp6, network->learning_rate);
|
||||||
Matrix* bias_delta = scale(sigma1, network->learning_rate);
|
Matrix* bias_delta = scale(sigma1, network->learning_rate);
|
||||||
|
|
||||||
|
// Matrix* temp7 = add(weights_delta, network->weights_output);
|
||||||
|
// matrix_free(network->weights_output);
|
||||||
|
// network->weights_output = temp7;
|
||||||
|
//
|
||||||
|
// Matrix* temp8 = add(bias_delta, network->bias_output);
|
||||||
|
// matrix_free(network->bias_output);
|
||||||
|
// network->bias_output = temp8;
|
||||||
|
|
||||||
Matrix* temp7 = add(weights_delta, network->weights_output);
|
Matrix* temp7 = add(weights_delta, network->weights_output);
|
||||||
matrix_free(network->weights_output);
|
for (int i = 0; i < network->weights_output->rows; ++i) {
|
||||||
network->weights_output = temp7;
|
for (int j = 0; j < network->weights_output->columns; ++j) {
|
||||||
|
network->weights_output->numbers[i][j] = temp7->numbers[i][j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Matrix* temp8 = add(bias_delta, network->bias_output);
|
Matrix* temp8 = add(bias_delta, network->bias_output);
|
||||||
matrix_free(network->bias_output);
|
for (int i = 0; i < network->bias_output->rows; ++i) {
|
||||||
network->bias_output = temp8;
|
for (int j = 0; j < network->bias_output->columns; ++j) {
|
||||||
|
network->bias_output->numbers[i][j] = temp8->numbers[i][j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// other levels
|
// other levels
|
||||||
Matrix* sigma2 = backPropagation(network->learning_rate, network->weights_3, network->bias_3, h3_outputs, h2_outputs, sigma1);
|
Matrix* sigma2 = backPropagation(network->learning_rate, network->weights_3, network->bias_3, h3_outputs, h2_outputs, sigma1);
|
||||||
|
|
|
||||||
Reference in a new issue