
192
|
第五章:打造深度網路
// 載入訓練資料:
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new File(filenameTrain)));
DataSetIterator trainIter = new RecordReaderDataSetIterator(rr,batchSize,0,2);
我們會在迭代器中告訴資料讀取器,我們的資料共有幾個欄位,其中哪個欄位代表的是
標籤。
確定網路的架構
我們需要的是一個基本的多層感知器,而且我們想要用
MultiLayerConfiguration
這個物件
來進行設定(它也可以用來設定任何 DL4J 網路架構),做法如下:
//log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(1)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(learningRate)
.updater(Updater.NESTEROVS).momentum(0.9)
.list()
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes) ...