Files
c-cnn/cnn_model.c
2024-11-11 11:04:38 +08:00

389 lines
14 KiB
C
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "cnn_model.h"
Model conv1_bias;
Model conv1_weight;
Model conv2_bias;
Model conv2_weight;
Model conv3_bias;
Model conv3_weight;
Model fc1_bias;
Model fc1_weight;
Model fc2_bias;
Model fc2_weight;
Model data;
float* modelmym_init(char* model_name){
if(conv1_bias.array == NULL && strcmp(model_name, "conv1_bias") == 0)
return conv1_bias.array = (float*)malloc(CONV1_BIAS_ARRSIZE * sizeof(float));
else if(conv1_weight.array == NULL && strcmp(model_name, "conv1_weight") == 0)
return conv1_weight.array = (float*)malloc(CONV1_WEIGHT_ARRSIZE * sizeof(float));
else if(conv2_bias.array == NULL && strcmp(model_name, "conv2_bias") == 0)
return conv2_bias.array = (float*)malloc(CONV2_BIAS_ARRSIZE * sizeof(float));
else if(conv2_weight.array == NULL && strcmp(model_name, "conv2_weight") == 0)
return conv2_weight.array = (float*)malloc(CONV2_WEIGHT_ARRSIZE * sizeof(float));
else if(conv3_bias.array == NULL && strcmp(model_name, "conv3_bias") == 0)
return conv3_bias.array = (float*)malloc(CONV3_BIAS_ARRSIZE * sizeof(float));
else if(conv3_weight.array == NULL && strcmp(model_name, "conv3_weight") == 0)
return conv3_weight.array = (float*)malloc(CONV3_WEIGHT_ARRSIZE * sizeof(float));
else if(fc1_bias.array == NULL && strcmp(model_name, "fc1_bias") == 0)
return fc1_bias.array = (float*)malloc(FC1_BIAS_ARRSIZE * sizeof(float));
else if(fc1_weight.array == NULL && strcmp(model_name, "fc1_weight") == 0)
return fc1_weight.array = (float*)malloc(FC1_WEIGHT_ARRSIZE * sizeof(float));
else if(fc2_bias.array == NULL && strcmp(model_name, "fc2_bias") == 0)
return fc2_bias.array = (float*)malloc(FC2_BIAS_ARRSIZE * sizeof(float));
else if(fc2_weight.array == NULL && strcmp(model_name, "fc2_weight") == 0)
return fc2_weight.array = (float*)malloc(FC2_WEIGHT_ARRSIZE * sizeof(float));
else if(data.array == NULL && strcmp(model_name, "data") == 0)
return data.array = (float*)malloc(DATA_ARRSIZE * sizeof(float));
else if(strcmp(model_name, "all") == 0){
if(conv1_bias.array == NULL)conv1_bias.array = (float*)malloc(CONV1_BIAS_ARRSIZE * sizeof(float));
if(conv1_weight.array == NULL)conv1_weight.array = (float*)malloc(CONV1_WEIGHT_ARRSIZE * sizeof(float));
if(conv2_bias.array == NULL)conv2_bias.array = (float*)malloc(CONV2_BIAS_ARRSIZE * sizeof(float));
if(conv2_weight.array == NULL)conv2_weight.array = (float*)malloc(CONV2_WEIGHT_ARRSIZE * sizeof(float));
if(conv3_bias.array == NULL)conv3_bias.array = (float*)malloc(CONV3_BIAS_ARRSIZE * sizeof(float));
if(conv3_weight.array == NULL)conv3_weight.array = (float*)malloc(CONV3_WEIGHT_ARRSIZE * sizeof(float));
if(fc1_bias.array == NULL)fc1_bias.array = (float*)malloc(FC1_BIAS_ARRSIZE * sizeof(float));
if(fc1_weight.array == NULL)fc1_weight.array = (float*)malloc(FC1_WEIGHT_ARRSIZE * sizeof(float));
if(fc2_bias.array == NULL)fc2_bias.array = (float*)malloc(FC2_BIAS_ARRSIZE * sizeof(float));
if(fc2_weight.array == NULL)fc2_weight.array = (float*)malloc(FC2_WEIGHT_ARRSIZE * sizeof(float));
if(data.array == NULL)data.array = (float*)malloc(DATA_ARRSIZE * sizeof(float));
}
return NULL;
}
u8 modelmym_free(char* model_name){
if(conv1_bias.array != NULL && strcmp(model_name, "conv1_bias") == 0){
free(conv1_bias.array);
conv1_bias.array = NULL;
conv1_bias.realength = 0;
return 1;
}
else if(conv1_weight.array != NULL && strcmp(model_name, "conv1_weight") == 0){
free(conv1_weight.array);
conv1_weight.array = NULL;
conv1_weight.realength = 0;
return 1;
}
else if(conv2_bias.array != NULL && strcmp(model_name, "conv2_bias") == 0){
free(conv2_bias.array);
conv2_bias.array = NULL;
conv2_bias.realength = 0;
return 1;
}
else if(conv2_weight.array != NULL && strcmp(model_name, "conv2_weight") == 0){
free(conv2_weight.array);
conv2_weight.array = NULL;
conv2_weight.realength = 0;
return 1;
}
else if(conv3_bias.array != NULL && strcmp(model_name, "conv3_bias") == 0){
free(conv3_bias.array);
conv3_bias.array = NULL;
conv3_bias.realength = 0;
return 1;
}
else if(conv3_weight.array != NULL && strcmp(model_name, "conv3_weight") == 0){
free(conv3_weight.array);
conv3_weight.array = NULL;
conv3_weight.realength = 0;
return 1;
}
else if(fc1_bias.array != NULL && strcmp(model_name, "fc1_bias") == 0){
free(fc1_bias.array);
fc1_bias.array = NULL;
fc1_bias.realength = 0;
return 1;
}
else if(fc1_weight.array != NULL && strcmp(model_name, "fc1_weight") == 0){
free(fc1_weight.array);
fc1_weight.array = NULL;
fc1_weight.realength = 0;
return 1;
}
else if(fc2_bias.array != NULL && strcmp(model_name, "fc2_bias") == 0){
free(fc2_bias.array);
fc2_bias.array = NULL;
fc2_bias.realength = 0;
return 1;
}
else if(fc2_weight.array != NULL && strcmp(model_name, "fc2_weight") == 0){
free(fc2_weight.array);
fc2_weight.array = NULL;
fc2_weight.realength = 0;
return 1;
}
else if(data.array != NULL && strcmp(model_name, "data") == 0){
free(data.array);
data.array = NULL;
data.realength = 0;
return 1;
}
else if(strcmp(model_name, "all") == 0){
modelmym_free("conv1_bias");
modelmym_free("conv1_weight");
modelmym_free("conv2_bias");
modelmym_free("conv2_weight");
modelmym_free("conv3_bias");
modelmym_free("conv3_weight");
modelmym_free("fc1_bias");
modelmym_free("fc1_weight");
modelmym_free("fc2_bias");
modelmym_free("fc2_weight");
modelmym_free("data");
return 2;
}
return 0;
}
u8 model_write(char* model_name)
{
if(strcmp(model_name, "all") == 0){
model_write("conv1_bias");
model_write("conv1_weight");
model_write("conv2_bias");
model_write("conv2_weight");
model_write("conv3_bias");
model_write("conv3_weight");
model_write("fc1_bias");
model_write("fc1_weight");
model_write("fc2_bias");
model_write("fc2_weight");
model_info("all");
SDRAM_USED();
}else{
u8 _times=0;
u32 _larr = 0;
u8 _len = strlen(model_name);
char _path[_len+1+7+30];
char _fstr[READLENGTH+1] = {0};
int progress;
Model *_model = model(model_name);
if(_model == NULL || strcmp(model_name, "data") == 0){
sprintf(_path, "./dataset/%s.txt", model_name);
FILE *file = fopen(_path, "r");
if(file == NULL){
DEBUG_PRINTF("\r\n输入了一个无效的模型或Data数据集的名字\r\n");
return 199;
}else{
_model = model("data");
_model -> dname = model_name;
}
}
if(_model -> dname == NULL){
sprintf(_path, "./dataset/%s.txt", _model -> name);
FILE *file = fopen(_path, "r");
if(file == NULL){
DEBUG_PRINTF("预设里没有这个模型:[%s]\r\n", _path);
return 4;
}
}
if(_model -> array == NULL && modelmym_init(_model -> name) == NULL){
DEBUG_PRINTF("无法创建模型参数[%s]的数组到SDRAM里\r\n", _model -> name);
return 200;
}
sprintf(_path, "./dataset/%s.txt", _model -> name);
FILE *file = fopen(_path, "r");
DEBUG_PRINTF("写入的模型参数名字是:%s\r\n", _model -> name);
if(_model -> dname)DEBUG_PRINTF("写入的Data数据集是%s\r\n", _model -> dname);
DEBUG_PRINTF("写入模型参数数组的最大长度为:%d\r\n", _model -> maxlength);
DEBUG_PRINTF("目前数组存活的元素数量为:%d", _model -> realength);
printf("\r\n正在写入模型参数'%s',请稍后......\r\n",_model -> dname ? _model -> dname : _model -> name);
while(_larr < _model->maxlength && fgets(_fstr, sizeof(_fstr), file) != NULL){
char *endptr;
float value = strtof(_fstr, &endptr);
// 检查是否转换成功
if (endptr == _fstr) {
fprintf(stderr, "第 %d 行不是有效的浮点数: %s", _larr + 1, _fstr);
continue; // 跳过无效行
}
_model->array[_larr++] = value;
_model->realength++;
if(_model -> maxlength >= 73728 && (_larr >= (_model -> maxlength/10)*_times)){
progress = _larr >= _model -> maxlength ? 100 : _times++ == 0 ? 0 : progress + 10;
DEBUG_PRINTF("\r\n[");
for(u16 j=0; j<50;j++){
if(j < progress/2) DEBUG_PRINTF("=");
else DEBUG_PRINTF(" ");
}
DEBUG_PRINTF("] %d%%", progress);
}
if(_larr >= _model -> maxlength)break;
}
DEBUG_PRINTF("\r\n模型参数[%s]已写入到内存中! 模型长度为 %d\r\n",_model -> dname ? _model -> dname : _model -> name,_model -> realength);
return 1;
}
return 0;
}
u8 model_read(char* model_name, u32 start, u32 end, u32 gap){
if(strcmp(model_name, "all") == 0){
model_read("conv1_bias", start, end, gap);
model_read("conv1_weight", start, end, gap);
model_read("conv2_bias", start, end, gap);
model_read("conv2_weight", start, end, gap);
model_read("conv3_bias", start, end, gap);
model_read("conv3_weight", start, end, gap);
model_read("fc1_bias", start, end, gap);
model_read("fc1_weight", start, end, gap);
model_read("fc2_bias", start, end, gap);
model_read("fc2_weight", start, end, gap);
model_read("data", start, end, gap);
}else{
Model *_model = model(model_name);
if(_model == NULL || end == 0 || start > _model -> realength || end > _model -> realength)return 0;
if(_model -> realength){
for (u32 i=0;i<((end > (_model -> realength) ? _model -> realength : (end-start))+(end > (_model -> realength) ? _model -> realength : (end-start)%gap ? 2 : 1));i+=gap)
printf("\r\n%s_floatArray[%d]: %f",_model->name,(i+start)<_model->realength ? i+start : _model->realength-1,(i+start)<_model->realength ? _model->array[i+start] : _model->array[_model -> realength-1]);
printf("\r\n");
return 1;
}
}
return 0;
}
u8 model_switchdata(char* data_name){
u8 _len = strlen(data_name);
char _path[_len+1+7];
if(data.array != NULL)modelmym_free("data");
sprintf(_path, "./dataset/%s.txt", data_name);
FILE *file = fopen(_path, "r");
if(file == NULL){
DEBUG_PRINTF("\r\nData数据集[%s]不存在\r\n",data_name);
return 0;
}else{
u8 _res = model_write(data_name);
if (_res == 0) {
DEBUG_PRINTF("Data数据集[%s]切换失败!!\r\n",data_name);
return 0;
}
else printf("Data数据集[%s]切换成功!\r\n",data_name);
DEBUG_PRINTF("data_name的长度为%d\r\n_path的长度为%d\r\n_path为%s\r\n",_len,sizeof(_path),_path);
return 1;
}
}
void model_dataset(){
printf("\r\ndataset is: %s\r\n",data.dname);
}
u8 model_info(char* model_name){
if(strcmp(model_name, "all") == 0){
model_info("conv1_bias");
model_info("conv1_weight");
model_info("conv2_bias");
model_info("conv2_weight");
model_info("conv3_bias");
model_info("conv3_weight");
model_info("fc1_bias");
model_info("fc1_weight");
model_info("fc2_bias");
model_info("fc2_weight");
model_info("data");
return 1;
}else if(model(model_name)){
Model *_model = model(model_name);
printf("\r\nmodel.name is: %s\r\n",_model -> name);
printf("model.array.address is: 0X%X\r\n",_model -> array);
printf("model.maxlength is: %d\r\n",_model -> maxlength);
printf("model.realength is: %d\r\n",_model -> realength);
//if(strcmp(_model -> name, "data") == 0)printf("dataset is: %s\r\n",_model -> dname); 这函数有BUG用model_dataset函数
return 1;
}
return 0;
}
void* model(char* model_name){
if(strcmp(model_name, "conv1_bias") == 0)return &conv1_bias;
else if(strcmp(model_name, "conv1_weight") == 0)return &conv1_weight;
else if(strcmp(model_name, "conv2_bias") == 0)return &conv2_bias;
else if(strcmp(model_name, "conv2_weight") == 0)return &conv2_weight;
else if(strcmp(model_name, "conv3_bias") == 0)return &conv3_bias;
else if(strcmp(model_name, "conv3_weight") == 0)return &conv3_weight;
else if(strcmp(model_name, "fc1_bias") == 0)return &fc1_bias;
else if(strcmp(model_name, "fc1_weight") == 0)return &fc1_weight;
else if(strcmp(model_name, "fc2_bias") == 0)return &fc2_bias;
else if(strcmp(model_name, "fc2_weight") == 0)return &fc2_weight;
else if(strcmp(model_name, "data") == 0)return &data;
return NULL;
}
void model_init(){
conv1_bias.name = "conv1_bias";
conv1_bias.array = modelmym_init(conv1_bias.name);
conv1_bias.maxlength = CONV1_BIAS_ARRSIZE;
conv1_weight.name = "conv1_weight";
conv1_weight.array = modelmym_init(conv1_weight.name);
conv1_weight.maxlength = CONV1_WEIGHT_ARRSIZE;
conv1_weight.channel = 1;
conv1_weight.num_kernels = 32;
conv2_bias.name = "conv2_bias";
conv2_bias.array = modelmym_init(conv2_bias.name);
conv2_bias.maxlength = CONV2_BIAS_ARRSIZE;
conv2_weight.name = "conv2_weight";
conv2_weight.array = modelmym_init(conv2_weight.name);
conv2_weight.maxlength = CONV2_WEIGHT_ARRSIZE;
conv2_weight.channel = 32;
conv2_weight.num_kernels = 64;
conv3_bias.name = "conv3_bias";
conv3_bias.array = modelmym_init(conv3_bias.name);
conv3_bias.maxlength = CONV3_BIAS_ARRSIZE;
conv3_weight.name = "conv3_weight";
conv3_weight.array = modelmym_init(conv3_weight.name);
conv3_weight.maxlength = CONV3_WEIGHT_ARRSIZE;
conv3_weight.channel = 64;
conv3_weight.num_kernels = 128;
fc1_bias.name = "fc1_bias";
fc1_bias.array = modelmym_init(fc1_bias.name);
fc1_bias.maxlength = FC1_BIAS_ARRSIZE;
fc1_weight.name = "fc1_weight";
fc1_weight.array = modelmym_init(fc1_weight.name);
fc1_weight.maxlength = FC1_WEIGHT_ARRSIZE;
fc2_bias.name = "fc2_bias";
fc2_bias.array = modelmym_init(fc2_bias.name);
fc2_bias.maxlength = FC2_BIAS_ARRSIZE;
fc2_weight.name = "fc2_weight";
fc2_weight.array = modelmym_init(fc2_weight.name);
fc2_weight.maxlength = FC2_WEIGHT_ARRSIZE;
data.name = "data";
data.array = modelmym_init(data.name);
data.maxlength = DATA_ARRSIZE;
data.dname = "data";
}