Understand FSDP2

After reading this blog, you will understand how FSDP2 orchestrates communication and computation with multiple cuda streams. A lot of blogs and the official RFC mention that FSDP2 avoids record_stream. However, none of them dig into the code details of how FSDP2 achieves this. We will dig into the code details and you will also learn a lot on cuda stream programming model if you’re new to it.

In this blog, we use dist as an alias of torch.distributed.

This blog analyzes the FSDP2 implementation of pytorch at 7e10855bb07cec6f4e40c982c9e0ffc7d90c8525. Commit date: 2026.03.09.

All-Gather Before Forward and Re-shard After Forward

Conceptually understand what should be done

How does FSDP work? It is a data parallel strategy. This means all DP ranks process a micro-batch of data. Instead of duplicating the full model across DP ranks (as in DDP) and doing all-reduce for gradient accumulation, FSDP shards optimizer states, gradients and parameters across DP ranks. Of course, it will lead to more communication cost than DDP. FSDP trades communication overhead with GPU memory. See the last section of this blog to see its relationship with ZeRO series of works.

FSDP overlaps computation and communication. For instance, the forward pass of a layer needs full parameter of that layer but FSDP shards parameters. So it does all-gather on parameters of a layer before its forward computation. With the philosophy of pipeline parallel, we can overlap the layer ‘s all-gather communication and layer ‘s forward computation. See the image from the Pytorch FSDP tutorial and the first section of this blog for more details.

FSDP2 forward pass just does three things: all-gather (gather the parameters of a layer to prepare for forward computation), forward computation and re-shard (shard the gathered parameters to DP ranks again after forward computation).

Copy-in

Motivation

The all-gather step calls dist.all_gather_into_tensor which returns a single local Tensor (not torch.distributed.all_gather which returns a list of Tensor).

Note that the output tensor and input tensor are all 1D flat tensors.

INPUT (tensor with same shape on each GPU):

┌─────────┬─────────┬─────────┬─────────┐
│ 0..99 │ 100..199│ 200..299│ 300..399│
└─────────┴─────────┴─────────┴─────────┘
▲─────────▲▲────────▲▲────────▲▲────────▲
│ GPU 0 ││ GPU 1 ││ GPU 2 ││ GPU 3 │

OUTPUT (on GPU 0-3):

┌─────────┬─────────┬─────────┬─────────┐
│ 0..99 │ 100..199│ 200..299│ 300..399│
└─────────┴─────────┴─────────┴─────────┘
▲───────────────────────────────────────▲
│ GPU 0 - 3 │

Therefore, before the all-gather step, we need to allocate the local tensor on each DP rank so we are ready to store the all-gather result. This is what copy-in does.

Implementation

In implicit prefetch, copy-in is done on a dedicated cuda stream. This design allows us to overlap the all-gather stage of layer with the copy-in stage of layer . This is essentially an instantiation of pipeline parallel.

Let’s look into what the all_gather_copy_in_stream cuda stream does. It first allocates the local empty output tensor to store the all-gather result on each DP rank (numel = local_tensor_shape * DP_rank_size). Then, it calls torch.ops.fsdp.all_gather_copy_in, and let’s look into what this op does.

What all_gather_copy_in_stream does. Code: foreach_all_gather() - _fsdp_collectives.py
# Phase 1: Copy-in (runs on copy_in_stream)
with device_handle.stream(all_gather_copy_in_stream):
param_all_gather_inputs = _get_param_all_gather_inputs(fsdp_params)
# ... dtype/metadata handling ...
inp_split_sizes = [t.numel() for t in all_gather_inputs]
all_gather_input_numel = sum(inp_split_sizes)
# Allocate the full AG output buffer (world_size * input_numel)
all_gather_output = all_gather_comm.allocate(
(all_gather_input_numel * world_size,), dtype=dtype, device=device
)
# Copy shards into the buffer; input is a *view into* output
all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(
all_gather_inputs, all_gather_output,
inp_split_sizes, all_gather_input_numel, rank,
)
del param_all_gather_inputs

Note that the local all_gather_input tensor is just a view into the local all_gather_output tensor. As the torch.tensor.narrow API says “it returns a new tensor and the returned tensor and input tensor share the same underlying storage”. In this way, by copying into all_gather_input, we directly use the corresponding part of the local output and copy our input tensor to that part for all-gather input.

Here’s the visualization. Say we have world_size=4, this is the outcome of the copy-in step:

Rank 0's all_gather_output:  [p0_data | garbage | garbage | garbage]
Rank 1's all_gather_output: [garbage | p1_data | garbage | garbage]
Rank 2's all_gather_output: [garbage | garbage | p2_data | garbage]
Rank 3's all_gather_output: [garbage | garbage | garbage | p3_data]

After getting the proper position in the local output tensor to store the local input tensor, Pytorch copies local tensors into that slot. We do this because parameters of a layer have diverse shape, and we need to convert them to a flattened 1D tensor (all_gather_input) to use a single communication op.

What torch.ops.fsdp.all_gather_copy_in does. Code: all_gather_copy_in_cuda() - _fsdp_collectives.py
# Op signature
lib.define("""
all_gather_copy_in(
Tensor[] all_gather_inputs, # list of per-param sharded tensors
Tensor all_gather_output, # pre-allocated buffer (input_numel * world_size)
SymInt[] inp_split_sizes, # numel of each input tensor
SymInt all_gather_input_numel,# sum of inp_split_sizes
SymInt rank # this rank's index in the process group
) -> (Tensor, Tensor) # returns (all_gather_input, all_gather_output)
""")

# CUDA implementation
def all_gather_copy_in_cuda(
all_gather_inputs, all_gather_output,
inp_split_sizes, all_gather_input_numel, rank,
) -> tuple[torch.Tensor, torch.Tensor]:
# Step 1: View into this rank's slot within the output buffer.
# The buffer is divided into `world_size` equal chunks; rank `r`
# writes into the chunk at offset `r * input_numel`.
all_gather_input = all_gather_output.narrow(
0, all_gather_input_numel * rank, all_gather_input_numel
)
# Step 2: Split this rank's slot into per-parameter-sized pieces
foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)

# Step 3: Bulk-copy all sharded params into their slots.
# foreach_copy_ is more efficient than individual copies —
# it batches multiple small copies into fewer kernel launches.
with torch.no_grad():
torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs)

# Returns both:
# - all_gather_input: this rank's slice (view into output), used as
# the `input_tensor` arg to dist.all_gather_into_tensor()
# - all_gather_output: the full buffer, used as the `output_tensor`
return all_gather_input, all_gather_output

All-Gather

Then, dist.all_gather_into_tensor is called (all_gather_comm is just a thin wrapper around it). This is the all-gather result:

Rank 0's all_gather_output:  [p0_data | p1_data | p2_data | p3_data]
Rank 1's all_gather_output: [p0_data | p1_data | p2_data | p3_data]
Rank 2's all_gather_output: [p0_data | p1_data | p2_data | p3_data]
Rank 3's all_gather_output: [p0_data | p1_data | p2_data | p3_data]
What all_gather_stream does. Code: foreach_all_gather() - _fsdp_collectives.py
with device_handle.stream(all_gather_stream):
all_gather_work = all_gather_comm(
output_tensor=all_gather_output,
input_tensor=all_gather_input,
group=group,
async_op=async_op,
)
# Record event right after issuing the collective
all_gather_event = all_gather_stream.record_event()
return AllGatherResult(
all_gather_output, all_gather_event, all_gather_work,
param_all_gather_input_dtypes, param_all_gather_input_numels,
inp_split_sizes,
)

Copy-out

Remember two facts about the all_gather_output tensor:

  1. it is a 1-D flattened tensor
  2. it is owned by the all_gather_copy_in_stream (it is allocated in the context of copy-in stream), but the forward computation happens on the default stream

See what it means by “a tensor is owned by a stream” here. Here is the place record_stream can be used and how FSDP2 avoids the use of record_stream. In short, the computation on the main stream needs to be done after the all_gather_stream finishes all-gather and before the all_gather_output tensor is freed. A way to orchestrate the life cycle of a tensor across streams is record_stream which notifies the freer that another stream uses the tensor. The problem of record_stream is that it can lead to non-deterministic memory spike from the user side and thus make our memory management harder. See here for why. In short, the del command will be delayed because of record_stream, and we don’t know when the del will actually happen. However, we are still allocating new tensor for the following layer, so there may be non-deterministic memory spike (we think we have del’ed something, but actually we don’t).

Fact 1 indicates that we need to convert it into actual parameter shape for forward computation. Fact 2 indicates that we need to allocate new tensors on the default stream and copy the all_gather_output tensor to these new tensors.

Pytorch does these things in copy-out. If you are clever, you must know that copy-out happens on the default stream (torch.cuda.current_stream()). It does two things: (1) allocate new tensors on the default stream (so they are owned by the default stream, avoiding cross-stream synchronization) and (2) fill them by copying from all_gather_output.

+--------------------------------------------------+
| all_gather_output | <- will be freed after copy-out
+--------------------------------------------------+ owned by all_gather_stream
| | |
v v v
split_with_sizes_copy (actual memcpy)

