How PyTorch implements DataParallel?

PyTorch can send batches and models to different GPUs automatically with DataParallel(model). How is it possible? I assume you know PyTorch uses dynamic computational graph as well as Python GIL. And PyTorch version is v1.0.1.

This is a complicated question and I asked on the PyTorch forum. I got a reply from Sebastian Raschka.

Data parallel’s process in a high level

TL;DR:

PyTorch trys hard in zero-copying. DataParallel splits tensor by its total size instead of along any axis.

DataParallel interface

Module defines its constructor and forward function. And DataParallel does the same. Let’s focus on forward.

def forward(self, *inputs, **kwargs):
    # ...
    inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) # Step 1, 2
    # ...
    self.module(*inputs[0], **kwargs[0]) # Build graph
    replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) # Step 3
    outputs = self.parallel_apply(replicas, inputs, kwargs) # Step 4
    return self.gather(outputs, self.output_device) # Step 5

There are 4 methods extracted and defined as instance methods which are scatter, replicate, gather and parallel_apply. They executes the step 1, 2; step 3; step 4; step 5 respectively.

Firstly, we have to recognize PyTorch utilize dynamic computational graph. That’s why module is invoked to build nodes. It builds the graph, parameters, grads and buffers.

And how Function works

Before we start to look into diffierent steps, let’s look at Function. We are going to see it many times.

class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixin, _HookMixin)):
    @staticmethod
    def forward(ctx, *args, **kwargs):
        r"""...
        It must accept a context ctx as the first argument, followed by any
        number of arguments (tensors or other types).

        The context can be used to store tensors that can be then retrieved
        during the backward pass.
        """
        raise NotImplementedError

    @staticmethod
    def backward(ctx, *grad_outputs):
        raise NotImplementedError

Some helper functions used Function as the base class. Moreover, since you will see Function.apply a lot in the following, we should see how it works. (metaclass is tricky.)


def with_metaclass(meta, *bases):
    """Create a base class with a metaclass."""
    # This requires a bit of explanation: the basic idea is to make a dummy
    # metaclass for one level of class instantiation that replaces itself with
    # the actual metaclass.
    class metaclass(meta):
        def __new__(cls, name, this_bases, d):
            return meta(name, bases, d)
    return type.__new__(metaclass, 'temporary_class', (), {})

This is tricky. We have to start from Python’s object design. Starting from Python 3, everything is an object. And all objects are constructed by class type(name, bases, dict). Each parameter would become class object’s __name__, __bases__, and __dict__. So type is the dynamic way to build a class. with_metaclass is a helper to build a metaclass without defining __dict__ yet. Namely, with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixin, _HookMixin)) reads FunctionMeta is the metaclass but the bases are in the tuple.

Then we are not far away from how it works.

class FunctionMeta(type):
    def __init__(cls, name, bases, attrs):
        // ...
        backward_fn = type(name + 'Backward', (BackwardCFunction,), {'_forward_cls': cls})
        setattr(cls, '_backward_cls', backward_fn)

        return super(FunctionMeta, cls).__init__(name, bases, attrs)

Basically a backward_fn is defined and stored in _backward_cls. Its base class is simple.

class BackwardCFunction(_C._FunctionBase, _ContextMethodMixin, _HookMixin):
    # ...
    def apply(self, *args):
        return self._forward_cls.backward(self, *args)

It will invoke the Function class’s backward method. Therefore _C._FunctionBase is more interesting. C._FunctionBase is defined by THPFunctionType. And the THPFunction defines its data structure.

static struct PyMethodDef THPFunction_methods[] = {
  {(char*)"apply", (PyCFunction)THPFunction_apply, METH_CLASS | METH_VARARGS, nullptr},
  {(char*)"_do_forward", (PyCFunction)THPFunction_do_forward, METH_VARARGS, nullptr},
  {(char*)"_do_backward", (PyCFunction)THPFunction_do_backward, METH_VARARGS, nullptr},
  {(char*)"_register_hook_dict", (PyCFunction)THPFunction__register_hook_dict, METH_O, nullptr},
  {(char*)"register_hook", (PyCFunction)THPFunction_register_hook, METH_O, nullptr},
  {nullptr}
};

It defines the apply method. That’s something we are interested.

PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs)
{
  // ...
  THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls"));
  if (!backward_cls) return nullptr;
  THPObjectPtr ctx_obj(PyObject_CallFunctionObjArgs(backward_cls, nullptr));
  if (!ctx_obj) return nullptr;
  THPFunction* ctx = (THPFunction*)ctx_obj.get();

  // Prepare inputs and allocate context (grad fn)
  auto info_pair = unpack_input<false>(inputs);
  UnpackedInput& unpacked_input = info_pair.first;
  InputFlags& input_info = info_pair.second;

  // Record input nodes if tracing
  // ...

  // Initialize backward function (and ctx)
  bool is_executable = input_info.is_executable;
  ctx->cdata.set_next_edges(std::move(input_info.next_edges));
  ctx->needs_input_grad = input_info.needs_input_grad.release();
  ctx->is_variable_input = std::move(input_info.is_variable_input);

  // Prepend ctx to input_tuple, in preparation for static method call
  auto num_args = PyTuple_GET_SIZE(inputs);
  THPObjectPtr ctx_input_tuple(PyTuple_New(num_args + 1));
  PyTuple_SET_ITEM(ctx_input_tuple.get(), 0, ctx_obj.release());
  // ...

  // Call forward
  THPObjectPtr tensor_outputs;
  {
    AutoGradMode grad_mode(false);
    THPObjectPtr forward_fn(PyObject_GetAttrString(cls, "forward"));
    if (!forward_fn) return nullptr;
    tensor_outputs = PyObject_CallObject(forward_fn, ctx_input_tuple);
    if (!tensor_outputs) return nullptr;
  }

  return process_outputs(cls, ctx, unpacked_input, inputs, std::move(tensor_outputs),
                         is_executable, node);
  END_HANDLE_TH_ERRORS
}

From here, we know that the cls.apply invokes cls.forward and prepares information for cls.backward. cls.apply takes its own class information and all parameters from Python. And those parameters will be applied on cls.forward. Noticabily, ctx object can carry the information to backward pass.

Now we can explore how DataParallel works.

Step 1 and Step 2: split minibatch on GPU:0 and move to GPU

scatter scatters the input argugments and returns a tuple of inputs and kwargs. The actual definiton is here:

def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
    r"""Scatter with support for kwargs dictionary"""
    inputs = scatter(inputs, target_gpus, dim) if inputs else []
    kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
    # ...
    return inputs, kwargs

scatter function is a recursive closure unwrap its input tensor(s). But its core is the Scatter class. This is a Function class which will operate differently in forward and backward.

class Scatter(Function):
    @staticmethod
    def forward(ctx, target_gpus, chunk_sizes, dim, input):
        # ...
        outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)
        # Synchronize with the copy stream
        if streams is not None:
            for i, output in enumerate(outputs):
                with torch.cuda.device(target_gpus[i]):
                    main_stream = torch.cuda.current_stream()
                    main_stream.wait_stream(streams[i])
                    output.record_stream(main_stream)
        return outputs

    @staticmethod
    def backward(ctx, *grad_output):
        return None, None, None, Gather.apply(ctx.input_device, ctx.dim, *grad_output)

Let’s focus on forward path. The backward path should be easier to understand. Thus on a multi-gpu environment, the input tensor will be sent to a PyTorch module for scattering and copying. This process needs synchronization.

The actual works in done by C++.

std::vector<at::Tensor> scatter(
    const at::Tensor& tensor,
    at::IntList devices,
    const c10::optional<std::vector<int64_t>>& chunk_sizes, // optional parameter is bad
    int64_t dim,
    const c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>& streams) {
  std::vector<at::Tensor> chunks;
  if (chunk_sizes) {
    // ...
    // flatten tensor by size. The default uses the other path
    int64_t chunk_start = 0;
    for (size_t chunk = 0; chunk < chunk_sizes->size(); ++chunk) {
      const int64_t chunk_size = (*chunk_sizes)[chunk];
      AT_CHECK(chunk_size > 0, "Chunk size must be positive");
      chunks.push_back(tensor.narrow(dim, chunk_start, chunk_size));
      chunk_start += chunk_size;
    }
  } else {
    // usually a tensor is seperated into chunks among devices
    chunks = tensor.chunk(/*chunks=*/devices.size(), /*dim=*/dim);
  }
  at::cuda::OptionalCUDAStreamGuard cuda_guard;
  for (size_t chunk = 0; chunk < chunks.size(); ++chunk) {
    // ...
    // dispatch to different devices
    chunks[chunk] = chunks[chunk].contiguous().to(
        {at::DeviceType::CUDA, device_index}, /*non_blocking=*/true);
  }
  return chunks;
}

