diff --git a/lib/section-solver/TinyHyperGraphSectionPipelineSolver.ts b/lib/section-solver/TinyHyperGraphSectionPipelineSolver.ts index fdc3ce5..283dab0 100644 --- a/lib/section-solver/TinyHyperGraphSectionPipelineSolver.ts +++ b/lib/section-solver/TinyHyperGraphSectionPipelineSolver.ts @@ -54,10 +54,27 @@ type AutomaticSectionSearchResult = { candidateInitMs: number candidateSolveMs: number candidateReplayScoreMs: number - winningCandidateLabel?: string - winningCandidateFamily?: TinyHyperGraphSectionCandidateFamily + winningCandidateLabels: string[] + winningCandidateFamilies: TinyHyperGraphSectionCandidateFamily[] } +type CandidateEvaluationResult = { + label: string + family: TinyHyperGraphSectionCandidateFamily + regionIds: RegionId[] + portSectionMask: Int8Array + finalMaxRegionCost: number + improvement: number + eligibilityMs: number + initMs: number + solveMs: number + replayScoreMs: number +} + +const isCandidateEvaluationResult = ( + value: CandidateEvaluationResult | null, +): value is CandidateEvaluationResult => value !== null + const DEFAULT_SOLVE_GRAPH_OPTIONS: TinyHyperGraphSolverOptions = { RIP_THRESHOLD_RAMP_ATTEMPTS: 5, } @@ -81,6 +98,7 @@ const DEFAULT_CANDIDATE_FAMILIES: TinyHyperGraphSectionCandidateFamily[] = [ const DEFAULT_MAX_HOT_REGIONS = 2 const IMPROVEMENT_EPSILON = 1e-9 +const DEFAULT_PARALLEL_CANDIDATE_SOLVE_LIMIT = 4 const getMaxRegionCost = (solver: TinyHyperGraphSolver) => solver.state.regionIntersectionCaches.reduce( @@ -248,8 +266,8 @@ const findBestAutomaticSectionMask = ( let bestFinalMaxRegionCost = baselineMaxRegionCost let bestPortSectionMask = new Int8Array(topology.portCount) - let winningCandidateLabel: string | undefined - let winningCandidateFamily: TinyHyperGraphSectionCandidateFamily | undefined + const winningCandidateLabels: string[] = [] + const winningCandidateFamilies: TinyHyperGraphSectionCandidateFamily[] = [] let generatedCandidateCount = 0 let candidateCount = 0 let duplicateCandidateCount = 0 @@ -263,6 +281,11 @@ const findBestAutomaticSectionMask = ( sectionSolverOptions.MAX_HOT_REGIONS ?? DEFAULT_MAX_HOT_REGIONS + const candidatesToEvaluate: Array<{ + candidate: SectionMaskCandidate + candidateProblem: TinyHyperGraphProblem + }> = [] + for (const candidate of getSectionMaskCandidates( solvedSolver, topology, @@ -286,7 +309,16 @@ const findBestAutomaticSectionMask = ( } seenPortSectionMasks.add(portSectionMaskKey) + candidatesToEvaluate.push({ candidate, candidateProblem }) + } + const evaluateCandidate = ({ + candidate, + candidateProblem, + }: { + candidate: SectionMaskCandidate + candidateProblem: TinyHyperGraphProblem + }): CandidateEvaluationResult | null => { try { const eligibilityStartTime = performance.now() const activeRouteIds = getActiveSectionRouteIds( @@ -294,14 +326,13 @@ const findBestAutomaticSectionMask = ( candidateProblem, solution, ) - candidateEligibilityMs += performance.now() - eligibilityStartTime + const eligibilityMs = performance.now() - eligibilityStartTime if (activeRouteIds.length === 0) { - continue + candidateEligibilityMs += eligibilityMs + return null } - candidateCount += 1 - const candidateInitStartTime = performance.now() const sectionSolver = new TinyHyperGraphSectionSolver( topology, @@ -309,41 +340,121 @@ const findBestAutomaticSectionMask = ( solution, sectionSolverOptions, ) - candidateInitMs += performance.now() - candidateInitStartTime + const initMs = performance.now() - candidateInitStartTime const candidateSolveStartTime = performance.now() sectionSolver.solve() - candidateSolveMs += performance.now() - candidateSolveStartTime + const solveMs = performance.now() - candidateSolveStartTime if (sectionSolver.failed || !sectionSolver.solved) { - continue + candidateEligibilityMs += eligibilityMs + candidateInitMs += initMs + candidateSolveMs += solveMs + return null } - const finalMaxRegionCost = Number( - sectionSolver.stats.finalMaxRegionCost ?? - getMaxRegionCost(sectionSolver.getSolvedSolver()), + const candidateReplayScoreStartTime = performance.now() + const replayedFinalMaxRegionCost = getSerializedOutputMaxRegionCost( + sectionSolver.getOutput(), ) - - if (finalMaxRegionCost < bestFinalMaxRegionCost - IMPROVEMENT_EPSILON) { - const candidateReplayScoreStartTime = performance.now() - const replayedFinalMaxRegionCost = getSerializedOutputMaxRegionCost( - sectionSolver.getOutput(), - ) - candidateReplayScoreMs += - performance.now() - candidateReplayScoreStartTime - - if ( - replayedFinalMaxRegionCost < - bestFinalMaxRegionCost - IMPROVEMENT_EPSILON - ) { - bestFinalMaxRegionCost = replayedFinalMaxRegionCost - bestPortSectionMask = new Int8Array(candidateProblem.portSectionMask) - winningCandidateLabel = candidate.label - winningCandidateFamily = candidate.family - } + const replayScoreMs = performance.now() - candidateReplayScoreStartTime + + candidateEligibilityMs += eligibilityMs + candidateInitMs += initMs + candidateSolveMs += solveMs + candidateReplayScoreMs += replayScoreMs + + return { + label: candidate.label, + family: candidate.family, + regionIds: candidate.regionIds, + portSectionMask: new Int8Array(candidateProblem.portSectionMask), + finalMaxRegionCost: replayedFinalMaxRegionCost, + improvement: baselineMaxRegionCost - replayedFinalMaxRegionCost, + eligibilityMs, + initMs, + solveMs, + replayScoreMs, } } catch { // Skip invalid section masks that split a route into multiple spans. + return null + } + } + + const parallelCandidateSolveLimit = Math.max( + 1, + sectionSolverOptions.PARALLEL_CANDIDATE_SOLVE_LIMIT ?? + DEFAULT_PARALLEL_CANDIDATE_SOLVE_LIMIT, + ) + const evaluatedCandidates: CandidateEvaluationResult[] = [] + for (let i = 0; i < candidatesToEvaluate.length; i += parallelCandidateSolveLimit) { + const chunk = candidatesToEvaluate.slice(i, i + parallelCandidateSolveLimit) + const chunkResults = chunk + .map(evaluateCandidate) + .filter(isCandidateEvaluationResult) + evaluatedCandidates.push(...chunkResults) + } + + candidateCount = evaluatedCandidates.length + + const sortedByImprovement = [...evaluatedCandidates].sort((left, right) => { + if (left.improvement !== right.improvement) { + return right.improvement - left.improvement + } + return left.finalMaxRegionCost - right.finalMaxRegionCost + }) + + const chosenCandidates: CandidateEvaluationResult[] = [] + const chosenRegionIds = new Set() + for (const evaluatedCandidate of sortedByImprovement) { + if (evaluatedCandidate.improvement <= IMPROVEMENT_EPSILON) { + continue + } + const overlapsChosenRegion = evaluatedCandidate.regionIds.some((regionId) => + chosenRegionIds.has(regionId), + ) + if (overlapsChosenRegion) { + continue + } + chosenCandidates.push(evaluatedCandidate) + for (const regionId of evaluatedCandidate.regionIds) { + chosenRegionIds.add(regionId) + } + } + + if (chosenCandidates.length > 0) { + const mergedMask = new Int8Array(topology.portCount) + for (const candidate of chosenCandidates) { + for (let portId = 0; portId < topology.portCount; portId++) { + if (candidate.portSectionMask[portId] === 1) { + mergedMask[portId] = 1 + } + } + winningCandidateLabels.push(candidate.label) + winningCandidateFamilies.push(candidate.family) + } + + const mergedProblem = createProblemWithPortSectionMask(problem, mergedMask) + const mergedSectionSolver = new TinyHyperGraphSectionSolver( + topology, + mergedProblem, + solution, + sectionSolverOptions, + ) + mergedSectionSolver.solve() + + if (mergedSectionSolver.solved && !mergedSectionSolver.failed) { + const mergedReplayedFinalMaxRegionCost = getSerializedOutputMaxRegionCost( + mergedSectionSolver.getOutput(), + ) + if ( + mergedReplayedFinalMaxRegionCost < + bestFinalMaxRegionCost - IMPROVEMENT_EPSILON + ) { + bestFinalMaxRegionCost = mergedReplayedFinalMaxRegionCost + bestPortSectionMask = mergedMask + } } } @@ -360,8 +471,8 @@ const findBestAutomaticSectionMask = ( candidateInitMs, candidateSolveMs, candidateReplayScoreMs, - winningCandidateLabel, - winningCandidateFamily, + winningCandidateLabels, + winningCandidateFamilies, } } @@ -469,9 +580,9 @@ export class TinyHyperGraphSectionPipelineSolver extends BasePipelineSolver