Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions mergekit/merge_methods/generalized_task_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
MergeTensorInput,
)
from mergekit.sparsify import RescaleNorm, SparsificationMethod, sparsify

from mergekit.subspace_helpers import iso_c, compute_and_sum_svd_mem_reduction, subspace_boosting

class ConsensusMethod(str, Enum):
count = "count"
Expand Down Expand Up @@ -55,6 +55,8 @@ def parameters(self) -> List[ConfigParameterDef]:
name="rescale", required=False, default_value=self.default_rescale
),
ConfigParameterDef(name="lambda", required=False, default_value=1.0),
ConfigParameterDef(name="svd_thresh", required=False, default_value=0.01),
ConfigParameterDef(name="cumsum", required=False, default_value=True),
]

def tensor_parameters(self) -> List[ConfigParameterDef]:
Expand Down Expand Up @@ -96,6 +98,8 @@ def make_task(
lambda_=parameters["lambda"],
rescale_norm=RescaleNorm.l1 if parameters["rescale"] else None,
weight_info=output_weight,
svd_thresh=parameters["svd_thresh"],
cumsum=parameters["cumsum"],
)


Expand All @@ -109,6 +113,8 @@ class GTATask(Task[torch.Tensor]):
normalize: bool
lambda_: float
rescale_norm: Optional[RescaleNorm]
svd_thresh: float
cumsum: bool

def uses_accelerator(self) -> bool:
return True
Expand All @@ -130,7 +136,6 @@ def execute(
)
if not tvs:
return base

# sparsify
if self.method.sparsification_method:
for tv_info in tvs:
Expand All @@ -148,9 +153,8 @@ def execute(
rescale_norm=self.rescale_norm,
**kwargs,
)

deltas = torch.stack([tv["delta"] for tv in tvs], dim=0)

weights = torch.tensor(
[tv["weight"] for tv in tvs], dtype=deltas.dtype, device=deltas.device
)
Expand All @@ -175,6 +179,16 @@ def execute(
divisor = weights.sum(dim=0)
divisor[divisor.abs() < 1e-8] = 1

param_key = self.weight_info.name
subspace_input = [tv["delta"] for tv in tvs]

if self.method.name() == "iso_c":
mixed_delta = iso_c(subspace_input, param_key, deltas.device)
elif self.method.name() == "tsvm":
mixed_delta = compute_and_sum_svd_mem_reduction(subspace_input, param_key, deltas.device)
elif self.method.name() in ["task_arithmetic_sb", "ties_sb"]:
mixed_delta = subspace_boosting(param_key, mixed_delta, svd_thresh=self.svd_thresh, cumsum=self.cumsum)

if self.normalize:
mixed_delta /= divisor

Expand Down
36 changes: 36 additions & 0 deletions mergekit/merge_methods/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,42 @@
method_pretty_name="Linear DELLA",
method_reference_url="https://arxiv.org/abs/2406.11617",
),
GeneralizedTaskArithmeticMerge(
consensus_method=None,
sparsification_method=None,
default_normalize=False,
default_rescale=False,
method_name="tsvm",
method_pretty_name="TSV-M",
method_reference_url="https://arxiv.org/abs/2412.00081",
),
GeneralizedTaskArithmeticMerge(
consensus_method=None,
sparsification_method=None,
default_normalize=False,
default_rescale=False,
method_name="iso_c",
method_pretty_name="ISO-C",
method_reference_url="https://www.arxiv.org/pdf/2502.04959",
),
GeneralizedTaskArithmeticMerge(
consensus_method=None,
sparsification_method=None,
default_normalize=False,
default_rescale=False,
method_name="task_arithmetic_sb",
method_pretty_name="Task Arithmetic with Subspace Boosting",
method_reference_url="https://arxiv.org/abs/2212.04089",
),
GeneralizedTaskArithmeticMerge(
consensus_method=ConsensusMethod.sum,
sparsification_method=SparsificationMethod.magnitude,
default_normalize=True,
default_rescale=False,
method_name="ties_sb",
method_pretty_name="TIES with Subspace Boosting",
method_reference_url="https://arxiv.org/abs/2306.01708",
),
]

