diff --git a/python/cppyy/_cpython_cppyy.py b/python/cppyy/_cpython_cppyy.py index 8d6f1e0f..eb597ee1 100644 --- a/python/cppyy/_cpython_cppyy.py +++ b/python/cppyy/_cpython_cppyy.py @@ -114,6 +114,14 @@ def __call__(self, *args): # most common cases are covered if args: args0 = args[0] + + if ( + type(args0).__module__ == "numpy" + and type(args0).__name__ == "ndarray" + and hasattr(args0, "dtype") + ): + # Handle arrays of arbitrary dimension recursively + return _np_vector(args0) if args0 and (type(args0) is tuple or type(args0) is list): t = type(args0[0]) if t is float: t = 'double' @@ -209,3 +217,41 @@ def _end_capture_stderr(): pass return "C++ issued an error message that could not be decoded (%s)" % str(original_error) return "" + +def _np_vector(arr): + CPP_EXPLICIT_TYPES = {"float64": "double", "int64": "long"} + + def build_nested_vector_type(ndim, base_type, cache={}): + key = (ndim, base_type) + if key not in cache: + vector_t = gbl.std.vector[base_type] + for _ in range(ndim - 1): + vector_t = gbl.std.vector[vector_t] + cache[key] = vector_t + return cache[key] + + def convert(arr): + ndim = arr.ndim + if arr.size > 0: + base_type = CPP_EXPLICIT_TYPES.get( + arr.dtype.type.__name__, type(arr.flat[0].item()) + ) + else: + base_type = float + + if ndim == 1: + vector = build_nested_vector_type(1, base_type)() + vector.reserve(arr.size) + for elem in arr.flat: + vector.push_back(elem.item()) + return vector + + vector_type = build_nested_vector_type(ndim, base_type) + result = vector_type() + result.reserve(arr.shape[0]) + for subarr in arr: + result.push_back(convert(subarr)) + + return result + + return convert(arr) \ No newline at end of file diff --git a/test/test_stltypes.py b/test/test_stltypes.py index 771cc941..673396a0 100644 --- a/test/test_stltypes.py +++ b/test/test_stltypes.py @@ -789,6 +789,45 @@ def test23_copy_conversion(self): for f, d in zip(x, v): assert f == d +def test_ndarray_template_less(self): + import cppyy + + try: + import numpy as np + except ImportError: + self.skipTest("numpy is not installed") + dtype_mappings = { + np.int32: "int", + np.int64: "long", + np.float32: "float", + np.float64: "double", + } + + shapes = [ + (10,), # 1D array + (5, 5), # 2D array + (4, 4, 4), # 3D array + (2, 3, 3, 3), # 4D array + ] + + for np_dtype, cpp_dtype in dtype_mappings.items(): + for shape in shapes: + rng = np.random.default_rng(seed=42) + + if np.issubdtype(np_dtype, np.integer): + x = rng.integers(low=0, high=100, size=shape, dtype=np_dtype) + else: + x = rng.random(size=shape).astype(np_dtype) + + cpp_vector = cppyy.gbl.std.vector(x) + assert len(cpp_vector) == shape[0] + + if len(shape) > 1: + assert len(cpp_vector[0]) == shape[1] + if len(shape) > 2: + assert len(cpp_vector[0][0]) == shape[2] + if len(shape) > 3: + assert len(cpp_vector[0][0][0]) == shape[3] class TestSTLSTRING: def setup_class(cls):