@@ -32,60 +32,67 @@ object GaussianElimination {
3232 }
3333
3434 def solve [A : ClassTag ](initialA : Seq [Seq [A ]], initialb : Seq [A ])(using aIntegral : Integral [A ]): Solution [A ] = {
35- val rows = initialA zip initialb // TODO: lazyZip
36- val m : mutable.ArraySeq [mutable.ArraySeq [A ]] = rows.map((a, b) => (a :+ b).to(mutable.ArraySeq )).to(mutable.ArraySeq )
35+ val m = initialA.size
3736 val n = initialA.head.size
37+ require(initialb.sizeIs == m)
38+
39+ val A : mutable.ArraySeq [mutable.ArraySeq [A ]] = initialA.map(_.to(mutable.ArraySeq )).to(mutable.ArraySeq )
40+ val b : mutable.ArraySeq [A ] = initialb.to(mutable.ArraySeq )
3841
3942 def swapRows (y1 : Int , y2 : Int ): Unit = {
40- val row1 = m(y1)
41- m(y1) = m(y2)
42- m(y2) = row1
43+ val A1 = A (y1)
44+ A (y1) = A (y2)
45+ A (y2) = A1
46+ val b1 = b(y1)
47+ b(y1) = b(y2)
48+ b(y2) = b1
4349 }
4450
4551 def simplifyRow (y : Int ): Unit = {
46- val factor = NumberTheory .gcd(m( y).toSeq) // TODO: avoid conversion
52+ val factor = NumberTheory .gcd(NumberTheory .gcd( A ( y).toSeq), b(y) ) // TODO: avoid conversion
4753 if (factor.abs > summon[Integral [A ]].one) {
48- for (x <- 0 until (n + 1 ) )
49- m(y)(x ) /= factor
54+ A (y).mapInPlace(_ / factor )
55+ b(y ) /= factor
5056 }
5157 }
5258
5359 def reduceRow (x : Int , y1 : Int , y2 : Int ): Unit = {
54- val c2 = m (y2)(x)
60+ val c2 = A (y2)(x)
5561 if (c2 != 0 ) {
56- val c1 = m (y1)(x)
62+ val c1 = A (y1)(x)
5763 val (factor1, factor2) = NumberTheory .extendedGcd(c1, c2)._3
5864 for (x2 <- 0 until x) // must start from 0 because we're now multiplying entire row y2
59- m(y2)(x2) = factor2 * m(y2)(x2)
60- for (x2 <- x until (n + 1 ))
61- m(y2)(x2) = factor2 * m(y2)(x2) + factor1 * m(y1)(x2)
65+ A (y2)(x2) = factor2 * A (y2)(x2)
66+ for (x2 <- x until n)
67+ A (y2)(x2) = factor2 * A (y2)(x2) + factor1 * A (y1)(x2)
68+ b(y2) = factor2 * b(y2) + factor1 * b(y1)
6269 // simplifyRow(y2) // TODO: helps?
6370 }
6471 }
6572
6673 // forward elimination
6774 var y = 0
6875 for (x <- 0 until n) {
69- (y until m.size ).find(m (_)(x) != 0 ) match {
76+ (y until m).find(A (_)(x) != 0 ) match {
7077 case None => // move to next x
7178 case Some (y2) =>
7279 swapRows(y, y2)
73- for (y3 <- (y + 1 ) until m.size )
80+ for (y3 <- (y + 1 ) until m)
7481 reduceRow(x, y, y3)
7582 y += 1
7683 }
7784 }
7885
7986 // check consistency
80- for (y2 <- y until m .size)
81- assert(m (y2).last == 0 ) // TODO: return Option
87+ for (y2 <- y until b .size)
88+ assert(b (y2) == 0 ) // TODO: return Option
8289
8390 // backward elimination
8491 val dependentVars = mutable.ArrayBuffer .empty[Int ]
8592 val freeVars = mutable.ArrayBuffer .empty[Int ]
8693 y = 0
8794 for (x <- 0 until n) {
88- if (y >= m.size || m (y)(x) == 0 )
95+ if (y >= m || A (y)(x) == 0 )
8996 freeVars += x
9097 else {
9198 dependentVars += x
@@ -95,12 +102,13 @@ object GaussianElimination {
95102 }
96103 }
97104
105+ val Aview = A .view.take(dependentVars.size)
98106 Solution (
99107 dependentVars = dependentVars.toSeq,
100- dependentGenerator = (dependentVars lazyZip m).view. map((v, row) => row(v )).toSeq,
108+ dependentGenerator = (A lazyZip dependentVars). map(_(_ )).toSeq,
101109 freeVars = freeVars.toSeq,
102- freeGenerators = freeVars.view.map(x => m.view.take(dependentVars.size) .map(_(x)).toSeq).toSeq,
103- const = m .view.take(dependentVars.size).map(_.last ).toSeq
110+ freeGenerators = freeVars.view.map(x => Aview .map(_(x)).toSeq).toSeq,
111+ const = b .view.take(dependentVars.size).toSeq
104112 )
105113 }
106114}
0 commit comments