添加功能

根据model_w.num_kernels,设置卷积核个数
This commit is contained in:
Qiea
2024-11-13 10:06:35 +08:00
parent 9db7b625fe
commit c57a8cb786
4 changed files with 23 additions and 19 deletions

33
cnn.c
View File

@@ -165,13 +165,14 @@ float* hidden(const float* input_matrix){
return affine1_rslt; return affine1_rslt;
} }
float* output(const float* input_matrix){ float* output(Model model_w, const float* input_matrix){
u8 num = model_w.num_kernels;
float affine2_temp; // 临时变量,用于存储输出层的中间结果 float affine2_temp; // 临时变量,用于存储输出层的中间结果
float *affine2_rslt = (float *) malloc(sizeof(float)*7); float *affine2_rslt = (float *) malloc(sizeof(float)*num);
memset(affine2_rslt, 0, sizeof(float)*7); memset(affine2_rslt, 0, sizeof(float)*num);
// 遍历10个输出神经元假设有10个类别 // 遍历10个输出神经元假设有10个类别
for(int n=0; n<7; n++) for(int n=0; n<num; n++)
{ {
affine2_temp = 0; // 当前神经元的输出初始化为0 affine2_temp = 0; // 当前神经元的输出初始化为0
@@ -252,10 +253,10 @@ float* generateMatrix(Model model, const float* value)
return CNN_data; return CNN_data;
} }
float calculate_probabilities(float *input_array) float calculate_probabilities(Model model_w, float *input_array)
{ {
float sum = 0; float sum = 0;
u8 input_num = 7; u8 input_num = model_w.num_kernels;
float *result = (float *) malloc(sizeof(float)*input_num); float *result = (float *) malloc(sizeof(float)*input_num);
memset(result, 0, sizeof(float)*input_num); memset(result, 0, sizeof(float)*input_num);
@@ -293,8 +294,8 @@ float calculate_probabilities(float *input_array)
} }
u8 calculate_layer(float *input_array){ u8 calculate_layer(Model model_w, float *input_array){
u8 input_num = 7; u8 input_num = model_w.num_kernels;
u8 predict_num = 0; u8 predict_num = 0;
float max_temp = -100; float max_temp = -100;
for(int n=0; n<input_num; n++) for(int n=0; n<input_num; n++)
@@ -305,8 +306,8 @@ u8 calculate_layer(float *input_array){
predict_num = n; // 记录最大值对应的类别索引 predict_num = n; // 记录最大值对应的类别索引
} }
} }
//print_rslt(input_array,7,7); print_rslt(input_array,7,7);
return predict_num+1; return predict_num+0;
} }
@@ -354,10 +355,10 @@ void cnn_run(){
conv_rlst_3 = NULL; conv_rlst_3 = NULL;
float* affine1_rslt = hidden(pool_rslt_3); float* affine1_rslt = hidden(pool_rslt_3);
float* affine2_rslt = output(affine1_rslt); float* affine2_rslt = output(fc2_weight, affine1_rslt);
printf("概率:%f\r\n",calculate_probabilities(affine2_rslt)); printf("概率:%f\r\n",calculate_probabilities(fc2_weight, affine2_rslt));
printf("Label is:%d\r\n",calculate_layer(affine2_rslt)); printf("Label is:%d\r\n",calculate_layer(fc2_weight, affine2_rslt));
free(pool_rslt_3); free(pool_rslt_3);
pool_rslt_3 = NULL; pool_rslt_3 = NULL;
@@ -404,10 +405,10 @@ void cnn_run(){
conv_rlst_3 = NULL; conv_rlst_3 = NULL;
float* affine1_rslt = hidden(pool_rslt_3); float* affine1_rslt = hidden(pool_rslt_3);
float* affine2_rslt = output(affine1_rslt); float* affine2_rslt = output(fc2_weight, affine1_rslt);
printf("概率:%f\r\n",calculate_probabilities(affine2_rslt)); printf("概率:%f\r\n",calculate_probabilities(fc2_weight, affine2_rslt));
printf("Label is:%d\r\n",calculate_layer(affine2_rslt)); printf("Label is:%d\r\n",calculate_layer(fc2_weight, affine2_rslt));
free(pool_rslt_3); free(pool_rslt_3);
pool_rslt_3 = NULL; pool_rslt_3 = NULL;

View File

@@ -383,6 +383,7 @@ void model_init(){
fc2_weight.name = "fc2_weight"; fc2_weight.name = "fc2_weight";
fc2_weight.array = modelmym_init(fc2_weight.name); fc2_weight.array = modelmym_init(fc2_weight.name);
fc2_weight.maxlength = FC2_WEIGHT_ARRSIZE; fc2_weight.maxlength = FC2_WEIGHT_ARRSIZE;
fc2_weight.num_kernels = FC2_WEIGHT_KERNELS;
data.name = "data"; data.name = "data";
data.array = modelmym_init(data.name); data.array = modelmym_init(data.name);

View File

@@ -35,6 +35,8 @@ typedef struct {
#define FC2_BIAS_ARRSIZE (7) #define FC2_BIAS_ARRSIZE (7)
#define FC2_WEIGHT_ARRSIZE (7*128) //896 #define FC2_WEIGHT_ARRSIZE (7*128) //896
#define FC2_WEIGHT_KERNELS 4 //4个卷积核
#define is1250000 1 #define is1250000 1
#if is1250000 #if is1250000
#define DATA_ARRSIZE (1250000) #define DATA_ARRSIZE (1250000)

6
main.c
View File

@@ -99,10 +99,10 @@ void run_dataset(){
int main(){ int main(){
model_init(); model_init();
model_write("all"); model_write("all");
// run_dataset(); run_dataset();
model_switchdata("C1autosave00095_right_new_2"); // model_switchdata("7E29181C 2024-11-12 19-13-55");
cnn_run(); // cnn_run();
DEBUG_PRINTF("\r\nEnd结束"); DEBUG_PRINTF("\r\nEnd结束");
} }