+----------+ +----------+ +----------+
| out[0] | | out[1] | | out[2] | <- allocated on default stream
| param_0 | | param_1 | | param_2 | owned by default stream and used in forward compute
+----------+ +----------+ +----------+
Copy-out code. Code: foreach_all_gather_copy_out() - _fsdp_collectives.py
@torch.no_grad()
def foreach_all_gather_copy_out(
all_gather_result: AllGatherResult,
fsdp_params: list[FSDPParam],
group: dist.ProcessGroup,
) -> None:
(
all_gather_output,
all_gather_event, # ← E_ag from Phase 2
all_gather_work,
param_all_gather_input_dtypes,
param_all_gather_input_numels,
all_gather_input_split_sizes,
) = all_gather_result
_dtype, device = all_gather_output.dtype, all_gather_output.device
device_handle = _get_device_handle(device.type)

# ── KEY SYNC: default stream waits for all-gather to complete ───────────
if all_gather_event is not None: # sync op (normal path)
device_handle.current_stream().wait_event(all_gather_event)
if isinstance(all_gather_work, dist.distributed_c10d.Work): # async op
all_gather_work.wait()

world_size, device = group.size(), all_gather_output.device

# Allocate per-parameter output tensors (in default stream!)
split_with_sizes_out: list[torch.Tensor] = []
for all_gather_input_numels, all_gather_input_dtypes, fsdp_param in zip(
param_all_gather_input_numels, param_all_gather_input_dtypes, fsdp_params
):
fsdp_param.init_all_gather_outputs(
all_gather_input_numels, all_gather_input_dtypes, world_size, device,
)
fsdp_param.alloc_all_gather_outputs()
# ... handle non-dim-0 sharding ...
split_with_sizes_out.extend(fsdp_param.all_gather_outputs)

# Copy from AG buffer → per-parameter tensors
all_gather_output = all_gather_output.view(world_size, -1)
out = [t.view(world_size, -1) for t in split_with_sizes_out]
# ... split_with_sizes_copy ...

Reshard after forward

In the post-forward method, the copy-out result (per-parameter un-sharded tensor that lives in FSDPParam.all_gather_outputs for all FSDPParam’s in a parameter group) is freed. The free only frees underlying data but keeps the metadata of the tensor, see more details here. The 1D flattened shared all-gather buffer all_gather_output is also freed.

Recall: the former (1) is owned by the default stream, (2) is allocated in copy-out, (3) has normal parameter shape and is used in forward pass; the latter (1) is owned by the copy-in stream, (2) is allocated in copy-in, (3) is used to store the all-gather result.

Stream dependencies

The stream sync mechanism between implicit and explicit prefetch is different. See here for the difference from the user side. Explicit prefetch allows user to specify what to fetch to overlap the communication and computation.

Implicit Prefetch

The call stack of the cuda stream sync mechanism in implicit prefetch is shown in the figure below.

We can visualize the GPU stream synchronization mechanism in implicit prefetch based on the figure above. See the figure below. The key conclusion is that the peak memory bound for parameter storage is 3x, and FSDPParamGroup._wait_all_gather_stream_on_event() ensures this.

One important thing to understand is that all CUDA operations are asynchronously dispatched from the CPU. The CPU never waits for GPU completion. Therefore, although on the CPU free_storage(per-param 0) in post_forward(0) still executes before alloc(AG buffer 1) in unshard(1), on the GPU things can happen in a reverse order. For example (in theory), if split_with_sizes_copy (copy-out) takes a long time and alloc(AG buffer) and copy-in take very short time, the copy-in stream is way over the default stream and can cause 3x memory bound because alloc(AG buffer 1) happens before free_storage(per-param 0).

Explicit Prefetch

For explicit prefetch, more layers are unsharded (copy-in + all-gather) in advance. The code stays in FSDPState._pre_forward() in _fsdp_state.py.

for fsdp_param_group in self._fsdp_param_groups:
# `pre_forward` calls `unshard` + `wait_for_unshard`
# where unshard = copy-in + all-gather; wait_for_unshard = copy-out
args, kwargs = fsdp_param_group.pre_forward(module, args, kwargs)
for fsdp_state in self._states_to_forward_prefetch:
for target_param_group in fsdp_state._fsdp_param_groups:
# `_prefetch_unshard` calls `unshard`
FSDPParamGroup._prefetch_unshard(target_param_group, "forward")

