Skip to content

Commit 2947005

Browse files
authored
Merge pull request #18 from tscircuit/optimization
Set default maxHotRegions to 2, minor snapshotting to sectionSolver to improve speed
2 parents bb0c967 + 4d537b7 commit 2947005

8 files changed

Lines changed: 455 additions & 18 deletions
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import type { SerializedHyperGraph } from "@tscircuit/hypergraph"
2+
3+
export interface SerializedHyperGraphPortPointPathingSolverParams {
4+
format: "serialized-hg-port-point-pathing-solver-params"
5+
graph: {
6+
regions: SerializedHyperGraph["regions"]
7+
ports: SerializedHyperGraph["ports"]
8+
}
9+
connections: NonNullable<SerializedHyperGraph["connections"]>
10+
}
11+
12+
export type SerializedHyperGraphPortPointPathingSolverInput =
13+
| SerializedHyperGraphPortPointPathingSolverParams
14+
| SerializedHyperGraphPortPointPathingSolverParams[]
15+
16+
const isSerializedHyperGraphPortPointPathingSolverParams = (
17+
value: unknown,
18+
): value is SerializedHyperGraphPortPointPathingSolverParams =>
19+
typeof value === "object" &&
20+
value !== null &&
21+
(value as { format?: unknown }).format ===
22+
"serialized-hg-port-point-pathing-solver-params" &&
23+
Array.isArray((value as { graph?: { regions?: unknown } }).graph?.regions) &&
24+
Array.isArray((value as { graph?: { ports?: unknown } }).graph?.ports) &&
25+
Array.isArray((value as { connections?: unknown }).connections)
26+
27+
const getSinglePortPointPathingSolverParams = (
28+
input: SerializedHyperGraphPortPointPathingSolverInput,
29+
) => {
30+
if (!Array.isArray(input)) {
31+
return input
32+
}
33+
34+
const params = input[0]
35+
if (!params) {
36+
throw new Error(
37+
"Port point pathing solver input array must contain at least one item",
38+
)
39+
}
40+
41+
return params
42+
}
43+
44+
export const convertPortPointPathingSolverInputToSerializedHyperGraph = (
45+
input: SerializedHyperGraphPortPointPathingSolverInput,
46+
): SerializedHyperGraph => {
47+
const params = getSinglePortPointPathingSolverParams(input)
48+
49+
if (!isSerializedHyperGraphPortPointPathingSolverParams(params)) {
50+
throw new Error(
51+
"Expected serialized-hg-port-point-pathing-solver-params input",
52+
)
53+
}
54+
55+
return {
56+
regions: params.graph.regions,
57+
ports: params.graph.ports,
58+
connections: params.connections,
59+
}
60+
}

