‹ projects

cluster-rnn

a distributed Torch7 RNN cluster over MPI
Log | Files | Refs | README

commit 4f4198ba67f6ea53b20bf19baba0aebb4fb9fa58
parent bc6df9a30e64cce089c332b6ea3c5ef629425e81
Author: umhau <umhau@users.noreply.github.com>
Date:   Tue, 14 Feb 2017 16:21:05 -0500

better explanatory comments

Diffstat:
Masyncsgd/goot.lua | 100+++++++++++++++++++++++++++++++++++++++++++++++--------------------------------
1 file changed, 60 insertions(+), 40 deletions(-)

diff --git a/asyncsgd/goot.lua b/asyncsgd/goot.lua @@ -1,30 +1,52 @@ -------------------------------------------------------------------- +------------------------------------------------------------------------------- -- Author: Sixin Zhang (zsx@cims.nyu.edu) -------------------------------------------------------------------- +-- Author: umhau (umhau@alum.gcc.edu) +------------------------------------------------------------------------------- + +-- NOTES ---------------------------------------------------------------------- + + +-- VARIABLES ------------------------------------------------------------------ local opt = opt or {} + +-- location of training data +local data_root = opt.data_root or + io.popen('echo $HOME'):read() .. '/data/torch7/mnist10' + + +-- MPI SETTINGS --------------------------------------------------------------- +-- most of these are set in the mlaunch file. These are mostly duplicates. + local state = state or {} -local optname = opt.name or 'sgd' +local optname = opt.name local lr = opt.lr or 1e-1 local mom = opt.mom or 0 local mb = opt.mb or 128 local mva = opt.mva or 0 local su = opt.su or 1 -local maxep = opt.maxepoch or 100 -local data_root = opt.data_root or - io.popen('echo $HOME'):read() .. '/data/torch7/mnist10' +local maxep = opt.maxepoch or 100 -- this is set in mlaunch local gpuid = opt.gpuid or -1 local rank = opt.rank or -1 local pclient = opt.pc or nil -------------------------------------------------------------------- + + +-- TIMING VARIABLES ---------------------------------------------------------- +-- Later on, these are used to announce how long the process took. require 'sys' local tm = {} tm.feval = 0 tm.sync = 0 -------------------------------------------------------------------- + + +-- SET THE RANDOM SEED -------------------------------------------------------- require 'os' local seed = opt.seed or os.time() torch.manualSeed(seed) -- remember to set cutorch.manualSeed if needed -------------------------------------------------------------------- + + +-- BUILD THE NEURAL NET ------------------------------------------------------- +-- replace this with own net + require 'nn' local model = nn.Sequential() model:add(nn.Linear(32*32,10)) @@ -34,10 +56,13 @@ model:add(nn.Linear(32*32,10)) model:add(nn.LogSoftMax()) criterion = nn.ClassNLLCriterion() state.theta,state.grad = model:getParameters() -------------------------------------------------------------------- --- data can be downloaded from, --- http://cs.nyu.edu/~zsx/mnist10/test_32x32.th7 --- http://cs.nyu.edu/~zsx/mnist10/train_32x32.th7 + + +-- LOAD AND CONFIGURE DATA ---------------------------------------------------- +-- replace this with own process + +-- data can be downloaded from http://cs.nyu.edu/~zsx/mnist10/test_32x32.th7 +-- and http://cs.nyu.edu/~zsx/mnist10/train_32x32.th7 -- remember to reset data_root -- may use test_bin for fast debug test_bin = data_root .. '/test_32x32.th7' @@ -61,40 +86,35 @@ if gpuid > 0 then test_data.data = test_data.data:cuda() test_data.labels = test_data.labels:cuda() end -------------------------------------------------------------------- + + +-- OPTIMIZER SETTINGS --------------------------------------------------------- +-- not even going to worry about alternatives. + require 'optim' -local opti -if optname == 'sgd' then - opti = optim.msgd - state.optim = { - lr = lr, - mommax = mom, - } -elseif optname == 'downpour' then - opti = optim.downpour - state.optim = { - lr = lr, - pclient = pclient, - su = su, - } -elseif optname == 'eamsgd' then - opti = optim.eamsgd - state.optim = { - lr = lr, - pclient = pclient, - su = su, - mva = mva, - mom = mom, - } -end -------------------------------------------------------------------- + +opti = optim.eamsgd +state.optim = { + lr = lr, + pclient = pclient, + su = su, + mva = mva, + mom = mom, +} + +-- LOAD DATA FOR PROCESSING --------------------------------------------------- +-- only if a client process. TODO: Figure out exactly what it is that's +-- sent through the 'start' function. + print('i am ' .. rank .. ' ready to run') if pclient then pclient:start(state.theta,state.grad) assert(rank == pclient.rank) print('pc ' .. rank .. ' started') end -------------------------------------------------------------------- + + +-- TRAINING AND TRAINING-RELATED FUNCTIONS ------------------------------------ local inputs = nil local targets = nil local avg_err = 0