Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ def choose_logger(
log_args: dict | None = None,
resume: bool | None = None,
**kwargs: Any,
):
) -> Any:
if logger_name == "csv":
return CSVLogger(root_dir=(out_dir / "logs"), name="csv", flush_logs_every_n_steps=log_interval, **kwargs)
if logger_name == "tensorboard":
Expand Down Expand Up @@ -607,7 +607,7 @@ def choose_logger(
)


def get_argument_names(cls):
def get_argument_names(cls: type) -> set[str]:
sig = inspect.signature(cls.__init__)
return {
name
Expand All @@ -616,7 +616,7 @@ def get_argument_names(cls):
}


def instantiate_bnb_optimizer(optimizer, model_parameters):
def instantiate_bnb_optimizer(optimizer: str | dict, model_parameters: Any) -> Any:
if (isinstance(optimizer, str) and "AdamW" not in optimizer) or (
isinstance(optimizer, dict) and "AdamW" not in optimizer.get("class_path", "")
):
Expand All @@ -633,7 +633,7 @@ def instantiate_bnb_optimizer(optimizer, model_parameters):
return optimizer


def instantiate_torch_optimizer(optimizer, model_parameters, **kwargs):
def instantiate_torch_optimizer(optimizer: str | dict, model_parameters: Any, **kwargs: Any) -> torch.optim.Optimizer:
# Special care taken where some optimizers do not have some parameters referenced in some of the code, for example "fused" in the pretrain.py script:
# bnb.optim.AdamW8bit
# grokadamw.GrokAdamW
Expand Down Expand Up @@ -679,7 +679,9 @@ def extend_checkpoint_dir(checkpoint_dir: Path) -> Path:
return new_checkpoint_dir if should_return_new_dir else checkpoint_dir


def check_file_size_on_cpu_and_warn(checkpoint_path, device, size_limit=4_509_715_660):
def check_file_size_on_cpu_and_warn(
checkpoint_path: Path | str, device: torch.device | str, size_limit: int = 4_509_715_660
) -> float:
"""
Checks the file size and raises a warning if it exceeds the size_limit.
The default size limit is 4.2 GB, the size of TinyLlama 1.1B: 4.2 * 1024 * 1024 * 1024 = 4_509_715_660
Expand All @@ -695,7 +697,9 @@ def check_file_size_on_cpu_and_warn(checkpoint_path, device, size_limit=4_509_71
return size


def auto_download_checkpoint(model_name, access_token=None, ignore_tokenizer_files=False):
def auto_download_checkpoint(
model_name: str, access_token: str | None = None, ignore_tokenizer_files: bool = False
) -> Path:
from litgpt.scripts.download import download_from_hub # moved here due to circular import issue

checkpoint_dir = extend_checkpoint_dir(Path(model_name))
Expand All @@ -716,7 +720,7 @@ def auto_download_checkpoint(model_name, access_token=None, ignore_tokenizer_fil
return checkpoint_dir


def check_nvlink_connectivity(fabric=None):
def check_nvlink_connectivity(fabric: L.Fabric | None = None) -> None:
"""Checks GPU connectivity for both NVIDIA and AMD GPUs.

This function delegates to vendor-specific implementations based on
Expand Down Expand Up @@ -744,7 +748,7 @@ def check_nvlink_connectivity(fabric=None):
custom_print(f"An error occurred while checking GPU connectivity: {e}")


def _check_nvidia_connectivity(custom_print):
def _check_nvidia_connectivity(custom_print: Any) -> None:
"""Checks NVLink connectivity on NVIDIA GPUs."""
result = subprocess.run(["nvidia-smi", "topo", "-m"], stdout=subprocess.PIPE, text=True)
if result.returncode != 0:
Expand Down Expand Up @@ -779,7 +783,7 @@ def _check_nvidia_connectivity(custom_print):
)


def _check_amd_connectivity(custom_print):
def _check_amd_connectivity(custom_print: Any) -> None:
"""Checks XGMI connectivity on AMD GPUs."""
result = subprocess.run(["rocm-smi", "--showtopotype"], stdout=subprocess.PIPE, text=True)
if result.returncode != 0:
Expand Down Expand Up @@ -825,7 +829,7 @@ def _check_amd_connectivity(custom_print):
)


def fix_and_load_json(s):
def fix_and_load_json(s: str) -> Any:
# Remove trailing commas before } or ]
s = re.sub(r",(\s*[}\]])", r"\1", s)

Expand All @@ -842,7 +846,7 @@ def fix_and_load_json(s):
raise ValueError(f"Failed to parse JSON after fixing: {e}")


def create_finetuning_performance_report(training_time, token_counts, device_type):
def create_finetuning_performance_report(training_time: float, token_counts: dict[str, int], device_type: str) -> str:
tok_sec = token_counts["raw_tokens_plus_prompt_template_and_padding"] / training_time
output = f"""
| ------------------------------------------------------
Expand All @@ -866,7 +870,7 @@ def create_finetuning_performance_report(training_time, token_counts, device_typ
return output


def select_sft_generate_example(eval, data):
def select_sft_generate_example(eval: Any, data: Any) -> str:
if eval.evaluate_example == "first":
if len(data.test_dataset.data):
instruction = data.test_dataset.data[0]["instruction"]
Expand Down Expand Up @@ -895,7 +899,7 @@ def select_sft_generate_example(eval, data):
return instruction


def _RunIf(thunder: bool = False, **kwargs):
def _RunIf(thunder: bool = False, **kwargs: Any) -> Any:
import pytest
from lightning.fabric.utilities.testing import _runif_reasons

Expand All @@ -908,7 +912,7 @@ def _RunIf(thunder: bool = False, **kwargs):
return pytest.mark.skipif(condition=len(reasons) > 0, reason=f"Requires: [{' + '.join(reasons)}]", **marker_kwargs)


def kill_process_tree(pid: int):
def kill_process_tree(pid: int) -> None:
"""
Kill a process and all its child processes given the parent PID.
"""
Expand Down
Loading