实现cnn计算功能,修复switchdata的bug

This commit is contained in:
Qiea
2024-11-11 12:28:33 +08:00
parent fe062deba5
commit 0af7f203c7
5 changed files with 150 additions and 111 deletions

49
cnn.c
View File

@@ -19,6 +19,8 @@ float* expand(const float* old_matrix, int old_matrix_length, int layer){
return new_matrix; return new_matrix;
} }
//model 模型名字 //model 模型名字
//input_matrix 输入图像 //input_matrix 输入图像
//input_matrix_length 输入图像的边长102 //input_matrix_length 输入图像的边长102
@@ -81,6 +83,7 @@ float* convolution(Model model_w, Model model_b, const float* input_matrix, int
} }
//num_kernels 卷积核的个数32 //num_kernels 卷积核的个数32
//area 池化的面积2*2 //area 池化的面积2*2
//input_matrix 输入图像 //input_matrix 输入图像
@@ -121,6 +124,7 @@ float* pooling(Model model_w, const float* input_matrix, u8 input_matrix_length)
} }
void print_rslt(float* rslt, u8 input_matrix_length, u32 length){ void print_rslt(float* rslt, u8 input_matrix_length, u32 length){
int _tmp = 0; int _tmp = 0;
printf("[0:0]"); printf("[0:0]");
@@ -130,57 +134,27 @@ void print_rslt(float* rslt, u8 input_matrix_length, u32 length){
printf("\n[%d:%d]",++_tmp,i+1); printf("\n[%d:%d]",++_tmp,i+1);
} }
} }
printf("\r\n\r\n"); printf("\r\n");
} }
void cnn_run(){
int main(){
model_init();
model_write("all");
model_switchdata("data");
//第一层填充102 * 102 //第一层填充102 * 102
float* expand_matrix_1 = expand(data.array, 100, 1); float* expand_matrix_1 = expand(data.array, 100, 1);
// print_rslt(expand_matrix_1, 102, (1*10*102));
float* conv_rlst_1 = convolution(conv1_weight,conv1_bias,expand_matrix_1, 102); float* conv_rlst_1 = convolution(conv1_weight,conv1_bias,expand_matrix_1, 102);
// print_rslt(conv_rlst_1, 100*0.01, (0.01*10*100));
float* pool_rslt_1 = pooling(conv1_weight, conv_rlst_1, 100); float* pool_rslt_1 = pooling(conv1_weight, conv_rlst_1, 100);
// print_rslt(pool_rslt_1, 50, (1*50*50));
//第二层填充32 * 52 * 52 //第二层填充32 * 52 * 52
float* expand_matrix_2 = expand(pool_rslt_1, 50, 32); float* expand_matrix_2 = expand(pool_rslt_1, 50, 32);
// print_rslt(expand_matrix_2, 52, (1*10*52));
float* conv_rlst_2 = convolution(conv2_weight,conv2_bias,expand_matrix_2, 52); float* conv_rlst_2 = convolution(conv2_weight,conv2_bias,expand_matrix_2, 52);
// print_rslt(conv_rlst_2, 50, (64*50*50));
float* pool_rslt_2 = pooling(conv2_weight, conv_rlst_2, 50); float* pool_rslt_2 = pooling(conv2_weight, conv_rlst_2, 50);
// print_rslt(pool_rslt_2, 25, (1*25*25));
//第三层:填充 64 * 27 * 27 //第三层:填充 64 * 27 * 27
float* expand_matrix_3 = expand(pool_rslt_2, 25, 64); float* expand_matrix_3 = expand(pool_rslt_2, 25, 64);
// print_rslt(expand_matrix_2, 52, (1*52*52));
float* conv_rlst_3 = convolution(conv3_weight,conv3_bias,expand_matrix_3, 27); float* conv_rlst_3 = convolution(conv3_weight,conv3_bias,expand_matrix_3, 27);
print_rslt(conv_rlst_3, 25, (1*25*25));
float* pool_rslt_3 = pooling(conv3_weight, conv_rlst_3, 25); float* pool_rslt_3 = pooling(conv3_weight, conv_rlst_3, 25);
// print_rslt(pool_rslt_3, 12, (1*12*12));
{
// 隐藏层参数地址 // 隐藏层参数地址
float *affine1_rslt = (float *) malloc(sizeof(float)*128); float *affine1_rslt = (float *) malloc(sizeof(float)*128);
memset(affine1_rslt, 0, sizeof(float)*128); memset(affine1_rslt, 0, sizeof(float)*128);
@@ -207,7 +181,7 @@ int main(){
affine1_rslt[n] = 0; // 否则存入0 affine1_rslt[n] = 0; // 否则存入0
} }
print_rslt(affine1_rslt,1,128); // print_rslt(affine1_rslt,1,128);
float affine2_temp; // 临时变量,用于存储输出层的中间结果 float affine2_temp; // 临时变量,用于存储输出层的中间结果
@@ -240,9 +214,8 @@ int main(){
} }
} }
print_rslt(affine2_rslt,1,7); print_rslt(affine2_rslt,7,7);
printf("Label is:%d",predict_num+1); printf("Label is:%d\r\n",predict_num+1);
}
return 0;
} }

8
cnn.h
View File

@@ -6,13 +6,7 @@
void cnn_run(void);

View File

