-
Notifications
You must be signed in to change notification settings - Fork 4.8k
Merging AutoSP into DeepSpeed #7860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
1e3a836
Add AutoSP to DeepSpeed
neeldani 4a194c6
Merge branch 'master' into autosp
sfc-gh-truwase 6f73ea2
Add AutoSP unit and end-to-end tests
spikerheado1234 6ba4117
AutoSP: fix torch 2.9 fake propagation issues (#2)
tohtana cd27fa1
update docs
neeldani e487441
Merge branch 'master' into autosp
tohtana 62ed944
Merge branch 'master' into autosp
PKUWZP 4ef9419
Update torch required version
neeldani 733272c
Revert check for stage 3
neeldani 5c8bf2c
Merge branch 'master' into autosp
tohtana File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # DeepSpeed Team | ||
|
|
||
| ######################################### | ||
| # AUTOSP | ||
| ######################################### | ||
| AUTOSP_INPUT_ID_KEY = "input_id" | ||
| AUTOSP_LABEL_ID_KEY = "label_id" | ||
| AUTOSP_POSITION_ID_KEY = "position_id" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # DeepSpeed Team | ||
|
|
||
| from .all_to_all import all_to_all | ||
| from . import sp_dp_registry | ||
|
|
||
| __all__ = ["all_to_all", "sp_dp_registry", "sp_compat"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # DeepSpeed Team | ||
|
|
||
| import torch | ||
| import deepspeed.comm as dist | ||
| from .sp_dp_registry import get_group, is_setup, sp_size | ||
|
|
||
|
|
||
| @torch.library.custom_op("autosp::all_to_all", mutates_args=()) | ||
| def all_to_all( | ||
| input: torch.Tensor, | ||
| scatter_idx: int, | ||
| gather_idx: int, | ||
| name: str, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| All-to-all collective for SDPA tensors [B, N, S, H]. | ||
|
|
||
| For QKV (scatter_idx=1, gather_idx=2): | ||
| [B, N, S/P, H] -> [B, N/P, S, H] | ||
| For O (scatter_idx=2, gather_idx=1): | ||
| [B, N/P, S, H] -> [B, N, S/P, H] | ||
| """ | ||
| assert is_setup(), 'Incorrect initialization of SP/DP mesh.' | ||
| B, dim1, dim2, H = input.shape | ||
| gid = dist.get_rank() // sp_size() | ||
| group = get_group(gid) | ||
|
|
||
| if scatter_idx == 1: | ||
| N, local_S = dim1, dim2 | ||
| input_t = input.reshape(B, sp_size(), N // sp_size(), local_S, H) | ||
| input_t = input_t.permute(1, 0, 2, 3, 4).contiguous() | ||
|
|
||
| output = torch.empty_like(input_t) | ||
| dist.all_to_all_single(output, input_t, group=group) | ||
|
|
||
| output = output.permute(1, 2, 0, 3, 4).contiguous() | ||
| output = output.reshape(B, N // sp_size(), sp_size() * local_S, H) | ||
| else: | ||
| local_N, S = dim1, dim2 | ||
| input_t = input.reshape(B, local_N, sp_size(), S // sp_size(), H) | ||
| input_t = input_t.permute(2, 0, 1, 3, 4).contiguous() | ||
|
|
||
| output = torch.empty_like(input_t) | ||
| dist.all_to_all_single(output, input_t, group=group) | ||
|
|
||
| output = output.permute(1, 0, 2, 3, 4).contiguous() | ||
| output = output.reshape(B, sp_size() * local_N, S // sp_size(), H) | ||
|
|
||
| return output | ||
|
|
||
|
|
||
| @torch.library.register_fake("autosp::all_to_all") | ||
| def all_to_all_fake(input: torch.Tensor, scatter_idx: int, gather_idx: int, name: str): | ||
| B, dim1, dim2, H = input.shape | ||
| if scatter_idx == 1: | ||
| return input.new_empty(B, dim1 // sp_size(), dim2 * sp_size(), H) | ||
| else: | ||
| return input.new_empty(B, dim1 * sp_size(), dim2 // sp_size(), H) | ||
|
|
||
|
|
||
| def _all_to_all_backward_setup(ctx, inputs, output): | ||
| _, scatter_idx, gather_idx, name = inputs | ||
| ctx.scatter_idx = gather_idx | ||
| ctx.gather_idx = scatter_idx | ||
| ctx.name = name + "_grad" | ||
|
|
||
|
|
||
| def _all_to_all_backward(ctx, grad): | ||
| return (all_to_all(grad, ctx.scatter_idx, ctx.gather_idx, ctx.name), None, None, None) | ||
|
|
||
|
|
||
| torch.library.register_autograd("autosp::all_to_all", _all_to_all_backward, setup_context=_all_to_all_backward_setup) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # DeepSpeed Team | ||
|
|
||
| import torch | ||
| from packaging.version import Version | ||
|
|
||
|
|
||
| def _check_autosp_compatibility(): | ||
| # Strip the local version segment (e.g. +cu128) so CUDA builds don't sort | ||
| # above the max bound when using packaging's local-version ordering rules. | ||
| torch_version = Version(torch.__version__.split("+")[0]) | ||
| if torch_version < Version("2.6") or torch_version >= Version("2.8"): | ||
| raise RuntimeError( | ||
| "AutoSP requires PyTorch >= 2.6 and <= 2.7, found " | ||
| f"{torch.__version__}." | ||
| ) | ||
|
|
||
| try: | ||
| import transformers | ||
| if Version(transformers.__version__) > Version("4.50.3"): | ||
| raise RuntimeError( | ||
| "AutoSP requires transformers <= 4.50.3, found " | ||
| f"{transformers.__version__}." | ||
| ) | ||
| except ImportError: | ||
| pass # transformers not installed; skip the check |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # DeepSpeed Team | ||
|
|
||
| import deepspeed.comm as dist | ||
|
|
||
| GROUP_REGISTRY = {} # int -> dist.ProcessGroup | ||
|
|
||
|
|
||
| def register_groups(groups): | ||
| """groups: List[List[int]], e.g. [[0,1],[2,3]]""" | ||
| for gid, ranks in enumerate(groups): | ||
| if gid not in GROUP_REGISTRY: | ||
| GROUP_REGISTRY[gid] = dist.new_group(ranks) | ||
|
|
||
|
|
||
| def get_group(gid: int): | ||
| return GROUP_REGISTRY[gid] if gid is not None else dist.get_world_group() | ||
|
|
||
|
|
||
| def get_registry(): | ||
| return GROUP_REGISTRY | ||
|
|
||
|
|
||
| def is_setup(): | ||
| return GROUP_REGISTRY['is_reg'] if 'is_reg' in GROUP_REGISTRY else False | ||
|
|
||
|
|
||
| def extract_mesh_size(param_dict): | ||
| sp_size = param_dict.get('sequence_parallel_size', 1) | ||
| assert dist.get_world_size() % sp_size == 0, 'World mesh-size should be divisible by SP_SIZE' | ||
| dp_size = dist.get_world_size() // sp_size | ||
|
|
||
| return sp_size, dp_size | ||
|
|
||
|
|
||
| def sp_size(): | ||
| assert 'SP_SIZE' in GROUP_REGISTRY, 'SP_SIZE not init properly.' | ||
|
|
||
| return GROUP_REGISTRY['SP_SIZE'] | ||
|
|
||
|
|
||
| def dp_size(): | ||
| assert 'DP_SIZE' in GROUP_REGISTRY, 'DP_SIZE not init properly' | ||
|
|
||
| return GROUP_REGISTRY['DP_SIZE'] | ||
|
|
||
|
|
||
| def populate_registry(SP_SIZE, DP_SIZE): | ||
| """ Populate rank to SP/DP mesh index. """ | ||
|
|
||
| if GROUP_REGISTRY.get('is_reg', False): | ||
| return | ||
|
|
||
| group_listing = [] | ||
| offset = 0 | ||
| for _ in range(DP_SIZE): | ||
| group_listing.append([i + offset for i in range(SP_SIZE)]) | ||
| offset += SP_SIZE | ||
|
|
||
| register_groups(group_listing) | ||
|
|
||
| ## Extraneous metadata required for proper instatiation. ## | ||
| GROUP_REGISTRY['SP_SIZE'] = SP_SIZE | ||
| GROUP_REGISTRY['DP_SIZE'] = DP_SIZE | ||
| GROUP_REGISTRY['is_reg'] = True |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # DeepSpeed Team | ||
|
|
||
| import torch | ||
| from torch.fx import GraphModule | ||
| from .passes.sp_compile import apply_autosp | ||
| from .passes.long_context_checkpointing import register_long_context_checkpointing | ||
| from .custom_ops.sp_dp_registry import extract_mesh_size | ||
| from .custom_ops.sp_compat import _check_autosp_compatibility | ||
|
|
||
|
|
||
| def init_autosp(config): | ||
| _check_autosp_compatibility() | ||
| sp_size, dp_size = extract_mesh_size(config._param_dict) | ||
| register_long_context_checkpointing() | ||
|
|
||
| def backend_fn(gm: GraphModule, real_inputs): | ||
| apply_autosp(gm, real_inputs, debug=False, sp_size=sp_size, dp_size=dp_size) | ||
| return torch._inductor.compile(gm, real_inputs) | ||
|
|
||
| return backend_fn |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # DeepSpeed Team | ||
|
|
||
| import inspect | ||
| import textwrap | ||
| import torch._functorch.partitioners as _partitioners | ||
|
|
||
| # The custom should_ban_recomputation to splice into solve_min_cut. | ||
| # All names it references (aten, operator, config, op_types, min_cut_options, | ||
| # is_materialized_backwards, get_aten_target, _size_of, fx, torch, | ||
| # CheckpointPolicy) are either module-level in torch._functorch.partitioners | ||
| # or local variables already in scope when this function executes inside | ||
| # solve_min_cut. | ||
| _CUSTOM_SHOULD_BAN = """\ | ||
| def should_ban_recomputation(node): | ||
| \"\"\"Sequence-aware recomputation banning logic\"\"\" | ||
| if node.op != "call_function": | ||
| return False | ||
| if node.target == operator.getitem: | ||
| return False | ||
| if node.meta.get("recompute", None) == CheckpointPolicy.MUST_SAVE: | ||
| return True | ||
| if config.recompute_views and op_types.is_view(node): | ||
| return False | ||
| if node.target in [aten.lift_fresh_copy.default, aten.lift_fresh.default]: | ||
| return False | ||
|
|
||
| must_save_set = [ | ||
| aten.convolution, | ||
| aten.convolution_backward, | ||
| aten._scaled_dot_product_flash_attention, | ||
| aten._scaled_dot_product_efficient_attention, | ||
| aten._flash_attention_forward, | ||
| aten._efficient_attention_forward, | ||
| aten.upsample_bilinear2d, | ||
| aten.native_dropout, | ||
| aten.rand_like, | ||
| aten.randn_like, | ||
| ] | ||
|
|
||
| if get_aten_target(node) in must_save_set: | ||
| return True | ||
|
|
||
| def heuristic(node): | ||
| if "val" in node.meta: | ||
| if isinstance(node.meta["val"], torch.Tensor) and node.meta["val"].dim() >= 2: | ||
| return node.meta["val"].shape[1] >= 4096 | ||
| return False | ||
|
|
||
| if min_cut_options.ban_if_not_in_allowlist: | ||
| if not op_types.is_recomputable(node): | ||
| return False | ||
|
|
||
| if min_cut_options.ban_if_materialized_backward and is_materialized_backwards(node): | ||
| if heuristic(node): | ||
| return False | ||
| return True | ||
|
|
||
| if node.dist_from_bw < 1000 and node.dist_from_bw > config.max_dist_from_bw: | ||
| return False | ||
|
|
||
| if min_cut_options.ban_if_reduction: | ||
| input_tensors_size = sum( | ||
| _size_of(i) for i in node.args if isinstance(i, fx.Node) | ||
| ) | ||
| output_size = _size_of(node) | ||
| return output_size * 4 < input_tensors_size | ||
| return False | ||
| """ | ||
|
|
||
|
|
||
| def register_long_context_checkpointing(): | ||
| """Splice the custom should_ban_recomputation into solve_min_cut. | ||
|
|
||
| Uses inspect.getsource to extract solve_min_cut's source, replaces the | ||
| original should_ban_recomputation with _CUSTOM_SHOULD_BAN, then execs the | ||
| result directly in _partitioners.__dict__. | ||
|
|
||
| The exec'd function's __globals__ is the real partitioners module dict, so | ||
| every other nested function (is_fusible, is_materialized_backwards, | ||
| can_fuse_into_*, etc.) and every local/closure variable (op_types, | ||
| min_cut_options, node_info, config, …) is exactly as in the original — | ||
| nothing else changes. | ||
|
|
||
| Backward compatible: if solve_min_cut gains new heuristics in a future | ||
| PyTorch version the exec automatically picks them up; only | ||
| _CUSTOM_SHOULD_BAN needs to stay in sync with any changes to the | ||
| original should_ban_recomputation signature/contract. | ||
| """ | ||
| src = inspect.getsource(_partitioners.solve_min_cut) | ||
| lines = src.split('\n') | ||
|
|
||
| # Locate the original should_ban_recomputation and the function after it. | ||
| start = next(i for i, l in enumerate(lines) if l.startswith(' def should_ban_recomputation(')) | ||
| end = next(i for i, l in enumerate(lines) if i > start and l.startswith(' def ')) | ||
|
|
||
| # Indent the replacement to the nesting level inside solve_min_cut (4 spaces). | ||
| replacement = textwrap.indent(_CUSTOM_SHOULD_BAN, ' ') | ||
|
|
||
| new_src = '\n'.join(lines[:start]) + '\n' + replacement + '\n'.join(lines[end:]) | ||
| exec(new_src, _partitioners.__dict__) # redefines _partitioners.solve_min_cut | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.