lib/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
export * from "./core"
2+
export { convertPortPointPathingSolverInputToSerializedHyperGraph } from "./compat/convertPortPointPathingSolverInputToSerializedHyperGraph"
23
export {
34
TinyHyperGraphSectionSolver,
45
type TinyHyperGraphSectionSolverOptions,

lib/section-solver/TinyHyperGraphSectionPipelineSolver.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ const DEFAULT_CANDIDATE_FAMILIES: TinyHyperGraphSectionCandidateFamily[] = [
7878
"twohop-all",
7979
"twohop-touch",
8080
]
81+
const DEFAULT_MAX_HOT_REGIONS = 2
8182

8283
const IMPROVEMENT_EPSILON = 1e-9
8384

@@ -257,11 +258,15 @@ const findBestAutomaticSectionMask = (
257258
let candidateSolveMs = 0
258259
let candidateReplayScoreMs = 0
259260
const seenPortSectionMasks = new Set<string>()
261+
const maxHotRegions =
262+
searchConfig?.maxHotRegions ??
263+
sectionSolverOptions.MAX_HOT_REGIONS ??
264+
DEFAULT_MAX_HOT_REGIONS
260265

261266
for (const candidate of getSectionMaskCandidates(
262267
solvedSolver,
263268
topology,
264-
searchConfig?.maxHotRegions ?? 9,
269+
maxHotRegions,
265270
searchConfig?.candidateFamilies ?? DEFAULT_CANDIDATE_FAMILIES,
266271
)) {
267272
const candidateProblem = createProblemWithPortSectionMask(

lib/section-solver/index.ts

Lines changed: 125 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ export interface TinyHyperGraphSectionSolverOptions
4343
MAX_RIPS?: number
4444
MAX_RIPS_WITHOUT_MAX_REGION_COST_IMPROVEMENT?: number
4545
EXTRA_RIPS_AFTER_BEATING_BASELINE_MAX_REGION_COST?: number
46+
/**
47+
* Pipeline convenience option for automatic section-mask search.
48+
* When `sectionSearchConfig.maxHotRegions` is omitted, the section pipeline
49+
* falls back to this value before using its built-in default.
50+
*/
51+
MAX_HOT_REGIONS?: number
4652
}
4753

4854
const applyTinyHyperGraphSectionSolverOptions = (
@@ -118,6 +124,16 @@ const cloneSolvedStateSnapshot = (
118124
),
119125
})
120126

127+
const restoreSolvedStateSnapshot = (
128+
solver: TinyHyperGraphSolver,
129+
snapshot: SolvedStateSnapshot,
130+
) => {
131+
const clonedSnapshot = cloneSolvedStateSnapshot(snapshot)
132+
solver.state.portAssignment = clonedSnapshot.portAssignment
133+
solver.state.regionSegments = clonedSnapshot.regionSegments
134+
solver.state.regionIntersectionCaches = clonedSnapshot.regionIntersectionCaches
135+
}
136+
121137
const summarizeRegionIntersectionCaches = (
122138
regionIntersectionCaches: ArrayLike<RegionIntersectionCache>,
123139
): RegionCostSummary => {
@@ -161,6 +177,35 @@ const summarizeRegionIntersectionCachesForRegionIds = (
161177
}
162178
}
163179

180+
const summarizeRegionIntersectionCachesExcludingRegionIds = (
181+
regionIntersectionCaches: ArrayLike<RegionIntersectionCache>,
182+
excludedRegionIds: RegionId[],
183+
): RegionCostSummary => {
184+
const excludedRegionIdSet = new Set(excludedRegionIds)
185+
let maxRegionCost = 0
186+
let totalRegionCost = 0
187+
188+
for (
189+
let regionId = 0;
190+
regionId < regionIntersectionCaches.length;
191+
regionId++
192+
) {
193+
if (excludedRegionIdSet.has(regionId)) {
194+
continue
195+
}
196+
197+
const regionCost =
198+
regionIntersectionCaches[regionId]?.existingRegionCost ?? 0
199+
maxRegionCost = Math.max(maxRegionCost, regionCost)
200+
totalRegionCost += regionCost
201+
}
202+
203+
return {
204+
maxRegionCost,
205+
totalRegionCost,
206+
}
207+
}
208+
164209
const compareRegionCostSummaries = (
165210
left: RegionCostSummary,
166211
right: RegionCostSummary,
@@ -538,6 +583,7 @@ const getSectionRegionIds = (
538583

539584
class TinyHyperGraphSectionSearchSolver extends TinyHyperGraphSolver {
540585
bestSnapshot?: SolvedStateSnapshot
586+
fixedSnapshot?: SolvedStateSnapshot
541587
bestSummary?: RegionCostSummary
542588
baselineBeatRipCount?: number
543589
previousBestMaxRegionCost = Number.POSITIVE_INFINITY
@@ -552,13 +598,20 @@ class TinyHyperGraphSectionSearchSolver extends TinyHyperGraphSolver {
552598
problem: TinyHyperGraphProblem,
553599
private routePlans: SectionRoutePlan[],
554600
private activeRouteIds: RouteId[],
601+
private mutableRegionIds: RegionId[],
602+
private immutableRegionSummary: RegionCostSummary,
555603
private baselineSummary: RegionCostSummary,
556604
options?: TinyHyperGraphSectionSolverOptions,
557605
) {
558606
super(topology, problem, options)
559607
applyTinyHyperGraphSectionSolverOptions(this, options)
560608
this.state.unroutedRoutes = [...activeRouteIds]
561609
this.applyFixedSegments()
610+
this.fixedSnapshot = cloneSolvedStateSnapshot({
611+
portAssignment: this.state.portAssignment,
612+
regionSegments: this.state.regionSegments,
613+
regionIntersectionCaches: this.state.regionIntersectionCaches,
614+
})
562615
}
563616

564617
applyFixedSegments() {
@@ -601,10 +654,7 @@ class TinyHyperGraphSectionSearchSolver extends TinyHyperGraphSolver {
601654
return
602655
}
603656

604-
const snapshot = cloneSolvedStateSnapshot(this.bestSnapshot)
605-
this.state.portAssignment = snapshot.portAssignment
606-
this.state.regionSegments = snapshot.regionSegments
607-
this.state.regionIntersectionCaches = snapshot.regionIntersectionCaches
657+
restoreSolvedStateSnapshot(this, this.bestSnapshot)
608658
this.state.currentRouteId = undefined
609659
this.state.currentRouteNetId = undefined
610660
this.state.unroutedRoutes = []
@@ -626,16 +676,30 @@ class TinyHyperGraphSectionSearchSolver extends TinyHyperGraphSolver {
626676
}
627677

628678
override resetRoutingStateForRerip() {
629-
super.resetRoutingStateForRerip()
679+
if (!this.fixedSnapshot) {
680+
super.resetRoutingStateForRerip()
681+
this.state.unroutedRoutes = shuffle(
682+
[...this.activeRouteIds],
683+
this.state.ripCount,
684+
)
685+
this.applyFixedSegments()
686+
return
687+
}
688+
689+
restoreSolvedStateSnapshot(this, this.fixedSnapshot)
690+
this.state.currentRouteId = undefined
691+
this.state.currentRouteNetId = undefined
630692
this.state.unroutedRoutes = shuffle(
631693
[...this.activeRouteIds],
632694
this.state.ripCount,
633695
)
634-
this.applyFixedSegments()
696+
this.state.candidateQueue.clear()
697+
this.resetCandidateBestCosts()
698+
this.state.goalPortId = -1
635699
}
636700

637701
override onAllRoutesRouted() {
638-
const { topology, state } = this
702+
const { state } = this
639703
const maxRips = Math.min(this.MAX_RIPS, this.RIP_THRESHOLD_RAMP_ATTEMPTS)
640704
const ripThresholdProgress =
641705
maxRips <= 0 ? 1 : Math.min(1, state.ripCount / maxRips)
@@ -644,22 +708,34 @@ class TinyHyperGraphSectionSearchSolver extends TinyHyperGraphSolver {
644708
(this.RIP_THRESHOLD_END - this.RIP_THRESHOLD_START) * ripThresholdProgress
645709

646710
const regionIdsOverCostThreshold: RegionId[] = []
647-
const regionCosts = new Float64Array(topology.regionCount)
648-
let maxRegionCost = 0
649-
let totalRegionCost = 0
711+
const mutableRegionCosts = new Float64Array(this.mutableRegionIds.length)
712+
let mutableMaxRegionCost = 0
713+
let mutableTotalRegionCost = 0
650714

651-
for (let regionId = 0; regionId < topology.regionCount; regionId++) {
715+
for (
716+
let mutableRegionIndex = 0;
717+
mutableRegionIndex < this.mutableRegionIds.length;
718+
mutableRegionIndex++
719+
) {
720+
const regionId = this.mutableRegionIds[mutableRegionIndex]!
652721
const regionCost =
653722
state.regionIntersectionCaches[regionId]?.existingRegionCost ?? 0
654-
regionCosts[regionId] = regionCost
655-
maxRegionCost = Math.max(maxRegionCost, regionCost)
656-
totalRegionCost += regionCost
723+
mutableRegionCosts[mutableRegionIndex] = regionCost
724+
mutableMaxRegionCost = Math.max(mutableMaxRegionCost, regionCost)
725+
mutableTotalRegionCost += regionCost
657726

658727
if (regionCost > currentRipThreshold) {
659728
regionIdsOverCostThreshold.push(regionId)
660729
}
661730
}
662731

732+
const maxRegionCost = Math.max(
733+
this.immutableRegionSummary.maxRegionCost,
734+
mutableMaxRegionCost,
735+
)
736+
const totalRegionCost =
737+
this.immutableRegionSummary.totalRegionCost + mutableTotalRegionCost
738+
663739
this.captureBestState({
664740
maxRegionCost,
665741
totalRegionCost,
@@ -713,9 +789,15 @@ class TinyHyperGraphSectionSearchSolver extends TinyHyperGraphSolver {
713789
return
714790
}
715791

716-
for (let regionId = 0; regionId < topology.regionCount; regionId++) {
792+
for (
793+
let mutableRegionIndex = 0;
794+
mutableRegionIndex < this.mutableRegionIds.length;
795+
mutableRegionIndex++
796+
) {
797+
const regionId = this.mutableRegionIds[mutableRegionIndex]!
717798
state.regionCongestionCost[regionId] +=
718-
regionCosts[regionId] * this.RIP_CONGESTION_REGION_COST_FACTOR
799+
mutableRegionCosts[mutableRegionIndex]! *
800+
this.RIP_CONGESTION_REGION_COST_FACTOR
719801
}
720802

721803
state.ripCount += 1
@@ -727,6 +809,25 @@ class TinyHyperGraphSectionSearchSolver extends TinyHyperGraphSolver {
727809
}
728810
}
729811

812+
override onOutOfCandidates() {
813+
const { state } = this
814+
815+
for (const regionId of this.mutableRegionIds) {
816+
const regionCost =
817+
state.regionIntersectionCaches[regionId]?.existingRegionCost ?? 0
818+
state.regionCongestionCost[regionId] +=
819+
regionCost * this.RIP_CONGESTION_REGION_COST_FACTOR
820+
}
821+
822+
state.ripCount += 1
823+
this.resetRoutingStateForRerip()
824+
this.stats = {
825+
...this.stats,
826+
ripCount: state.ripCount,
827+
reripReason: "out_of_candidates",
828+
}
829+
}
830+
730831
override tryFinalAcceptance() {
731832
if (!this.bestSnapshot) {
732833
return
@@ -750,6 +851,7 @@ export class TinyHyperGraphSectionSolver extends BaseSolver {
750851
baselineSolver: TinyHyperGraphSolver
751852
baselineSummary: RegionCostSummary
752853
sectionBaselineSummary: RegionCostSummary
854+
outsideSectionBaselineSummary: RegionCostSummary
753855
sectionRegionIds: RegionId[]
754856
optimizedSolver?: TinyHyperGraphSolver
755857
sectionSolver?: TinyHyperGraphSectionSearchSolver
@@ -790,6 +892,11 @@ export class TinyHyperGraphSectionSolver extends BaseSolver {
790892
this.baselineSolver.state.regionIntersectionCaches,
791893
this.sectionRegionIds,
792894
)
895+
this.outsideSectionBaselineSummary =
896+
summarizeRegionIntersectionCachesExcludingRegionIds(
897+
this.baselineSolver.state.regionIntersectionCaches,
898+
this.sectionRegionIds,
899+
)
793900
this.applySectionRipPolicy()
794901
}
795902

@@ -828,6 +935,8 @@ export class TinyHyperGraphSectionSolver extends BaseSolver {
828935
sectionProblem,
829936
routePlans,
830937
activeRouteIds,
938+
this.sectionRegionIds,
939+
this.outsideSectionBaselineSummary,
831940
this.baselineSummary,
832941
getTinyHyperGraphSectionSolverOptions(this),
833942
)

package.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
"typecheck": "tsc -p tsconfig.json --pretty false",
88
"benchmark1": "bun run --cpu-prof-md scripts/profiling/hg07-first10.ts",
99
"benchmark:section": "bun run scripts/benchmarking/hg07-section-pipeline.ts",
10-
"benchmark:section:profile": "bun run scripts/benchmarking/hg07-first40-section-profile.ts"
10+
"benchmark:section:profile": "bun run scripts/benchmarking/hg07-first40-section-profile.ts",
11+
"benchmark:port-point-pathing": "bun run scripts/benchmarking/port-point-pathing-section-pipeline.ts"
1112
},
1213
"devDependencies": {
1314
"@biomejs/biome": "^2.4.8",

0 commit comments

Comments
 (0)