‹ projects

cluster-rnn

a distributed Torch7 RNN cluster over MPI
Log | Files | Refs | README

ptest.lua (1743B)


      1 -- mpirun -n 2 luajit ptest.lua
      2 
      3 local ssize = 10*4096*4096
      4 local usecuda = false
      5 
      6 require 'mpiT'
      7 dofile('init.lua')
      8 mpiT.Init()
      9 local world = mpiT.COMM_WORLD
     10 local rank = mpiT.get_rank(world)
     11 local size = mpiT.get_size(world)
     12 
     13 assert((size>0) and (size%2==0))
     14 
     15 local conf = {}
     16 conf.rank = rank
     17 conf.world = world
     18 conf.sranks = {}
     19 conf.cranks = {}
     20 for i = 0,size-1 do
     21    if i < size/2 then
     22       table.insert(conf.sranks,i) --as server
     23    else
     24       table.insert(conf.cranks,i) --as client
     25    end
     26 end
     27 
     28 if rank < size/2 then
     29    print('rank ' .. rank .. ' is server.')
     30    -- require 'cutorch' -- in case usecuda==true and your mpirun does not stop, try uncomment this out.
     31    torch.setdefaulttensortype('torch.FloatTensor')
     32    print(rank,'use cpu')
     33    -- server   
     34    local ps = pServer(conf)
     35    ps:start()
     36 else
     37    print('rank ' .. rank .. ' is client.')
     38 
     39    -- use gpu?
     40    if usecuda then
     41       require 'cutorch'
     42       torch.setdefaulttensortype('torch.CudaTensor')
     43       local gpus = cutorch.getDeviceCount()
     44       local gpu =(rank%(size/2)) % gpus + 1
     45       cutorch.setDevice(gpu)
     46       print(rank,'use gpu',gpu)
     47    else
     48       torch.setdefaulttensortype('torch.FloatTensor')
     49       print(rank,'use cpu')
     50    end
     51 
     52    -- client
     53    local theta = torch.Tensor(ssize)
     54    local grad = torch.Tensor(ssize)
     55    local pc = pClient(conf)
     56    pc:start(theta,grad)
     57 
     58    local begin = os.time()
     59    local T = 100
     60    for t=1,T do
     61       print('t=' .. t .. ':rank' .. rank .. 'pingpong')
     62       pc:async_recv_param()
     63       pc:async_send_grad()
     64       pc:wait()
     65    end
     66    local now = os.time()
     67    print('rank ' .. rank .. ' bandwidth(bi-direction) is ' .. (2*T*ssize*4/(now-begin)/1024/1024) .. ' MBytes/sec')
     68    pc:stop()
     69    print('pc stopped')
     70 end
     71 
     72 mpiT.Finalize()