Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions tests/test_ml_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import numpy as np
import pytest

import multipers as mp
import multipers.ml.kernels as kernels


def test_distance_matrix_to_list():
"""Test DistanceMatrix2DistanceList transformer"""
# Create a simple distance matrix
dist_matrix = np.array([[0, 1, 2], [1, 0, 1.5], [2, 1.5, 0]])

transformer = kernels.DistanceMatrix2DistanceList()
result = transformer.fit_transform([dist_matrix])
assert result is not None
assert len(result) == 1


def test_distance_list_to_matrix():
"""Test DistanceList2DistanceMatrix transformer"""
# Create a simple distance list
dist_list = np.array([1, 2, 1.5])

transformer = kernels.DistanceList2DistanceMatrix()
result = transformer.fit_transform([dist_list])
assert result is not None
assert len(result) == 1


def test_distance_matrices_to_lists():
"""Test DistanceMatrices2DistancesList transformer"""
# Create distance matrices
dist_matrices = [
np.array([[0, 1], [1, 0]]),
np.array([[0, 2], [2, 0]])
]

transformer = kernels.DistanceMatrices2DistancesList()
result = transformer.fit_transform([dist_matrices])
assert result is not None
assert len(result) == 1


def test_distance_lists_to_matrices():
"""Test DistancesLists2DistanceMatrices transformer"""
# Create distance lists
dist_lists = [np.array([1]), np.array([2])]

transformer = kernels.DistancesLists2DistanceMatrices()
result = transformer.fit_transform([dist_lists])
assert result is not None
assert len(result) == 1


def test_distance_matrix_to_kernel():
"""Test DistanceMatrix2Kernel transformer"""
# Create a simple distance matrix
dist_matrix = np.array([[0, 1, 2], [1, 0, 1.5], [2, 1.5, 0]])

transformer = kernels.DistanceMatrix2Kernel()
result = transformer.fit_transform([dist_matrix])
assert result is not None
assert len(result) == 1
120 changes: 120 additions & 0 deletions tests/test_ml_pipelines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import numpy as np
import pytest
import platform

import multipers as mp
import multipers.ml.mma as mma
import multipers.ml.signed_measures as signed_measures
import multipers.ml.tools as tools
from multipers.tests import random_st


def test_filtered_complex_to_mma():
"""Test FilteredComplex2MMA transformer"""
st = mp.SimplexTreeMulti(num_parameters=2)
st.insert([0], [0, 1])
st.insert([1], [1, 0])
st.insert([0, 1], [1, 1])

transformer = mma.FilteredComplex2MMA()
result = transformer.fit_transform([[st]])
assert result is not None
assert len(result) == 1


def test_mma_formatter():
"""Test MMAFormatter transformer"""
st = mp.SimplexTreeMulti(num_parameters=2)
st.insert([0], [0, 1])
st.insert([1], [1, 0])
st.insert([0, 1], [1, 1])

# First get MMA
mma_transformer = mma.FilteredComplex2MMA()
mma_result = mma_transformer.fit_transform([[st]])

# Then format
formatter = mma.MMAFormatter()
result = formatter.fit_transform(mma_result)
assert result is not None


def test_filtered_complex_to_signed_measure():
"""Test FilteredComplex2SignedMeasure transformer"""
st = mp.SimplexTreeMulti(num_parameters=2)
st.insert([0], [0, 1])
st.insert([1], [1, 0])
st.insert([0, 1], [1, 1])

transformer = signed_measures.FilteredComplex2SignedMeasure()
result = transformer.fit_transform([[st]])
assert result is not None
assert len(result) == 1


def test_signed_measure_formatter():
"""Test SignedMeasureFormatter transformer"""
st = mp.SimplexTreeMulti(num_parameters=2)
st.insert([0], [0, 1])
st.insert([1], [1, 0])
st.insert([0, 1], [1, 1])

# First get signed measures
sm_transformer = signed_measures.FilteredComplex2SignedMeasure()
sm_result = sm_transformer.fit_transform([[st]])

# Then format
formatter = signed_measures.SignedMeasureFormatter()
result = formatter.fit_transform(sm_result)
assert result is not None


def test_simplex_tree_edge_collapser():
"""Test SimplexTreeEdgeCollapser from tools"""
st = random_st(num_parameters=2)

collapser = tools.SimplexTreeEdgeCollapser()
result = collapser.fit_transform([st])
assert result is not None
assert len(result) == 1


@pytest.mark.skipif(
platform.system() == "Windows",
reason="Detected windows. Pykeops is not compatible with windows yet. Skipping this ftm.",
)
def test_point_cloud_to_filtered_complex():
"""Test point cloud to filtered complex pipeline"""
import multipers.ml.point_clouds as mmp

pts = np.array([[1, 1], [2, 2]], dtype=np.float32)

# Test basic functionality
transformer = mmp.PointCloud2FilteredComplex(masses=[0.1])
result = transformer.fit_transform([pts])
assert result is not None
assert len(result) == 1

# Check result type
st = result[0][0]
assert isinstance(st, mp.simplex_tree_multi.SimplexTreeMulti_type)


@pytest.mark.skipif(
platform.system() == "Windows",
reason="Detected windows. Pykeops is not compatible with windows yet. Skipping this ftm.",
)
def test_point_cloud_alpha_complex():
"""Test point cloud with alpha complex"""
import multipers.ml.point_clouds as mmp

pts = np.array([[1, 1], [2, 2]], dtype=np.float32)

transformer = mmp.PointCloud2FilteredComplex(
bandwidths=[-0.1], complex="alpha"
)
result = transformer.fit_transform([pts])
assert result is not None

st = result[0][0]
assert isinstance(st, mp.simplex_tree_multi.SimplexTreeMulti_type)
29 changes: 29 additions & 0 deletions tests/test_ml_sliced_wasserstein.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import numpy as np
import pytest

import multipers as mp
import multipers.ml.sliced_wasserstein as sliced_wasserstein


def test_sliced_wasserstein_distance():
"""Test SlicedWassersteinDistance transformer"""
# Create simple persistence diagrams
dgm1 = np.array([[0.0, 1.0], [0.5, 2.0]])
dgm2 = np.array([[0.2, 1.2], [0.7, 1.8]])

transformer = sliced_wasserstein.SlicedWassersteinDistance()
result = transformer.fit_transform([[dgm1], [dgm2]])
assert result is not None
assert len(result) == 2


def test_wasserstein_distance():
"""Test WassersteinDistance transformer"""
# Create simple persistence diagrams
dgm1 = np.array([[0.0, 1.0], [0.5, 2.0]])
dgm2 = np.array([[0.2, 1.2], [0.7, 1.8]])

transformer = sliced_wasserstein.WassersteinDistance()
result = transformer.fit_transform([[dgm1], [dgm2]])
assert result is not None
assert len(result) == 2
Loading