As a matter of fact, tensor would be split which is done by at::narrow in the end. Since the operation only happens to strides and sizes, the memory is reused! PyTorch takes zero copy seriously at every level. But an important insight is that tensor is splitted regardless of its shape. You need to align different input tensors by its total size instead of a particular dimension.

Step 3: copy models to GPU

replicate sounds easy! But no. So the module is passed in as network.

def replicate(network, devices, detach=False):
    from ._functions import Broadcast

    // ...
    num_replicas = len(devices)

    // copy model parameters
    params = list(network.parameters())
    param_indices = {param: idx for idx, param in enumerate(params)}
    param_copies = Broadcast.apply(devices, *params)
    if len(params) > 0:
        param_copies = [param_copies[i:i + len(params)]
                        for i in range(0, len(param_copies), len(params))]

    // copy model buffers
    buffers = list(network.buffers())
    buffer_indices = {buf: idx for idx, buf in enumerate(buffers)}
    buffer_copies = comm.broadcast_coalesced(buffers, devices)

    // copy model modules
    modules = list(network.modules())
    module_copies = [[] for device in devices]
    module_indices = {}

    // construrct modules
    // ...

    // copy every submodules
    for i, module in enumerate(modules):
        for key, child in module._modules.items():
            // ...
                    replica = module_copies[j][i]
                    replica._modules[key] = module_copies[j][module_idx]
        // and their parameters
        for key, param in module._parameters.items():
            // ...
                    replica = module_copies[j][i]
                    replica._parameters[key] = param_copies[j][param_idx].detach() \
                        if detach else param_copies[j][param_idx]
        // also buffers
        for key, buf in module._buffers.items():
            // ...
                    replica = module_copies[j][i]
                    replica._buffers[key] = buffer_copies[j][buffer_idx]

    return [module_copies[j][0] for j in range(num_replicas)]

The copying part to GPU is hard to be found. But we know they are Broadcast and comm.broadcast_coalesced(buffers, devices). Until this point, we have the model living in different GPUs.

Here is the Broadcast. The invocation starts with Broadcast.apply(devices, *params).

class Broadcast(Function):
    @staticmethod
    def forward(ctx, target_gpus, *inputs):
        // ...
        ctx.target_gpus = target_gpus
        if len(inputs) == 0:
            return tuple()
        ctx.num_inputs = len(inputs)
        ctx.input_device = inputs[0].get_device()
        outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
        // ...
        return tuple([t for tensors in outputs for t in tensors])

    @staticmethod
    def backward(ctx, *grad_outputs):
        return (None,) + ReduceAddCoalesced.apply(ctx.input_device, ctx.num_inputs, *grad_outputs)

And its backward direction ReduceAddCoalesced looks similar.

class ReduceAddCoalesced(Function):

    @staticmethod
    def forward(ctx, destination, num_inputs, *grads):
        // ...
        grads = [grads[i:i + num_inputs]
                 for i in range(0, len(grads), num_inputs)]
        return comm.reduce_add_coalesced(grads, destination)

    @staticmethod
    def backward(ctx, *grad_outputs):
        return (None, None,) + Broadcast.apply(ctx.target_gpus, *grad_outputs)

So the core parts are comm.broadcast_coalesced and comm.reduce_add_coalesced.

The comm.broadcast_coalesced is implemented in C++. We have model parameters as tensors.

tensor_list2d broadcast_coalesced(TensorList tensors, IntArrayRef devices, size_t buffer_size) {
  if (!std::all_of(tensors.begin(), tensors.end(),
                   [&](const at::Tensor& t) { return t.get_device() == devices[0]; })) {
    throw std::runtime_error("all tensors must be on devices[0]");
  }
  // ...

  tensor_list2d outputs(devices.size());
  outputs[0] = tensors.vec();
  // ...

  for (auto & chunk : utils::take_tensors(tensors, buffer_size)) {
    // ...
    std::vector<at::Tensor> results;
    if (chunk.type().is_sparse()) {
      auto flat_tuple = utils::flatten_sparse_tensors(chunk.tensors);
      std::vector<at::Tensor> broadcast_indices = broadcast(flat_tuple.first, devices);
      std::vector<at::Tensor> broadcast_values = broadcast(flat_tuple.second, devices);
      // ...
    } else {
      std::vector<Tensor> results = broadcast(utils::flatten_dense_tensors(chunk.tensors),
                                              devices);
      // ...
    }
  }

  // ...
  return outputs;
}