@@ -194,8 +194,12 @@ u8 model_write(char* model_name)
return 200; return 200;
} }
sprintf(_path, "./dataset/%s.txt", _model -> name); sprintf(_path, "./dataset/%s.txt", _model -> dname ? _model -> dname : _model -> name);
FILE *file = fopen(_path, "r"); FILE *file = fopen(_path, "r");
if(file == NULL){
DEBUG_PRINTF("文件[%s]无法打开\r\n", _model -> dname ? _model -> dname : _model -> name);
return 199;
}
DEBUG_PRINTF("写入的模型参数名字是:%s\r\n", _model -> name); DEBUG_PRINTF("写入的模型参数名字是:%s\r\n", _model -> name);
if(_model -> dname)DEBUG_PRINTF("写入的Data数据集是%s\r\n", _model -> dname); if(_model -> dname)DEBUG_PRINTF("写入的Data数据集是%s\r\n", _model -> dname);
@@ -227,10 +231,10 @@ u8 model_write(char* model_name)
if(_larr >= _model -> maxlength)break; if(_larr >= _model -> maxlength)break;
} }
DEBUG_PRINTF("\r\n模型参数[%s]已写入到内存中! 模型长度为 %d\r\n",_model -> dname ? _model -> dname : _model -> name,_model -> realength); DEBUG_PRINTF("\r\n模型参数[%s]已写入到内存中! 模型长度为 %d\r\n",_model -> dname ? _model -> dname : _model -> name,_model -> realength);
return 1;
}
return 0; return 0;
} }
return 1;
}
@@ -273,7 +277,7 @@ u8 model_switchdata(char* data_name){
return 0; return 0;
}else{ }else{
u8 _res = model_write(data_name); u8 _res = model_write(data_name);
if (_res == 0) { if (_res) {
DEBUG_PRINTF("Data数据集[%s]切换失败!!\r\n",data_name); DEBUG_PRINTF("Data数据集[%s]切换失败!!\r\n",data_name);
return 0; return 0;
} }

View File

@@ -1,7 +1,7 @@
#include "debug.h" #include "debug.h"
u8 _DEBUG = 1; u8 _DEBUG = 0;

104
main.c
View File

@@ -1,20 +1,88 @@
////
//// Created by Qi on 2024/11/9.
////
// //
//#include "cnn_model.h" // Created by Qi on 2024/11/9.
// //
//void main(){
// u8 res; #include "cnn_model.h"
// model_init(); #include "cnn.h"
// model_write("all");
// model_switchdata("C1autosave00095_right_new_2");
// void run_dataset(){
// char* modelusearr[] = {
// "C1autosave00095_right_new_2",
// "C1autosave00096_right_new_2",
// "C1autosave00097_right_new_2",
// "C1autosave00098_right_new_2",
// model_info("all"); "C1autosave00099_right_new_2",
// DEBUG_PRINTF("\r\nEnd结束"); "C1autosave00100_right_new_2",
//} "C1autosave00101_right_new_2",
"C1autosave00102_right_new_2",
"C1autosave00103_right_new_2",
"C1autosave00104_right_new_2",
"C1autosave00105_right_new_2",
"C1autosave00106_right_new_2",
"C1autosave00107_right_new_2",
"C1autosave00108_right_new_2",
"C1autosave00109_right_new_2",
"C1autosave00110_right_new_2",
"C1autosave00111_right_new_2",
"C1autosave00112_right_new_2",
"C1autosave00113_right_new_2",
"C1autosave00114_right_new_2",
"C1autosave00115_right_new_2",
"C1autosave00116_right_new_2",
"C1autosave00117_right_new_2",
"C1autosave00118_right_new_2",
"C1autosave00119_right_new_2",
"C1autosave00120_right_new_2",
"C1autosave00121_right_new_2",
"C1autosave00122_right_new_2",
"C1autosave00123_right_new_2",
"C1autosave00124_right_new_2",
"filtered_C1autosave00011_right_new",
"filtered_C1autosave00015_right_new",
"filtered_C1autosave00043_right_new",
"filtered_C1autosave00067_right_new",
"filtered_C1autosave00090_right_new",
"filtered_C1autosave00106_right_new",
"filtered_C1autosave00118_right_new",
"filtered_C1autosave00007_right_new",
"filtered_C1autosave00035_right_new",
"filtered_C1autosave00036_right_new",
"filtered_C1autosave00040_right_new",
"filtered_C1autosave00053_right_new",
"filtered_C1autosave00061_right_new",
"filtered_C1autosave00074_right_new",
"filtered_C1autosave00077_right_new",
"filtered_C1autosave00080_right_new",
"filtered_C1autosave00085_right_new",
"filtered_C1autosave00098_right_new",
"filtered_C1autosave00100_right_new",
"filtered_C1autosave00104_right_new",
"filtered_C1autosave00122_right_new",
"filtered_C1autosave00124_right_new",
"filtered_C1autosave00108_right_new",
"filtered_C1autosave00004_right_new",
"filtered_C1autosave00039_right_new",
"filtered_C1autosave00062_right_new",
};
for(int a=0;a<(sizeof(modelusearr) / sizeof(modelusearr[0]));a++){
SDRAM_USED();
model_switchdata(modelusearr[a]);
cnn_run();
}
printf("\r\n运行完成\r\n");
}
int main(){
model_init();
model_write("all");
run_dataset();
DEBUG_PRINTF("\r\nEnd结束");
}