Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,8 @@ def _replace_with_config(self, child, name):
# No matching spec found
if self.partition_config.strict_mode:
raise ValueError(f"No matching spec for {param_name}")
# Default: column parallel for Linear layers
spec = TPLayerSpec(patterns=[], partition_type=PartitionType.COLUMN)
# With partition_config, rely only on explicit specs and skip unmatched layers.
return child

setattr(child, "replaced", True)

Expand All @@ -439,6 +439,8 @@ def _replace_with_config(self, child, name):

def _create_row_parallel_layer(self, module, spec: TPLayerSpec, name: str):
"""Create row-parallel layer (AllReduce after forward)."""
if self.conv_linear_layer:
return Conv_LinearALlreduce(module, self.mp_group, name=name)
# Check for lm_head / embed_out
if name == "lm_head" or name == 'embed_out':
return LmHeadLinearAllreduce(module, self.mp_group)
Expand All @@ -455,6 +457,12 @@ def _create_row_parallel_layer(self, module, spec: TPLayerSpec, name: str):

def _create_column_parallel_layer(self, module, spec: TPLayerSpec, name: str):
"""Create column-parallel layer (AllReduce in backward)."""
if self.conv_linear_layer:
return conv_LinearLayer(module, self.mp_group, name=name)
# Only use fused-QKV heuristics when no partition_config is provided.
elif self.partition_config is None and require_tp_fused_qkvw(name, self.mp_size):
# Check and handle fused qkv for TP
return fused_LinearLayer(module, self.mp_group, fused_module=self.module)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these fix exposed by a test? i.e. a model with conv linear layer or fused qkv weight.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch! I added a test to validate we use the layers for new custom patterns when a partition is given.

if spec.shape is not None:
return SubParamLinearLayer(
module,
Expand Down Expand Up @@ -488,6 +496,7 @@ def _get_model_type(self) -> Optional[str]:
def _slice_embedding(self, child, name, conv_linear_layer):
if getattr(child, "replaced", False) == True:
return

mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)

if hasattr(child.weight, 'ds_tensor'):
Expand Down Expand Up @@ -551,7 +560,30 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''):
continue
if len(child._buffers) != 0 and self.state_dict is not None:
Loading.load_buffer(child, self.state_dict, checking_key)
if child.__class__ in self.linear_policies:

# When using partition_config (custom patterns/presets), use pattern-based routing
# instead of linear_policies. This keeps all pattern logic centralized here.
if self.partition_config is not None:
full_name = prev_name + '.' + name if prev_name else name
if isinstance(child, nn.Embedding):
# Check if embedding matches any pattern
param_name = full_name + ".weight"
model_type = self._get_model_type()
spec = self.partition_config.find_matching_spec(param_name, model_type)
if spec is not None and spec.partition_type != PartitionType.SKIP:
new_child = self._slice_embedding(child, full_name, False)
if new_child is not None:
setattr(r_module, name, new_child)
# If no pattern matched or skip, leave embedding unchanged
elif hasattr(child, "weight") and getattr(child.weight, "dim", lambda: 0)() == 2:
new_child = self._replace_with_config(child, full_name)
if new_child is not None:
setattr(r_module, name, new_child)
else:
self.update_mp_params(child)
self._replace_module(child, name, class_name)
# Traditional path: use linear_policies for type-based routing
elif child.__class__ in self.linear_policies:
setattr(r_module, name, self.linear_policies[child.__class__](child, prev_name + '.' + name,
self.conv_linear_layer))
elif any(isinstance(child, lp) for lp in self.linear_policies):
Expand Down
4 changes: 3 additions & 1 deletion deepspeed/runtime/tensor_parallel/init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def merge_tp_model_init_into_config(config_dict: dict, mpu, mesh_param, dist_mod
if tp_group is not None and mpu is not None:
raise ValueError("tp_model_init provided tp_group; deepspeed.initialize must not receive mpu.")
if tp_group is None and mpu is None and mesh_param is None:
raise ValueError("tp_model_init did not provide tp_group; deepspeed.initialize requires mpu or mesh_param.")
# Auto-create TP groups for compatibility with HF Trainer (mpu is not passed).
from deepspeed.utils import groups
groups._init_tp_mesh_device(tensor_model_parallel_size=tp_size)

tp_section = config_dict.get("tensor_parallel")
if tp_section is None:
Expand Down
104 changes: 104 additions & 0 deletions tests/unit/model_parallelism/test_autotp_custom_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,29 @@ def forward(self, x):
return x


class CustomLinearModule(torch.nn.Module):

def __init__(self, hidden_dim):
super(CustomLinearModule, self).__init__()
self.weight = torch.nn.Parameter(torch.empty(hidden_dim, hidden_dim))
self.bias = torch.nn.Parameter(torch.empty(hidden_dim))
torch.nn.init.uniform_(self.weight, -0.02, 0.02)
torch.nn.init.uniform_(self.bias, -0.02, 0.02)

def forward(self, x):
return torch.matmul(x, self.weight.transpose(-1, -2)) + self.bias


class CustomLinearModel(torch.nn.Module):

def __init__(self, hidden_dim):
super(CustomLinearModel, self).__init__()
self.custom = CustomLinearModule(hidden_dim)

def forward(self, x):
return self.custom(x)


def init_tp_engine(tp_size, partition_config=None):
config_dict = {
"train_micro_batch_size_per_gpu": 1,
Expand Down Expand Up @@ -178,6 +201,87 @@ def test_custom_patterns_applied_via_config(self):
assert isinstance(engine.module.linears[1], LinearLayer)
assert isinstance(engine.module.linears[2], nn.Linear)

def test_use_default_specs_false_skips_unmatched_layers(self):
skip_on_device()
# Verify unmatched layers remain unsharded when defaults are disabled.
partition_config = {
"use_default_specs":
False,
"layer_specs": [
{
"patterns": [".*linears\\.0\\.weight$"],
"partition_type": "row",
},
{
"patterns": [".*linears\\.1\\.weight$"],
"partition_type": "column",
},
],
}
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-6
}
},
"tensor_parallel": {
"autotp_size": 2,
"partition_config": partition_config,
},
"zero_optimization": {
"stage": 0,
}
}
if preferred_dtype() is torch.float16:
config_dict["fp16"] = {"enabled": True}
elif preferred_dtype() is torch.bfloat16:
config_dict["bf16"] = {"enabled": True}

