From d9b6e342c007781f0161aa2e4b23df463764581f Mon Sep 17 00:00:00 2001 From: Thomas Schleicher Date: Thu, 21 Sep 2023 09:56:10 +0200 Subject: [PATCH] load network --- .gitignore | 2 ++ main.c | 20 ++++---------------- matrix.c | 17 ++++++++--------- matrix.h | 2 +- neuronal_network.c | 10 +++++++++- 5 files changed, 24 insertions(+), 27 deletions(-) diff --git a/.gitignore b/.gitignore index 01adc20..a412b2a 100644 --- a/.gitignore +++ b/.gitignore @@ -56,3 +56,5 @@ Mkfile.old dkms.conf /.idea/.name /.idea/misc.xml +/.idea/shelf/Uncommitted_changes_before_Update_at_21_09_23,_09_38_[Changes]/shelved.patch +/.idea/shelf/Uncommitted_changes_before_Update_at_21_09_23,_09_38_[Changes]/shelved.patch diff --git a/main.c b/main.c index 0407567..8ddb221 100644 --- a/main.c +++ b/main.c @@ -5,25 +5,13 @@ #include "neuronal_network.h" int main() { -// Image** images = import_images("../data/train-images.idx3-ubyte", "../data/train-labels.idx1-ubyte", NULL, 2); -// img_visualize(images[1]); +// Image** images = import_images("../data/train-images.idx3-ubyte", "../data/train-labels.idx1-ubyte", NULL, 60000); +// img_visualize(images[4]); // Neural_Network* nn = new_network(4, 2, 3, 0.5); -// -// int n = 20; -// -// matrix_randomize(nn->bias_1, n); -// matrix_randomize(nn->bias_2, n); -// matrix_randomize(nn->bias_3, n); -// -// matrix_randomize(nn->weights_1, n); -// matrix_randomize(nn->weights_2, n); -// matrix_randomize(nn->weights_3, n); -// -// matrix_randomize(nn->weights_output, n); -// +// randomize_network(nn, 20); // save_network(nn); - Neural_Network* nn = load_network("../networks/test1.txt"); +// Neural_Network* nn = load_network("../networks/test1.txt"); } \ No newline at end of file diff --git a/matrix.c b/matrix.c index be1a530..9cbd386 100644 --- a/matrix.c +++ b/matrix.c @@ -246,7 +246,6 @@ Matrix* transpose(Matrix* matrix) { } -//file operations void matrix_save(Matrix* matrix, char* file_string){ // open the file in append mode @@ -274,38 +273,40 @@ void matrix_save(Matrix* matrix, char* file_string){ } Matrix* matrix_load(char* file_string){ + FILE *fptr = fopen(file_string, "r"); + if(!fptr){ printf("Could not open \"%s\"", file_string); exit(1); } + Matrix * m = load_next_matrix(fptr); + fclose(fptr); return m; - } -Matrix * load_next_matrix(FILE *fptr){ +Matrix* load_next_matrix(FILE *save_file){ char buffer[MAX_BYTES]; - fgets(buffer, MAX_BYTES, fptr); + fgets(buffer, MAX_BYTES, save_file); int rows = (int)strtol(buffer, NULL, 10); - fgets(buffer, MAX_BYTES, fptr); + fgets(buffer, MAX_BYTES, save_file); int cols = (int)strtol(buffer, NULL, 10); Matrix *matrix = matrix_create(rows, cols); for(int i = 0; i < rows; i++){ for(int j = 0; j < cols; j++){ - fgets(buffer, MAX_BYTES, fptr); + fgets(buffer, MAX_BYTES, save_file); matrix->numbers[i][j] = strtod(buffer, NULL); } } return matrix; } - Matrix* matrix_flatten(Matrix* matrix, int axis) { // Axis = 0 -> Column Vector, Axis = 1 -> Row Vector Matrix* result_matrix; @@ -329,8 +330,6 @@ Matrix* matrix_flatten(Matrix* matrix, int axis) { return result_matrix; } - - int matrix_argmax(Matrix* matrix) { // Expects a Mx1 matrix if (matrix->columns != 1){ diff --git a/matrix.h b/matrix.h index cc80043..429d5bf 100644 --- a/matrix.h +++ b/matrix.h @@ -16,7 +16,7 @@ void matrix_print(Matrix *matrix); Matrix* matrix_copy(Matrix *matrix); void matrix_save(Matrix* matrix, char* file_string); Matrix* matrix_load(char* file_string); -Matrix* load_next_matrix(FILE * fptr); +Matrix* load_next_matrix(FILE * save_file); void matrix_randomize(Matrix* matrix, int n); // don't understand the usage of the n int matrix_argmax(Matrix* matrix); diff --git a/neuronal_network.c b/neuronal_network.c index 4682233..f64f66b 100644 --- a/neuronal_network.c +++ b/neuronal_network.c @@ -112,8 +112,16 @@ Neural_Network* load_network(char* file) { // create a new network to fill with the saved data Neural_Network* saved_network = new_network(input_size, hidden_size, output_size, 0); + // load matrices from file into struct + saved_network->bias_1 = load_next_matrix(save_file); + saved_network->weights_1 = load_next_matrix(save_file); + saved_network->bias_2 = load_next_matrix(save_file); + saved_network->weights_2 = load_next_matrix(save_file); + saved_network->bias_3 = load_next_matrix(save_file); + saved_network->weights_3 = load_next_matrix(save_file); + saved_network->weights_output = load_next_matrix(save_file); - + // return saved network return saved_network; }