optim-msgd.lua (1068B)
1 -- MSGD 2 -- Nesterov's momentum, see e.g. Sutskever et al., ICML 2013 3 -- Author: Sixin Zhang (zsx@cims.nyu.edu) 4 require 'optim' 5 6 function optim.msgd(opfunc, w, config, state) 7 local config = config or {} 8 local state = state or config 9 10 local lr = config.lr or 0 11 local lrd = config.lrd or 0 12 local lrp = config.lrp or 0 13 local mom = config.mom or 0 14 local mmax = config.mommax or 1 15 local mlrd = config.momdecay or 0 16 local l2wd = config.l2wd or 0 17 18 state.pversion = state.pversion or 0 19 20 if mom > 0 then 21 if mlrd > 0 then 22 mom = math.min(mmax, 1-0.5/(1+state.pversion/mlrd)) 23 end 24 if not state.vt then 25 state.vt = w:clone():zero() 26 end 27 state.vt:mul(mom) 28 w:add(state.vt) 29 end 30 local fx,dfdx = opfunc(w) 31 if l2wd ~= 0 then dfdx:add(l2wd,w) end 32 local clr = lr 33 if lrd > 0 and lrp > 0 then 34 clr = lr / math.pow(1+state.pversion*lrd,lrp) 35 end 36 w:add(-clr,dfdx) 37 if mom > 0 then 38 state.vt:add(-clr,dfdx) 39 end 40 state.pversion = state.pversion + 1 41 return w,{fx} 42 end