From c57a8cb7863f448cb88b7e4dcd7f5cf86d425d05 Mon Sep 17 00:00:00 2001 From: Qiea <1310371422@qq.com> Date: Wed, 13 Nov 2024 10:06:35 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 根据model_w.num_kernels,设置卷积核个数 --- cnn.c | 33 +++++++++++++++++---------------- cnn_model.c | 1 + cnn_model.h | 2 ++ main.c | 6 +++--- 4 files changed, 23 insertions(+), 19 deletions(-) diff --git a/cnn.c b/cnn.c index ae7ceda..1833785 100644 --- a/cnn.c +++ b/cnn.c @@ -165,13 +165,14 @@ float* hidden(const float* input_matrix){ return affine1_rslt; } -float* output(const float* input_matrix){ +float* output(Model model_w, const float* input_matrix){ + u8 num = model_w.num_kernels; float affine2_temp; // 临时变量,用于存储输出层的中间结果 - float *affine2_rslt = (float *) malloc(sizeof(float)*7); - memset(affine2_rslt, 0, sizeof(float)*7); + float *affine2_rslt = (float *) malloc(sizeof(float)*num); + memset(affine2_rslt, 0, sizeof(float)*num); // 遍历10个输出神经元(假设有10个类别) - for(int n=0; n<7; n++) + for(int n=0; n