goot.lua (5314B)
1 ------------------------------------------------------------------------------- 2 -- Author: Sixin Zhang (zsx@cims.nyu.edu) 3 -- Author: umhau (umhau@alum.gcc.edu) 4 ------------------------------------------------------------------------------- 5 6 -- NOTES ---------------------------------------------------------------------- 7 8 9 -- VARIABLES ------------------------------------------------------------------ 10 local opt = opt or {} 11 12 -- location of training data 13 -- looks like it's designed to continue processing data after an interruption. 14 local data_root = opt.data_root or 15 io.popen('echo $HOME'):read() .. '/data/torch7/mnist10' 16 17 18 -- MPI SETTINGS --------------------------------------------------------------- 19 -- most of these are set in the mlaunch file. These are mostly duplicates. 20 21 local state = state or {} 22 local optname = opt.name 23 local lr = opt.lr or 1e-1 24 local mom = opt.mom or 0 25 local mb = opt.mb or 128 26 local mva = opt.mva or 0 27 local su = opt.su or 1 28 local maxep = opt.maxepoch or 100 -- this is set in mlaunch 29 local gpuid = opt.gpuid or -1 30 local rank = opt.rank or -1 31 local pclient = opt.pc or nil 32 33 34 -- TIMING VARIABLES ---------------------------------------------------------- 35 -- Later on, these are used to announce how long the process took. 36 require 'sys' 37 local tm = {} 38 tm.feval = 0 39 tm.sync = 0 40 41 42 -- SET THE RANDOM SEED -------------------------------------------------------- 43 require 'os' 44 local seed = opt.seed or os.time() 45 torch.manualSeed(seed) -- remember to set cutorch.manualSeed if needed 46 47 48 -- BUILD THE NEURAL NET ------------------------------------------------------- 49 -- replace this with own net 50 51 require 'nn' 52 local model = nn.Sequential() 53 model:add(nn.Linear(32*32,10)) 54 --model:add(nn.Threshold()) -- relu 55 --model:add(nn.Dropout()) 56 --model:add(nn.Linear(100,10)) 57 model:add(nn.LogSoftMax()) 58 criterion = nn.ClassNLLCriterion() 59 state.theta,state.grad = model:getParameters() 60 61 62 -- LOAD AND CONFIGURE DATA ---------------------------------------------------- 63 -- replace this with own process 64 65 -- data can be downloaded from http://cs.nyu.edu/~zsx/mnist10/test_32x32.th7 66 -- and http://cs.nyu.edu/~zsx/mnist10/train_32x32.th7 67 -- remember to reset data_root 68 -- may use test_bin for fast debug 69 test_bin = data_root .. '/test_32x32.th7' 70 train_bin = data_root .. '/train_32x32.th7' 71 train_data = torch.load(train_bin) 72 test_data = torch.load(test_bin) 73 local dim = train_data['data']:size(2)* 74 train_data['data']:size(3)* 75 train_data['data']:size(4) 76 local trsize = train_data['data']:size(1) 77 local ttsize = test_data['data']:size(1) 78 train_data.data:resize(trsize,dim) 79 test_data.data:resize(ttsize,dim) 80 train_data.data = train_data.data:float():div(255) 81 test_data.data = test_data.data:float():div(255) 82 train_data.labels = train_data.labels:float() 83 test_data.labels = test_data.labels:float() 84 if gpuid > 0 then 85 train_data.data = train_data.data:cuda() 86 train_data.labels = train_data.labels:cuda() 87 test_data.data = test_data.data:cuda() 88 test_data.labels = test_data.labels:cuda() 89 end 90 91 92 -- OPTIMIZER SETTINGS --------------------------------------------------------- 93 -- not even going to worry about alternatives. 94 95 require 'optim' 96 97 opti = optim.eamsgd 98 state.optim = { 99 lr = lr, 100 pclient = pclient, 101 su = su, 102 mva = mva, 103 mom = mom, 104 } 105 106 -- LOAD DATA FOR PROCESSING --------------------------------------------------- 107 -- only if a client process. TODO: Figure out exactly what it is that's 108 -- sent through the 'start' function. 109 110 print('i am ' .. rank .. ' ready to run') 111 if pclient then 112 pclient:start(state.theta,state.grad) 113 assert(rank == pclient.rank) 114 print('pc ' .. rank .. ' started') 115 end 116 117 118 -- TRAINING AND TRAINING-RELATED FUNCTIONS ------------------------------------ 119 local inputs = nil 120 local targets = nil 121 local avg_err = 0 122 local feval = 123 function(x) 124 local time_feval = sys.clock() 125 -- get new parameters 126 if x ~= state.theta then 127 print('copy theta!!') 128 state.theta:copy(x) 129 end 130 -- reset gradients 131 state.grad:zero() 132 -- forward 133 local outputs = model:forward(inputs) 134 local err = criterion:forward(outputs, targets) 135 -- estimate df/dW 136 local dE_do = criterion:backward(outputs, targets) 137 model:backward(inputs, dE_do) 138 local er 139 if type(err) == 'number' then 140 er = err -- for cpu 141 else 142 er = err[1] -- for gpu 143 end 144 avg_err = avg_err + er 145 tm.feval = tm.feval + (sys.clock() - time_feval) 146 return er,state.grad 147 end 148 149 -- train 150 sys.tic() 151 local iter = 0 152 for epoch = 1,maxep do 153 for t = 1,trsize,mb do 154 -- prepare mini batch 155 local mbs = math.min(trsize-t+1,mb) 156 -- there's no shuffling in this cycling, just for illustration 157 inputs = train_data.data:narrow(1,t,mbs) 158 targets = train_data.labels:narrow(1,t,mbs) 159 -- optimize on current mini-batch 160 local x,fx 161 x,fx = opti(feval, state.theta, state.optim) 162 -- increase iteration count 163 iter = iter + 1 164 end 165 print(io.popen('hostname -s'):read(),sys.toc(),rank, 166 'avg_err at epoch ' .. epoch .. ' is ' .. avg_err / iter) 167 end 168 169 if pclient then 170 pclient:stop() 171 end 172 173 print(rank,'total training time is', sys.toc()) 174 print(rank,'total function eval time is', tm.feval) 175 if state.optim.dusync then 176 tm.sync = state.optim.dusync 177 end 178 print(rank,'total sync time is', tm.sync)