diff --git a/lib/solvers/SameNetTraceCombiningSolver/SameNetTraceCombiningSolver.ts b/lib/solvers/SameNetTraceCombiningSolver/SameNetTraceCombiningSolver.ts new file mode 100644 index 000000000..869c126d2 --- /dev/null +++ b/lib/solvers/SameNetTraceCombiningSolver/SameNetTraceCombiningSolver.ts @@ -0,0 +1,207 @@ +import { BaseSolver } from "lib/solvers/BaseSolver/BaseSolver" +import type { SolvedTracePath } from "lib/solvers/SchematicTraceLinesSolver/SchematicTraceLinesSolver" +import type { InputProblem } from "lib/types/InputProblem" +import type { GraphicsObject } from "graphics-debug" +import { visualizeInputProblem } from "lib/solvers/SchematicTracePipelineSolver/visualizeInputProblem" +import type { Point } from "@tscircuit/math-utils" + +const SNAP_THRESHOLD = 0.15 + +interface Segment { + traceIndex: number + segIndex: number + p1: Point + p2: Point + isHorizontal: boolean +} + +/** + * Combines same-net trace segments that run close together (parallel and + * nearly coincident) by snapping them onto a shared coordinate. + * + * Placed after TraceOverlapShiftSolver to clean up same-net redundancy + * before net-label placement. + */ +export class SameNetTraceCombiningSolver extends BaseSolver { + inputProblem: InputProblem + inputTraces: SolvedTracePath[] + outputTraces: SolvedTracePath[] + + private processed = false + + constructor(params: { + inputProblem: InputProblem + traces: SolvedTracePath[] + }) { + super() + this.inputProblem = params.inputProblem + this.inputTraces = params.traces + this.outputTraces = params.traces.map((t) => ({ + ...t, + tracePath: t.tracePath.map((p) => ({ ...p })), + })) + } + + override getConstructorParams(): ConstructorParameters< + typeof SameNetTraceCombiningSolver + >[0] { + return { + inputProblem: this.inputProblem, + traces: this.inputTraces, + } + } + + override _step() { + if (this.processed) { + this.solved = true + return + } + this.processed = true + this.combineTraces() + this.solved = true + } + + private combineTraces() { + const netGroups = new Map() + for (let i = 0; i < this.outputTraces.length; i++) { + const netId = this.outputTraces[i]!.globalConnNetId + if (!netGroups.has(netId)) netGroups.set(netId, []) + netGroups.get(netId)!.push(i) + } + + for (const indices of netGroups.values()) { + if (indices.length < 2) continue + this.combineGroup(indices) + } + } + + private combineGroup(traceIndices: number[]) { + let changed = true + let iterations = 0 + const maxIter = 20 + + while (changed && iterations < maxIter) { + changed = false + iterations++ + + const segments = this.collectSegments(traceIndices) + + const horizontals = segments.filter((s) => s.isHorizontal) + const verticals = segments.filter((s) => !s.isHorizontal) + + if (this.snapParallelSegments(horizontals, true)) changed = true + if (this.snapParallelSegments(verticals, false)) changed = true + } + } + + private collectSegments(traceIndices: number[]): Segment[] { + const EPS = 1e-6 + const segments: Segment[] = [] + + for (const ti of traceIndices) { + const path = this.outputTraces[ti]!.tracePath + for (let si = 0; si < path.length - 1; si++) { + const p1 = path[si]! + const p2 = path[si + 1]! + const isHorizontal = Math.abs(p1.y - p2.y) < EPS + const isVertical = Math.abs(p1.x - p2.x) < EPS + if (!isHorizontal && !isVertical) continue + segments.push({ traceIndex: ti, segIndex: si, p1, p2, isHorizontal }) + } + } + return segments + } + + private snapParallelSegments( + segments: Segment[], + horizontal: boolean, + ): boolean { + let changed = false + + for (let i = 0; i < segments.length; i++) { + for (let j = i + 1; j < segments.length; j++) { + const a = segments[i]! + const b = segments[j]! + if (a.traceIndex === b.traceIndex) continue + + if (horizontal) { + const dy = Math.abs(a.p1.y - b.p1.y) + if (dy < 1e-6 || dy > SNAP_THRESHOLD) continue + if (!this.rangesOverlap(a.p1.x, a.p2.x, b.p1.x, b.p2.x)) continue + + const avgY = (a.p1.y + b.p1.y) / 2 + this.shiftSegmentY(a, avgY) + this.shiftSegmentY(b, avgY) + changed = true + } else { + const dx = Math.abs(a.p1.x - b.p1.x) + if (dx < 1e-6 || dx > SNAP_THRESHOLD) continue + if (!this.rangesOverlap(a.p1.y, a.p2.y, b.p1.y, b.p2.y)) continue + + const avgX = (a.p1.x + b.p1.x) / 2 + this.shiftSegmentX(a, avgX) + this.shiftSegmentX(b, avgX) + changed = true + } + } + } + + return changed + } + + private rangesOverlap( + a1: number, + a2: number, + b1: number, + b2: number, + ): boolean { + const minA = Math.min(a1, a2) + const maxA = Math.max(a1, a2) + const minB = Math.min(b1, b2) + const maxB = Math.max(b1, b2) + return Math.min(maxA, maxB) - Math.max(minA, minB) > 1e-6 + } + + private shiftSegmentY(seg: Segment, newY: number) { + const path = this.outputTraces[seg.traceIndex]!.tracePath + const isFirstSeg = seg.segIndex === 0 + const isLastSeg = seg.segIndex === path.length - 2 + + if (!isFirstSeg) { + path[seg.segIndex]!.y = newY + } + if (!isLastSeg) { + path[seg.segIndex + 1]!.y = newY + } + } + + private shiftSegmentX(seg: Segment, newX: number) { + const path = this.outputTraces[seg.traceIndex]!.tracePath + const isFirstSeg = seg.segIndex === 0 + const isLastSeg = seg.segIndex === path.length - 2 + + if (!isFirstSeg) { + path[seg.segIndex]!.x = newX + } + if (!isLastSeg) { + path[seg.segIndex + 1]!.x = newX + } + } + + getOutput(): { traces: SolvedTracePath[] } { + return { traces: this.outputTraces } + } + + override visualize(): GraphicsObject { + const lines = this.outputTraces.map((trace) => ({ + points: trace.tracePath.map((p) => ({ x: p.x, y: p.y })), + strokeColor: "blue", + })) + + const base = visualizeInputProblem(this.inputProblem) + return { + ...base, + lines: [...(base.lines ?? []), ...lines], + } + } +} diff --git a/lib/solvers/SchematicTracePipelineSolver/SchematicTracePipelineSolver.ts b/lib/solvers/SchematicTracePipelineSolver/SchematicTracePipelineSolver.ts index 59821f0c1..5dfab4fce 100644 --- a/lib/solvers/SchematicTracePipelineSolver/SchematicTracePipelineSolver.ts +++ b/lib/solvers/SchematicTracePipelineSolver/SchematicTracePipelineSolver.ts @@ -26,6 +26,7 @@ import { AvailableNetOrientationSolver } from "../AvailableNetOrientationSolver/ import { VccNetLabelCornerPlacementSolver } from "../VccNetLabelCornerPlacementSolver/VccNetLabelCornerPlacementSolver" import { TraceAnchoredNetLabelOverlapSolver } from "../TraceAnchoredNetLabelOverlapSolver/TraceAnchoredNetLabelOverlapSolver" import { NetLabelTraceCollisionSolver } from "../NetLabelTraceCollisionSolver/NetLabelTraceCollisionSolver" +import { SameNetTraceCombiningSolver } from "../SameNetTraceCombiningSolver/SameNetTraceCombiningSolver" type PipelineStep BaseSolver> = { solverName: string @@ -79,6 +80,7 @@ export class SchematicTracePipelineSolver extends BaseSolver { availableNetOrientationSolver?: AvailableNetOrientationSolver vccNetLabelCornerPlacementSolver?: VccNetLabelCornerPlacementSolver traceAnchoredNetLabelOverlapSolver?: TraceAnchoredNetLabelOverlapSolver + sameNetTraceCombiningSolver?: SameNetTraceCombiningSolver netLabelTraceCollisionSolver?: NetLabelTraceCollisionSolver startTimeOfPhase: Record @@ -154,19 +156,29 @@ export class SchematicTracePipelineSolver extends BaseSolver { onSolved: (_solver) => {}, }, ), + definePipelineStep( + "sameNetTraceCombiningSolver", + SameNetTraceCombiningSolver, + (instance) => [ + { + inputProblem: instance.inputProblem, + traces: instance.traceOverlapShiftSolver?.correctedTraceMap + ? Object.values(instance.traceOverlapShiftSolver.correctedTraceMap) + : instance.longDistancePairSolver!.getOutput().allTracesMerged, + }, + ], + ), definePipelineStep( "netLabelPlacementSolver", NetLabelPlacementSolver, - () => [ + (instance) => [ { - inputProblem: this.inputProblem, - inputTraceMap: - this.traceOverlapShiftSolver?.correctedTraceMap ?? - Object.fromEntries( - this.longDistancePairSolver!.getOutput().allTracesMerged.map( - (p) => [p.mspPairId, p], - ), - ), + inputProblem: instance.inputProblem, + inputTraceMap: Object.fromEntries( + instance + .sameNetTraceCombiningSolver!.getOutput() + .traces.map((p) => [p.mspPairId, p]), + ), }, ], { @@ -179,14 +191,11 @@ export class SchematicTracePipelineSolver extends BaseSolver { "traceLabelOverlapAvoidanceSolver", TraceLabelOverlapAvoidanceSolver, (instance) => { - const traceMap = - instance.traceOverlapShiftSolver?.correctedTraceMap ?? - Object.fromEntries( - instance - .longDistancePairSolver!.getOutput() - .allTracesMerged.map((p) => [p.mspPairId, p]), - ) - const traces = Object.values(traceMap) + const traces = instance.sameNetTraceCombiningSolver + ? instance.sameNetTraceCombiningSolver.getOutput().traces + : instance.traceOverlapShiftSolver?.correctedTraceMap + ? Object.values(instance.traceOverlapShiftSolver.correctedTraceMap) + : instance.longDistancePairSolver!.getOutput().allTracesMerged const netLabelPlacements = instance.netLabelPlacementSolver!.netLabelPlacements diff --git a/tests/solvers/SameNetTraceCombiningSolver/SameNetTraceCombiningSolver.test.ts b/tests/solvers/SameNetTraceCombiningSolver/SameNetTraceCombiningSolver.test.ts new file mode 100644 index 000000000..d059b7cb8 --- /dev/null +++ b/tests/solvers/SameNetTraceCombiningSolver/SameNetTraceCombiningSolver.test.ts @@ -0,0 +1,186 @@ +import { test, expect } from "bun:test" +import { SameNetTraceCombiningSolver } from "lib/solvers/SameNetTraceCombiningSolver/SameNetTraceCombiningSolver" +import type { SolvedTracePath } from "lib/solvers/SchematicTraceLinesSolver/SchematicTraceLinesSolver" +import type { InputProblem } from "lib/types/InputProblem" + +const makeTrace = ( + id: string, + netId: string, + path: { x: number; y: number }[], +): SolvedTracePath => ({ + mspPairId: id, + dcConnNetId: netId, + globalConnNetId: netId, + pins: [ + { pinId: `${id}_p1`, x: path[0]!.x, y: path[0]!.y, chipId: "U1" }, + { + pinId: `${id}_p2`, + x: path[path.length - 1]!.x, + y: path[path.length - 1]!.y, + chipId: "U2", + }, + ], + tracePath: path, + mspConnectionPairIds: [id], + pinIds: [`${id}_p1`, `${id}_p2`], +}) + +const emptyInput: InputProblem = { + chips: [], + directConnections: [], + netConnections: [], + availableNetLabelOrientations: {}, +} + +test("snaps close horizontal same-net segments to average Y", () => { + const t1 = makeTrace("t1", "VCC", [ + { x: 0, y: 0 }, + { x: 2, y: 0 }, + { x: 2, y: 1 }, + ]) + const t2 = makeTrace("t2", "VCC", [ + { x: 0, y: 3 }, + { x: 1, y: 3 }, + { x: 1, y: 3.1 }, + { x: 3, y: 3.1 }, + ]) + + const solver = new SameNetTraceCombiningSolver({ + inputProblem: emptyInput, + traces: [t1, t2], + }) + solver.solve() + + const output = solver.getOutput() + expect(output.traces).toHaveLength(2) + + const t2Path = output.traces[1]!.tracePath + const seg2Y = t2Path[2]!.y + const seg3Y = t2Path[3]!.y + expect(seg2Y).toBeCloseTo(seg3Y, 5) +}) + +test("does not snap segments from different nets", () => { + const t1 = makeTrace("t1", "VCC", [ + { x: 0, y: 0 }, + { x: 2, y: 0 }, + { x: 2, y: 1.05 }, + ]) + const t2 = makeTrace("t2", "GND", [ + { x: 0, y: 3 }, + { x: 1, y: 3 }, + { x: 1, y: 1 }, + { x: 3, y: 1 }, + ]) + + const solver = new SameNetTraceCombiningSolver({ + inputProblem: emptyInput, + traces: [t1, t2], + }) + solver.solve() + + const output = solver.getOutput() + expect(output.traces[0]!.tracePath[1]!.y).toBe(0) + expect(output.traces[1]!.tracePath[2]!.y).toBe(1) +}) + +test("preserves terminal pin endpoints", () => { + const t1 = makeTrace("t1", "VCC", [ + { x: 0, y: 0 }, + { x: 2, y: 0 }, + ]) + const t2 = makeTrace("t2", "VCC", [ + { x: 0, y: 0.1 }, + { x: 2, y: 0.1 }, + ]) + + const solver = new SameNetTraceCombiningSolver({ + inputProblem: emptyInput, + traces: [t1, t2], + }) + solver.solve() + + const output = solver.getOutput() + expect(output.traces[0]!.tracePath[0]!.y).toBe(0) + expect(output.traces[1]!.tracePath[0]!.y).toBe(0.1) +}) + +test("handles single trace without errors", () => { + const t1 = makeTrace("t1", "VCC", [ + { x: 0, y: 0 }, + { x: 2, y: 0 }, + { x: 2, y: 1 }, + ]) + + const solver = new SameNetTraceCombiningSolver({ + inputProblem: emptyInput, + traces: [t1], + }) + solver.solve() + + const output = solver.getOutput() + expect(output.traces).toHaveLength(1) + expect(output.traces[0]!.tracePath).toEqual(t1.tracePath) +}) + +test("snaps close vertical same-net segments to average X", () => { + const t1 = makeTrace("t1", "NET1", [ + { x: 0, y: 0 }, + { x: 1, y: 0 }, + { x: 1, y: 3 }, + ]) + const t2 = makeTrace("t2", "NET1", [ + { x: 0, y: 5 }, + { x: 1.1, y: 5 }, + { x: 1.1, y: 2 }, + ]) + + const solver = new SameNetTraceCombiningSolver({ + inputProblem: emptyInput, + traces: [t1, t2], + }) + solver.solve() + + const output = solver.getOutput() + const t1X = output.traces[0]!.tracePath[1]!.x + const t2X = output.traces[1]!.tracePath[1]!.x + expect(t1X).toBeCloseTo(t2X, 5) + expect(t1X).toBeCloseTo(1.05, 5) +}) + +test("does not snap segments farther apart than threshold", () => { + const t1 = makeTrace("t1", "VCC", [ + { x: 0, y: 0 }, + { x: 2, y: 0 }, + { x: 2, y: 1 }, + ]) + const t2 = makeTrace("t2", "VCC", [ + { x: 0, y: 5 }, + { x: 2, y: 5 }, + { x: 2, y: 1.5 }, + ]) + + const solver = new SameNetTraceCombiningSolver({ + inputProblem: emptyInput, + traces: [t1, t2], + }) + solver.solve() + + const output = solver.getOutput() + expect(output.traces[0]!.tracePath[1]!.y).toBe(0) + expect(output.traces[1]!.tracePath[1]!.y).toBe(5) +}) + +test("integrates into full pipeline without breaking existing tests", () => { + const inputProblem = require("../../assets/example01.json") + const { + SchematicTracePipelineSolver, + } = require("lib/solvers/SchematicTracePipelineSolver/SchematicTracePipelineSolver") + + const solver = new SchematicTracePipelineSolver(inputProblem) + solver.solve() + + expect(solver.solved).toBe(true) + expect(solver.failed).toBe(false) + expect(solver.sameNetTraceCombiningSolver).toBeDefined() +})