optim-eamsgd.lua (2233B)
1 -- Async EASGD/EAMSGD 2 -- Author: Sixin Zhang (zsx@cims.nyu.edu) 3 -- Author: umhau (umhau@alum.gcc.edu) 4 -- when mom==0, it is the easgd 5 require 'optim' 6 7 function optim.eamsgd(opfunc, w, config, state) 8 local config = config or {} 9 local state = state or config 10 11 local lr = config.learningRate or 0 -- learning rate \eta 12 local lrd = config.learningRateDecay or 0 -- learning rate decay 13 local lrp = config.learningRateDecayPower or 0 -- learning rate decay power 14 local mom = config.momentum or 0 -- momentum term \delta 15 local l2wd = config.l2wd or 0 16 17 local pc = config.pclient or nil 18 local mva = config.movingRateAlpha or 0 -- moving rate \alpha 19 local su = config.communicationPeriod or 1 -- comm period \tau 20 21 state.pversion = state.pversion or 0 22 state.dusync = state.dusync or 0 23 24 local fx,dfdx 25 local function localupdate() 26 if lr ~= 0 then 27 if mom > 0 then 28 if not state.vt then 29 state.vt = w:clone():zero() 30 end 31 state.vt:mul(mom) 32 w:add(state.vt) 33 end 34 fx,dfdx = opfunc(w) 35 if l2wd ~= 0 then dfdx:add(l2wd, w) end 36 local clr = lr 37 if lrd ~= 0 and lrp > 0 then 38 clr = lr / math.pow(1+state.pversion*lrd,lrp) 39 end 40 w:add(-clr,dfdx) 41 if mom > 0 then 42 state.vt:add(-clr,dfdx) 43 end 44 state.pversion = state.pversion + 1 45 end 46 end 47 48 if (pc and su>0 and mva>0) then 49 if (state.pversion%su == 0) then 50 if not config.suw then -- need 2 copies 51 config.suw = torch.Tensor():typeAs(w):resizeAs(w):fill(0) 52 config.sug = torch.Tensor():typeAs(w):resizeAs(w):fill(0) 53 pc:reset(config.suw,config.sug) 54 end 55 pc:async_recv_param() -- suw=w* 56 local synctime = sys.clock() 57 pc:wait() -- sug is sent and suw is recv 58 state.dusync = state.dusync + sys.clock()-synctime 59 config.sug:copy(w) -- sug=w 60 config.sug:add(-1,config.suw) -- sug=w-w* 61 config.sug:mul(mva) -- sug=mva*(w-w*) 62 pc:async_send_grad() -- apply w*=w*+mva*(w-w*) 63 local synctime = sys.clock() 64 pc:ping() -- overlap aio and computation 65 state.dusync = state.dusync + sys.clock()-synctime 66 localupdate() 67 w:add(-1,config.sug) -- w=w+mva*(w*-w) 68 else 69 localupdate() 70 end 71 else 72 assert(false) 73 end 74 return w,{fx} 75 end