This is too complicated and perhaps not insightful for our aim. But effecitively NCCL is used to broadcast tensor’s address to multiple GPUs. reduce_add_coalesced does the reverse.

Step 4: Forward pass

Finally, this step is obvious. But we need to pay attention to the output tuple because of the next step.

def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
    r"""Applies each `module` in :attr:`modules` in parallel on arguments
    contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
    on each of :attr:`devices`.

    Args:
        modules (Module): modules to be parallelized
        inputs (tensor): inputs to the modules
        devices (list of int or torch.device): CUDA devices

    :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
    :attr:`devices` (if given) should all have same length. Moreover, each
    element of :attr:`inputs` can either be a single object as the only argument
    to a module, or a collection of positional arguments.
    """
    # ...

    def _worker(i, module, input, kwargs, device=None):
        # ...
        try:
            with torch.cuda.device(device):
                # this also avoids accidental slicing of `input` if it is a Tensor
                if not isinstance(input, (list, tuple)):
                    input = (input,)
                output = module(*input, **kwargs)
            with lock:
                results[i] = output
        except Exception as e:
            with lock:
                results[i] = e

    if len(modules) > 1:
        threads = [threading.Thread(target=_worker,
                                    args=(i, module, input, kwargs, device))
                   for i, (module, input, kwargs, device) in
                   enumerate(zip(modules, inputs, kwargs_tup, devices))]

        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()
    else:
        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])

    # ...
    return outputs

results keeps track of results from every devices. Most cases, it will be a scalar loss tensor.

Python’s threading library is used to send things to models on different GPUs. This is a little bit expensive than I thought because of the overhead from threading.

Step 5: Compute

The gather has a similiar design as scatter. The core is Gather.apply(target_device, dim, *outputs). And its core is comm.gather(inputs, ctx.dim, ctx.target_device).

at::Tensor gather(
    at::TensorList tensors,
    int64_t dim,
    c10::optional<int32_t> destination_index) {
  # ...
  AT_CHECK(!tensors.empty(), "Expected at least one tensor to gather from");
  at::Tensor result;
  int64_t total_size = 0;
  auto& first = tensors.front();
  const auto first_size = first.sizes();
  std::vector<int64_t> expected_size(first_size.begin(), first_size.end());
  for (const auto& tensor : tensors) {
    // ...
    expected_size[dim] = tensor.size(dim);
    // ...
    total_size += tensor.size(dim);
  }
  expected_size[dim] = total_size;
  at::Device device(at::DeviceType::CPU);
  if (!destination_index || *destination_index != -1) {
    device = at::Device(at::DeviceType::CUDA, destination_index ? *destination_index : -1);
  }
  result = at::empty(expected_size, first.options().device(device));

  int64_t chunk_start = 0;
  for (const auto& tensor : tensors) {
    result.narrow(dim, chunk_start, tensor.size(dim))
        .copy_(tensor, /*non_blocking=*/true);
    chunk_start += tensor.size(dim);
  }
  return result;
}

It’s actually hard to find things. This basically compose a final tensor from tensors from devices. destination_index is the index of output device which is GPU:0 by default (defined as the parameter of data_parrallel).

Step 6: Loss value

This is the standard code you will write.

loss = model.forward(x)
loss = loss.sum()
optimizer.zero_grad()
loss.backward()
optimizer.step()

In next minibatch, same process will be sent to GPU.

Discussion

PyTorch focuses on abstraction and zero-copy. It’s more intutive to Tensorflow 1.x. It’s indeed an engineering success. Despite its simple interfaces, we need to understand its Tensor abstraction and those convenient helpers. Since CUDA uses different programming models, the work to hide CUDA details is enormous. PyTorch did this by mixing C++ and Python with pybind11. And it strives to maintain a similiar language in a large hierarchy that aligns the Tensor, Function, and Model. That’s remarkable.

Though, programming in a mixed languages are difficult. Type checking is always complicated. GIL needs to be remembered. With CUDA, synchronization, memory management and communication expands the scope of engineering work.

Leave a Comment