/* trainmeanfieldrbmlogistic.cpp - train rbmlogistic as a deterministic network Copyright (C) 2010 Rui Rodrigues This software is released under the terms of the GNU General Public License (http://www.gnu.org/copyleft/gpl.html). */ #include #include #include #include #include using namespace std; #include "netrbm.h" #include "netdimsandfilenames.h" #include #include #include #ifdef _OPENMP #include #endif void checkfstream(ofstream& file_io,const char* filename); //in checkfstream.cpp void checkfstream(ifstream& file_io,const char* filename); //in checkfstream.cpp void read_datafile(ifstream&in,netdimsandfilenames& A); //in netdimsandfilenames.cpp void logistic (gsl_matrix * m); //in down in this file void writegslmatriz(const char* filename,gsl_matrix*m); void readgslmatriz(const char* filename,gsl_matrix*&m); void readgslvector(const char* filename,gsl_vector*m); void writegslvector(const char* filename,gsl_vector*m); //in iogslvectormatrix.cpp void useblacklist(gsl_matrix * &inputdata,const char*blacklistfile); //in blacklist.cpp // ---------------------------------------------------------------------------------------- //-----------------------CONFIGURE----------------------------------------------------- const size_t batchsize=500; const size_t numepochs=100; const double epsilonweights=0.1; const double epsilonbias=0.1; const double momentum=0.2; const double weightscost=0.002; //----------------------------------------------------------------------------------- double compute_error_rbmlogistic(netrbm&, gsl_matrix * data); //down in this file void meanfieldtrainrbmlogistic(unsigned numepochs,gsl_vector *vectorweightsandbias, gsl_matrix *tdata,unsigned ninputs,unsigned nhidden); //down in this file const string start="start"; const string cont="cont"; const string blacklist_use="useblacklist"; int main(int argc, char ** argv){ try{ if(argc<3){ cout<<" must be called with argument signal1 and after folder name. Optionaly thereis an extra argument: useblacklist !"<size2!=ninputs){ cout<<"inputdata is not compatible with ninputs!"<size1; //gsl random number generator gsl_rng *r = gsl_rng_alloc(gsl_rng_taus2); unsigned long seed=time (NULL) * getpid(); gsl_rng_set(r,seed); //load weights (first matrix with pure weights then visible bias finally //hidden bias gsl_vector * vectorweightsandbias=gsl_vector_calloc (ninputs*nhidden+ninputs+nhidden); //load weights from file readgslvector(A1.netlogisticweights.c_str(),vectorweightsandbias); #ifdef _OPENMP int maxnumthreads = omp_get_max_threads(); #else /* _OPENMP */ int maxnumthreads = 4; #endif /* _OPENMP */ unsigned blocksize=npatches/maxnumthreads;//just fo computing error //computer error rate before training { netrbm net1(ninputs,nhidden, npatches, maxnumthreads, blocksize, npatches, vectorweightsandbias); double error=compute_error_rbmlogistic(net1,data); cout<<"error rate by patch before training is "<