commit 4f4198ba67f6ea53b20bf19baba0aebb4fb9fa58
parent bc6df9a30e64cce089c332b6ea3c5ef629425e81
Author: umhau <umhau@users.noreply.github.com>
Date: Tue, 14 Feb 2017 16:21:05 -0500
better explanatory comments
Diffstat:
| M | asyncsgd/goot.lua | | | 100 | +++++++++++++++++++++++++++++++++++++++++++++++-------------------------------- |
1 file changed, 60 insertions(+), 40 deletions(-)
diff --git a/asyncsgd/goot.lua b/asyncsgd/goot.lua
@@ -1,30 +1,52 @@
--------------------------------------------------------------------
+-------------------------------------------------------------------------------
-- Author: Sixin Zhang (zsx@cims.nyu.edu)
--------------------------------------------------------------------
+-- Author: umhau (umhau@alum.gcc.edu)
+-------------------------------------------------------------------------------
+
+-- NOTES ----------------------------------------------------------------------
+
+
+-- VARIABLES ------------------------------------------------------------------
local opt = opt or {}
+
+-- location of training data
+local data_root = opt.data_root or
+ io.popen('echo $HOME'):read() .. '/data/torch7/mnist10'
+
+
+-- MPI SETTINGS ---------------------------------------------------------------
+-- most of these are set in the mlaunch file. These are mostly duplicates.
+
local state = state or {}
-local optname = opt.name or 'sgd'
+local optname = opt.name
local lr = opt.lr or 1e-1
local mom = opt.mom or 0
local mb = opt.mb or 128
local mva = opt.mva or 0
local su = opt.su or 1
-local maxep = opt.maxepoch or 100
-local data_root = opt.data_root or
- io.popen('echo $HOME'):read() .. '/data/torch7/mnist10'
+local maxep = opt.maxepoch or 100 -- this is set in mlaunch
local gpuid = opt.gpuid or -1
local rank = opt.rank or -1
local pclient = opt.pc or nil
--------------------------------------------------------------------
+
+
+-- TIMING VARIABLES ----------------------------------------------------------
+-- Later on, these are used to announce how long the process took.
require 'sys'
local tm = {}
tm.feval = 0
tm.sync = 0
--------------------------------------------------------------------
+
+
+-- SET THE RANDOM SEED --------------------------------------------------------
require 'os'
local seed = opt.seed or os.time()
torch.manualSeed(seed) -- remember to set cutorch.manualSeed if needed
--------------------------------------------------------------------
+
+
+-- BUILD THE NEURAL NET -------------------------------------------------------
+-- replace this with own net
+
require 'nn'
local model = nn.Sequential()
model:add(nn.Linear(32*32,10))
@@ -34,10 +56,13 @@ model:add(nn.Linear(32*32,10))
model:add(nn.LogSoftMax())
criterion = nn.ClassNLLCriterion()
state.theta,state.grad = model:getParameters()
--------------------------------------------------------------------
--- data can be downloaded from,
--- http://cs.nyu.edu/~zsx/mnist10/test_32x32.th7
--- http://cs.nyu.edu/~zsx/mnist10/train_32x32.th7
+
+
+-- LOAD AND CONFIGURE DATA ----------------------------------------------------
+-- replace this with own process
+
+-- data can be downloaded from http://cs.nyu.edu/~zsx/mnist10/test_32x32.th7
+-- and http://cs.nyu.edu/~zsx/mnist10/train_32x32.th7
-- remember to reset data_root
-- may use test_bin for fast debug
test_bin = data_root .. '/test_32x32.th7'
@@ -61,40 +86,35 @@ if gpuid > 0 then
test_data.data = test_data.data:cuda()
test_data.labels = test_data.labels:cuda()
end
--------------------------------------------------------------------
+
+
+-- OPTIMIZER SETTINGS ---------------------------------------------------------
+-- not even going to worry about alternatives.
+
require 'optim'
-local opti
-if optname == 'sgd' then
- opti = optim.msgd
- state.optim = {
- lr = lr,
- mommax = mom,
- }
-elseif optname == 'downpour' then
- opti = optim.downpour
- state.optim = {
- lr = lr,
- pclient = pclient,
- su = su,
- }
-elseif optname == 'eamsgd' then
- opti = optim.eamsgd
- state.optim = {
- lr = lr,
- pclient = pclient,
- su = su,
- mva = mva,
- mom = mom,
- }
-end
--------------------------------------------------------------------
+
+opti = optim.eamsgd
+state.optim = {
+ lr = lr,
+ pclient = pclient,
+ su = su,
+ mva = mva,
+ mom = mom,
+}
+
+-- LOAD DATA FOR PROCESSING ---------------------------------------------------
+-- only if a client process. TODO: Figure out exactly what it is that's
+-- sent through the 'start' function.
+
print('i am ' .. rank .. ' ready to run')
if pclient then
pclient:start(state.theta,state.grad)
assert(rank == pclient.rank)
print('pc ' .. rank .. ' started')
end
--------------------------------------------------------------------
+
+
+-- TRAINING AND TRAINING-RELATED FUNCTIONS ------------------------------------
local inputs = nil
local targets = nil
local avg_err = 0