diff --git a/core/build.gradle.kts b/core/build.gradle.kts index 3d166c55f3..82ff0200e9 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -17,6 +17,7 @@ plugins { alias(kodex) alias(buildconfig) alias(binary.compatibility.validator) + alias(kotlinx.benchmark) // generates keywords using the :generator module alias(keywordGenerator) @@ -75,6 +76,7 @@ dependencies { testImplementation(libs.kotestAssertions) { exclude("org.jetbrains.kotlin", "kotlin-stdlib-jdk8") } + testImplementation(libs.kotlinx.benchmark.runtime) testImplementation(libs.kotlin.scriptingJvm) testImplementation(libs.jsoup) testImplementation(libs.sl4jsimple) @@ -89,6 +91,17 @@ dependencies { testImplementation(projects.dataframeCsv) } +benchmark { + targets { + register("test") + } + configurations { + register("sort") { + include("SortingBenchmark") + } + } +} + val samplesImplementation by configurations.getting { extendsFrom(configurations.testImplementation.get()) } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/DataFrame.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/DataFrame.kt index adff782b74..88d9150f49 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/DataFrame.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/DataFrame.kt @@ -102,9 +102,11 @@ public interface DataFrame : @RequiredByIntellijPlugin public operator fun get(index: Int): DataRow - public operator fun get(indices: Iterable): DataFrame = getRows(indices) + public operator fun get(indices: Iterable): DataFrame = + columns().map { col -> col[indices] }.toDataFrame().cast() - public operator fun get(range: IntRange): DataFrame = getRows(range) + public operator fun get(range: IntRange): DataFrame = + if (range == indices()) this else columns().map { col -> col[range] }.toDataFrame().cast() public operator fun get(first: IntRange, vararg ranges: IntRange): DataFrame = getRows(headPlusArray(first, ranges).asSequence().flatMap { it.asSequence() }.asIterable()) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataFrameGet.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataFrameGet.kt index 77bd83cf1e..6f7f964918 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataFrameGet.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataFrameGet.kt @@ -45,11 +45,9 @@ public fun DataFrame.getColumns(vararg columns: String): List = g public fun DataFrame.getColumnIndex(col: AnyCol): Int = getColumnIndex(col.name()) -public fun DataFrame.getRows(range: IntRange): DataFrame = - if (range == indices()) this else columns().map { col -> col[range] }.toDataFrame().cast() +public fun DataFrame.getRows(range: IntRange): DataFrame = get(range) -public fun DataFrame.getRows(indices: Iterable): DataFrame = - columns().map { col -> col[indices] }.toDataFrame().cast() +public fun DataFrame.getRows(indices: Iterable): DataFrame = get(indices) public fun DataFrame.getOrNull(index: Int): DataRow? = if (index < 0 || index >= nrow) null else get(index) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/KotlinNotebookPluginUtils.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/KotlinNotebookPluginUtils.kt index 657b84e69c..9781af2c90 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/KotlinNotebookPluginUtils.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/KotlinNotebookPluginUtils.kt @@ -3,6 +3,7 @@ package org.jetbrains.kotlinx.dataframe.jupyter import org.jetbrains.kotlinx.dataframe.AnyCol import org.jetbrains.kotlinx.dataframe.AnyFrame import org.jetbrains.kotlinx.dataframe.AnyRow +import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.DataRow import org.jetbrains.kotlinx.dataframe.annotations.RequiredByIntellijPlugin import org.jetbrains.kotlinx.dataframe.api.Convert @@ -29,13 +30,17 @@ import org.jetbrains.kotlinx.dataframe.api.at import org.jetbrains.kotlinx.dataframe.api.dataFrameOf import org.jetbrains.kotlinx.dataframe.api.frames import org.jetbrains.kotlinx.dataframe.api.getColumn +import org.jetbrains.kotlinx.dataframe.api.getRows import org.jetbrains.kotlinx.dataframe.api.into -import org.jetbrains.kotlinx.dataframe.api.sortWith +import org.jetbrains.kotlinx.dataframe.api.isFrameColumn +import org.jetbrains.kotlinx.dataframe.api.isList +import org.jetbrains.kotlinx.dataframe.api.rows import org.jetbrains.kotlinx.dataframe.api.toDataFrame import org.jetbrains.kotlinx.dataframe.api.values import org.jetbrains.kotlinx.dataframe.api.valuesAreComparable import org.jetbrains.kotlinx.dataframe.columns.ColumnPath import org.jetbrains.kotlinx.dataframe.impl.ColumnNameGenerator +import java.util.Arrays /** * A class with utility methods for Kotlin Notebook Plugin integration. @@ -68,6 +73,10 @@ public object KotlinNotebookPluginUtils { /** * Sorts a dataframe-like object by multiple columns. * If a column type is not comparable, sorting by string representation is applied instead. + * Sorts DataFrames by their size because looking at the smallest / biggest groups after groupBy is very popular. + * + * Returns "lazily materialized" dataframe, which means get, getRows, take operation must be applied to turn it to a valid sorted dataframe. + * "lazily materialized" - after sorting 1 million of rows and given the page size = 100, a dataframe with only 100 rows is created. * * @param dataFrameLike The dataframe-like object to sort. * @param columnPaths The list of columns to sort by. Each element in the list represents a column path @@ -103,60 +112,100 @@ public object KotlinNotebookPluginUtils { ColumnPath(path) } - val comparator = createComparator(sortKeys, isDesc) + if (sortKeys.size == 1) { + val column = df.getColumn(sortKeys[0]) + + // Not sure how to have generic logic that would produce Comparator and Comparator without overhead + // For now Comparator is needed for fallback case of sorting multiple columns. Although it's now impossible in UI + // Please make sure to change both this and createColumnComparator + val comparator: Comparator = when { + column.valuesAreComparable() -> compareBy(nullsLast()) { + column[it] as Comparable? + } + + column.isFrameColumn() -> compareBy { column[it].rowsCount() } + + column.isList() -> compareBy { (column[it] as? List<*>)?.size ?: 0 } + + else -> compareBy { column[it]?.toString() ?: "" } + } + + val finalComparator = if (isDesc[0]) comparator.reversed() else comparator + + val permutation = Array(column.size()) { it } + Arrays.parallelSort(permutation, finalComparator) + return SortedDataFrameView(df, permutation.asList()) + } + + val comparator = createComparator(df, sortKeys, isDesc) - return df.sortWith(comparator) + return df.sortWithLazy(comparator) } - private fun createComparator(sortKeys: List, isDesc: List): Comparator> { - return Comparator { row1, row2 -> - for ((key, desc) in sortKeys.zip(isDesc)) { - val comparisonResult = if (row1.df().getColumn(key).valuesAreComparable()) { - compareComparableValues(row1, row2, key, desc) - } else { - compareStringValues(row1, row2, key, desc) + private fun createComparator( + df: AnyFrame, + sortKeys: List, + isDesc: List, + ): Comparator> { + val columnComparators = sortKeys.zip(isDesc).map { (key, desc) -> + val column = df.getColumn(key) + createColumnComparator(column, desc) + } + + return when (columnComparators.size) { + 1 -> columnComparators.single() + + else -> Comparator { row1, row2 -> + for (comparator in columnComparators) { + val result = comparator.compare(row1, row2) + // If a comparison result is non-zero, we have resolved the ordering + if (result != 0) return@Comparator result } - // If a comparison result is non-zero, we have resolved the ordering - if (comparisonResult != 0) return@Comparator comparisonResult + // All comparisons are equal + 0 } - // All comparisons are equal - 0 } } - @Suppress("UNCHECKED_CAST") - private fun compareComparableValues( - row1: DataRow<*>, - row2: DataRow<*>, - key: ColumnPath, - desc: Boolean, - ): Int { - val firstValue = row1.getValueOrNull(key) as Comparable? - val secondValue = row2.getValueOrNull(key) as Comparable? - - return when { - firstValue == null && secondValue == null -> 0 - firstValue == null -> if (desc) 1 else -1 - secondValue == null -> if (desc) -1 else 1 - desc -> secondValue.compareTo(firstValue) - else -> firstValue.compareTo(secondValue) + private fun createColumnComparator(column: AnyCol, desc: Boolean): Comparator> { + val comparator: Comparator> = when { + column.valuesAreComparable() -> compareBy(nullsLast()) { + column[it] as Comparable? + } + + // Comparator shows a slight improvement in performance for this case + column.isFrameColumn() -> Comparator { r1, r2 -> + column[r1].rowsCount().compareTo(column[r2].rowsCount()) + } + + column.isList() -> compareBy { (column[it] as? List<*>)?.size ?: 0 } + + else -> compareBy { column[it]?.toString() ?: "" } } + return if (desc) comparator.reversed() else comparator } - private fun compareStringValues( - row1: DataRow<*>, - row2: DataRow<*>, - key: ColumnPath, - desc: Boolean, - ): Int { - val firstValue = (row1.getValueOrNull(key)?.toString() ?: "") - val secondValue = (row2.getValueOrNull(key)?.toString() ?: "") - - return if (desc) { - secondValue.compareTo(firstValue) - } else { - firstValue.compareTo(secondValue) + private fun DataFrame.sortWithLazy(comparator: Comparator>): DataFrame { + val permutation = rows().sortedWith(comparator).map { it.index() } + return SortedDataFrameView(this, permutation) + } + + private class SortedDataFrameView(private val source: DataFrame, private val permutation: List) : + DataFrame by source { + + override operator fun get(index: Int): DataRow = source[permutation[index]] + + override operator fun get(range: IntRange): DataFrame { + val indices = range.map { permutation[it] } + return source.getRows(indices) } + + override operator fun get(indices: Iterable): DataFrame { + val mappedIndices = indices.map { permutation[it] } + return source.getRows(mappedIndices) + } + + override fun get(columnName: String): AnyCol = super.get(columnName)[permutation] } internal fun isDataframeConvertable(dataframeLike: Any?): Boolean = diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/KotlinNotebookPluginUtilsTests.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/KotlinNotebookPluginUtilsTests.kt new file mode 100644 index 0000000000..03e3a4f575 --- /dev/null +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/KotlinNotebookPluginUtilsTests.kt @@ -0,0 +1,84 @@ +package org.jetbrains.kotlinx.dataframe + +import io.kotest.matchers.shouldBe +import org.jetbrains.kotlinx.dataframe.api.dataFrameOf +import org.jetbrains.kotlinx.dataframe.api.toColumn +import org.jetbrains.kotlinx.dataframe.jupyter.KotlinNotebookPluginUtils +import org.junit.Test +import kotlin.random.Random + +/** + * Other tests are located in Jupyter module: + * org.jetbrains.kotlinx.dataframe.jupyter.RenderingTests + */ +class KotlinNotebookPluginUtilsTests { + @Test + fun `sort lists by size descending`() { + val random = Random(123) + val lists = List(20) { List(random.nextInt(1, 100)) { it } } + null + val df = dataFrameOf("listColumn" to lists) + + val res = KotlinNotebookPluginUtils.sortByColumns(df, listOf(listOf("listColumn")), desc = listOf(true)) + + res["listColumn"].values() shouldBe lists.sortedByDescending { it?.size ?: 0 } + } + + @Test + fun `sort lists by size ascending`() { + val lists = listOf(listOf(1, 2, 3), listOf(1), listOf(1, 2), null) + val df = dataFrameOf("listColumn" to lists) + + val res = KotlinNotebookPluginUtils.sortByColumns(df, listOf(listOf("listColumn")), desc = listOf(false)) + + res["listColumn"].values() shouldBe listOf(null, listOf(1), listOf(1, 2), listOf(1, 2, 3)) + } + + @Test + fun `sort empty lists`() { + val lists = listOf(listOf(1, 2), emptyList(), listOf(1), emptyList()) + val df = dataFrameOf("listColumn" to lists) + + val res = KotlinNotebookPluginUtils.sortByColumns(df, listOf(listOf("listColumn")), desc = listOf(true)) + + res["listColumn"].values() shouldBe listOf(listOf(1, 2), listOf(1), emptyList(), emptyList()) + } + + @Test + fun `sort lists with equal sizes preserves stability`() { + val lists = listOf(listOf("a"), listOf("b"), listOf("c")) + val df = dataFrameOf("listColumn" to lists) + + val res = KotlinNotebookPluginUtils.sortByColumns(df, listOf(listOf("listColumn")), desc = listOf(true)) + + res["listColumn"].values() shouldBe lists + } + + @Test + fun `sort frame column by row count descending`() { + val frames = listOf( + dataFrameOf("x" to listOf(1)), + dataFrameOf("x" to listOf(1, 2, 3)), + dataFrameOf("x" to listOf(1, 2)), + DataFrame.empty(), + ) + val df = dataFrameOf("nested" to frames.toColumn()) + + val res = KotlinNotebookPluginUtils.sortByColumns(df, listOf(listOf("nested")), desc = listOf(true)) + + res["nested"].values().map { (it as DataFrame<*>).rowsCount() } shouldBe listOf(3, 2, 1, 0) + } + + @Test + fun `sort frame column by row count ascending`() { + val frames = listOf( + dataFrameOf("x" to listOf(1, 2, 3)), + dataFrameOf("x" to listOf(1)), + DataFrame.empty(), + ) + val df = dataFrameOf("nested" to frames.toColumn()) + + val res = KotlinNotebookPluginUtils.sortByColumns(df, listOf(listOf("nested")), desc = listOf(false)) + + res["nested"].values().map { (it as DataFrame<*>).rowsCount() } shouldBe listOf(0, 1, 3) + } +} diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/SortingBenchmark.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/SortingBenchmark.kt new file mode 100644 index 0000000000..215571722c --- /dev/null +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/SortingBenchmark.kt @@ -0,0 +1,59 @@ +@file:Suppress("EmptyRange") + +package org.jetbrains.kotlinx.dataframe + +import kotlinx.benchmark.Benchmark +import kotlinx.benchmark.BenchmarkMode +import kotlinx.benchmark.BenchmarkTimeUnit +import kotlinx.benchmark.Measurement +import kotlinx.benchmark.Mode +import kotlinx.benchmark.OutputTimeUnit +import kotlinx.benchmark.Param +import kotlinx.benchmark.Scope +import kotlinx.benchmark.Setup +import kotlinx.benchmark.State +import kotlinx.benchmark.Warmup +import org.jetbrains.kotlinx.dataframe.api.dataFrameOf +import org.jetbrains.kotlinx.dataframe.api.toColumn +import org.jetbrains.kotlinx.dataframe.api.toDataFrame +import org.jetbrains.kotlinx.dataframe.jupyter.KotlinNotebookPluginUtils +import kotlin.random.Random + +@State(Scope.Benchmark) +@Warmup(iterations = 5, time = 1) +@Measurement(iterations = 5, time = 1) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(BenchmarkTimeUnit.MILLISECONDS) +open class SortingBenchmark { + + @Param("10000", "100000", "1000000") + var size: Int = 0 + + @Param("int", "string", "double", "category", "list", "frame") + lateinit var columnType: String + + private lateinit var df: DataFrame<*> + private lateinit var columnPath: List + + @Setup + fun setup() { + val random = Random(42) + df = (0 until size).toDataFrame { + "int" from { it } + "string" from { "name_${random.nextInt(1000)}" } + "double" from { random.nextDouble() } + "category" from { listOf("A", "B", "C", "D").random(random) } + "list" from { List(random.nextInt(1, 20)) { "tag$it" } } + "frame" from { + dataFrameOf("x" to List(random.nextInt(1, 50)) { random.nextInt() }.toColumn()) + } + } + columnPath = listOf(columnType) + } + + @Benchmark + fun sort(): DataFrame<*> { + val sorted = KotlinNotebookPluginUtils.sortByColumns(df, listOf(columnPath), listOf(false)) + return KotlinNotebookPluginUtils.getRowsSubsetForRendering(sorted, 0, 20).value + } +}