Files
stm32-cnn/PORTING/CNN/cnn_model.c
2024-12-19 14:06:05 +08:00

412 lines
15 KiB
C
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "cnn_model.h"
Model conv1_bias;
Model conv1_weight;
Model conv2_bias;
Model conv2_weight;
Model conv3_bias;
Model conv3_weight;
Model fc1_bias;
Model fc1_weight;
Model fc2_bias;
Model fc2_weight;
Model data;
float data_array[DATA_ARRSIZE] __attribute__((at(0XC0009C40)));
float* modelmym_init(char* model_name){
if(conv1_bias.array == NULL && strcmp(model_name, "conv1_bias") == 0)
return conv1_bias.array = (float*)mymalloc(SRAMEX, CONV1_BIAS_ARRSIZE * sizeof(float));
else if(conv1_weight.array == NULL && strcmp(model_name, "conv1_weight") == 0)
return conv1_weight.array = (float*)mymalloc(SRAMEX, CONV1_WEIGHT_ARRSIZE * sizeof(float));
else if(conv2_bias.array == NULL && strcmp(model_name, "conv2_bias") == 0)
return conv2_bias.array = (float*)mymalloc(SRAMEX, CONV2_BIAS_ARRSIZE * sizeof(float));
else if(conv2_weight.array == NULL && strcmp(model_name, "conv2_weight") == 0)
return conv2_weight.array = (float*)mymalloc(SRAMEX, CONV2_WEIGHT_ARRSIZE * sizeof(float));
else if(conv3_bias.array == NULL && strcmp(model_name, "conv3_bias") == 0)
return conv3_bias.array = (float*)mymalloc(SRAMEX, CONV3_BIAS_ARRSIZE * sizeof(float));
else if(conv3_weight.array == NULL && strcmp(model_name, "conv3_weight") == 0)
return conv3_weight.array = (float*)mymalloc(SRAMEX, CONV3_WEIGHT_ARRSIZE * sizeof(float));
else if(fc1_bias.array == NULL && strcmp(model_name, "fc1_bias") == 0)
return fc1_bias.array = (float*)mymalloc(SRAMEX, FC1_BIAS_ARRSIZE * sizeof(float));
else if(fc1_weight.array == NULL && strcmp(model_name, "fc1_weight") == 0)
return fc1_weight.array = (float*)mymalloc(SRAMEX, FC1_WEIGHT_ARRSIZE * sizeof(float));
else if(fc2_bias.array == NULL && strcmp(model_name, "fc2_bias") == 0)
return fc2_bias.array = (float*)mymalloc(SRAMEX, FC2_BIAS_ARRSIZE * sizeof(float));
else if(fc2_weight.array == NULL && strcmp(model_name, "fc2_weight") == 0)
return fc2_weight.array = (float*)mymalloc(SRAMEX, FC2_WEIGHT_ARRSIZE * sizeof(float));
// else if(data.array == NULL && strcmp(model_name, "data") == 0)
// return data.array = (float*)mymalloc(SRAMEX, DATA_ARRSIZE * sizeof(float));
else if(strcmp(model_name, "all") == 0){
if(conv1_bias.array == NULL)conv1_bias.array = (float*)mymalloc(SRAMEX, CONV1_BIAS_ARRSIZE * sizeof(float));
if(conv1_weight.array == NULL)conv1_weight.array = (float*)mymalloc(SRAMEX, CONV1_WEIGHT_ARRSIZE * sizeof(float));
if(conv2_bias.array == NULL)conv2_bias.array = (float*)mymalloc(SRAMEX, CONV2_BIAS_ARRSIZE * sizeof(float));
if(conv2_weight.array == NULL)conv2_weight.array = (float*)mymalloc(SRAMEX, CONV2_WEIGHT_ARRSIZE * sizeof(float));
if(conv3_bias.array == NULL)conv3_bias.array = (float*)mymalloc(SRAMEX, CONV3_BIAS_ARRSIZE * sizeof(float));
if(conv3_weight.array == NULL)conv3_weight.array = (float*)mymalloc(SRAMEX, CONV3_WEIGHT_ARRSIZE * sizeof(float));
if(fc1_bias.array == NULL)fc1_bias.array = (float*)mymalloc(SRAMEX, FC1_BIAS_ARRSIZE * sizeof(float));
if(fc1_weight.array == NULL)fc1_weight.array = (float*)mymalloc(SRAMEX, FC1_WEIGHT_ARRSIZE * sizeof(float));
if(fc2_bias.array == NULL)fc2_bias.array = (float*)mymalloc(SRAMEX, FC2_BIAS_ARRSIZE * sizeof(float));
if(fc2_weight.array == NULL)fc2_weight.array = (float*)mymalloc(SRAMEX, FC2_WEIGHT_ARRSIZE * sizeof(float));
// if(data.array == NULL)data.array = (float*)mymalloc(SRAMEX, DATA_ARRSIZE * sizeof(float));
}
return NULL;
}
u8 modelmym_free(char* model_name){
if(conv1_bias.array != NULL && strcmp(model_name, "conv1_bias") == 0){
myfree(SRAMEX,conv1_bias.array);
conv1_bias.array = NULL;
conv1_bias.realength = 0;
return 1;
}
else if(conv1_weight.array != NULL && strcmp(model_name, "conv1_weight") == 0){
myfree(SRAMEX,conv1_weight.array);
conv1_weight.array = NULL;
conv1_weight.realength = 0;
return 1;
}
else if(conv2_bias.array != NULL && strcmp(model_name, "conv2_bias") == 0){
myfree(SRAMEX,conv2_bias.array);
conv2_bias.array = NULL;
conv2_bias.realength = 0;
return 1;
}
else if(conv2_weight.array != NULL && strcmp(model_name, "conv2_weight") == 0){
myfree(SRAMEX,conv2_weight.array);
conv2_weight.array = NULL;
conv2_weight.realength = 0;
return 1;
}
else if(conv3_bias.array != NULL && strcmp(model_name, "conv3_bias") == 0){
myfree(SRAMEX,conv3_bias.array);
conv3_bias.array = NULL;
conv3_bias.realength = 0;
return 1;
}
else if(conv3_weight.array != NULL && strcmp(model_name, "conv3_weight") == 0){
myfree(SRAMEX,conv3_weight.array);
conv3_weight.array = NULL;
conv3_weight.realength = 0;
return 1;
}
else if(fc1_bias.array != NULL && strcmp(model_name, "fc1_bias") == 0){
myfree(SRAMEX,fc1_bias.array);
fc1_bias.array = NULL;
fc1_bias.realength = 0;
return 1;
}
else if(fc1_weight.array != NULL && strcmp(model_name, "fc1_weight") == 0){
myfree(SRAMEX,fc1_weight.array);
fc1_weight.array = NULL;
fc1_weight.realength = 0;
return 1;
}
else if(fc2_bias.array != NULL && strcmp(model_name, "fc2_bias") == 0){
myfree(SRAMEX,fc2_bias.array);
fc2_bias.array = NULL;
fc2_bias.realength = 0;
return 1;
}
else if(fc2_weight.array != NULL && strcmp(model_name, "fc2_weight") == 0){
myfree(SRAMEX,fc2_weight.array);
fc2_weight.array = NULL;
fc2_weight.realength = 0;
return 1;
}
else if(data.array != NULL && strcmp(model_name, "data") == 0){
// myfree(SRAMEX,data.array);
memset(data.array, 0 ,data.maxlength);
// data.realength = 0;
return 1;
}
else if(strcmp(model_name, "all") == 0){
modelmym_free("conv1_bias");
modelmym_free("conv1_weight");
modelmym_free("conv2_bias");
modelmym_free("conv2_weight");
modelmym_free("conv3_bias");
modelmym_free("conv3_weight");
modelmym_free("fc1_bias");
modelmym_free("fc1_weight");
modelmym_free("fc2_bias");
modelmym_free("fc2_weight");
modelmym_free("data");
return 2;
}
return 0;
}
u8 model_write(char* model_name)
{
if(strcmp(model_name, "all") == 0){
model_write("conv1_bias");
model_write("conv1_weight");
model_write("conv2_bias");
model_write("conv2_weight");
model_write("conv3_bias");
model_write("conv3_weight");
model_write("fc1_bias");
model_write("fc1_weight");
model_write("fc2_bias");
model_write("fc2_weight");
model_info("all");
SDRAM_USED();
}else{
u8 res=0;
u8 isneg=0;
u8 _times=0;
u32 _larr = 0;
u8 _len = strlen(model_name);
char _path[_len+1+7+35];
char _datapath[_len+1+7+35];
char _fstr[READLENGTH+1] = {0};
char _sstr[2];
int progress;
Model *_model = model(model_name);
if(_model == NULL || strcmp(model_name, "data") == 0){
sprintf(_path, "dataset/_data/%s.txt", model_name);
if(f_open(file, (const TCHAR *)_path, 1)){
DEBUG_PRINTF("\r\n输入了一个无效的模型或Data数据集的名字\r\n");
return 199;
}else{
_model = model("data");
_model -> dname = model_name;
}
}
if(_model -> dname == NULL){
sprintf(_path, "dataset/%s.txt", _model -> name);
if(f_open(file, (const TCHAR *)_path, 1)){
DEBUG_PRINTF("预设里没有这个模型:[%s]\r\n", _path);
return 4;
}
}
if(_model -> array == NULL && modelmym_init(_model -> name) == NULL){
DEBUG_PRINTF("无法创建模型参数[%s]的数组到SDRAM里\r\n", _model -> name);
return 200;
}
if(_model -> dname)sprintf(_datapath, "_data/%s", _model -> dname);
sprintf(_path, "dataset/%s.txt", _model -> dname ? _datapath : _model -> name);
if(f_open(file, (const TCHAR *)_path, 1)){
DEBUG_PRINTF("文件[%s]无法打开\r\n", _model -> dname ? _model -> dname : _model -> name);
return 199;
}
DEBUG_PRINTF("写入的模型参数名字是:%s\r\n", _model -> name);
if(_model -> dname)DEBUG_PRINTF("写入的Data数据集是%s\r\n", _model -> dname);
DEBUG_PRINTF("写入模型参数数组的最大长度为:%d\r\n", _model -> maxlength);
DEBUG_PRINTF("目前数组存活的元素数量为:%d", _model -> realength);
printf("\r\n正在写入模型参数'%s',请稍后......\r\n",_model -> dname ? _model -> dname : _model -> name);
while(1){
res = f_read(file, fatbuf, READLENGTH ,&br);
if(res){
DEBUG_PRINTF("读文件出错,错误码为:%d\r\n",res);
return res;
}else{
for(int i=0; i < br; i++){
if(fatbuf[i] == 0x0d){
float _fvalue = atof(_fstr);
if(isneg)_fvalue = -_fvalue;
_model -> array[_larr++] = _fvalue;
//DEBUG_PRINTF("回车[%d] 单行数据是[string]: %s\r\n回车[%d] 单行数据是[float]: %f\r\n",i,_fstr,_fvalue);
i++;
isneg=0;
*_fstr = NULL;
_model -> realength = _larr; //_larr最大值为模型最大长度
if(_larr >= _model -> maxlength)break;
}
else if(fatbuf[i] == 0x2d)isneg = 1;
else{
sprintf(_sstr, "%c", fatbuf[i]);
strcat(_fstr, _sstr);
//DEBUG_PRINTF("[%d]_fstr is[%s], _sstr is[%s], fatbuf is [%c]\r\n",i,_fstr,_sstr,fatbuf[i]);
}
}
if(_model -> maxlength >= 73728 && (_larr >= (_model -> maxlength/10)*_times)){
progress = _larr >= _model -> maxlength ? 100 : _times++ == 0 ? 0 : progress + 10;
DEBUG_PRINTF("\r\n[");
for(u16 j=0; j<50;j++){
if(j < progress/2) DEBUG_PRINTF("=");
else DEBUG_PRINTF(" ");
}
DEBUG_PRINTF("] %d%%", progress);
}
if(_larr >= _model -> maxlength)break;
}
}
DEBUG_PRINTF("\r\n模型参数[%s]已写入到内存中! 模型长度为 %d\r\n",_model -> dname ? _model -> dname : _model -> name,_model -> realength);
return 0;
}
return 1;
}
u8 model_read(char* model_name, u32 start, u32 end, u32 gap){
if(strcmp(model_name, "all") == 0){
model_read("conv1_bias", start, end, gap);
model_read("conv1_weight", start, end, gap);
model_read("conv2_bias", start, end, gap);
model_read("conv2_weight", start, end, gap);
model_read("conv3_bias", start, end, gap);
model_read("conv3_weight", start, end, gap);
model_read("fc1_bias", start, end, gap);
model_read("fc1_weight", start, end, gap);
model_read("fc2_bias", start, end, gap);
model_read("fc2_weight", start, end, gap);
model_read("data", start, end, gap);
}else{
Model *_model = model(model_name);
if(_model == NULL || end == 0 || start > _model -> realength || end > _model -> realength)return NULL;
if(_model -> realength){
for (u32 i=0;i<((end > (_model -> realength) ? _model -> realength : (end-start))+(end > (_model -> realength) ? _model -> realength : (end-start)%gap ? 2 : 1));i+=gap)
printf("\r\n%s_floatArray[%d]: %f",_model->name,(i+start)<_model->realength ? i+start : _model->realength-1,(i+start)<_model->realength ? _model->array[i+start] : _model->array[_model -> realength-1]);
printf("\r\n");
return 1;
}
}
return 0;
}
u8 model_switchdata(char* data_name){
u8 _len = strlen(data_name);
char _path[_len+1+7+35];
if(data.array != NULL)modelmym_free("data");
sprintf(_path, "dataset/_data/%s.txt",data_name);
if(f_open(file,(const TCHAR*)_path,1)){
DEBUG_PRINTF("\r\nData数据集[%s]不存在\r\n",data_name);
return 0;
}else{
u8 _res = model_write(data_name);
if (_res) {
DEBUG_PRINTF("Data数据集[%s]切换失败!!\r\n",data_name);
return 0;
}
else DEBUG_PRINTF("Data数据集[%s]切换成功!\r\n",data_name);
DEBUG_PRINTF("data_name的长度为%d\r\n_path的长度为%d\r\n_path为%s\r\n",_len,sizeof(_path),_path);
return 1;
}
}
void model_dataset(){
printf("\r\ndataset is: %s\r\n",data.dname);
}
u8 model_info(char* model_name){
if(strcmp(model_name, "all") == 0){
model_info("conv1_bias");
model_info("conv1_weight");
model_info("conv2_bias");
model_info("conv2_weight");
model_info("conv3_bias");
model_info("conv3_weight");
model_info("fc1_bias");
model_info("fc1_weight");
model_info("fc2_bias");
model_info("fc2_weight");
model_info("data");
return 1;
}else if(model(model_name)){
Model *_model = model(model_name);
printf("\r\nmodel.name is: %s\r\n",_model -> name);
printf("model.array.address is: 0X%X\r\n",_model -> array);
printf("model.maxlength is: %d\r\n",_model -> maxlength);
printf("model.realength is: %d\r\n",_model -> realength);
//if(strcmp(_model -> name, "data") == 0)printf("dataset is: %s\r\n",_model -> dname); 这函数有BUG用model_dataset函数
return 1;
}
return 0;
}
void* model(char* model_name){
if(strcmp(model_name, "conv1_bias") == 0)return &conv1_bias;
else if(strcmp(model_name, "conv1_weight") == 0)return &conv1_weight;
else if(strcmp(model_name, "conv2_bias") == 0)return &conv2_bias;
else if(strcmp(model_name, "conv2_weight") == 0)return &conv2_weight;
else if(strcmp(model_name, "conv3_bias") == 0)return &conv3_bias;
else if(strcmp(model_name, "conv3_weight") == 0)return &conv3_weight;
else if(strcmp(model_name, "fc1_bias") == 0)return &fc1_bias;
else if(strcmp(model_name, "fc1_weight") == 0)return &fc1_weight;
else if(strcmp(model_name, "fc2_bias") == 0)return &fc2_bias;
else if(strcmp(model_name, "fc2_weight") == 0)return &fc2_weight;
else if(strcmp(model_name, "data") == 0)return &data;
return NULL;
}
void model_init(){
conv1_bias.name = "conv1_bias";
conv1_bias.array = modelmym_init(conv1_bias.name);
conv1_bias.maxlength = CONV1_BIAS_ARRSIZE;
conv1_weight.name = "conv1_weight";
conv1_weight.array = modelmym_init(conv1_weight.name);
conv1_weight.maxlength = CONV1_WEIGHT_ARRSIZE;
conv1_weight.channel = 1;
conv1_weight.num_kernels = CONV1_BIAS_ARRSIZE;
conv2_bias.name = "conv2_bias";
conv2_bias.array = modelmym_init(conv2_bias.name);
conv2_bias.maxlength = CONV2_BIAS_ARRSIZE;
conv2_weight.name = "conv2_weight";
conv2_weight.array = modelmym_init(conv2_weight.name);
conv2_weight.maxlength = CONV2_WEIGHT_ARRSIZE;
conv2_weight.channel = 32;
conv2_weight.num_kernels = CONV2_BIAS_ARRSIZE;
conv3_bias.name = "conv3_bias";
conv3_bias.array = modelmym_init(conv3_bias.name);
conv3_bias.maxlength = CONV3_BIAS_ARRSIZE;
conv3_weight.name = "conv3_weight";
conv3_weight.array = modelmym_init(conv3_weight.name);
conv3_weight.maxlength = CONV3_WEIGHT_ARRSIZE;
conv3_weight.channel = 64;
conv3_weight.num_kernels = CONV3_BIAS_ARRSIZE;
fc1_bias.name = "fc1_bias";
fc1_bias.array = modelmym_init(fc1_bias.name);
fc1_bias.maxlength = FC1_BIAS_ARRSIZE;
fc1_weight.name = "fc1_weight";
fc1_weight.array = modelmym_init(fc1_weight.name);
fc1_weight.maxlength = FC1_WEIGHT_ARRSIZE;
fc2_bias.name = "fc2_bias";
fc2_bias.array = modelmym_init(fc2_bias.name);
fc2_bias.maxlength = FC2_BIAS_ARRSIZE;
fc2_weight.name = "fc2_weight";
fc2_weight.array = modelmym_init(fc2_weight.name);
fc2_weight.maxlength = FC2_WEIGHT_ARRSIZE;
fc2_weight.num_kernels = FC2_BIAS_ARRSIZE;
data.name = "data";
data.array = data_array;
data.maxlength = DATA_ARRSIZE;
data.realength = DATA_ARRSIZE;
data.dname = "data";
memset(data.array, 0, sizeof(float)*DATA_ARRSIZE);
}