diff --git a/examples/evaluate_precision/smol.yaml b/examples/evaluate_precision/smol.yaml new file mode 100644 index 000000000..cc17c19e0 --- /dev/null +++ b/examples/evaluate_precision/smol.yaml @@ -0,0 +1,59 @@ +# Example precision-evaluation config: sweep precision-stability features on SmolLM2-135M. +# +# Run with: +# python -m tools.evaluate_precision -c examples/evaluate_precision/smol.yaml +# +# `pretrained.path` accepts either a local checkpoint directory or a HF Hub model id +# (auto-downloaded via `huggingface_hub.snapshot_download` on first use). +pretrained: + path: HuggingFaceTB/SmolLM2-135M + format: llama +output_dir: /tmp/fast_llm_tests/evaluate_precision/features +sequence_length: 2048 +variants: + # Baseline bf16: compute_dtype=bf16 + Fast-LLM defaults (fp32 gradient accumulation, bf16 residual, bf16 lm_head). + bf16: + model.distributed.compute_dtype: bfloat16 + # Turn ON full-precision residual stream. + bf16_fp32_residual: + model.distributed.compute_dtype: bfloat16 + model.base_model.embeddings.full_precision_residual: true + # Turn ON fp32 LM head matmul (PR #526). + bf16_fp32_lm_head: + model.distributed.compute_dtype: bfloat16 + model.base_model.head.fp32_lm_head: true + # Both stability features on (most precise bf16-compute configuration). + bf16_max_precision: + model.distributed.compute_dtype: bfloat16 + model.base_model.embeddings.full_precision_residual: true + model.base_model.head.fp32_lm_head: true + # Diagnostic: enable bf16 reduced-precision reductions in cuBLAS GEMMs. Tests whether the + # within-engine bf16-vs-fp32 gap is sensitive to the partial-sum reduction precision (the + # MMA accumulator is fp32 by hardware on H100/A100; this flag affects split-K reductions). + bf16_reduced_reduction: + model.distributed.compute_dtype: bfloat16 + _torch_backend.cuda.matmul.allow_bf16_reduced_precision_reduction: true + # Diagnostic: simulate a "bf16 inputs, fp32 output" lm-head matmul kernel. fp32_lm_head=True + # upcasts inputs+weights to fp32, then matmul_precision='medium' runs the matmul through + # bf16 Tensor Cores anyway, then logits stay fp32. Tests whether fp32_lm_head's gain comes + # from input precision or from skipping the bf16 output cast. + bf16_in_fp32_out: + model.distributed.compute_dtype: bfloat16 + model.base_model.head.fp32_lm_head: true + _torch_matmul_precision: medium + # fp16 sweep: probes whether the precision-vs-noise picture (rms noise ~0.1 nats per token + # for bf16) shrinks ~8× for fp16 (10 mantissa bits vs 7), as the literature's "switch to + # fp16" recommendation implies. Default dynamic grad-scaler (initial 2^16) is uniform + # across variants, so relative comparisons stay meaningful. + fp16: + model.distributed.compute_dtype: float16 + fp16_fp32_residual: + model.distributed.compute_dtype: float16 + model.base_model.embeddings.full_precision_residual: true + fp16_fp32_lm_head: + model.distributed.compute_dtype: float16 + model.base_model.head.fp32_lm_head: true + fp16_max_precision: + model.distributed.compute_dtype: float16 + model.base_model.embeddings.full_precision_residual: true + model.base_model.head.fp32_lm_head: true diff --git a/examples/evaluate_precision/smol_gspo.yaml b/examples/evaluate_precision/smol_gspo.yaml new file mode 100644 index 000000000..b0e8e319d --- /dev/null +++ b/examples/evaluate_precision/smol_gspo.yaml @@ -0,0 +1,52 @@ +# Example precision-evaluation config: sweep precision-stability features on SmolLM2-135M +# with the GSPO policy-gradient loss (uses advantages and old log-probabilities). +# +# Run with: +# python -m tools.evaluate_precision -c examples/evaluate_precision/smol_gspo.yaml +# +# `pretrained.path` accepts either a local checkpoint directory or a HF Hub model id +# (auto-downloaded via `huggingface_hub.snapshot_download` on first use). +pretrained: + path: HuggingFaceTB/SmolLM2-135M + format: llama +model: + base_model: + head: + losses: + gspo: + type: gspo +output_dir: /tmp/fast_llm_tests/evaluate_precision/gspo +data_path: /tmp/fast_llm_tests/evaluate_precision/gspo_data +sequence_length: 2048 +variants: + bf16: + model.distributed.compute_dtype: bfloat16 + bf16_fp32_residual: + model.distributed.compute_dtype: bfloat16 + model.base_model.embeddings.full_precision_residual: true + bf16_fp32_lm_head: + model.distributed.compute_dtype: bfloat16 + model.base_model.head.fp32_lm_head: true + bf16_max_precision: + model.distributed.compute_dtype: bfloat16 + model.base_model.embeddings.full_precision_residual: true + model.base_model.head.fp32_lm_head: true + bf16_reduced_reduction: + model.distributed.compute_dtype: bfloat16 + _torch_backend.cuda.matmul.allow_bf16_reduced_precision_reduction: true + bf16_in_fp32_out: + model.distributed.compute_dtype: bfloat16 + model.base_model.head.fp32_lm_head: true + _torch_matmul_precision: medium + fp16: + model.distributed.compute_dtype: float16 + fp16_fp32_residual: + model.distributed.compute_dtype: float16 + model.base_model.embeddings.full_precision_residual: true + fp16_fp32_lm_head: + model.distributed.compute_dtype: float16 + model.base_model.head.fp32_lm_head: true + fp16_max_precision: + model.distributed.compute_dtype: float16 + model.base_model.embeddings.full_precision_residual: true + model.base_model.head.fp32_lm_head: true diff --git a/fast_llm/data/document/config.py b/fast_llm/data/document/config.py index a90bcdebc..fbfe60ac3 100644 --- a/fast_llm/data/document/config.py +++ b/fast_llm/data/document/config.py @@ -80,6 +80,12 @@ class LanguageModelBatchPreprocessingConfig(TokenPreprocessingConfig): use_preference_spans: bool = Field(default=False) use_grpo_data: bool = Field(default=False) return_label_counts: bool = Field(default=False) + output_hidden_states: list[str] = Field( + default_factory=list, + desc="Regex patterns to add to each model input's `output_hidden_states` set." + " Matching `_debug`-named tensors get populated into `kwargs[hidden_states]`" + " and (when running under a `Run` context) emitted into `tensor_logs`.", + ) def _validate(self) -> None: super()._validate() diff --git a/fast_llm/data/document/language_model.py b/fast_llm/data/document/language_model.py index 16114cb80..000fcc01d 100644 --- a/fast_llm/data/document/language_model.py +++ b/fast_llm/data/document/language_model.py @@ -161,6 +161,13 @@ def get_model_inputs(self, config: LanguageModelBatchPreprocessingConfig) -> lis self._set_target_inputs(model_inputs, config) + if config.output_hidden_states: + import re + + patterns = {re.compile(pattern) for pattern in config.output_hidden_states} + for model_input in model_inputs: + model_input.output_hidden_states.update(patterns) + return model_inputs def _set_target_inputs( diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index c055a7f2c..4c99798c5 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -100,6 +100,18 @@ def _get_key(cls, parameter_name: str, shard_name: str) -> str: Assert.eq(shard_name, "weights") return parameter_name + @classmethod + def _resolve_path(cls, path: pathlib.Path) -> pathlib.Path: + """Resolve a local directory or HF Hub model id (e.g. ``meta-llama/Llama-3.2-1B``) to a + local snapshot directory. Local directories pass through unchanged; everything else is + materialized via :func:`huggingface_hub.snapshot_download` (cached on subsequent calls). + """ + if path.is_dir(): + return path + import huggingface_hub + + return pathlib.Path(huggingface_hub.snapshot_download(str(path))) + # Use custom config instead of relying on the transformers library @classmethod def _load_config(cls, directory: pathlib.Path | str) -> dict: @@ -128,20 +140,32 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]: { # transformers PretrainedConfig "_name_or_path", + "add_cross_attention", "architectures", "auto_map", "chunk_size_feed_forward", + "cross_attention_hidden_size", "dtype", + "finetuning_task", "id2label", + "is_decoder", "is_encoder_decoder", "label2id", "model_type", "output_attentions", "output_hidden_states", + "prefix", "problem_type", + "pruned_heads", "return_dict", + "task_specific_params", + "tf_legacy_loss", + "tie_encoder_decoder", + "tokenizer_class", "torch_dtype", + "torchscript", "transformers_version", + "use_bfloat16", "use_cache", # Token ids — generation/inference, not architecture. "bos_token_id", @@ -149,10 +173,39 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]: "eos_token_id", "pad_token_id", "sep_token_id", + # Generation defaults — never architecture. + "bad_words_ids", + "begin_suppress_tokens", + "diversity_penalty", + "do_sample", + "early_stopping", + "encoder_no_repeat_ngram_size", + "exponential_decay_length_penalty", + "forced_bos_token_id", + "forced_eos_token_id", + "length_penalty", + "max_length", + "min_length", + "no_repeat_ngram_size", + "num_beam_groups", + "num_beams", + "num_return_sequences", + "output_scores", + "remove_invalid_values", + "repetition_penalty", + "return_dict_in_generate", + "suppress_tokens", + "temperature", + "top_k", + "top_p", + "typical_p", # Initialization / pretraining metadata Fast-LLM does not consume. "initializer_range", "max_position_embeddings", "pretraining_tp", + # Family markers / default-valued knobs serialized by recent transformers versions. + "is_llama_config", + "rope_interleaved", } ) @@ -181,28 +234,29 @@ def _load_weights( import transformers Assert.eq(self.get_shard_names(config), ("weights",)) - if (config.path / transformers.utils.SAFE_WEIGHTS_NAME).is_file(): - paths = {config.path / transformers.utils.SAFE_WEIGHTS_NAME} - elif (config.path / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).is_file(): - logger.info(f"Loading index from {config.path / transformers.utils.SAFE_WEIGHTS_INDEX_NAME}") + directory = self._resolve_path(config.path) + if (directory / transformers.utils.SAFE_WEIGHTS_NAME).is_file(): + paths = {directory / transformers.utils.SAFE_WEIGHTS_NAME} + elif (directory / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).is_file(): + logger.info(f"Loading index from {directory / transformers.utils.SAFE_WEIGHTS_INDEX_NAME}") paths = { - config.path / path - for path in json.loads((config.path / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).read_text())[ + directory / path + for path in json.loads((directory / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).read_text())[ "weight_map" ].values() } - elif (config.path / transformers.utils.WEIGHTS_NAME).is_file(): - paths = {config.path / transformers.utils.WEIGHTS_NAME} - elif (config.path / transformers.utils.WEIGHTS_INDEX_NAME).is_file(): - logger.info(f"Loading index from {config.path / transformers.utils.WEIGHTS_INDEX_NAME}") + elif (directory / transformers.utils.WEIGHTS_NAME).is_file(): + paths = {directory / transformers.utils.WEIGHTS_NAME} + elif (directory / transformers.utils.WEIGHTS_INDEX_NAME).is_file(): + logger.info(f"Loading index from {directory / transformers.utils.WEIGHTS_INDEX_NAME}") paths = { - config.path / path - for path in json.loads((config.path / transformers.utils.WEIGHTS_INDEX_NAME).read_text())[ + directory / path + for path in json.loads((directory / transformers.utils.WEIGHTS_INDEX_NAME).read_text())[ "weight_map" ].values() } else: - raise FileNotFoundError(f"No compatible checkpoint found in {config.path}") + raise FileNotFoundError(f"No compatible checkpoint found in {directory}") for path in paths: logger.info(f"Loading from {path}") diff --git a/tests/utils/compare_tensor_logs.py b/fast_llm/engine/config_utils/compare_tensor_logs.py similarity index 69% rename from tests/utils/compare_tensor_logs.py rename to fast_llm/engine/config_utils/compare_tensor_logs.py index f02d62c79..dbad78a25 100644 --- a/tests/utils/compare_tensor_logs.py +++ b/fast_llm/engine/config_utils/compare_tensor_logs.py @@ -87,6 +87,52 @@ def _compare_dict_keys(self, dict_ref, dict_test, errors, name): # Avoid set to preserve ordering. return [key for key in dict_test if key in dict_ref] + def _compute_diff(self, tensor_ref, tensor_test, step_name, tensor_name) -> dict | None: + # Returns per-tensor error metrics, or None on shape/sampling mismatch. + if tensor_ref["shape"] != tensor_test["shape"]: + return None + if tensor_ref["step"] != tensor_test["step"]: + return None + sub_config = self._get_sub_config(step_name, tensor_name) + samples_ref = tensor_ref["samples"].flatten().float() + samples_test = tensor_test["samples"].flatten().float() + if sub_config.scale != 1.0: + samples_test = samples_test / sub_config.scale + scale_unreg = (samples_ref**2).mean() ** 0.5 + rms_scale = (scale_unreg**2 + sub_config.rms_eps**2) ** 0.5 + diff = samples_test - samples_ref + rms = (diff**2).mean() ** 0.5 + max_diff = diff.abs().max() + bias = diff.mean() + # Linear-regression decomposition: `test ≈ slope * ref + intercept + residual`. + # Useful for separating systematic distortion (slope ≠ 1) from per-position decorrelated + # noise (residual). For RL importance ratios, slope ≠ 1 indicates likely-token-dependent + # bias which is more dangerous than a uniform shift. + centered_test = samples_test - samples_test.mean() + centered_ref = samples_ref - samples_ref.mean() + var_ref = (centered_ref**2).mean() + var_test = (centered_test**2).mean() + cov = (centered_test * centered_ref).mean() + denom = (var_test * var_ref) ** 0.5 + correlation = (cov / denom).item() if denom > 0 else float("nan") + slope = (cov / var_ref).item() if var_ref > 0 else float("nan") + residual_var = (var_test - cov**2 / var_ref).clamp(min=0.0) if var_ref > 0 else var_test + residual_rms = residual_var**0.5 + return { + "rms_abs": rms.item(), + "rms_rel": (rms / rms_scale).item(), + "max_abs": max_diff.item(), + "max_rel": (max_diff / rms_scale).item(), + "ref_scale": scale_unreg.item(), + "ref_scale_regularized": rms_scale.item(), + "bias_abs": bias.item(), + "bias_rel": (bias / rms_scale).item(), + "correlation": correlation, + "slope": slope, + "residual_rms_abs": residual_rms.item(), + "residual_rms_rel": (residual_rms / rms_scale).item(), + } + def compare_tensors(self, tensor_ref, tensor_test, errors, step_name, tensor_name): sub_config = self._get_sub_config(step_name, tensor_name) if tensor_ref["shape"] != tensor_test["shape"]: @@ -108,34 +154,33 @@ def compare_tensors(self, tensor_ref, tensor_test, errors, step_name, tensor_nam ) return - samples_ref = tensor_ref["samples"].flatten().float() - samples_test = tensor_test["samples"].flatten().float() - if sub_config.scale != 1.0: - samples_test = samples_test / sub_config.scale - scale_unreg = (samples_ref**2).mean() ** 0.5 - rms_scale = (scale_unreg**2 + sub_config.rms_eps**2) ** 0.5 - rms = ((samples_ref - samples_test) ** 2).mean() ** 0.5 - max_diff = (samples_ref - samples_test).abs().max() + metrics = self._compute_diff(tensor_ref, tensor_test, step_name, tensor_name) + rms_scale = metrics["ref_scale_regularized"] + scale_unreg = metrics["ref_scale"] tensor_errors = [] - if rms > sub_config.rms_abs_tolerance: - tensor_errors.append(f" * RMS diff absolute = {rms} > {sub_config.rms_abs_tolerance}") + if metrics["rms_abs"] > sub_config.rms_abs_tolerance: + tensor_errors.append(f" * RMS diff absolute = {metrics['rms_abs']} > {sub_config.rms_abs_tolerance}") - if rms / rms_scale > sub_config.rms_rel_tolerance: + if metrics["rms_rel"] > sub_config.rms_rel_tolerance: tensor_errors.append( - f" * RMS diff scaled = {rms / rms_scale} > {sub_config.rms_rel_tolerance} (scale={rms_scale}, unregularized={scale_unreg})" + f" * RMS diff scaled = {metrics['rms_rel']} > {sub_config.rms_rel_tolerance} (scale={rms_scale}, unregularized={scale_unreg})" ) - if max_diff > sub_config.max_abs_tolerance: - tensor_errors.append(f" * Max diff absolute = {max_diff} > {sub_config.max_abs_tolerance}") + if metrics["max_abs"] > sub_config.max_abs_tolerance: + tensor_errors.append(f" * Max diff absolute = {metrics['max_abs']} > {sub_config.max_abs_tolerance}") - if max_diff / rms_scale > sub_config.max_rel_tolerance: + if metrics["max_rel"] > sub_config.max_rel_tolerance: tensor_errors.append( - f" * Max diff scaled = {max_diff / rms_scale} > {sub_config.max_rel_tolerance} (scale={rms_scale}, unregularized={scale_unreg})" + f" * Max diff scaled = {metrics['max_rel']} > {sub_config.max_rel_tolerance} (scale={rms_scale}, unregularized={scale_unreg})" ) if tensor_errors: + samples_ref = tensor_ref["samples"].flatten().float() + samples_test = tensor_test["samples"].flatten().float() + if sub_config.scale != 1.0: + samples_test = samples_test / sub_config.scale tensor_errors.extend( [ f" Test samples: " + "".join(f"{x:12.4e}" for x in samples_test[: self.show_samples].tolist()), diff --git a/fast_llm/engine/config_utils/logging.py b/fast_llm/engine/config_utils/logging.py index 32deb4562..b82d4c847 100644 --- a/fast_llm/engine/config_utils/logging.py +++ b/fast_llm/engine/config_utils/logging.py @@ -76,6 +76,15 @@ class TensorLogsConfig(Config): valid=check_field(Assert.gt, 0), ) full_tensors: bool = Field(default=False, desc="Save and/or print entire tensors.") + sample_level_overrides: dict[str, int] = Field( + default_factory=dict, + desc="Per-tensor sample-density overrides (regex pattern -> level)." + " For tensors whose logged name matches a pattern, the effective `log_tensor` level is" + " raised to the matching override (samples = 2 ** (level - 3))." + " Useful for sparse tensors like embedding-weight gradients where the default sampling" + " stride misses most non-zero rows.", + hint=FieldHint.logging, + ) class TensorLogs: diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 958a3d228..96cb52f09 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -139,6 +139,14 @@ class StageConfig(Config): desc="Check for tensor-parallel desyncs and log an error if a desync is found. High overhead", hint=FieldHint.logging, ) + debug_hidden_states_log: list[str] = Field( + default_factory=list, + desc="Regex patterns for `_debug`-named tensors (`.`, e.g. `head.logits`," + " `decoder.0.norm_1`) to log to `tensor_logs`. Patterns are appended to each model" + " input's `output_hidden_states` set, so matching tensors are both populated into" + " `kwargs[hidden_states]` for downstream consumers and emitted into `tensor_logs`.", + hint=FieldHint.logging, + ) @config_class() diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 805eae1e5..0476a8107 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -18,6 +18,12 @@ logger = logging.getLogger(__name__) +# Verbosity used for `output_hidden_states`-driven tensor logging. `log_tensor` collects sampled +# tensor values only at level >= 3; 13 matches the convention in the layer-comparison tests +# (1024 sampled values per tensor). +_HIDDEN_STATE_LOG_LEVEL = 13 + + class DebugLayer: """ A debugging utility for blocks. @@ -55,11 +61,14 @@ def __call__( if level > 1: log_pipeline_parallel_main_rank(lambda: log_memory_usage(name, str)) - if level > 0 and tensor is not None: + # `output_hidden_state` requests full-fidelity capture even when `model_debug_level` is + # off — clamp the log level so samples are saved alongside summary stats. + log_level = max(level, _HIDDEN_STATE_LOG_LEVEL) if output_hidden_state else level + if log_level > 0 and tensor is not None: log_distributed_tensor( "", tensor, - level=level, + level=log_level, meta=meta, **logging_kwargs, ) @@ -67,7 +76,7 @@ def __call__( log_distributed_grad( "", tensor, - level=level, + level=log_level, meta=self._get_meta(tensor, f"{name}.grad", dims), **logging_kwargs, ) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index bde33f297..6a0bfcfd6 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -131,6 +131,13 @@ class LanguageModelHeadConfig(BlockConfig): hint=FieldHint.architecture, valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) + fp32_lm_head: bool = Field( + default=False, + desc="Upcast input and weight to float32 before the lm_head linear. " + "Matches vLLM's bf16_last_layer_fp32 quantization so new_logprobs and old_logprobs " + "are computed at the same numerical precision, keeping the IS ratio near 1 at init.", + hint=FieldHint.feature, + ) prediction_heads: int = Field( default=1, desc="Prediction heads.", diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 22c750082..8dd511480 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -22,7 +22,7 @@ ) from fast_llm.layers.language_model.loss.config import LanguageModelLabelEntropyLossConfig from fast_llm.layers.language_model.loss.loss import LanguageModelLoss -from fast_llm.tensor import TensorMeta +from fast_llm.tensor import TensorMeta, accumulate_gradient from fast_llm.utils import Assert, safe_merge_dicts logger = logging.getLogger(__name__) @@ -252,9 +252,17 @@ def _logits_loss_forward_backward_partial( split_index: int = 0, return_logits: bool = False, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + if self._config.fp32_lm_head: + input_dtype = input_.dtype + input_ = input_.to(torch.float32) + # detach → requires_grad=False → output_parallel_linear_backward skips weight grad + weight = self.output_weights.detach().to(torch.float32) + else: + weight = self.output_weights + logits, context = output_parallel_linear_forward( input_=input_, - weight=self.output_weights, + weight=weight, bias=None, group=self._parallel_dim.group if self._vocab_parallel else None, sequence_parallel=self._sequence_parallel and self._vocab_parallel, @@ -285,12 +293,38 @@ def _logits_loss_forward_backward_partial( if loss_value is not None: losses_.append(loss_value.detach()) - if grad is not None and self._config.final_logit_softcap is not None: + if grad is not None: + # `logits` has `requires_grad=False` (custom-autograd), so the existing + # `_debug(logits, ...)` can't auto-capture the gradient. Log it explicitly here + # so `output_hidden_states` patterns covering `head.logits` also catch the grad. + self._debug( + grad, + f"logits.grad{"" if self._config.cross_entropy_splits == 1 else f"_{split_index}"}", + (kwargs.get(LanguageModelKwargs.hidden_token_dim), self._vocab_dim), + kwargs, + scale=self._config.logits_scale_factor, + ) + + if not self.training or grad is None: + return sum(losses_) if losses_ else None, None + + if self._config.final_logit_softcap is not None: grad = _softcap_backward(grad, logits, self._config.final_logit_softcap) - return sum(losses_) if losses_ else None, ( - output_parallel_linear_backward(grad, context) if self.training else None - ) + input_grad = output_parallel_linear_backward(grad, context) + if self._config.fp32_lm_head: + # Weight grad was skipped because weight.requires_grad=False; accumulate manually. + # context: (input_, weight, bias, group, sequence_parallel, ...) + saved_input = context[0] + if context[4]: # sequence_parallel + from fast_llm.core.ops import gather_op + + saved_input = gather_op(saved_input, context[3], dim=0) + grad_weight = grad.flatten(0, -2).t().mm(saved_input.flatten(0, -2)) + accumulate_gradient(self.output_weights, grad_weight.to(self.output_weights.dtype)) + input_grad = input_grad.to(input_dtype) + + return sum(losses_) if losses_ else None, input_grad def get_loss_definitions(self) -> list[LossDef]: return [ diff --git a/fast_llm/layers/language_model/loss/chosen_logprob.py b/fast_llm/layers/language_model/loss/chosen_logprob.py new file mode 100644 index 000000000..cb99e7c17 --- /dev/null +++ b/fast_llm/layers/language_model/loss/chosen_logprob.py @@ -0,0 +1,41 @@ +import math +import typing + +import torch + +from fast_llm.layers.language_model.loss.config import LanguageModelChosenLogprobLossConfig +from fast_llm.layers.language_model.loss.loss import LanguageModelLoss +from fast_llm.logging import log_tensor + + +class LanguageModelChosenLogprobLoss[ConfigType: LanguageModelChosenLogprobLossConfig](LanguageModelLoss[ConfigType]): + """Logs log π(label) per position via the tensor-log pipeline; contributes nothing to gradients.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + # Don't surface a "chosen_logprob: 0" line in the training metrics. + self._do_register_loss = False + + def _forward_backward( + self, + logits: "torch.Tensor", + kwargs: dict[str, typing.Any], + losses: dict | None = None, + split_index: int = 0, + grad_logits: torch.Tensor | None = None, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + if self._vocab_parallel: + raise NotImplementedError("chosen_logprob loss does not support vocab parallel") + labels = self._get_labels(kwargs, split_index).reshape(-1).long() + with torch.no_grad(): + log_probs = torch.log_softmax(logits.float() * self._logits_scale_factor, dim=-1) + # Mask out-of-range labels (e.g. -100 for prompt tokens in RL data) before gather to + # avoid CUDA assert. Fast-LLM convention: any label < 0 is masked. + valid = labels >= 0 + safe_labels = labels.clamp(min=0) + chosen_logprob = log_probs.gather(-1, safe_labels.unsqueeze(-1)).squeeze(-1) + chosen_logprob = chosen_logprob[valid] + # Capture the full tensor: bias is the mean over all positions, not a sampled subset. + level = math.ceil(math.log2(max(chosen_logprob.numel(), 1))) + 3 + log_tensor(f"Global : {self._name}", chosen_logprob, level=level) + return torch.zeros((), dtype=logits.dtype, device=logits.device), grad_logits diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 9a220aacf..aa05fbb9a 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -9,6 +9,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: + from fast_llm.layers.language_model.loss.chosen_logprob import LanguageModelChosenLogprobLoss from fast_llm.layers.language_model.loss.dpo import LanguageModelDPOLoss from fast_llm.layers.language_model.loss.entropy_loss import ( LanguageModelDistillationLoss, @@ -186,6 +187,30 @@ def get_reference_models(self) -> set[str]: return {self.reference_model} +@config_class(dynamic_type={LanguageModelLossConfig: "chosen_logprob"}) +class LanguageModelChosenLogprobLossConfig(LanguageModelLossConfig): + """No-gradient diagnostic loss that logs log π(label) per position via the tensor-log pipeline. + + The chosen-token log-prob is the scalar that policy-gradient importance ratios depend on, + so its precision drift is a more direct signal than bulk-logit RMS. + """ + + _abstract: typing.ClassVar[bool] = False + + weight: float = Field( + default=0.0, + hint=FieldHint.derived, + desc="Forced to 0: this loss has no gradient contribution.", + valid=check_field(Assert.eq, 0.0), + ) + + @property + def loss_class(self) -> "type[LanguageModelChosenLogprobLoss]": + from fast_llm.layers.language_model.loss.chosen_logprob import LanguageModelChosenLogprobLoss + + return LanguageModelChosenLogprobLoss + + @config_class(dynamic_type={LanguageModelLossConfig: "z_loss"}) class LanguageModelZLossConfig(LanguageModelLossConfig): """Z-loss regularization to prevent overconfidence.""" diff --git a/fast_llm/logging.py b/fast_llm/logging.py index 2619883d6..6326e7e4b 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -131,6 +131,15 @@ def log_tensor[T]( ) -> T | None: if level < 1: return + # Per-tensor sample-density override: lets users boost the effective level for specific + # tensors (e.g. sparse embedding-weight gradients) via `TensorLogsConfig`. + overrides = TensorLogs.config.sample_level_overrides if TensorLogs.config else None + if overrides: + import re + + for pattern, override in overrides.items(): + if re.search(pattern, name): + level = max(level, override) tensor = tensor.detach() if tensor.ndim == 0: tensor = tensor[None] diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 2e9b4365b..f4d4b286a 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -112,6 +112,7 @@ def get_preprocessing_config( return LanguageModelBatchPreprocessingConfig( phase=phase, micro_batch_splits=micro_batch_splits, + output_hidden_states=list(self._config.multi_stage.debug_hidden_states_log), **self._base_model.get_preprocessing_config(), ) diff --git a/tests/data/test_tokenizer.py b/tests/data/test_tokenizer.py index 184294551..04a24e2ae 100644 --- a/tests/data/test_tokenizer.py +++ b/tests/data/test_tokenizer.py @@ -2,13 +2,13 @@ from fast_llm.data.preparation.tokenizer import Tokenizer, TokenizerConfig from fast_llm.utils import Assert -from tests.utils.dataset import download_santacoder_tokenizer +from tests.utils.dataset import download_test_tokenizer from tests.utils.global_variables import TOKENIZER_PATH @pytest.fixture(scope="session") def common_tokenizer() -> Tokenizer: - download_santacoder_tokenizer() + download_test_tokenizer() return TokenizerConfig(path=TOKENIZER_PATH).get_tokenizer() diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 0b4dbafc1..f3febae4b 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -18,9 +18,9 @@ ModelConfigType, ) from fast_llm.engine.checkpoint.convert import ConvertConfig +from fast_llm.engine.config_utils.compare_tensor_logs import CompareConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode from fast_llm.utils import Assert, header -from tests.utils.compare_tensor_logs import CompareConfig from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.model_configs import ModelTestingConfig, ModelTestingGroup from tests.utils.save_load_configs import DISTRIBUTED_SAVE_LOAD_CONFIGS, DistributedSaveLoadConfig diff --git a/tests/models/test_lm_eval.py b/tests/models/test_lm_eval.py index 7ae26c2d6..c8b5fd004 100644 --- a/tests/models/test_lm_eval.py +++ b/tests/models/test_lm_eval.py @@ -3,7 +3,7 @@ import pytest -from tests.utils.dataset import download_santacoder_tokenizer +from tests.utils.dataset import download_test_tokenizer from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.global_variables import TOKENIZER_PATH from tests.utils.model_configs import ModelTestingGroup @@ -15,7 +15,7 @@ @pytest.fixture(scope="module") def tokenizer_path(): - download_santacoder_tokenizer() + download_test_tokenizer() return TOKENIZER_PATH diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index 03ebac757..3c95d0dea 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -18,9 +18,9 @@ from fast_llm.data.dataset.sampled import logger from fast_llm.data.document.language_model import LanguageModelDocument from fast_llm.data.preparation.tokenizer import TokenizerConfig +from fast_llm.engine.config_utils.compare_tensor_logs import CompareConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert -from tests.utils.compare_tensor_logs import CompareConfig from tests.utils.dataset import get_common_test_dataset from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.global_variables import DATASET_CACHE, MODEL_TEST_VOCAB_SIZE, TOKENIZER_NAME diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index a2ea2f46e..e7b206cf5 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -14,7 +14,7 @@ from tests.utils.global_variables import DATASET_CACHE, MODEL_TEST_VOCAB_SIZE, TOKENIZER_FILE, TOKENIZER_PATH -def download_santacoder_tokenizer(): +def download_test_tokenizer(): if not TOKENIZER_FILE.is_file(): import transformers @@ -218,7 +218,7 @@ def _get_test_dataset( if has_grpo_data: source_schema["advantages"] = "advantages" - download_santacoder_tokenizer() + download_test_tokenizer() preparator_config = GPTMemmapDatasetPreparatorConfig.from_dict( { "dataset": { diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index f3bbbac8d..d08b023b9 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -4,7 +4,7 @@ import torch -from tests.utils.compare_tensor_logs import CompareConfig +from fast_llm.engine.config_utils.compare_tensor_logs import CompareConfig logger = logging.getLogger(__name__) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py new file mode 100644 index 000000000..9da8904a1 --- /dev/null +++ b/tools/evaluate_precision.py @@ -0,0 +1,637 @@ +import json +import logging +import math +import pathlib +import shutil +import statistics +import typing + +from fast_llm.config import Field, FieldHint, config_class +from fast_llm.engine.config_utils.compare_tensor_logs import CompareConfig +from fast_llm.engine.config_utils.runnable import RunnableConfig +from fast_llm.engine.training.config import TrainerConfig +from fast_llm.models.gpt.config import PretrainedGPTModelConfig + +# Populate the trainer dynamic-type registry. +import fast_llm.data.auto # noqa: F401 # isort:skip +import fast_llm.engine.checkpoint.convert # noqa: F401 # isort:skip +import fast_llm.models.auto # noqa: F401 # isort:skip + +logger = logging.getLogger(__name__) + + +_REFERENCE_NAME = "reference" +_MODEL_TYPE = "gpt" +# Embedding-weight gradients are row-sparse (only input-token rows non-zero), so a +# uniformly-spaced sample of vocab_size entries usually misses all of them. The pattern +# is applied via `TensorLogsConfig.sample_level_overrides` and picked up inside +# `log_tensor` (samples = 2 ** (level - 3) -> level 23 yields ~1M samples per tensor). +_SPARSE_GRAD_LEVEL = 23 +_SPARSE_GRAD_OVERRIDES = {r"Global gradient: embeddings\.": _SPARSE_GRAD_LEVEL} +_CHOSEN_LOGPROB_NAME = "chosen_logprob" +# Auto-calibration of the constant gradient scaler. Each variant runs a calibration pass at +# `scale=1` (no overflow risk), then the actual run uses the largest power-of-2 scale that +# keeps logged gradient magnitudes (and a small safety factor for hidden in-kernel +# intermediates like norm partial sums) within fp16's representable range. Per-variant +# unscaling at compare time lets different variants pick different scales without polluting +# the relative metrics. +_HIDDEN_INTERMEDIATE_HEADROOM = 4.0 # safety factor for fused-kernel partial sums we don't log +_CALIBRATION_SUBDIR_PREFIX = ".calibration_" +# Variant-override keys starting with this prefix are interpreted as `torch.backends.` and +# applied before each run. Used for diagnostics (e.g. enabling bf16 reduced-precision reductions); +# entries are listed in `_TORCH_BACKEND_DEFAULTS` and reset to their defaults before applying. +_TORCH_BACKEND_PREFIX = "_torch_backend." +_TORCH_BACKEND_DEFAULTS = { + "cuda.matmul.allow_bf16_reduced_precision_reduction": False, +} +_TORCH_MATMUL_PRECISION_KEY = "_torch_matmul_precision" + + +@config_class() +class EvaluatePrecisionConfig(PretrainedGPTModelConfig, RunnableConfig): + """Evaluate layer-wise numerical-error propagation against an fp32 reference. + + Inherits `model` and `pretrained` from `PretrainedGPTModelConfig`: either or both + can be set in the YAML. The tool runs one fp32 reference + one trainer invocation + per variant, captures per-layer forward activations and input gradients via the + standard tensor-logs pipeline, and reports per-tensor RMS / max diffs. + """ + + _abstract = False + variants: dict[str, typing.Any] = Field( + desc="Named override bundles to evaluate against the fp32 reference." + " Each value is a flat dict mapping dotted-path keys (same syntax as the Fast-LLM CLI) to values.", + hint=FieldHint.core, + ) + output_dir: pathlib.Path = Field( + desc="Directory for per-run tensor-log artifacts and the final JSON report.", + hint=FieldHint.core, + ) + num_samples: int = Field( + default=8192, + desc="Number of sampled values stored per logged tensor (rounded up to next power of 2)." + " Sparse tensors (e.g. embedding-weight gradients) get a higher level via" + " `TensorLogsConfig.sample_level_overrides`.", + hint=FieldHint.feature, + ) + sequence_length: int = Field( + default=2048, + desc="Sequence length per micro-batch sample. Drives both `data.micro_batch_size` (the" + " per-sample token count, despite the name) and `data.maximum_document_length`.", + hint=FieldHint.feature, + ) + data_path: pathlib.Path | None = Field( + default=None, + desc="If set, prepare a tokenized memmap dataset with advantages and `old_log_probabilities`" + " at this path (using the test helper `_get_test_dataset`) and use it as the training" + " input — required for policy-gradient losses like GSPO/GRPO. If unset, uses random tokens.", + hint=FieldHint.feature, + ) + + def _validate(self) -> None: + super()._validate() + assert _REFERENCE_NAME not in self.variants, f"'{_REFERENCE_NAME}' is reserved for the fp32 baseline." + for name, overrides in self.variants.items(): + assert isinstance(overrides, dict) and all( + isinstance(k, str) for k in overrides + ), f"Variant {name!r} must be a flat dict of dotted-path string keys." + + def run(self) -> None: + self.output_dir.mkdir(parents=True, exist_ok=True) + self._prepare_data() + runs: dict[str, dict[str, typing.Any]] = {_REFERENCE_NAME: {}} + runs.update(self.variants) + scales: dict[str, float] = {} + for name, variant_overrides in runs.items(): + scales[name] = self._calibrate_and_run(name, variant_overrides) + + ref_artifacts = self._artifact_path(_REFERENCE_NAME) + results = { + name: self._compare(ref_artifacts, self._artifact_path(name), scales[_REFERENCE_NAME], scales[name]) + for name in self.variants + } + + report_path = self.output_dir / "precision_report.json" + report_path.write_text(json.dumps({"scales": scales, "variants": results}, indent=2)) + logger.info(f"Wrote report to {report_path}") + logger.info(f"Per-variant gradient scales: {scales}") + + for name, rows in results.items(): + _print_table(name, rows) + _print_summary(results) + + def _calibrate_and_run(self, name: str, variant_overrides: dict[str, typing.Any]) -> float: + """Pick a power-of-2 gradient scale for this variant via a calibration pass, then run with it. + + Calibration runs with `constant=1.0` so no overflow is possible; scanning logged gradients + then gives us `max_unscaled`. The largest safe power of 2 keeps `scale * max_unscaled` below + `fp16_max / hidden_intermediate_budget`, where the budget reserves headroom for partial sums + inside fused kernels (e.g. norm-weight grads sum over the sequence dimension). + """ + import torch + + cal_dir = self.output_dir / f"{_CALIBRATION_SUBDIR_PREFIX}{name}" + self._run_one(name, variant_overrides, constant_scale=1.0, experiment_dir=cal_dir) + max_unscaled = _scan_max_grad(cal_dir / "runs" / "0" / "artifacts") + shutil.rmtree(cal_dir) + if max_unscaled <= 0.0: + scale = 1.0 + logger.warning(f"[{name}] calibration found no nonzero gradient — falling back to scale=1.0") + else: + fp16_max = torch.finfo(torch.float16).max + optimal_unrounded = fp16_max / max_unscaled / _HIDDEN_INTERMEDIATE_HEADROOM + scale = float(2 ** max(0, math.floor(math.log2(optimal_unrounded)))) + logger.info(f"[{name}] calibration: max_unscaled={max_unscaled:.4e} -> gradient_scaler.constant={scale:g}") + self._run_one(name, variant_overrides, constant_scale=scale) + return scale + + def _prepare_data(self) -> None: + if self.data_path is None: + return + if (self.data_path / "fast_llm_config.yaml").is_file(): + return + # Couples `tools/` to `tests/utils/` for now — extract later if it sticks. + from tests.utils.dataset import _get_test_dataset + + self.data_path.mkdir(parents=True, exist_ok=True) + logger.info(f"Preparing memmap dataset at {self.data_path}") + _get_test_dataset( + self.data_path, + seed=42, + has_grpo_data=True, + max_vocab_size=self.model.base_model.embeddings.vocab_size, + ) + + def _artifact_path(self, name: str) -> pathlib.Path: + return self.output_dir / name / "runs" / "0" / "artifacts" + + def _run_one( + self, + name: str, + variant_overrides: dict[str, typing.Any], + *, + constant_scale: float | None = None, + experiment_dir: pathlib.Path | None = None, + ) -> None: + # The trainer's Run picks the next `runs/` subdir based on what already exists; wipe + # any prior contents so each invocation lands in `runs/0` and stale artifacts can't be + # read by `_artifact_path` below. + if experiment_dir is None: + experiment_dir = self.output_dir / name + if experiment_dir.exists(): + shutil.rmtree(experiment_dir) + # Base config: hardcoded training/optimizer/data/run skeleton plus the user's model/pretrained. + # Forced fp32 on the reference baseline lives in here too so a variant can override it. + optimizer_config: dict[str, typing.Any] = { + "learning_rate": {"base": 0.0, "decay_style": "constant", "warmup_iterations": 0}, + } + if constant_scale is not None: + optimizer_config["gradient_scaler"] = {"constant": float(constant_scale)} + base_dict: dict[str, typing.Any] = { + "pretrained": self.pretrained.to_dict(), + "model": self.model.to_dict(), + "training": { + "train_iters": 1, + "num_workers": 0, + "logs": {"interval": 1}, + }, + "optimizer": optimizer_config, + "data": { + "datasets": { + "training": ( + {"type": "file", "path": str(self.data_path / "fast_llm_config.yaml")} + if self.data_path is not None + else {"type": "random"} + ) + }, + # Despite the name, Fast-LLM's `data.micro_batch_size` is the per-sample sequence + # length, not the batch dimension. Default 2048 → 2048-token sample. + "micro_batch_size": self.sequence_length, + "maximum_document_length": self.sequence_length, + }, + "run": { + "experiment_dir": str(experiment_dir.resolve()), + "tensor_logs": { + "save": True, + "show": False, + "sample_level_overrides": _SPARSE_GRAD_OVERRIDES, + }, + }, + } + # Translate `num_samples` to a `log_tensor` level: 2**(level-3) = samples. + log_level = math.ceil(math.log2(max(self.num_samples, 1))) + 3 + fp32_dtypes = { + ("model", "distributed", "compute_dtype"): "float32", + ("model", "distributed", "optimization_dtype"): "float32", + } + # Split off torch-backend overrides before passing the rest to Fast-LLM's config system. + backend_overrides = { + key[len(_TORCH_BACKEND_PREFIX) :]: value + for key, value in variant_overrides.items() + if key.startswith(_TORCH_BACKEND_PREFIX) + } + _apply_torch_backend_overrides(backend_overrides) + matmul_precision = variant_overrides.get(_TORCH_MATMUL_PRECISION_KEY, "highest") + _apply_torch_matmul_precision(matmul_precision) + variant_updates = { + tuple(key.split(".")): value + for key, value in variant_overrides.items() + if not key.startswith(_TORCH_BACKEND_PREFIX) and key != _TORCH_MATMUL_PRECISION_KEY + } + # Tool-required overrides win over variants — a variant must not silently disable tensor logging. + tool_overrides: dict[tuple[str, ...], typing.Any] = { + ("model", "multi_stage", "debug_layer_outputs"): log_level, + ("model", "multi_stage", "debug_layer_gradients"): log_level, + ("model", "multi_stage", "debug_all_param_gradients"): log_level, + # Capture the LM-head logits via the `output_hidden_states` mechanism: the head's + # `_debug(logits, ...)` call matches this pattern and emits to `tensor_logs`. + ("model", "multi_stage", "debug_hidden_states_log"): [r"head\.logits"], + # Diagnostic loss that logs log π(label) per position via the tensor-log pipeline. + # Contributes no gradient (weight=0); the comparison code picks it up by name. + ("model", "base_model", "head", "losses", _CHOSEN_LOGPROB_NAME): {"type": "chosen_logprob"}, + } + # When the user hasn't configured any loss, the head defaults to cross-entropy. Adding a + # loss explicitly suppresses that default, so re-add it so gradients still flow. + if not (self.model.base_model.head.losses or {}): + tool_overrides[("model", "base_model", "head", "losses", "cross_entropy")] = {"type": "label"} + logger.info(f"=== Running {name!r} ===") + if variant_overrides: + logger.info(f"Variant overrides: {variant_overrides}") + trainer_class = TrainerConfig.get_subclass(_MODEL_TYPE) + trainer_config = trainer_class.from_dict(base_dict, fp32_dtypes, variant_updates, tool_overrides) + trainer_config.configure_logging() + trainer_config._get_runnable()() + + def _compare( + self, + ref_path: pathlib.Path, + test_path: pathlib.Path, + ref_scale: float, + test_scale: float, + ) -> list[dict[str, typing.Any]]: + compare_config = CompareConfig() + errors: list[str] = [] + ref_logs = compare_config._extract_tensor_logs(ref_path, errors) + test_logs = compare_config._extract_tensor_logs(test_path, errors) + for error in errors: + logger.warning(error) + # Each variant's gradient logs are scaled by its own `constant` factor (auto-calibrated). + # Undo per-variant scaling so the relative comparison reflects unscaled gradient diffs. + _unscale_gradients_in_place(ref_logs, ref_scale) + _unscale_gradients_in_place(test_logs, test_scale) + rows: list[dict[str, typing.Any]] = [] + for step_name in sorted(ref_logs): + if step_name not in test_logs: + logger.warning(f"Step {step_name!r} missing from test logs") + continue + step_ref = ref_logs[step_name] + step_test = test_logs[step_name] + for tensor_name, ref in step_ref.items(): + if tensor_name not in step_test: + continue + metrics = compare_config._compute_diff(ref, step_test[tensor_name], step_name, tensor_name) + if metrics is None: + continue + rows.append( + { + "step": step_name, + "tensor_name": tensor_name, + "kind": _classify(tensor_name), + "shape": ref["shape"], + **metrics, + } + ) + return rows + + +def _is_gradient_like(tensor_name: str) -> bool: + # Anything affected by the loss-scaling multiplier: parameter gradients from `Fsdp.log_shard`, + # backward activations from layer hooks, and explicit `.grad` debug entries (e.g. logits.grad). + return ("gradient:" in tensor_name) or (" bw" in tensor_name) or (".grad" in tensor_name) + + +def _scan_max_grad(artifact_path: pathlib.Path) -> float: + max_abs = 0.0 + compare_config = CompareConfig() + errors: list[str] = [] + logs = compare_config._extract_tensor_logs(artifact_path, errors) + for step_logs in logs.values(): + for tensor_name, entry in step_logs.items(): + if not _is_gradient_like(tensor_name): + continue + # Saved stats include min/max; fall back to samples if absent. + if "max" in entry and "min" in entry: + value = max(abs(float(entry["max"])), abs(float(entry["min"]))) + else: + value = float(entry["samples"].abs().max().item()) + if math.isfinite(value) and value > max_abs: + max_abs = value + return max_abs + + +def _unscale_gradients_in_place(logs: dict, scale: float) -> None: + if scale == 1.0: + return + inv = 1.0 / scale + for step_logs in logs.values(): + for tensor_name, entry in step_logs.items(): + if not _is_gradient_like(tensor_name): + continue + entry["samples"] = entry["samples"].float() * inv + for key in ("min", "max", "mu", "std"): + if key in entry and entry[key] is not None: + entry[key] = float(entry[key]) * inv + + +def _apply_torch_backend_overrides(overrides: dict[str, typing.Any]) -> None: + import torch + + unknown = set(overrides) - set(_TORCH_BACKEND_DEFAULTS) + if unknown: + logger.warning(f"Unknown torch backend overrides (ignored): {sorted(unknown)}") + for path, default in _TORCH_BACKEND_DEFAULTS.items(): + value = overrides.get(path, default) + obj: typing.Any = torch.backends + parts = path.split(".") + for part in parts[:-1]: + obj = getattr(obj, part) + setattr(obj, parts[-1], value) + + +def _apply_torch_matmul_precision(precision: str) -> None: + import torch + + torch.set_float32_matmul_precision(precision) + + +def _layer_name(tensor_name: str) -> str: + # Stage hooks name tensors `Global fw: ...` / `Global bw: ...`; + # Fsdp.log_shard names weight gradients `Global gradient: `. + prefix = tensor_name.split(":", 1)[0].strip().split() + if prefix == ["Global", "gradient"]: + param = tensor_name.split(":", 1)[1].strip() + return param.split(".")[0] + if prefix and prefix[0] == "Global": + prefix = prefix[1:] + if prefix and prefix[-1] in ("fw", "bw"): + prefix = prefix[:-1] + return " ".join(prefix) if prefix else "?" + + +def _named_row(rows: list[dict[str, typing.Any]], name: str) -> dict[str, typing.Any] | None: + return next((r for r in rows if r["tensor_name"].split(":", 1)[-1].strip() == name), None) + + +_LM_HEAD_NAME = "head.output_weights" +_EMBEDDINGS_NAME = "embeddings.word_embeddings_weight" + + +def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: + sample = next(iter(results.values())) + has_fw_logits = _named_row(sample, "head.logits") is not None + has_bw_logits = _named_row(sample, "head.logits.grad") is not None + has_bias = any( + r["kind"] == "grad" and r["tensor_name"].split(":", 1)[-1].strip().endswith(".bias") for r in sample + ) + # Each kind's aggregation columns are listed chronologically (left-to-right matches + # the order tensors are logged). Logits show up via `output_hidden_states` on the + # fw/bw boundary; weight gradients have no logits hook. + fw_aggs = ("first", "median", "max") + (("logits",) if has_fw_logits else ()) + ("last",) + bw_aggs = ("first",) + (("logits",) if has_bw_logits else ()) + ("median", "max", "last") + grad_aggs = ( + ("lm_head", "linear_med", "linear_max", "norm_med", "norm_max") + + (("bias_med", "bias_max") if has_bias else ()) + + ("embeddings",) + ) + aggs_per_kind = {"fw": fw_aggs, "bw": bw_aggs, "grad": grad_aggs} + for kind in ("fw", "bw", "grad"): + _print_summary_table(results, kind, aggs_per_kind[kind]) + if _named_row(sample, _CHOSEN_LOGPROB_NAME) is not None: + _print_chosen_logprob_summary(results) + + +def _print_chosen_logprob_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: + rows_by_variant = {name: _named_row(rows, _CHOSEN_LOGPROB_NAME) for name, rows in results.items()} + # log π(label) is the scalar that policy-gradient importance ratios depend on. Bias persists + # under per-document averaging where RMS shrinks ~1/√T, so for RL stability it's the more + # informative signal — surface it alongside RMS, slope and residual. + rms_rel_decimals = _column_decimals((r["rms_rel"] for r in rows_by_variant.values()), default=3, max_decimals=5) + bias_rel_decimals = _column_decimals((r["bias_rel"] for r in rows_by_variant.values()), default=3, max_decimals=5) + resid_rel_decimals = _column_decimals( + (r["residual_rms_rel"] for r in rows_by_variant.values()), default=3, max_decimals=5 + ) + name_width = max((len(name) for name in results), default=7) + 1 + cols = [ + ("RMS rel", lambda r: f"{r['rms_rel'] * 100:.{rms_rel_decimals}f}%"), + ("Bias rel", lambda r: f"{r['bias_rel'] * 100:+.{bias_rel_decimals}f}%"), + ("Resid rel", lambda r: f"{r['residual_rms_rel'] * 100:.{resid_rel_decimals}f}%"), + ("Corr", lambda r: f"{r['correlation']:.5f}"), + ("Slope", lambda r: f"{r['slope']:+.5f}"), + ("Max abs", lambda r: f"{r['max_abs']:.4g}"), + ("Scale", lambda r: f"{r['ref_scale']:.4g}"), + ] + widths = [max(len(label), max(len(fn(r)) for r in rows_by_variant.values())) for label, fn in cols] + print(f"\n=== Summary: chosen_logprob (per-token) ===") + header = f"{'Variant':<{name_width}}" + " ".join( + f"{label:<{w}}" for (label, _), w in zip(cols, widths, strict=True) + ) + print(header) + print("-" * len(header)) + for name, row in rows_by_variant.items(): + cells = [fn(row) for _, fn in cols] + print(f"{name:<{name_width}}" + " ".join(f"{c:<{w}}" for c, w in zip(cells, widths, strict=True))) + + +def _grad_category(tensor_name: str) -> str: + name = tensor_name.split(":", 1)[-1].strip() + if name.endswith(".bias"): + return "bias" + if ".norm_" in name or name.endswith(".norm.weight"): + return "norm" + return "linear" + + +def _print_summary_table(results: dict[str, list[dict[str, typing.Any]]], kind: str, aggs: tuple[str, ...]) -> None: + sample = next(iter(results.values())) + group = [r for r in sample if r["kind"] == kind] + if not group: + return + endpoint_labels = { + "first": _layer_name(group[0]["tensor_name"]), + "last": _layer_name(group[-1]["tensor_name"]), + } + mid_labels = { + "median": "mid med", + "max": "mid max", + "logits": "logits", + "lm_head": "lm head", + "embeddings": "embeddings", + "linear_med": "linear med", + "linear_max": "linear max", + "norm_med": "norm med", + "norm_max": "norm max", + "bias_med": "bias med", + "bias_max": "bias max", + } + + def _label(agg: str) -> str: + return endpoint_labels[agg] if agg in endpoint_labels else mid_labels[agg] + + name_width = max((len(name) for name in results), default=7) + 1 + cell_width = max(len(_label(a)) for a in aggs) + cell_sep = " " + raw: dict[str, dict[str, float | None]] = {} + for name, rows in results.items(): + logits_fw = _named_row(rows, "head.logits") + logits_bw = _named_row(rows, "head.logits.grad") + logits_value = { + "fw": logits_fw["rms_rel"] if logits_fw else float("nan"), + "bw": logits_bw["rms_rel"] if logits_bw else float("nan"), + } + kind_rows = [r for r in rows if r["kind"] == kind] + values = [r["rms_rel"] for r in kind_rows] + if kind == "grad": + decoder_rows = [r for r in kind_rows if r["tensor_name"].split(":", 1)[-1].strip().startswith("decoder.")] + category_values: dict[str, list[float]] = {"linear": [], "norm": [], "bias": []} + for r in decoder_rows: + category_values[_grad_category(r["tensor_name"])].append(r["rms_rel"]) + lm_head_row = _named_row(kind_rows, _LM_HEAD_NAME) + embeddings_row = _named_row(kind_rows, _EMBEDDINGS_NAME) + else: + category_values = {} + lm_head_row = embeddings_row = None + intermediate = values[1:-1] or values + cells: dict[str, float | None] = {} + for agg in aggs: + if agg == "first": + cells[agg] = values[0] if values else None + elif agg == "last": + cells[agg] = values[-1] if values else None + elif agg == "logits": + cells[agg] = logits_value[kind] + elif agg == "lm_head": + cells[agg] = lm_head_row["rms_rel"] if lm_head_row else None + elif agg == "embeddings": + cells[agg] = embeddings_row["rms_rel"] if embeddings_row else None + elif "_" in agg and agg.split("_", 1)[0] in category_values: + cat, stat = agg.split("_", 1) + cat_values = category_values[cat] + if not cat_values: + cells[agg] = None + elif stat == "max": + cells[agg] = max(cat_values) + else: + cells[agg] = statistics.median(cat_values) + elif agg == "max": + cells[agg] = max(intermediate) if intermediate else None + else: + cells[agg] = statistics.median(intermediate) if intermediate else None + raw[name] = cells + + column_decimals = { + agg: _column_decimals(cells[agg] for cells in raw.values() if cells[agg] is not None) for agg in aggs + } + if kind == "grad": + subtitle = " (Relative %)" + else: + subtitle = " (Relative %; mid = excluding first/last)" + print(f"\n=== Summary: {kind}{subtitle} ===") + header = f"{'Variant':<{name_width}}" + cell_sep.join(f"{_label(a):<{cell_width}}" for a in aggs) + print(header) + print("-" * len(header)) + for name, cells in raw.items(): + formatted = [ + f"{cells[agg] * 100:.{column_decimals[agg]}f}%" if cells[agg] is not None else "n/a" for agg in aggs + ] + print(f"{name:<{name_width}}" + cell_sep.join(f"{c:<{cell_width}}" for c in formatted)) + + +def _column_decimals( + values: typing.Iterable[float], min_sig_figs: int = 2, default: int = 3, max_decimals: int | None = None +) -> int: + # Keep the default precision, but bump up so the smallest non-zero value carries at least + # `min_sig_figs` significant digits when formatted as percent. `max_decimals` caps the + # bump so a single tiny noisy value doesn't widen the whole column. + smallest = min((abs(v) * 100 for v in values if v != 0), default=None) + if smallest is None or smallest >= 10 ** -(default - min_sig_figs + 1): + result = default + else: + result = max(default, -math.floor(math.log10(smallest)) + min_sig_figs - 1) + return min(result, max_decimals) if max_decimals is not None else result + + +def _display_group(row: dict[str, typing.Any]) -> str: + # Map each row to one of "fw"/"bw"/"grad" for the per-variant table, independent + # of `kind`: head.logits is a forward activation, head.logits.grad is a backward + # quantity, parameter gradients are their own group. + if row["kind"] == "grad": + return "grad" + if row["kind"] == "bw" or row["tensor_name"].endswith(".grad"): + return "bw" + return "fw" + + +def _classify(tensor_name: str) -> str: + # Stage._log_layer_forward / _log_layer_backward produce " fw[, mb=…]" + # and " bw[, mb=…]"; log_distributed_tensor may prefix the name + # with "Global " and append a ": " suffix when reconstructing a + # tensor-parallel-global tensor. Per-parameter gradient logs come from + # `Fsdp.log_shard(name="gradient", ...)` and are tagged "grad" so they appear + # in the per-variant table but stay out of the fw/bw summary aggregation. + # Other entries (e.g. `Global : head.logits`, `Global : head.logits.grad`) come + # from the `_debug` / `output_hidden_states` path and are surfaced via dedicated + # logits columns in the summary. + if "gradient:" in tensor_name: + return "grad" + for kind in ("fw", "bw"): + if f" {kind}:" in tensor_name or f" {kind}," in tensor_name or tensor_name.endswith(f" {kind}"): + return kind + return "other" + + +def _print_table(name: str, rows: list[dict[str, typing.Any]]) -> None: + print(f"\n=== Variant: {name} ===") + if not rows: + print("(no matching tensors)") + return + name_fn = lambda r: f"{r['tensor_name'].split(':', 1)[-1].strip()} ({r['kind']})" + name_width = max(len("Tensor"), max(len(name_fn(r)) for r in rows)) + # Adaptive precision for the relative column: bump decimals so small but real values + # (typical for weight gradients) stay legible, capped at 5 to bound column width. + relative_decimals = _column_decimals((r["rms_rel"] for r in rows), default=2, max_decimals=5) + relative_fn = lambda r: f"{r['rms_rel'] * 100:.{relative_decimals}f}%" + bias_decimals = _column_decimals((r["bias_rel"] for r in rows), default=2, max_decimals=5) + bias_fn = lambda r: f"{r['bias_rel'] * 100:+.{bias_decimals}f}%" + relative_width = max(len("Relative"), max(len(relative_fn(r)) for r in rows)) + bias_width = max(len("Bias"), max(len(bias_fn(r)) for r in rows)) + columns: list[tuple[str, int, typing.Callable[[dict[str, typing.Any]], str]]] = [ + ("Tensor", name_width, name_fn), + ("Relative", relative_width, relative_fn), + ("Bias", bias_width, bias_fn), + ("Absolute", 10, lambda r: f"{r['rms_abs']:.4g}"), + ("Max", 10, lambda r: f"{r['max_abs']:.4g}"), + ("Scale", 10, lambda r: f"{r['ref_scale']:.4g}"), + ] + header = " ".join(f"{title:<{width}}" for title, width, _ in columns) + print(header) + print("-" * len(header)) + # Display grouping (fw / bw / grad) separates the chronologically-interleaved + # backward and reduce_gradients hooks. Independent of `kind` so the summary + # aggregation isn't affected. + groups = ("fw", "bw", "grad") + grouped: dict[str, list[dict[str, typing.Any]]] = {g: [] for g in groups} + for row in rows: + grouped[_display_group(row)].append(row) + first = True + for group in groups: + if not grouped[group]: + continue + if not first: + print() + first = False + for row in grouped[group]: + print(" ".join(f"{format_fn(row):<{width}}" for _, width, format_fn in columns)) + + +if __name__ == "__main__": + EvaluatePrecisionConfig.parse_and_run()