Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,14 @@ def compile(
# the dict directly. Note that we don't need to check any args, since the pool checks
# this on compile anyway.
if "_compiled_programs" not in self.__dict__:
static_params: tuple[str, ...] = ()
Comment thread
SF-N marked this conversation as resolved.
Outdated
if static_args:
# This is reached by `entry_point.compile(..., **static_args)` before delegating
# to `CompiledProgramsPool.compile()` below.
# Keep this in sync with `compile` in compiled_program.py.
static_params = tuple(static_args.keys())
self.__dict__["_compiled_programs"] = self._make_compiled_programs_pool(
static_params=tuple(static_args.keys()),
static_params=static_params,
static_domains=self.compilation_options.static_domains,
)

Expand Down
23 changes: 14 additions & 9 deletions src/gt4py/next/otf/compiled_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,15 +666,20 @@ def compile(
"""
for offset_provider in offset_providers: # not included in product for better type checking
for static_values in itertools.product(*static_args.values()):
# Only inject a `StaticArg` descriptor mapping when static arguments were
# actually given.
Comment thread
SF-N marked this conversation as resolved.
Outdated
argument_descriptors: ArgStaticDescriptorsByType = {}
if static_args:
# Calls from `Program.compile()`/`FieldOperator.compile()` reach this.
# Keep this in sync with `compile` in decorator.py.
argument_descriptors[arguments.StaticArg] = dict(
zip(
static_args.keys(),
[arguments.StaticArg(value=v) for v in static_values],
strict=True,
)
)
self._compile_variant(
argument_descriptors={
arguments.StaticArg: dict(
zip(
static_args.keys(),
[arguments.StaticArg(value=v) for v in static_values],
strict=True,
)
),
},
argument_descriptors=argument_descriptors,
offset_provider=offset_provider,
)
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,28 @@ def test_compile(cartesian_case, compile_testee):
assert np.allclose(kwargs["out"].ndarray, args[0].ndarray + args[1].ndarray)


def test_compile_with_empty_static_args_matches_no_static_args(cartesian_case, compile_testee):
Comment thread
SF-N marked this conversation as resolved.
Outdated
if cartesian_case.backend is None:
pytest.skip("Embedded compiled program doesn't make sense.")

empty_static_args = {}
decorator_path_testee = compile_testee.with_backend(cartesian_case.backend)
decorator_path_testee.compile(
offset_provider=cartesian_case.offset_provider, **empty_static_args
)
decorator_path_mapping = decorator_path_testee._compiled_programs.argument_descriptor_mapping
assert arguments.StaticArg not in decorator_path_mapping

compiled_program_path_testee = compile_testee.with_backend(cartesian_case.backend)
compiled_program_path_pool = compiled_program_path_testee._make_compiled_programs_pool(
static_params=(), static_domains=False
)
compiled_program_path_pool.argument_descriptor_mapping = None
compiled_program_path_pool.compile(offset_providers=[cartesian_case.offset_provider])

assert compiled_program_path_pool.argument_descriptor_mapping == decorator_path_mapping
Comment thread
SF-N marked this conversation as resolved.
Outdated


def test_compile_twice_same_program_errors(cartesian_case, compile_testee):
if cartesian_case.backend is None:
pytest.skip("Embedded compiled program doesn't make sense.")
Expand Down