REGISTERED_MERGE_METHODS: Dict[str, MergeMethod] = {
Expand Down
204 changes: 204 additions & 0 deletions mergekit/subspace_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import torch
from typing import List, Dict, Any, Optional
import time
import logging

def iso_c(task_vectors: List[torch.Tensor], tv_key: str, device: torch.device) -> Dict[str, Any]:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function signature indicates a return type of Dict[str, Any], but the implementation returns a torch.Tensor. The return type annotation should be updated to match the actual implementation:

def iso_c(task_vectors: List[torch.Tensor], tv_key: str, device: torch.device) -> torch.Tensor:

This will ensure type consistency and help with static type checking.

Suggested change
def iso_c(task_vectors: List[torch.Tensor], tv_key: str, device: torch.device) -> Dict[str, Any]:
def iso_c(task_vectors: List[torch.Tensor], tv_key: str, device: torch.device) -> torch.Tensor:

Spotted by Graphite Agent

Fix in Graphite


Is this helpful? React 👍 or 👎 to let us know.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused device parameter in iso_c function

Low Severity

The iso_c function accepts a device parameter but never uses it anywhere in the function body. All tensor operations naturally inherit their device from the input task_vectors. This is inconsistent with compute_and_sum_svd_mem_reduction, which uses the device parameter to explicitly place accumulator tensors. The unused parameter adds confusion about whether device placement is being controlled.

Fix in Cursor Fix in Web

with torch.no_grad():
tvs = task_vectors
new_vector = sum(tvs) / len(tvs)
original_dtype = new_vector.dtype # Store original dtype

if (len(task_vectors[0].shape) == 2 and "embed_tokens" not in tv_key and "lm_head" not in tv_key):
print(f"Computing SVD for {tv_key}... with shape {task_vectors[0].shape}")
new_vector *= len(tvs)
# Convert to float32 for SVD
vec_fp32 = new_vector.to(torch.float32)
U, S, V = torch.linalg.svd(vec_fp32, full_matrices=False)
S_mean = torch.ones_like(S) * S.mean()

# Perform matrix multiplication in float32 and convert back to original dtype
new_vector = torch.linalg.multi_dot(
(
U,
torch.diag(S_mean),
V,
)
).to(original_dtype) # Convert back to original dtype

return new_vector

###############
#### TSV Merge Orthogonalization
def compute_and_sum_svd_mem_reduction(task_vectors: List[torch.Tensor], tv_key: str, device: torch.device) -> Dict[str, Any]:
"""
Computes the Singular Value Decomposition (SVD) for each vector in the task_vectors,
reduces the dimensionality of the vectors based on the sv_reduction factor, and concatenate
the low-rank matrices. If the vector is not a 2D tensor or is "text_projection", it computes the mean of the vectors.
Computation of the SVD is performed also for the second operation.

Args:
task_vectors (list): A list of task vector objects, where each object contains a
dictionary of vectors.
Returns:
dict: A dictionary containing the new vectors after SVD computation and merging.
"""
sv_reduction = 1 / len(task_vectors)
with torch.no_grad():
new_vector = {}
for i, task_vector in enumerate(task_vectors):
vec = task_vector
original_dtype = vec.dtype # Store original dtype

if (
len(task_vector.shape) == 2
and "embed_tokens" not in tv_key
and "lm_head" not in tv_key
):
print(f"Computing SVD for {tv_key}... with shape {task_vector.shape}")
# Convert to float32 for SVD
vec_fp32 = vec.to(torch.float32)
u, s, v = torch.linalg.svd(vec_fp32, full_matrices=False)

if i == 0:
sum_u = torch.zeros_like(u, device=device, dtype=torch.float32)
sum_s = torch.zeros_like(s, device=device, dtype=torch.float32)
sum_v = torch.zeros_like(v, device=device, dtype=torch.float32)
reduced_index_s = int(s.shape[0] * sv_reduction)

# select only the first reduced_index_s columns of u and place them
sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
:, :reduced_index_s
]
sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
:reduced_index_s
]
# select only the first reduced_index_s rows of v and place them
sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
:reduced_index_s, :
]

else:
if i == 0:
new_vector = vec.clone()
else:
new_vector += (vec - new_vector) / (i + 1)

if (
len(task_vector.shape) == 2
and "embed_tokens" not in tv_key
and "lm_head" not in tv_key
):
# Perform final SVD operations in float32

try:
u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False)
except torch._C._LinAlgError:
print(f"[Retry with 'gesvd'] SVD failed for {tv_key}.")
u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False, driver='gesvd')

try:
u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False)
except torch._C._LinAlgError:
print(f"[Retry with 'gesvd'] SVD failed for {tv_key}.")
u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False, driver='gesvd')

