diff --git a/lib/solvers/SameNetTraceConsolidationSolver/SameNetTraceConsolidationSolver.ts b/lib/solvers/SameNetTraceConsolidationSolver/SameNetTraceConsolidationSolver.ts new file mode 100644 index 000000000..3e4e75a05 --- /dev/null +++ b/lib/solvers/SameNetTraceConsolidationSolver/SameNetTraceConsolidationSolver.ts @@ -0,0 +1,475 @@ +import type { Point } from "@tscircuit/math-utils" +import { doSegmentsIntersect } from "@tscircuit/math-utils" +import type { GraphicsObject } from "graphics-debug" +import { BaseSolver } from "lib/solvers/BaseSolver/BaseSolver" +import type { MspConnectionPairId } from "lib/solvers/MspConnectionPairSolver/MspConnectionPairSolver" +import type { SolvedTracePath } from "lib/solvers/SchematicTraceLinesSolver/SchematicTraceLinesSolver" +import { + isHorizontal, + isVertical, + segmentIntersectsRect, +} from "lib/solvers/SchematicTraceLinesSolver/SchematicTraceSingleLineSolver2/collisions" +import { getObstacleRects } from "lib/solvers/SchematicTraceLinesSolver/SchematicTraceSingleLineSolver2/rect" +import { simplifyPath } from "lib/solvers/TraceCleanupSolver/simplifyPath" +import { visualizeInputProblem } from "lib/solvers/SchematicTracePipelineSolver/visualizeInputProblem" +import type { InputProblem } from "lib/types/InputProblem" +import { getColorFromString } from "lib/utils/getColorFromString" + +type Axis = "horizontal" | "vertical" + +type SegmentRef = { + mspPairId: MspConnectionPairId + segmentIndex: number + axis: Axis + coord: number + min: number + max: number + length: number + stableKey: string +} + +export interface SameNetTraceConsolidationSolverInput { + inputProblem: InputProblem + inputTraces: SolvedTracePath[] + mergeDistance?: number + intervalGap?: number +} + +const DEFAULT_MERGE_DISTANCE = 0.12 +const DEFAULT_INTERVAL_GAP = 0.12 +const MAX_CONSOLIDATION_PASSES = 1000 +const EPS = 1e-6 + +const cloneTrace = (trace: SolvedTracePath): SolvedTracePath => ({ + ...trace, + pins: [{ ...trace.pins[0] }, { ...trace.pins[1] }], + pinIds: [...trace.pinIds], + mspConnectionPairIds: [...trace.mspConnectionPairIds], + tracePath: trace.tracePath.map((p) => ({ ...p })), +}) + +const samePoint = (a: Point, b: Point) => + Math.abs(a.x - b.x) < EPS && Math.abs(a.y - b.y) < EPS + +const dedupePath = (path: Point[]): Point[] => { + const deduped: Point[] = [] + for (const point of path) { + if ( + deduped.length === 0 || + !samePoint(deduped[deduped.length - 1]!, point) + ) { + deduped.push(point) + } + } + return deduped +} + +const normalizePath = (path: Point[]) => + dedupePath(simplifyPath(dedupePath(path))) + +const intervalDistance = (a: SegmentRef, b: SegmentRef) => + Math.max(0, Math.max(a.min, b.min) - Math.min(a.max, b.max)) + +const segmentRefsCompatible = ( + a: SegmentRef, + b: SegmentRef, + mergeDistance: number, + intervalGap: number, +) => + a.axis === b.axis && + a.mspPairId !== b.mspPairId && + Math.abs(a.coord - b.coord) <= mergeDistance && + intervalDistance(a, b) <= intervalGap + +const compareSegmentRefs = (a: SegmentRef, b: SegmentRef) => + a.axis.localeCompare(b.axis) || + a.coord - b.coord || + a.min - b.min || + a.max - b.max || + a.mspPairId.localeCompare(b.mspPairId) || + a.segmentIndex - b.segmentIndex + +const chooseCanonicalSegment = (segments: SegmentRef[]) => + [...segments].sort( + (a, b) => + b.length - a.length || + a.stableKey.localeCompare(b.stableKey) || + a.segmentIndex - b.segmentIndex, + )[0]! + +const getSegmentRefs = (trace: SolvedTracePath): SegmentRef[] => { + const refs: SegmentRef[] = [] + const pts = trace.tracePath + + for (let i = 0; i < pts.length - 1; i++) { + const start = pts[i]! + const end = pts[i + 1]! + if (samePoint(start, end)) continue + + if (isHorizontal(start, end, EPS)) { + const min = Math.min(start.x, end.x) + const max = Math.max(start.x, end.x) + refs.push({ + mspPairId: trace.mspPairId, + segmentIndex: i, + axis: "horizontal", + coord: start.y, + min, + max, + length: max - min, + stableKey: `${trace.mspPairId}:${i}`, + }) + } else if (isVertical(start, end, EPS)) { + const min = Math.min(start.y, end.y) + const max = Math.max(start.y, end.y) + refs.push({ + mspPairId: trace.mspPairId, + segmentIndex: i, + axis: "vertical", + coord: start.x, + min, + max, + length: max - min, + stableKey: `${trace.mspPairId}:${i}`, + }) + } + } + + return refs +} + +const clusterSegments = ( + refs: SegmentRef[], + mergeDistance: number, + intervalGap: number, +) => { + const parent = refs.map((_, index) => index) + const find = (index: number): number => { + while (parent[index] !== index) { + parent[index] = parent[parent[index]!]! + index = parent[index]! + } + return index + } + const union = (a: number, b: number) => { + const rootA = find(a) + const rootB = find(b) + if (rootA !== rootB) parent[rootB] = rootA + } + + for (let i = 0; i < refs.length; i++) { + for (let j = i + 1; j < refs.length; j++) { + if ( + segmentRefsCompatible(refs[i]!, refs[j]!, mergeDistance, intervalGap) + ) { + union(i, j) + } + } + } + + const clusters = new Map() + for (let i = 0; i < refs.length; i++) { + const root = find(i) + if (!clusters.has(root)) clusters.set(root, []) + clusters.get(root)!.push(refs[i]!) + } + + return Array.from(clusters.values()).filter((cluster) => cluster.length > 1) +} + +const hasOnlyOrthogonalSegments = (path: Point[]) => { + for (let i = 0; i < path.length - 1; i++) { + const start = path[i]! + const end = path[i + 1]! + if (samePoint(start, end)) continue + if (!isHorizontal(start, end, EPS) && !isVertical(start, end, EPS)) { + return false + } + } + return true +} + +const countChipCollisions = (path: Point[], inputProblem: InputProblem) => { + const rects = getObstacleRects(inputProblem) + let count = 0 + for (let i = 0; i < path.length - 1; i++) { + const start = path[i]! + const end = path[i + 1]! + for (const rect of rects) { + if (segmentIntersectsRect(start, end, rect, EPS)) count++ + } + } + return count +} + +const countDifferentNetIntersections = ( + trace: SolvedTracePath, + path: Point[], + traces: SolvedTracePath[], +) => { + let count = 0 + for (let i = 0; i < path.length - 1; i++) { + const a = path[i]! + const b = path[i + 1]! + for (const otherTrace of traces) { + if ( + otherTrace.mspPairId === trace.mspPairId || + otherTrace.globalConnNetId === trace.globalConnNetId + ) { + continue + } + for (let j = 0; j < otherTrace.tracePath.length - 1; j++) { + const c = otherTrace.tracePath[j]! + const d = otherTrace.tracePath[j + 1]! + if (doSegmentsIntersect(a, b, c, d)) count++ + } + } + } + return count +} + +const snappedPathForSegment = ( + trace: SolvedTracePath, + segment: SegmentRef, + coord: number, +): Point[] | null => { + const pts = trace.tracePath.map((p) => ({ ...p })) + const segmentStart = pts[segment.segmentIndex]! + const segmentEnd = pts[segment.segmentIndex + 1]! + const lastIndex = pts.length - 1 + const isFirstSegment = segment.segmentIndex === 0 + const isLastSegment = segment.segmentIndex + 1 === lastIndex + + if (isFirstSegment && isLastSegment) return null + if (Math.abs(segment.coord - coord) < EPS) return null + + if (segment.axis === "horizontal") { + if (isFirstSegment) { + pts.splice( + 1, + 1, + { x: segmentStart.x, y: coord }, + { x: segmentEnd.x, y: coord }, + ) + } else if (isLastSegment) { + pts.splice( + segment.segmentIndex, + 1, + { x: segmentStart.x, y: coord }, + { x: segmentEnd.x, y: coord }, + ) + } else { + segmentStart.y = coord + segmentEnd.y = coord + } + } else { + if (isFirstSegment) { + pts.splice( + 1, + 1, + { x: coord, y: segmentStart.y }, + { x: coord, y: segmentEnd.y }, + ) + } else if (isLastSegment) { + pts.splice( + segment.segmentIndex, + 1, + { x: coord, y: segmentStart.y }, + { x: coord, y: segmentEnd.y }, + ) + } else { + segmentStart.x = coord + segmentEnd.x = coord + } + } + + return normalizePath(pts) +} + +export class SameNetTraceConsolidationSolver extends BaseSolver { + inputProblem: InputProblem + inputTraces: SolvedTracePath[] + mergeDistance: number + intervalGap: number + + outputTraces: SolvedTracePath[] + correctedTraceMap: Record + private consolidationPassCount = 0 + + constructor(params: SameNetTraceConsolidationSolverInput) { + super() + this.inputProblem = params.inputProblem + this.inputTraces = params.inputTraces + this.mergeDistance = params.mergeDistance ?? DEFAULT_MERGE_DISTANCE + this.intervalGap = params.intervalGap ?? DEFAULT_INTERVAL_GAP + + this.outputTraces = params.inputTraces.map(cloneTrace) + this.correctedTraceMap = Object.fromEntries( + this.outputTraces.map((trace) => [trace.mspPairId, trace]), + ) + } + + override getConstructorParams(): ConstructorParameters< + typeof SameNetTraceConsolidationSolver + >[0] { + return { + inputProblem: this.inputProblem, + inputTraces: this.inputTraces, + mergeDistance: this.mergeDistance, + intervalGap: this.intervalGap, + } + } + + override _step() { + const changed = this.applyNextConsolidationPass() + if (!changed) { + this.solved = true + return + } + + this.consolidationPassCount++ + if (this.consolidationPassCount >= MAX_CONSOLIDATION_PASSES) { + this.stats.consolidationPassLimitExceeded = true + this.stats.consolidationPassCount = this.consolidationPassCount + this.solved = true + } + } + + private applyNextConsolidationPass() { + const tracesByNet = new Map() + for (const trace of this.outputTraces) { + if (!tracesByNet.has(trace.globalConnNetId)) { + tracesByNet.set(trace.globalConnNetId, []) + } + tracesByNet.get(trace.globalConnNetId)!.push(trace) + } + + for (const globalConnNetId of [...tracesByNet.keys()].sort()) { + const netTraces = tracesByNet.get(globalConnNetId)! + for (const axis of ["horizontal", "vertical"] as const) { + const refs = netTraces + .flatMap(getSegmentRefs) + .filter((ref) => ref.axis === axis) + .sort(compareSegmentRefs) + + const clusters = clusterSegments( + refs, + this.mergeDistance, + this.intervalGap, + ).sort((a, b) => compareSegmentRefs(a[0]!, b[0]!)) + + for (const cluster of clusters) { + const canonical = chooseCanonicalSegment(cluster) + const targets = cluster + .filter((segment) => segment.stableKey !== canonical.stableKey) + .sort( + (a, b) => + a.length - b.length || a.stableKey.localeCompare(b.stableKey), + ) + + let changed = false + const updatedTraceIds = new Set() + for (const target of targets) { + if (updatedTraceIds.has(target.mspPairId)) continue + const trace = this.correctedTraceMap[target.mspPairId] + if (!trace) continue + const candidatePath = snappedPathForSegment( + trace, + target, + canonical.coord, + ) + if (!candidatePath) continue + if (!this.isCandidateSafe(trace, candidatePath)) continue + + const updatedTrace = { + ...trace, + tracePath: candidatePath, + } + this.correctedTraceMap[trace.mspPairId] = updatedTrace + this.outputTraces = this.outputTraces.map((existingTrace) => + existingTrace.mspPairId === trace.mspPairId + ? updatedTrace + : existingTrace, + ) + updatedTraceIds.add(trace.mspPairId) + changed = true + } + + if (changed) return true + } + } + } + + return false + } + + private isCandidateSafe(trace: SolvedTracePath, path: Point[]) { + if (path.length < 2) return false + if (!samePoint(path[0]!, trace.tracePath[0]!)) return false + if ( + !samePoint( + path[path.length - 1]!, + trace.tracePath[trace.tracePath.length - 1]!, + ) + ) { + return false + } + if (!hasOnlyOrthogonalSegments(path)) return false + + const originalChipCollisions = countChipCollisions( + trace.tracePath, + this.inputProblem, + ) + const candidateChipCollisions = countChipCollisions(path, this.inputProblem) + if (candidateChipCollisions > originalChipCollisions) return false + + const originalDifferentNetIntersections = countDifferentNetIntersections( + trace, + trace.tracePath, + this.outputTraces, + ) + const candidateDifferentNetIntersections = countDifferentNetIntersections( + trace, + path, + this.outputTraces, + ) + if ( + candidateDifferentNetIntersections > originalDifferentNetIntersections + ) { + return false + } + + return true + } + + getOutput() { + return { + traces: this.outputTraces, + correctedTraceMap: this.correctedTraceMap, + } + } + + override visualize(): GraphicsObject { + const graphics = visualizeInputProblem(this.inputProblem, { + chipAlpha: 0.1, + connectionAlpha: 0.1, + }) + + for (const trace of this.inputTraces) { + graphics.lines!.push({ + points: trace.tracePath, + strokeColor: "rgba(120,120,120,0.45)", + strokeDash: "4 2", + }) + } + + for (const trace of this.outputTraces) { + graphics.lines!.push({ + points: trace.tracePath, + strokeColor: getColorFromString(trace.globalConnNetId, 0.9), + }) + } + + return graphics + } +} diff --git a/lib/solvers/SchematicTracePipelineSolver/SchematicTracePipelineSolver.ts b/lib/solvers/SchematicTracePipelineSolver/SchematicTracePipelineSolver.ts index a56b50b7b..1bb4da3f7 100644 --- a/lib/solvers/SchematicTracePipelineSolver/SchematicTracePipelineSolver.ts +++ b/lib/solvers/SchematicTracePipelineSolver/SchematicTracePipelineSolver.ts @@ -27,6 +27,7 @@ import { VccNetLabelCornerPlacementSolver } from "../VccNetLabelCornerPlacementS import { TraceAnchoredNetLabelOverlapSolver } from "../TraceAnchoredNetLabelOverlapSolver/TraceAnchoredNetLabelOverlapSolver" import { NetLabelTraceCollisionSolver } from "../NetLabelTraceCollisionSolver/NetLabelTraceCollisionSolver" import { NetLabelNetLabelCollisionSolver } from "../NetLabelNetLabelCollisionSolver/NetLabelNetLabelCollisionSolver" +import { SameNetTraceConsolidationSolver } from "../SameNetTraceConsolidationSolver/SameNetTraceConsolidationSolver" type PipelineStep BaseSolver> = { solverName: string @@ -76,6 +77,7 @@ export class SchematicTracePipelineSolver extends BaseSolver { labelMergingSolver?: MergedNetLabelObstacleSolver traceLabelOverlapAvoidanceSolver?: TraceLabelOverlapAvoidanceSolver traceCleanupSolver?: TraceCleanupSolver + sameNetTraceConsolidationSolver?: SameNetTraceConsolidationSolver example28Solver?: Example28Solver availableNetOrientationSolver?: AvailableNetOrientationSolver vccNetLabelCornerPlacementSolver?: VccNetLabelCornerPlacementSolver @@ -219,11 +221,28 @@ export class SchematicTracePipelineSolver extends BaseSolver { }, ] }), + definePipelineStep( + "sameNetTraceConsolidationSolver", + SameNetTraceConsolidationSolver, + (instance) => { + const traces = + instance.traceCleanupSolver?.getOutput().traces ?? + instance.traceLabelOverlapAvoidanceSolver!.getOutput().traces + + return [ + { + inputProblem: instance.inputProblem, + inputTraces: traces, + }, + ] + }, + ), definePipelineStep( "netLabelPlacementSolver", NetLabelPlacementSolver, (instance) => { const traces = + instance.sameNetTraceConsolidationSolver?.getOutput().traces ?? instance.traceCleanupSolver?.getOutput().traces ?? instance.traceLabelOverlapAvoidanceSolver!.getOutput().traces @@ -239,6 +258,7 @@ export class SchematicTracePipelineSolver extends BaseSolver { ), definePipelineStep("example28Solver", Example28Solver, (instance) => { const traces = + instance.sameNetTraceConsolidationSolver?.getOutput().traces ?? instance.traceCleanupSolver?.getOutput().traces ?? instance.traceLabelOverlapAvoidanceSolver!.getOutput().traces diff --git a/site/examples/example42.page.tsx b/site/examples/example42.page.tsx new file mode 100644 index 000000000..4f5f5f1b4 --- /dev/null +++ b/site/examples/example42.page.tsx @@ -0,0 +1,13 @@ +import { useMemo } from "react" +import { SameNetTraceConsolidationSolver } from "lib/solvers/SameNetTraceConsolidationSolver/SameNetTraceConsolidationSolver" +import { GenericSolverDebugger } from "site/components/GenericSolverDebugger" +import inputData from "../../tests/assets/example42.json" + +export default () => { + const solver = useMemo( + () => new SameNetTraceConsolidationSolver(inputData as any), + [], + ) + + return +} diff --git a/tests/assets/example42.json b/tests/assets/example42.json new file mode 100644 index 000000000..9cbd857f6 --- /dev/null +++ b/tests/assets/example42.json @@ -0,0 +1,67 @@ +{ + "inputProblem": { + "chips": [], + "directConnections": [], + "netConnections": [], + "availableNetLabelOrientations": {} + }, + "inputTraces": [ + { + "mspPairId": "same-net-trunk", + "dcConnNetId": "vcc", + "globalConnNetId": "vcc", + "userNetId": "VCC", + "pins": [ + { "chipId": "A", "pinId": "A.1", "x": 0, "y": 0 }, + { "chipId": "B", "pinId": "B.1", "x": 2.4, "y": 0 } + ], + "pinIds": ["A.1", "B.1"], + "mspConnectionPairIds": ["same-net-trunk"], + "tracePath": [ + { "x": 0, "y": 0 }, + { "x": 0.4, "y": 0 }, + { "x": 0.4, "y": 1 }, + { "x": 2.4, "y": 1 }, + { "x": 2.4, "y": 0 } + ] + }, + { + "mspPairId": "same-net-branch", + "dcConnNetId": "vcc", + "globalConnNetId": "vcc", + "userNetId": "VCC", + "pins": [ + { "chipId": "C", "pinId": "C.1", "x": 0, "y": 0.2 }, + { "chipId": "D", "pinId": "D.1", "x": 2, "y": 0.2 } + ], + "pinIds": ["C.1", "D.1"], + "mspConnectionPairIds": ["same-net-branch"], + "tracePath": [ + { "x": 0, "y": 0.2 }, + { "x": 0.4, "y": 0.2 }, + { "x": 0.4, "y": 1.08 }, + { "x": 2, "y": 1.08 }, + { "x": 2, "y": 0.2 } + ] + }, + { + "mspPairId": "different-net-nearby", + "dcConnNetId": "gnd", + "globalConnNetId": "gnd", + "userNetId": "GND", + "pins": [ + { "chipId": "E", "pinId": "E.1", "x": 0, "y": 0.4 }, + { "chipId": "F", "pinId": "F.1", "x": 2, "y": 0.4 } + ], + "pinIds": ["E.1", "F.1"], + "mspConnectionPairIds": ["different-net-nearby"], + "tracePath": [ + { "x": 0, "y": 0.4 }, + { "x": 0.4, "y": 0.4 }, + { "x": 0.4, "y": 1.16 }, + { "x": 2, "y": 1.16 }, + { "x": 2, "y": 0.4 } + ] + } + ] +} diff --git a/tests/examples/__snapshots__/example02.snap.svg b/tests/examples/__snapshots__/example02.snap.svg index 3815fdc0b..311efa946 100644 --- a/tests/examples/__snapshots__/example02.snap.svg +++ b/tests/examples/__snapshots__/example02.snap.svg @@ -58,7 +58,7 @@ orientation: y+" data-x="-1.4574283249999997" data-y="1.3024186000000004" cx="29 +orientation: y-" data-x="-1.5071549750000002" data-y="-0.2000000000000004" cx="288.00904304318556" cy="349.5798500586337" r="3" fill="hsl(40, 100%, 50%, 0.9)" /> - + @@ -202,7 +202,7 @@ available orientations: y+" data-x="-1.4574283249999997" data-y="1.5274186000000 +available orientations: y-" data-x="-1.5071549750000002" data-y="-0.4250000000000004" x="279.54556663155927" y="349.5798500586337" width="16.92695282325252" height="38.085643852318185" fill="#00000066" stroke="#000000" stroke-width="0.011815475714285715" /> - + diff --git a/tests/examples/__snapshots__/example19.snap.svg b/tests/examples/__snapshots__/example19.snap.svg index ac5ba82db..8e54f6c76 100644 --- a/tests/examples/__snapshots__/example19.snap.svg +++ b/tests/examples/__snapshots__/example19.snap.svg @@ -98,10 +98,10 @@ orientation: y+" data-x="3.3884680250000008" data-y="1.2997267500000007" cx="438 - + - + @@ -129,19 +129,23 @@ orientation: y+" data-x="3.3884680250000008" data-y="1.2997267500000007" cx="438 +globalConnNetId: connectivity_net3 +available orientations: any" data-x="2.2284928" data-y="-0.46751595000000035" x="300.34928" y="375.9895275000001" width="45" height="19.999999999999943" fill="hsl(40, 100%, 50%, 0.35)" stroke="black" stroke-width="0.01" /> +globalConnNetId: connectivity_net0 +available orientations: any" data-x="1.6252733499999996" data-y="-0.30120930000000024" x="240.027335" y="359.3588625000001" width="45" height="20" fill="hsl(40, 100%, 50%, 0.35)" stroke="black" stroke-width="0.01" /> +globalConnNetId: connectivity_net1 +available orientations: any" data-x="0.7999316625000001" data-y="0.42500000000000004" x="169.99316625000006" y="274.23793250000006" width="20" height="45" fill="hsl(40, 100%, 50%, 0.35)" stroke="black" stroke-width="0.01" /> +globalConnNetId: connectivity_net2 +available orientations: any" data-x="3.3884680250000008" data-y="1.5247267500000008" x="428.8468025000001" y="164.26525749999996" width="20" height="45.00000000000003" fill="hsl(40, 100%, 50%, 0.35)" stroke="black" stroke-width="0.01" />