Skip to content

Commit 70ca4f2

Browse files
authored
Merge pull request #7 from harry0000/fenwick-tree
Add FenwickTree
2 parents 6997e81 + 94b70f8 commit 70ca4f2

File tree

5 files changed

+365
-1
lines changed

5 files changed

+365
-1
lines changed

build.sbt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@ lazy val root = project
1616
"-unchecked",
1717
"-Wunused:all"
1818
),
19-
libraryDependencies += "org.scalameta" %% "munit" % "0.7.29" % Test
19+
libraryDependencies += "org.scalameta" %% "munit" % "0.7.29" % Test,
20+
Test / parallelExecution := false
2021
)
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
package io.github.acl4s
2+
3+
import scala.reflect.ClassTag
4+
5+
import io.github.acl4s.internal.rightOpenInterval
6+
7+
/**
8+
* Reference: https://en.wikipedia.org/wiki/Fenwick_tree
9+
*
10+
* @param n
11+
* @param m
12+
* @tparam T
13+
*/
14+
case class FenwickTree[T: ClassTag](n: Int)(using m: AddSub[T]) {
15+
private val data: Array[T] = Array.fill(n)(m.e())
16+
17+
def add(index: Int, x: T): Unit = {
18+
assert(0 <= index && index < n)
19+
var p = index + 1
20+
while (p <= n) {
21+
data(p - 1) = m.combine(data(p - 1), x)
22+
p += p & -p
23+
}
24+
}
25+
26+
private def sum(i: Int): T = {
27+
var s = m.e()
28+
var r = i
29+
while (r > 0) {
30+
s = m.combine(s, data(r - 1))
31+
r -= r & -r
32+
}
33+
s
34+
}
35+
36+
def sum(range: Range): T = {
37+
val (l, r) = rightOpenInterval(range)
38+
sum(l, r)
39+
}
40+
41+
def sum(l: Int, r: Int): T = {
42+
assert(0 <= l && l <= r && r <= n)
43+
m.subtract(sum(r), sum(l))
44+
}
45+
}
46+
47+
object FenwickTree {
48+
def apply[T: AddSub: ClassTag](array: Array[T]): FenwickTree[T] = {
49+
val ft = FenwickTree[T](array.length)
50+
array.indices.foreach(i => {
51+
ft.add(i, array(i))
52+
})
53+
ft
54+
}
55+
}
56+
57+
trait AddSub[T] extends Add[T] {
58+
def subtract(a: T, b: T): T
59+
}
60+
61+
object AddSub {
62+
given (using m: Add[Char]): AddSub[Char] with {
63+
override def e(): Char = m.e()
64+
override def combine(a: Char, b: Char): Char = m.combine(a, b)
65+
override def subtract(a: Char, b: Char): Char = (a - b).asInstanceOf[Char]
66+
}
67+
68+
given (using m: Add[Byte]): AddSub[Byte] with {
69+
override def e(): Byte = m.e()
70+
override def combine(a: Byte, b: Byte): Byte = m.combine(a, b)
71+
override def subtract(a: Byte, b: Byte): Byte = (a - b).asInstanceOf[Byte]
72+
}
73+
74+
given (using m: Add[Short]): AddSub[Short] with {
75+
override def e(): Short = m.e()
76+
override def combine(a: Short, b: Short): Short = m.combine(a, b)
77+
override def subtract(a: Short, b: Short): Short = (a - b).asInstanceOf[Short]
78+
}
79+
80+
given (using m: Add[Int]): AddSub[Int] with {
81+
override def e(): Int = m.e()
82+
override def combine(a: Int, b: Int): Int = m.combine(a, b)
83+
override def subtract(a: Int, b: Int): Int = a - b
84+
}
85+
86+
given (using m: Add[Long]): AddSub[Long] with {
87+
override def e(): Long = m.e()
88+
override def combine(a: Long, b: Long): Long = m.combine(a, b)
89+
override def subtract(a: Long, b: Long): Long = a - b
90+
}
91+
92+
given (using m: Add[Float]): AddSub[Float] with {
93+
override def e(): Float = m.e()
94+
override def combine(a: Float, b: Float): Float = m.combine(a, b)
95+
override def subtract(a: Float, b: Float): Float = a - b
96+
}
97+
98+
given (using m: Add[Double]): AddSub[Double] with {
99+
override def e(): Double = m.e()
100+
override def combine(a: Double, b: Double): Double = m.combine(a, b)
101+
override def subtract(a: Double, b: Double): Double = a - b
102+
}
103+
104+
given (using m: Add[DynamicModInt]): AddSub[DynamicModInt] with {
105+
override def e(): DynamicModInt = m.e()
106+
override def combine(a: DynamicModInt, b: DynamicModInt): DynamicModInt = m.combine(a, b)
107+
override def subtract(a: DynamicModInt, b: DynamicModInt): DynamicModInt = a - b
108+
}
109+
110+
given (using m: Add[ModInt998244353]): AddSub[ModInt998244353] with {
111+
override def e(): ModInt998244353 = m.e()
112+
override def combine(a: ModInt998244353, b: ModInt998244353): ModInt998244353 = m.combine(a, b)
113+
override def subtract(a: ModInt998244353, b: ModInt998244353): ModInt998244353 = a - b
114+
}
115+
116+
given (using m: Add[ModInt1000000007]): AddSub[ModInt1000000007] with {
117+
override def e(): ModInt1000000007 = m.e()
118+
override def combine(a: ModInt1000000007, b: ModInt1000000007): ModInt1000000007 = m.combine(a, b)
119+
override def subtract(a: ModInt1000000007, b: ModInt1000000007): ModInt1000000007 = a - b
120+
}
121+
122+
given [T <: Int](using m: Add[StaticModInt[T]]): AddSub[StaticModInt[T]] with {
123+
override def e(): StaticModInt[T] = m.e()
124+
override def combine(a: StaticModInt[T], b: StaticModInt[T]): StaticModInt[T] = m.combine(a, b)
125+
override def subtract(a: StaticModInt[T], b: StaticModInt[T]): StaticModInt[T] = a - b
126+
}
127+
}

