Skip to content

Commit 727955d

Browse files
committed
Optimize set operations with O(n+m) merge-walk algorithm
This replaces the previous O(n log n) and O(n * log m) implementations of setUnion, setInter, and setDiff with efficient O(n+m) merge-walk algorithms that exploit the fact that sets are already sorted. Changes: - setUnion: Changed from concat + sort + dedup to merge-walk (was O((n+m) log(n+m)), now O(n+m)) - setInter: Changed from iterate + binary search to merge-walk (was O(min(n,m) * log(max(n,m))), now O(n+m)) - setDiff: Changed from iterate + binary search to merge-walk (was O(n * log m), now O(n+m)) - Added applyKeyF helper function for cleaner key function handling - Added validateSet calls to setUnion for consistency with other ops - Pre-sized ArrayBuffer allocations for better memory efficiency The setMember function still uses binary search which is optimal for single-element lookups. These optimizations address the same performance issue as upstream PR databricks#574, but avoid the O(n²) bug in that PR's uniqArr implementation (which calls ArrayBuilder.result() on each iteration). All changes respect the throwErrorForInvalidSets setting.
1 parent d6bc48b commit 727955d

File tree

1 file changed

+107
-16
lines changed

1 file changed

+107
-16
lines changed

sjsonnet/src/sjsonnet/Std.scala

