‹ projects

cluster-rnn

a distributed Torch7 RNN cluster over MPI
Log | Files | Refs | README

optim-eamsgd.lua (2233B)


      1 -- Async EASGD/EAMSGD
      2 -- Author: Sixin Zhang (zsx@cims.nyu.edu)
      3 -- Author: umhau (umhau@alum.gcc.edu)
      4 -- when mom==0, it is the easgd
      5 require 'optim'
      6 
      7 function optim.eamsgd(opfunc, w, config, state)
      8    local config = config or {}   
      9    local state = state or config
     10   
     11    local lr = config.learningRate or 0   -- learning rate \eta
     12    local lrd = config.learningRateDecay or 0 -- learning rate decay
     13    local lrp = config.learningRateDecayPower or 0 -- learning rate decay power
     14    local mom = config.momentum or 0 -- momentum term \delta
     15    local l2wd = config.l2wd or 0
     16 
     17    local pc = config.pclient or nil
     18    local mva = config.movingRateAlpha or 0 -- moving rate \alpha
     19    local su = config.communicationPeriod or 1   -- comm period \tau
     20 
     21    state.pversion = state.pversion or 0
     22    state.dusync = state.dusync or 0
     23 
     24    local fx,dfdx
     25    local function localupdate()
     26       if lr ~= 0 then
     27 	 if mom > 0 then
     28 	    if not state.vt then
     29 	       state.vt = w:clone():zero()
     30 	    end
     31 	    state.vt:mul(mom)
     32 	    w:add(state.vt)
     33 	 end	 	 
     34 	 fx,dfdx = opfunc(w)
     35 	 if l2wd ~= 0 then dfdx:add(l2wd, w) end	 
     36 	 local clr = lr
     37 	 if lrd ~= 0 and lrp > 0 then 
     38 	    clr = lr / math.pow(1+state.pversion*lrd,lrp)
     39 	 end
     40 	 w:add(-clr,dfdx)
     41 	 if mom > 0 then
     42 	    state.vt:add(-clr,dfdx)
     43 	 end
     44 	 state.pversion = state.pversion + 1	         
     45       end
     46    end
     47    
     48    if (pc and su>0 and mva>0) then
     49       if (state.pversion%su == 0) then
     50 	 if not config.suw then -- need 2 copies
     51 	    config.suw = torch.Tensor():typeAs(w):resizeAs(w):fill(0)
     52 	    config.sug = torch.Tensor():typeAs(w):resizeAs(w):fill(0)
     53 	    pc:reset(config.suw,config.sug)
     54 	 end
     55 	 pc:async_recv_param() -- suw=w*
     56 	 local synctime = sys.clock()
     57 	 pc:wait() -- sug is sent and suw is recv
     58 	 state.dusync = state.dusync + sys.clock()-synctime
     59 	 config.sug:copy(w) -- sug=w
     60 	 config.sug:add(-1,config.suw) -- sug=w-w*
     61 	 config.sug:mul(mva) -- sug=mva*(w-w*)
     62 	 pc:async_send_grad() -- apply w*=w*+mva*(w-w*)
     63 	 local synctime = sys.clock()
     64 	 pc:ping() -- overlap aio and computation
     65 	 state.dusync = state.dusync + sys.clock()-synctime
     66 	 localupdate()
     67 	 w:add(-1,config.sug) -- w=w+mva*(w*-w)
     68       else
     69 	 localupdate()
     70       end
     71    else
     72       assert(false)
     73    end
     74    return w,{fx}
     75 end