From 909d053f72bbc20fade6e5ab64a98bb229c9dc39 Mon Sep 17 00:00:00 2001 From: Geet Sethi Date: Thu, 8 Aug 2024 19:30:40 -0700 Subject: [PATCH] distributed init updates (addition of non-meta-tensor/rank0-broadcast path) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: add non-meta tensor initialization path to distributed init. in this path, the model is materialized by all processes on the cpu (duplicated—resulting in proportionally increased cpu memory requirements) instead of just by rank 0. afterwards, all modules/buffers/params are iterated over and sharded as previously; however, instead of rank0 first broadcasting its tensor data to all other meta-tensor ranks, all ranks now simply copy their own cpu-init'd data onto their corresponding gpu and continue. Reviewed By: xiecong Differential Revision: D57022318 --- d2go/modeling/api.py | 19 ++++--- d2go/trainer/fsdp.py | 118 ++++++++++++++++++++++++++++--------------- 2 files changed, 89 insertions(+), 48 deletions(-) diff --git a/d2go/modeling/api.py b/d2go/modeling/api.py index 1d7e1c03..8124fe5e 100644 --- a/d2go/modeling/api.py +++ b/d2go/modeling/api.py @@ -63,7 +63,7 @@ def build_d2go_model( and "FSDPModelingHook" in cfg.MODEL.MODELING_HOOKS and hasattr(cfg, "FSDP") and hasattr(cfg.FSDP, "DISTRIBUTED_INIT") - and cfg.FSDP.DISTRIBUTED_INIT + and cfg.FSDP.DISTRIBUTED_INIT.ENABLED ): logger.info("Using distributed initialization path.") import torch.distributed as dist @@ -72,13 +72,18 @@ def build_d2go_model( from d2go.trainer.fsdp import CpuOverrideMode from torch._subclasses import FakeTensorMode - # NOTE (global) rank 0 will build the whole model on cpu - # other ranks will build the model on fake tensors - if dist.get_rank() == 0: - with CpuOverrideMode(): - model = build_meta_arch(cfg) + if cfg.FSDP.DISTRIBUTED_INIT.RANK0_BROADCAST: + # NOTE (global) rank 0 will build the whole model on cpu + # other ranks will build the model on fake tensors + if dist.get_rank() == 0: + with CpuOverrideMode(): + model = build_meta_arch(cfg) + else: + with FakeTensorMode(allow_non_fake_inputs=True): + model = build_meta_arch(cfg) else: - with FakeTensorMode(allow_non_fake_inputs=True): + # all ranks will build the model on cpu first + with CpuOverrideMode(): model = build_meta_arch(cfg) else: raise RuntimeError( diff --git a/d2go/trainer/fsdp.py b/d2go/trainer/fsdp.py index d3a46e8e..e311b0b1 100644 --- a/d2go/trainer/fsdp.py +++ b/d2go/trainer/fsdp.py @@ -6,6 +6,7 @@ from typing import Any, Dict, Generator, List, Optional, Set, Tuple import torch +import torch.distributed as dist import torch.nn as nn from d2go.config import CfgNode as CN from d2go.modeling.modeling_hook import ModelingHook @@ -71,7 +72,12 @@ def add_fsdp_configs(_C: CN): # if False, this allows the CPU thread to schedule all-gathers without any extra synchronization _C.FSDP.LIMIT_ALL_GATHERS = False # flag for distributed FSDP model initialization - _C.FSDP.DISTRIBUTED_INIT = False + _C.FSDP.DISTRIBUTED_INIT = CN() + _C.FSDP.DISTRIBUTED_INIT.ENABLED = False + # whether to build full model on rank 0 cpu only + # and meta model on all other ranks + _C.FSDP.DISTRIBUTED_INIT.RANK0_BROADCAST = False + _C.FSDP.DISTRIBUTED_INIT.VERBOSE = False class ShardingAlgorithm(str, Enum): @@ -111,9 +117,9 @@ def get_grad_scaler(cfg): return ShardedGradScaler() if is_fsdp_enabled(cfg) else GradScaler() -def bottom_up_nested_fsdp(root_module, fsdp_kwargs: Dict[str, Any]): - import torch.distributed as dist - +def bottom_up_nested_fsdp( + root_module, fsdp_kwargs: Dict[str, Any], rank0_broadcast: bool, verbose: bool +): modules_to_fsdp: Tuple = tuple(fsdp_kwargs["auto_wrap_policy"]._module_classes) del fsdp_kwargs["auto_wrap_policy"] modules_not_to_fsdp: List = fsdp_kwargs["ignored_modules"] @@ -129,6 +135,8 @@ def postorder_fsdp_wrap( fqn: str, parent_module: Optional[nn.Module], ignore_branch: bool, + rank0_broadcast: bool, + verbose: bool, ): rank = dist.get_rank() @@ -144,58 +152,71 @@ def postorder_fsdp_wrap( f"{fqn}.{child_name}", module, ignore_branch, + rank0_broadcast, + verbose, ) - logger.info( - f"(Distributed FSDP init) Rank {rank} Beginning processing module: {fqn}" - ) + if verbose: + logger.info( + f"(Distributed FSDP init) Rank {rank} Beginning processing module: {fqn}" + ) # regardless of wrapping, we need to transfer all # module params and buffers to device, and if not rank 0, # need to retreive data from rank 0 with torch.no_grad(): - if rank != 0: - with no_dispatch(): - for name, param in module.named_parameters(recurse=False): - setattr( - module, - name, - torch.nn.Parameter( - torch.empty_like(param, device=cuda_device), - requires_grad=param.requires_grad, - ), - ) - for name, buffer in module.named_buffers(recurse=False): - setattr( - module, - name, - torch.empty_like( - buffer, - device=cuda_device, - requires_grad=buffer.requires_grad, - ), - ) + if rank0_broadcast: + if rank != 0: + with no_dispatch(): + for name, param in module.named_parameters(recurse=False): + setattr( + module, + name, + torch.nn.Parameter( + torch.empty_like(param, device=cuda_device), + requires_grad=param.requires_grad, + ), + ) + for name, buffer in module.named_buffers(recurse=False): + setattr( + module, + name, + torch.empty_like( + buffer, + device=cuda_device, + requires_grad=buffer.requires_grad, + ), + ) + else: + for _, param in module.named_parameters(recurse=False): + param.data = param.to(cuda_device) + for _, buffer in module.named_buffers(recurse=False): + buffer.data = buffer.to(cuda_device) + for _, param in module.named_parameters(recurse=False): + dist.broadcast(param, 0) + for _, buffer in module.named_buffers(recurse=False): + dist.broadcast(buffer, 0) else: for _, param in module.named_parameters(recurse=False): param.data = param.to(cuda_device) for _, buffer in module.named_buffers(recurse=False): buffer.data = buffer.to(cuda_device) - for _, param in module.named_parameters(recurse=False): - dist.broadcast(param, 0) - for _, buffer in module.named_buffers(recurse=False): - dist.broadcast(buffer, 0) # if module is marked for FSDP, wrap it # AND if not in ignored branch if not ignore_branch and isinstance(module, modules_to_fsdp): - logger.info( - f"(Distributed FSDP init) Rank {rank} FSDP Wrapping module: {fqn}" - ) + if verbose: + logger.info( + f"(Distributed FSDP init) Rank {rank} FSDP Wrapping module: {fqn}" + ) setattr(parent_module, module_name, FSDP(module, **fsdp_kwargs)) - logger.info( - f"(Distributed FSDP init) Rank {rank} Finished processing module: {fqn}" - ) + if verbose: + logger.info( + f"(Distributed FSDP init) Rank {rank} Finished processing module: {fqn}" + ) - postorder_fsdp_wrap(root_module, "root", "root", None, False) + postorder_fsdp_wrap( + root_module, "root", "root", None, False, rank0_broadcast, verbose + ) class FSDPWrapper(FSDP): @@ -208,6 +229,8 @@ def __init__( state_dict_cpu_offload: bool = True, state_dict_rank0_only: bool = True, distributed_init: bool = False, + distributed_init_rank0_broadcast: bool = False, + distributed_init_verbose: bool = False, **fsdp_kwargs, ): self.precision = amp_autocast_dtype @@ -220,7 +243,14 @@ def __init__( if self.distributed_init: # NOTE traverse and apply all non-root level FSDP # and then wrap root level FSDP - bottom_up_nested_fsdp(model, fsdp_kwargs) + logger.info(f"(Distributed FSDP init) Rank {dist.get_rank()} Beginning") + bottom_up_nested_fsdp( + model, + fsdp_kwargs, + rank0_broadcast=distributed_init_rank0_broadcast, + verbose=distributed_init_verbose, + ) + logger.info(f"(Distributed FSDP init) Rank {dist.get_rank()} Finished") super().__init__(model, **fsdp_kwargs) logger.info(f"FSDP Wrapped model architecture: {self}") @@ -289,6 +319,8 @@ def build_fsdp( device_id: Optional[int] = None, limit_all_gathers: bool = False, distributed_init: bool = False, + distributed_init_rank0_broadcast: bool = False, + distributed_init_verbose: bool = False, ): if sharding_algorithm == ShardingAlgorithm.SHARD_GRAD_OP: sharding_strategy = ShardingStrategy.SHARD_GRAD_OP @@ -365,6 +397,8 @@ def build_fsdp( "state_dict_cpu_offload": state_dict_cpu_offload, "state_dict_rank0_only": state_dict_rank0_only, "distributed_init": distributed_init, + "distributed_init_rank0_broadcast": distributed_init_rank0_broadcast, + "distributed_init_verbose": distributed_init_verbose, } return FSDPWrapper(model, **wrapper_kwargs, **fsdp_kwargs) @@ -422,7 +456,9 @@ def apply(self, model: nn.Module) -> FSDPWrapper: use_orig_params=self.cfg.FSDP.USE_ORIG_PARAMS, device_id=torch.cuda.current_device(), limit_all_gathers=self.cfg.FSDP.LIMIT_ALL_GATHERS, - distributed_init=self.cfg.FSDP.DISTRIBUTED_INIT, + distributed_init=self.cfg.FSDP.DISTRIBUTED_INIT.ENABLED, + distributed_init_rank0_broadcast=self.cfg.FSDP.DISTRIBUTED_INIT.RANK0_BROADCAST, + distributed_init_verbose=self.cfg.FSDP.DISTRIBUTED_INIT.VERBOSE, ) return wrapped_model