Examples
Simple data parallel example
Let us assume we are in typical supervised learning situation. We have plenty of data (xinput
, youtput
),
and we search for unknown parameters minimizing some norm or general functional, simply referred to as the
loss function. Furthermore, we assume that the
loss function is just a summation of losses per data point. E.g. consider the following squared error:
def lossfunction(params):
# compute local loss
localloss = torch.sum(torch.square(youtput - some_parametrized_function(xinput, params)))
return localloss
This function is usually fed into an gradient-based optimizer to find the optimal parameters.
We want to argue in the following that parallelizing this code in a data-parallel way is often as easy as
adding two calls to mpi4torch.MPI_Communicator.Allreduce()
:
def lossfunction(params):
# average initial params to bring all ranks on the same page
params = comm.Allreduce(params, mpi4torch.MPI_SUM) / comm.size
# compute local loss
localloss = torch.sum(torch.square(youtput - some_parametrized_function(xinput, params)))
# sum up the loss among all ranks
return comm.Allreduce(localloss, mpi4torch.MPI_SUM)
mpi4torch.MPI_Communicator.Allreduce()
is used once to compute the average of the incoming parameters
and once to collect the total loss.
Embedded in a whole program this may look like (the code is also available in the git repository in the examples folder):
1import torch
2import mpi4torch
3import mpi4py.MPI
4
5comm = mpi4torch.COMM_WORLD
6
7torch.manual_seed(42)
8
9num_points = 10000
10chunk_size = num_points // comm.size
11rest = num_points % comm.size
12if comm.rank < rest:
13 chunk_size += 1
14 offset = chunk_size * comm.rank
15else:
16 offset = chunk_size * comm.rank + rest
17
18xinput = 2.0 * torch.rand([num_points],dtype=torch.double)[offset:offset+chunk_size]
19
20def some_parametrized_function(inp, params):
21 return (params[2] * inp + params[1]) * inp + params[0]
22
23gen_params = torch.tensor([0.1, 1.0, -2.0])
24
25youtput = some_parametrized_function(xinput, gen_params)
26
27def lossfunction(params):
28 # average initial params to bring all ranks on the same page
29 params = comm.Allreduce(params, mpi4torch.MPI_SUM) / comm.size
30
31 # compute local loss
32 localloss = torch.sum(torch.square(youtput - some_parametrized_function(xinput, params)))
33
34 # sum up the loss among all ranks
35 return comm.Allreduce(localloss, mpi4torch.MPI_SUM)
36
37params = torch.arange(3, dtype=torch.double).requires_grad_()
38
39# LBFGS only needs one outer iteration for a linear problem
40# with so few parameters
41num_iterations = 1
42optimizer = torch.optim.LBFGS([params], 1)
43
44for i in range(num_iterations):
45 def closure():
46 loss = lossfunction(params)
47 optimizer.zero_grad()
48 loss.backward()
49 if comm.rank == 0:
50 print("Params: ", params)
51 print("Loss : ", loss)
52 return loss
53 optimizer.step(closure)
54
55# only print output on rank 0
56if comm.rank == 0:
57 print("Final parameters: ", params)
Note that although the averaging in line 29 might seem superfluous at first — since all ranks start
off with the same initial set of parameters — having the adjoint of mpi4torch.MPI_Communicator.Allreduce()
in the backward pass
is essential for all instances of the LBFGS optimizer to perform the same update on all ranks.
For the second call to mpi4torch.MPI_Communicator.Allreduce()
in line 35 it is actually the other way
around: Here the forward pass is crucial, but the backward pass merely adds up the ones coming from the
different ranks, which (surprise) results in a vector of length 1 that just contains the communicator size.
It is easy to see that the forward pass is indpendent of the number of ranks used to compute the result. That
the parallelized backward pass also gives the same result may at first seem a bit surprising,
as we already saw that the gradient with respect to localloss
will just store the size
of the MPI communicator. However,
the corresponding backward code of the averaging in line 29 divides again through comm.size
, such that
in total all gradients from all ranks are simply added up. The final
gradient as stored in params.grad
is thus also independent of the number of processes.
Starting off with the same parameters on all ranks, it is thus ensured that all local LBFGS instances see the same parameters, the same losses and the same gradients, and thus perform the identical operations and give the same result.