src/main/scala/io/github/acl4s/ModInt.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,13 +189,15 @@ type ModInt998244353 = StaticModInt[Mod998244353.value.type]
189189
object ModInt1000000007 {
190190
given Modulus[1_000_000_007] = Mod1000000007
191191

192+
def apply(): ModInt1000000007 = StaticModInt()
192193
def apply(value: Int): ModInt1000000007 = StaticModInt(value)
193194
def apply(value: Long): ModInt1000000007 = StaticModInt(value)
194195
}
195196

196197
object ModInt998244353 {
197198
given Modulus[998_244_353] = Mod998244353
198199

200+
def apply(): ModInt998244353 = StaticModInt()
199201
def apply(value: Int): ModInt998244353 = StaticModInt(value)
200202
def apply(value: Long): ModInt998244353 = StaticModInt(value)
201203
}

src/main/scala/io/github/acl4s/Monoid.scala

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,67 @@ trait Monoid[T] {
88
object Monoid {
99
def apply[T](using m: Monoid[T]): Monoid[T] = m
1010
}
11+
12+
trait Add[T] extends Monoid[T]
13+
object Add {
14+
given Add[Char] with {
15+
override def e(): Char = 0
16+
override def combine(a: Char, b: Char): Char = (a + b).asInstanceOf[Char]
17+
}
18+
19+
given Add[Byte] with {
20+
override def e(): Byte = 0
21+
override def combine(a: Byte, b: Byte): Byte = (a + b).asInstanceOf[Byte]
22+
}
23+
24+
given Add[Short] with {
25+
override def e(): Short = 0
26+
override def combine(a: Short, b: Short): Short = (a + b).asInstanceOf[Short]
27+
}
28+
29+
given Add[Int] with {
30+
override def e(): Int = 0
31+
override def combine(a: Int, b: Int): Int = a + b
32+
}
33+
34+
given Add[Long] with {
35+
override def e(): Long = 0L
36+
override def combine(a: Long, b: Long): Long = a + b
37+
}
38+
39+
given Add[Float] with {
40+
override def e(): Float = 0f
41+
override def combine(a: Float, b: Float): Float = a + b
42+
}
43+
44+
given Add[Double] with {
45+
override def e(): Double = 0d
46+
override def combine(a: Double, b: Double): Double = a + b
47+
}
48+
49+
given Add[DynamicModInt] with {
50+
private val zero = DynamicModInt()
51+
override def e(): DynamicModInt = zero
52+
override def combine(a: DynamicModInt, b: DynamicModInt): DynamicModInt = a + b
53+
}
54+
55+
given Add[ModInt998244353] with {
56+
private val zero = ModInt998244353()
57+
override def e(): ModInt998244353 = zero
58+
override def combine(a: ModInt998244353, b: ModInt998244353): ModInt998244353 = a + b
59+
}
60+
61+
given Add[ModInt1000000007] with {
62+
private val zero = ModInt1000000007()
63+
override def e(): ModInt1000000007 = zero
64+
override def combine(a: ModInt1000000007, b: ModInt1000000007): ModInt1000000007 = a + b
65+
}
66+
67+
given [T <: Int](using Modulus[T]): Add[StaticModInt[T]] with {
68+
private val zero = StaticModInt[T]()
69+
override def e(): StaticModInt[T] = zero
70+
override def combine(a: StaticModInt[T], b: StaticModInt[T]): StaticModInt[T] = a + b
71+
}
72+
73+
def apply[T](using m: Add[T]): Add[T] = m
74+
}
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
package io.github.acl4s
2+
3+
class FenwickTreeSuite extends munit.FunSuite {
4+
5+
/**
6+
* @see https://atcoder.jp/contests/practice2/tasks/practice2_b
7+
*/
8+
test("AtCoder Library Practice Contest B - Fenwick Tree") {
9+
val fw = FenwickTree(Array(1L, 2L, 3L, 4L, 5L))
10+
11+
assertEquals(fw.sum(0, 5), 15L)
12+
assertEquals(fw.sum(2, 4), 7L)
13+
14+
fw.add(3, 10)
15+
16+
assertEquals(fw.sum(0, 5), 25L)
17+
assertEquals(fw.sum(0, 3), 6L)
18+
}
19+
20+
test("zero") {
21+
{
22+
val fw = FenwickTree[Long](0)
23+
assertEquals(fw.sum(0, 0), 0L)
24+
}
25+
26+
{
27+
type ModInt = DynamicModInt
28+
val ModInt = DynamicModInt
29+
30+
val fw = FenwickTree[ModInt](0)
31+
assertEquals(fw.sum(0, 0), ModInt(0))
32+
}
33+
34+
{
35+
type ModInt = ModInt998244353
36+
val ModInt = ModInt998244353
37+
38+
val fw = FenwickTree[ModInt](0)
39+
assertEquals(fw.sum(0, 0), ModInt(0))
40+
}
41+
42+
{
43+
type ModInt = ModInt1000000007
44+
val ModInt = ModInt1000000007
45+
46+
val fw = FenwickTree[ModInt](0)
47+
assertEquals(fw.sum(0, 0), ModInt(0))
48+
}
49+
50+
{
51+
given Modulus[1_000_000_009] = Modulus[1_000_000_009]()
52+
type ModInt = StaticModInt[1_000_000_009]
53+
val ModInt = StaticModInt
54+
55+
val fw = FenwickTree[ModInt](0)
56+
assertEquals(fw.sum(0, 0), ModInt(0))
57+
}
58+
}
59+
60+
test("naive") {
61+
(0 to 50).foreach(n => {
62+
val fw = FenwickTree[Long](n)
63+
(0 until n).foreach(i => {
64+
fw.add(i, i.toLong * i)
65+
})
66+
67+
for {
68+
l <- 0 to n
69+
r <- l to n
70+
} {
71+
val sum = (l until r).map(i => i.toLong * i).sum
72+
assertEquals(fw.sum(l, r), sum)
73+
}
74+
})
75+
}
76+
77+
test("bound int") {
78+
val fw = FenwickTree[Int](10)
79+
80+
fw.add(3, Int.MaxValue)
81+
fw.add(5, Int.MinValue)
82+
83+
assertEquals(fw.sum(0, 10), -1)
84+
assertEquals(fw.sum(3, 6), -1)
85+
86+
assertEquals(fw.sum(3, 4), Int.MaxValue)
87+
assertEquals(fw.sum(4, 10), Int.MinValue)
88+
}
89+
90+
test("bound long") {
91+
val fw = FenwickTree[Long](10)
92+
93+
fw.add(3, Long.MaxValue)
94+
fw.add(5, Long.MinValue)
95+
96+
assertEquals(fw.sum(0, 10), -1L)
97+
assertEquals(fw.sum(3, 6), -1L)
98+
99+
assertEquals(fw.sum(3, 4), Long.MaxValue)
100+
assertEquals(fw.sum(4, 10), Long.MinValue)
101+
}
102+
103+
test("overflow") {
104+
val fw = FenwickTree[Int](20)
105+
val a = new Array[Long](20)
106+
(0 until 10).foreach(i => {
107+
fw.add(i, Int.MaxValue)
108+
a(i) += Int.MaxValue
109+
})
110+
(10 until 20).foreach(i => {
111+
fw.add(i, Int.MinValue)
112+
a(i) += Int.MinValue
113+
})
114+
115+
fw.add(5, 11_111)
116+
a(5) += 11_111
117+
118+
for {
119+
l <- 0 to 20
120+
r <- l to 20
121+
} {
122+
val sum = (l until r).map(i => a(i)).sum
123+
val dif = sum - fw.sum(l, r)
124+
assertEquals(dif % (1L << 32), 0L)
125+
}
126+
}
127+
128+
test("StaticModInt") {
129+
given Modulus[11] = Modulus[11]()
130+
type ModInt = StaticModInt[11]
131+
val ModInt = StaticModInt
132+
133+
(0 to 50).foreach(n => {
134+
val fw = FenwickTree[ModInt](n)
135+
(0 until n).foreach(i => {
136+
fw.add(i, ModInt(i.toLong * i))
137+
})
138+
139+
for {
140+
l <- 0 to n
141+
r <- l to n
142+
} {
143+
val sum = (l until r).map(i => i.toLong * i).sum
144+
assertEquals(fw.sum(l, r), ModInt(sum))
145+
}
146+
})
147+
}
148+
149+
test("DynamicModInt") {
150+
type ModInt = DynamicModInt
151+
val ModInt = DynamicModInt
152+
ModInt.setMod(11)
153+
154+
(0 to 50).foreach(n => {
155+
val fw = FenwickTree[ModInt](n)
156+
(0 until n).foreach(i => {
157+
fw.add(i, ModInt(i.toLong * i))
158+
})
159+
160+
for {
161+
l <- 0 to n
162+
r <- l to n
163+
} {
164+
val sum = (l until r).map(i => i.toLong * i).sum
165+
assertEquals(fw.sum(l, r), ModInt(sum))
166+
}
167+
})
168+
}
169+
170+
}

0 commit comments

Comments
 (0)