diff --git a/litgpt/utils.py b/litgpt/utils.py index ec79fa0764..fa0113e16c 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -568,7 +568,7 @@ def choose_logger( log_args: dict | None = None, resume: bool | None = None, **kwargs: Any, -): +) -> Any: if logger_name == "csv": return CSVLogger(root_dir=(out_dir / "logs"), name="csv", flush_logs_every_n_steps=log_interval, **kwargs) if logger_name == "tensorboard": @@ -607,7 +607,7 @@ def choose_logger( ) -def get_argument_names(cls): +def get_argument_names(cls: type) -> set[str]: sig = inspect.signature(cls.__init__) return { name @@ -616,7 +616,7 @@ def get_argument_names(cls): } -def instantiate_bnb_optimizer(optimizer, model_parameters): +def instantiate_bnb_optimizer(optimizer: str | dict, model_parameters: Any) -> Any: if (isinstance(optimizer, str) and "AdamW" not in optimizer) or ( isinstance(optimizer, dict) and "AdamW" not in optimizer.get("class_path", "") ): @@ -633,7 +633,7 @@ def instantiate_bnb_optimizer(optimizer, model_parameters): return optimizer -def instantiate_torch_optimizer(optimizer, model_parameters, **kwargs): +def instantiate_torch_optimizer(optimizer: str | dict, model_parameters: Any, **kwargs: Any) -> torch.optim.Optimizer: # Special care taken where some optimizers do not have some parameters referenced in some of the code, for example "fused" in the pretrain.py script: # bnb.optim.AdamW8bit # grokadamw.GrokAdamW @@ -679,7 +679,9 @@ def extend_checkpoint_dir(checkpoint_dir: Path) -> Path: return new_checkpoint_dir if should_return_new_dir else checkpoint_dir -def check_file_size_on_cpu_and_warn(checkpoint_path, device, size_limit=4_509_715_660): +def check_file_size_on_cpu_and_warn( + checkpoint_path: Path | str, device: torch.device | str, size_limit: int = 4_509_715_660 +) -> float: """ Checks the file size and raises a warning if it exceeds the size_limit. The default size limit is 4.2 GB, the size of TinyLlama 1.1B: 4.2 * 1024 * 1024 * 1024 = 4_509_715_660 @@ -695,7 +697,9 @@ def check_file_size_on_cpu_and_warn(checkpoint_path, device, size_limit=4_509_71 return size -def auto_download_checkpoint(model_name, access_token=None, ignore_tokenizer_files=False): +def auto_download_checkpoint( + model_name: str, access_token: str | None = None, ignore_tokenizer_files: bool = False +) -> Path: from litgpt.scripts.download import download_from_hub # moved here due to circular import issue checkpoint_dir = extend_checkpoint_dir(Path(model_name)) @@ -716,7 +720,7 @@ def auto_download_checkpoint(model_name, access_token=None, ignore_tokenizer_fil return checkpoint_dir -def check_nvlink_connectivity(fabric=None): +def check_nvlink_connectivity(fabric: L.Fabric | None = None) -> None: """Checks GPU connectivity for both NVIDIA and AMD GPUs. This function delegates to vendor-specific implementations based on @@ -744,7 +748,7 @@ def check_nvlink_connectivity(fabric=None): custom_print(f"An error occurred while checking GPU connectivity: {e}") -def _check_nvidia_connectivity(custom_print): +def _check_nvidia_connectivity(custom_print: Any) -> None: """Checks NVLink connectivity on NVIDIA GPUs.""" result = subprocess.run(["nvidia-smi", "topo", "-m"], stdout=subprocess.PIPE, text=True) if result.returncode != 0: @@ -779,7 +783,7 @@ def _check_nvidia_connectivity(custom_print): ) -def _check_amd_connectivity(custom_print): +def _check_amd_connectivity(custom_print: Any) -> None: """Checks XGMI connectivity on AMD GPUs.""" result = subprocess.run(["rocm-smi", "--showtopotype"], stdout=subprocess.PIPE, text=True) if result.returncode != 0: @@ -825,7 +829,7 @@ def _check_amd_connectivity(custom_print): ) -def fix_and_load_json(s): +def fix_and_load_json(s: str) -> Any: # Remove trailing commas before } or ] s = re.sub(r",(\s*[}\]])", r"\1", s) @@ -842,7 +846,7 @@ def fix_and_load_json(s): raise ValueError(f"Failed to parse JSON after fixing: {e}") -def create_finetuning_performance_report(training_time, token_counts, device_type): +def create_finetuning_performance_report(training_time: float, token_counts: dict[str, int], device_type: str) -> str: tok_sec = token_counts["raw_tokens_plus_prompt_template_and_padding"] / training_time output = f""" | ------------------------------------------------------ @@ -866,7 +870,7 @@ def create_finetuning_performance_report(training_time, token_counts, device_typ return output -def select_sft_generate_example(eval, data): +def select_sft_generate_example(eval: Any, data: Any) -> str: if eval.evaluate_example == "first": if len(data.test_dataset.data): instruction = data.test_dataset.data[0]["instruction"] @@ -895,7 +899,7 @@ def select_sft_generate_example(eval, data): return instruction -def _RunIf(thunder: bool = False, **kwargs): +def _RunIf(thunder: bool = False, **kwargs: Any) -> Any: import pytest from lightning.fabric.utilities.testing import _runif_reasons @@ -908,7 +912,7 @@ def _RunIf(thunder: bool = False, **kwargs): return pytest.mark.skipif(condition=len(reasons) > 0, reason=f"Requires: [{' + '.join(reasons)}]", **marker_kwargs) -def kill_process_tree(pid: int): +def kill_process_tree(pid: int) -> None: """ Kill a process and all its child processes given the parent PID. """