FSDP2 Small Tricks
Here lives some small tricks/lessons I learnt when reading the FSDP2 code.
Free A Tensor Without Destroying Metadata
In FSDP2, there are two ways to free a Tensor.
tensor = None. For example, it is how the shared all-gather buffer is freed.- Free a tensor with
free_storageand recover the tensor withalloc_storage.
def free_storage(tensor: torch.Tensor) -> None:
if (storage := tensor.untyped_storage()).size() != 0:
storage.resize_(0)
def alloc_storage(tensor: torch.Tensor) -> None:
# this read the metadata, so no seg fault
size = tensor.numel() * tensor.itemsize
if (storage := tensor.untyped_storage()).size() != size:
storage.resize_(size)
The second way keeps the metadata of a tensor while freeing its underlying storage.
If tensors are on GPU, both methods allow cuda memory allocator to reuse the block when the same stream claims the same size of space and thus can avoid cudaFree.
Here is a minimal script to reproduce things:
import torch
# Simulate all_gather_outputs[0]: flat 1D tensor, 12 elements (e.g., 2x2 param, world_size=3)
all_gather_output = torch.randn(12)
print("all_gather_output:", all_gather_output)
# all_gather_output: tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2])
# Simulate init_unsharded_param - 1st iteration
orig_size = torch.Size([3, 4]) # original param shape before sharding
contiguous_stride = (4, 1) # stride for contiguous [3, 4]
unsharded_param = torch.as_strided(
all_gather_output, orig_size, contiguous_stride, storage_offset=0
)
param = torch.nn.Parameter(unsharded_param)
print("param:\n", param)
print(
"same storage?",
param.untyped_storage().data_ptr()
== all_gather_output.untyped_storage().data_ptr(),
)
# True - param is a view of all_gather_output
# =============================
# == Free with free_storage ===
# =============================
all_gather_output.untyped_storage().resize_(0)
print("storage size after free:", all_gather_output.untyped_storage().size()) # 0
print(all_gather_output) # segmentation fault
all_gather_output.untyped_storage().resize_(12)
# param still exists as a Python object, just can't read data
# =============================
# ===== Free with "=None" =====
# =============================
all_gather_output = None
print(param) # this will still print result - so, tensor is not freed actually
all_gather_output = torch.randn(12)
print(param) # the result will not update - data inconsistency
This script shows the advantage of freeing with free_storage. If a tensor x has a lot of views into it, for instance those tensors created with x.narrow() and torch.as_strided(x), then by keeping the metadata of x, those “views” into tensor x can be kept and updated when x is inflated (by calling alloc_storage) and changed. However, if x is freed with x=None, there will be data inconsistency - the “views” will not be updated when we call x = new_tensor.
However, the code maintainer should pay careful attention such that no reading is conducted between free_storage and alloc_storage or segmentation fault will happen and the process will be killed. This is like the unsafe zone in Rust.
Why seg fault happens? Because after resize_(0), the storage has 0 bytes, but the tensor’s metadata (shape, stride) still says it’s a [3, 4] tensor with 12 elements. When you print(), Python calls into the C++ tensor code to read those 12 elements — but the storage backing them is gone. It reads invalid memory and segfaults.
Between the free and alloc, don’t access the underlying data! Just access the metadata!