Take the example in pytorch doc. The user asks FSDP to prefetch two layers in advance. The call stack is illustrated in the figure below. Take layer 0 as an example. The self._fsdp_param_groups = [layer_0_param], and self._states_to_forward_prefetch of layer 0 is the FSDPState of layer 1 and 2, and thus the param groups (target_param_group) will be called to do _prefetch_unshard (which calls unshard only). Therefore, “prefetch” in forward means do copy-in and all-gather.

One more thing

Functionally, wait_stream is equivalent to recording an event on stream A at the current point, then having stream B wait on that event. In fact, that’s how it’s implemented internally in CUDA. The difference is purely about whether you need to reuse the event later. At the copy-in to all-gather boundary, the event is consumed immediately by one stream and never referenced again. So wait_stream is simpler — no need to create an event object, name it, or pass it around.

Backward

Before going into the system, let’s think about what should be done in FSDP backward pass from the algorithm perspective.

  1. We need full parameter of a layer to do backprop on this layer. This leads to an all-gather on weights.
  2. Then weights and their grad should be sharded again, while each grad should accumulate the grad from all DP ranks (because each DP rank processes a micro-batch of data), effectively leading to a reduce-scatter.

For the backward pass, we only analyze the HSDP + implicit prefetch part. HSDP backward is strictly more complicated than FSDP backward, so after reading this blog, you can understand FSDP backward easily. In short, HSDP does DDP across nodes and FSDP within each node. That is, we duplicate our model across nodes, and only do FSDP within each node. This avoids inter-node reduce-scatter, only requiring intra-node all-reduce and inter-node reduce-scatter. The rationale behind this is inter-node communication is much slower and we don’t want it to be the bottleneck.

If you are clever, you can know that if we do HSDP, then we will need:

  1. Accumulate all weights across nodes because different nodes process different micro-batches. This effectively leads to an all-reduce (not reduce-scatter because each node hosts a full model - the DDP setting).

How backward hook works

Pre-backward: all-gather the full weights

Before computing the backward pass of a parameter group, we need to unshard that parameter group. The unsharding logic is similar to that of pre-forward with small differences.

  1. Each parameter group still does unshard (copy-in and all-gather) and wait_for_unshard (copy-out).
  2. Difference (1): In wait_for_unshard, the next all-gather strictly waits for the last copy-out (see the self._training_state == TrainingState.FORWARD branch in FSDPParamGroup.wait_for_unshard). Why do this? In backward pass, GPU memory is in greater shortage. This reduces parameter storage memory peak from 3x to 2x.
  3. Difference (2): After the normal unshard and wait_for_unshard, implicit prefetch will do unshard the next layer (default prefetch). For instance, layer will prefetch layer . Why do this? I think there’s no big difference with or without default prefetch because CPU asynchronously dispatches command to streams. Without default dispatch, the CPU dispatches unshard(n-1) -> wait_for_unshard(n-1) -> backprop(n-1) -> post-backprop(n-1) -> unshard(n-2); with default dispatch, the CPU dispatches unshard(n-1) -> wait_for_unshard(n-1) -> unshard(n-2) -> backprop(n-1) -> post-backprop(n-1). The difference is that unshard(n-2) is committed on CPU earlier. However, if CPU runs fast (it’s true because no CPU-GPU sync is conducted), this command will not be delayed for too much time on the copy-in stream.

Post-backward: reduce-scatter the grad tensor

It is split in three stages as well: copy-in, reduce-scatter communication and copy-out.

The copy-in step copies the grad tensors of all parameters in the parameter group into a single 1D flattened reduce-scatter buffer called reduce_scatter_input. The main goal is to build the input format so that the NCCL reduce-scatter operation can be called.

The reduce-scatter communication calls NCCL op. It is done on the reduce_scatter_stream. This stream allocates and thus owns the reduce_output tensor. Visualization:

Rank 0's input:    [A0 | A1 | A2 | A3] (`reduce_scatter_input` on rank 0)
Rank 1's input: [B0 | B1 | B2 | B3]
Rank 2's input: [C0 | C1 | C2 | C3]
Rank 3's input: [D0 | D1 | D2 | D3]

After reduce_scatter (op=SUM):

Rank 0's output: [A0+B0+C0+D0] (`reduce_output` on rank 0)
Rank 1's output: [A1+B1+C1+D1]
Rank 2's output: [A2+B2+C2+D2]
Rank 3's output: [A3+B3+C3+D3]

The copy-out creates a view (using torch.as_strided) into the local reduce_output for each sharded parameter and binds it with the sharded_param.grad. Of course, it frees the reduce_scatter_input tensor after the copy-out is done.

If HSDP is applied, there will be another all_reduce_stream to do the all-reduce step across nodes after the intra-node reduce-scatter. If no HSDP and pure FSDP is applied, this step is skipped and all_reduce_stream is not applied. See the HSDP branch for details.

HSDP branch in foreach_reduce() in _fsdp_collectives.py
if all_reduce_group is not None:  # HSDP or DDP/replicate

...

post_reduce_stream = all_reduce_stream

if world_size >= 1:
all_reduce_stream.wait_stream(reduce_scatter_stream)
else:
all_reduce_stream.wait_stream(current_stream)

# Do the all-reduce
with device_handle.stream(all_reduce_stream):
dist.all_reduce(
reduce_output,
group=all_reduce_group,
op=all_reduce_op,
)
all_reduce_input = reduce_output
all_reduce_event = all_reduce_stream.record_event()

Cuda stream synchronization in backward pass

A summary of the cuda streams sync mechanism

The FSDPCommContext data structure includes all cuda streams (except for the default stream which can be obtained through torch.cuda.current_stream()) and shared buffers used in FSDP2.

class FSDPCommContext:

def lazy_init(self, device: torch.device):
self.device_handle = _get_device_handle(device.type)
# Setting the all-gather/reduce-scatter streams to be higher priority
# can help avoid some issues where their copies in/out are delayed and
# block computation (this is different from high-pri NCCL streams)
high_priority = -1
# All-gather state and copy-in stream allow overlapping the next
# copy-in with the current all-gather in forward; copy-in overlaps with
# reduce-scatter in backward without the separate copy-in stream
self.all_gather_copy_in_stream = self.device_handle.Stream(
priority=high_priority
)
# All-gather stream allows overlapping next all-gather with current
# forward compute
self.all_gather_stream = self.device_handle.Stream(priority=high_priority)
# Reduce-scatter stream gives separate execution "thread" for post-
# backward logic like pre/post-gradient division and reduce-scatter
self.reduce_scatter_stream = self.device_handle.Stream(priority=high_priority)
# Run the HSDP all-reduces concurrently with all-gather/reduce-scatter
# since collectives use different network resources and can overlap
# in the typical intra-node sharding / inter-node replication case
self.all_reduce_stream = self.device_handle.Stream()
# All-gather/reduce-scatter states keep references to collective
# tensors produced in one stream and used in another and accompanying
# CUDA events for synchronization
self.all_gather_state: AllGatherState | None = None
self.reduce_scatter_state: ReduceScatterState | None = None
self.post_forward_order: list[FSDPParamGroup] = [] # will cause ref cycles

I would like to emphasize some points and provide some intuition for understanding the whole picture:

  1. How to manage peak memory bound? Use event after memory free and let a cuda stream that is going to claim some memory slots to wait for that event. For example, in the wait_for_unsharded of the pre-forward pass, the copy-in stream will wait for the previous copy-out stream to complete (thus freeing the FSDPParam.all_gather_outputs) and then do the copy-in (where all_gather_output is allocated). This also makes peak memory deterministic from the user side, solving the issue of record_stream effectively.
  2. Use “views” into a flattened shared tensor buffer to build the bridge between nD-shape parameters and NCCL op input. Use free_storage to free the underlying data without destroying the Tensor metadata.

FSDP and ZeRO

ZeRO Overview

FSDP implements the ideas from the ZeRO (Zero Redundancy Optimizer) paper by DeepSpeed. ZeRO has three stages, each progressively sharding more training state across data-parallel ranks:

Stage What is sharded
ZeRO-1 Optimizer states (e.g., Adam’s m and v)
ZeRO-2 Optimizer states + gradients
ZeRO-3 Optimizer states + gradients + parameters

FSDP = ZeRO-3

FSDP implements ZeRO Stage 3. It shards optimizer states, gradients, and parameters across the data-parallel group. This is exactly what we’ve been looking at throughout this blog:

ZeRO++ and HSDP

ZeRO++ (paper) extends ZeRO-3 with three optimizations to reduce communication:

  1. Quantized weights: All-gather parameters in lower precision (e.g., INT8) to reduce all-gather volume.
  2. Hierarchical partitioning: Maintain a secondary full copy of parameters within each node, so all-gather only happens intra-node (faster NVLink) instead of inter-node (slower network).
  3. Quantized gradients: Reduce-scatter gradients in lower precision.

PyTorch’s HSDP (Hybrid Sharded Data Parallel) is related to the hierarchical partitioning idea. HSDP uses a 2D mesh: parameters are sharded within a “shard group” (e.g., intra-node), and gradients are all-reduced across “replicate groups” (e.g., inter-node).