‹ projects

cluster-rnn

a distributed Torch7 RNN cluster over MPI
Log | Files | Refs | README

goot.lua (5314B)


      1 -------------------------------------------------------------------------------
      2 -- Author: Sixin Zhang (zsx@cims.nyu.edu)
      3 -- Author: umhau (umhau@alum.gcc.edu)
      4 -------------------------------------------------------------------------------
      5 
      6 -- NOTES ----------------------------------------------------------------------
      7 
      8 
      9 -- VARIABLES ------------------------------------------------------------------
     10 local opt = opt or {}
     11 
     12 -- location of training data
     13 -- looks like it's designed to continue processing data after an interruption.
     14 local data_root = opt.data_root or
     15    io.popen('echo $HOME'):read() .. '/data/torch7/mnist10'
     16 
     17 
     18 -- MPI SETTINGS ---------------------------------------------------------------
     19 -- most of these are set in the mlaunch file.  These are mostly duplicates.
     20 
     21 local state = state or {}
     22 local optname = opt.name
     23 local lr = opt.lr or 1e-1
     24 local mom = opt.mom or 0
     25 local mb = opt.mb or 128
     26 local mva = opt.mva or 0
     27 local su = opt.su or 1
     28 local maxep = opt.maxepoch or 100 -- this is set in mlaunch
     29 local gpuid = opt.gpuid or -1
     30 local rank = opt.rank or -1
     31 local pclient = opt.pc or nil
     32 
     33 
     34 -- TIMING VARIABLES  ----------------------------------------------------------
     35 -- Later on, these are used to announce how long the process took.
     36 require 'sys'
     37 local tm = {}
     38 tm.feval = 0
     39 tm.sync = 0
     40 
     41 
     42 -- SET THE RANDOM SEED --------------------------------------------------------
     43 require 'os'
     44 local seed = opt.seed or os.time()
     45 torch.manualSeed(seed) -- remember to set cutorch.manualSeed if needed
     46 
     47 
     48 -- BUILD THE NEURAL NET -------------------------------------------------------
     49 -- replace this with own net
     50 
     51 require 'nn'
     52 local model = nn.Sequential()
     53 model:add(nn.Linear(32*32,10))
     54 --model:add(nn.Threshold()) -- relu
     55 --model:add(nn.Dropout())
     56 --model:add(nn.Linear(100,10))
     57 model:add(nn.LogSoftMax())
     58 criterion = nn.ClassNLLCriterion()
     59 state.theta,state.grad = model:getParameters()
     60 
     61 
     62 -- LOAD AND CONFIGURE DATA ----------------------------------------------------
     63 -- replace this with own process
     64 
     65 -- data can be downloaded from http://cs.nyu.edu/~zsx/mnist10/test_32x32.th7
     66 -- and http://cs.nyu.edu/~zsx/mnist10/train_32x32.th7
     67 -- remember to reset data_root
     68 -- may use test_bin for fast debug
     69 test_bin = data_root .. '/test_32x32.th7'
     70 train_bin = data_root .. '/train_32x32.th7'
     71 train_data = torch.load(train_bin)
     72 test_data = torch.load(test_bin)
     73 local dim = train_data['data']:size(2)*
     74             train_data['data']:size(3)*
     75 	    train_data['data']:size(4)
     76 local trsize = train_data['data']:size(1)
     77 local ttsize = test_data['data']:size(1)
     78 train_data.data:resize(trsize,dim)
     79 test_data.data:resize(ttsize,dim)
     80 train_data.data = train_data.data:float():div(255)
     81 test_data.data = test_data.data:float():div(255)
     82 train_data.labels = train_data.labels:float()
     83 test_data.labels = test_data.labels:float()
     84 if gpuid > 0 then
     85    train_data.data = train_data.data:cuda()
     86    train_data.labels = train_data.labels:cuda()
     87    test_data.data = test_data.data:cuda()
     88    test_data.labels = test_data.labels:cuda()
     89 end
     90 
     91 
     92 -- OPTIMIZER SETTINGS ---------------------------------------------------------
     93 -- not even going to worry about alternatives. 
     94 
     95 require 'optim'
     96 
     97 opti = optim.eamsgd
     98 state.optim = {
     99     lr = lr,
    100     pclient = pclient,
    101     su = su,
    102     mva = mva,
    103     mom = mom,
    104 }
    105 
    106 -- LOAD DATA FOR PROCESSING ---------------------------------------------------
    107 -- only if a client process.  TODO: Figure out exactly what it is that's 
    108 -- sent through the 'start' function.
    109 
    110 print('i am ' .. rank .. ' ready to run')
    111 if pclient then
    112    pclient:start(state.theta,state.grad)
    113    assert(rank == pclient.rank)
    114    print('pc ' .. rank .. ' started')
    115 end
    116 
    117 
    118 -- TRAINING AND TRAINING-RELATED FUNCTIONS ------------------------------------
    119 local inputs = nil
    120 local targets = nil
    121 local avg_err = 0
    122 local feval = 
    123 function(x)
    124    local time_feval = sys.clock()
    125    -- get new parameters    
    126    if x ~= state.theta then
    127       print('copy theta!!')
    128       state.theta:copy(x)
    129    end
    130    -- reset gradients
    131    state.grad:zero()
    132    -- forward
    133    local outputs = model:forward(inputs)
    134    local err = criterion:forward(outputs, targets)
    135    -- estimate df/dW
    136    local dE_do = criterion:backward(outputs, targets)
    137    model:backward(inputs, dE_do)
    138    local er
    139    if type(err) == 'number' then
    140       er = err -- for cpu
    141    else
    142       er = err[1]  -- for gpu
    143    end
    144    avg_err = avg_err + er
    145    tm.feval = tm.feval + (sys.clock() - time_feval)
    146    return er,state.grad
    147 end
    148 
    149 -- train
    150 sys.tic()
    151 local iter = 0
    152 for epoch = 1,maxep do
    153    for t = 1,trsize,mb do
    154       -- prepare mini batch
    155       local mbs = math.min(trsize-t+1,mb)
    156       -- there's no shuffling in this cycling, just for illustration
    157       inputs = train_data.data:narrow(1,t,mbs)
    158       targets = train_data.labels:narrow(1,t,mbs)
    159       -- optimize on current mini-batch
    160       local x,fx
    161       x,fx = opti(feval, state.theta, state.optim)     
    162       -- increase iteration count
    163       iter = iter + 1
    164    end
    165    print(io.popen('hostname -s'):read(),sys.toc(),rank,
    166 	 'avg_err at epoch ' .. epoch .. ' is ' .. avg_err / iter)
    167 end
    168 
    169 if pclient then
    170    pclient:stop()
    171 end
    172 
    173 print(rank,'total training time is', sys.toc())
    174 print(rank,'total function eval time is', tm.feval)
    175 if state.optim.dusync then
    176    tm.sync = state.optim.dusync
    177 end
    178 print(rank,'total sync time is', tm.sync)