From b9e21fab957a8e30f3046988da2fd6f95232a922 Mon Sep 17 00:00:00 2001 From: Keith Smiley Date: Thu, 26 Mar 2026 10:40:14 -0700 Subject: [PATCH] Add first class `cuda-compile` action This splits off `cuda-compile` from the `c++-compile` action (behind a incompatible flag) which allows toolchain maintainers to customize cuda compiles separately from normal C++ compiles. This is useful for passing the many cuda specific compiler flags. --- .../build/lib/rules/cpp/CcCommon.java | 1 + .../build/lib/rules/cpp/CppActionNames.java | 2 ++ .../build/lib/rules/cpp/CppCompileAction.java | 2 ++ .../rules/cpp/CppCompileActionBuilder.java | 5 +++ .../rules/cpp/CppCompileActionTemplate.java | 1 + .../build/lib/rules/cpp/CppConfiguration.java | 4 +++ .../build/lib/rules/cpp/CppFileTypes.java | 10 ++++-- .../build/lib/rules/cpp/CppOptions.java | 14 ++++++++ .../common/builtin_exec_platforms.bzl | 1 + .../cpp/CcLibraryConfiguredTargetTest.java | 34 +++++++++++++++++++ tools/build_defs/cc/action_names.bzl | 4 +++ 11 files changed, 76 insertions(+), 2 deletions(-) diff --git a/src/main/java/com/google/devtools/build/lib/rules/cpp/CcCommon.java b/src/main/java/com/google/devtools/build/lib/rules/cpp/CcCommon.java index b58223efe5a979..dcfa77c321aa31 100644 --- a/src/main/java/com/google/devtools/build/lib/rules/cpp/CcCommon.java +++ b/src/main/java/com/google/devtools/build/lib/rules/cpp/CcCommon.java @@ -42,6 +42,7 @@ public final class CcCommon { CppActionNames.CPP_MODULE_DEPS_SCANNING, CppActionNames.CPP20_MODULE_COMPILE, CppActionNames.CPP20_MODULE_CODEGEN, + CppActionNames.CUDA_COMPILE, CppActionNames.ASSEMBLE, CppActionNames.PREPROCESS_ASSEMBLE, CppActionNames.CLIF_MATCH, diff --git a/src/main/java/com/google/devtools/build/lib/rules/cpp/CppActionNames.java b/src/main/java/com/google/devtools/build/lib/rules/cpp/CppActionNames.java index 6b3556b2ee2ccd..a956a3e331301f 100644 --- a/src/main/java/com/google/devtools/build/lib/rules/cpp/CppActionNames.java +++ b/src/main/java/com/google/devtools/build/lib/rules/cpp/CppActionNames.java @@ -36,6 +36,8 @@ public class CppActionNames { public static final String OBJCPP_COMPILE = "objc++-compile"; /** A string constant for the c++ header parsing. */ public static final String CPP_HEADER_PARSING = "c++-header-parsing"; + /** A string constant for cuda compilation. */ + public static final String CUDA_COMPILE = "cuda-compile"; /** A string constant for the c++20 modules deps scanning */ public static final String CPP_MODULE_DEPS_SCANNING = "c++-module-deps-scanning"; diff --git a/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompileAction.java b/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompileAction.java index 52e4ca9fc0b5a0..f9e290ed68a7f2 100644 --- a/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompileAction.java +++ b/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompileAction.java @@ -1977,6 +1977,8 @@ static String actionNameToMnemonic( return "CppHeaderAnalysis"; case CppActionNames.CPP_MODULE_DEPS_SCANNING: return "CppDepsScanning"; + case CppActionNames.CUDA_COMPILE: + return "CudaCompile"; default: return CPP_COMPILE_MNEMONIC; } diff --git a/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompileActionBuilder.java b/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompileActionBuilder.java index 1da78d3fe556ee..3a1a014ed9b508 100644 --- a/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompileActionBuilder.java +++ b/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompileActionBuilder.java @@ -77,6 +77,7 @@ public final class CppCompileActionBuilder implements StarlarkValue { NestedSetBuilder.emptySet(Order.STABLE_ORDER); private ImmutableList additionalOutputs = ImmutableList.of(); private boolean needsIncludeValidation; + private boolean useCudaCompileAction; // New fields need to be added to the copy constructor. @@ -88,6 +89,7 @@ public CppCompileActionBuilder( this.shareable = false; this.configuration = configuration; this.cppConfiguration = configuration.getFragment(CppConfiguration.class); + this.useCudaCompileAction = this.cppConfiguration.useCudaCompileAction(); this.mandatoryInputsBuilder = NestedSetBuilder.stableOrder(); this.additionalIncludeScanningRoots = new ArrayList<>(); this.ccToolchain = ccToolchain; @@ -216,6 +218,9 @@ public String getActionName() { throw new IllegalStateException(); } else if (CppFileTypes.C_SOURCE.matches(sourcePath)) { return CppActionNames.C_COMPILE; + } else if (this.useCudaCompileAction && CppFileTypes.CUDA_SOURCE.matches(sourcePath)) { + // NOTE: Must be checked before C++ until .cu is removed from CPP_SOURCE + return CppActionNames.CUDA_COMPILE; } else if (CppFileTypes.CPP_SOURCE.matches(sourcePath)) { return CppActionNames.CPP_COMPILE; } else if (CppFileTypes.OBJC_SOURCE.matches(sourcePath)) { diff --git a/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompileActionTemplate.java b/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompileActionTemplate.java index 79bb5523d2030e..20253922ebf009 100644 --- a/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompileActionTemplate.java +++ b/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompileActionTemplate.java @@ -107,6 +107,7 @@ public final class CppCompileActionTemplate extends ActionKeyComputer FileTypeSet.of( CppFileTypes.CPP_SOURCE, CppFileTypes.CPP_HEADER, + CppFileTypes.CUDA_SOURCE, CppFileTypes.OBJC_SOURCE, CppFileTypes.OBJCPP_SOURCE, CppFileTypes.C_SOURCE, diff --git a/src/main/java/com/google/devtools/build/lib/rules/cpp/CppConfiguration.java b/src/main/java/com/google/devtools/build/lib/rules/cpp/CppConfiguration.java index dcb89301c7d604..6644f1368d79ad 100644 --- a/src/main/java/com/google/devtools/build/lib/rules/cpp/CppConfiguration.java +++ b/src/main/java/com/google/devtools/build/lib/rules/cpp/CppConfiguration.java @@ -795,6 +795,10 @@ public boolean useSpecificToolFilesForStarlark(StarlarkThread thread) throws Eva return cppOptions.useSpecificToolFiles; } + public boolean useCudaCompileAction() { + return cppOptions.useCudaCompileAction; + } + public boolean disableNoCopts() { return cppOptions.disableNoCopts; } diff --git a/src/main/java/com/google/devtools/build/lib/rules/cpp/CppFileTypes.java b/src/main/java/com/google/devtools/build/lib/rules/cpp/CppFileTypes.java index 9f2c65290768de..c0c3a0914b87b1 100644 --- a/src/main/java/com/google/devtools/build/lib/rules/cpp/CppFileTypes.java +++ b/src/main/java/com/google/devtools/build/lib/rules/cpp/CppFileTypes.java @@ -30,6 +30,7 @@ public final class CppFileTypes { // FileType is extended to use case-sensitive comparison also on Windows public static final FileType CPP_SOURCE = new FileType() { + // TODO: Remove .cu when --incompatible_cuda_compile_action is removed final ImmutableList extensions = ImmutableList.of(".cc", ".cpp", ".cxx", ".c++", ".C", ".cu", ".cl"); @@ -70,6 +71,7 @@ public ImmutableList getExtensions() { public static final FileType CLIF_INPUT_PROTO = FileType.of(".ipb"); public static final FileType CLIF_OUTPUT_PROTO = FileType.of(".opb"); public static final FileType BC_SOURCE = FileType.of(".bc"); + public static final FileType CUDA_SOURCE = FileType.of(".cu"); public static final FileTypeSet ALL_C_CLASS_SOURCE = FileTypeSet.of( @@ -77,11 +79,15 @@ public ImmutableList getExtensions() { CppFileTypes.C_SOURCE, CppFileTypes.OBJCPP_SOURCE, CppFileTypes.OBJC_SOURCE, - CppFileTypes.CLIF_INPUT_PROTO); + CppFileTypes.CLIF_INPUT_PROTO, + CppFileTypes.CUDA_SOURCE); // Filetypes that generate LLVM bitcode when -flto is specified. public static final FileTypeSet LTO_SOURCE = - FileTypeSet.of(CppFileTypes.CPP_SOURCE, CppFileTypes.C_SOURCE); + FileTypeSet.of( + CppFileTypes.CPP_SOURCE, + CppFileTypes.C_SOURCE, + CppFileTypes.CUDA_SOURCE); public static final FileType CPP_HEADER = FileType.of( diff --git a/src/main/java/com/google/devtools/build/lib/rules/cpp/CppOptions.java b/src/main/java/com/google/devtools/build/lib/rules/cpp/CppOptions.java index eb11271b51d8d4..9c1b4514010de1 100644 --- a/src/main/java/com/google/devtools/build/lib/rules/cpp/CppOptions.java +++ b/src/main/java/com/google/devtools/build/lib/rules/cpp/CppOptions.java @@ -872,6 +872,20 @@ public Label getMemProfProfileLabel() { + "actions. See https://github.com/bazelbuild/bazel/issues/8531") public boolean useSpecificToolFiles; + @Option( + name = "incompatible_cuda_compile_action", + defaultValue = "false", + documentationCategory = OptionDocumentationCategory.TOOLCHAIN, + effectTags = { + OptionEffectTag.LOADING_AND_ANALYSIS, + OptionEffectTag.ACTION_COMMAND_LINES, + OptionEffectTag.AFFECTS_OUTPUTS + }, + metadataTags = {OptionMetadataTag.INCOMPATIBLE_CHANGE}, + help = + "Compile cuda files using the cuda-compile action in the toolchain.") + public boolean useCudaCompileAction; + @Option( name = "incompatible_disable_nocopts", defaultValue = "true", diff --git a/src/main/starlark/builtins_bzl/common/builtin_exec_platforms.bzl b/src/main/starlark/builtins_bzl/common/builtin_exec_platforms.bzl index 93317b9503f0c5..40cb6517d4fcf9 100644 --- a/src/main/starlark/builtins_bzl/common/builtin_exec_platforms.bzl +++ b/src/main/starlark/builtins_bzl/common/builtin_exec_platforms.bzl @@ -290,6 +290,7 @@ bazel_fragments["CppOptions"] = fragment( "//command_line_option:incompatible_require_ctx_in_configure_features", "//command_line_option:incompatible_make_thinlto_command_lines_standalone", "//command_line_option:incompatible_use_specific_tool_files", + "//command_line_option:incompatible_cuda_compile_action", "//command_line_option:incompatible_disable_nocopts", "//command_line_option:incompatible_validate_top_level_header_inclusions", "//command_line_option:strict_system_includes", diff --git a/src/test/java/com/google/devtools/build/lib/rules/cpp/CcLibraryConfiguredTargetTest.java b/src/test/java/com/google/devtools/build/lib/rules/cpp/CcLibraryConfiguredTargetTest.java index 6d9abd12c97c3e..1c7191cc0b57a1 100644 --- a/src/test/java/com/google/devtools/build/lib/rules/cpp/CcLibraryConfiguredTargetTest.java +++ b/src/test/java/com/google/devtools/build/lib/rules/cpp/CcLibraryConfiguredTargetTest.java @@ -1207,6 +1207,40 @@ public void testIncompatibleUseCppCompileHeaderMnemonic() throws Exception { .isEqualTo("CppCompileHeader"); } + @Test + public void testCudaCompileActionMnemonic() throws Exception { + // TODO: Remove when we bump rules_cc + AnalysisMock.get() + .ccSupport() + .setupCcToolchainConfig( + mockToolsConfig, + CcToolchainConfig.builder().withActionConfigs(CppActionNames.CUDA_COMPILE)); + useConfiguration("--incompatible_cuda_compile_action"); + + ConfiguredTarget x = + scratchConfiguredTarget( + "foo", + "x", + "load('@rules_cc//cc:cc_library.bzl', 'cc_library')", + "cc_library(name = 'x', srcs = ['a.cu'])"); + + assertThat(getGeneratingCompileAction("_objs/x/a.o", x).getMnemonic()) + .isEqualTo("CudaCompile"); + } + + @Test + public void testCudaCompileWithoutFlagUsesCppCompile() throws Exception { + ConfiguredTarget x = + scratchConfiguredTarget( + "foo", + "x", + "load('@rules_cc//cc:cc_library.bzl', 'cc_library')", + "cc_library(name = 'x', srcs = ['a.cu'])"); + + assertThat(getGeneratingCompileAction("_objs/x/a.o", x).getMnemonic()) + .isEqualTo("CppCompile"); + } + private CppCompileAction getGeneratingCompileAction( String packageRelativePath, ConfiguredTarget owner) { return (CppCompileAction) getGeneratingAction(getBinArtifact(packageRelativePath, owner)); diff --git a/tools/build_defs/cc/action_names.bzl b/tools/build_defs/cc/action_names.bzl index e93e08ec72202b..9588c3a66615f4 100644 --- a/tools/build_defs/cc/action_names.bzl +++ b/tools/build_defs/cc/action_names.bzl @@ -19,6 +19,9 @@ C_COMPILE_ACTION_NAME = "c-compile" # Name of the C++ compilation action. CPP_COMPILE_ACTION_NAME = "c++-compile" +# Name of the CUDA compilation action. +CUDA_COMPILE_ACTION_NAME = "cuda-compile" + # Name of the linkstamp-compile action. LINKSTAMP_COMPILE_ACTION_NAME = "linkstamp-compile" @@ -104,6 +107,7 @@ VALIDATE_STATIC_LIBRARY = "validate-static-library" ACTION_NAMES = struct( c_compile = C_COMPILE_ACTION_NAME, cpp_compile = CPP_COMPILE_ACTION_NAME, + cuda_compile = CUDA_COMPILE_ACTION_NAME, linkstamp_compile = LINKSTAMP_COMPILE_ACTION_NAME, cc_flags_make_variable = CC_FLAGS_MAKE_VARIABLE_ACTION_NAME, cpp_module_codegen = CPP_MODULE_CODEGEN_ACTION_NAME,