Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
02fd39c
Add tool to evaluate layer-wise numerical-error propagation
jlamypoirier May 26, 2026
4dd6c14
Collapse to a single config; require a checkpoint
jlamypoirier May 27, 2026
5ebea33
Expose `model:` alongside `pretrained:` in the tool config
jlamypoirier May 27, 2026
4c444d8
Inherit PretrainedGPTModelConfig; use Config update mechanism
jlamypoirier May 27, 2026
35206a6
Expand HF metadata allowlist for newer transformers configs
jlamypoirier May 27, 2026
bde1efa
Reshape console table for readability
jlamypoirier May 27, 2026
8099b51
Merge tensor+kind, fix decimal precision in console table
jlamypoirier May 27, 2026
dbd7702
Switch back to fixed-decimal formatting in the table
jlamypoirier May 27, 2026
152ffc3
Wipe per-variant experiment dir before each run
jlamypoirier May 27, 2026
7e98500
Support pre-generated memmap dataset; misc table-format polish
jlamypoirier May 28, 2026
173ae0d
Print per-variant summary at the end of the run
jlamypoirier May 28, 2026
005fd62
Reshape end-of-run summary: variants × aggregations, relative only
jlamypoirier May 28, 2026
c594658
Clarify intermediate aggregation in summary header
jlamypoirier May 28, 2026
3159f73
Split summary across fw/bw rows; one extra precision digit
jlamypoirier May 28, 2026
6ef153e
Two-row column header in summary; chronological column order
jlamypoirier May 28, 2026
7327932
Add fp32_lm_head flag for vLLM precision parity
jlamypoirier May 28, 2026
76335df
Extract layer-name labels for summary first/last columns
jlamypoirier May 28, 2026
8122946
Add `debug_hidden_states_log` to capture named tensors via output_hid…
jlamypoirier May 28, 2026
4633bfd
Capture logit gradients; expose them in the summary
jlamypoirier May 28, 2026
9ca1711
Place logits after head in bw summary; widen format for sub-percent v…
jlamypoirier May 28, 2026
f2655f3
Pick per-column decimals to guarantee ≥2 sig figs
jlamypoirier May 28, 2026
7f8ef96
Tighten summary table spacing
jlamypoirier May 28, 2026
08b1637
Support HF Hub model ids in pretrained.path
jlamypoirier May 28, 2026
77eae22
Add example precision-evaluation configs
jlamypoirier May 28, 2026
efa95b1
Drop bf16_no_fp32_gradients variant from example configs
jlamypoirier May 28, 2026
46bc5b8
Add weight gradients to per-variant report tables
jlamypoirier May 28, 2026
bef2f0d
Separate fw/bw/grad rows in per-variant tables
jlamypoirier May 28, 2026
4fecad4
Split summary into three tables (fw, bw, grad)
jlamypoirier May 28, 2026
4f47dc0
Split grad summary by parameter category
jlamypoirier May 28, 2026
5198c25
Per-tensor sample-density overrides in TensorLogsConfig
jlamypoirier May 28, 2026
312343e
Chosen-logprob loss, per-variant grad-scale auto-calibration, fp16 va…
jlamypoirier May 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions examples/evaluate_precision/smol.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Example precision-evaluation config: sweep precision-stability features on SmolLM2-135M.
#
# Run with:
# python -m tools.evaluate_precision -c examples/evaluate_precision/smol.yaml
#
# `pretrained.path` accepts either a local checkpoint directory or a HF Hub model id
# (auto-downloaded via `huggingface_hub.snapshot_download` on first use).
pretrained:
path: HuggingFaceTB/SmolLM2-135M
format: llama
output_dir: /tmp/fast_llm_tests/evaluate_precision/features
sequence_length: 2048
variants:
# Baseline bf16: compute_dtype=bf16 + Fast-LLM defaults (fp32 gradient accumulation, bf16 residual, bf16 lm_head).
bf16:
model.distributed.compute_dtype: bfloat16
# Turn ON full-precision residual stream.
bf16_fp32_residual:
model.distributed.compute_dtype: bfloat16
model.base_model.embeddings.full_precision_residual: true
# Turn ON fp32 LM head matmul (PR #526).
bf16_fp32_lm_head:
model.distributed.compute_dtype: bfloat16
model.base_model.head.fp32_lm_head: true
# Both stability features on (most precise bf16-compute configuration).
bf16_max_precision:
model.distributed.compute_dtype: bfloat16
model.base_model.embeddings.full_precision_residual: true
model.base_model.head.fp32_lm_head: true
# Diagnostic: enable bf16 reduced-precision reductions in cuBLAS GEMMs. Tests whether the
# within-engine bf16-vs-fp32 gap is sensitive to the partial-sum reduction precision (the
# MMA accumulator is fp32 by hardware on H100/A100; this flag affects split-K reductions).
bf16_reduced_reduction:
model.distributed.compute_dtype: bfloat16
_torch_backend.cuda.matmul.allow_bf16_reduced_precision_reduction: true
# Diagnostic: simulate a "bf16 inputs, fp32 output" lm-head matmul kernel. fp32_lm_head=True
# upcasts inputs+weights to fp32, then matmul_precision='medium' runs the matmul through
# bf16 Tensor Cores anyway, then logits stay fp32. Tests whether fp32_lm_head's gain comes
# from input precision or from skipping the bf16 output cast.
bf16_in_fp32_out:
model.distributed.compute_dtype: bfloat16
model.base_model.head.fp32_lm_head: true
_torch_matmul_precision: medium
# fp16 sweep: probes whether the precision-vs-noise picture (rms noise ~0.1 nats per token
# for bf16) shrinks ~8× for fp16 (10 mantissa bits vs 7), as the literature's "switch to
# fp16" recommendation implies. Default dynamic grad-scaler (initial 2^16) is uniform
# across variants, so relative comparisons stay meaningful.
fp16:
model.distributed.compute_dtype: float16
fp16_fp32_residual:
model.distributed.compute_dtype: float16
model.base_model.embeddings.full_precision_residual: true
fp16_fp32_lm_head:
model.distributed.compute_dtype: float16
model.base_model.head.fp32_lm_head: true
fp16_max_precision:
model.distributed.compute_dtype: float16
model.base_model.embeddings.full_precision_residual: true
model.base_model.head.fp32_lm_head: true
52 changes: 52 additions & 0 deletions examples/evaluate_precision/smol_gspo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Example precision-evaluation config: sweep precision-stability features on SmolLM2-135M
# with the GSPO policy-gradient loss (uses advantages and old log-probabilities).
#
# Run with:
# python -m tools.evaluate_precision -c examples/evaluate_precision/smol_gspo.yaml
#
# `pretrained.path` accepts either a local checkpoint directory or a HF Hub model id
# (auto-downloaded via `huggingface_hub.snapshot_download` on first use).
pretrained:
path: HuggingFaceTB/SmolLM2-135M
format: llama
model:
base_model:
head:
losses:
gspo:
type: gspo
output_dir: /tmp/fast_llm_tests/evaluate_precision/gspo
data_path: /tmp/fast_llm_tests/evaluate_precision/gspo_data
sequence_length: 2048
variants:
bf16:
model.distributed.compute_dtype: bfloat16
bf16_fp32_residual:
model.distributed.compute_dtype: bfloat16
model.base_model.embeddings.full_precision_residual: true
bf16_fp32_lm_head:
model.distributed.compute_dtype: bfloat16
model.base_model.head.fp32_lm_head: true
bf16_max_precision:
model.distributed.compute_dtype: bfloat16
model.base_model.embeddings.full_precision_residual: true
model.base_model.head.fp32_lm_head: true
bf16_reduced_reduction:
model.distributed.compute_dtype: bfloat16
_torch_backend.cuda.matmul.allow_bf16_reduced_precision_reduction: true
bf16_in_fp32_out:
model.distributed.compute_dtype: bfloat16
model.base_model.head.fp32_lm_head: true
_torch_matmul_precision: medium
fp16:
model.distributed.compute_dtype: float16
fp16_fp32_residual:
model.distributed.compute_dtype: float16
model.base_model.embeddings.full_precision_residual: true
fp16_fp32_lm_head:
model.distributed.compute_dtype: float16
model.base_model.head.fp32_lm_head: true
fp16_max_precision:
model.distributed.compute_dtype: float16
model.base_model.embeddings.full_precision_residual: true
model.base_model.head.fp32_lm_head: true
6 changes: 6 additions & 0 deletions fast_llm/data/document/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ class LanguageModelBatchPreprocessingConfig(TokenPreprocessingConfig):
use_preference_spans: bool = Field(default=False)
use_grpo_data: bool = Field(default=False)
return_label_counts: bool = Field(default=False)
output_hidden_states: list[str] = Field(
default_factory=list,
desc="Regex patterns to add to each model input's `output_hidden_states` set."
" Matching `_debug`-named tensors get populated into `kwargs[hidden_states]`"
" and (when running under a `Run` context) emitted into `tensor_logs`.",
)

