diff --git a/brainscore_language/model_helpers/localize.py b/brainscore_language/model_helpers/localize.py index 95b09897..a7cd4b7b 100644 --- a/brainscore_language/model_helpers/localize.py +++ b/brainscore_language/model_helpers/localize.py @@ -106,7 +106,8 @@ def extract_representations( logger.debug(f"> Using Device: {device}") model.eval() - model.to(device) + if not getattr(model, "hf_device_map", None): # skip when already sharded via device_map + model.to(device) final_layer_representations = { "sentences": {layer_name: np.zeros((len(langloc_dataset.sentences), hidden_dim)) for layer_name in layer_names},