Skip to content

Commit 2af0082

Browse files
committed
Clean up GaussianElimination more
1 parent a1c7281 commit 2af0082

File tree

1 file changed

+29
-21
lines changed

1 file changed

+29
-21
lines changed

src/main/scala/eu/sim642/adventofcodelib/GaussianElimination.scala

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,60 +32,67 @@ object GaussianElimination {
3232
}
3333

3434
def solve[A: ClassTag](initialA: Seq[Seq[A]], initialb: Seq[A])(using aIntegral: Integral[A]): Solution[A] = {
35-
val rows = initialA zip initialb // TODO: lazyZip
36-
val m: mutable.ArraySeq[mutable.ArraySeq[A]] = rows.map((a, b) => (a :+ b).to(mutable.ArraySeq)).to(mutable.ArraySeq)
35+
val m = initialA.size
3736
val n = initialA.head.size
37+
require(initialb.sizeIs == m)
38+
39+
val A: mutable.ArraySeq[mutable.ArraySeq[A]] = initialA.map(_.to(mutable.ArraySeq)).to(mutable.ArraySeq)
40+
val b: mutable.ArraySeq[A] = initialb.to(mutable.ArraySeq)
3841

3942
def swapRows(y1: Int, y2: Int): Unit = {
40-
val row1 = m(y1)
41-
m(y1) = m(y2)
42-
m(y2) = row1
43+
val A1 = A(y1)
44+
A(y1) = A(y2)
45+
A(y2) = A1
46+
val b1 = b(y1)
47+
b(y1) = b(y2)
48+
b(y2) = b1
4349
}
4450

4551
def simplifyRow(y: Int): Unit = {
46-
val factor = NumberTheory.gcd(m(y).toSeq) // TODO: avoid conversion
52+
val factor = NumberTheory.gcd(NumberTheory.gcd(A(y).toSeq), b(y)) // TODO: avoid conversion
4753
if (factor.abs > summon[Integral[A]].one) {
48-
for (x <- 0 until (n + 1))
49-
m(y)(x) /= factor
54+
A(y).mapInPlace(_ / factor)
55+
b(y) /= factor
5056
}
5157
}
5258

5359
def reduceRow(x: Int, y1: Int, y2: Int): Unit = {
54-
val c2 = m(y2)(x)
60+
val c2 = A(y2)(x)
5561
if (c2 != 0) {
56-
val c1 = m(y1)(x)
62+
val c1 = A(y1)(x)
5763
val (factor1, factor2) = NumberTheory.extendedGcd(c1, c2)._3
5864
for (x2 <- 0 until x) // must start from 0 because we're now multiplying entire row y2
59-
m(y2)(x2) = factor2 * m(y2)(x2)
60-
for (x2 <- x until (n + 1))
61-
m(y2)(x2) = factor2 * m(y2)(x2) + factor1 * m(y1)(x2)
65+
A(y2)(x2) = factor2 * A(y2)(x2)
66+
for (x2 <- x until n)
67+
A(y2)(x2) = factor2 * A(y2)(x2) + factor1 * A(y1)(x2)
68+
b(y2) = factor2 * b(y2) + factor1 * b(y1)
6269
//simplifyRow(y2) // TODO: helps?
6370
}
6471
}
6572

6673
// forward elimination
6774
var y = 0
6875
for (x <- 0 until n) {
69-
(y until m.size).find(m(_)(x) != 0) match {
76+
(y until m).find(A(_)(x) != 0) match {
7077
case None => // move to next x
7178
case Some(y2) =>
7279
swapRows(y, y2)
73-
for (y3 <- (y + 1) until m.size)
80+
for (y3 <- (y + 1) until m)
7481
reduceRow(x, y, y3)
7582
y += 1
7683
}
7784
}
7885

7986
// check consistency
80-
for (y2 <- y until m.size)
81-
assert(m(y2).last == 0) // TODO: return Option
87+
for (y2 <- y until b.size)
88+
assert(b(y2) == 0) // TODO: return Option
8289

8390
// backward elimination
8491
val dependentVars = mutable.ArrayBuffer.empty[Int]
8592
val freeVars = mutable.ArrayBuffer.empty[Int]
8693
y = 0
8794
for (x <- 0 until n) {
88-
if (y >= m.size || m(y)(x) == 0)
95+
if (y >= m || A(y)(x) == 0)
8996
freeVars += x
9097
else {
9198
dependentVars += x
@@ -95,12 +102,13 @@ object GaussianElimination {
95102
}
96103
}
97104

105+
val Aview = A.view.take(dependentVars.size)
98106
Solution(
99107
dependentVars = dependentVars.toSeq,
100-
dependentGenerator = (dependentVars lazyZip m).view.map((v, row) => row(v)).toSeq,
108+
dependentGenerator = (A lazyZip dependentVars).map(_(_)).toSeq,
101109
freeVars = freeVars.toSeq,
102-
freeGenerators = freeVars.view.map(x => m.view.take(dependentVars.size).map(_(x)).toSeq).toSeq,
103-
const = m.view.take(dependentVars.size).map(_.last).toSeq
110+
freeGenerators = freeVars.view.map(x => Aview.map(_(x)).toSeq).toSeq,
111+
const = b.view.take(dependentVars.size).toSeq
104112
)
105113
}
106114
}

0 commit comments

Comments
 (0)