def _validate(self) -> None:
super()._validate()
Expand Down
7 changes: 7 additions & 0 deletions fast_llm/data/document/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ def get_model_inputs(self, config: LanguageModelBatchPreprocessingConfig) -> lis

self._set_target_inputs(model_inputs, config)

if config.output_hidden_states:
import re

patterns = {re.compile(pattern) for pattern in config.output_hidden_states}
for model_input in model_inputs:
model_input.output_hidden_states.update(patterns)

return model_inputs

def _set_target_inputs(
Expand Down
80 changes: 67 additions & 13 deletions fast_llm/engine/checkpoint/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,18 @@ def _get_key(cls, parameter_name: str, shard_name: str) -> str:
Assert.eq(shard_name, "weights")
return parameter_name

@classmethod
def _resolve_path(cls, path: pathlib.Path) -> pathlib.Path:
"""Resolve a local directory or HF Hub model id (e.g. ``meta-llama/Llama-3.2-1B``) to a
local snapshot directory. Local directories pass through unchanged; everything else is
materialized via :func:`huggingface_hub.snapshot_download` (cached on subsequent calls).
"""
if path.is_dir():
return path
import huggingface_hub

return pathlib.Path(huggingface_hub.snapshot_download(str(path)))

# Use custom config instead of relying on the transformers library
@classmethod
def _load_config(cls, directory: pathlib.Path | str) -> dict:
Expand Down Expand Up @@ -128,31 +140,72 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]:
{
# transformers PretrainedConfig
"_name_or_path",
"add_cross_attention",
"architectures",
"auto_map",
"chunk_size_feed_forward",
"cross_attention_hidden_size",
"dtype",
"finetuning_task",
"id2label",
"is_decoder",
"is_encoder_decoder",
"label2id",
"model_type",
"output_attentions",
"output_hidden_states",
"prefix",
"problem_type",
"pruned_heads",
"return_dict",
"task_specific_params",
"tf_legacy_loss",
"tie_encoder_decoder",
"tokenizer_class",
"torch_dtype",
"torchscript",
"transformers_version",
"use_bfloat16",
"use_cache",
# Token ids — generation/inference, not architecture.
"bos_token_id",
"decoder_start_token_id",
"eos_token_id",
"pad_token_id",
"sep_token_id",
# Generation defaults — never architecture.
"bad_words_ids",
"begin_suppress_tokens",
"diversity_penalty",
"do_sample",
"early_stopping",
"encoder_no_repeat_ngram_size",
"exponential_decay_length_penalty",
"forced_bos_token_id",
"forced_eos_token_id",
"length_penalty",
"max_length",
"min_length",
"no_repeat_ngram_size",
"num_beam_groups",
"num_beams",
"num_return_sequences",
"output_scores",
"remove_invalid_values",
"repetition_penalty",
"return_dict_in_generate",
"suppress_tokens",
"temperature",
"top_k",
"top_p",
"typical_p",
# Initialization / pretraining metadata Fast-LLM does not consume.
"initializer_range",
"max_position_embeddings",
"pretraining_tp",
# Family markers / default-valued knobs serialized by recent transformers versions.
"is_llama_config",
"rope_interleaved",
}
)

