‹ projects

cluster-rnn

a distributed Torch7 RNN cluster over MPI
Log | Files | Refs | README

pclient.lua (5356B)


      1 -- Parameter client
      2 -- Author: Sixin Zhang (zsx@cims.nyu.edu)
      3 require 'mpiT'
      4 
      5 local pClient = torch.class('pClient')
      6 
      7 function pClient:__init(conf,state)
      8    self.state = state or {}
      9    self.rank = conf.rank or -1
     10    self.sranks = conf.sranks or {} -- server ranks
     11    self.cranks = conf.cranks or {} -- client ranks   
     12    self.plong = conf.plong or 0 -- size of whole parameter
     13    self.pstorage = conf.pstorage or torch.Storage()
     14    self.gstorage = conf.gstorage or torch.Storage()
     15    self.emptys = torch.Storage()
     16    self.sinfo = {}
     17    self.mtype = mpiT.FLOAT
     18    self.mworld = conf.world or mpiT.COMM_WORLD
     19    self.coq = Queue() -- coroutine queue
     20    self.state.on = false
     21    self.state.io = false
     22    self.conf = conf
     23 end
     24 
     25 local function pClient_sendinit(self,srank,offset,size)
     26    coroutine.yield(mpiT.signal_INIT)   
     27    --print('pClient:sendinit',self.rank,srank,offset,size)
     28    self.sinfo[srank] = {}
     29    self.sinfo[srank].offset = offset
     30    self.sinfo[srank].size = size
     31    local cinfo = torch.LongStorage(2)
     32    cinfo[1] = offset
     33    cinfo[2] = size
     34    mpiT.aio_send(cinfo,2,mpiT.LONG,
     35 		 srank,mpiT.tag_ps_recv_init,self.mworld,self.state)
     36    --print('pClient:sendinit done')
     37    coroutine.yield(mpiT.signal_DONE)
     38 end
     39 
     40 local function pClient_sendstop(self,srank) 
     41    coroutine.yield(mpiT.signal_INIT)
     42    local tostop = torch.ByteStorage(1):fill(1)
     43    mpiT.aio_send(tostop,1,mpiT.BYTE,
     44 		 srank,mpiT.tag_ps_recv_stop,self.mworld,self.state)
     45    coroutine.yield(mpiT.signal_DONE)   
     46 end
     47 
     48 local function pClient_sendgrad(self,grad,srank)
     49    coroutine.yield(mpiT.signal_INIT)
     50    local sgrad = torch.Storage(grad,
     51 			       self.sinfo[srank].offset,
     52 			       self.sinfo[srank].size)
     53    mpiT.aio_send(sgrad,sgrad:size(),self.mtype,
     54 		 srank,mpiT.tag_ps_recv_grad,self.mworld,self.state)
     55    mpiT.aio_recv(self.emptys,0,self.mtype,
     56                  srank,mpiT.tag_ps_recv_grad_tail,self.mworld,self.state)
     57    coroutine.yield(mpiT.signal_DONE)
     58 end
     59 
     60 local function pClient_sendparam(self,param,srank)
     61    coroutine.yield(mpiT.signal_INIT)
     62    local sparam = torch.Storage(param,
     63 				self.sinfo[srank].offset,
     64 				self.sinfo[srank].size)
     65    mpiT.aio_send(sparam,sparam:size(),self.mtype,
     66 		 srank,mpiT.tag_ps_recv_param,self.mworld,self.state)
     67    mpiT.aio_recv(self.emptys,0,self.mtype,
     68                  srank,mpiT.tag_ps_recv_param_tail,self.mworld,self.state)
     69    coroutine.yield(mpiT.signal_DONE)
     70 end
     71 
     72 local function pClient_recvparam(self,param,srank)
     73    coroutine.yield(mpiT.signal_INIT)
     74    mpiT.aio_send(self.emptys,0,self.mtype,
     75 		 srank,mpiT.tag_ps_recv_header,self.mworld,self.state)
     76    local sparam = torch.Storage(param,
     77 				self.sinfo[srank].offset,
     78 				self.sinfo[srank].size)
     79    mpiT.aio_recv(sparam,sparam:size(),self.mtype,
     80 		 srank,mpiT.tag_ps_send_param,self.mworld,self.state)
     81    coroutine.yield(mpiT.signal_DONE)   
     82 end
     83 
     84 function pClient:async_recv_param()
     85    local param = self.pstorage
     86    for i,srank in pairs(self.sranks) do
     87       --print('pc ' .. self.rank .. ' recv param from ' .. srank)
     88       local co = mpiT.co_execute(pClient_recvparam,{self,param,srank})
     89       self.coq:push(co)
     90    end
     91 end
     92 
     93 function pClient:async_send_grad()
     94    local grad = self.gstorage
     95    for i,srank in pairs(self.sranks) do
     96       --print('pc ' .. self.rank .. ' send grad to ' .. srank)
     97       local co = mpiT.co_execute(pClient_sendgrad,{self,grad,srank})
     98       self.coq:push(co)
     99    end
    100 end
    101 
    102 function pClient:async_send_param()
    103    local param = self.pstorage
    104    for i,srank in pairs(self.sranks) do
    105       --print('pc send param to ' .. srank)
    106       local co = mpiT.co_execute(pClient_sendparam,{self,param,srank})
    107       self.coq:push(co)
    108    end
    109 end
    110 
    111 local function pClient_init(self)
    112    -- set offset size for each piece of parameter server
    113    local offset = 1
    114    local size = math.floor(self.plong/#self.sranks)
    115    for i,srank in pairs(self.sranks) do
    116       if i == #self.sranks then
    117 	 size = self.plong - offset + 1
    118       end
    119       local co = mpiT.co_execute(pClient_sendinit,{self,srank,offset,size})
    120       self.coq:push(co)
    121       offset = offset + size
    122    end
    123    mpiT.co_wait(self.coq)
    124    -- init pserver param
    125    if self.rank == self.cranks[1] then
    126       self:async_send_param(self.pstorage)
    127    end
    128    mpiT.co_wait(self.coq)
    129 end
    130 
    131 function pClient:ping(nb)
    132    local nb = nb or self.coq:len()
    133    for n=1,nb do
    134       mpiT.co_ping(self.coq)
    135    end
    136 end
    137 
    138 function pClient:reset(param,grad)
    139    if param then
    140       self.pstorage = param:storage()
    141       self.plong = self.pstorage:size()
    142       if grad then
    143 	 self.gstorage = grad:storage()
    144 	 assert(self.plong == self.gstorage:size())
    145       end
    146    end  
    147 end
    148 
    149 function pClient:wait()
    150    mpiT.co_wait(self.coq)
    151 end
    152 
    153 function pClient:stop()
    154    self:wait()
    155    -- stop servers
    156    for i,srank in pairs(self.sranks) do
    157       -- print('to stop server', srank)
    158       local co0 = mpiT.co_execute(pClient_sendstop,{self,srank})
    159       self.coq:push(co0)
    160    end
    161    self:wait()
    162    self.state.io = false
    163    self.state.on = false
    164 end
    165 
    166 function pClient:start(param,grad)
    167    self.state.on = true
    168    self.state.io = true
    169    if param then
    170       self.pstorage = param:storage()
    171       self.plong = self.pstorage:size()
    172       if grad then
    173 	 self.gstorage = grad:storage()
    174 	 assert(self.plong == self.gstorage:size())
    175       end
    176    end
    177    -- print('i am pc',self.rank,'p',self.pstorage:size(),'g',self.gstorage:size())
    178    pClient_init(self)
    179 end