diff --git a/cnn.c b/cnn.c index 4a3f872..94de085 100644 --- a/cnn.c +++ b/cnn.c @@ -26,43 +26,37 @@ float* expand(const float* old_matrix, u8 old_matrix_length, u8 layer){ //返回卷积的结果 float* convolution(Model model_w, Model model_b, const float* input_matrix, u8 input_matrix_length){ // 初始化卷积层参数 - u8 c_rl = input_matrix_length - 2; + u8 im_l = input_matrix_length; + u8 cr_l = input_matrix_length - 2; float conv_temp; // 临时变量,用于存储卷积计算的中间结果 - float* conv_rlst = (float *) malloc(sizeof (float) * model_w.num_kernels * model_w.layer * (c_rl*c_rl)); - memset(conv_rlst, 0, sizeof (float) * model_w.num_kernels * model_w.layer * (c_rl*c_rl)); + float* conv_rlst = (float *) malloc(sizeof (float) * model_w.num_kernels * (cr_l * cr_l)); + memset(conv_rlst, 0, sizeof (float) * model_w.num_kernels * (cr_l * cr_l)); // 遍历30个卷积核(假设有30个通道) + + for(u8 l=0;l 0) - conv_rlst[row * (input_matrix_length - 2) + col + - n * (input_matrix_length - 2) * (input_matrix_length - 2) + - 0 - ] = conv_temp; // 如果卷积结果大于0,存入结果数组 - else - conv_rlst[row * (input_matrix_length - 2) + col + - n * (input_matrix_length - 2) * (input_matrix_length - 2) + - 0 - ] = 0; // 否则存入0 } + // 加上对应卷积核的偏置 + conv_temp += model_b.array[n]; + // 激活函数:ReLU(将小于0的值设为0) + if (conv_temp > 0) + conv_rlst[(n*(cr_l*cr_l)) + (row*cr_l) + (col)] = conv_temp; // 如果卷积结果大于0,存入结果数组 + else + conv_rlst[(n*(cr_l*cr_l)) + (row*cr_l) + (col)] = 0; // 否则存入0 } - } + } } + } return conv_rlst; } @@ -73,33 +67,33 @@ float* convolution(Model model_w, Model model_b, const float* input_matrix, u8 i //input_matrix_length 输入图像的边长:100 //输出图像的边长:50 //返回池化的结果 -float* pooling(u8 num_kernels, u8 area, const float* input_matrix, u8 input_matrix_length){ - +float* pooling(Model model_w, const float* input_matrix, u8 input_matrix_length){ + u8 im_l = input_matrix_length; float pool_temp = 0; // 临时变量,用于存储池化操作的最大值 - float* pool_rslt = (float *) malloc(sizeof (float)*num_kernels*input_matrix_length*input_matrix_length); - memset(pool_rslt, 0, sizeof (float)*num_kernels*input_matrix_length*input_matrix_length); + float* pool_rslt = (float *) malloc(sizeof (float)*model_w.num_kernels*im_l*im_l); + memset(pool_rslt, 0, sizeof (float)*model_w.num_kernels*im_l*im_l); // 遍历30个通道(与卷积核数量相同) - for(u8 n=0; n