diff --git a/src/accelerate/utils/launch.py b/src/accelerate/utils/launch.py index 5bc35590cb1..8ebe866600e 100644 --- a/src/accelerate/utils/launch.py +++ b/src/accelerate/utils/launch.py @@ -130,6 +130,8 @@ def _apply_kt_config_to_env(args: argparse.Namespace, current_env: dict[str, str "kt_model_max_length": "ACCELERATE_KT_MODEL_MAX_LENGTH", "kt_skip_expert_loading": "ACCELERATE_KT_SKIP_EXPERT_LOADING", "kt_share_backward_bb": "ACCELERATE_KT_SHARE_BACKWARD_BB", + "kt_train_mode": "ACCELERATE_KT_TRAIN_MODE", + "kt_full_weight_grad": "ACCELERATE_KT_FULL_WEIGHT_GRAD", } for key, env_key in mapping.items():