搜索
bottom↓
回复: 13

分享一个简单BP神经网络的C源码

[复制链接]

出0入0汤圆

发表于 2016-5-20 20:32:28 | 显示全部楼层 |阅读模式
#include "BpNet.h"  
  
#ifdef _DEBUG  
#define new DEBUG_NEW  
#undef THIS_FILE  
static char THIS_FILE[] = __FILE__;  
#endif  
  
/////////////////////////////////////////////////////////////////////////////  
// CBpNet  
IMPLEMENT_SERIAL( CBpNet, CObject, 1 )  
  
CBpNet::CBpNet()  
{initM(MATCOM_VERSION);//启用矩阵运算库  
}  
  
CBpNet::~CBpNet()  
{exitM();  
delete this;  
}  
  
  
  
  
/////////////////////////////////////////////////////////////////////////////  
// CBpNet message handlers  
//创建新网络  
void CBpNet::Create(Mm mInputData, Mm mTarget, int iInput, int iHidden, int iOutput)  
{ int i,j;  
  mSampleInput=zeros(mInput.rows(),mInput.cols());  
  mSampleTarget=zeros(mTarget.rows(),mTarget.cols());   
  mSampleInput=mInputData;  
  mSampleTarget=mTarget;  
  this->iInput=iInput;  
  this->iHidden=iHidden;  
  this->iOutput=iOutput;  
  //创建计算用的单个样本矩阵  
  mInput=zeros(1,this->iInput);  
  mHidden=zeros(1,this->iHidden);  
  mOutput=zeros(1,this->iOutput);  
  //创建权重矩阵,并赋初值  
  mWeighti=zeros(this->iInput,this->iHidden);  
  mWeighto=zeros(this->iHidden,this->iOutput);   
  //赋初值  
  for(i=1;iiInput;i++)  
      for(j=1;jiHidden;j++)  
          mWeighti.r(i,j)=randab(-1.0,1.0);  
  for(i=1;iiHidden;i++)  
      for(j=1;jiOutput;j++)  
          mWeighto.r(i,j)=randab(-1.0,1.0);  
   
  //创建阙值矩阵,并赋值  
  mThresholdi=zeros(1,this->iHidden);  
  for(i=1;iiHidden;i++)  
      mThresholdi.r(i)=randab(-1.0,1.0);  
  mThresholdo=zeros(1,this->iOutput);  
  for(i=1;iiOutput;i++)  
      mThresholdo.r(i)=randab(-1.0,1.0);  
  //创建权重变化矩阵  
  mChangei=zeros(this->iInput,this->iHidden);  
  mChangeo=zeros(this->iHidden,this->iOutput);  
   
  mInputNormFactor=zeros(iInput,2);  
  mTargetNormFactor=zeros(iOutput,2);  
  //误差矩阵  
   mOutputDeltas=zeros(iOutput);  
   mHiddenDeltas=zeros(iHidden);  
  //学习速率赋值  
  dblLearnRate1=0.5;  
  dblLearnRate2=0.5;  
  dblMomentumFactor=0.95;  
   
  m_isOK=false;  
  m_IsStop=false;  
  dblMse=1.0e-6;//误差限  
  dblError=1.0;  
  lEpochs=0;  
  
}  
//根据已有的网络进行预测  
Mm CBpNet::simulate(Mm mData)  
{int i,j;  
Mm mResult;  
Mm data=zeros(mData.rows(),mData.cols());  
data=mData;  
if(mData.cols()!=iInput)  
{::MessageBox(NULL,"输入数据变量个数错误!","输入数据变量个数错误!",MB_OK);  
  return mResult;  
}  
mResult=zeros(data.rows(),iOutput);   
//正规化数据  
for(i=1;i                       for(j=1;j                           data.r(i,j)=(data.r(i,j)-mInputNormFactor.r(j,1))/(mInputNormFactor.r(j,2)-mInputNormFactor.r(j,1));  
//计算  
     int iSample;  
     Mm mInputdata,mHiddendata,mOutputdata;  
     mInputdata=zeros(1,iInput);  
     mHiddendata=zeros(1,iHidden);  
     mOutputdata=zeros(1,iOutput);  
     double sum=0.0;  
   for(iSample=1;iSample                         //输入层数据  
        for(i=1;i                             mInputdata.r(i)=data.r(iSample,i);  
     //隐层数据   
        for(j=1;j                             sum=0.0;  
         for(i=1;i                              sum+=mInputdata.r(i)*mWeighti.r(i,j);  
         sum-=mThresholdi.r(j);   
         mHiddendata.r(j)=1.0/(1.0+exp(-sum));  
        }  
      
  //输出数据  
    for(j=1;j                           sum=0.0;  
        for(i=1;i                               sum+=mHiddendata.r(i)*mWeighto.r(i,j);  
        sum-=mThresholdo.r(j);   
        mOutputdata.r(j)=1.0/(1.0+exp(-sum));  
    }  
      
    //转换  
    for(j=1;j                           mResult.r(iSample,j)=mOutputdata.r(j)*(mTargetNormFactor.r(j,2)-mTargetNormFactor.r(j,1))+mTargetNormFactor.r(j,1);  
}  
   
return (mResult);  
}  
  
void CBpNet::LoadBpNet(CString &strNetName)  
{CFile file;  
   
if(file.Open(strNetName,CFile::modeRead)==0)     
{MessageBox(NULL,"无法打开文件!","错误",MB_OK);  
  return;  
}  
else{  
     CArchive myar(&file,CArchive::load);  
     Serialize(myar);   
     myar.Close();  
}  
file.Close();   
}  
  
bool CBpNet::SaveBpNet(CString &strNetName)  
{CFile file;  
if(strNetName.GetLength()==0)  
     return(false);  
if(file.Open(strNetName,CFile::modeCreate|CFile::modeWrite)==0)     
{MessageBox(NULL,"无法创建文件!","错误",MB_OK);  
  return(false);  
}  
else{  
     CArchive myar(&file,CArchive::store);  
     Serialize(myar);   
     myar.Close();  
}  
file.Close();   
return(true);  
}  
//网络学习  
void CBpNet::learn()  
{ int iSample=1;  
  double dblTotal;  
  MSG msg;  
  if(m_IsStop)  
      m_IsStop=false;  
  //数据正规化处理  
  normalize();  
  
  while(dblError>dblMse&&!m_IsStop){  
   dblTotal=0.0;  
   for(iSample=1;iSample                        forward(iSample);  
    backward(iSample);  
    dblTotal+=dblErr;//总误差  
  }  
   if(dblTotal/dblError>1.04){//动态改变学习速率  
       dblLearnRate1*=0.7;  
       dblLearnRate2*=0.7;  
   }  
   else{  
       dblLearnRate1*=1.05;  
       dblLearnRate2*=1.05;  
   }  
   lEpochs++;  
   dblError=dblTotal;  
  
   ::PeekMessage(&msg,NULL,0,0,PM_REMOVE);  
   ::DispatchMessage(&msg);  
     msg.message=-1;  
   ::DispatchMessage(&msg);//这样可以消除屏闪和假死机  
  }  
if(dblError                  m_isOK=true;  
else  
m_isOK=false;  
  
}  
  
void CBpNet::stop()  
{  
m_IsStop=true;  
}  
  
double CBpNet::randab(double a, double b)  
{ //注意,如果应用矩阵库,头文件matlib.h对rand()函数重新定义,只产生(0,1)  
  //之间的随机数  
  return((b-a)*rand()+a);  
}  
  
//将数据转化到(0,1)区间  
void CBpNet::normalize()  
{  
   
int i,j;  
  //输入数据范围  
  mInputNormFactor=scope(mSampleInput);  
  //目标数据范围  
  mTargetNormFactor=scope(mSampleTarget);  
  
for(i=1;i                       for(j=1;j                           mSampleInput.r(i,j)=(mSampleInput.r(i,j)-mInputNormFactor.r(j,1))/(mInputNormFactor.r(j,2)-mInputNormFactor.r(j,1));  
   
for(i=1;i                       for(j=1;j                           mSampleTarget.r(i,j)=(mSampleTarget.r(i,j)-mTargetNormFactor.r(j,1))/(mTargetNormFactor.r(j,2)-mTargetNormFactor.r(j,1));  
   
}  
  
//前向计算   
void CBpNet::forward(int iSample)  
{//根据第iSample个样本,前向计算  
    if(iSamplemSampleInput.rows()){  
        MessageBox(NULL,"无此样本数据:索引出界!","无此样本数据:索引出界!",MB_OK);  
        return;  
    }  
    int i,j;  
    double sum=0.0;  
      
    //输入层数据  
    for(i=1;i                           mInput.r(i)=mSampleInput.r(iSample,i);  
      
    //隐层数据  
      
    for(j=1;j                           sum=0.0;  
        for(i=1;i                               sum+=mInput.r(i)*mWeighti.r(i,j);  
         
        sum-=mThresholdi.r(j);   
        mHidden.r(j)=1.0/(1.0+exp(-sum));  
    }  
         
    //输出数据  
    for(j=1;j                           sum=0.0;  
        for(i=1;i                               sum+=mHidden.r(i)*mWeighto.r(i,j);  
        sum-=mThresholdo.r(j);   
        mOutput.r(j)=1.0/(1.0+exp(-sum));  
    }  
      
}  
  
//后向反馈  
void CBpNet::backward(int iSample)  
{  
    if(iSamplemSampleInput.rows()){  
        MessageBox(NULL,"无此样本数据:索引出界!","无此样本数据:索引出界!",MB_OK);  
        return;  
    }  
    int i,j;  
      
    //输出误差  
    for(i=1;i                           mOutputDeltas.r(i)=mOutput.r(i)*(1-mOutput.r(i))*(mSampleTarget.r(iSample,i)-mOutput.r(i));  
      
    //隐层误差  
    double sum=0.0;  
    for(j=1;j                           sum=0.0;  
        for(i=1;i                               sum+=mOutputDeltas.r(i)*mWeighto.r(j,i);  
        mHiddenDeltas.r(j)=mHidden.r(j)*(1-mHidden.r(j))*sum;  
    }  
    //更新隐层-输出权重  
      
    double dblChange;  
    for(j=1;j                           for(i=1;i                               dblChange=mOutputDeltas.r(i)*mHidden.r(j);  
            mWeighto.r(j,i)=mWeighto.r(j,i)+dblLearnRate2*dblChange+dblMomentumFactor*mChangeo.r(j,i);  
            mChangeo.r(j,i)=dblChange;  
        }  
      
    //更新输入-隐层权重  
    for(i=1;i                           for(j=1;j                               dblChange=mHiddenDeltas.r(j)*mInput.r(i);  
            mWeighti.r(i,j)=mWeighti.r(i,j)+dblLearnRate1*dblChange+dblMomentumFactor*mChangei.r(i,j);   
            mChangei.r(i,j)=dblChange;  
        }  
    //修改阙值  
    for(j=1;j                           mThresholdo.r(j)-=dblLearnRate2*mOutputDeltas.r(j);   
    for(i=1;i                           mThresholdi.r(i)-=dblLearnRate1*mHiddenDeltas.r(i);   
    //计算误差  
    dblErr=0.0;  
    for(i=1;i                           dblErr+=0.5*(mSampleTarget.r(iSample,i)-mOutput.r(i))*(mSampleTarget.r(iSample,i)-mOutput.r(i));  
      
     
}  
  
//求数据列的范围  
Mm CBpNet::scope(Mm mData)  
{Mm mScope;  
mScope=zeros(mData.cols(),2);  
double  min,max;  
for(int i=1;i                       min=max=mData.r(1,i);   
     for(int j=1;j                           if(mData.r(j,i)>=max)  
             max=mData.r(j,i);  
         if(mData.r(j,i)                                 min=mData.r(j,i);  
     }  
     if(min==max)  
         min=0.0;  
     mScope.r(i,1)=min;  
     mScope.r(i,2)=max;  
}  
return(mScope);  
  
}  
  
//显示矩阵数据,方便调试  
void CBpNet::display(Mm data)  
{CString strData,strTemp;  
int i=1,j=1;  
for(i=1;i                       for(j=1;j                       strTemp.Format("%.3f ",data.r(i,j));  
     strData+=strTemp;  
     }  
     strData=strData+"\r\n";  
}  
::MessageBox(NULL,strData,"",MB_OK);  
  
}  
  
void CBpNet::Serialize(CArchive &ar)  
{CObject::Serialize(ar);  
/////////////////////////////////////   
if(ar.IsStoring()){  
     int i,j;  
     double dblData;  
     CString strTemp="Bp";  
     ar                     //纪录神经元个数  
     ar                     //纪录权值  
     for(i=1;i                           for(j=1;j                               dblData=mWeighti.r(i,j);  
             ar                          }  
     for(i=1;i                           for(j=1;j                               dblData=mWeighto.r(i,j);  
             ar                          }  
      
      //记录权值变化  
     for(j=1;j                          for(i=1;i                               ar                        
    //输入-隐层权重变化  
     for(i=1;i                          for(j=1;j                               ar                           
      //纪录阙值  
     for(i=1;i                                dblData=mThresholdi.r(i);  
              ar                                  }  
     for(i=1;i                                dblData=mThresholdo.r(i);  
              ar                                  }  
     //纪录输入输出的极值  
     for(i=1;i                            dblData=mInputNormFactor.r(i,1);  
          ar                              dblData=mInputNormFactor.r(i,2);  
          ar                             }  
     for(i=1;i                           {dblData=mTargetNormFactor.r(i,1);  
          ar                              dblData=mTargetNormFactor.r(i,2);  
          ar                             }  
     //误差范围  
     ar                      //学习速率  
     ar                        
}  
           
else{  
     int i,j;  
     CString strTemp="";  
     double dblTemp;  
     ar>>strTemp;//读入标志  
    //读入神经元个数  
     ar>>iInput>>iHidden>>iOutput;  
     mChangei=zeros(iInput,iHidden);  
     mChangeo=zeros(iHidden,iOutput);  
     mWeighti=zeros(iInput,iHidden);  
     mWeighto=zeros(iHidden,iOutput);  
    //读入权值  
     for(i=1;i                            for(j=1;j                           { ar>>dblTemp;  
            mWeighti.r(i,j)=dblTemp;  
          }  
            
     for(i=1;i                           for(j=1;j                           { ar>>dblTemp;  
           mWeighto.r(i,j)=dblTemp;  
         }  
      
     //读入权值变化  
     for(j=1;j                          for(i=1;i                               ar>>mChangeo.r(j,i);  
      
     //输入-隐层权重  
     for(i=1;i                          for(j=1;j                               ar>>mChangei.r(i,j);  
      
     //读入阙值  
     mThresholdi=zeros(1,iHidden);  
     for(i=1;i                       {ar>>dblTemp;  
      mThresholdi.r(i)=dblTemp;  
     }  
     mThresholdo=zeros(1,iOutput);  
     for(i=1;i                       {ar>>dblTemp;  
      mThresholdo.r(i)=dblTemp;  
     }  
     //读入输入输出的极值  
     mInputNormFactor=zeros(iInput,2);  
     for(i=1;i                            ar>>dblTemp;  
          mInputNormFactor.r(i,1)=dblTemp; //极小值  
          ar>>dblTemp;  
          mInputNormFactor.r(i,2)=dblTemp; //极大值  
         }  
     mTargetNormFactor=zeros(iOutput,2);  
     for(i=1;i                           {ar>>dblTemp;  
          mTargetNormFactor.r(i,1)=dblTemp; //输出数据极小值  
          ar>>dblTemp;  
          mTargetNormFactor.r(i,2)=dblTemp;   
         }  
     //读入误差范围  
     ar>>dblMse;   
     //读入学习速率  
     ar>>dblLearnRate1>>dblLearnRate2;  
      
  
     //创建计算用的单个样本矩阵  
     mInput=zeros(1,iInput);  
     mHidden=zeros(1,iHidden);  
     mOutput=zeros(1,iOutput);  
     //误差矩阵  
     mOutputDeltas=zeros(iOutput);  
     mHiddenDeltas=zeros(iHidden);  
      
}  
  
}  
  
//如果不是新网络,比如从文件恢复的网络,调用此函数构建学习样本  
void CBpNet::LoadPattern(Mm mIn, Mm mOut)  
{ if(mIn.cols()!=iInput||mOut.cols()!=iOutput){  
    ::MessageBox( NULL,"学习样本格式错误!","错误",MB_OK);  
    return;  
}  
  mSampleInput=zeros(mIn.rows(),mIn.cols());  
  mSampleTarget=zeros(mOut.rows(),mOut.cols());   
  mSampleInput=mIn;  
  mSampleTarget=mOut;  
      
  m_isOK=false;  
  m_IsStop=false;  
  lEpochs=0;  
  dblMomentumFactor=0.95;  
  dblError=1.0;  
}  

阿莫论坛20周年了!感谢大家的支持与爱护!!

如果天空是黑暗的,那就摸黑生存;
如果发出声音是危险的,那就保持沉默;
如果自觉无力发光,那就蜷伏于牆角。
但是,不要习惯了黑暗就为黑暗辩护;
也不要为自己的苟且而得意;
不要嘲讽那些比自己更勇敢的人。
我们可以卑微如尘土,但不可扭曲如蛆虫。

出0入17汤圆

发表于 2016-5-21 17:32:49 | 显示全部楼层
亲爱的楼主,给大伙科普一下呗

出675入8汤圆

发表于 2016-11-28 22:54:47 | 显示全部楼层
感兴趣,有空研究一下

出0入53汤圆

发表于 2017-9-19 09:08:41 | 显示全部楼层
lz给大家科普一下吧

出0入0汤圆

发表于 2017-9-19 10:01:45 | 显示全部楼层
这明显是C++嘛

出0入0汤圆

发表于 2017-9-19 11:49:23 来自手机 | 显示全部楼层
没有支持gpu没有用处

出0入0汤圆

发表于 2017-9-21 16:52:11 | 显示全部楼层
做个CUDA的吧

出0入0汤圆

发表于 2018-8-25 22:27:33 来自手机 | 显示全部楼层
太高端了

出0入0汤圆

发表于 2018-8-26 09:10:54 来自手机 | 显示全部楼层
哥这是c++

出0入0汤圆

发表于 2018-8-26 09:17:30 来自手机 | 显示全部楼层
需要用cuda才能快速计算。

出0入0汤圆

发表于 2018-8-26 11:35:31 来自手机 | 显示全部楼层
太复杂了,我就看看好了

出0入0汤圆

发表于 2018-12-11 15:20:10 | 显示全部楼层
这个太暴力了

出0入0汤圆

发表于 2022-7-28 18:51:42 | 显示全部楼层
楼主能不能上个注解版的
回帖提示: 反政府言论将被立即封锁ID 在按“提交”前,请自问一下:我这样表达会给举报吗,会给自己惹麻烦吗? 另外:尽量不要使用Mark、顶等没有意义的回复。不得大量使用大字体和彩色字。【本论坛不允许直接上传手机拍摄图片,浪费大家下载带宽和论坛服务器空间,请压缩后(图片小于1兆)才上传。压缩方法可以在微信里面发给自己(不要勾选“原图),然后下载,就能得到压缩后的图片。注意:要连续压缩2次才能满足要求!!】。另外,手机版只能上传图片,要上传附件需要切换到电脑版(不需要使用电脑,手机上切换到电脑版就行,页面底部)。
您需要登录后才可以回帖 登录 | 注册

本版积分规则

手机版|Archiver|amobbs.com 阿莫电子技术论坛 ( 粤ICP备2022115958号, 版权所有:东莞阿莫电子贸易商行 创办于2004年 (公安交互式论坛备案:44190002001997 ) )

GMT+8, 2024-9-27 06:58

© Since 2004 www.amobbs.com, 原www.ourdev.cn, 原www.ouravr.com

快速回复 返回顶部 返回列表