pclient.lua (5356B)
1 -- Parameter client 2 -- Author: Sixin Zhang (zsx@cims.nyu.edu) 3 require 'mpiT' 4 5 local pClient = torch.class('pClient') 6 7 function pClient:__init(conf,state) 8 self.state = state or {} 9 self.rank = conf.rank or -1 10 self.sranks = conf.sranks or {} -- server ranks 11 self.cranks = conf.cranks or {} -- client ranks 12 self.plong = conf.plong or 0 -- size of whole parameter 13 self.pstorage = conf.pstorage or torch.Storage() 14 self.gstorage = conf.gstorage or torch.Storage() 15 self.emptys = torch.Storage() 16 self.sinfo = {} 17 self.mtype = mpiT.FLOAT 18 self.mworld = conf.world or mpiT.COMM_WORLD 19 self.coq = Queue() -- coroutine queue 20 self.state.on = false 21 self.state.io = false 22 self.conf = conf 23 end 24 25 local function pClient_sendinit(self,srank,offset,size) 26 coroutine.yield(mpiT.signal_INIT) 27 --print('pClient:sendinit',self.rank,srank,offset,size) 28 self.sinfo[srank] = {} 29 self.sinfo[srank].offset = offset 30 self.sinfo[srank].size = size 31 local cinfo = torch.LongStorage(2) 32 cinfo[1] = offset 33 cinfo[2] = size 34 mpiT.aio_send(cinfo,2,mpiT.LONG, 35 srank,mpiT.tag_ps_recv_init,self.mworld,self.state) 36 --print('pClient:sendinit done') 37 coroutine.yield(mpiT.signal_DONE) 38 end 39 40 local function pClient_sendstop(self,srank) 41 coroutine.yield(mpiT.signal_INIT) 42 local tostop = torch.ByteStorage(1):fill(1) 43 mpiT.aio_send(tostop,1,mpiT.BYTE, 44 srank,mpiT.tag_ps_recv_stop,self.mworld,self.state) 45 coroutine.yield(mpiT.signal_DONE) 46 end 47 48 local function pClient_sendgrad(self,grad,srank) 49 coroutine.yield(mpiT.signal_INIT) 50 local sgrad = torch.Storage(grad, 51 self.sinfo[srank].offset, 52 self.sinfo[srank].size) 53 mpiT.aio_send(sgrad,sgrad:size(),self.mtype, 54 srank,mpiT.tag_ps_recv_grad,self.mworld,self.state) 55 mpiT.aio_recv(self.emptys,0,self.mtype, 56 srank,mpiT.tag_ps_recv_grad_tail,self.mworld,self.state) 57 coroutine.yield(mpiT.signal_DONE) 58 end 59 60 local function pClient_sendparam(self,param,srank) 61 coroutine.yield(mpiT.signal_INIT) 62 local sparam = torch.Storage(param, 63 self.sinfo[srank].offset, 64 self.sinfo[srank].size) 65 mpiT.aio_send(sparam,sparam:size(),self.mtype, 66 srank,mpiT.tag_ps_recv_param,self.mworld,self.state) 67 mpiT.aio_recv(self.emptys,0,self.mtype, 68 srank,mpiT.tag_ps_recv_param_tail,self.mworld,self.state) 69 coroutine.yield(mpiT.signal_DONE) 70 end 71 72 local function pClient_recvparam(self,param,srank) 73 coroutine.yield(mpiT.signal_INIT) 74 mpiT.aio_send(self.emptys,0,self.mtype, 75 srank,mpiT.tag_ps_recv_header,self.mworld,self.state) 76 local sparam = torch.Storage(param, 77 self.sinfo[srank].offset, 78 self.sinfo[srank].size) 79 mpiT.aio_recv(sparam,sparam:size(),self.mtype, 80 srank,mpiT.tag_ps_send_param,self.mworld,self.state) 81 coroutine.yield(mpiT.signal_DONE) 82 end 83 84 function pClient:async_recv_param() 85 local param = self.pstorage 86 for i,srank in pairs(self.sranks) do 87 --print('pc ' .. self.rank .. ' recv param from ' .. srank) 88 local co = mpiT.co_execute(pClient_recvparam,{self,param,srank}) 89 self.coq:push(co) 90 end 91 end 92 93 function pClient:async_send_grad() 94 local grad = self.gstorage 95 for i,srank in pairs(self.sranks) do 96 --print('pc ' .. self.rank .. ' send grad to ' .. srank) 97 local co = mpiT.co_execute(pClient_sendgrad,{self,grad,srank}) 98 self.coq:push(co) 99 end 100 end 101 102 function pClient:async_send_param() 103 local param = self.pstorage 104 for i,srank in pairs(self.sranks) do 105 --print('pc send param to ' .. srank) 106 local co = mpiT.co_execute(pClient_sendparam,{self,param,srank}) 107 self.coq:push(co) 108 end 109 end 110 111 local function pClient_init(self) 112 -- set offset size for each piece of parameter server 113 local offset = 1 114 local size = math.floor(self.plong/#self.sranks) 115 for i,srank in pairs(self.sranks) do 116 if i == #self.sranks then 117 size = self.plong - offset + 1 118 end 119 local co = mpiT.co_execute(pClient_sendinit,{self,srank,offset,size}) 120 self.coq:push(co) 121 offset = offset + size 122 end 123 mpiT.co_wait(self.coq) 124 -- init pserver param 125 if self.rank == self.cranks[1] then 126 self:async_send_param(self.pstorage) 127 end 128 mpiT.co_wait(self.coq) 129 end 130 131 function pClient:ping(nb) 132 local nb = nb or self.coq:len() 133 for n=1,nb do 134 mpiT.co_ping(self.coq) 135 end 136 end 137 138 function pClient:reset(param,grad) 139 if param then 140 self.pstorage = param:storage() 141 self.plong = self.pstorage:size() 142 if grad then 143 self.gstorage = grad:storage() 144 assert(self.plong == self.gstorage:size()) 145 end 146 end 147 end 148 149 function pClient:wait() 150 mpiT.co_wait(self.coq) 151 end 152 153 function pClient:stop() 154 self:wait() 155 -- stop servers 156 for i,srank in pairs(self.sranks) do 157 -- print('to stop server', srank) 158 local co0 = mpiT.co_execute(pClient_sendstop,{self,srank}) 159 self.coq:push(co0) 160 end 161 self:wait() 162 self.state.io = false 163 self.state.on = false 164 end 165 166 function pClient:start(param,grad) 167 self.state.on = true 168 self.state.io = true 169 if param then 170 self.pstorage = param:storage() 171 self.plong = self.pstorage:size() 172 if grad then 173 self.gstorage = grad:storage() 174 assert(self.plong == self.gstorage:size()) 175 end 176 end 177 -- print('i am pc',self.rank,'p',self.pstorage:size(),'g',self.gstorage:size()) 178 pClient_init(self) 179 end