‹ projects

cluster-rnn

a distributed Torch7 RNN cluster over MPI
Log | Files | Refs | README

optim-downpour.lua (1603B)


      1 -- DOWNPOUR
      2 -- Author: Sixin Zhang (zsx@cims.nyu.edu)
      3 -- when su==1, it is the asyncsgd/hogwild
      4 require 'optim'
      5 
      6 function optim.downpour(opfunc, w, config, state)
      7    local config = config or {}   
      8    local state = state or config
      9    
     10    local lr = config.lr or 0 -- learning rate
     11    local lrd = config.lrd or 0 -- learning rate decay
     12    local l2wd = config.l2wd or 0
     13 
     14    local pc = config.pclient or nil
     15    local su = config.su or 0 -- sync updates (grad and param)
     16 
     17    state.pversion = state.pversion or 0
     18    state.dusync = state.dusync or 0   
     19    
     20    if lrd ~= 0 then 
     21       lr = lr / (1 + state.pversion*lrd)
     22    end
     23    local fx,dfdx = opfunc(w)
     24    if l2wd ~= 0 then dfdx:add(l2wd, w) end
     25 
     26    if pc and su>1 then
     27       -- apply lr
     28       dfdx:mul(-lr)
     29       -- accumulate grad
     30       if not config.dfdx then -- need one copy to accumulate
     31 	 config.dfdx = torch.Tensor():typeAs(dfdx):resizeAs(dfdx):fill(0)
     32 	 pc:reset(w,config.dfdx)
     33       end
     34       config.dfdx:add(dfdx)
     35       -- send grads and get new param
     36       if state.pversion%su==0 then
     37 	 pc:async_send_grad()
     38 	 pc:async_recv_param()
     39 	 local synctime = sys.clock()
     40 	 pc:wait()
     41 	 state.dusync = state.dusync + sys.clock()-synctime
     42 	 config.dfdx:fill(0)
     43       else
     44 	 w:add(dfdx) -- move locally
     45       end
     46    elseif pc and su==1 then
     47       -- apply lr
     48       dfdx:mul(-lr)
     49       -- send
     50       pc:async_send_grad()
     51       pc:async_recv_param()
     52       local synctime = sys.clock()
     53       pc:wait()
     54       state.dusync = state.dusync + sys.clock()-synctime
     55    else
     56       assert(false)
     57    end
     58    state.pversion = state.pversion + 1      
     59    return w,{fx}
     60 end