# Perform matrix multiplication in float32
new_vector = torch.linalg.multi_dot(
(
u_u,
v_u,
torch.diag(sum_s),
u_v,
v_v,
)
).to(original_dtype) # Convert back to original dtype

return new_vector

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Inconsistent Init Triggers SVD Runtime Errors

In compute_and_sum_svd_mem_reduction, new_vector is inconsistently initialized as a dict then used for tensor operations, causing runtime errors. The SVD accumulation tensors (sum_u, sum_s, sum_v) are incorrectly sized and can receive zero-length slices from reduced_index_s, leading to out-of-bounds errors. Also, the final SVD condition check uses the last task_vector.

Fix in Cursor Fix in Web


def subspace_boosting(
merged_tv_key: str,
merged_tv: torch.Tensor,
svd_thresh=0.01,
cumsum=True,
) -> Dict[str, Any]:
"""
Subspace boosting for merging task vectors.

Parameters:
tv_flat_checks: Flattened task vectors.
ptm_check:
Pretrained model.
config:
Configuration object containing method parameters (e.g., config.method.k, config.method.use_ties).
reset_thresh: default 20
Threshold parameter used for ties merging. defaults to 20.
svd_thresh: default 0.01
Threshold for singular value boosting. If cumsum is True, used as a cumulative ratio threshold;
otherwise used as a fraction of the total number of singular values. Defaults to 0.01.
cumsum:
Whether to use the cumulative sum approach for thresholding the singular values.
remove_keys:
Optional list of keys to remove from the state dict conversion.

Returns:
A merged flat vector representing the task vector after subspace boosting.

Raises:
ValueError: If the base_method is not one of the defined options.
"""

# Merging approach for attention weight matrices
#apply_to_attn = config.method.apply_to_attn
# apply_to_attn=False: no subspace boosting for attention weights
#if apply_to_attn not in [False, "full_attn", "per_qkv", "per_head"]:
# raise ValueError(f"Apply to attention method {apply_to_attn} not defined.")

keys_to_eval = [
".self_attn.q_proj.weight",
".self_attn.k_proj.weight",
".self_attn.v_proj.weight",
".self_attn.o_proj.weight",
".mlp.gate_proj.weight",
".mlp.up_proj.weight",
".mlp.down_proj.weight",
]

if any(i in merged_tv_key for i in keys_to_eval) and isinstance(merged_tv, torch.Tensor):
print(f"Applying subspace boosting to {merged_tv_key} with shape {merged_tv.shape}")

# Store original dtype
original_dtype = merged_tv.dtype

# Convert to float32 for SVD
merged_tv_fp32 = merged_tv.to(torch.float32)

U, S, Vh = torch.linalg.svd(merged_tv_fp32, full_matrices=False)

if cumsum:
total_sum = S.sum()
cumulative = torch.cumsum(S, dim=0)

thresh = svd_thresh

k = (cumulative / total_sum >= thresh).nonzero(as_tuple=False)

if k.numel() == 0:
# fallback: use smallest singular value
cutoff_idx = -1
print(f"[Warning] No valid SVD cutoff for {merged_tv_key}. Using full singular spectrum.")
else:
cutoff_idx = k[0].item()

S_damped = torch.clamp(S, min=S[cutoff_idx])
else: # Clamping approach using the threshold as an index
cutoff_idx = int(thresh * S.numel())
S_damped = torch.clamp(S, min=S[cutoff_idx])

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Undefined variable causing NameError in subspace_boosting

The subspace_boosting function's else branch (when cumsum is False) uses an undefined thresh variable, causing a NameError; svd_thresh was likely intended. Additionally, in the if cumsum: block's fallback, if no SVD cutoff is found, S is clamped to its smallest singular value. This doesn't modify S, contrary to the full spectrum intent.

Fix in Cursor Fix in Web

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out-of-bounds index when svd_thresh >= 1.0 with cumsum=False

Medium Severity

In the subspace_boosting function's non-cumsum branch, cutoff_idx = int(svd_thresh * S.numel()) computes an index without bounds checking. When svd_thresh >= 1.0, this produces an index equal to or greater than the array length, causing an IndexError on S[cutoff_idx]. Users might naturally try svd_thresh: 1.0 expecting it to represent "use all singular values," but this crashes. The valid range is implicitly [0, 1) but is not enforced.

Fix in Cursor Fix in Web


# Perform matrix multiplication in FP32
merged_tv = (U * S_damped.unsqueeze(0)) @ Vh

# Convert back to original dtype
merged_tv = merged_tv.to(original_dtype)

return merged_tv
Loading