diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index 65d72018..abb70ec8 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -786,7 +786,7 @@ def __init__( "Communicator needs to be instantiated with communication subsystem" f" derived from `comm_abc.Comm`, got {type(comm)}." ) - if comm.Get_size() != partitioner.total_ranks: + if comm.Get_size() < partitioner.total_ranks: raise ValueError( f"was given a partitioner for {partitioner.total_ranks} ranks but a " f"comm object with only {comm.Get_size()} ranks, are we running "