diff --git a/petastorm/reader.py b/petastorm/reader.py index 008bfac9d..6e5a7126d 100644 --- a/petastorm/reader.py +++ b/petastorm/reader.py @@ -532,9 +532,7 @@ def _apply_predicate_to_row_groups(self, dataset, row_groups, predicate): raise ValueError('predicate parameter is expected to be derived from PredicateBase') predicate_fields = predicate.get_fields() - if set(predicate_fields) == dataset.partitions.partition_names: - assert len(dataset.partitions.partition_names) == 1, \ - 'Datasets with only a single partition level supported at the moment' + if set(predicate_fields).issubset(dataset.partitions.partition_names): filtered_row_group_indexes = [] for piece_index, piece in enumerate(row_groups): diff --git a/petastorm/tests/test_predicates.py b/petastorm/tests/test_predicates.py index 07333bc0d..13f8eb06f 100644 --- a/petastorm/tests/test_predicates.py +++ b/petastorm/tests/test_predicates.py @@ -154,12 +154,14 @@ def test_predicate_on_partitioned_dataset(tmpdir): """ TestSchema = Unischema('TestSchema', [ UnischemaField('id', np.int32, (), ScalarCodec(IntegerType()), False), + UnischemaField('id2', np.int32, (), ScalarCodec(IntegerType()), False), UnischemaField('test_field', np.int32, (), ScalarCodec(IntegerType()), False), ]) def test_row_generator(x): """Returns a single entry in the generated dataset.""" return {'id': x, + 'id2': x+1, 'test_field': x*x} rowgroup_size_mb = 256 @@ -177,11 +179,13 @@ def test_row_generator(x): spark.createDataFrame(rows_rdd, TestSchema.as_spark_schema()) \ .write \ - .partitionBy('id') \ + .partitionBy('id', 'id2') \ .parquet(dataset_url) with make_reader(dataset_url, predicate=in_lambda(['id'], lambda x: x == 3)) as reader: assert next(reader).id == 3 + with make_reader(dataset_url, predicate=in_lambda(['id2'], lambda x: x == 5)) as reader: + assert next(reader).id == 5 with make_reader(dataset_url, predicate=in_lambda(['id'], lambda x: x == '3')) as reader: with pytest.raises(StopIteration): # Predicate should have selected none, so a StopIteration should be raised.