pserver.lua (5026B)
1 -- Parameter server 2 -- Author: Sixin Zhang (zsx@cims.nyu.edu) 3 -- Important change: 4 -- init the parameter from the first local worker (once and only once) 5 -- before other service gets started. From the pclient side, 6 -- it's also the first local pclient (worker) who sends its parameter 7 -- to the psever. 8 require 'mpiT' 9 10 local pServer = torch.class('pServer') 11 12 function pServer:__init(conf,state) 13 self.state = state or {} 14 self.rank = conf.rank or -1 15 self.cranks = conf.cranks or {} -- client ranks 16 self.mtype = mpiT.FLOAT 17 self.mworld = conf.world or mpiT.COMM_WORLD 18 19 self.offset = -1 -- offset of param grad 20 self.size = -1 -- size of param grad 21 self.tensor = {} -- tensor from storage 22 self.storage = {} -- param gradient storage 23 self.emptys = torch.Storage() 24 25 self.state.on = false 26 self.state.io = false 27 self.state.iostop = 0 28 self.coq = Queue() -- coroutine queue 29 30 self.conf = conf 31 end 32 33 local function pServer_recvinit(self,crank) 34 coroutine.yield(mpiT.signal_INIT) 35 -- get meta info 36 local cinfo = torch.LongStorage(2) 37 mpiT.aio_recv(cinfo,2,mpiT.LONG, 38 crank,mpiT.tag_ps_recv_init,self.mworld,self.state) 39 -- init storage 40 if self.offset == -1 then 41 self.offset = cinfo[1] 42 self.size = cinfo[2] 43 self.storage.p = torch.Storage(self.size) 44 self.storage.g = {} 45 self.storage.g[crank] = torch.Storage(self.size) 46 self.tensor.p = torch.Tensor(self.storage.p) 47 self.tensor.g = {} 48 self.tensor.g[crank] = torch.Tensor(self.storage.g[crank]) 49 else 50 assert(self.offset == cinfo[1]) 51 assert(self.size == cinfo[2]) 52 self.storage.g[crank] = torch.Storage(self.size) 53 self.tensor.g[crank] = torch.Tensor(self.storage.g[crank]) 54 end 55 --print('pServer:recvinit',self.rank,crank,self.offset,self.size) 56 coroutine.yield(mpiT.signal_DONE) 57 end 58 59 local function pServer_sendparam(self,crank) 60 coroutine.yield(mpiT.signal_INIT) 61 while (self.state.on) do 62 --print('pServer_sendparam to recv',crank,self.size) 63 mpiT.aio_recv(self.emptys,0,self.mtype, 64 crank,mpiT.tag_ps_recv_header,self.mworld,self.state) 65 --print('ps ' .. self.rank .. ' send param to ' .. crank) 66 if self.state.io then 67 mpiT.aio_send(self.storage.p,self.size,self.mtype, 68 crank,mpiT.tag_ps_send_param,self.mworld,self.state) 69 end 70 end 71 coroutine.yield(mpiT.signal_DONE) 72 end 73 74 -- Warning: no lock on self.tensor.p during this recvgrad, expect inconsistent read 75 local function pServer_recvgrad(self,crank) 76 coroutine.yield(mpiT.signal_INIT) 77 while (self.state.on) do 78 -- recv 79 mpiT.aio_recv(self.storage.g[crank],self.size,self.mtype, 80 crank,mpiT.tag_ps_recv_grad,self.mworld,self.state) 81 --print('ps ' .. self.rank .. ' recv grad from ' .. crank) 82 -- apply 83 self.tensor.p:add(self.tensor.g[crank]) 84 if self.state.on then 85 mpiT.aio_send(self.emptys,0,self.mtype, 86 crank,mpiT.tag_ps_recv_grad_tail,self.mworld,self.state) 87 end 88 end 89 coroutine.yield(mpiT.signal_DONE) 90 end 91 92 local function pServer_recvparam(self,crank) 93 coroutine.yield(mpiT.signal_INIT) 94 mpiT.aio_recv(self.storage.p,self.size,self.mtype, 95 crank,mpiT.tag_ps_recv_param,self.mworld,self.state) 96 if self.state.on then 97 mpiT.aio_send(self.emptys,0,self.mtype, 98 crank,mpiT.tag_ps_recv_param_tail,self.mworld,self.state) 99 end 100 --print('ps ' .. self.rank .. ' recv param from ' .. crank) 101 coroutine.yield(mpiT.signal_DONE) 102 end 103 104 -- stop 105 function table.len(tbl) 106 local l=0 107 if tbl then 108 for k,v in pairs(tbl) do 109 l=l+1 110 end 111 end 112 return l 113 end 114 115 local function pServer_recvstop(self,crank) 116 coroutine.yield(mpiT.signal_INIT) 117 local tostop = torch.ByteStorage(1) 118 mpiT.aio_recv(tostop,1,mpiT.BYTE, 119 crank,mpiT.tag_ps_recv_stop,self.mworld,self.state) 120 if tostop[1] then 121 self.state.iostop = self.state.iostop + 1 122 if self.state.iostop == table.len(self.cranks) then 123 self.state.on = false 124 self.state.io = false 125 -- print('server', self.rank, 'finally stoped by', crank) 126 end 127 end 128 coroutine.yield(mpiT.signal_DONE) 129 end 130 131 function pServer:start() 132 self.state.on = true 133 self.state.io = true 134 -- init 135 self.coq:clear() 136 for i,crank in pairs(self.cranks) do 137 local co = mpiT.co_execute(pServer_recvinit,{self,crank}) 138 self.coq:push(co) 139 end 140 mpiT.co_wait(self.coq) 141 for i,crank in pairs(self.cranks) do 142 if i == 1 then 143 -- init the parameter from the first local worker 144 local co3 = mpiT.co_execute(pServer_recvparam,{self,crank}) 145 self.coq:push(co3) 146 mpiT.co_wait(self.coq) 147 end 148 -- on request 149 local co0 = mpiT.co_execute(pServer_recvstop,{self,crank}) 150 local co1 = mpiT.co_execute(pServer_recvgrad,{self,crank}) 151 local co2 = mpiT.co_execute(pServer_sendparam,{self,crank}) 152 self.coq:push(co0) 153 self.coq:push(co1) 154 self.coq:push(co2) 155 end 156 mpiT.co_wait(self.coq) 157 end