diff --git a/maven/JarAssembler.kt b/maven/JarAssembler.kt index cdd4b5e3..acde77f9 100644 --- a/maven/JarAssembler.kt +++ b/maven/JarAssembler.kt @@ -28,17 +28,20 @@ import java.io.BufferedInputStream import java.io.BufferedOutputStream import java.io.File import java.io.FileOutputStream -import java.lang.RuntimeException +import java.io.InputStream import java.nio.charset.Charset import java.nio.file.Path import java.nio.file.Paths import java.util.concurrent.Callable import java.util.zip.ZipEntry import java.util.zip.ZipFile +import java.util.zip.ZipInputStream import java.util.zip.ZipOutputStream -import kotlin.collections.HashMap +import kotlin.RuntimeException import kotlin.system.exitProcess +typealias Entries = MutableMap +private fun Entries(): Entries = mutableMapOf() @Command(name = "jar-assembler", mixinStandardHelpOptions = true) class JarAssembler : Callable { @@ -58,45 +61,110 @@ class JarAssembler : Callable { @Option(names = ["--jars"], split = ";") lateinit var jars: Array - private val entries = HashMap() - private val entryNames = mutableSetOf() - override fun call() { - ZipOutputStream(BufferedOutputStream(FileOutputStream(outputFile))).use { out -> - if (pomFile != null) { + Entries().apply { + pomFile?.readBytes()?.let { pomContents -> val pomPath = "META-INF/maven/${groupId}/${artifactId}/pom.xml" - entries += preCreateDirectories(Paths.get(pomPath)) - entries[pomPath] = pomFile!!.readBytes() + this += preCreateDirectories(Paths.get(pomPath)) + this[pomPath] = pomContents } - for (jar in jars) { - ZipFile(jar).use { jarZip -> - jarZip.entries().asSequence().forEach { entry -> - if (entryNames.contains(entry.name)) { - throw RuntimeException("duplicate entry in the JAR: ${entry.name}") - } - if (entry.name.contains("META-INF")) { - // pom.xml will be added by us - return@forEach - } - if (entry.isDirectory) { - // needed directories would be added by us - return@forEach - } - entryNames.add(entry.name) - BufferedInputStream(jarZip.getInputStream(entry)).use { inputStream -> - val sourceFileBytes = inputStream.readBytes() - val resultLocation = getFinalPath(entry, sourceFileBytes) - entries += preCreateDirectories(Paths.get(resultLocation)) - entries[resultLocation] = sourceFileBytes + + ZipOutputStream(BufferedOutputStream(FileOutputStream(outputFile))).use { + if (outputFile.extension == "aar") assembleAar(it) + else assembleClassesJar(it) + } + } + } + + /** Assemble a class JAR containing the transitive class closure from [jars] and any pre-existing entries in [this] */ + private fun Entries.assembleClassesJar(output: ZipOutputStream, jars: List = this@JarAssembler.jars.toList()) { + for (jar in jars) { + if (jar.extension == "aar") { + throw RuntimeException("cannot package AAR within classes JAR") + } + processZip(ZipFile(jar)) + } + + writeEntries(output) + } + + /** Assemble an AAR from a base AAR containing the transitive class closure from the additional [jars] and any pre-existing entries in [this] */ + private fun Entries.assembleAar(output: ZipOutputStream) { + val classes = Entries() + processZip(jars.single { it.extension == "aar" }.let(::ZipFile)) { aar, entry -> + validateEntry(entry)?.let { + if (entry.name == "classes.jar") { + // pull out classes in nested JAR + entry.let(aar::getInputStream).let(::ZipInputStream).use { classesJar -> + var zipEntry: ZipEntry? = classesJar.nextEntry + while (zipEntry != null) { + classes.processEntry(classesJar, zipEntry) + zipEntry = classesJar.nextEntry } } + } else { + // add to top-level entries + processEntry(aar, entry) } } - entries.keys.sorted().forEach { - val newEntry = ZipEntry(it) - out.putNextEntry(newEntry) - out.write(entries[it]!!) - } + } + + // write classes jar first + ZipEntry("classes.jar").let(output::putNextEntry) + val classJar = ZipOutputStream(output) + classes.assembleClassesJar(classJar, jars.filter { it.extension != "aar" }) + classJar.finish() + + // write the rest of the entries + writeEntries(output) + } + + /** [process] each [ZipEntry] in [file] within the context of [this] */ + private fun Entries.processZip(file: ZipFile, process: Entries.(zip: ZipFile, entry: ZipEntry) -> Unit = { zip, entry -> processEntry(zip, entry) }) = file.use { zip -> + zip.entries().asSequence().forEach { entry -> + process(zip, entry) + } + } + + /** Validate [ZipEntry] and add information to [this] entries map */ + private fun Entries.processEntry(zip: ZipFile, entry: ZipEntry): Unit = BufferedInputStream(zip.getInputStream(entry)).use { + processEntry(it, entry) + } + + /** Validate [ZipEntry] and add information to [this] entries map */ + private fun Entries.processEntry(inputStream: InputStream, entry: ZipEntry) { + validateEntry(entry)?.let { + val sourceFileBytes = inputStream.readBytes() + val resultLocation = getFinalPath(it, sourceFileBytes) + this += preCreateDirectories(Paths.get(resultLocation)) + this[resultLocation] = sourceFileBytes + } + } + + /** Return null if this [entry] shouldn't be processed */ + private fun Entries.validateEntry(entry: ZipEntry): ZipEntry? = when { + entry.isDirectory -> { + // needed directories would be added by us + null + } + entry.name.contains("META-INF/maven") -> { + // pom.xml will be added by us + null + } + keys.contains(entry.name) -> { + // TODO: Investigate why I'm getting duplicates + println("I have a duplicate entry: ${entry.name}") + null + // throw RuntimeException("duplicate entry in the JAR: ${entry.name}") + } + else -> entry + } + + /** Write entries captured in [this] to [output] */ + private fun Entries.writeEntries(output: ZipOutputStream) { + entries.sortedBy(Map.Entry::key).forEach { (key, entry) -> + output.putNextEntry(ZipEntry(key)) + output.write(entry) } } @@ -104,7 +172,7 @@ class JarAssembler : Callable { * For path "a/b/c.java" inserts "a/" and "a/b/ into `entries` */ private fun preCreateDirectories(path: Path): Map { - val newEntries = HashMap() + val newEntries = Entries() for (i in path.nameCount-1 downTo 1) { val subPath = path.subpath(0, i).toString() + "/" newEntries[subPath] = ByteArray(0) diff --git a/maven/PomGenerator.kt b/maven/PomGenerator.kt index 79b2b3f2..ffb985d1 100644 --- a/maven/PomGenerator.kt +++ b/maven/PomGenerator.kt @@ -74,6 +74,9 @@ class PomGenerator : Callable { @Option(names = ["--target_deps_coordinates"]) lateinit var dependencyCoordinates: String + @Option(names = ["--packaging"]) + var packaging = "" + fun getLicenseInfo(license_id: String): Pair { return when { license_id.equals("apache") -> { @@ -224,6 +227,12 @@ class PomGenerator : Callable { versionElem.appendChild(pom.createTextNode(version)) rootElement.appendChild(versionElem) + if (packaging.isNotEmpty() && packaging != "jar") { + val packagingElem = pom.createElement("packaging") + packagingElem.appendChild(pom.createTextNode(packaging)) + rootElement.appendChild(packagingElem) + } + // add dependency information rootElement.appendChild(dependencies(pom, version, workspace_refs)) diff --git a/maven/rules.bzl b/maven/rules.bzl index 52be4e69..fe40c7a6 100644 --- a/maven/rules.bzl +++ b/maven/rules.bzl @@ -17,6 +17,13 @@ # under the License. # +# Known generic labels to automatically not include in closure +_DO_NOT_INCLUDE_IN_TRANSITIVE_CLOSURE_TARGETS = [ + Label("@bazel_tools//tools/android:android_jar"), +] + +def _is_android_library(target): + return AndroidLibraryAarInfo in target def _parse_maven_coordinates(coordinates_string, enforce_version_template=True): coordinates = coordinates_string.split(':') @@ -47,13 +54,14 @@ def _generate_version_file(ctx): def _generate_pom_file(ctx, version_file): target = ctx.attr.target - maven_coordinates = _parse_maven_coordinates(target[JarInfo].name) + jar_info = target[JarInfo] + maven_coordinates = _parse_maven_coordinates(jar_info.name) pom_file = ctx.actions.declare_file("{}_pom.xml".format(ctx.attr.name)) pom_deps = [] - for pom_dependency in [dep for dep in target[JarInfo].deps.to_list() if dep.type == 'pom']: + for pom_dependency in [dep for dep in jar_info.deps.to_list() if dep.type == 'pom']: pom_dependency = pom_dependency.maven_coordinates - if pom_dependency == target[JarInfo].name: + if pom_dependency == jar_info.name: continue pom_dependency_coordinates = _parse_maven_coordinates(pom_dependency, False) pom_dependency_artifact = pom_dependency_coordinates.group_id + ":" + pom_dependency_coordinates.artifact_id @@ -78,6 +86,7 @@ def _generate_pom_file(ctx, version_file): "--version_file=" + version_file.path, "--output_file=" + pom_file.path, "--workspace_refs_file=" + ctx.file.workspace_refs.path, + "--packaging=" + jar_info.packaging ], ) @@ -88,12 +97,14 @@ def _generate_class_jar(ctx, pom_file): maven_coordinates = _parse_maven_coordinates(target[JarInfo].name) jar = None - if hasattr(target, "files") and target.files.to_list() and target.files.to_list()[0].extension == "jar": + if (_is_android_library(target)): + jar = target[AndroidLibraryAarInfo].aar + elif hasattr(target, "files") and target.files.to_list() and target.files.to_list()[0].extension == "jar": jar = target[JavaInfo].outputs.jars[0].class_jar else: fail("Could not find JAR file to deploy in {}".format(target)) - output_jar = ctx.actions.declare_file("{}:{}.jar".format(maven_coordinates.group_id, maven_coordinates.artifact_id)) + output_jar = ctx.actions.declare_file("{}:{}.{}".format(maven_coordinates.group_id, maven_coordinates.artifact_id, target[JarInfo].packaging)) class_jar_deps = [dep.class_jar for dep in target[JarInfo].deps.to_list() if dep.type == 'jar'] class_jar_paths = [jar.path] + [target.path for target in class_jar_deps] @@ -119,7 +130,7 @@ def _generate_source_jar(ctx): srcjar = None - if hasattr(target, "files") and target.files.to_list() and target.files.to_list()[0].extension == "jar": + if _is_android_library(target) or (hasattr(target, "files") and target.files.to_list() and target.files.to_list()[0].extension == "jar"): for output in target[JavaInfo].outputs.jars: if output.source_jar and (output.source_jar.basename.endswith("-src.jar") or output.source_jar.basename.endswith("-sources.jar")): srcjar = output.source_jar @@ -159,7 +170,7 @@ def _assemble_maven_impl(ctx): return [ DefaultInfo(files = depset(output_files)), - MavenDeploymentInfo(jar = class_jar, pom = pom_file, srcjar = source_jar) + MavenDeploymentInfo(packaging = ctx.attr.target[JarInfo].packaging, jar = class_jar, pom = pom_file, srcjar = source_jar) ] def find_maven_coordinates(target, tags): @@ -176,6 +187,7 @@ JarInfo = provider( fields = { "name": "The name of a the JAR (Maven coordinates)", "deps": "The list of dependencies of this JAR. A dependency may be of two types, POM or JAR.", + "packaging": "The type of target to publish (jar, war, aar, etc.)" }, ) @@ -184,22 +196,33 @@ def _aggregate_dependency_info_impl(target, ctx): deps = getattr(ctx.rule.attr, "deps", []) runtime_deps = getattr(ctx.rule.attr, "runtime_deps", []) exports = getattr(ctx.rule.attr, "exports", []) + neverlink = getattr(ctx.rule.attr, "neverlink", False) deps_all = deps + exports + runtime_deps maven_coordinates = find_maven_coordinates(target, tags) dependencies = [] + packaging = "aar" if _is_android_library(target) else "jar" # depend via POM if maven_coordinates: dependencies = [struct( + target = target, type = "pom", maven_coordinates = maven_coordinates )] + # Hacky way to ignore something we don't care about but not crash + elif neverlink or target.label in _DO_NOT_INCLUDE_IN_TRANSITIVE_CLOSURE_TARGETS: + return JarInfo( + name = None, + deps = depset([]), + packaging = None, + ) # include runtime output jars - elif target[JavaInfo].runtime_output_jars: + elif JavaInfo in target: jars = target[JavaInfo].runtime_output_jars source_jars = target[JavaInfo].source_jars dependencies = [struct( + target = target, type = "jar", class_jar = jar, source_jar = source_jar, @@ -209,7 +232,7 @@ def _aggregate_dependency_info_impl(target, ctx): else: fail("Unsure how to package dependency for target: %s" % target) - return JarInfo( + jar_info = JarInfo( name = maven_coordinates, deps = depset(dependencies, transitive = [ # Filter transitive JARs from dependency that has maven coordinates @@ -218,8 +241,11 @@ def _aggregate_dependency_info_impl(target, ctx): depset([dep for dep in target[JarInfo].deps.to_list() if dep.type == 'pom']) if target[JarInfo].name else target[JarInfo].deps for target in deps_all ]), + packaging = packaging, ) + return jar_info + aggregate_dependency_info = aspect( attr_aspects = [ "jars", @@ -301,6 +327,7 @@ assemble_maven = rule( MavenDeploymentInfo = provider( fields = { + 'packaging': 'The type of target to publish (jar, war, aar, etc.)', 'jar': 'JAR file to deploy', 'srcjar': 'JAR file with sources', 'pom': 'Accompanying pom.xml file' @@ -314,6 +341,7 @@ def _deploy_maven_impl(ctx): lib_jar_link = "lib.jar" src_jar_link = "lib.srcjar" pom_xml_link = ctx.attr.target[MavenDeploymentInfo].pom.basename + packaging = ctx.attr.target[MavenDeploymentInfo].packaging ctx.actions.expand_template( template = ctx.file._deployment_script, @@ -323,7 +351,8 @@ def _deploy_maven_impl(ctx): "$SRCJAR_PATH": src_jar_link, "$POM_PATH": pom_xml_link, "{snapshot}": ctx.attr.snapshot, - "{release}": ctx.attr.release + "{release}": ctx.attr.release, + "$PACKAGING": packaging, } ) diff --git a/maven/templates/deploy.py b/maven/templates/deploy.py index a1f7218d..eff57d3e 100644 --- a/maven/templates/deploy.py +++ b/maven/templates/deploy.py @@ -95,6 +95,7 @@ def unpack_args(_, a, b=False): jar_path = "$JAR_PATH" pom_file_path = "$POM_PATH" srcjar_path = "$SRCJAR_PATH" +packaging = "$PACKAGING" namespace = { 'namespace': 'http://maven.apache.org/POM/4.0.0' } root = ElementTree.parse(pom_file_path).getroot() @@ -127,12 +128,13 @@ def unpack_args(_, a, b=False): 'must have a version which complies to this regex: {}' .format(version, repo_type, version_release_regex)) +artifact_extension = '.' + packaging filename_base = '{coordinates}/{artifact}/{version}/{artifact}-{version}'.format( coordinates=group_id.text.replace('.', '/'), version=version, artifact=artifact_id.text) -upload(maven_url, username, password, jar_path, filename_base + '.jar') +upload(maven_url, username, password, jar_path, filename_base + artifact_extension) if should_sign: - upload(maven_url, username, password, sign(jar_path), filename_base + '.jar.asc') + upload(maven_url, username, password, sign(jar_path), filename_base + artifact_extension + '.asc') upload(maven_url, username, password, pom_file_path, filename_base + '.pom') if should_sign: upload(maven_url, username, password, sign(pom_file_path), filename_base + '.pom.asc') @@ -158,12 +160,12 @@ def unpack_args(_, a, b=False): with tempfile.NamedTemporaryFile(mode='wt', delete=True) as jar_md5: jar_md5.write(md5(jar_path)) jar_md5.flush() - upload(maven_url, username, password, jar_md5.name, filename_base + '.jar.md5') + upload(maven_url, username, password, jar_md5.name, filename_base + artifact_extension + '.md5') with tempfile.NamedTemporaryFile(mode='wt', delete=True) as jar_sha1: jar_sha1.write(sha1(jar_path)) jar_sha1.flush() - upload(maven_url, username, password, jar_sha1.name, filename_base + '.jar.sha1') + upload(maven_url, username, password, jar_sha1.name, filename_base + artifact_extension + '.sha1') if os.path.exists(srcjar_path): with tempfile.NamedTemporaryFile(mode='wt', delete=True) as srcjar_md5: