diff --git a/libs/@local/hashql/core/src/graph/algorithms/color/mod.rs b/libs/@local/hashql/core/src/graph/algorithms/color/mod.rs new file mode 100644 index 00000000000..3b85ab3bcfb --- /dev/null +++ b/libs/@local/hashql/core/src/graph/algorithms/color/mod.rs @@ -0,0 +1,238 @@ +//! Three-color depth-first search for directed graphs. +//! +//! Implements a DFS where each node transitions through three states: +//! +//! - **White** (unvisited): not yet encountered. +//! - **Gray** (in the `gray` set): discovered but not yet finished; still on the DFS stack. +//! - **Black** (in the `black` set): all successors have been processed. +//! +//! The color of a node when it is re-encountered determines the edge classification: +//! +//! | Re-encounter color | Meaning | +//! |--------------------|------------------| +//! | `None` (white) | Tree edge | +//! | `Some(Gray)` | Back edge (cycle)| +//! | `Some(Black)` | Cross/forward | +//! +//! This is an iterative (stack-based) implementation. The visitor receives callbacks +//! at two points: when a node is first examined ([`node_examined`]) and when all its +//! successors are finished ([`node_finished`]). The `node_finished` callback fires in +//! postorder. +//! +//! [`node_examined`]: TriColorVisitor::node_examined +//! [`node_finished`]: TriColorVisitor::node_finished + +use alloc::alloc::Global; +use core::{alloc::Allocator, ops::Try}; + +use crate::{ + graph::{DirectedGraph, Successors}, + id::bit_vec::DenseBitSet, +}; + +#[cfg(test)] +mod tests; + +/// DFS node state. +/// +/// Passed to [`TriColorVisitor::node_examined`] as the `before` parameter to indicate +/// what state a node was in when it was re-encountered. A value of `None` means the node +/// was white (first discovery). +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum NodeColor { + /// On the current DFS path. Re-encountering a gray node means a back edge (cycle). + Gray, + /// Fully processed. Re-encountering a black node means a cross or forward edge. + Black, +} + +/// Internal event pushed onto the DFS stack. +/// +/// Each node generates two events: `Gray` on discovery (explore successors) +/// and `Black` when all successors are done (finish the node). +struct Event { + node: N, + next: NodeColor, +} + +/// Iterative three-color DFS over a directed graph. +/// +/// Reusable across multiple `run` calls. Each call to [`run`](Self::run) resets all +/// internal state before starting from the given root. +/// +/// The graph's full node domain is used to size the internal bitsets, so node IDs +/// from the graph can be used directly without remapping. +pub struct TriColorDepthFirstSearch<'graph, G: ?Sized, N, A: Allocator = Global> { + graph: &'graph G, + stack: Vec, A>, + + /// Nodes that have been discovered (entered the DFS stack). + gray: DenseBitSet, + /// Nodes whose successors have all been processed. + black: DenseBitSet, +} + +impl<'graph, G: DirectedGraph + ?Sized> TriColorDepthFirstSearch<'graph, G, G::NodeId, Global> { + #[inline] + pub fn new(graph: &'graph G) -> Self { + Self::new_in(graph, Global) + } +} + +impl<'graph, G: DirectedGraph + ?Sized, A: Allocator> + TriColorDepthFirstSearch<'graph, G, G::NodeId, A> +{ + pub fn new_in(graph: &'graph G, alloc: A) -> Self { + Self { + graph, + stack: Vec::new_in(alloc), + gray: DenseBitSet::new_empty(graph.node_count()), + black: DenseBitSet::new_empty(graph.node_count()), + } + } + + /// Clears all traversal state (gray set, black set, stack). + /// + /// Call this before a sequence of [`run_from`](Self::run_from) calls to start + /// with a clean slate. + pub fn reset(&mut self) { + self.stack.clear(); + self.gray.clear(); + self.black.clear(); + } + + /// Run a DFS from `root`, resetting all state first. + /// + /// Equivalent to calling [`reset`](Self::reset) followed by + /// [`run_from`](Self::run_from). Use this when each DFS should be independent. + pub fn run(&mut self, root: G::NodeId, visitor: &mut V) -> V::Result + where + V: TriColorVisitor, + G: Successors, + { + self.reset(); + self.run_from(root, visitor) + } + + /// Run a DFS from `root` without resetting state. + /// + /// Nodes already in the gray or black sets from previous calls are treated as + /// previously visited. This allows running DFS from multiple roots while + /// accumulating state: a node finished (black) by an earlier root is skipped, + /// so each connected component is explored at most once. + /// + /// Stops early if the visitor returns a residual (e.g., `Err` or + /// `ControlFlow::Break`). Edges for which [`TriColorVisitor::ignore_edge`] + /// returns `true` are not followed. + pub fn run_from(&mut self, root: G::NodeId, visitor: &mut V) -> V::Result + where + V: TriColorVisitor, + G: Successors, + { + self.stack.push(Event { + node: root, + next: NodeColor::Gray, + }); + + while let Some(Event { node, next }) = self.stack.pop() { + match next { + NodeColor::Black => { + let not_previously_finished = self.black.insert(node); + debug_assert!( + not_previously_finished, + "a node should be finished exactly once" + ); + + visitor.node_finished(node)?; + } + NodeColor::Gray => { + let newly_discovered = self.gray.insert(node); + let previous_color = if newly_discovered { + None + } else if self.black.contains(node) { + Some(NodeColor::Black) + } else { + Some(NodeColor::Gray) + }; + + visitor.node_examined(node, previous_color)?; + + // Already visited through another path: nothing more to do. + if previous_color.is_some() { + continue; + } + + self.stack.push(Event { + node, + next: NodeColor::Black, + }); + for successor in self.graph.successors(node) { + if !visitor.ignore_edge(node, successor) { + self.stack.push(Event { + node: successor, + next: NodeColor::Gray, + }); + } + } + } + } + } + + Try::from_output(()) + } +} + +/// Callbacks for [`TriColorDepthFirstSearch`]. +/// +/// All methods have default no-op implementations, so visitors only need to +/// override the events they care about. +pub trait TriColorVisitor { + /// The control-flow type returned by each callback. + /// + /// Use `Result<(), E>` or `ControlFlow` to support early termination. + type Result: Try; + + /// Called when a node is encountered during DFS. + /// + /// `before` indicates the node's color at the time of re-encounter: + /// - `None`: first discovery (white to gray transition). + /// - `Some(Gray)`: back edge, indicating a cycle. + /// - `Some(Black)`: cross or forward edge. + #[expect(unused_variables)] + fn node_examined(&mut self, node: G::NodeId, before: Option) -> Self::Result { + Try::from_output(()) + } + + /// Called after all successors of `node` have been fully processed. + /// + /// Fires in postorder: a node finishes only after all its descendants finish. + #[expect(unused_variables)] + fn node_finished(&mut self, node: G::NodeId) -> Self::Result { + Try::from_output(()) + } + + /// Return `true` to skip this edge during traversal. + /// + /// Allows restricting the DFS to a subgraph without constructing a + /// separate graph data structure. + #[expect(unused_variables)] + fn ignore_edge(&mut self, source: G::NodeId, target: G::NodeId) -> bool { + false + } +} + +/// A [`TriColorVisitor`] that detects cycles. +/// +/// Returns `Err(())` as soon as a back edge (re-encounter of a gray node) is found. +pub struct CycleDetector; + +impl TriColorVisitor for CycleDetector { + type Result = Result<(), ()>; + + fn node_examined(&mut self, _: G::NodeId, before: Option) -> Self::Result { + match before { + Some(NodeColor::Gray) => Err(()), + _ => Ok(()), + } + } +} diff --git a/libs/@local/hashql/core/src/graph/algorithms/color/tests.rs b/libs/@local/hashql/core/src/graph/algorithms/color/tests.rs new file mode 100644 index 00000000000..fccbed31f01 --- /dev/null +++ b/libs/@local/hashql/core/src/graph/algorithms/color/tests.rs @@ -0,0 +1,250 @@ +use core::ops::ControlFlow; + +use super::{NodeColor, TriColorDepthFirstSearch, TriColorVisitor}; +use crate::{ + graph::{DirectedGraph as _, NodeId, tests::TestGraph}, + id::Id as _, +}; + +macro_rules! n { + ($id:expr) => { + NodeId::from_usize($id) + }; +} + +struct CycleDetector; + +impl TriColorVisitor for CycleDetector { + type Result = ControlFlow; + + fn node_examined(&mut self, node: NodeId, before: Option) -> Self::Result { + match before { + Some(NodeColor::Gray) => ControlFlow::Break(node), + _ => ControlFlow::Continue(()), + } + } +} + +fn has_cycle(graph: &TestGraph) -> bool { + let mut search = TriColorDepthFirstSearch::new(graph); + (0..graph.node_count()).any(|i| search.run(n!(i), &mut CycleDetector).is_break()) +} + +fn cycle_target(graph: &TestGraph) -> Option { + let mut search = TriColorDepthFirstSearch::new(graph); + for i in 0..graph.node_count() { + if let ControlFlow::Break(target) = search.run(n!(i), &mut CycleDetector) { + return Some(target); + } + } + None +} + +struct PostOrderCollector { + order: Vec, +} + +impl TriColorVisitor for PostOrderCollector { + type Result = ControlFlow<()>; + + fn node_finished(&mut self, node: NodeId) -> Self::Result { + self.order.push(node); + ControlFlow::Continue(()) + } +} + +fn postorder(graph: &TestGraph, root: usize) -> Vec { + let mut search = TriColorDepthFirstSearch::new(graph); + let mut collector = PostOrderCollector { order: Vec::new() }; + let _: ControlFlow<()> = search.run(n!(root), &mut collector); + collector.order +} + +#[test] +fn self_loop_is_cyclic() { + let graph = TestGraph::new(&[(0, 0)]); + assert!(has_cycle(&graph)); + assert_eq!(cycle_target(&graph), Some(n!(0))); +} + +#[test] +fn two_node_cycle() { + let graph = TestGraph::new(&[(0, 1), (1, 0)]); + assert!(has_cycle(&graph)); +} + +#[test] +fn three_node_cycle() { + let graph = TestGraph::new(&[(0, 1), (1, 2), (2, 0)]); + assert!(has_cycle(&graph)); +} + +#[test] +fn linear_chain_no_cycle() { + let graph = TestGraph::new(&[(0, 1), (1, 2), (2, 3)]); + assert!(!has_cycle(&graph)); +} + +#[test] +fn diamond_no_cycle() { + let graph = TestGraph::new(&[(0, 1), (0, 2), (1, 3), (2, 3)]); + assert!(!has_cycle(&graph)); +} + +#[test] +fn diamond_with_back_edge_is_cyclic() { + let graph = TestGraph::new(&[(0, 1), (0, 2), (1, 3), (2, 3), (3, 0)]); + assert!(has_cycle(&graph)); +} + +#[test] +fn disconnected_with_cycle_in_second_component() { + // Component 1: 0 -> 1 (no cycle) + // Component 2: 2 -> 3 -> 2 (cycle) + let graph = TestGraph::new(&[(0, 1), (2, 3), (3, 2)]); + assert!(has_cycle(&graph)); +} + +#[test] +fn disconnected_no_cycle() { + let graph = TestGraph::new(&[(0, 1), (2, 3)]); + assert!(!has_cycle(&graph)); +} + +#[test] +fn isolated_node_no_cycle() { + // Single node, no edges (TestGraph needs at least one edge to set node_count, + // so use two disconnected nodes with one edge). + let graph = TestGraph::new(&[(0, 1)]); + assert!(!has_cycle(&graph)); +} + +#[test] +fn postorder_linear_chain() { + // 0 -> 1 -> 2 + let graph = TestGraph::new(&[(0, 1), (1, 2)]); + let order = postorder(&graph, 0); + assert_eq!(order, [n!(2), n!(1), n!(0)]); +} + +#[test] +fn postorder_diamond() { + // 0 -> 1, 0 -> 2, 1 -> 3, 2 -> 3 + let graph = TestGraph::new(&[(0, 1), (0, 2), (1, 3), (2, 3)]); + let order = postorder(&graph, 0); + + // 3 must come before both 1 and 2; 1 and 2 must come before 0. + assert_eq!(order.len(), 4); + assert_eq!(*order.last().expect("non-empty"), n!(0)); + + let pos = |id: usize| order.iter().position(|&n| n == n!(id)).expect("non-empty"); + assert!(pos(3) < pos(1)); + assert!(pos(3) < pos(2)); + assert!(pos(1) < pos(0)); + assert!(pos(2) < pos(0)); +} + +#[test] +fn postorder_unreachable_node_not_visited() { + // 0 -> 1, node 2 exists but is unreachable from 0 + let graph = TestGraph::new(&[(0, 1), (2, 2)]); + let order = postorder(&graph, 0); + + // Only nodes reachable from root 0 + assert_eq!(order, [n!(1), n!(0)]); +} + +struct FilteredCycleDetector { + ignored: (usize, usize), +} + +impl TriColorVisitor for FilteredCycleDetector { + type Result = ControlFlow<()>; + + fn node_examined(&mut self, _: NodeId, before: Option) -> Self::Result { + match before { + Some(NodeColor::Gray) => ControlFlow::Break(()), + _ => ControlFlow::Continue(()), + } + } + + fn ignore_edge(&mut self, source: NodeId, target: NodeId) -> bool { + source == n!(self.ignored.0) && target == n!(self.ignored.1) + } +} + +#[test] +fn ignore_edge_breaks_cycle() { + // 0 -> 1 -> 2 -> 0 (cycle); ignoring 2 -> 0 removes the cycle + let graph = TestGraph::new(&[(0, 1), (1, 2), (2, 0)]); + + let mut search = TriColorDepthFirstSearch::new(&graph); + let mut visitor = FilteredCycleDetector { ignored: (2, 0) }; + let result = search.run(n!(0), &mut visitor); + assert!(result.is_continue()); +} + +#[test] +fn ignore_edge_wrong_edge_keeps_cycle() { + // 0 -> 1 -> 2 -> 0 (cycle); ignoring 0 -> 1 still leaves 1 -> 2 -> 0 reachable + // from 0? No: if 0 -> 1 is ignored, DFS from 0 has no successors, no cycle found. + // But the cycle B -> C -> A still exists if we start from 1. + let graph = TestGraph::new(&[(0, 1), (1, 2), (2, 0)]); + + let mut search = TriColorDepthFirstSearch::new(&graph); + let mut visitor = FilteredCycleDetector { ignored: (0, 1) }; + + // From node 0: no successors after filtering, no cycle + assert!(search.run(n!(0), &mut visitor).is_continue()); + + // From node 1: 1 -> 2 -> 0 -> (0->1 ignored) -> done, no back edge to gray + // Wait: 0's successor 1 is ignored, so from 0 we go nowhere. But from 1: 1->2->0, + // then 0 has no unignored successors. 0 finishes. No cycle. + assert!(search.run(n!(1), &mut visitor).is_continue()); +} + +#[test] +fn run_resets_between_calls() { + let graph = TestGraph::new(&[(0, 1), (1, 0)]); + let mut search = TriColorDepthFirstSearch::new(&graph); + + // First run: finds cycle + assert!(search.run(n!(0), &mut CycleDetector).is_break()); + + // Second run on same search: state is reset, should find cycle again + assert!(search.run(n!(0), &mut CycleDetector).is_break()); +} + +#[test] +fn run_from_accumulates_state() { + // 0->1->2, 3->1 (node 1 reachable from both roots) + // Without accumulation, run_from(3) would re-explore 1->2 and emit them again. + // With accumulation, nodes 1 and 2 are already black after run_from(0). + let graph = TestGraph::new(&[(0, 1), (1, 2), (3, 1)]); + let mut search = TriColorDepthFirstSearch::new(&graph); + let mut collector = PostOrderCollector { order: Vec::new() }; + + search.reset(); + let _: ControlFlow<()> = search.run_from(n!(0), &mut collector); + let _: ControlFlow<()> = search.run_from(n!(3), &mut collector); + + // Nodes 1 and 2 finished during first run_from; second run_from only finishes 3. + assert_eq!(collector.order, [n!(2), n!(1), n!(0), n!(3)]); +} + +#[test] +fn run_from_skips_already_finished_nodes() { + // 0->1->2, 3->2 (shared sink at 2) + let graph = TestGraph::new(&[(0, 1), (1, 2), (3, 2)]); + let mut search = TriColorDepthFirstSearch::new(&graph); + let mut collector = PostOrderCollector { order: Vec::new() }; + + search.reset(); + let _: ControlFlow<()> = search.run_from(n!(0), &mut collector); + let _: ControlFlow<()> = search.run_from(n!(3), &mut collector); + + // Node 2 should appear exactly once (finished during first run_from), + // not re-emitted when reached from node 3. + assert_eq!(collector.order.iter().filter(|&&n| n == n!(2)).count(), 1); + assert_eq!(collector.order.len(), 4); // 2, 1, 0, 3 +} diff --git a/libs/@local/hashql/core/src/graph/algorithms/mod.rs b/libs/@local/hashql/core/src/graph/algorithms/mod.rs index 5509a2f47aa..bfd5647551d 100644 --- a/libs/@local/hashql/core/src/graph/algorithms/mod.rs +++ b/libs/@local/hashql/core/src/graph/algorithms/mod.rs @@ -25,6 +25,7 @@ //! # assert_eq!(visited, [n1, n2]); //! ``` +pub mod color; pub mod dominators; pub mod tarjan; @@ -32,6 +33,7 @@ use alloc::collections::VecDeque; use core::iter::FusedIterator; pub use self::{ + color::{CycleDetector, TriColorDepthFirstSearch, TriColorVisitor}, dominators::{ DominanceFrontier, DominatorFrontiers, Dominators, IteratedDominanceFrontier, dominance_frontiers, dominators, iterated_dominance_frontier, diff --git a/libs/@local/hashql/core/src/graph/algorithms/tarjan/mod.rs b/libs/@local/hashql/core/src/graph/algorithms/tarjan/mod.rs index edda7708ba7..3a1bd412feb 100644 --- a/libs/@local/hashql/core/src/graph/algorithms/tarjan/mod.rs +++ b/libs/@local/hashql/core/src/graph/algorithms/tarjan/mod.rs @@ -203,6 +203,59 @@ where pub fn of(&self, id: S) -> &[N] { self.as_slice().of(id) } + + #[inline] + pub fn iter(&self) -> impl ExactSizeIterator + DoubleEndedIterator { + self.sccs().map(|scc| (scc, self.of(scc))) + } + + // TODO: miri tests + #[expect(unsafe_code)] + pub fn iter_mut( + &mut self, + ) -> impl ExactSizeIterator + DoubleEndedIterator + '_ { + let ptr = self.nodes.as_mut_ptr(); + let offsets = &self.offsets; + + offsets.ids().take(self.offsets.len() - 1).map(move |scc| { + let start = offsets[scc]; + let end = offsets[scc.plus(1)]; + + // SAFETY: The start and end indices are valid for the nodes slice, and members is + // non-overlapping by construction + (scc, unsafe { + core::slice::from_raw_parts_mut(ptr.add(start), end - start) + }) + }) + } +} + +impl<'this, N, S, A: Allocator> IntoIterator for &'this Members +where + S: Id, +{ + type Item = (S, &'this [N]); + + type IntoIter = impl ExactSizeIterator + DoubleEndedIterator; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'this, N, S, A: Allocator> IntoIterator for &'this mut Members +where + S: Id, +{ + type Item = (S, &'this mut [N]); + + type IntoIter = impl ExactSizeIterator + DoubleEndedIterator; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.iter_mut() + } } /// Storage for the computed SCCs and their relationships. diff --git a/libs/@local/hashql/mir/src/pass/analysis/callgraph/mod.rs b/libs/@local/hashql/mir/src/pass/analysis/callgraph/mod.rs index 8c26ca917bf..028911c12dc 100644 --- a/libs/@local/hashql/mir/src/pass/analysis/callgraph/mod.rs +++ b/libs/@local/hashql/mir/src/pass/analysis/callgraph/mod.rs @@ -228,6 +228,17 @@ impl CallGraph<'_, A> { Some(DefId::new(edge.source().as_u32())) } + + #[inline] + pub fn callers(&self, def: DefId) -> impl Iterator { + let node = NodeId::from_usize(def.as_usize()); + + self.inner.incoming_edges(node).map(move |edge| CallSite { + caller: DefId::new(edge.source().as_u32()), + kind: edge.data, + target: def, + }) + } } impl fmt::Display for CallGraph<'_, A> { diff --git a/libs/@local/hashql/mir/src/pass/transform/inline/find.rs b/libs/@local/hashql/mir/src/pass/transform/inline/find.rs index 66efbc1cc12..103a3a97aa0 100644 --- a/libs/@local/hashql/mir/src/pass/transform/inline/find.rs +++ b/libs/@local/hashql/mir/src/pass/transform/inline/find.rs @@ -23,10 +23,8 @@ use crate::{ /// /// A callsite is eligible if: /// - It's a direct call (function is a constant `FnPtr`). -/// - Its target SCC has not already been inlined into this caller. -/// -/// The SCC check prevents cycles: once we've inlined a function (or any function -/// in its SCC) into a filter, we won't inline it again. +/// - It's not a self-call. +/// - Its target is not a loop breaker. pub(crate) struct FindCallsiteVisitor<'ctx, 'state, 'env, 'heap, A: Allocator> { /// The filter function we're finding callsites in. pub caller: DefId, @@ -53,10 +51,10 @@ impl<'heap, A: Allocator> Visitor<'heap> for FindCallsiteVisitor<'_, '_, '_, 'he return Ok(()); }; - let target_component = self.state.components.scc(ptr); - - // Skip if we've already inlined this SCC into this caller. - if self.state.inlined.contains(self.caller, target_component) { + // Skip self-calls and calls to loop breakers. Breakers are the cycle + // cut points: inlining them would reintroduce the recursion that + // breaker selection removed. + if ptr == self.caller || self.state.loop_breakers.contains(ptr) { return Ok(()); } diff --git a/libs/@local/hashql/mir/src/pass/transform/inline/loop_breaker.rs b/libs/@local/hashql/mir/src/pass/transform/inline/loop_breaker.rs new file mode 100644 index 00000000000..4f8dc405935 --- /dev/null +++ b/libs/@local/hashql/mir/src/pass/transform/inline/loop_breaker.rs @@ -0,0 +1,413 @@ +//! Loop-breaker selection for recursive SCCs. +//! +//! When functions form a mutually recursive group (SCC), the inliner cannot inline all +//! calls without diverging. This module selects which functions to mark as loop breakers: +//! calls to a breaker within its SCC are skipped, while calls from a breaker to non-breakers +//! are still inlined. This flattens most of the call chain without infinite expansion. +//! +//! The approach follows GHC's loop-breaker strategy (Peyton Jones & Marlow 2002): select +//! which *nodes* to mark as non-inlineable rather than which *edges* to cut. All edges +//! targeting a loop breaker become non-inlineable. This reduces the problem from feedback +//! arc set (NP-hard) to feedback vertex set, which is tractable for the small SCCs (at most +//! ~12 nodes) that appear in practice. +//! +//! # Algorithm +//! +//! [`LoopBreaker::run_in`] processes every non-trivial SCC (size > 1) in the call graph: +//! +//! 1. **Score** each member by inverse inlining value via [`LoopBreaker::score`]. Higher score = +//! less valuable to inline = better breaker candidate. +//! 2. **Select** breakers greedily: pick the highest-scored member (least valuable to inline), mark +//! it as a breaker, then check if the remaining members still contain a cycle +//! ([`LoopBreaker::has_cycle`]). Repeat until the remaining subgraph is acyclic. This produces a +//! sufficient (not necessarily minimal) feedback vertex set. +//! 3. **Reorder** the SCC members via [`LoopBreaker::order`]: non-breakers appear in DFS postorder +//! (callees before callers), followed by breakers. This ordering ensures that when a function is +//! processed, its non-breaker callees within the same SCC have already been optimized. +//! +//! The members slice is mutated in place so the caller can iterate it directly. +//! +//! # Scoring +//! +//! The breaker score (see [`InlineLoopBreakerConfig`]) combines: +//! +//! - **Body cost** (positive contribution): large functions are expensive to duplicate. +//! - **Caller count** (negative): functions with many call sites lose more inlining opportunities +//! when chosen as breakers. +//! - **Unique callsite** (negative): a single call site means zero duplication on inline. +//! - **Leaf status** (negative): leaves are safe, cheap inlining targets. +//! - **Inline directive**: `Never` maps to `+inf` (ideal breaker), `Always` to `-inf` (avoided +//! unless every other candidate has been exhausted). +//! +//! # Cycle detection +//! +//! After each breaker is selected, the remaining non-breaker subgraph is checked +//! for cycles using three-color DFS ([`TriColorDepthFirstSearch`]). The DFS runs +//! on the full [`CallGraph`] with an [`ignore_edge`] filter that restricts traversal +//! to non-breaker SCC members. State is accumulated across roots via +//! [`run_from`](TriColorDepthFirstSearch::run_from) so disconnected components +//! (which appear when breaker removal splits the subgraph) are all covered. +//! +//! # Postorder computation +//! +//! Once breakers are selected, the non-breaker members form a DAG. Their processing +//! order is computed as DFS postorder over a [`CallSubgraph`] that filters the +//! call graph to non-breaker members. Breaker members are appended after the +//! non-breakers. +//! +//! [`ignore_edge`]: TriColorVisitor::ignore_edge + +use core::{alloc::Allocator, iter, ops::ControlFlow}; + +use hashql_core::{ + graph::{ + DirectedGraph, Successors, + algorithms::{ + DepthFirstForestPostOrder, TriColorDepthFirstSearch, TriColorVisitor, + color::NodeColor, + tarjan::{Members, SccId}, + }, + }, + heap::BumpAllocator, + id::bit_vec::DenseBitSet, +}; + +use super::analysis::{BodyProperties, InlineDirective}; +use crate::{ + def::{DefId, DefIdSlice}, + pass::analysis::{CallGraph, CallKind}, +}; + +/// Configuration for loop-breaker selection within recursive SCCs. +/// +/// Controls the scoring function that determines which SCC members are selected +/// as loop breakers. Higher breaker scores indicate better breaker candidates +/// (less valuable to inline). +/// +/// # Scoring Formula +/// +/// ```text +/// score = cost_weight * body_cost +/// - caller_penalty * apply_caller_count +/// - unique_callsite_penalty (if exactly one callsite targets this function) +/// - leaf_penalty (if function has no outgoing calls) +/// ``` +/// +/// Functions with `InlineDirective::Never` get score `+inf` (ideal breakers). +/// Functions with `InlineDirective::Always` get score `-inf` (avoided unless +/// every other candidate has been exhausted). +#[derive(Debug, Copy, Clone, PartialEq)] +pub struct InlineLoopBreakerConfig { + /// Weight applied to body cost. + /// + /// Large functions are expensive to duplicate at each call site, making them + /// good breaker candidates. + /// + /// Default: `1.0`. + pub cost_weight: f32, + + /// Penalty per apply-callsite caller. + /// + /// Functions called from many sites provide more inlining opportunities. + /// Selecting them as breakers loses those opportunities for every caller. + /// + /// Default: `5.0`. + pub caller_penalty: f32, + + /// Penalty for functions with exactly one callsite. + /// + /// A unique callsite means inlining causes zero code duplication, making + /// the function a poor breaker choice. + /// + /// Default: `15.0`. + pub unique_callsite_penalty: f32, + + /// Penalty for leaf functions. + /// + /// Leaves have no outgoing calls (except intrinsics) and cannot trigger + /// further inlining cascades, making them safe and valuable to inline. + /// + /// Default: `10.0`. + pub leaf_penalty: f32, +} + +impl Default for InlineLoopBreakerConfig { + fn default() -> Self { + Self { + cost_weight: 1.0, + caller_penalty: 5.0, + unique_callsite_penalty: 15.0, + leaf_penalty: 10.0, + } + } +} + +/// A view of the [`CallGraph`] induced on the non-breaker members of a single SCC. +/// +/// Both source and target are filtered: a node outside the non-breaker member set +/// has no successors, and edges targeting nodes outside it are dropped. +/// +/// [`node_count`](DirectedGraph::node_count) returns the full call graph domain so +/// that traversal algorithms size their bitsets correctly for the global `DefId` space. +struct CallSubgraph<'ctx, 'heap, A: Allocator> { + inner: &'ctx CallGraph<'heap, A>, + members: &'ctx [DefId], + breakers: &'ctx DenseBitSet, +} + +impl DirectedGraph for CallSubgraph<'_, '_, A> { + type Edge<'this> + = (DefId, DefId) + where + Self: 'this; + type EdgeId = (DefId, DefId); + type Node<'this> + = DefId + where + Self: 'this; + type NodeId = DefId; + + fn node_count(&self) -> usize { + // Must match the full DefId domain so that DenseBitSet/MixedBitSet + // in traversal algorithms are sized correctly for any DefId index. + self.inner.node_count() + } + + fn edge_count(&self) -> usize { + self.inner.edge_count() + } + + #[expect(unreachable_code)] + fn iter_nodes(&self) -> impl ExactSizeIterator> + DoubleEndedIterator { + unimplemented!(); + iter::empty() + } + + #[expect(unreachable_code)] + fn iter_edges(&self) -> impl ExactSizeIterator> + DoubleEndedIterator { + unimplemented!(); + iter::empty() + } +} + +impl Successors for CallSubgraph<'_, '_, A> { + type SuccIter<'this> + = impl Iterator + where + Self: 'this; + + fn successors(&self, node: Self::NodeId) -> Self::SuccIter<'_> { + let in_subgraph = self.members.contains(&node) && !self.breakers.contains(node); + + self.inner.successors(node).filter(move |&succ| { + in_subgraph && self.members.contains(&succ) && !self.breakers.contains(succ) + }) + } +} + +/// Entry point for loop-breaker selection and SCC reordering. +pub(crate) struct LoopBreaker<'ctx, 'heap, A: Allocator> { + pub config: InlineLoopBreakerConfig, + pub graph: &'ctx CallGraph<'heap, A>, + pub properties: &'ctx DefIdSlice>, + pub search: TriColorDepthFirstSearch<'ctx, CallGraph<'heap, A>, DefId, A>, +} + +impl LoopBreaker<'_, '_, A> { + /// Select loop breakers and reorder members for every non-trivial SCC. + /// + /// After this call, for each non-trivial SCC: + /// - A sufficient set of breakers has been selected to make the remainder acyclic. + /// - The member slice is reordered: non-breaker callees before their callers, breakers last. + /// + /// Returns a bitset of all selected breakers across every SCC. + pub(crate) fn run_in( + &mut self, + members: &mut Members, + scratch: &S, + ) -> DenseBitSet { + let mut breakers = DenseBitSet::new_empty(self.properties.len()); + + for (_, members) in members { + if members.len() < 2 { + continue; + } + + self.select_in(members, &mut breakers, scratch); + + #[expect( + clippy::debug_assert_with_mut_call, + reason = "the call only resets and uses the search state, therefore is safe to be \ + mut" + )] + { + debug_assert!( + !self.has_cycle(members, &breakers), + "select_in must produce an acyclic remainder" + ); + } + + let postorder = self.order(members, &breakers, scratch); + members.copy_from_slice(postorder); + } + + breakers + } + + /// Greedily select breakers for a single non-trivial SCC. + /// + /// Postcondition: the non-breaker remainder of `members` is acyclic. + fn select_in( + &mut self, + members: &[DefId], + breakers: &mut DenseBitSet, + scratch: &B, + ) { + // Sort descending: highest breaker score (least valuable to inline) first. + let scored = scratch + .allocate_slice_uninit(members.len()) + .write_with(|index| (members[index], self.score(members[index]))); + scored.sort_by(|(_, lhs_score), (_, rhs_score)| lhs_score.total_cmp(rhs_score).reverse()); + + // The full SCC is cyclic by definition, so we always need at least one breaker. + for &(candidate, _) in &*scored { + breakers.insert(candidate); + + if !self.has_cycle(members, breakers) { + break; + } + } + } + + /// Returns whether the non-breaker members still contain a cycle. + fn has_cycle(&mut self, members: &[DefId], breakers: &DenseBitSet) -> bool { + struct SubgraphCycleDetector<'ctx> { + members: &'ctx [DefId], + breakers: &'ctx DenseBitSet, + } + + impl TriColorVisitor for SubgraphCycleDetector<'_> + where + G: DirectedGraph, + { + type Result = ControlFlow<()>; + + fn node_examined(&mut self, _: DefId, before: Option) -> Self::Result { + match before { + Some(NodeColor::Gray) => ControlFlow::Break(()), + _ => ControlFlow::Continue(()), + } + } + + fn ignore_edge(&mut self, source: DefId, target: DefId) -> bool { + self.breakers.contains(source) + || self.breakers.contains(target) + || !self.members.contains(&source) + || !self.members.contains(&target) + } + } + + let mut detector = SubgraphCycleDetector { members, breakers }; + + // Accumulate visited state across roots: breaker removal can disconnect + // the subgraph, and a cycle in an unreachable component would be missed + // by a single-root search. + self.search.reset(); + for &member in members { + if breakers.contains(member) { + continue; + } + + if self.search.run_from(member, &mut detector).is_break() { + return true; + } + } + + false + } + + /// Compute the breaker score for a single function. + /// + /// Higher score = better breaker candidate (less valuable to inline). + /// See [`InlineLoopBreakerConfig`] for the formula and weight descriptions. + #[expect(clippy::cast_precision_loss)] + fn score(&self, body: DefId) -> f32 { + let props = &self.properties[body]; + + match props.directive { + InlineDirective::Never => return f32::INFINITY, + InlineDirective::Always => return f32::NEG_INFINITY, + InlineDirective::Heuristic => {} + } + + let caller_count = self + .graph + .callers(body) + .filter(|cs| matches!(cs.kind, CallKind::Apply(_))) + .count(); + + let mut score = self.config.cost_weight * props.cost; + score = self + .config + .caller_penalty + .mul_add(-(caller_count as f32), score); + + if self.graph.unique_caller(body).is_some() { + score -= self.config.unique_callsite_penalty; + } + + if props.is_leaf { + score -= self.config.leaf_penalty; + } + + score + } + + /// Compute the processing order for a non-trivial SCC. + /// + /// Returns non-breaker members ordered so that callees appear before their + /// callers, followed by breaker members. + #[expect(unsafe_code)] + fn order<'alloc, S: BumpAllocator>( + &self, + members: &[DefId], + breakers: &DenseBitSet, + alloc: &'alloc S, + ) -> &'alloc [DefId] { + let subgraph = CallSubgraph { + inner: self.graph, + members, + breakers, + }; + + let mut index = 0; + let order = alloc.allocate_slice_uninit(members.len()); + + // The forest traversal covers the full DefId domain (since node_count + // must match the DefId index space for bitset sizing). Non-member nodes + // have no successors in the induced subgraph and yield as isolated + // nodes, so we filter them out. + for node in DepthFirstForestPostOrder::new(&subgraph) { + if !breakers.contains(node) && members.contains(&node) { + order[index].write(node); + index += 1; + } + } + + // Breakers last, in original order. + for &member in members { + if breakers.contains(member) { + order[index].write(member); + index += 1; + } + } + + debug_assert_eq!(index, members.len()); + + // SAFETY: All `members.len()` elements are initialized: + // - The forest traversal yields every non-breaker member exactly once (reachable in the + // full domain, filtered to SCC members). + // - The final loop writes all breaker members. + unsafe { order.assume_init_mut() } + } +} diff --git a/libs/@local/hashql/mir/src/pass/transform/inline/mod.rs b/libs/@local/hashql/mir/src/pass/transform/inline/mod.rs index ce26d8bc987..fca2f6875f2 100644 --- a/libs/@local/hashql/mir/src/pass/transform/inline/mod.rs +++ b/libs/@local/hashql/mir/src/pass/transform/inline/mod.rs @@ -17,12 +17,15 @@ //! # Normal Phase //! //! For non-filter functions, the normal phase: -//! 1. Processes SCCs in dependency order (callees before callers). -//! 2. For each callsite, computes a score using [`InlineHeuristics::score`]. -//! 3. Selects candidates with positive scores, limited by per-caller budget. -//! 4. Updates caller costs after inlining to prevent cascade explosions. -//! -//! Recursive calls (same SCC) are never inlined to prevent infinite expansion. +//! 1. Selects loop breakers for each non-trivial SCC (see [`loop_breaker`]). +//! 2. Processes SCCs in dependency order (callees before callers). Within non-trivial SCCs, +//! non-breaker members are processed in postorder of the breaker-removed DAG, then breaker +//! members. +//! 3. For each callsite, computes a score using [`InlineHeuristics::score`]. +//! 4. Calls to loop breakers within their SCC are skipped. Calls to non-breakers within the same +//! SCC are eligible for inlining. +//! 5. Selects candidates with positive scores, limited by per-caller budget. +//! 6. Updates caller costs after inlining to prevent cascade explosions. //! //! # Aggressive Phase //! @@ -30,7 +33,7 @@ //! aggressive inlining to fully flatten the filter logic. The aggressive phase: //! 1. Iterates up to `aggressive_inline_cutoff` times per filter. //! 2. On each iteration, inlines all eligible callsites found in the filter. -//! 3. Tracks which SCCs have been inlined to prevent cycles. +//! 3. Calls to loop breakers and self-calls are skipped to prevent cycles. //! 4. Emits a diagnostic if the cutoff is reached. //! //! # Budget System @@ -50,26 +53,24 @@ use alloc::collections::BinaryHeap; use core::{alloc::Allocator, cmp, mem}; use hashql_core::{ - graph::{ - DirectedGraph as _, - algorithms::{ - Tarjan, - tarjan::{SccId, StronglyConnectedComponents}, - }, + graph::algorithms::{ + Tarjan, TriColorDepthFirstSearch, + tarjan::{Members, SccId, StronglyConnectedComponents}, }, heap::{BumpAllocator, Heap}, - id::{ - Id as _, IdSlice, - bit_vec::{DenseBitSet, SparseBitMatrix}, - }, + id::{Id as _, IdSlice, bit_vec::DenseBitSet}, span::SpanId, }; -pub use self::{analysis::InlineCostEstimationConfig, heuristics::InlineHeuristicsConfig}; +pub use self::{ + analysis::InlineCostEstimationConfig, heuristics::InlineHeuristicsConfig, + loop_breaker::InlineLoopBreakerConfig, +}; use self::{ analysis::{BodyAnalysis, BodyProperties, CostEstimationResidual}, find::FindCallsiteVisitor, heuristics::InlineHeuristics, + loop_breaker::LoopBreaker, rename::RenameVisitor, }; use crate::{ @@ -100,6 +101,7 @@ mod find; mod heuristics; mod rename; +mod loop_breaker; #[cfg(test)] mod tests; @@ -141,9 +143,11 @@ pub struct InlineConfig { pub cost: InlineCostEstimationConfig, /// Thresholds and bonuses for scoring callsites. pub heuristics: InlineHeuristicsConfig, + /// Configuration for loop-breaker selection in recursive SCCs. + pub loop_breaker: InlineLoopBreakerConfig, /// Multiplier for computing per-caller budget. /// - /// Budget = `heuristics.max × budget_multiplier`. + /// Budget = `heuristics.max * budget_multiplier`. /// Limits how much code can be inlined into a single function. /// /// Default: `2.0` (budget of 120 with default max of 60). @@ -163,6 +167,7 @@ impl Default for InlineConfig { Self { cost: InlineCostEstimationConfig::default(), heuristics: InlineHeuristicsConfig::default(), + loop_breaker: InlineLoopBreakerConfig::default(), budget_multiplier: 2.0, aggressive_inline_cutoff: 16, } @@ -198,39 +203,38 @@ struct InlineState<'ctx, 'state, 'env, 'heap, A: Allocator> { /// Functions that require aggressive inlining (filter closures). filters: DenseBitSet, - /// Tracks which SCCs have been inlined into each function. + /// Functions selected as loop breakers within their SCC. /// - /// Used to prevent cycles during aggressive inlining: once an SCC - /// has been inlined into a filter, it won't be inlined again. - inlined: SparseBitMatrix, + /// Calls to a breaker within its SCC are skipped during inlining. + /// Calls from a breaker to non-breakers are still inlined. + loop_breakers: DenseBitSet, // cost estimation properties costs: CostEstimationResidual<'heap, A>, /// SCC membership for cycle detection. components: StronglyConnectedComponents, + component_members: Option>, global: &'ctx mut GlobalTransformState<'state>, } impl<'heap, A: Allocator> InlineState<'_, '_, '_, 'heap, A> { - /// Collect all non-recursive callsites for aggressive inlining. + /// Collect all callsites for aggressive inlining. /// /// Used for filter functions which bypass normal heuristics. - /// Records inlined SCCs to prevent cycles in subsequent iterations. - fn collect_all_callsites(&mut self, body: DefId, mem: &mut InlineStateMemory) { + /// Self-calls are excluded to prevent panics in `get_disjoint_mut`. + fn collect_all_callsites(&self, body: DefId, mem: &mut InlineStateMemory) { let component = self.components.scc(body); self.graph .apply_callsites(body) - .filter(|callsite| self.components.scc(callsite.target) != component) + .filter(|callsite| { + callsite.target != body + && (self.components.scc(callsite.target) != component + || !self.loop_breakers.contains(callsite.target)) + }) .collect_into(&mut mem.callsites); - - self.inlined.insert(body, component); - for callsite in &mem.callsites { - self.inlined - .insert(body, self.components.scc(callsite.target)); - } } /// Collect callsites using heuristic scoring and budget. @@ -260,7 +264,13 @@ impl<'heap, A: Allocator> InlineState<'_, '_, '_, 'heap, A> { let candidates = &mut mem.candidates; for callsite in self.graph.apply_callsites(body) { - if self.components.scc(callsite.target) == component { + // Within an SCC, only skip calls to loop breakers (they break the cycle). + // Calls to non-breakers within the SCC are eligible because we're now inside of a DAG. + let same_scc = self.components.scc(callsite.target) == component; + if same_scc && self.loop_breakers.contains(callsite.target) { + continue; + } + if callsite.target == body { continue; } @@ -536,15 +546,25 @@ impl Inline { let tarjan = Tarjan::new_in(&graph, &self.alloc); let components = tarjan.run(); + let mut component_members = components.members_in(&self.alloc); + + let mut loop_breaker = LoopBreaker { + config: self.config.loop_breaker, + graph: &graph, + properties: &costs.properties, + search: TriColorDepthFirstSearch::new_in(&graph, &self.alloc), + }; + let loop_breakers = loop_breaker.run_in(&mut component_members, &self.alloc); InlineState { config: self.config, filters, - inlined: SparseBitMatrix::new_in(components.node_count(), &self.alloc), + loop_breakers, interner, graph, costs, components, + component_members: Some(component_members), global: state, } } @@ -552,18 +572,22 @@ impl Inline { /// Run the normal inlining phase. /// /// Processes SCCs in dependency order (callees before callers) so that - /// cost updates propagate correctly. + /// cost updates propagate correctly. Within non-trivial SCCs, non-breaker + /// members are processed in postorder (callees before callers in the + /// breaker-removed DAG), followed by breaker members. fn normal<'heap, 'alloc>( - &self, state: &mut InlineState<'_, '_, '_, 'heap, &'alloc A>, bodies: &mut IdSlice>, mem: &mut InlineStateMemory<&'alloc A>, ) -> Changed { - let members = state.components.members_in(&self.alloc); - let mut any_changed = Changed::No; - for scc in members.sccs() { - for &id in members.of(scc) { + let component_members = state + .component_members + .take() + .unwrap_or_else(|| panic!("scc component members have been taken twice")); + + for (_, scc_members) in &component_members { + for &id in scc_members { let changed = state.run(bodies, id, mem); any_changed |= changed; state.global.mark(id, changed); @@ -605,9 +629,6 @@ impl Inline { mem.callsites .sort_unstable_by(|lhs, rhs| lhs.kind.cmp(&rhs.kind).reverse()); for callsite in mem.callsites.drain(..) { - let target_component = state.components.scc(callsite.target); - state.inlined.insert(filter, target_component); - state.inline(bodies, callsite); } @@ -637,7 +658,7 @@ impl<'env, 'heap, A: BumpAllocator> GlobalTransformPass<'env, 'heap> for Inline< let mut mem = InlineStateMemory::new(&self.alloc); let mut changed = Changed::No; - changed |= self.normal(&mut state, bodies, &mut mem); + changed |= Self::normal(&mut state, bodies, &mut mem); changed |= self.aggressive(context, &mut state, bodies, &mut mem); changed } diff --git a/libs/@local/hashql/mir/src/pass/transform/inline/tests.rs b/libs/@local/hashql/mir/src/pass/transform/inline/tests.rs index fca219e0be6..632256db6af 100644 --- a/libs/@local/hashql/mir/src/pass/transform/inline/tests.rs +++ b/libs/@local/hashql/mir/src/pass/transform/inline/tests.rs @@ -6,7 +6,9 @@ use std::path::PathBuf; use bstr::ByteVec as _; use hashql_core::{ + graph::algorithms::{Tarjan, TriColorDepthFirstSearch, tarjan::SccId}, heap::Heap, + id::{Id as _, bit_vec::DenseBitSet}, pretty::Formatter, symbol::sym, r#type::{TypeFormatter, TypeFormatterOptions, environment::Environment}, @@ -16,6 +18,7 @@ use insta::{Settings, assert_snapshot}; use super::{ BodyAnalysis, Inline, InlineConfig, InlineCostEstimationConfig, InlineHeuristicsConfig, + InlineLoopBreakerConfig, loop_breaker::LoopBreaker, }; use crate::{ body::{Body, Source, basic_block::BasicBlockId, location::Location}, @@ -1113,3 +1116,616 @@ fn heuristics_no_unique_callsite_bonus_multiple_calls() { // 10 + 5 - 30 * 0.875 = 15.0 - 26.25 = -11.25 assert!((heuristics.score(default_callsite()) - (-11.25)).abs() < f32::EPSILON); } + +/// Two mutually recursive functions A and B, with a caller C. +/// +/// A is large (many statements), B is small (single return). The loop breaker +/// should select A (high cost = good breaker), leaving B as a non-breaker. +/// When the inliner processes the SCC: +/// - B's call to A (the breaker) is skipped. +/// - A's call to B (non-breaker) is inlined into A. +/// - C's call to either is cross-SCC and inlined normally. +#[test] +fn loop_breaker_mutual_recursion() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let a_id = DefId::new(0); + let b_id = DefId::new(1); + let c_id = DefId::new(2); + + // A: large function that calls B + let a = body!(interner, env; fn@a_id/1 -> Int { + decl n: Int, cond: Bool, tmp1: Int, tmp2: Int, tmp3: Int, result: Int; + bb0() { + cond = bin.== n 0; + if cond then bb1() else bb2(); + }, + bb1() { goto bb3(n); }, + bb2() { + tmp1 = bin.+ n 1; + tmp2 = bin.+ tmp1 2; + tmp3 = apply (b_id), tmp2; + goto bb3(tmp3); + }, + bb3(result) { return result; } + }); + + // B: small function that calls A + let b = body!(interner, env; fn@b_id/1 -> Int { + decl x: Int, result: Int; + bb0() { + result = apply (a_id), x; + return result; + } + }); + + // C: external caller + let c = simple_caller(&interner, &env, c_id, b_id); + + let mut bodies = [a, b, c]; + + assert_inline_pass( + "loop_breaker_mutual_recursion", + &mut bodies, + &mut MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }, + InlineConfig::default(), + ); +} + +/// Verifies breaker selection picks the highest-cost member. +/// +/// Given SCC {A, B} where A has high cost and B has low cost, +/// A should be selected as the breaker (high cost = good breaker). +#[test] +fn loop_breaker_selects_highest_cost() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let a_id = DefId::new(0); + let b_id = DefId::new(1); + + // A: expensive, calls B + let a = body!(interner, env; fn@a_id/1 -> Int { + decl n: Int, cond: Bool, t1: Int, t2: Int, t3: Int, t4: Int, result: Int; + bb0() { + cond = bin.== n 0; + if cond then bb1() else bb2(); + }, + bb1() { goto bb3(n); }, + bb2() { + t1 = bin.+ n 1; + t2 = bin.+ t1 2; + t3 = bin.- t2 3; + t4 = apply (b_id), t3; + goto bb3(t4); + }, + bb3(result) { return result; } + }); + + // B: cheap, calls A + let b = body!(interner, env; fn@b_id/1 -> Int { + decl x: Int, result: Int; + bb0() { + result = apply (a_id), x; + return result; + } + }); + + let bodies = [a, b]; + let bodies_slice = DefIdSlice::from_raw(&bodies); + + let graph = CallGraph::analyze_in(bodies_slice, &heap); + let mut analysis = BodyAnalysis::new( + &graph, + bodies_slice, + InlineCostEstimationConfig::default(), + &heap, + ); + for body in &bodies { + analysis.run(body); + } + let costs = analysis.finish(); + + let tarjan: Tarjan<_, _, SccId, _, _> = Tarjan::new_in(&graph, &heap); + let components = tarjan.run(); + let mut members = components.members_in(&heap); + + let mut breaker = LoopBreaker { + config: InlineLoopBreakerConfig::default(), + graph: &graph, + properties: &costs.properties, + search: TriColorDepthFirstSearch::new_in(&graph, &heap), + }; + let breakers = breaker.run_in(&mut members, &heap); + + // A has higher cost, so A should be the breaker. + assert!( + breakers.contains(a_id), + "expected A (high cost) to be selected as breaker" + ); + assert!( + !breakers.contains(b_id), + "expected B (low cost) to not be a breaker" + ); +} + +/// Three-way mutual recursion: A -> B -> C -> A. +/// +/// One breaker should be sufficient to break the single cycle. +/// The member with highest cost should be selected. +#[test] +fn loop_breaker_three_way_cycle() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let a_id = DefId::new(0); + let b_id = DefId::new(1); + let c_id = DefId::new(2); + + // A: expensive + let a = body!(interner, env; fn@a_id/1 -> Int { + decl n: Int, cond: Bool, t1: Int, t2: Int, t3: Int, result: Int; + bb0() { + cond = bin.== n 0; + if cond then bb1() else bb2(); + }, + bb1() { goto bb3(n); }, + bb2() { + t1 = bin.+ n 1; + t2 = bin.+ t1 2; + t3 = apply (b_id), t2; + goto bb3(t3); + }, + bb3(result) { return result; } + }); + + // B: medium + let b = body!(interner, env; fn@b_id/1 -> Int { + decl x: Int, t1: Int, result: Int; + bb0() { + t1 = bin.- x 1; + result = apply (c_id), t1; + return result; + } + }); + + // C: cheap, calls A + let c = body!(interner, env; fn@c_id/1 -> Int { + decl x: Int, result: Int; + bb0() { + result = apply (a_id), x; + return result; + } + }); + + let bodies = [a, b, c]; + let bodies_slice = DefIdSlice::from_raw(&bodies); + + let graph = CallGraph::analyze_in(bodies_slice, &heap); + let mut analysis = BodyAnalysis::new( + &graph, + bodies_slice, + InlineCostEstimationConfig::default(), + &heap, + ); + for body in &bodies { + analysis.run(body); + } + let costs = analysis.finish(); + + let tarjan: Tarjan<_, _, SccId, _, _> = Tarjan::new_in(&graph, &heap); + let components = tarjan.run(); + let mut members = components.members_in(&heap); + + let mut breaker = LoopBreaker { + config: InlineLoopBreakerConfig::default(), + graph: &graph, + properties: &costs.properties, + search: TriColorDepthFirstSearch::new_in(&graph, &heap), + }; + let breakers = breaker.run_in(&mut members, &heap); + + // Exactly one breaker should suffice for a single cycle. + assert_eq!( + breakers.count(), + 1, + "expected exactly 1 breaker for a 3-node cycle, got {}", + breakers.count() + ); + + // A has the highest cost. + assert!( + breakers.contains(a_id), + "expected A (highest cost) to be the breaker" + ); +} + +/// Helper: run loop-breaker selection on a set of bodies and return the breaker bitset +/// and the reordered members. +fn run_loop_breaker<'heap>( + bodies: &[Body<'heap>], + heap: &'heap Heap, +) -> (DenseBitSet, Vec>) { + let bodies_slice = DefIdSlice::from_raw(bodies); + + let graph = CallGraph::analyze_in(bodies_slice, heap); + let mut analysis = BodyAnalysis::new( + &graph, + bodies_slice, + InlineCostEstimationConfig::default(), + heap, + ); + for body in bodies { + analysis.run(body); + } + let costs = analysis.finish(); + + let tarjan: Tarjan<_, _, SccId, _, _> = Tarjan::new_in(&graph, heap); + let components = tarjan.run(); + let mut members = components.members_in(heap); + + let mut breaker = LoopBreaker { + config: InlineLoopBreakerConfig::default(), + graph: &graph, + properties: &costs.properties, + search: TriColorDepthFirstSearch::new_in(&graph, heap), + }; + let breakers = breaker.run_in(&mut members, heap); + + let scc_orders: Vec> = members + .iter() + .filter(|(_, m)| m.len() > 1) + .map(|(_, m)| m.to_vec()) + .collect(); + + (breakers, scc_orders) +} + +use core::ops::ControlFlow; + +use hashql_core::graph::algorithms::color::NodeColor; + +struct RemainderCycleDetector<'a> { + members: &'a [DefId], + breakers: &'a DenseBitSet, +} + +impl> + hashql_core::graph::algorithms::TriColorVisitor for RemainderCycleDetector<'_> +{ + type Result = ControlFlow<()>; + + fn node_examined(&mut self, _: DefId, before: Option) -> Self::Result { + match before { + Some(NodeColor::Gray) => ControlFlow::Break(()), + _ => ControlFlow::Continue(()), + } + } + + fn ignore_edge(&mut self, source: DefId, target: DefId) -> bool { + self.breakers.contains(source) + || self.breakers.contains(target) + || !self.members.contains(&source) + || !self.members.contains(&target) + } +} + +/// SCC with two independent 2-cycles joined into one component. +/// Requires at least two breakers. +/// +/// Structure: A <-> B, C <-> D, with B -> C and D -> A connecting them. +#[test] +fn loop_breaker_multi_breaker_scc() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let a_id = DefId::new(0); + let b_id = DefId::new(1); + let c_id = DefId::new(2); + let d_id = DefId::new(3); + + // A: calls B + let a = body!(interner, env; fn@a_id/1 -> Int { + decl n: Int, result: Int; + bb0() { + result = apply (b_id), n; + return result; + } + }); + + // B: calls A and C + let b = body!(interner, env; fn@b_id/1 -> Int { + decl n: Int, cond: Bool, t1: Int, t2: Int, result: Int; + bb0() { + cond = bin.== n 0; + if cond then bb1() else bb2(); + }, + bb1() { + t1 = apply (a_id), n; + goto bb3(t1); + }, + bb2() { + t2 = apply (c_id), n; + goto bb3(t2); + }, + bb3(result) { return result; } + }); + + // C: calls D + let c = body!(interner, env; fn@c_id/1 -> Int { + decl n: Int, result: Int; + bb0() { + result = apply (d_id), n; + return result; + } + }); + + // D: calls C and A (completing both sub-cycles) + let d = body!(interner, env; fn@d_id/1 -> Int { + decl n: Int, cond: Bool, t1: Int, t2: Int, result: Int; + bb0() { + cond = bin.== n 0; + if cond then bb1() else bb2(); + }, + bb1() { + t1 = apply (c_id), n; + goto bb3(t1); + }, + bb2() { + t2 = apply (a_id), n; + goto bb3(t2); + }, + bb3(result) { return result; } + }); + + let bodies = [a, b, c, d]; + let (breakers, _) = run_loop_breaker(&bodies, &heap); + + // Two overlapping sub-cycles (A<->B and C<->D) need exactly 2 breakers: + // no single node participates in both cycles. + assert_eq!( + breakers.count(), + 2, + "expected exactly 2 breakers, got {}", + breakers.count() + ); + + // Verify the remainder is actually acyclic. + let bodies_slice = DefIdSlice::from_raw(&bodies); + let graph = CallGraph::analyze_in(bodies_slice, &heap); + let mut search = TriColorDepthFirstSearch::new_in(&graph, &heap); + let mut cycle_found = false; + + let all_members: Vec = (0..bodies.len()).map(DefId::from_usize).collect(); + let mut detector = RemainderCycleDetector { + members: &all_members, + breakers: &breakers, + }; + + search.reset(); + for &member in &all_members { + if !breakers.contains(member) && search.run_from(member, &mut detector).is_break() { + cycle_found = true; + break; + } + } + + assert!( + !cycle_found, + "remainder after breaker selection must be acyclic" + ); +} + +/// Ordering test with 3 non-breakers forming a chain. +/// +/// SCC: {breaker, X, Y, W} where breaker removal leaves X -> Y -> W. +/// Postorder must satisfy: W before Y, Y before X. +#[test] +fn loop_breaker_ordering_chain() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let breaker_id = DefId::new(0); + let x_id = DefId::new(1); + let y_id = DefId::new(2); + let w_id = DefId::new(3); + + // breaker: expensive, calls X, completes the cycle from W + let breaker_fn = body!(interner, env; fn@breaker_id/1 -> Int { + decl n: Int, cond: Bool, t1: Int, t2: Int, t3: Int, t4: Int, t5: Int, result: Int; + bb0() { + cond = bin.== n 0; + if cond then bb1() else bb2(); + }, + bb1() { goto bb3(n); }, + bb2() { + t1 = bin.+ n 1; + t2 = bin.+ t1 2; + t3 = bin.- t2 3; + t4 = bin.+ t3 4; + t5 = apply (x_id), t4; + goto bb3(t5); + }, + bb3(result) { return result; } + }); + + // X: calls Y + let x = body!(interner, env; fn@x_id/1 -> Int { + decl n: Int, result: Int; + bb0() { + result = apply (y_id), n; + return result; + } + }); + + // Y: calls W + let y = body!(interner, env; fn@y_id/1 -> Int { + decl n: Int, result: Int; + bb0() { + result = apply (w_id), n; + return result; + } + }); + + // W: calls breaker (closing the cycle) + let w = body!(interner, env; fn@w_id/1 -> Int { + decl n: Int, result: Int; + bb0() { + result = apply (breaker_id), n; + return result; + } + }); + + let bodies = [breaker_fn, x, y, w]; + let (breakers, scc_orders) = run_loop_breaker(&bodies, &heap); + + assert_eq!(breakers.count(), 1); + assert!(breakers.contains(breaker_id)); + + // There should be exactly one non-trivial SCC. + assert_eq!(scc_orders.len(), 1); + let order = &scc_orders[0]; + assert_eq!(order.len(), 4); + + // Non-breakers in postorder: W before Y before X. + let pos = |id: DefId| { + order + .iter() + .position(|&node| node == id) + .expect("node exists inside of order") + }; + assert!( + pos(w_id) < pos(y_id), + "W (leaf) must come before Y in postorder" + ); + assert!(pos(y_id) < pos(x_id), "Y must come before X in postorder"); + // Breaker is last. + assert!( + pos(x_id) < pos(breaker_id), + "all non-breakers must come before the breaker" + ); +} + +/// All members have `Always` directive. The algorithm must still select a breaker +/// to break the cycle, even though all candidates score `-inf`. +#[test] +fn loop_breaker_all_always_directive() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let a_id = DefId::new(0); + let b_id = DefId::new(1); + + // Both are constructors (Always directive) + let mut a = body!(interner, env; fn@a_id/1 -> Int { + decl n: Int, result: Int; + bb0() { + result = apply (b_id), n; + return result; + } + }); + a.source = Source::Ctor(hashql_core::symbol::sym::Some); + + let mut b = body!(interner, env; fn@b_id/1 -> Int { + decl n: Int, result: Int; + bb0() { + result = apply (a_id), n; + return result; + } + }); + b.source = Source::Ctor(hashql_core::symbol::sym::None); + + let bodies = [a, b]; + let (breakers, _) = run_loop_breaker(&bodies, &heap); + + // A 2-node cycle needs exactly 1 breaker, even when both score -inf. + assert_eq!( + breakers.count(), + 1, + "expected exactly 1 breaker for a 2-node cycle, got {}", + breakers.count() + ); + // Both are Always with equal cost, so either is a valid choice. + assert!( + breakers.contains(a_id) || breakers.contains(b_id), + "the selected breaker must be one of the SCC members" + ); +} + +/// A filter function that calls into a mutually recursive SCC. +/// +/// The aggressive phase should inline non-breaker B into the filter, but the +/// breaker A (visible after B's inlining) must not be expanded. Without the +/// unconditional breaker check in `FindCallsiteVisitor`, the aggressive phase +/// would re-expand A on each iteration until the cutoff. +#[test] +fn loop_breaker_aggressive_filter_with_recursive_scc() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let a_id = DefId::new(0); + let b_id = DefId::new(1); + let filter_id = DefId::new(2); + + // A: expensive, calls B (will be selected as breaker) + let a = body!(interner, env; fn@a_id/1 -> Int { + decl n: Int, cond: Bool, t1: Int, t2: Int, t3: Int, result: Int; + bb0() { + cond = bin.== n 0; + if cond then bb1() else bb2(); + }, + bb1() { goto bb3(n); }, + bb2() { + t1 = bin.+ n 1; + t2 = bin.+ t1 2; + t3 = apply (b_id), t2; + goto bb3(t3); + }, + bb3(result) { return result; } + }); + + // B: cheap, calls A + let b = body!(interner, env; fn@b_id/1 -> Int { + decl x: Int, result: Int; + bb0() { + result = apply (a_id), x; + return result; + } + }); + + // Filter: calls B + let filter = body!(interner, env; [graph::read::filter]@filter_id/1 -> Int { + decl x: Int, result: Int; + bb0() { + result = apply (b_id), x; + return result; + } + }); + + let mut bodies = [a, b, filter]; + + assert_inline_pass( + "loop_breaker_aggressive_filter", + &mut bodies, + &mut MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }, + InlineConfig::default(), + ); +} diff --git a/libs/@local/hashql/mir/src/pass/transform/mod.rs b/libs/@local/hashql/mir/src/pass/transform/mod.rs index c6819a45d9f..ddda2fa2369 100644 --- a/libs/@local/hashql/mir/src/pass/transform/mod.rs +++ b/libs/@local/hashql/mir/src/pass/transform/mod.rs @@ -22,7 +22,10 @@ pub use self::{ dle::DeadLocalElimination, dse::DeadStoreElimination, forward_substitution::ForwardSubstitution, - inline::{Inline, InlineConfig, InlineCostEstimationConfig, InlineHeuristicsConfig}, + inline::{ + Inline, InlineConfig, InlineCostEstimationConfig, InlineHeuristicsConfig, + InlineLoopBreakerConfig, + }, inst_simplify::InstSimplify, post_inline::PostInline, pre_inline::PreInline, diff --git a/libs/@local/hashql/mir/tests/ui/pass/inline/loop_breaker_aggressive_filter.snap b/libs/@local/hashql/mir/tests/ui/pass/inline/loop_breaker_aggressive_filter.snap new file mode 100644 index 00000000000..e219ef754d7 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/inline/loop_breaker_aggressive_filter.snap @@ -0,0 +1,130 @@ +--- +source: libs/@local/hashql/mir/src/pass/transform/inline/tests.rs +assertion_line: 136 +expression: output +--- +fn {closure@4294967040}(%0: Integer) -> Integer { + let %1: Boolean + let %2: Integer + let %3: Integer + let %4: Integer + let %5: Integer + + bb0(): { + %1 = %0 == 0 + + switchInt(%1) -> [0: bb2(), 1: bb1()] + } + + bb1(): { + goto -> bb3(%0) + } + + bb2(): { + %2 = %0 + 1 + %3 = %2 + 2 + %4 = apply ({def@1} as FnPtr) %3 + + goto -> bb3(%4) + } + + bb3(%5): { + return %5 + } +} + +fn {closure@4294967040}(%0: Integer) -> Integer { + let %1: Integer + + bb0(): { + %1 = apply ({def@0} as FnPtr) %0 + + return %1 + } +} + +fn {graph::read::filter@4294967040}(%0: Integer) -> Integer { + let %1: Integer + + bb0(): { + %1 = apply ({def@1} as FnPtr) %0 + + return %1 + } +} + +================= After Inlining ================= + +fn {closure@4294967040}(%0: Integer) -> Integer { + let %1: Boolean + let %2: Integer + let %3: Integer + let %4: Integer + let %5: Integer + let %6: Integer + let %7: Integer + + bb0(): { + %1 = %0 == 0 + + switchInt(%1) -> [0: bb2(), 1: bb1()] + } + + bb1(): { + goto -> bb3(%0) + } + + bb2(): { + %2 = %0 + 1 + %3 = %2 + 2 + %6 = %3 + + goto -> bb5() + } + + bb3(%5): { + return %5 + } + + bb4(%4): { + goto -> bb3(%4) + } + + bb5(): { + %7 = apply ({def@0} as FnPtr) %6 + + goto -> bb4(%7) + } +} + +fn {closure@4294967040}(%0: Integer) -> Integer { + let %1: Integer + + bb0(): { + %1 = apply ({def@0} as FnPtr) %0 + + return %1 + } +} + +fn {graph::read::filter@4294967040}(%0: Integer) -> Integer { + let %1: Integer + let %2: Integer + let %3: Integer + + bb0(): { + %2 = %0 + + goto -> bb2() + } + + bb1(%1): { + return %1 + } + + bb2(): { + %3 = apply ({def@0} as FnPtr) %2 + + goto -> bb1(%3) + } +} diff --git a/libs/@local/hashql/mir/tests/ui/pass/inline/loop_breaker_mutual_recursion.snap b/libs/@local/hashql/mir/tests/ui/pass/inline/loop_breaker_mutual_recursion.snap new file mode 100644 index 00000000000..635712c8756 --- /dev/null +++ b/libs/@local/hashql/mir/tests/ui/pass/inline/loop_breaker_mutual_recursion.snap @@ -0,0 +1,130 @@ +--- +source: libs/@local/hashql/mir/src/pass/transform/inline/tests.rs +assertion_line: 136 +expression: output +--- +fn {closure@4294967040}(%0: Integer) -> Integer { + let %1: Boolean + let %2: Integer + let %3: Integer + let %4: Integer + let %5: Integer + + bb0(): { + %1 = %0 == 0 + + switchInt(%1) -> [0: bb2(), 1: bb1()] + } + + bb1(): { + goto -> bb3(%0) + } + + bb2(): { + %2 = %0 + 1 + %3 = %2 + 2 + %4 = apply ({def@1} as FnPtr) %3 + + goto -> bb3(%4) + } + + bb3(%5): { + return %5 + } +} + +fn {closure@4294967040}(%0: Integer) -> Integer { + let %1: Integer + + bb0(): { + %1 = apply ({def@0} as FnPtr) %0 + + return %1 + } +} + +thunk {thunk@4294967040}() -> Integer { + let %0: Integer + + bb0(): { + %0 = apply ({def@1} as FnPtr) 1 + + return %0 + } +} + +================= After Inlining ================= + +fn {closure@4294967040}(%0: Integer) -> Integer { + let %1: Boolean + let %2: Integer + let %3: Integer + let %4: Integer + let %5: Integer + let %6: Integer + let %7: Integer + + bb0(): { + %1 = %0 == 0 + + switchInt(%1) -> [0: bb2(), 1: bb1()] + } + + bb1(): { + goto -> bb3(%0) + } + + bb2(): { + %2 = %0 + 1 + %3 = %2 + 2 + %6 = %3 + + goto -> bb5() + } + + bb3(%5): { + return %5 + } + + bb4(%4): { + goto -> bb3(%4) + } + + bb5(): { + %7 = apply ({def@0} as FnPtr) %6 + + goto -> bb4(%7) + } +} + +fn {closure@4294967040}(%0: Integer) -> Integer { + let %1: Integer + + bb0(): { + %1 = apply ({def@0} as FnPtr) %0 + + return %1 + } +} + +thunk {thunk@4294967040}() -> Integer { + let %0: Integer + let %1: Integer + let %2: Integer + + bb0(): { + %1 = 1 + + goto -> bb2() + } + + bb1(%0): { + return %0 + } + + bb2(): { + %2 = apply ({def@0} as FnPtr) %1 + + goto -> bb1(%2) + } +}