实现cnn计算功能,修复switchdata的bug
This commit is contained in:
49
cnn.c
49
cnn.c
@@ -19,6 +19,8 @@ float* expand(const float* old_matrix, int old_matrix_length, int layer){
|
|||||||
return new_matrix;
|
return new_matrix;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
//model 模型名字
|
//model 模型名字
|
||||||
//input_matrix 输入图像
|
//input_matrix 输入图像
|
||||||
//input_matrix_length 输入图像的边长:102
|
//input_matrix_length 输入图像的边长:102
|
||||||
@@ -81,6 +83,7 @@ float* convolution(Model model_w, Model model_b, const float* input_matrix, int
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
//num_kernels 卷积核的个数:32
|
//num_kernels 卷积核的个数:32
|
||||||
//area 池化的面积:2*2
|
//area 池化的面积:2*2
|
||||||
//input_matrix 输入图像
|
//input_matrix 输入图像
|
||||||
@@ -121,6 +124,7 @@ float* pooling(Model model_w, const float* input_matrix, u8 input_matrix_length)
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void print_rslt(float* rslt, u8 input_matrix_length, u32 length){
|
void print_rslt(float* rslt, u8 input_matrix_length, u32 length){
|
||||||
int _tmp = 0;
|
int _tmp = 0;
|
||||||
printf("[0:0]");
|
printf("[0:0]");
|
||||||
@@ -130,57 +134,27 @@ void print_rslt(float* rslt, u8 input_matrix_length, u32 length){
|
|||||||
printf("\n[%d:%d]",++_tmp,i+1);
|
printf("\n[%d:%d]",++_tmp,i+1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
printf("\r\n\r\n");
|
printf("\r\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void cnn_run(){
|
||||||
|
|
||||||
|
|
||||||
int main(){
|
|
||||||
model_init();
|
|
||||||
model_write("all");
|
|
||||||
model_switchdata("data");
|
|
||||||
|
|
||||||
//第一层:填充102 * 102
|
//第一层:填充102 * 102
|
||||||
float* expand_matrix_1 = expand(data.array, 100, 1);
|
float* expand_matrix_1 = expand(data.array, 100, 1);
|
||||||
// print_rslt(expand_matrix_1, 102, (1*10*102));
|
|
||||||
float* conv_rlst_1 = convolution(conv1_weight,conv1_bias,expand_matrix_1, 102);
|
float* conv_rlst_1 = convolution(conv1_weight,conv1_bias,expand_matrix_1, 102);
|
||||||
// print_rslt(conv_rlst_1, 100*0.01, (0.01*10*100));
|
|
||||||
float* pool_rslt_1 = pooling(conv1_weight, conv_rlst_1, 100);
|
float* pool_rslt_1 = pooling(conv1_weight, conv_rlst_1, 100);
|
||||||
// print_rslt(pool_rslt_1, 50, (1*50*50));
|
|
||||||
|
|
||||||
//第二层:填充32 * 52 * 52
|
//第二层:填充32 * 52 * 52
|
||||||
float* expand_matrix_2 = expand(pool_rslt_1, 50, 32);
|
float* expand_matrix_2 = expand(pool_rslt_1, 50, 32);
|
||||||
// print_rslt(expand_matrix_2, 52, (1*10*52));
|
|
||||||
float* conv_rlst_2 = convolution(conv2_weight,conv2_bias,expand_matrix_2, 52);
|
float* conv_rlst_2 = convolution(conv2_weight,conv2_bias,expand_matrix_2, 52);
|
||||||
// print_rslt(conv_rlst_2, 50, (64*50*50));
|
|
||||||
float* pool_rslt_2 = pooling(conv2_weight, conv_rlst_2, 50);
|
float* pool_rslt_2 = pooling(conv2_weight, conv_rlst_2, 50);
|
||||||
// print_rslt(pool_rslt_2, 25, (1*25*25));
|
|
||||||
|
|
||||||
//第三层:填充 64 * 27 * 27
|
//第三层:填充 64 * 27 * 27
|
||||||
float* expand_matrix_3 = expand(pool_rslt_2, 25, 64);
|
float* expand_matrix_3 = expand(pool_rslt_2, 25, 64);
|
||||||
// print_rslt(expand_matrix_2, 52, (1*52*52));
|
|
||||||
float* conv_rlst_3 = convolution(conv3_weight,conv3_bias,expand_matrix_3, 27);
|
float* conv_rlst_3 = convolution(conv3_weight,conv3_bias,expand_matrix_3, 27);
|
||||||
print_rslt(conv_rlst_3, 25, (1*25*25));
|
|
||||||
float* pool_rslt_3 = pooling(conv3_weight, conv_rlst_3, 25);
|
float* pool_rslt_3 = pooling(conv3_weight, conv_rlst_3, 25);
|
||||||
// print_rslt(pool_rslt_3, 12, (1*12*12));
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
{
|
||||||
// 隐藏层参数地址
|
// 隐藏层参数地址
|
||||||
float *affine1_rslt = (float *) malloc(sizeof(float)*128);
|
float *affine1_rslt = (float *) malloc(sizeof(float)*128);
|
||||||
memset(affine1_rslt, 0, sizeof(float)*128);
|
memset(affine1_rslt, 0, sizeof(float)*128);
|
||||||
@@ -207,7 +181,7 @@ int main(){
|
|||||||
affine1_rslt[n] = 0; // 否则存入0
|
affine1_rslt[n] = 0; // 否则存入0
|
||||||
}
|
}
|
||||||
|
|
||||||
print_rslt(affine1_rslt,1,128);
|
// print_rslt(affine1_rslt,1,128);
|
||||||
|
|
||||||
|
|
||||||
float affine2_temp; // 临时变量,用于存储输出层的中间结果
|
float affine2_temp; // 临时变量,用于存储输出层的中间结果
|
||||||
@@ -240,9 +214,8 @@ int main(){
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
print_rslt(affine2_rslt,1,7);
|
print_rslt(affine2_rslt,7,7);
|
||||||
|
|
||||||
printf("Label is:%d",predict_num+1);
|
printf("Label is:%d\r\n",predict_num+1);
|
||||||
|
}
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
|
|||||||
12
cnn_model.c
12
cnn_model.c
@@ -194,8 +194,12 @@ u8 model_write(char* model_name)
|
|||||||
return 200;
|
return 200;
|
||||||
}
|
}
|
||||||
|
|
||||||
sprintf(_path, "./dataset/%s.txt", _model -> name);
|
sprintf(_path, "./dataset/%s.txt", _model -> dname ? _model -> dname : _model -> name);
|
||||||
FILE *file = fopen(_path, "r");
|
FILE *file = fopen(_path, "r");
|
||||||
|
if(file == NULL){
|
||||||
|
DEBUG_PRINTF("文件[%s]无法打开\r\n", _model -> dname ? _model -> dname : _model -> name);
|
||||||
|
return 199;
|
||||||
|
}
|
||||||
|
|
||||||
DEBUG_PRINTF("写入的模型参数名字是:%s\r\n", _model -> name);
|
DEBUG_PRINTF("写入的模型参数名字是:%s\r\n", _model -> name);
|
||||||
if(_model -> dname)DEBUG_PRINTF("写入的Data数据集是:%s\r\n", _model -> dname);
|
if(_model -> dname)DEBUG_PRINTF("写入的Data数据集是:%s\r\n", _model -> dname);
|
||||||
@@ -227,10 +231,10 @@ u8 model_write(char* model_name)
|
|||||||
if(_larr >= _model -> maxlength)break;
|
if(_larr >= _model -> maxlength)break;
|
||||||
}
|
}
|
||||||
DEBUG_PRINTF("\r\n模型参数[%s]已写入到内存中! 模型长度为 %d\r\n",_model -> dname ? _model -> dname : _model -> name,_model -> realength);
|
DEBUG_PRINTF("\r\n模型参数[%s]已写入到内存中! 模型长度为 %d\r\n",_model -> dname ? _model -> dname : _model -> name,_model -> realength);
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -273,7 +277,7 @@ u8 model_switchdata(char* data_name){
|
|||||||
return 0;
|
return 0;
|
||||||
}else{
|
}else{
|
||||||
u8 _res = model_write(data_name);
|
u8 _res = model_write(data_name);
|
||||||
if (_res == 0) {
|
if (_res) {
|
||||||
DEBUG_PRINTF("Data数据集[%s]切换失败!!\r\n",data_name);
|
DEBUG_PRINTF("Data数据集[%s]切换失败!!\r\n",data_name);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|||||||
2
debug.c
2
debug.c
@@ -1,7 +1,7 @@
|
|||||||
#include "debug.h"
|
#include "debug.h"
|
||||||
|
|
||||||
|
|
||||||
u8 _DEBUG = 1;
|
u8 _DEBUG = 0;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
104
main.c
104
main.c
@@ -1,20 +1,88 @@
|
|||||||
////
|
|
||||||
//// Created by Qi on 2024/11/9.
|
|
||||||
////
|
|
||||||
//
|
//
|
||||||
//#include "cnn_model.h"
|
// Created by Qi on 2024/11/9.
|
||||||
//
|
//
|
||||||
//void main(){
|
|
||||||
// u8 res;
|
#include "cnn_model.h"
|
||||||
// model_init();
|
#include "cnn.h"
|
||||||
// model_write("all");
|
|
||||||
// model_switchdata("C1autosave00095_right_new_2");
|
|
||||||
//
|
void run_dataset(){
|
||||||
//
|
char* modelusearr[] = {
|
||||||
//
|
"C1autosave00095_right_new_2",
|
||||||
//
|
"C1autosave00096_right_new_2",
|
||||||
//
|
"C1autosave00097_right_new_2",
|
||||||
//
|
"C1autosave00098_right_new_2",
|
||||||
// model_info("all");
|
"C1autosave00099_right_new_2",
|
||||||
// DEBUG_PRINTF("\r\nEnd结束");
|
"C1autosave00100_right_new_2",
|
||||||
//}
|
"C1autosave00101_right_new_2",
|
||||||
|
"C1autosave00102_right_new_2",
|
||||||
|
"C1autosave00103_right_new_2",
|
||||||
|
"C1autosave00104_right_new_2",
|
||||||
|
"C1autosave00105_right_new_2",
|
||||||
|
"C1autosave00106_right_new_2",
|
||||||
|
"C1autosave00107_right_new_2",
|
||||||
|
"C1autosave00108_right_new_2",
|
||||||
|
"C1autosave00109_right_new_2",
|
||||||
|
"C1autosave00110_right_new_2",
|
||||||
|
"C1autosave00111_right_new_2",
|
||||||
|
"C1autosave00112_right_new_2",
|
||||||
|
"C1autosave00113_right_new_2",
|
||||||
|
"C1autosave00114_right_new_2",
|
||||||
|
"C1autosave00115_right_new_2",
|
||||||
|
"C1autosave00116_right_new_2",
|
||||||
|
"C1autosave00117_right_new_2",
|
||||||
|
"C1autosave00118_right_new_2",
|
||||||
|
"C1autosave00119_right_new_2",
|
||||||
|
"C1autosave00120_right_new_2",
|
||||||
|
"C1autosave00121_right_new_2",
|
||||||
|
"C1autosave00122_right_new_2",
|
||||||
|
"C1autosave00123_right_new_2",
|
||||||
|
"C1autosave00124_right_new_2",
|
||||||
|
|
||||||
|
"filtered_C1autosave00011_right_new",
|
||||||
|
"filtered_C1autosave00015_right_new",
|
||||||
|
"filtered_C1autosave00043_right_new",
|
||||||
|
"filtered_C1autosave00067_right_new",
|
||||||
|
"filtered_C1autosave00090_right_new",
|
||||||
|
"filtered_C1autosave00106_right_new",
|
||||||
|
"filtered_C1autosave00118_right_new",
|
||||||
|
|
||||||
|
"filtered_C1autosave00007_right_new",
|
||||||
|
"filtered_C1autosave00035_right_new",
|
||||||
|
"filtered_C1autosave00036_right_new",
|
||||||
|
"filtered_C1autosave00040_right_new",
|
||||||
|
"filtered_C1autosave00053_right_new",
|
||||||
|
"filtered_C1autosave00061_right_new",
|
||||||
|
"filtered_C1autosave00074_right_new",
|
||||||
|
"filtered_C1autosave00077_right_new",
|
||||||
|
"filtered_C1autosave00080_right_new",
|
||||||
|
"filtered_C1autosave00085_right_new",
|
||||||
|
"filtered_C1autosave00098_right_new",
|
||||||
|
"filtered_C1autosave00100_right_new",
|
||||||
|
"filtered_C1autosave00104_right_new",
|
||||||
|
"filtered_C1autosave00122_right_new",
|
||||||
|
"filtered_C1autosave00124_right_new",
|
||||||
|
|
||||||
|
"filtered_C1autosave00108_right_new",
|
||||||
|
|
||||||
|
"filtered_C1autosave00004_right_new",
|
||||||
|
"filtered_C1autosave00039_right_new",
|
||||||
|
"filtered_C1autosave00062_right_new",
|
||||||
|
};
|
||||||
|
for(int a=0;a<(sizeof(modelusearr) / sizeof(modelusearr[0]));a++){
|
||||||
|
SDRAM_USED();
|
||||||
|
model_switchdata(modelusearr[a]);
|
||||||
|
cnn_run();
|
||||||
|
}
|
||||||
|
printf("\r\n运行完成\r\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
int main(){
|
||||||
|
model_init();
|
||||||
|
model_write("all");
|
||||||
|
|
||||||
|
run_dataset();
|
||||||
|
|
||||||
|
DEBUG_PRINTF("\r\nEnd结束");
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user