‹ projects

cluster-rnn

a distributed Torch7 RNN cluster over MPI
Log | Files | Refs | README

pserver.lua (5026B)


      1 -- Parameter server
      2 -- Author: Sixin Zhang (zsx@cims.nyu.edu)
      3 -- Important change: 
      4 -- init the parameter from the first local worker (once and only once)
      5 -- before other service gets started. From the pclient side, 
      6 -- it's also the first local pclient (worker) who sends its parameter
      7 -- to the psever.
      8 require 'mpiT'
      9 
     10 local pServer = torch.class('pServer')
     11 
     12 function pServer:__init(conf,state)
     13    self.state = state or {}
     14    self.rank = conf.rank or -1
     15    self.cranks = conf.cranks or {} -- client ranks
     16    self.mtype = mpiT.FLOAT
     17    self.mworld = conf.world or mpiT.COMM_WORLD
     18 
     19    self.offset = -1 -- offset of param grad
     20    self.size = -1 -- size of param grad
     21    self.tensor = {} -- tensor from storage
     22    self.storage = {} -- param gradient storage
     23    self.emptys = torch.Storage()
     24 
     25    self.state.on = false
     26    self.state.io = false
     27    self.state.iostop = 0
     28    self.coq = Queue() -- coroutine queue
     29 
     30    self.conf = conf
     31 end
     32 
     33 local function pServer_recvinit(self,crank)
     34    coroutine.yield(mpiT.signal_INIT)
     35    -- get meta info
     36    local cinfo = torch.LongStorage(2)
     37    mpiT.aio_recv(cinfo,2,mpiT.LONG,
     38 		 crank,mpiT.tag_ps_recv_init,self.mworld,self.state)
     39    -- init storage 
     40    if self.offset == -1 then
     41       self.offset = cinfo[1]
     42       self.size = cinfo[2]
     43       self.storage.p = torch.Storage(self.size)
     44       self.storage.g = {}
     45       self.storage.g[crank] = torch.Storage(self.size)
     46       self.tensor.p = torch.Tensor(self.storage.p)
     47       self.tensor.g = {}
     48       self.tensor.g[crank] = torch.Tensor(self.storage.g[crank])
     49    else
     50       assert(self.offset == cinfo[1])
     51       assert(self.size == cinfo[2])
     52       self.storage.g[crank] = torch.Storage(self.size)
     53       self.tensor.g[crank] = torch.Tensor(self.storage.g[crank])
     54    end
     55    --print('pServer:recvinit',self.rank,crank,self.offset,self.size)
     56    coroutine.yield(mpiT.signal_DONE)
     57 end
     58 
     59 local function pServer_sendparam(self,crank)
     60    coroutine.yield(mpiT.signal_INIT)
     61    while (self.state.on) do
     62       --print('pServer_sendparam to recv',crank,self.size)
     63       mpiT.aio_recv(self.emptys,0,self.mtype,
     64 		    crank,mpiT.tag_ps_recv_header,self.mworld,self.state)
     65       --print('ps ' .. self.rank .. ' send param to ' .. crank)
     66       if self.state.io then
     67       	 mpiT.aio_send(self.storage.p,self.size,self.mtype,
     68 		       crank,mpiT.tag_ps_send_param,self.mworld,self.state)
     69       end
     70    end
     71    coroutine.yield(mpiT.signal_DONE)
     72 end
     73 
     74 -- Warning: no lock on self.tensor.p during this recvgrad, expect inconsistent read
     75 local function pServer_recvgrad(self,crank)
     76    coroutine.yield(mpiT.signal_INIT)
     77    while (self.state.on) do      
     78       -- recv
     79       mpiT.aio_recv(self.storage.g[crank],self.size,self.mtype,
     80 		    crank,mpiT.tag_ps_recv_grad,self.mworld,self.state)
     81       --print('ps ' .. self.rank .. ' recv grad from ' .. crank)
     82       -- apply
     83       self.tensor.p:add(self.tensor.g[crank])
     84       if self.state.on then
     85          mpiT.aio_send(self.emptys,0,self.mtype,
     86 	  	       crank,mpiT.tag_ps_recv_grad_tail,self.mworld,self.state)
     87       end
     88    end
     89    coroutine.yield(mpiT.signal_DONE)
     90 end
     91 
     92 local function pServer_recvparam(self,crank)
     93    coroutine.yield(mpiT.signal_INIT)
     94    mpiT.aio_recv(self.storage.p,self.size,self.mtype,
     95 		 crank,mpiT.tag_ps_recv_param,self.mworld,self.state)
     96    if self.state.on then
     97       mpiT.aio_send(self.emptys,0,self.mtype,
     98                     crank,mpiT.tag_ps_recv_param_tail,self.mworld,self.state)
     99    end
    100    --print('ps ' .. self.rank .. ' recv param from ' .. crank)
    101    coroutine.yield(mpiT.signal_DONE)
    102 end
    103 
    104 -- stop
    105 function table.len(tbl)
    106    local l=0
    107    if tbl then
    108       for k,v in pairs(tbl) do
    109 	 l=l+1
    110       end
    111    end
    112    return l
    113 end
    114 
    115 local function pServer_recvstop(self,crank)
    116    coroutine.yield(mpiT.signal_INIT)
    117    local tostop = torch.ByteStorage(1)
    118    mpiT.aio_recv(tostop,1,mpiT.BYTE,
    119 		 crank,mpiT.tag_ps_recv_stop,self.mworld,self.state)
    120    if tostop[1] then
    121       self.state.iostop = self.state.iostop + 1
    122       if self.state.iostop == table.len(self.cranks) then
    123 	 self.state.on = false
    124 	 self.state.io = false
    125 	 -- print('server', self.rank, 'finally stoped by', crank)
    126       end
    127    end
    128    coroutine.yield(mpiT.signal_DONE)
    129 end
    130 
    131 function pServer:start()
    132    self.state.on = true
    133    self.state.io = true
    134    -- init
    135    self.coq:clear()
    136    for i,crank in pairs(self.cranks) do
    137       local co = mpiT.co_execute(pServer_recvinit,{self,crank})
    138       self.coq:push(co)
    139    end
    140    mpiT.co_wait(self.coq)
    141    for i,crank in pairs(self.cranks) do      
    142       if i == 1 then 
    143          -- init the parameter from the first local worker
    144          local co3 = mpiT.co_execute(pServer_recvparam,{self,crank})      
    145          self.coq:push(co3)
    146          mpiT.co_wait(self.coq)
    147       end
    148       -- on request
    149       local co0 = mpiT.co_execute(pServer_recvstop,{self,crank})
    150       local co1 = mpiT.co_execute(pServer_recvgrad,{self,crank})
    151       local co2 = mpiT.co_execute(pServer_sendparam,{self,crank})
    152       self.coq:push(co0)
    153       self.coq:push(co1)
    154       self.coq:push(co2)
    155    end
    156    mpiT.co_wait(self.coq)
    157 end