Skip to content

Commit a1c7281

Browse files
committed
Clean up GaussianElimination
1 parent 2763310 commit a1c7281

File tree

1 file changed

+21
-33
lines changed

1 file changed

+21
-33
lines changed

src/main/scala/eu/sim642/adventofcodelib/GaussianElimination.scala

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -42,42 +42,36 @@ object GaussianElimination {
4242
m(y2) = row1
4343
}
4444

45-
def multiplyRow(y: Int, factor: A): Unit = {
46-
for (x2 <- 0 until (n + 1))
47-
m(y)(x2) *= factor
48-
}
49-
5045
def simplifyRow(y: Int): Unit = {
5146
val factor = NumberTheory.gcd(m(y).toSeq) // TODO: avoid conversion
5247
if (factor.abs > summon[Integral[A]].one) {
53-
for (x2 <- 0 until (n + 1))
54-
m(y)(x2) /= factor
48+
for (x <- 0 until (n + 1))
49+
m(y)(x) /= factor
5550
}
5651
}
5752

58-
def reduceDown(x: Int, y1: Int, y2: Int): Unit = {
53+
def reduceRow(x: Int, y1: Int, y2: Int): Unit = {
5954
val c2 = m(y2)(x)
6055
if (c2 != 0) {
6156
val c1 = m(y1)(x)
62-
val (_, _, (factor, factor2)) = NumberTheory.extendedGcd(c1, c2)
57+
val (factor1, factor2) = NumberTheory.extendedGcd(c1, c2)._3
6358
for (x2 <- 0 until x) // must start from 0 because we're now multiplying entire row y2
6459
m(y2)(x2) = factor2 * m(y2)(x2)
6560
for (x2 <- x until (n + 1))
66-
m(y2)(x2) = factor2 * m(y2)(x2) + factor * m(y1)(x2)
61+
m(y2)(x2) = factor2 * m(y2)(x2) + factor1 * m(y1)(x2)
6762
//simplifyRow(y2) // TODO: helps?
6863
}
6964
}
7065

66+
// forward elimination
7167
var y = 0
7268
for (x <- 0 until n) {
73-
val y2opt = m.indices.find(y2 => y2 >= y && m(y2)(x) != 0)
74-
y2opt match {
69+
(y until m.size).find(m(_)(x) != 0) match {
7570
case None => // move to next x
7671
case Some(y2) =>
7772
swapRows(y, y2)
7873
for (y3 <- (y + 1) until m.size)
79-
reduceDown(x, y, y3)
80-
74+
reduceRow(x, y, y3)
8175
y += 1
8276
}
8377
}
@@ -86,33 +80,27 @@ object GaussianElimination {
8680
for (y2 <- y until m.size)
8781
assert(m(y2).last == 0) // TODO: return Option
8882

89-
val mainVars = mutable.ArrayBuffer.empty[Int]
83+
// backward elimination
84+
val dependentVars = mutable.ArrayBuffer.empty[Int]
9085
val freeVars = mutable.ArrayBuffer.empty[Int]
9186
y = 0
9287
for (x <- 0 until n) {
93-
if (y < m.size) { // TODO: break if y too big
94-
if (m(y)(x) == 0) {
95-
freeVars += x
96-
()
97-
} // move to next x
98-
else {
99-
mainVars += x
100-
for (y3 <- 0 until y)
101-
reduceDown(x, y, y3)
102-
103-
y += 1
104-
}
88+
if (y >= m.size || m(y)(x) == 0)
89+
freeVars += x
90+
else {
91+
dependentVars += x
92+
for (y2 <- 0 until y)
93+
reduceRow(x, y, y2)
94+
y += 1
10595
}
106-
else
107-
freeVars += x // can't break if this is here
10896
}
10997

11098
Solution(
111-
dependentVars = mainVars.toSeq,
112-
dependentGenerator = (mainVars lazyZip m).view.map((v, row) => row(v)).toSeq,
99+
dependentVars = dependentVars.toSeq,
100+
dependentGenerator = (dependentVars lazyZip m).view.map((v, row) => row(v)).toSeq,
113101
freeVars = freeVars.toSeq,
114-
freeGenerators = freeVars.view.map(x => m.view.take(mainVars.size).map(_(x)).toSeq).toSeq,
115-
const = m.view.take(mainVars.size).map(_.last).toSeq
102+
freeGenerators = freeVars.view.map(x => m.view.take(dependentVars.size).map(_(x)).toSeq).toSeq,
103+
const = m.view.take(dependentVars.size).map(_.last).toSeq
116104
)
117105
}
118106
}

0 commit comments

Comments
 (0)