Examples

Simple data parallel example

Let us assume we are in typical supervised learning situation. We have plenty of data (xinput, youtput), and we search for unknown parameters minimizing some norm or general functional, simply referred to as the loss function. Furthermore, we assume that the loss function is just a summation of losses per data point. E.g. consider the following squared error:

def lossfunction(params):
    # compute local loss
    localloss = torch.sum(torch.square(youtput - some_parametrized_function(xinput, params)))
    return localloss

This function is usually fed into an gradient-based optimizer to find the optimal parameters. We want to argue in the following that parallelizing this code in a data-parallel way is often as easy as adding two calls to mpi4torch.MPI_Communicator.Allreduce():

def lossfunction(params):
    # average initial params to bring all ranks on the same page
    params = comm.Allreduce(params, mpi4torch.MPI_SUM) / comm.size

    # compute local loss
    localloss = torch.sum(torch.square(youtput - some_parametrized_function(xinput, params)))

    # sum up the loss among all ranks
    return comm.Allreduce(localloss, mpi4torch.MPI_SUM)

mpi4torch.MPI_Communicator.Allreduce() is used once to compute the average of the incoming parameters and once to collect the total loss.

Embedded in a whole program this may look like (the code is also available in the git repository in the examples folder):

 1import torch
 2import mpi4torch
 3import mpi4py.MPI
 4
 5comm = mpi4torch.COMM_WORLD
 6
 7torch.manual_seed(42)
 8
 9num_points = 10000
10chunk_size = num_points // comm.size
11rest = num_points % comm.size
12if comm.rank < rest:
13    chunk_size += 1
14    offset = chunk_size * comm.rank
15else:
16    offset = chunk_size * comm.rank + rest
17
18xinput = 2.0 * torch.rand([num_points],dtype=torch.double)[offset:offset+chunk_size]
19
20def some_parametrized_function(inp, params):
21    return (params[2] * inp + params[1]) * inp + params[0]
22
23gen_params = torch.tensor([0.1, 1.0, -2.0])
24
25youtput = some_parametrized_function(xinput, gen_params)
26
27def lossfunction(params):
28    # average initial params to bring all ranks on the same page
29    params = comm.Allreduce(params, mpi4torch.MPI_SUM) / comm.size
30
31    # compute local loss
32    localloss = torch.sum(torch.square(youtput - some_parametrized_function(xinput, params)))
33
34    # sum up the loss among all ranks
35    return comm.Allreduce(localloss, mpi4torch.MPI_SUM)
36
37params = torch.arange(3, dtype=torch.double).requires_grad_()
38
39# LBFGS only needs one outer iteration for a linear problem
40# with so few parameters
41num_iterations = 1
42optimizer = torch.optim.LBFGS([params], 1)
43
44for i in range(num_iterations):
45    def closure():
46        loss = lossfunction(params)
47        optimizer.zero_grad()
48        loss.backward()
49        if comm.rank == 0:
50            print("Params: ", params)
51            print("Loss  : ", loss)
52        return loss
53    optimizer.step(closure)
54
55# only print output on rank 0
56if comm.rank == 0:
57    print("Final parameters: ", params)

Note that although the averaging in line 29 might seem superfluous at first — since all ranks start off with the same initial set of parameters — having the adjoint of mpi4torch.MPI_Communicator.Allreduce() in the backward pass is essential for all instances of the LBFGS optimizer to perform the same update on all ranks.

For the second call to mpi4torch.MPI_Communicator.Allreduce() in line 35 it is actually the other way around: Here the forward pass is crucial, but the backward pass merely adds up the ones coming from the different ranks, which (surprise) results in a vector of length 1 that just contains the communicator size.

It is easy to see that the forward pass is indpendent of the number of ranks used to compute the result. That the parallelized backward pass also gives the same result may at first seem a bit surprising, as we already saw that the gradient with respect to localloss will just store the size of the MPI communicator. However, the corresponding backward code of the averaging in line 29 divides again through comm.size, such that in total all gradients from all ranks are simply added up. The final gradient as stored in params.grad is thus also independent of the number of processes.

Starting off with the same parameters on all ranks, it is thus ensured that all local LBFGS instances see the same parameters, the same losses and the same gradients, and thus perform the identical operations and give the same result.