Expand Down Expand Up @@ -181,28 +234,29 @@ def _load_weights(
import transformers

Assert.eq(self.get_shard_names(config), ("weights",))
if (config.path / transformers.utils.SAFE_WEIGHTS_NAME).is_file():
paths = {config.path / transformers.utils.SAFE_WEIGHTS_NAME}
elif (config.path / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).is_file():
logger.info(f"Loading index from {config.path / transformers.utils.SAFE_WEIGHTS_INDEX_NAME}")
directory = self._resolve_path(config.path)
if (directory / transformers.utils.SAFE_WEIGHTS_NAME).is_file():
paths = {directory / transformers.utils.SAFE_WEIGHTS_NAME}
elif (directory / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).is_file():
logger.info(f"Loading index from {directory / transformers.utils.SAFE_WEIGHTS_INDEX_NAME}")
paths = {
config.path / path
for path in json.loads((config.path / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).read_text())[
directory / path
for path in json.loads((directory / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).read_text())[
"weight_map"
].values()
}
elif (config.path / transformers.utils.WEIGHTS_NAME).is_file():
paths = {config.path / transformers.utils.WEIGHTS_NAME}
elif (config.path / transformers.utils.WEIGHTS_INDEX_NAME).is_file():
logger.info(f"Loading index from {config.path / transformers.utils.WEIGHTS_INDEX_NAME}")
elif (directory / transformers.utils.WEIGHTS_NAME).is_file():
paths = {directory / transformers.utils.WEIGHTS_NAME}
elif (directory / transformers.utils.WEIGHTS_INDEX_NAME).is_file():
logger.info(f"Loading index from {directory / transformers.utils.WEIGHTS_INDEX_NAME}")
paths = {
config.path / path
for path in json.loads((config.path / transformers.utils.WEIGHTS_INDEX_NAME).read_text())[
directory / path
for path in json.loads((directory / transformers.utils.WEIGHTS_INDEX_NAME).read_text())[
"weight_map"
].values()
}
else:
raise FileNotFoundError(f"No compatible checkpoint found in {config.path}")
raise FileNotFoundError(f"No compatible checkpoint found in {directory}")

for path in paths:
logger.info(f"Loading from {path}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,52 @@ def _compare_dict_keys(self, dict_ref, dict_test, errors, name):
# Avoid set to preserve ordering.
return [key for key in dict_test if key in dict_ref]

def _compute_diff(self, tensor_ref, tensor_test, step_name, tensor_name) -> dict | None:
# Returns per-tensor error metrics, or None on shape/sampling mismatch.
if tensor_ref["shape"] != tensor_test["shape"]:
return None
if tensor_ref["step"] != tensor_test["step"]:
return None
sub_config = self._get_sub_config(step_name, tensor_name)
samples_ref = tensor_ref["samples"].flatten().float()
samples_test = tensor_test["samples"].flatten().float()
if sub_config.scale != 1.0:
samples_test = samples_test / sub_config.scale
scale_unreg = (samples_ref**2).mean() ** 0.5
rms_scale = (scale_unreg**2 + sub_config.rms_eps**2) ** 0.5
diff = samples_test - samples_ref
rms = (diff**2).mean() ** 0.5
max_diff = diff.abs().max()
bias = diff.mean()
# Linear-regression decomposition: `test ≈ slope * ref + intercept + residual`.
# Useful for separating systematic distortion (slope ≠ 1) from per-position decorrelated
# noise (residual). For RL importance ratios, slope ≠ 1 indicates likely-token-dependent
# bias which is more dangerous than a uniform shift.
centered_test = samples_test - samples_test.mean()
centered_ref = samples_ref - samples_ref.mean()
var_ref = (centered_ref**2).mean()
var_test = (centered_test**2).mean()
cov = (centered_test * centered_ref).mean()
denom = (var_test * var_ref) ** 0.5
correlation = (cov / denom).item() if denom > 0 else float("nan")
slope = (cov / var_ref).item() if var_ref > 0 else float("nan")
residual_var = (var_test - cov**2 / var_ref).clamp(min=0.0) if var_ref > 0 else var_test
residual_rms = residual_var**0.5
return {
"rms_abs": rms.item(),
"rms_rel": (rms / rms_scale).item(),
"max_abs": max_diff.item(),
"max_rel": (max_diff / rms_scale).item(),
"ref_scale": scale_unreg.item(),
"ref_scale_regularized": rms_scale.item(),
"bias_abs": bias.item(),
"bias_rel": (bias / rms_scale).item(),
"correlation": correlation,
"slope": slope,
"residual_rms_abs": residual_rms.item(),
"residual_rms_rel": (residual_rms / rms_scale).item(),
}

def compare_tensors(self, tensor_ref, tensor_test, errors, step_name, tensor_name):
sub_config = self._get_sub_config(step_name, tensor_name)
if tensor_ref["shape"] != tensor_test["shape"]:
Expand All @@ -108,34 +154,33 @@ def compare_tensors(self, tensor_ref, tensor_test, errors, step_name, tensor_nam
)
return

samples_ref = tensor_ref["samples"].flatten().float()
samples_test = tensor_test["samples"].flatten().float()
if sub_config.scale != 1.0:
samples_test = samples_test / sub_config.scale
scale_unreg = (samples_ref**2).mean() ** 0.5
rms_scale = (scale_unreg**2 + sub_config.rms_eps**2) ** 0.5
rms = ((samples_ref - samples_test) ** 2).mean() ** 0.5
max_diff = (samples_ref - samples_test).abs().max()
metrics = self._compute_diff(tensor_ref, tensor_test, step_name, tensor_name)
rms_scale = metrics["ref_scale_regularized"]
scale_unreg = metrics["ref_scale"]

tensor_errors = []

if rms > sub_config.rms_abs_tolerance:
tensor_errors.append(f" * RMS diff absolute = {rms} > {sub_config.rms_abs_tolerance}")
if metrics["rms_abs"] > sub_config.rms_abs_tolerance:
tensor_errors.append(f" * RMS diff absolute = {metrics['rms_abs']} > {sub_config.rms_abs_tolerance}")

if rms / rms_scale > sub_config.rms_rel_tolerance:
if metrics["rms_rel"] > sub_config.rms_rel_tolerance:
tensor_errors.append(
f" * RMS diff scaled = {rms / rms_scale} > {sub_config.rms_rel_tolerance} (scale={rms_scale}, unregularized={scale_unreg})"
f" * RMS diff scaled = {metrics['rms_rel']} > {sub_config.rms_rel_tolerance} (scale={rms_scale}, unregularized={scale_unreg})"
)

if max_diff > sub_config.max_abs_tolerance:
tensor_errors.append(f" * Max diff absolute = {max_diff} > {sub_config.max_abs_tolerance}")
if metrics["max_abs"] > sub_config.max_abs_tolerance:
tensor_errors.append(f" * Max diff absolute = {metrics['max_abs']} > {sub_config.max_abs_tolerance}")

if max_diff / rms_scale > sub_config.max_rel_tolerance:
if metrics["max_rel"] > sub_config.max_rel_tolerance:
tensor_errors.append(
f" * Max diff scaled = {max_diff / rms_scale} > {sub_config.max_rel_tolerance} (scale={rms_scale}, unregularized={scale_unreg})"
f" * Max diff scaled = {metrics['max_rel']} > {sub_config.max_rel_tolerance} (scale={rms_scale}, unregularized={scale_unreg})"
)

if tensor_errors:
samples_ref = tensor_ref["samples"].flatten().float()
samples_test = tensor_test["samples"].flatten().float()
if sub_config.scale != 1.0:
samples_test = samples_test / sub_config.scale
tensor_errors.extend(
[
f" Test samples: " + "".join(f"{x:12.4e}" for x in samples_test[: self.show_samples].tolist()),
Expand Down
9 changes: 9 additions & 0 deletions fast_llm/engine/config_utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ class TensorLogsConfig(Config):
valid=check_field(Assert.gt, 0),
)
full_tensors: bool = Field(default=False, desc="Save and/or print entire tensors.")
sample_level_overrides: dict[str, int] = Field(
default_factory=dict,
desc="Per-tensor sample-density overrides (regex pattern -> level)."
" For tensors whose logged name matches a pattern, the effective `log_tensor` level is"
" raised to the matching override (samples = 2 ** (level - 3))."
" Useful for sparse tensors like embedding-weight gradients where the default sampling"
" stride misses most non-zero rows.",
hint=FieldHint.logging,
)


class TensorLogs:
Expand Down
Loading
Loading