model = SequentialLinearModel(hidden_dim=16, nlayers=3)
engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
assert isinstance(engine.module.linears[0], LinearAllreduce)
assert isinstance(engine.module.linears[1], LinearLayer)
assert isinstance(engine.module.linears[2], nn.Linear)

def test_custom_module_replacement_with_patterns(self):
skip_on_device()
# Verify custom linear-like modules are partitioned via patterns.
partition_config = {
"use_default_specs": False,
"layer_specs": [
{
"patterns": [".*custom\\.weight$"],
"partition_type": "column",
},
],
}
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-6
}
},
"tensor_parallel": {
"autotp_size": 2,
"partition_config": partition_config,
},
"zero_optimization": {
"stage": 0,
}
}
if preferred_dtype() is torch.float16:
config_dict["fp16"] = {"enabled": True}
elif preferred_dtype() is torch.bfloat16:
config_dict["bf16"] = {"enabled": True}

model = CustomLinearModel(hidden_dim=16)
engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
assert isinstance(engine.module.custom, LinearLayer)

def test_first_match_precedence(self):
skip_on_device()
partition_config = {
Expand Down
21 changes: 17 additions & 4 deletions tests/unit/model_parallelism/test_autotp_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,19 +165,32 @@ def test_tp_model_init_config_autotp_size_mismatch(self):
with pytest.raises(ValueError, match="tensor_parallel.autotp_size"):
deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict, mpu=DummyMPU())

def test_tp_model_init_requires_mpu_or_mesh_param(self):
def test_tp_model_init_autocreates_tp_group(self):
skip_on_device()
reset_tp_model_init_state()
# Verify tp_model_init creates TP groups when no mpu is provided.
model = SimpleModel(hidden_dim=8)
deepspeed.tp_model_init(model, tp_size=1, dtype=preferred_dtype())
tp_size = 2
deepspeed.tp_model_init(model, tp_size=tp_size, dtype=preferred_dtype())
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"tensor_parallel": {
"partition_config": {
"use_default_specs": False,
"layer_specs": [{
"patterns": [".*\\.weight$"],
"partition_type": "skip",
}],
}
},
"zero_optimization": {
"stage": 0,
}
}
with pytest.raises(ValueError, match="requires mpu or mesh_param"):
deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
engine, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
assert engine.autotp_size() == tp_size
assert groups.get_tensor_model_parallel_world_size() == tp_size
assert groups.get_data_parallel_world_size() == dist.get_world_size() // tp_size

def test_tp_model_init_tp_group_rejects_mpu(self):
skip_on_device()
Expand Down
Loading