Lines changed: 107 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1707,15 +1707,53 @@ class Std(
17071707
builtin(Set_),
17081708
builtinWithDefaults("setUnion", "a" -> null, "b" -> null, "keyF" -> Val.False(dummyPos)) {
17091709
(args, pos, ev) =>
1710+
val keyF = args(2)
1711+
validateSet(ev, pos, keyF, args(0))
1712+
validateSet(ev, pos, keyF, args(1))
1713+
17101714
val a = toSetArr(args, 0, pos, ev)
17111715
val b = toSetArr(args, 1, pos, ev)
1716+
17121717
if (a.isEmpty) {
1713-
uniqArr(pos, ev, sortArr(pos, ev, args(1), args(2)), args(2))
1718+
args(1)
17141719
} else if (b.isEmpty) {
1715-
uniqArr(pos, ev, sortArr(pos, ev, args(0), args(2)), args(2))
1720+
args(0)
17161721
} else {
1717-
val concat = Val.Arr(pos, a ++ b)
1718-
uniqArr(pos, ev, sortArr(pos, ev, concat, args(2)), args(2))
1722+
// Use merge-walk algorithm: O(n + m) instead of O((n+m) log(n+m))
1723+
val out = new mutable.ArrayBuffer[Lazy](a.length + b.length)
1724+
var idxA = 0
1725+
var idxB = 0
1726+
1727+
while (idxA < a.length && idxB < b.length) {
1728+
val elemA = a(idxA).force
1729+
val elemB = b(idxB).force
1730+
val keyA = applyKeyF(elemA, keyF, pos, ev)
1731+
val keyB = applyKeyF(elemB, keyF, pos, ev)
1732+
1733+
val cmp = ev.compare(keyA, keyB)
1734+
if (cmp < 0) {
1735+
out.append(a(idxA))
1736+
idxA += 1
1737+
} else if (cmp > 0) {
1738+
out.append(b(idxB))
1739+
idxB += 1
1740+
} else {
1741+
// Equal keys: take from a, skip both
1742+
out.append(a(idxA))
1743+
idxA += 1
1744+
idxB += 1
1745+
}
1746+
}
1747+
// Append remaining elements
1748+
while (idxA < a.length) {
1749+
out.append(a(idxA))
1750+
idxA += 1
1751+
}
1752+
while (idxB < b.length) {
1753+
out.append(b(idxB))
1754+
idxB += 1
1755+
}
1756+
Val.Arr(pos, out.toArray)
17191757
}
17201758
},
17211759
builtinWithDefaults("setInter", "a" -> null, "b" -> null, "keyF" -> Val.False(dummyPos)) {
@@ -1727,13 +1765,27 @@ class Std(
17271765
val a = toSetArr(args, 0, pos, ev)
17281766
val b = toSetArr(args, 1, pos, ev)
17291767

1730-
val out = new mutable.ArrayBuffer[Lazy]
1731-
1732-
// The intersection will always be, at most, the size of the smallest set.
1733-
val sets = if (b.length < a.length) (b, a) else (a, b)
1734-
for (v <- sets._1) {
1735-
if (existsInSet(ev, pos, keyF, sets._2, v.force)) {
1736-
out.append(v)
1768+
// Use merge-walk algorithm: O(n + m) instead of O(min(n,m) * log(max(n,m)))
1769+
val out = new mutable.ArrayBuffer[Lazy](math.min(a.length, b.length))
1770+
var idxA = 0
1771+
var idxB = 0
1772+
1773+
while (idxA < a.length && idxB < b.length) {
1774+
val elemA = a(idxA).force
1775+
val elemB = b(idxB).force
1776+
val keyA = applyKeyF(elemA, keyF, pos, ev)
1777+
val keyB = applyKeyF(elemB, keyF, pos, ev)
1778+
1779+
val cmp = ev.compare(keyA, keyB)
1780+
if (cmp < 0) {
1781+
idxA += 1
1782+
} else if (cmp > 0) {
1783+
idxB += 1
1784+
} else {
1785+
// Equal keys: found intersection element
1786+
out.append(a(idxA))
1787+
idxA += 1
1788+
idxB += 1
17371789
}
17381790
}
17391791
Val.Arr(pos, out.toArray)
@@ -1746,14 +1798,45 @@ class Std(
17461798

17471799
val a = toSetArr(args, 0, pos, ev)
17481800
val b = toSetArr(args, 1, pos, ev)
1749-
val out = new mutable.ArrayBuffer[Lazy]
17501801

1751-
for (v <- a) {
1752-
if (!existsInSet(ev, pos, keyF, b, v.force)) {
1753-
out.append(v)
1802+
if (b.isEmpty) {
1803+
args(0)
1804+
} else {
1805+
// Use merge-walk algorithm: O(n + m) instead of O(n * log(m))
1806+
val out = new mutable.ArrayBuffer[Lazy](a.length)
1807+
var idxA = 0
1808+
var idxB = 0
1809+
1810+
while (idxA < a.length) {
1811+
val elemA = a(idxA).force
1812+
val keyA = applyKeyF(elemA, keyF, pos, ev)
1813+
1814+
// Advance idxB past elements smaller than keyA
1815+
while (idxB < b.length && {
1816+
val keyB = applyKeyF(b(idxB).force, keyF, pos, ev)
1817+
ev.compare(keyA, keyB) > 0
1818+
}) {
1819+
idxB += 1
1820+
}
1821+
1822+
// Check if current b element matches keyA
1823+
if (idxB < b.length) {
1824+
val keyB = applyKeyF(b(idxB).force, keyF, pos, ev)
1825+
if (ev.compare(keyA, keyB) != 0) {
1826+
// keyA not found in b, add to output
1827+
out.append(a(idxA))
1828+
} else {
1829+
// Found match, skip this element and advance b
1830+
idxB += 1
1831+
}
1832+
} else {
1833+
// Exhausted b, all remaining elements of a are in the diff
1834+
out.append(a(idxA))
1835+
}
1836+
idxA += 1
17541837
}
1838+
Val.Arr(pos, out.toArray)
17551839
}
1756-
Val.Arr(pos, out.toArray)
17571840
},
17581841
builtinWithDefaults("setMember", "x" -> null, "arr" -> null, "keyF" -> Val.False(dummyPos)) {
17591842
(args, pos, ev) =>
@@ -1972,6 +2055,14 @@ class Std(
19722055
}
19732056
}
19742057

2058+
private def applyKeyF(elem: Val, keyF: Val, pos: Position, ev: EvalScope): Val = {
2059+
keyF match {
2060+
case keyFFunc: Val.Func =>
2061+
keyFFunc.apply1(elem, pos.noOffset)(ev, TailstrictModeDisabled)
2062+
case _ => elem
2063+
}
2064+
}
2065+
19752066
private def existsInSet(
19762067
ev: EvalScope,
19772068
pos: Position,

0 commit comments

Comments
 (0)