diff --git a/bindings/python/pymongoarrow/api.py b/bindings/python/pymongoarrow/api.py index d3d2286f..1a7d5578 100644 --- a/bindings/python/pymongoarrow/api.py +++ b/bindings/python/pymongoarrow/api.py @@ -169,12 +169,12 @@ def args_iterable(): if parallelism == "threads": with ThreadPoolExecutor(max_workers=4) as executor: results = list(executor.map(lambda args: _process_batch(*args), args_iterable())) - return pa.concat_tables(results, promote_options="default") + return pa.concat_tables(results, promote_options="permissive") if parallelism == "processes": with multiprocessing.Pool(processes=4) as pool: results = pool.starmap(_process_batch, args_iterable()) - return pa.concat_tables(results, promote_options="default") + return pa.concat_tables(results, promote_options="permissive") context = PyMongoArrowContext( schema, codec_options=collection.codec_options, allow_invalid=allow_invalid diff --git a/bindings/python/pymongoarrow/lib.pyx b/bindings/python/pymongoarrow/lib.pyx index e2be9084..f8fdaee5 100644 --- a/bindings/python/pymongoarrow/lib.pyx +++ b/bindings/python/pymongoarrow/lib.pyx @@ -244,7 +244,16 @@ cdef class BuilderManager: builder = <_ArrayBuilderBase>self.builder_map.get(full_key, None) # If the inferred type was int32 but the same field has an int64 value, # re-infer the field's type since int32 is a strict subset of int64. - if not self.has_schema and (builder is None or builder.type_marker == BSON_TYPE_INT32 and value_t == BSON_TYPE_INT64): + # If builder already existed, avoid ditching previously appended values. + if not self.has_schema and builder is not None and builder.type_marker == BSON_TYPE_INT32 and value_t == BSON_TYPE_INT64: + old_array = builder.finish().cast('int64') + builder = self.get_builder(full_key, value_t, doc_iter, True) + for val in old_array: + if val.is_valid: + (builder).builder.get().Append(val.as_py()) + else: + (builder).builder.get().AppendNull() + elif not self.has_schema and builder is None: builder = self.get_builder(full_key, value_t, doc_iter, True) if builder is None: continue diff --git a/bindings/python/test/pandas_types/test_int32_overflow.py b/bindings/python/test/pandas_types/test_int32_overflow.py index 71db5a13..ed6717d5 100644 --- a/bindings/python/test/pandas_types/test_int32_overflow.py +++ b/bindings/python/test/pandas_types/test_int32_overflow.py @@ -24,7 +24,8 @@ def test_aggregate_pandas_all_schema_inference_int32_avoids_overflow(): coll.insert_many( [ {"_id": 1, "value": 1}, - {"_id": 2, "value": 2**40}, # much larger than Int32 max + {"_id": 2, "value": None}, + {"_id": 3, "value": 2**40}, # much larger than Int32 max ] ) @@ -35,7 +36,7 @@ def test_aggregate_pandas_all_schema_inference_int32_avoids_overflow(): df = coll.aggregate_pandas_all(pipeline) - assert len(df) == 2 + assert len(df) == 3 assert df["value"].max() == 2**40 client.close() @@ -60,6 +61,6 @@ def test_aggregate_pandas_all_explicit_int64_schema_avoids_overflow(): df = coll.aggregate_pandas_all(pipeline, schema=schema) - assert len(df) == 2 + assert len(df) == 3 assert df["value"].max() == 2**40 client.close() diff --git a/bindings/python/test/test_arrow.py b/bindings/python/test/test_arrow.py index bac86b49..61014ef0 100644 --- a/bindings/python/test/test_arrow.py +++ b/bindings/python/test/test_arrow.py @@ -1401,3 +1401,61 @@ def test_find_arrow_all_parallelism_options(self): table_off.equals(table_thread), msg=f"tables differ:\n{table_off}\n\n{table_thread}", ) + + def test_find_multiple_batches_of_different_schema(self): + docs = [{"_id": i, "value": i} for i in range(10)] + [ + { + "_id": 10, + "value": 2**40, # Value much larger than Int32 max + } + ] + self.coll.insert_many(docs) + + orig_method = self.coll.find_raw_batches + + def mock_find_raw_batches(*args, **kwargs): + kwargs["batch_size"] = 2 + return orig_method(*args, **kwargs) + + with mock.patch.object( + pymongo.collection.Collection, + "find_raw_batches", + wraps=mock_find_raw_batches, + ): + table_off = find_arrow_all( + self.coll, + {}, + parallelism="off", + ) + table_proc = find_arrow_all( + self.coll, + {}, + parallelism="processes", + ) + table_thread = find_arrow_all( + self.coll, + {}, + parallelism="threads", + ) + + self.assertEqual(table_off.num_rows, len(docs)) + self.assertEqual(table_proc.num_rows, len(docs)) + self.assertEqual(table_thread.num_rows, len(docs)) + + self.assertTrue( + table_off.schema.equals(table_proc.schema), + msg=f"{table_off.schema} != {table_proc.schema}", + ) + self.assertTrue( + table_off.schema.equals(table_thread.schema), + msg=f"{table_off.schema} != {table_thread.schema}", + ) + + self.assertTrue( + table_off.equals(table_proc), + msg=f"tables differ:\n{table_off}\n\n{table_proc}", + ) + self.assertTrue( + table_off.equals(table_thread), + msg=f"tables differ:\n{table_off}\n\n{table_thread}", + )