From 643eca68a36aadb3cb246e129c193407478e5b15 Mon Sep 17 00:00:00 2001 From: Qiea <1310371422@qq.com> Date: Sun, 10 Nov 2024 17:25:51 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=E6=B1=A0=E5=8C=96=E5=8A=9F?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cnn.c | 118 ++++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 73 insertions(+), 45 deletions(-) diff --git a/cnn.c b/cnn.c index 53d774d..5f9b4f4 100644 --- a/cnn.c +++ b/cnn.c @@ -14,49 +14,97 @@ float* expand(const float* old_matrix, u8 old_matrix_num){ return new_matrix; } -//卷积核的个数:32 -//卷积的面积:3*3 -//输入图像 -//输入图像的边长:102 +//num_kernels 卷积核的个数:32 +//area 卷积的面积:3*3 +//input_matrix 输入图像 +//input_matrix_length 输入图像的边长:102 //输出图像的边长:100 //返回卷积的结果 float* convolution(u8 num_kernels, u8 area, const float* input_matrix, u8 input_matrix_length){ // 初始化卷积层参数 float conv_temp; // 临时变量,用于存储卷积计算的中间结果 - float* conv_rlst = (float *) malloc(sizeof (float)*32*100*100); - memset(conv_rlst, 0, sizeof (float)*32*100*100); + float* conv_rlst = (float *) malloc(sizeof (float)*num_kernels*(input_matrix_length-2)*(input_matrix_length-2)); + memset(conv_rlst, 0, sizeof (float)*num_kernels*(input_matrix_length-2)*(input_matrix_length-2)); // 遍历30个卷积核(假设有30个通道) - for(int n=0; n<32; n++) + for(u8 n=0; n 0) - conv_rlst[row*100+col+n*100*100] = conv_temp; // 如果卷积结果大于0,存入结果数组 + conv_rlst[row*(input_matrix_length-2)+col+n*(input_matrix_length-2)*(input_matrix_length-2)] = conv_temp; // 如果卷积结果大于0,存入结果数组 else - conv_rlst[row*100+col+n*100*100] = 0; // 否则存入0 + conv_rlst[row*(input_matrix_length-2)+col+n*(input_matrix_length-2)*(input_matrix_length-2)] = 0; // 否则存入0 } } } return conv_rlst; } + +//num_kernels 卷积核的个数:32 +//area 池化的面积:2*2 +//input_matrix 输入图像 +//input_matrix_length 输入图像的边长:100 +//输出图像的边长:50 +//返回池化的结果 +float* pooling(u8 num_kernels, u8 area, const float* input_matrix, u8 input_matrix_length){ + + float pool_temp = 0; // 临时变量,用于存储池化操作的最大值 + float* pool_rslt = (float *) malloc(sizeof (float)*num_kernels*input_matrix_length*input_matrix_length); + memset(pool_rslt, 0, sizeof (float)*num_kernels*input_matrix_length*input_matrix_length); + // 遍历30个通道(与卷积核数量相同) + for(u8 n=0; n