‹ projects

cluster-rnn

a distributed Torch7 RNN cluster over MPI
Log | Files | Refs | README

testreduceall.lua (1945B)


      1 -- mpirun -n 2 luajit testreduceall.lua
      2 
      3 local ssize = 3*4096*4096
      4 local usecuda = false
      5 
      6 require 'mpiT'
      7 dofile('init.lua')
      8 mpiT.Init()
      9 local world = mpiT.COMM_WORLD
     10 local rank = mpiT.get_rank(world)
     11 local size = mpiT.get_size(world)
     12 
     13 assert((size>0) and (size%2==0))
     14 
     15 local conf = {}
     16 conf.rank = rank
     17 conf.world = world
     18 conf.sranks = {}
     19 conf.cranks = {}
     20 for i = 0,size-1 do
     21    if i < size/2 then
     22       table.insert(conf.sranks,i) --as server
     23    else
     24       table.insert(conf.cranks,i) --as client
     25    end
     26 end
     27 
     28 if rank < size/2 then
     29    print('rank ' .. rank .. ' is server.')
     30    -- require 'cutorch' -- in case usecuda==true and your mpirun does not stop, try uncomment this out.
     31    torch.setdefaulttensortype('torch.FloatTensor')
     32    print(rank,'use cpu')
     33    -- server   
     34    local ps = pServer(conf)
     35    ps:start()
     36 else
     37    print('rank ' .. rank .. ' is client.')
     38 
     39    -- use gpu?
     40    if usecuda then
     41       require 'cutorch'
     42       torch.setdefaulttensortype('torch.CudaTensor')
     43       local gpus = cutorch.getDeviceCount()
     44       local gpu =(rank%(size/2)) % gpus + 1
     45       cutorch.setDevice(gpu)
     46       print(rank,'use gpu',gpu)
     47    else
     48       torch.setdefaulttensortype('torch.FloatTensor')
     49       print(rank,'use cpu')
     50    end
     51 
     52    -- client
     53    local theta = torch.Tensor(ssize)
     54    local grad = torch.Tensor(ssize)
     55    local pc = pClient(conf)
     56    pc:start(theta,grad)
     57 
     58    local begin = os.time()
     59    for t=1,3 do
     60       print('rank' .. rank .. 'pingpong')
     61       pc:async_recv_param()
     62       pc:async_send_grad()
     63       pc:wait()
     64    end
     65    local now = os.time()
     66    print('rank ' .. rank .. ' bandwidth(bi-direction) is ' .. (2*ssize*4/(now-begin)/1024/1024) .. ' MBytes/sec')
     67    
     68    pc:stop()
     69    print('pc stopped')
     70 end
     71 
     72    torch.manualSeed(123)
     73    a = torch.FloatTensor(3,2):uniform()
     74    b = torch.FloatTensor(3,2)
     75    print('a', a)
     76    mpiT.Allreduce(a:storage(), b:storage(), 6, mpiT.FLOAT, mpiT.SUM, mpiT.COMM_WORLD)
     77    print('b after a', b)
     78 
     79 mpiT.Finalize()