diff --git a/specs b/specs index c75927fdc..b1720e4a5 160000 --- a/specs +++ b/specs @@ -1 +1 @@ -Subproject commit c75927fdc0aa6714f638ab28165ca78bf2759a76 +Subproject commit b1720e4a5010942d412624d342b636070a41502b diff --git a/tests/CLAUDE.md b/tests/CLAUDE.md new file mode 100644 index 000000000..faa4047fc --- /dev/null +++ b/tests/CLAUDE.md @@ -0,0 +1,31 @@ +# Testing Guidelines + +## General Principles + +- **No test classes** - Use plain functions, not class-based test organization +- **Minimal mocking** - Use real backends with temp directories whenever possible; only mock when absolutely necessary +- **Use mocker fixture** - When mocking is required, use the `mocker` fixture (pytest-mock), never unittest mocks + +## File Organization + +``` +tests/unit/layer/volumetric// +├── test_.py # One file per component +``` + +## Naming Conventions + +- Test files: `test_.py` +- Test functions: `test__` + +## Fixtures + +- Function-scoped (default) +- Use `tempfile.TemporaryDirectory()` for temp dirs +- Explicit cleanup via context managers or yield pattern + +## Assertions + +- `pytest.approx()` for float comparisons +- `np.testing.assert_array_equal()` for array comparisons +- `pytest.raises(ExceptionType, match="pattern")` for exception testing diff --git a/tests/unit/layer/volumetric/seg_contact/test_backend.py b/tests/unit/layer/volumetric/seg_contact/test_backend.py new file mode 100644 index 000000000..fa35d230f --- /dev/null +++ b/tests/unit/layer/volumetric/seg_contact/test_backend.py @@ -0,0 +1,617 @@ +import json +import os +import tempfile + +import numpy as np +import pytest + +from zetta_utils.geometry import BBox3D, Vec3D +from zetta_utils.layer.volumetric import VolumetricIndex +from zetta_utils.layer.volumetric.seg_contact import SegContact, SegContactLayerBackend + +# --- Chunk naming tests --- + + +def test_get_chunk_name(): + """Test chunk naming follows precomputed format.""" + backend = SegContactLayerBackend( + path="/tmp/test", + resolution=Vec3D(16, 16, 40), + voxel_offset=Vec3D(0, 0, 0), + size=Vec3D(1000, 1000, 500), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + ) + + # First chunk at origin + assert backend.get_chunk_name((0, 0, 0)) == "0-256_0-256_0-128" + + # Second chunk in x + assert backend.get_chunk_name((1, 0, 0)) == "256-512_0-256_0-128" + + # Chunk at (1, 2, 3) + assert backend.get_chunk_name((1, 2, 3)) == "256-512_512-768_384-512" + + +def test_get_chunk_name_with_offset(): + """Test chunk naming with non-zero voxel offset.""" + backend = SegContactLayerBackend( + path="/tmp/test", + resolution=Vec3D(16, 16, 40), + voxel_offset=Vec3D(100, 200, 50), + size=Vec3D(1000, 1000, 500), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + ) + + # First chunk starts at offset + assert backend.get_chunk_name((0, 0, 0)) == "100-356_200-456_50-178" + + +def test_get_chunk_path(): + """Test chunk path includes contacts subdirectory.""" + backend = SegContactLayerBackend( + path="/tmp/test_dataset", + resolution=Vec3D(16, 16, 40), + voxel_offset=Vec3D(0, 0, 0), + size=Vec3D(1000, 1000, 500), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + ) + + chunk_path = backend.get_chunk_path((0, 0, 0)) + assert chunk_path == "/tmp/test_dataset/contacts/0-256_0-256_0-128" + + +# --- COM to chunk index tests --- + + +def test_com_to_chunk_idx(): + """Test converting COM in nanometers to chunk grid index.""" + backend = SegContactLayerBackend( + path="/tmp/test", + resolution=Vec3D(16, 16, 40), + voxel_offset=Vec3D(0, 0, 0), + size=Vec3D(1000, 1000, 500), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + ) + + # COM at origin -> chunk (0, 0, 0) + assert backend.com_to_chunk_idx(Vec3D(0.0, 0.0, 0.0)) == (0, 0, 0) + + # COM in middle of first chunk (chunk_size * resolution / 2) + # chunk_size[0] * resolution[0] = 256 * 16 = 4096 nm + # So 2000 nm is still in first chunk + assert backend.com_to_chunk_idx(Vec3D(2000.0, 2000.0, 2000.0)) == (0, 0, 0) + + # COM at start of second chunk in x + # 256 voxels * 16 nm/voxel = 4096 nm + assert backend.com_to_chunk_idx(Vec3D(4096.0, 0.0, 0.0)) == (1, 0, 0) + + +def test_com_to_chunk_idx_with_offset(): + """Test COM to chunk index with non-zero voxel offset.""" + backend = SegContactLayerBackend( + path="/tmp/test", + resolution=Vec3D(16, 16, 40), + voxel_offset=Vec3D(100, 0, 0), # offset of 100 voxels in x = 1600 nm + size=Vec3D(1000, 1000, 500), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + ) + + # COM at 1600 nm (which is voxel 100, the start) -> chunk (0, 0, 0) + assert backend.com_to_chunk_idx(Vec3D(1600.0, 0.0, 0.0)) == (0, 0, 0) + + # COM at 1600 + 4096 = 5696 nm -> chunk (1, 0, 0) + assert backend.com_to_chunk_idx(Vec3D(5696.0, 0.0, 0.0)) == (1, 0, 0) + + +def test_com_to_chunk_idx_different_resolutions(): + """Test COM to chunk index with anisotropic resolution.""" + backend = SegContactLayerBackend( + path="/tmp/test", + resolution=Vec3D(8, 8, 40), # different z resolution + voxel_offset=Vec3D(0, 0, 0), + size=Vec3D(1000, 1000, 500), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + ) + + # chunk_size_nm = (256*8, 256*8, 128*40) = (2048, 2048, 5120) + assert backend.com_to_chunk_idx(Vec3D(0.0, 0.0, 0.0)) == (0, 0, 0) + assert backend.com_to_chunk_idx(Vec3D(2048.0, 0.0, 0.0)) == (1, 0, 0) + assert backend.com_to_chunk_idx(Vec3D(0.0, 0.0, 5120.0)) == (0, 0, 1) + + +# --- Info file tests --- + + +def test_write_info_creates_file(): + """Test that write_info creates a valid info JSON file.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = SegContactLayerBackend( + path=temp_dir, + resolution=Vec3D(16, 16, 40), + voxel_offset=Vec3D(0, 0, 0), + size=Vec3D(1000, 1000, 500), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + ) + + backend.write_info() + + info_path = os.path.join(temp_dir, "info") + assert os.path.exists(info_path) + + with open(info_path, "r") as f: + info = json.load(f) + + assert info["type"] == "seg_contact" + assert info["resolution"] == [16, 16, 40] + assert info["chunk_size"] == [256, 256, 128] + + +def test_write_info_all_fields(): + """Test that write_info writes all required fields.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = SegContactLayerBackend( + path=temp_dir, + resolution=Vec3D(16, 16, 40), + voxel_offset=Vec3D(100, 200, 50), + size=Vec3D(1000, 1000, 500), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + ) + + backend.write_info() + + with open(os.path.join(temp_dir, "info"), "r") as f: + info = json.load(f) + + assert info["format_version"] == "1.0" + assert info["type"] == "seg_contact" + assert info["resolution"] == [16, 16, 40] + assert info["voxel_offset"] == [100, 200, 50] + assert info["size"] == [1000, 1000, 500] + assert info["chunk_size"] == [256, 256, 128] + assert info["max_contact_span"] == 512 + + +def test_from_path_loads_info(): + """Test loading backend from existing info file.""" + with tempfile.TemporaryDirectory() as temp_dir: + # First create and write + backend = SegContactLayerBackend( + path=temp_dir, + resolution=Vec3D(16, 16, 40), + voxel_offset=Vec3D(100, 200, 50), + size=Vec3D(1000, 1000, 500), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + ) + backend.write_info() + + # Then load from path + loaded = SegContactLayerBackend.from_path(temp_dir) + + assert loaded.resolution == Vec3D(16, 16, 40) + assert loaded.voxel_offset == Vec3D(100, 200, 50) + assert loaded.size == Vec3D(1000, 1000, 500) + assert loaded.chunk_size == Vec3D(256, 256, 128) + assert loaded.max_contact_span == 512 + + +def test_from_path_missing_info_raises(): + """Test that from_path raises when info file doesn't exist.""" + with tempfile.TemporaryDirectory() as temp_dir: + with pytest.raises(FileNotFoundError): + SegContactLayerBackend.from_path(temp_dir) + + +# --- Chunk write/read tests --- + + +def make_seg_contact( + id: int, + seg_a: int, + seg_b: int, + com: tuple[float, float, float], + n_faces: int = 3, +) -> SegContact: + """Helper to create a SegContact for testing.""" + contact_faces = np.array( + [[com[0] + i, com[1] + i, com[2] + i, 0.5 + i * 0.1] for i in range(n_faces)], + dtype=np.float32, + ) + return SegContact( + id=id, + seg_a=seg_a, + seg_b=seg_b, + com=Vec3D(*com), + contact_faces=contact_faces, + ) + + +def test_write_chunk_creates_file(): + """Test that write_chunk creates a chunk file.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = SegContactLayerBackend( + path=temp_dir, + resolution=Vec3D(16, 16, 40), + voxel_offset=Vec3D(0, 0, 0), + size=Vec3D(1000, 1000, 500), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + ) + + contact = make_seg_contact(id=1, seg_a=100, seg_b=200, com=(100.0, 100.0, 100.0)) + backend.write_chunk((0, 0, 0), [contact]) + + chunk_path = backend.get_chunk_path((0, 0, 0)) + assert os.path.exists(chunk_path) + + +def test_write_read_chunk_single_contact(): + """Test round-trip of a single contact through write_chunk/read_chunk.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = SegContactLayerBackend( + path=temp_dir, + resolution=Vec3D(16, 16, 40), + voxel_offset=Vec3D(0, 0, 0), + size=Vec3D(1000, 1000, 500), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + ) + + contact = make_seg_contact( + id=42, seg_a=100, seg_b=200, com=(100.0, 200.0, 300.0), n_faces=5 + ) + backend.write_chunk((0, 0, 0), [contact]) + + contacts_read = backend.read_chunk((0, 0, 0)) + + assert len(contacts_read) == 1 + c = contacts_read[0] + assert c.id == 42 + assert c.seg_a == 100 + assert c.seg_b == 200 + assert c.com == Vec3D(100.0, 200.0, 300.0) + assert c.contact_faces.shape == (5, 4) + np.testing.assert_array_almost_equal(c.contact_faces, contact.contact_faces) + + +def test_write_read_chunk_multiple_contacts(): + """Test round-trip of multiple contacts in a single chunk.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = SegContactLayerBackend( + path=temp_dir, + resolution=Vec3D(16, 16, 40), + voxel_offset=Vec3D(0, 0, 0), + size=Vec3D(1000, 1000, 500), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + ) + + contacts = [ + make_seg_contact(id=1, seg_a=100, seg_b=200, com=(100.0, 100.0, 100.0), n_faces=3), + make_seg_contact(id=2, seg_a=100, seg_b=300, com=(200.0, 200.0, 200.0), n_faces=7), + make_seg_contact(id=3, seg_a=200, seg_b=300, com=(300.0, 300.0, 300.0), n_faces=1), + ] + backend.write_chunk((0, 0, 0), contacts) + + contacts_read = backend.read_chunk((0, 0, 0)) + + assert len(contacts_read) == 3 + # Verify each contact + for orig, read in zip(contacts, contacts_read): + assert read.id == orig.id + assert read.seg_a == orig.seg_a + assert read.seg_b == orig.seg_b + assert read.com == orig.com + np.testing.assert_array_almost_equal(read.contact_faces, orig.contact_faces) + + +def test_write_read_chunk_empty(): + """Test writing and reading an empty chunk.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = SegContactLayerBackend( + path=temp_dir, + resolution=Vec3D(16, 16, 40), + voxel_offset=Vec3D(0, 0, 0), + size=Vec3D(1000, 1000, 500), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + ) + + backend.write_chunk((0, 0, 0), []) + + contacts_read = backend.read_chunk((0, 0, 0)) + assert len(contacts_read) == 0 + + +def test_read_chunk_nonexistent_returns_empty(): + """Test reading a chunk that doesn't exist returns empty list.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = SegContactLayerBackend( + path=temp_dir, + resolution=Vec3D(16, 16, 40), + voxel_offset=Vec3D(0, 0, 0), + size=Vec3D(1000, 1000, 500), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + ) + + contacts_read = backend.read_chunk((0, 0, 0)) + assert len(contacts_read) == 0 + + +def test_write_read_chunk_with_partner_metadata(): + """Test round-trip of contact with partner_metadata.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = SegContactLayerBackend( + path=temp_dir, + resolution=Vec3D(16, 16, 40), + voxel_offset=Vec3D(0, 0, 0), + size=Vec3D(1000, 1000, 500), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + ) + + contact = SegContact( + id=1, + seg_a=100, + seg_b=200, + com=Vec3D(100.0, 100.0, 100.0), + contact_faces=np.array([[100, 100, 100, 0.5]], dtype=np.float32), + partner_metadata={100: {"type": "axon"}, 200: {"type": "dendrite"}}, + ) + backend.write_chunk((0, 0, 0), [contact]) + + contacts_read = backend.read_chunk((0, 0, 0)) + + assert len(contacts_read) == 1 + assert contacts_read[0].partner_metadata == { + 100: {"type": "axon"}, + 200: {"type": "dendrite"}, + } + + +def test_write_read_chunk_with_no_partner_metadata(): + """Test round-trip of contact without partner_metadata.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = SegContactLayerBackend( + path=temp_dir, + resolution=Vec3D(16, 16, 40), + voxel_offset=Vec3D(0, 0, 0), + size=Vec3D(1000, 1000, 500), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + ) + + contact = SegContact( + id=1, + seg_a=100, + seg_b=200, + com=Vec3D(100.0, 100.0, 100.0), + contact_faces=np.array([[100, 100, 100, 0.5]], dtype=np.float32), + partner_metadata=None, + ) + backend.write_chunk((0, 0, 0), [contact]) + + contacts_read = backend.read_chunk((0, 0, 0)) + + assert len(contacts_read) == 1 + assert contacts_read[0].partner_metadata is None + + +def test_write_read_chunk_large_contact_faces(): + """Test contact with many faces.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = SegContactLayerBackend( + path=temp_dir, + resolution=Vec3D(16, 16, 40), + voxel_offset=Vec3D(0, 0, 0), + size=Vec3D(1000, 1000, 500), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + ) + + n_faces = 1000 + contact_faces = np.random.rand(n_faces, 4).astype(np.float32) + contact = SegContact( + id=1, + seg_a=100, + seg_b=200, + com=Vec3D(100.0, 100.0, 100.0), + contact_faces=contact_faces, + ) + backend.write_chunk((0, 0, 0), [contact]) + + contacts_read = backend.read_chunk((0, 0, 0)) + + assert len(contacts_read) == 1 + assert contacts_read[0].contact_faces.shape == (n_faces, 4) + np.testing.assert_array_almost_equal(contacts_read[0].contact_faces, contact_faces) + + +# --- High-level read/write tests --- + + +def test_write_distributes_contacts_to_chunks(): + """Test that write() distributes contacts to correct chunks based on COM.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = SegContactLayerBackend( + path=temp_dir, + resolution=Vec3D(16, 16, 40), + voxel_offset=Vec3D(0, 0, 0), + size=Vec3D(1000, 1000, 500), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + ) + backend.write_info() + + # chunk_size_nm in x = 256 * 16 = 4096 + # Contact 1: COM in chunk (0,0,0) + # Contact 2: COM in chunk (1,0,0) + contacts = [ + make_seg_contact(id=1, seg_a=100, seg_b=200, com=(100.0, 100.0, 100.0)), + make_seg_contact(id=2, seg_a=100, seg_b=300, com=(5000.0, 100.0, 100.0)), # > 4096 + ] + + idx = VolumetricIndex( + resolution=Vec3D(16, 16, 40), + bbox=BBox3D.from_slices((slice(0, 512), slice(0, 256), slice(0, 128))), + ) + backend.write(idx, contacts) + + # Check chunk (0,0,0) + chunk_0 = backend.read_chunk((0, 0, 0)) + assert len(chunk_0) == 1 + assert chunk_0[0].id == 1 + + # Check chunk (1,0,0) + chunk_1 = backend.read_chunk((1, 0, 0)) + assert len(chunk_1) == 1 + assert chunk_1[0].id == 2 + + +def test_read_filters_by_bbox(): + """Test that read() filters contacts to those within the query bbox.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = SegContactLayerBackend( + path=temp_dir, + resolution=Vec3D(16, 16, 40), + voxel_offset=Vec3D(0, 0, 0), + size=Vec3D(1000, 1000, 500), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + ) + backend.write_info() + + # Write contacts at different positions within chunk (0,0,0) + # All within first chunk (0-4096 nm in x) + contacts = [ + make_seg_contact(id=1, seg_a=100, seg_b=200, com=(100.0, 100.0, 100.0)), + make_seg_contact(id=2, seg_a=100, seg_b=300, com=(2000.0, 100.0, 100.0)), + make_seg_contact(id=3, seg_a=200, seg_b=300, com=(3500.0, 100.0, 100.0)), + ] + backend.write_chunk((0, 0, 0), contacts) + + # Query only first half of chunk (0-128 voxels = 0-2048 nm in x) + idx = VolumetricIndex( + resolution=Vec3D(16, 16, 40), + bbox=BBox3D.from_slices((slice(0, 128), slice(0, 256), slice(0, 128))), + ) + result = backend.read(idx) + + # Should only get contacts 1 and 2 (COM < 2048 nm) + assert len(result) == 2 + ids = {c.id for c in result} + assert ids == {1, 2} + + +def test_read_spans_multiple_chunks(): + """Test that read() can span multiple chunks.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = SegContactLayerBackend( + path=temp_dir, + resolution=Vec3D(16, 16, 40), + voxel_offset=Vec3D(0, 0, 0), + size=Vec3D(1000, 1000, 500), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + ) + backend.write_info() + + # Write to chunk (0,0,0) + backend.write_chunk( + (0, 0, 0), + [ + make_seg_contact(id=1, seg_a=100, seg_b=200, com=(100.0, 100.0, 100.0)), + ], + ) + + # Write to chunk (1,0,0) + backend.write_chunk( + (1, 0, 0), + [ + make_seg_contact(id=2, seg_a=100, seg_b=300, com=(5000.0, 100.0, 100.0)), + ], + ) + + # Query spanning both chunks (0-512 voxels in x) + idx = VolumetricIndex( + resolution=Vec3D(16, 16, 40), + bbox=BBox3D.from_slices((slice(0, 512), slice(0, 256), slice(0, 128))), + ) + result = backend.read(idx) + + assert len(result) == 2 + ids = {c.id for c in result} + assert ids == {1, 2} + + +def test_read_empty_region(): + """Test reading from a region with no contacts.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = SegContactLayerBackend( + path=temp_dir, + resolution=Vec3D(16, 16, 40), + voxel_offset=Vec3D(0, 0, 0), + size=Vec3D(1000, 1000, 500), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + ) + backend.write_info() + + idx = VolumetricIndex( + resolution=Vec3D(16, 16, 40), + bbox=BBox3D.from_slices((slice(0, 256), slice(0, 256), slice(0, 128))), + ) + result = backend.read(idx) + + assert len(result) == 0 + + +def test_round_trip_full(): + """Full round-trip test: write via write(), read via read().""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = SegContactLayerBackend( + path=temp_dir, + resolution=Vec3D(16, 16, 40), + voxel_offset=Vec3D(0, 0, 0), + size=Vec3D(1000, 1000, 500), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + ) + backend.write_info() + + contacts = [ + make_seg_contact(id=1, seg_a=100, seg_b=200, com=(100.0, 100.0, 100.0), n_faces=5), + make_seg_contact(id=2, seg_a=100, seg_b=300, com=(200.0, 200.0, 200.0), n_faces=10), + ] + + write_idx = VolumetricIndex( + resolution=Vec3D(16, 16, 40), + bbox=BBox3D.from_slices((slice(0, 256), slice(0, 256), slice(0, 128))), + ) + backend.write(write_idx, contacts) + + # Read back + read_idx = VolumetricIndex( + resolution=Vec3D(16, 16, 40), + bbox=BBox3D.from_slices((slice(0, 256), slice(0, 256), slice(0, 128))), + ) + result = backend.read(read_idx) + + assert len(result) == 2 + result_by_id = {c.id: c for c in result} + + for orig in contacts: + read = result_by_id[orig.id] + assert read.seg_a == orig.seg_a + assert read.seg_b == orig.seg_b + assert read.com == orig.com + np.testing.assert_array_almost_equal(read.contact_faces, orig.contact_faces) diff --git a/tests/unit/layer/volumetric/seg_contact/test_build.py b/tests/unit/layer/volumetric/seg_contact/test_build.py new file mode 100644 index 000000000..ed8e66111 --- /dev/null +++ b/tests/unit/layer/volumetric/seg_contact/test_build.py @@ -0,0 +1,282 @@ +import os +import tempfile + +import pytest + +from zetta_utils import builder +from zetta_utils.geometry import BBox3D, Vec3D +from zetta_utils.layer.volumetric.seg_contact import ( + SegContactInfoSpec, + SegContactInfoSpecParams, + SegContactLayerBackend, + VolumetricSegContactLayer, + build_seg_contact_info_spec, +) +from zetta_utils.layer.volumetric.seg_contact.build import build_seg_contact_layer + + +def make_backend(temp_dir: str) -> SegContactLayerBackend: + """Helper to create a backend for testing.""" + backend = SegContactLayerBackend( + path=temp_dir, + resolution=Vec3D(16, 16, 40), + voxel_offset=Vec3D(0, 0, 0), + size=Vec3D(1000, 1000, 500), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + ) + backend.write_info() + return backend + + +# --- Read mode tests --- + + +def test_build_seg_contact_layer_read_mode(): + """Test building a contact layer in read mode.""" + with tempfile.TemporaryDirectory() as temp_dir: + make_backend(temp_dir) + + layer = build_seg_contact_layer(path=temp_dir, mode="read") + + assert isinstance(layer, VolumetricSegContactLayer) + assert layer.readonly is True + + +def test_build_seg_contact_layer_read_nonexistent(): + """Test that read mode fails for nonexistent path.""" + with pytest.raises(FileNotFoundError): + build_seg_contact_layer(path="/path/to/nonexistent", mode="read") + + +# --- Update mode tests --- + + +def test_build_seg_contact_layer_update_mode(): + """Test building a contact layer in update mode.""" + with tempfile.TemporaryDirectory() as temp_dir: + make_backend(temp_dir) + + layer = build_seg_contact_layer(path=temp_dir, mode="update") + + assert isinstance(layer, VolumetricSegContactLayer) + assert layer.readonly is False + + +def test_build_seg_contact_layer_update_nonexistent(): + """Test that update mode fails for nonexistent path.""" + with pytest.raises(FileNotFoundError): + build_seg_contact_layer(path="/path/to/nonexistent", mode="update") + + +def test_build_seg_contact_layer_preserves_backend_properties(): + """Test that built layer has correct backend properties.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = make_backend(temp_dir) + + layer = build_seg_contact_layer(path=temp_dir, mode="read") + + assert layer.backend.resolution == backend.resolution + assert layer.backend.voxel_offset == backend.voxel_offset + assert layer.backend.size == backend.size + assert layer.backend.chunk_size == backend.chunk_size + assert layer.backend.max_contact_span == backend.max_contact_span + + +# --- Write mode tests --- + + +def test_build_seg_contact_layer_write_mode(): + """Test creating a new contact layer in write mode.""" + with tempfile.TemporaryDirectory() as temp_dir: + path = os.path.join(temp_dir, "new_layer") + bbox = BBox3D.from_coords([0, 0, 0], [16000, 16000, 20000], [16, 16, 40]) + info_spec = SegContactInfoSpec( + info_spec_params=SegContactInfoSpecParams( + resolution=Vec3D(16, 16, 40), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + bbox=bbox, + ) + ) + + layer = build_seg_contact_layer(path=path, mode="write", info_spec=info_spec) + + assert isinstance(layer, VolumetricSegContactLayer) + assert layer.readonly is False + assert os.path.exists(os.path.join(path, "info")) + + +def test_build_seg_contact_layer_write_existing_fails(): + """Test that write mode fails if layer exists.""" + with tempfile.TemporaryDirectory() as temp_dir: + make_backend(temp_dir) + bbox = BBox3D.from_coords([0, 0, 0], [16000, 16000, 20000], [16, 16, 40]) + info_spec = SegContactInfoSpec( + info_spec_params=SegContactInfoSpecParams( + resolution=Vec3D(16, 16, 40), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + bbox=bbox, + ) + ) + + with pytest.raises(FileExistsError): + build_seg_contact_layer(path=temp_dir, mode="write", info_spec=info_spec) + + +def test_build_seg_contact_layer_write_no_info_spec_fails(): + """Test that write mode fails without info_spec.""" + with tempfile.TemporaryDirectory() as temp_dir: + path = os.path.join(temp_dir, "new_layer") + + with pytest.raises(ValueError, match="info_spec is required"): + build_seg_contact_layer(path=path, mode="write") + + +# --- Builder system tests --- + + +def test_build_seg_contact_layer_via_builder_read(): + """Test building contact layer via builder system in read mode.""" + with tempfile.TemporaryDirectory() as temp_dir: + make_backend(temp_dir) + + spec = { + "@type": "build_seg_contact_layer", + "path": temp_dir, + "mode": "read", + } + + layer = builder.build(spec) + + assert isinstance(layer, VolumetricSegContactLayer) + assert layer.readonly is True + + +def test_build_seg_contact_layer_via_builder_update(): + """Test building contact layer via builder system in update mode.""" + with tempfile.TemporaryDirectory() as temp_dir: + make_backend(temp_dir) + + spec = { + "@type": "build_seg_contact_layer", + "path": temp_dir, + "mode": "update", + } + + layer = builder.build(spec) + + assert isinstance(layer, VolumetricSegContactLayer) + assert layer.readonly is False + + +def test_build_seg_contact_layer_via_builder_write(): + """Test building new contact layer via builder system.""" + with tempfile.TemporaryDirectory() as temp_dir: + path = os.path.join(temp_dir, "new_layer") + + spec = { + "@type": "build_seg_contact_layer", + "path": path, + "mode": "write", + "info_spec": { + "@type": "build_seg_contact_info_spec", + "resolution": [16, 16, 40], + "chunk_size": [256, 256, 128], + "max_contact_span": 512, + "bbox": { + "@type": "BBox3D.from_coords", + "start_coord": [0, 0, 0], + "end_coord": [1000, 1000, 500], + "resolution": [16, 16, 40], + }, + }, + } + + layer = builder.build(spec) + + assert isinstance(layer, VolumetricSegContactLayer) + assert layer.readonly is False + + +# --- SegContactInfoSpec tests --- + + +def test_contact_info_spec_from_params(): + """Test creating SegContactInfoSpec from params.""" + bbox = BBox3D.from_coords([0, 0, 0], [16000, 16000, 20000], [16, 16, 40]) + spec = SegContactInfoSpec( + info_spec_params=SegContactInfoSpecParams( + resolution=Vec3D(16, 16, 40), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + bbox=bbox, + ) + ) + + info = spec.make_info() + + assert info["type"] == "seg_contact" + assert info["resolution"] == [16, 16, 40] + assert info["chunk_size"] == [256, 256, 128] + assert info["max_contact_span"] == 512 + + +def test_contact_info_spec_from_path(): + """Test creating SegContactInfoSpec from existing path.""" + with tempfile.TemporaryDirectory() as temp_dir: + make_backend(temp_dir) + + spec = SegContactInfoSpec(info_path=temp_dir) + info = spec.make_info() + + assert info["type"] == "seg_contact" + assert info["resolution"] == [16, 16, 40] + + +def test_contact_info_spec_write_info(): + """Test writing info file.""" + with tempfile.TemporaryDirectory() as temp_dir: + path = os.path.join(temp_dir, "new_layer") + bbox = BBox3D.from_coords([0, 0, 0], [16000, 16000, 20000], [16, 16, 40]) + spec = SegContactInfoSpec( + info_spec_params=SegContactInfoSpecParams( + resolution=Vec3D(16, 16, 40), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + bbox=bbox, + ) + ) + + spec.write_info(path) + + assert os.path.exists(os.path.join(path, "info")) + # Verify we can load it back + backend = SegContactLayerBackend.from_path(path) + assert backend.resolution == Vec3D(16, 16, 40) + + +def test_build_seg_contact_info_spec_from_params(): + """Test builder function with direct params.""" + bbox = BBox3D.from_coords([0, 0, 0], [16000, 16000, 20000], [16, 16, 40]) + spec = build_seg_contact_info_spec( + resolution=[16, 16, 40], + chunk_size=[256, 256, 128], + max_contact_span=512, + bbox=bbox, + ) + + assert spec.info_spec_params is not None + assert spec.info_spec_params.resolution == Vec3D(16, 16, 40) + + +def test_build_seg_contact_info_spec_from_reference(): + """Test builder function with reference path.""" + with tempfile.TemporaryDirectory() as temp_dir: + make_backend(temp_dir) + + spec = build_seg_contact_info_spec(reference_path=temp_dir) + + assert spec.info_spec_params is not None + assert spec.info_spec_params.resolution == Vec3D(16, 16, 40) diff --git a/tests/unit/layer/volumetric/seg_contact/test_contact.py b/tests/unit/layer/volumetric/seg_contact/test_contact.py new file mode 100644 index 000000000..cc63a8770 --- /dev/null +++ b/tests/unit/layer/volumetric/seg_contact/test_contact.py @@ -0,0 +1,228 @@ +import numpy as np +import pytest + +from zetta_utils.geometry import BBox3D, Vec3D +from zetta_utils.layer.volumetric import VolumetricIndex +from zetta_utils.layer.volumetric.seg_contact import SegContact + + +def make_seg_contact( + id: int = 1, + seg_a: int = 100, + seg_b: int = 200, + com: tuple[float, float, float] = (100.0, 100.0, 100.0), + n_faces: int = 3, +) -> SegContact: + """Helper to create a SegContact for testing.""" + contact_faces = np.array( + [[com[0] + i, com[1] + i, com[2] + i, 0.5] for i in range(n_faces)], + dtype=np.float32, + ) + return SegContact( + id=id, + seg_a=seg_a, + seg_b=seg_b, + com=Vec3D(*com), + contact_faces=contact_faces, + ) + + +# --- Basic instantiation tests --- + + +def test_seg_contact_instantiation(): + """Test basic SegContact creation.""" + contact = make_seg_contact(id=42, seg_a=100, seg_b=200, com=(50.0, 60.0, 70.0)) + + assert contact.id == 42 + assert contact.seg_a == 100 + assert contact.seg_b == 200 + assert contact.com == Vec3D(50.0, 60.0, 70.0) + assert contact.contact_faces.shape == (3, 4) + + +def test_seg_contact_with_optional_fields(): + """Test SegContact with all optional fields.""" + contact = SegContact( + id=1, + seg_a=100, + seg_b=200, + com=Vec3D(100.0, 100.0, 100.0), + contact_faces=np.array([[100, 100, 100, 0.5]], dtype=np.float32), + local_pointclouds={100: np.zeros((10, 3)), 200: np.ones((10, 3))}, + merge_decisions={"ground_truth": True, "model_v1": False}, + partner_metadata={100: {"type": "axon"}, 200: {"type": "dendrite"}}, + ) + + assert contact.local_pointclouds is not None + assert 100 in contact.local_pointclouds + assert contact.merge_decisions == {"ground_truth": True, "model_v1": False} + assert contact.partner_metadata == {100: {"type": "axon"}, 200: {"type": "dendrite"}} + + +def test_seg_contact_defaults_to_none(): + """Test that optional fields default to None.""" + contact = make_seg_contact() + + assert contact.local_pointclouds is None + assert contact.merge_decisions is None + assert contact.partner_metadata is None + + +# --- in_bounds tests --- + + +def test_in_bounds_com_inside(): + """Test in_bounds returns True when COM is inside bbox.""" + # COM at (100, 100, 100) nm + contact = make_seg_contact(com=(100.0, 100.0, 100.0)) + + # Bbox from (0, 0, 0) to (200, 200, 200) in voxels at resolution (1, 1, 1) + # So in nm: (0, 0, 0) to (200, 200, 200) + idx = VolumetricIndex( + resolution=Vec3D(1, 1, 1), + bbox=BBox3D.from_slices((slice(0, 200), slice(0, 200), slice(0, 200))), + ) + + assert contact.in_bounds(idx) is True + + +def test_in_bounds_com_outside(): + """Test in_bounds returns False when COM is outside bbox.""" + # COM at (300, 300, 300) nm + contact = make_seg_contact(com=(300.0, 300.0, 300.0)) + + # Bbox from (0, 0, 0) to (200, 200, 200) nm + idx = VolumetricIndex( + resolution=Vec3D(1, 1, 1), + bbox=BBox3D.from_slices((slice(0, 200), slice(0, 200), slice(0, 200))), + ) + + assert contact.in_bounds(idx) is False + + +def test_in_bounds_com_on_boundary_start(): + """Test in_bounds with COM exactly on start boundary (inclusive).""" + contact = make_seg_contact(com=(100.0, 100.0, 100.0)) + + # Bbox starts exactly at COM + idx = VolumetricIndex( + resolution=Vec3D(1, 1, 1), + bbox=BBox3D.from_slices((slice(100, 200), slice(100, 200), slice(100, 200))), + ) + + assert contact.in_bounds(idx) is True + + +def test_in_bounds_com_on_boundary_end(): + """Test in_bounds with COM exactly on end boundary (exclusive).""" + contact = make_seg_contact(com=(200.0, 200.0, 200.0)) + + # Bbox ends exactly at COM + idx = VolumetricIndex( + resolution=Vec3D(1, 1, 1), + bbox=BBox3D.from_slices((slice(100, 200), slice(100, 200), slice(100, 200))), + ) + + assert contact.in_bounds(idx) is False + + +def test_in_bounds_with_resolution(): + """Test in_bounds with non-unit resolution.""" + # COM at (1600, 1600, 4000) nm + contact = make_seg_contact(com=(1600.0, 1600.0, 4000.0)) + + # Bbox from (0, 0, 0) to (200, 200, 200) voxels at resolution (16, 16, 40) + # In nm: (0, 0, 0) to (3200, 3200, 8000) + idx = VolumetricIndex( + resolution=Vec3D(16, 16, 40), + bbox=BBox3D.from_slices((slice(0, 200), slice(0, 200), slice(0, 200))), + ) + + assert contact.in_bounds(idx) is True + + +def test_in_bounds_partial_outside(): + """Test in_bounds when COM is outside in one dimension only.""" + # COM at (100, 100, 300) nm - outside in z + contact = make_seg_contact(com=(100.0, 100.0, 300.0)) + + idx = VolumetricIndex( + resolution=Vec3D(1, 1, 1), + bbox=BBox3D.from_slices((slice(0, 200), slice(0, 200), slice(0, 200))), + ) + + assert contact.in_bounds(idx) is False + + +# --- with_converted_coordinates tests --- + + +def test_with_converted_coordinates_same_resolution(): + """Test coordinate conversion with same resolution (no change).""" + contact = make_seg_contact(com=(100.0, 100.0, 100.0)) + + converted = contact.with_converted_coordinates( + from_res=Vec3D(16, 16, 40), + to_res=Vec3D(16, 16, 40), + ) + + assert converted.com == contact.com + np.testing.assert_array_equal(converted.contact_faces, contact.contact_faces) + + +def test_with_converted_coordinates_upscale(): + """Test coordinate conversion to higher resolution (smaller voxels).""" + contact = SegContact( + id=1, + seg_a=100, + seg_b=200, + com=Vec3D(160.0, 160.0, 400.0), # nm + contact_faces=np.array([[160, 160, 400, 0.5]], dtype=np.float32), + ) + + # From 16nm to 8nm resolution - coordinates stay in nm, no change + converted = contact.with_converted_coordinates( + from_res=Vec3D(16, 16, 40), + to_res=Vec3D(8, 8, 20), + ) + + # Coordinates are in nm, they shouldn't change + assert converted.com == Vec3D(160.0, 160.0, 400.0) + + +def test_with_converted_coordinates_preserves_other_fields(): + """Test that coordinate conversion preserves non-coordinate fields.""" + contact = SegContact( + id=42, + seg_a=100, + seg_b=200, + com=Vec3D(100.0, 100.0, 100.0), + contact_faces=np.array([[100, 100, 100, 0.5]], dtype=np.float32), + partner_metadata={100: {"type": "axon"}}, + merge_decisions={"gt": True}, + ) + + converted = contact.with_converted_coordinates( + from_res=Vec3D(16, 16, 40), + to_res=Vec3D(8, 8, 20), + ) + + assert converted.id == 42 + assert converted.seg_a == 100 + assert converted.seg_b == 200 + assert converted.partner_metadata == {100: {"type": "axon"}} + assert converted.merge_decisions == {"gt": True} + + +def test_with_converted_coordinates_returns_new_instance(): + """Test that with_converted_coordinates returns a new SegContact.""" + contact = make_seg_contact() + + converted = contact.with_converted_coordinates( + from_res=Vec3D(16, 16, 40), + to_res=Vec3D(16, 16, 40), + ) + + # Should be different object (Contact is frozen/immutable) + assert converted is not contact diff --git a/tests/unit/layer/volumetric/seg_contact/test_layer.py b/tests/unit/layer/volumetric/seg_contact/test_layer.py new file mode 100644 index 000000000..ec92d2d0a --- /dev/null +++ b/tests/unit/layer/volumetric/seg_contact/test_layer.py @@ -0,0 +1,276 @@ +import tempfile + +import numpy as np +import pytest + +from zetta_utils.geometry import BBox3D, Vec3D +from zetta_utils.layer.volumetric import VolumetricIndex +from zetta_utils.layer.volumetric.seg_contact import ( + SegContact, + SegContactLayerBackend, + VolumetricSegContactLayer, +) + + +def make_backend(temp_dir: str) -> SegContactLayerBackend: + """Helper to create a backend for testing.""" + backend = SegContactLayerBackend( + path=temp_dir, + resolution=Vec3D(16, 16, 40), + voxel_offset=Vec3D(0, 0, 0), + size=Vec3D(1000, 1000, 500), + chunk_size=Vec3D(256, 256, 128), + max_contact_span=512, + ) + backend.write_info() + return backend + + +def make_seg_contact( + id: int = 1, + seg_a: int = 100, + seg_b: int = 200, + com: tuple[float, float, float] = (100.0, 100.0, 100.0), + n_faces: int = 3, +) -> SegContact: + """Helper to create a SegContact for testing.""" + contact_faces = np.array( + [[com[0] + i, com[1] + i, com[2] + i, 0.5] for i in range(n_faces)], + dtype=np.float32, + ) + return SegContact( + id=id, + seg_a=seg_a, + seg_b=seg_b, + com=Vec3D(*com), + contact_faces=contact_faces, + ) + + +# --- Basic instantiation tests --- + + +def test_layer_instantiation(): + """Test basic layer creation.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = make_backend(temp_dir) + layer = VolumetricSegContactLayer(backend=backend) + + assert layer.backend is backend + assert layer.readonly is False + + +def test_layer_instantiation_readonly(): + """Test layer creation with readonly=True.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = make_backend(temp_dir) + layer = VolumetricSegContactLayer(backend=backend, readonly=True) + + assert layer.readonly is True + + +# --- Read tests --- + + +def test_getitem_empty(): + """Test reading from empty layer.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = make_backend(temp_dir) + layer = VolumetricSegContactLayer(backend=backend) + + idx = VolumetricIndex( + resolution=Vec3D(16, 16, 40), + bbox=BBox3D.from_slices((slice(0, 256), slice(0, 256), slice(0, 128))), + ) + result = layer[idx] + + assert len(result) == 0 + + +def test_getitem_single_contact(): + """Test reading a single contact.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = make_backend(temp_dir) + contact = make_seg_contact(id=42, com=(100.0, 100.0, 100.0)) + backend.write_chunk((0, 0, 0), [contact]) + + layer = VolumetricSegContactLayer(backend=backend) + + idx = VolumetricIndex( + resolution=Vec3D(16, 16, 40), + bbox=BBox3D.from_slices((slice(0, 256), slice(0, 256), slice(0, 128))), + ) + result = layer[idx] + + assert len(result) == 1 + assert result[0].id == 42 + + +def test_getitem_multiple_contacts(): + """Test reading multiple contacts.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = make_backend(temp_dir) + contacts = [ + make_seg_contact(id=1, com=(100.0, 100.0, 100.0)), + make_seg_contact(id=2, com=(200.0, 200.0, 200.0)), + make_seg_contact(id=3, com=(300.0, 300.0, 300.0)), + ] + backend.write_chunk((0, 0, 0), contacts) + + layer = VolumetricSegContactLayer(backend=backend) + + idx = VolumetricIndex( + resolution=Vec3D(16, 16, 40), + bbox=BBox3D.from_slices((slice(0, 256), slice(0, 256), slice(0, 128))), + ) + result = layer[idx] + + assert len(result) == 3 + ids = {c.id for c in result} + assert ids == {1, 2, 3} + + +def test_getitem_filters_by_bbox(): + """Test that getitem filters contacts by bbox.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = make_backend(temp_dir) + # All in first chunk but different positions + contacts = [ + make_seg_contact(id=1, com=(100.0, 100.0, 100.0)), # in query + make_seg_contact(id=2, com=(2000.0, 100.0, 100.0)), # outside query in x + ] + backend.write_chunk((0, 0, 0), contacts) + + layer = VolumetricSegContactLayer(backend=backend) + + # Query only first part (0-1000 nm in x = 0-62.5 voxels at 16nm) + idx = VolumetricIndex( + resolution=Vec3D(16, 16, 40), + bbox=BBox3D.from_slices((slice(0, 64), slice(0, 256), slice(0, 128))), + ) + result = layer[idx] + + assert len(result) == 1 + assert result[0].id == 1 + + +# --- Write tests --- + + +def test_setitem_single_contact(): + """Test writing a single contact.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = make_backend(temp_dir) + layer = VolumetricSegContactLayer(backend=backend) + + contact = make_seg_contact(id=42, com=(100.0, 100.0, 100.0)) + + idx = VolumetricIndex( + resolution=Vec3D(16, 16, 40), + bbox=BBox3D.from_slices((slice(0, 256), slice(0, 256), slice(0, 128))), + ) + layer[idx] = [contact] + + # Read back + result = layer[idx] + assert len(result) == 1 + assert result[0].id == 42 + + +def test_setitem_multiple_contacts(): + """Test writing multiple contacts.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = make_backend(temp_dir) + layer = VolumetricSegContactLayer(backend=backend) + + contacts = [ + make_seg_contact(id=1, com=(100.0, 100.0, 100.0)), + make_seg_contact(id=2, com=(200.0, 200.0, 200.0)), + ] + + idx = VolumetricIndex( + resolution=Vec3D(16, 16, 40), + bbox=BBox3D.from_slices((slice(0, 256), slice(0, 256), slice(0, 128))), + ) + layer[idx] = contacts + + result = layer[idx] + assert len(result) == 2 + + +def test_setitem_readonly_raises(): + """Test that writing to readonly layer raises.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = make_backend(temp_dir) + layer = VolumetricSegContactLayer(backend=backend, readonly=True) + + contact = make_seg_contact() + idx = VolumetricIndex( + resolution=Vec3D(16, 16, 40), + bbox=BBox3D.from_slices((slice(0, 256), slice(0, 256), slice(0, 128))), + ) + + with pytest.raises(Exception): # Could be IOError, PermissionError, etc. + layer[idx] = [contact] + + +def test_setitem_distributes_to_chunks(): + """Test that writing distributes contacts to correct chunks.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = make_backend(temp_dir) + layer = VolumetricSegContactLayer(backend=backend) + + # Contact 1 in chunk (0,0,0), Contact 2 in chunk (1,0,0) + # chunk_size_nm = 256 * 16 = 4096 + contacts = [ + make_seg_contact(id=1, com=(100.0, 100.0, 100.0)), + make_seg_contact(id=2, com=(5000.0, 100.0, 100.0)), # > 4096 + ] + + idx = VolumetricIndex( + resolution=Vec3D(16, 16, 40), + bbox=BBox3D.from_slices((slice(0, 512), slice(0, 256), slice(0, 128))), + ) + layer[idx] = contacts + + # Check chunks directly + chunk_0 = backend.read_chunk((0, 0, 0)) + chunk_1 = backend.read_chunk((1, 0, 0)) + + assert len(chunk_0) == 1 + assert chunk_0[0].id == 1 + assert len(chunk_1) == 1 + assert chunk_1[0].id == 2 + + +# --- Round-trip tests --- + + +def test_round_trip(): + """Test full round-trip: write then read.""" + with tempfile.TemporaryDirectory() as temp_dir: + backend = make_backend(temp_dir) + layer = VolumetricSegContactLayer(backend=backend) + + contacts = [ + make_seg_contact(id=1, seg_a=100, seg_b=200, com=(100.0, 100.0, 100.0), n_faces=5), + make_seg_contact(id=2, seg_a=100, seg_b=300, com=(200.0, 200.0, 200.0), n_faces=10), + ] + + idx = VolumetricIndex( + resolution=Vec3D(16, 16, 40), + bbox=BBox3D.from_slices((slice(0, 256), slice(0, 256), slice(0, 128))), + ) + layer[idx] = contacts + + result = layer[idx] + + assert len(result) == 2 + result_by_id = {c.id: c for c in result} + + for orig in contacts: + read = result_by_id[orig.id] + assert read.seg_a == orig.seg_a + assert read.seg_b == orig.seg_b + assert read.com == orig.com + np.testing.assert_array_almost_equal(read.contact_faces, orig.contact_faces) diff --git a/tests/unit/mazepa_layer_processing/common/test_seg_contact_op.py b/tests/unit/mazepa_layer_processing/common/test_seg_contact_op.py new file mode 100644 index 000000000..8da35610a --- /dev/null +++ b/tests/unit/mazepa_layer_processing/common/test_seg_contact_op.py @@ -0,0 +1,276 @@ +import numpy as np + +from zetta_utils.geometry import Vec3D +from zetta_utils.mazepa_layer_processing.common.seg_contact_op import ( + SegContactOp, + _blackout_segments, + _build_seg_to_ref, + _compute_affinity_weighted_com, + _compute_contact_counts, + _compute_overlaps, + _filter_pairs_by_com, + _filter_pairs_touching_boundary, + _find_axis_contacts, + _find_contacts, + _find_merger_segment_ids, + _find_small_segment_ids, + _find_unclaimed_segment_ids, +) + + +# --- Unit tests for helper functions --- + + +def test_find_axis_contacts_basic(): + """Test finding contacts along one axis.""" + # Two segments touching along x axis + seg_lo = np.array([[[1, 1]], [[1, 1]]], dtype=np.int64) + seg_hi = np.array([[[2, 2]], [[2, 2]]], dtype=np.int64) + aff = np.ones((2, 1, 2), dtype=np.float32) * 0.8 + + seg_a, seg_b, aff_vals, x, y, z = _find_axis_contacts( + seg_lo, seg_hi, aff, offset=(0.5, 0, 0) + ) + + assert len(seg_a) == 4 + assert set(seg_a) == {1} + assert set(seg_b) == {2} + np.testing.assert_array_almost_equal(aff_vals, [0.8, 0.8, 0.8, 0.8]) + + +def test_find_axis_contacts_no_contacts(): + """Test no contacts when segments are identical.""" + seg = np.array([[[1, 1]], [[1, 1]]], dtype=np.int64) + aff = np.ones((2, 1, 2), dtype=np.float32) + + seg_a, seg_b, aff_vals, x, y, z = _find_axis_contacts(seg, seg, aff, offset=(0, 0, 0)) + + assert len(seg_a) == 0 + + +def test_find_axis_contacts_ignores_zero(): + """Test that contacts with segment 0 are ignored.""" + seg_lo = np.array([[[0, 1]]], dtype=np.int64) + seg_hi = np.array([[[1, 2]]], dtype=np.int64) + aff = np.ones((1, 1, 2), dtype=np.float32) + + seg_a, seg_b, aff_vals, x, y, z = _find_axis_contacts(seg_lo, seg_hi, aff, offset=(0, 0, 0)) + + # Only (1, 2) contact should be found, not (0, 1) + assert len(seg_a) == 1 + assert seg_a[0] == 1 + assert seg_b[0] == 2 + + +def test_find_contacts_normalizes_order(): + """Test that segment order is normalized (seg_a < seg_b).""" + # Create data where seg_b > seg_a in raw data + seg = np.array([[[1, 2, 1]]], dtype=np.int64) + aff = np.zeros((3, 1, 1, 3), dtype=np.float32) + aff[2] = 0.5 # z-axis affinity + + seg_a, seg_b, aff_vals, x, y, z = _find_contacts(seg, aff, Vec3D(0, 0, 0)) + + # All pairs should have seg_a < seg_b + assert all(a < b for a, b in zip(seg_a, seg_b)) + + +def test_filter_pairs_touching_boundary(): + """Test filtering contacts touching padded boundary.""" + seg_a = np.array([1, 1, 2], dtype=np.int64) + seg_b = np.array([2, 2, 3], dtype=np.int64) + aff = np.array([0.5, 0.6, 0.7], dtype=np.float32) + x = np.array([0.0, 10.0, 10.0], dtype=np.float32) # 0 is on boundary + y = np.array([10.0, 10.0, 10.0], dtype=np.float32) + z = np.array([10.0, 10.0, 10.0], dtype=np.float32) + + shape = (20, 20, 20) + start = Vec3D(0, 0, 0) + + result = _filter_pairs_touching_boundary(seg_a, seg_b, aff, x, y, z, start, shape) + seg_a_f, seg_b_f, aff_f, x_f, y_f, z_f = result + + # Pair (1, 2) has contact at x=0 which is on boundary + # So only (2, 3) should remain + assert len(seg_a_f) == 1 + assert seg_a_f[0] == 2 + assert seg_b_f[0] == 3 + + +def test_filter_pairs_by_com(): + """Test filtering contacts by COM outside kernel region.""" + # Pair (1, 2) has contacts at x=4 and x=6, COM at x=5 which is on kernel boundary + # Pair (2, 3) has contact at x=10 which is inside kernel + seg_a = np.array([1, 1, 2], dtype=np.int64) + seg_b = np.array([2, 2, 3], dtype=np.int64) + aff = np.array([0.5, 0.5, 0.7], dtype=np.float32) + x = np.array([4.0, 6.0, 10.0], dtype=np.float32) + y = np.array([10.0, 10.0, 10.0], dtype=np.float32) + z = np.array([10.0, 10.0, 10.0], dtype=np.float32) + + crop_pad = (5, 5, 5) + shape = (20, 20, 20) + start = Vec3D(0, 0, 0) + + result = _filter_pairs_by_com(seg_a, seg_b, aff, x, y, z, start, shape, crop_pad) + seg_a_f, seg_b_f, aff_f, x_f, y_f, z_f = result + + # Pair (1, 2) has COM at x=5 which is exactly on kernel start boundary (included) + # Pair (2, 3) has COM at x=10 which is inside kernel + # Both should remain + assert len(seg_a_f) == 3 + + +def test_filter_pairs_by_com_outside(): + """Test filtering contacts by COM outside kernel region.""" + # Pair (1, 2) has contacts at x=2 and x=4, COM at x=3 which is outside kernel + seg_a = np.array([1, 1], dtype=np.int64) + seg_b = np.array([2, 2], dtype=np.int64) + aff = np.array([0.5, 0.5], dtype=np.float32) + x = np.array([2.0, 4.0], dtype=np.float32) + y = np.array([10.0, 10.0], dtype=np.float32) + z = np.array([10.0, 10.0], dtype=np.float32) + + crop_pad = (5, 5, 5) + shape = (20, 20, 20) + start = Vec3D(0, 0, 0) + + result = _filter_pairs_by_com(seg_a, seg_b, aff, x, y, z, start, shape, crop_pad) + seg_a_f, seg_b_f, aff_f, x_f, y_f, z_f = result + + # Pair (1, 2) has COM at x=3 which is outside kernel (5-15) + assert len(seg_a_f) == 0 + + +def test_compute_overlaps_basic(): + """Test computing overlaps between segments and reference.""" + seg = np.array([[[1, 1, 2], [1, 1, 2]]], dtype=np.int64) + ref = np.array([[[1, 1, 1], [1, 1, 2]]], dtype=np.int64) + + seg_ids, ref_ids, counts = _compute_overlaps(seg, ref) + + # Segment 1 overlaps ref 1 (4 voxels) + # Segment 2 overlaps ref 1 (1 voxel) and ref 2 (1 voxel) + assert len(seg_ids) > 0 + + +def test_find_small_segment_ids(): + """Test finding segments below size threshold.""" + seg = np.zeros((10, 10, 10), dtype=np.int64) + seg[:5, :, :] = 1 # 500 voxels + seg[5:6, :, :] = 2 # 100 voxels + seg[6:, :, :] = 3 # 400 voxels + + small_ids = _find_small_segment_ids(seg, min_seg_size_vx=200) + + assert 2 in small_ids + assert 1 not in small_ids + assert 3 not in small_ids + + +def test_find_merger_segment_ids(): + """Test finding segments that overlap multiple reference CCs.""" + seg_ids = np.array([1, 1, 2, 2], dtype=np.int64) + ref_ids = np.array([10, 20, 30, 30], dtype=np.int64) + counts = np.array([100, 100, 100, 100], dtype=np.int32) + + merger_ids = _find_merger_segment_ids(seg_ids, ref_ids, counts, min_overlap_vx=50) + + # Segment 1 overlaps ref 10 and 20 -> merger + # Segment 2 overlaps only ref 30 -> not merger + assert 1 in merger_ids + assert 2 not in merger_ids + + +def test_find_unclaimed_segment_ids(): + """Test finding segments without sufficient overlap.""" + seg_ids = np.array([1, 2, 3], dtype=np.int64) + counts = np.array([100, 50, 10], dtype=np.int32) + + unclaimed = _find_unclaimed_segment_ids(seg_ids, counts, min_overlap_vx=60) + + assert 2 in unclaimed + assert 3 in unclaimed + assert 1 not in unclaimed + + +def test_build_seg_to_ref(): + """Test building segment to reference mapping.""" + seg_ids = np.array([1, 1, 2], dtype=np.int64) + ref_ids = np.array([10, 20, 30], dtype=np.int64) + counts = np.array([100, 50, 100], dtype=np.int32) + + seg_to_ref = _build_seg_to_ref(seg_ids, ref_ids, counts, min_overlap_vx=60) + + assert seg_to_ref[1] == {10} # 20 filtered out due to low count + assert seg_to_ref[2] == {30} + + +def test_blackout_segments(): + """Test setting segment IDs to 0.""" + seg = np.array([[[1, 2, 3], [1, 2, 3]]], dtype=np.int64) + result = _blackout_segments(seg, {2, 3}) + + assert np.all(result[seg == 1] == 1) + assert np.all(result[seg == 2] == 0) + assert np.all(result[seg == 3] == 0) + + +def test_blackout_segments_empty(): + """Test blackout with empty set does nothing.""" + seg = np.array([[[1, 2, 3]]], dtype=np.int64) + result = _blackout_segments(seg, set()) + + np.testing.assert_array_equal(result, seg) + + +def test_compute_contact_counts(): + """Test counting contacts per segment pair.""" + seg_a = np.array([1, 1, 2, 2, 2], dtype=np.int64) + seg_b = np.array([3, 3, 4, 4, 4], dtype=np.int64) + + counts = _compute_contact_counts(seg_a, seg_b) + + assert counts[(1, 3)] == 2 + assert counts[(2, 4)] == 3 + + +def test_compute_affinity_weighted_com(): + """Test affinity-weighted center of mass computation.""" + contacts = [(0.0, 0.0, 0.0, 0.9), (10.0, 0.0, 0.0, 0.1)] + resolution = np.array([16.0, 16.0, 40.0]) + + com = _compute_affinity_weighted_com(contacts, resolution) + + # COM should be closer to x=0 due to higher affinity weight + assert com[0] < 5.0 * 16.0 # Would be 5.0 * 16 = 80 if unweighted + + +def test_compute_affinity_weighted_com_zero_affinity(): + """Test COM computation when all affinities are zero.""" + contacts = [(0.0, 0.0, 0.0, 0.0), (10.0, 0.0, 0.0, 0.0)] + resolution = np.array([16.0, 16.0, 40.0]) + + com = _compute_affinity_weighted_com(contacts, resolution) + + # Should fall back to simple mean + np.testing.assert_array_almost_equal(com, [5.0 * 16.0, 0.0, 0.0]) + + +# --- SegContactOp method tests --- + + +def test_seg_contact_op_with_added_crop_pad(): + """Test with_added_crop_pad method.""" + op = SegContactOp(sphere_radius_nm=1000.0, crop_pad=(10, 10, 10)) + op2 = op.with_added_crop_pad(Vec3D(5, 5, 5)) + + assert tuple(op2.crop_pad) == (15, 15, 15) + + +def test_seg_contact_op_get_input_resolution(): + """Test get_input_resolution returns same resolution.""" + op = SegContactOp(sphere_radius_nm=1000.0) + res = Vec3D(16.0, 16.0, 40.0) + + assert op.get_input_resolution(res) == res diff --git a/zetta_utils/internal b/zetta_utils/internal index 6016b81e0..94738bbc2 160000 --- a/zetta_utils/internal +++ b/zetta_utils/internal @@ -1 +1 @@ -Subproject commit 6016b81e0c3db3afed5bcbc74aef9a5197aceee4 +Subproject commit 94738bbc2e139a0c5d8fc1a9af9e4926a164f6af diff --git a/zetta_utils/layer/volumetric/seg_contact/DESIGN.md b/zetta_utils/layer/volumetric/seg_contact/DESIGN.md new file mode 100644 index 000000000..5e4e91132 --- /dev/null +++ b/zetta_utils/layer/volumetric/seg_contact/DESIGN.md @@ -0,0 +1,275 @@ +# SegContact Layer Format Design + +A chunked spatial storage format for contact data between segmentation objects. + +## Overview + +SegContacts represent interfaces between two segments. Each contact has: +- A unique integer ID +- A center of mass (COM) in 3D space +- Contact faces (3D points with affinity values) +- Optional local point clouds (mesh samples around COM) +- Optional merge decisions from various authorities + +SegContacts are spatially indexed by their COM and stored in chunks following a precomputed-like naming convention. + +The `max_contact_span` constraint ensures that any contact can be fully computed within a single processing window. When generating contacts, processing chunks must have padding >= `max_contact_span / 2`. Contacts exceeding this span are filtered out during generation. + +**Indexing (bounds, chunks) uses voxels at a specified resolution. SegContact data (COM, faces, pointclouds) is stored in nanometers.** + +## SegContact Dataclass + +```python +@attrs.frozen +class SegContact: + id: int + seg_a: int + seg_b: int + com: Vec3D[float] # center of mass in nm + contact_faces: np.ndarray # (N, 4) float32: x, y, z, affinity in nm + local_pointclouds: dict[int, np.ndarray] | None # segment_id -> (n_points, 3) in nm + merge_decisions: dict[str, bool] | None # authority -> yes/no + partner_metadata: dict[int, Any] | None # segment_id -> metadata + + def in_bounds(self, idx: VolumetricIndex) -> bool: + """Check if COM falls within the given volumetric index.""" + ... + + def with_converted_coordinates( + self, from_res: Vec3D, to_res: Vec3D + ) -> SegContact: + """Return new SegContact with coordinates converted between resolutions.""" + ... +``` + +## Info File Structure + +The `info` JSON file at the dataset root: + +```json +{ + "format_version": "1.0", + "type": "seg_contact", + + "resolution": [16, 16, 40], + "voxel_offset": [0, 0, 0], + "size": [6250, 6250, 1250], + "chunk_size": [256, 256, 128], + "max_contact_span": 512, + + "affinity_path": "gs://bucket/affinities", + "segmentation_path": "gs://bucket/segmentation", + "image_path": "gs://bucket/image", + + "local_point_clouds": [ + {"radius_nm": 200, "n_points": 1024}, + {"radius_nm": 2000, "n_points": 4096} + ], + + "merge_decisions": ["ground_truth", "model_v1"], + + "filter_settings": { + "min_seg_size_vx": 2000, + "min_overlap_vx": 1000, + "min_contact_vx": 5, + "max_contact_vx": 2048 + } +} +``` + +### Field Descriptions + +| Field | Type | Description | +|-------|------|-------------| +| `format_version` | string | Format version for compatibility | +| `type` | string | Always `"seg_contact"` | +| `resolution` | [x, y, z] | Voxel size in nanometers | +| `voxel_offset` | [x, y, z] | Dataset start in voxels | +| `size` | [x, y, z] | Dataset dimensions in voxels | +| `chunk_size` | [x, y, z] | Chunk dimensions in voxels | +| `max_contact_span` | int | Maximum contact span in voxels | +| `affinity_path` | string | Path to source affinity layer | +| `segmentation_path` | string | Path to source segmentation layer | +| `image_path` | string? | Optional path to image layer for visualization | +| `local_point_clouds` | array | Configurations for local point cloud sampling | +| `merge_decisions` | array | List of merge decision authority names | +| `filter_settings` | object | Filter parameters used during generation | + +## Directory Structure + +``` +seg_contact_dataset/ +├── info +├── contacts/ +│ ├── 0-256_0-256_0-128 +│ ├── 256-512_0-256_0-128 +│ └── ... +├── local_point_clouds/ +│ ├── 200nm_1024pts/ +│ │ ├── 0-256_0-256_0-128 +│ │ └── ... +│ └── 2000nm_4096pts/ +│ ├── 0-256_0-256_0-128 +│ └── ... +└── merge_decisions/ + ├── ground_truth/ + │ ├── 0-256_0-256_0-128 + │ └── ... + └── model_v1/ + └── ... +``` + +## Chunk Naming Convention + +Follows precomputed format: `{x_start}-{x_end}_{y_start}-{y_end}_{z_start}-{z_end}` + +Coordinates are in voxels at the specified resolution. For grid position `(gx, gy, gz)`: +``` +x_start = voxel_offset[0] + gx * chunk_size[0] +x_end = x_start + chunk_size[0] +... +filename = f"{x_start}-{x_end}_{y_start}-{y_end}_{z_start}-{z_end}" +``` + +## SegContact Assignment Rule + +A seg_contact is assigned to the chunk containing its **center of mass (COM)**. The `max_contact_span` constraint ensures contacts don't extend beyond what can be processed in a single operation. + +## Binary Data Formats + +All chunk files use a custom binary format with little-endian encoding. + +### contacts/ + +Each chunk file contains all contacts whose COM falls within that chunk. + +``` +Header: + - n_contacts: uint32 (number of contacts in chunk) + +Per contact: + - id: int64 + - seg_a: int64 + - seg_b: int64 + - com: float32[3] (x, y, z in nm) + - n_faces: uint32 + - contact_faces: float32[n_faces, 4] (x, y, z, affinity per face) + - partner_metadata_len: uint32 + - partner_metadata: uint8[partner_metadata_len] (JSON-encoded dict) +``` + +### local_point_clouds/{radius}nm_{n_points}pts/ + +Each chunk contains point clouds for segments involved in contacts in that chunk. + +``` +Header: + - n_entries: uint32 + +Per entry: + - contact_id: int64 + - seg_a_points: float32[n_points, 3] + - seg_b_points: float32[n_points, 3] +``` + +Points are sampled from segment meshes within a sphere of `radius_nm` around the contact COM. +The `n_points` is fixed per configuration (from info file). + +### merge_decisions/{authority}/ + +Each chunk contains binary merge decisions for contacts in that chunk. + +``` +Header: + - n_decisions: uint32 + +Per decision: + - contact_id: int64 + - should_merge: uint8 (0 or 1) +``` + +## Reading Contacts + +To read contacts in a bounding box: + +1. Load `info` file +2. Calculate which chunks intersect the query bbox +3. For each chunk: + - Load chunk file + - Filter contacts whose COM is within query bbox +4. Optionally load corresponding local_point_clouds and merge_decisions + +## Writing SegContacts + +SegContacts are typically generated via a subchunkable operation: + +1. Process each chunk with padding >= `max_contact_span / 2` (in voxels) +2. Find contacts, compute COM for each +3. Assign contacts to chunks based on COM +4. Write to appropriate chunk files + +## Layer Architecture + +Following the pattern of `VolumetricAnnotationLayer`: + +### VolumetricSegContactLayer + +```python +@attrs.frozen +class VolumetricSegContactLayer(Layer[VolumetricIndex, Sequence[SegContact], Sequence[SegContact]]): + backend: SegContactLayerBackend + readonly: bool = False + + index_procs: tuple[IndexProcessor[VolumetricIndex], ...] = () + read_procs: tuple[SegContactDataProcT, ...] = () + write_procs: tuple[SegContactDataProcT, ...] = () + + def __getitem__(self, idx: VolumetricIndex) -> Sequence[SegContact]: + ... + + def __setitem__(self, idx: VolumetricIndex, data: Sequence[SegContact]): + ... +``` + +### SegContactLayerBackend + +```python +@attrs.define +class SegContactLayerBackend(Backend[VolumetricIndex, Sequence[SegContact], Sequence[SegContact]]): + path: str + resolution: Vec3D[int] # voxel size in nm + voxel_offset: Vec3D[int] # dataset start in voxels + size: Vec3D[int] # dataset dimensions in voxels + chunk_size: Vec3D[int] # chunk dimensions in voxels + max_contact_span: int # in voxels + # ... other info fields + + def read(self, idx: VolumetricIndex) -> Sequence[SegContact]: + ... + + def write(self, idx: VolumetricIndex, data: Sequence[SegContact]): + ... +``` + +## File Structure + +``` +zetta_utils/layer/volumetric/seg_contact/ +├── __init__.py +├── contact.py # SegContact dataclass +├── backend.py # SegContactLayerBackend +├── layer.py # VolumetricSegContactLayer +└── build.py # Builder functions +``` + +## Design Rationale + +### Why COM-based assignment? +- Deterministic: each contact belongs to exactly one chunk +- Efficient queries: spatial indexing by a single point +- Avoids duplication across chunk boundaries + +### Why max_contact_span constraint? +- Ensures contacts can be fully computed within a processing window +- Processing chunk must have padding >= max_contact_span / 2 (in voxels) +- Contacts exceeding this span are filtered out during generation diff --git a/zetta_utils/layer/volumetric/seg_contact/__init__.py b/zetta_utils/layer/volumetric/seg_contact/__init__.py new file mode 100644 index 000000000..9f9429843 --- /dev/null +++ b/zetta_utils/layer/volumetric/seg_contact/__init__.py @@ -0,0 +1,5 @@ +from .contact import SegContact +from .backend import SegContactLayerBackend +from .layer import VolumetricSegContactLayer +from .build import build_seg_contact_layer +from .info_spec import SegContactInfoSpec, SegContactInfoSpecParams, build_seg_contact_info_spec diff --git a/zetta_utils/layer/volumetric/seg_contact/backend.py b/zetta_utils/layer/volumetric/seg_contact/backend.py new file mode 100644 index 000000000..3eb1c5fb7 --- /dev/null +++ b/zetta_utils/layer/volumetric/seg_contact/backend.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +import json +import os +from collections.abc import Sequence + +import attrs +import numpy as np + +from zetta_utils.geometry import Vec3D +from zetta_utils.layer.backend_base import Backend +from zetta_utils.layer.volumetric import VolumetricIndex + +from .contact import SegContact + + +@attrs.define +class SegContactLayerBackend(Backend[VolumetricIndex, Sequence[SegContact], Sequence[SegContact]]): + """Backend for reading/writing seg_contact data in chunked format.""" + + path: str + resolution: Vec3D[int] # voxel size in nm + voxel_offset: Vec3D[int] # dataset start in voxels + size: Vec3D[int] # dataset dimensions in voxels + chunk_size: Vec3D[int] # chunk dimensions in voxels + max_contact_span: int # in voxels + + @property + def name(self) -> str: + return self.path + + def with_changes(self, **kwargs) -> SegContactLayerBackend: + return attrs.evolve(self, **kwargs) + + @classmethod + def from_path(cls, path: str) -> SegContactLayerBackend: + """Load backend from existing info file.""" + info_path = os.path.join(path, "info") + if not os.path.exists(info_path): + raise FileNotFoundError(f"Info file not found: {info_path}") + with open(info_path, "r") as f: + info = json.load(f) + return cls( + path=path, + resolution=Vec3D(*info["resolution"]), + voxel_offset=Vec3D(*info["voxel_offset"]), + size=Vec3D(*info["size"]), + chunk_size=Vec3D(*info["chunk_size"]), + max_contact_span=info["max_contact_span"], + ) + + def write_info(self) -> None: + """Write info file to disk.""" + info = { + "format_version": "1.0", + "type": "seg_contact", + "resolution": list(self.resolution), + "voxel_offset": list(self.voxel_offset), + "size": list(self.size), + "chunk_size": list(self.chunk_size), + "max_contact_span": self.max_contact_span, + } + os.makedirs(self.path, exist_ok=True) + with open(os.path.join(self.path, "info"), "w") as f: + json.dump(info, f, indent=2) + + def read(self, idx: VolumetricIndex) -> Sequence[SegContact]: + """Read contacts whose COM falls within the given index.""" + # Get bbox in nm + bbox = idx.bbox + start_nm: Vec3D = Vec3D( + bbox.start[0] * idx.resolution[0], + bbox.start[1] * idx.resolution[1], + bbox.start[2] * idx.resolution[2], + ) + end_nm: Vec3D = Vec3D( + bbox.end[0] * idx.resolution[0], + bbox.end[1] * idx.resolution[1], + bbox.end[2] * idx.resolution[2], + ) + + # Find which chunks to read + start_chunk = self.com_to_chunk_idx(start_nm) + end_chunk = self.com_to_chunk_idx( + Vec3D(end_nm[0] - 0.001, end_nm[1] - 0.001, end_nm[2] - 0.001) + ) + + result = [] + for gx in range(start_chunk[0], end_chunk[0] + 1): + for gy in range(start_chunk[1], end_chunk[1] + 1): + for gz in range(start_chunk[2], end_chunk[2] + 1): + contacts = self.read_chunk((gx, gy, gz)) + # Filter by COM within bbox + for c in contacts: + if ( + start_nm[0] <= c.com[0] < end_nm[0] + and start_nm[1] <= c.com[1] < end_nm[1] + and start_nm[2] <= c.com[2] < end_nm[2] + ): + result.append(c) + return result + + def write(self, idx: VolumetricIndex, data: Sequence[SegContact]) -> None: + """Write contacts to appropriate chunks based on their COM.""" + # Group contacts by chunk + chunk_contacts: dict[tuple[int, int, int], list[SegContact]] = {} + for contact in data: + chunk_idx = self.com_to_chunk_idx(contact.com) + if chunk_idx not in chunk_contacts: + chunk_contacts[chunk_idx] = [] + chunk_contacts[chunk_idx].append(contact) + + # Write each chunk + for chunk_idx, contacts in chunk_contacts.items(): + self.write_chunk(chunk_idx, contacts) + + def get_chunk_path(self, chunk_idx: tuple[int, int, int]) -> str: + """Get file path for a chunk given grid indices.""" + return os.path.join(self.path, "contacts", self.get_chunk_name(chunk_idx)) + + def get_chunk_name(self, chunk_idx: tuple[int, int, int]) -> str: + """Get chunk filename in precomputed format.""" + gx, gy, gz = chunk_idx + x_start = self.voxel_offset[0] + gx * self.chunk_size[0] + x_end = x_start + self.chunk_size[0] + y_start = self.voxel_offset[1] + gy * self.chunk_size[1] + y_end = y_start + self.chunk_size[1] + z_start = self.voxel_offset[2] + gz * self.chunk_size[2] + z_end = z_start + self.chunk_size[2] + return f"{x_start}-{x_end}_{y_start}-{y_end}_{z_start}-{z_end}" + + def com_to_chunk_idx(self, com_nm: Vec3D[float]) -> tuple[int, int, int]: + """Convert COM in nanometers to chunk grid index.""" + # Convert COM from nm to voxels + com_vx = Vec3D( + com_nm[0] / self.resolution[0], + com_nm[1] / self.resolution[1], + com_nm[2] / self.resolution[2], + ) + # Subtract offset and divide by chunk size + gx = int((com_vx[0] - self.voxel_offset[0]) // self.chunk_size[0]) + gy = int((com_vx[1] - self.voxel_offset[1]) // self.chunk_size[1]) + gz = int((com_vx[2] - self.voxel_offset[2]) // self.chunk_size[2]) + return (gx, gy, gz) + + def write_chunk(self, chunk_idx: tuple[int, int, int], contacts: Sequence[SegContact]) -> None: + """Write contacts to a specific chunk file.""" + import struct + + chunk_path = self.get_chunk_path(chunk_idx) + os.makedirs(os.path.dirname(chunk_path), exist_ok=True) + + with open(chunk_path, "wb") as f: + # Header: n_contacts + f.write(struct.pack(" Sequence[SegContact]: + """Read contacts from a specific chunk file.""" + import struct + + chunk_path = self.get_chunk_path(chunk_idx) + if not os.path.exists(chunk_path): + return [] + + contacts = [] + with open(chunk_path, "rb") as f: + # Header: n_contacts + n_contacts = struct.unpack(" 0: + metadata_bytes = f.read(metadata_len) + partner_metadata = json.loads(metadata_bytes.decode("utf-8")) + # Convert string keys back to int + partner_metadata = {int(k): v for k, v in partner_metadata.items()} + else: + partner_metadata = None + + contacts.append( + SegContact( + id=id_, + seg_a=seg_a, + seg_b=seg_b, + com=Vec3D(*com), + contact_faces=contact_faces.copy(), + partner_metadata=partner_metadata, + ) + ) + + return contacts diff --git a/zetta_utils/layer/volumetric/seg_contact/build.py b/zetta_utils/layer/volumetric/seg_contact/build.py new file mode 100644 index 000000000..ac8ca949f --- /dev/null +++ b/zetta_utils/layer/volumetric/seg_contact/build.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from typing import Literal + +from zetta_utils.builder import register + +from .backend import SegContactLayerBackend +from .info_spec import SegContactInfoSpec +from .layer import VolumetricSegContactLayer + + +@register("build_seg_contact_layer") +def build_seg_contact_layer( + path: str, + readonly: bool = False, + info_spec: SegContactInfoSpec | None = None, + mode: Literal["read", "write", "update"] = "read", +) -> VolumetricSegContactLayer: + """Build a VolumetricSegContactLayer from a path. + + :param path: Path to seg_contact layer. + :param readonly: Whether the layer should be read-only. + :param info_spec: Info specification for creating new layer. + :param mode: How the layer should be opened: + - "read": for reading only; layer must exist. + - "write": for writing; creates new layer, fails if exists. + - "update": for writing to existing layer. + :return: VolumetricSegContactLayer instance. + """ + import os + + info_path = os.path.join(path, "info") + layer_exists = os.path.exists(info_path) + + if mode == "read": + if not layer_exists: + raise FileNotFoundError(f"SegContact layer not found at {path}") + backend = SegContactLayerBackend.from_path(path) + return VolumetricSegContactLayer(backend=backend, readonly=True) + + if mode == "write": + if layer_exists: + raise FileExistsError(f"SegContact layer already exists at {path}") + if info_spec is None: + raise ValueError("info_spec is required when mode='write'") + info_spec.write_info(path) + backend = SegContactLayerBackend.from_path(path) + return VolumetricSegContactLayer(backend=backend, readonly=readonly) + + if mode == "update": + if not layer_exists: + raise FileNotFoundError(f"SegContact layer not found at {path}") + backend = SegContactLayerBackend.from_path(path) + return VolumetricSegContactLayer(backend=backend, readonly=readonly) + + raise ValueError(f"Invalid mode: {mode}") diff --git a/zetta_utils/layer/volumetric/seg_contact/contact.py b/zetta_utils/layer/volumetric/seg_contact/contact.py new file mode 100644 index 000000000..3aa066041 --- /dev/null +++ b/zetta_utils/layer/volumetric/seg_contact/contact.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from typing import Any + +import attrs +import numpy as np + +from zetta_utils.geometry import BBox3D, Vec3D +from zetta_utils.layer.volumetric import VolumetricIndex + + +@attrs.frozen +class SegContact: + """Represents a contact interface between two segments.""" + + id: int + seg_a: int + seg_b: int + com: Vec3D[float] # center of mass in nm + contact_faces: np.ndarray # (N, 4) float32: x, y, z, affinity in nm + local_pointclouds: dict[int, np.ndarray] | None = None # segment_id -> (n_points, 3) in nm + merge_decisions: dict[str, bool] | None = None # authority -> yes/no + partner_metadata: dict[int, Any] | None = None # segment_id -> metadata + + def in_bounds(self, idx: VolumetricIndex) -> bool: + """Check if COM falls within the given volumetric index.""" + bbox = idx.bbox + # Convert bbox to nm + start_nm = ( + bbox.start[0] * idx.resolution[0], + bbox.start[1] * idx.resolution[1], + bbox.start[2] * idx.resolution[2], + ) + end_nm = ( + bbox.end[0] * idx.resolution[0], + bbox.end[1] * idx.resolution[1], + bbox.end[2] * idx.resolution[2], + ) + return ( + start_nm[0] <= self.com[0] < end_nm[0] + and start_nm[1] <= self.com[1] < end_nm[1] + and start_nm[2] <= self.com[2] < end_nm[2] + ) + + def with_converted_coordinates(self, from_res: Vec3D, to_res: Vec3D) -> SegContact: + """Return new Contact with coordinates converted between resolutions. + + Note: Contact coordinates are stored in nanometers, so resolution + conversion doesn't change the values - this method exists for API + consistency with other layer types. + """ + # Coordinates are in nm, they don't change with resolution + # Just return a copy with same values + return SegContact( + id=self.id, + seg_a=self.seg_a, + seg_b=self.seg_b, + com=self.com, + contact_faces=self.contact_faces.copy(), + local_pointclouds=( + {k: v.copy() for k, v in self.local_pointclouds.items()} + if self.local_pointclouds is not None + else None + ), + merge_decisions=self.merge_decisions, + partner_metadata=self.partner_metadata, + ) diff --git a/zetta_utils/layer/volumetric/seg_contact/info_spec.py b/zetta_utils/layer/volumetric/seg_contact/info_spec.py new file mode 100644 index 000000000..3eaf1fe2e --- /dev/null +++ b/zetta_utils/layer/volumetric/seg_contact/info_spec.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +from typing import Sequence + +import attrs +from typeguard import typechecked + +from zetta_utils import builder +from zetta_utils.geometry import BBox3D, Vec3D + + +@typechecked +@attrs.mutable +class SegContactInfoSpecParams: + """Parameters for creating a seg_contact layer info file.""" + + resolution: Vec3D[int] + chunk_size: Vec3D[int] + max_contact_span: int + bbox: BBox3D + + @classmethod + def from_reference( + cls, + reference_path: str, + resolution: Sequence[int] | None = None, + chunk_size: Sequence[int] | None = None, + max_contact_span: int | None = None, + bbox: BBox3D | None = None, + ) -> SegContactInfoSpecParams: + """Create params from a reference seg_contact layer path.""" + from .backend import SegContactLayerBackend + + ref = SegContactLayerBackend.from_path(reference_path) + + if resolution is None: + resolution = ref.resolution + if chunk_size is None: + chunk_size = ref.chunk_size + if max_contact_span is None: + max_contact_span = ref.max_contact_span + if bbox is None: + bbox = BBox3D.from_coords( + start_coord=ref.voxel_offset, + end_coord=Vec3D(*ref.voxel_offset) + Vec3D(*ref.size), + resolution=ref.resolution, + ) + + return cls( + resolution=Vec3D(*resolution), + chunk_size=Vec3D(*chunk_size), + max_contact_span=max_contact_span, + bbox=bbox, + ) + + +@typechecked +@attrs.mutable +class SegContactInfoSpec: + """Specification for seg_contact layer info file, similar to PrecomputedInfoSpec.""" + + info_path: str | None = None + info_spec_params: SegContactInfoSpecParams | None = None + + def __attrs_post_init__(self): + if (self.info_path is None and self.info_spec_params is None) or ( + self.info_path is not None and self.info_spec_params is not None + ): + raise ValueError("Exactly one of `info_path`/`info_spec_params` must be provided") + + def make_info(self) -> dict: + """Generate info dict from spec params.""" + if self.info_path is not None: + from .backend import SegContactLayerBackend + + backend = SegContactLayerBackend.from_path(self.info_path) + return { + "format_version": "1.0", + "type": "seg_contact", + "resolution": list(backend.resolution), + "voxel_offset": list(backend.voxel_offset), + "size": list(backend.size), + "chunk_size": list(backend.chunk_size), + "max_contact_span": backend.max_contact_span, + } + else: + assert self.info_spec_params is not None + params = self.info_spec_params + voxel_offset = [int(params.bbox.start[i] / params.resolution[i]) for i in range(3)] + size = [int(params.bbox.shape[i] / params.resolution[i]) for i in range(3)] + return { + "format_version": "1.0", + "type": "seg_contact", + "resolution": list(params.resolution), + "voxel_offset": voxel_offset, + "size": size, + "chunk_size": list(params.chunk_size), + "max_contact_span": params.max_contact_span, + } + + def write_info(self, path: str) -> None: + """Write info file to the given path.""" + import json + import os + + info = self.make_info() + os.makedirs(path, exist_ok=True) + with open(os.path.join(path, "info"), "w") as f: + json.dump(info, f, indent=2) + + def set_bbox(self, bbox: BBox3D) -> None: + """Update the bounding box.""" + assert self.info_spec_params is not None + self.info_spec_params.bbox = bbox + + +@builder.register("build_seg_contact_info_spec") +def build_seg_contact_info_spec( + info_path: str | None = None, + reference_path: str | None = None, + resolution: Sequence[int] | None = None, + chunk_size: Sequence[int] | None = None, + max_contact_span: int | None = None, + bbox: BBox3D | None = None, +) -> SegContactInfoSpec: + """Build a SegContactInfoSpec for use in specs. + + :param info_path: Path to existing seg_contact layer to use as info source. + :param reference_path: Path to reference seg_contact layer to inherit params from. + :param resolution: Voxel resolution in nm (x, y, z). + :param chunk_size: Chunk size in voxels (x, y, z). + :param max_contact_span: Maximum contact span in voxels. + :param bbox: Bounding box for the dataset. + :return: SegContactInfoSpec instance. + """ + if info_path is not None: + if any(p is not None for p in [reference_path, resolution, chunk_size, max_contact_span]): + raise ValueError("When `info_path` is provided, other params should not be specified") + return SegContactInfoSpec(info_path=info_path) + + if reference_path is not None: + params = SegContactInfoSpecParams.from_reference( + reference_path=reference_path, + resolution=resolution, + chunk_size=chunk_size, + max_contact_span=max_contact_span, + bbox=bbox, + ) + else: + if resolution is None or chunk_size is None or max_contact_span is None or bbox is None: + raise ValueError( + "When no reference is provided, resolution, chunk_size, " + "max_contact_span, and bbox are all required" + ) + params = SegContactInfoSpecParams( + resolution=Vec3D(*resolution), + chunk_size=Vec3D(*chunk_size), + max_contact_span=max_contact_span, + bbox=bbox, + ) + + return SegContactInfoSpec(info_spec_params=params) diff --git a/zetta_utils/layer/volumetric/seg_contact/layer.py b/zetta_utils/layer/volumetric/seg_contact/layer.py new file mode 100644 index 000000000..638ac3ce2 --- /dev/null +++ b/zetta_utils/layer/volumetric/seg_contact/layer.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from collections.abc import Sequence + +import attrs + +from zetta_utils.layer import ( + DataProcessor, + IndexProcessor, + JointIndexDataProcessor, + Layer, +) +from zetta_utils.layer.volumetric import VolumetricIndex + +from .backend import SegContactLayerBackend +from .contact import SegContact + +SegContactDataProcT = ( + DataProcessor[Sequence[SegContact]] + | JointIndexDataProcessor[Sequence[SegContact], VolumetricIndex] +) + + +@attrs.frozen +class VolumetricSegContactLayer( + Layer[VolumetricIndex, Sequence[SegContact], Sequence[SegContact]] +): + """Layer for reading/writing seg_contact data.""" + + backend: SegContactLayerBackend + readonly: bool = False + + index_procs: tuple[IndexProcessor[VolumetricIndex], ...] = () + read_procs: tuple[SegContactDataProcT, ...] = () + write_procs: tuple[SegContactDataProcT, ...] = () + + def __getitem__(self, idx: VolumetricIndex) -> Sequence[SegContact]: + return self.read_with_procs(idx) + + def __setitem__(self, idx: VolumetricIndex, data: Sequence[SegContact]) -> None: + if self.readonly: + raise IOError("Cannot write to readonly layer") + self.write_with_procs(idx, data) diff --git a/zetta_utils/mazepa_layer_processing/common/__init__.py b/zetta_utils/mazepa_layer_processing/common/__init__.py index 5d8303f25..fc057a179 100644 --- a/zetta_utils/mazepa_layer_processing/common/__init__.py +++ b/zetta_utils/mazepa_layer_processing/common/__init__.py @@ -20,3 +20,4 @@ ) from .. import ChunkableOpProtocol, VolumetricOpProtocol from .interpolate_flow import build_interpolate_flow +from .seg_contact_op import SegContactOp diff --git a/zetta_utils/mazepa_layer_processing/common/seg_contact_op.py b/zetta_utils/mazepa_layer_processing/common/seg_contact_op.py new file mode 100644 index 000000000..08db7b97b --- /dev/null +++ b/zetta_utils/mazepa_layer_processing/common/seg_contact_op.py @@ -0,0 +1,520 @@ +from __future__ import annotations + +from collections import defaultdict +from collections.abc import Sequence +from concurrent.futures import ThreadPoolExecutor + +import attrs +import cc3d +import numpy as np +import pandas as pd +import trimesh +from cloudvolume import CloudVolume + +from zetta_utils import builder +from zetta_utils.geometry import Vec3D +from zetta_utils.layer.volumetric import VolumetricIndex, VolumetricLayer +from zetta_utils.layer.volumetric.seg_contact import SegContact, VolumetricSegContactLayer +from zetta_utils.mazepa import taskable_operation_cls +from zetta_utils.mazepa.semaphores import semaphore + + +def _read_layers_parallel( + segmentation: VolumetricLayer, + reference: VolumetricLayer, + affinity: VolumetricLayer, + idx: VolumetricIndex, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Read segmentation, reference, and affinity layers in parallel.""" + + def read_layer(layer: VolumetricLayer, index: VolumetricIndex) -> np.ndarray: + return np.asarray(layer[index]) + + with ThreadPoolExecutor(max_workers=3) as executor: + seg_future = executor.submit(read_layer, segmentation, idx) + ref_future = executor.submit(read_layer, reference, idx) + aff_future = executor.submit(read_layer, affinity, idx) + return seg_future.result().squeeze(), ref_future.result().squeeze(), aff_future.result() + + +def _compute_overlaps( + seg: np.ndarray, reference: np.ndarray +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Compute overlaps between segments and reference connected components.""" + cc_ref = cc3d.connected_components(reference, connectivity=6) + flat_seg = seg.ravel() + flat_cc_ref = cc_ref.ravel() + valid_mask = (flat_seg != 0) & (flat_cc_ref != 0) + df = pd.DataFrame({"seg": flat_seg[valid_mask], "cc_ref": flat_cc_ref[valid_mask]}) + counts_df = df.groupby(["seg", "cc_ref"]).size().reset_index(name="count") + return ( + counts_df["seg"].values.astype(np.int64), + counts_df["cc_ref"].values.astype(np.int64), + counts_df["count"].values.astype(np.int32), + ) + + +def _find_small_segment_ids(seg: np.ndarray, min_seg_size_vx: int) -> set[int]: + """Find segment IDs with total voxel count below threshold.""" + unique, counts = np.unique(seg, return_counts=True) + return {int(s) for s, cnt in zip(unique, counts) if s != 0 and cnt < min_seg_size_vx} + + +def _find_merger_segment_ids( + seg_ids: np.ndarray, ref_ids: np.ndarray, counts: np.ndarray, min_overlap_vx: int +) -> set[int]: + """Find merger segments (overlap 2+ reference CCs with >= min_overlap each).""" + seg_to_ref: dict[int, set[int]] = defaultdict(set) + for seg, ref, cnt in zip(seg_ids, ref_ids, counts): + if cnt >= min_overlap_vx: + seg_to_ref[int(seg)].add(int(ref)) + return {seg for seg, refs in seg_to_ref.items() if len(refs) >= 2} + + +def _find_unclaimed_segment_ids( + seg_ids: np.ndarray, counts: np.ndarray, min_overlap_vx: int +) -> set[int]: + """Find segments without sufficient reference overlap.""" + seg_max_overlap: dict[int, int] = defaultdict(int) + for seg, cnt in zip(seg_ids, counts): + seg_max_overlap[int(seg)] = max(seg_max_overlap[int(seg)], int(cnt)) + return {seg for seg, max_ovl in seg_max_overlap.items() if max_ovl < min_overlap_vx} + + +def _build_seg_to_ref( + seg_ids: np.ndarray, ref_ids: np.ndarray, counts: np.ndarray, min_overlap_vx: int +) -> dict[int, set[int]]: + """Build mapping from segment to reference CCs it overlaps with.""" + result: dict[int, set[int]] = defaultdict(set) + for seg, ref, cnt in zip(seg_ids, ref_ids, counts): + if cnt >= min_overlap_vx: + result[int(seg)].add(int(ref)) + return result + + +def _blackout_segments(seg: np.ndarray, ids_to_remove: set[int]) -> np.ndarray: + """Set specified segment IDs to 0.""" + if not ids_to_remove: + return seg + seg = seg.copy() + mask = np.isin(seg, list(ids_to_remove)) + seg[mask] = 0 + return seg + + +def _find_axis_contacts( + seg_lo: np.ndarray, + seg_hi: np.ndarray, + aff_slice: np.ndarray, + offset: tuple[float, float, float], +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Find contacts along one axis. Returns face centers.""" + mask = (seg_lo != seg_hi) & (seg_lo != 0) & (seg_hi != 0) + idx = np.nonzero(mask) + if len(idx[0]) == 0: + empty_i, empty_f = np.array([], dtype=np.int64), np.array([], dtype=np.float32) + return empty_i, empty_i, empty_f, empty_f, empty_f, empty_f + return ( + seg_lo[mask], + seg_hi[mask], + aff_slice[mask], + idx[0].astype(np.float32) + offset[0], + idx[1].astype(np.float32) + offset[1], + idx[2].astype(np.float32) + offset[2], + ) + + +def _find_contacts( + seg: np.ndarray, aff: np.ndarray, start: Vec3D +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Find contacts between segments using affinity data.""" + sx, sy, sz = float(start[0]), float(start[1]), float(start[2]) + results = [] + + for seg_lo, seg_hi, aff_slice, offset in [ + (seg[:-1], seg[1:], aff[0, 1:], (sx + 0.5, sy, sz)), + (seg[:, :-1], seg[:, 1:], aff[1, :, 1:], (sx, sy + 0.5, sz)), + (seg[:, :, :-1], seg[:, :, 1:], aff[2, :, :, 1:], (sx, sy, sz + 0.5)), + ]: + r = _find_axis_contacts(seg_lo, seg_hi, aff_slice, offset) + if len(r[0]) > 0: + results.append(r) + + if not results: + empty_i, empty_f = np.array([], dtype=np.int64), np.array([], dtype=np.float32) + return empty_i, empty_i, empty_f, empty_f, empty_f, empty_f + + seg_a = np.concatenate([r[0] for r in results]) + seg_b = np.concatenate([r[1] for r in results]) + aff_vals = np.concatenate([r[2] for r in results]) + x = np.concatenate([r[3] for r in results]) + y = np.concatenate([r[4] for r in results]) + z = np.concatenate([r[5] for r in results]) + + swap = seg_a > seg_b + seg_a, seg_b = np.where(swap, seg_b, seg_a), np.where(swap, seg_a, seg_b) + + return seg_a, seg_b, aff_vals, x, y, z + + +def _filter_pairs_touching_boundary( + seg_a: np.ndarray, + seg_b: np.ndarray, + aff: np.ndarray, + x: np.ndarray, + y: np.ndarray, + z: np.ndarray, + start: Vec3D, + shape: tuple[int, ...], +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Exclude pairs that have any contact touching the padded boundary.""" + padded_start = np.array([start[0], start[1], start[2]]) + padded_end = padded_start + np.array(shape) + on_boundary = ( + (x <= padded_start[0]) + | (x >= padded_end[0] - 1) + | (y <= padded_start[1]) + | (y >= padded_end[1] - 1) + | (z <= padded_start[2]) + | (z >= padded_end[2] - 1) + ) + pairs_on_boundary: set[tuple[int, int]] = set() + for a, b, on_b in zip(seg_a, seg_b, on_boundary): + if on_b: + pairs_on_boundary.add((int(a), int(b))) + keep = np.array([(int(a), int(b)) not in pairs_on_boundary for a, b in zip(seg_a, seg_b)]) + return seg_a[keep], seg_b[keep], aff[keep], x[keep], y[keep], z[keep] + + +def _filter_pairs_by_com( + seg_a: np.ndarray, + seg_b: np.ndarray, + aff: np.ndarray, + x: np.ndarray, + y: np.ndarray, + z: np.ndarray, + start: Vec3D, + shape: tuple[int, ...], + crop_pad: Sequence[int], +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Exclude pairs whose affinity-weighted COM falls outside the kernel region.""" + kernel_start = np.array([start[0], start[1], start[2]]) + np.array(crop_pad) + kernel_end = np.array([start[0], start[1], start[2]]) + np.array(shape) - np.array(crop_pad) + + contact_data = _build_contact_lookup(seg_a, seg_b, x, y, z, aff) + pairs_outside: set[tuple[int, int]] = set() + + for (a, b), contacts in contact_data.items(): + xs = np.array([c[0] for c in contacts]) + ys = np.array([c[1] for c in contacts]) + zs = np.array([c[2] for c in contacts]) + affs = np.array([c[3] for c in contacts]) + aff_sum = affs.sum() + if aff_sum > 0: + com_x = (xs * affs).sum() / aff_sum + com_y = (ys * affs).sum() / aff_sum + com_z = (zs * affs).sum() / aff_sum + else: + com_x, com_y, com_z = xs.mean(), ys.mean(), zs.mean() + + if ( + com_x < kernel_start[0] + or com_x >= kernel_end[0] + or com_y < kernel_start[1] + or com_y >= kernel_end[1] + or com_z < kernel_start[2] + or com_z >= kernel_end[2] + ): + pairs_outside.add((a, b)) + + keep = np.array([(int(a), int(b)) not in pairs_outside for a, b in zip(seg_a, seg_b)]) + return seg_a[keep], seg_b[keep], aff[keep], x[keep], y[keep], z[keep] + + +def _compute_contact_counts(seg_a: np.ndarray, seg_b: np.ndarray) -> dict[tuple[int, int], int]: + """Count contacts per segment pair.""" + counts: dict[tuple[int, int], int] = defaultdict(int) + for a, b in zip(seg_a, seg_b): + counts[(int(a), int(b))] += 1 + return counts + + +def _build_contact_lookup( + seg_a: np.ndarray, + seg_b: np.ndarray, + cx: np.ndarray, + cy: np.ndarray, + cz: np.ndarray, + aff: np.ndarray, +) -> dict[tuple[int, int], list[tuple[float, float, float, float]]]: + """Build lookup from segment pair to contact face centers.""" + data: dict[tuple[int, int], list[tuple[float, float, float, float]]] = defaultdict(list) + for a, b, x, y, z, af in zip(seg_a, seg_b, cx, cy, cz, aff): + data[(int(a), int(b))].append((float(x), float(y), float(z), float(af))) + return data + + +def _download_and_clip_meshes( + cv: CloudVolume, segment_ids: list[int], bbox_start: np.ndarray, bbox_end: np.ndarray +) -> dict[int, trimesh.Trimesh]: + """Download meshes and clip to bounding box.""" + if not segment_ids: + return {} + + meshes = cv.mesh.get(segment_ids, progress=False) + box = trimesh.creation.box(extents=bbox_end - bbox_start) + box.apply_translation((bbox_start + bbox_end) / 2) + + result: dict[int, trimesh.Trimesh] = {} + for seg_id in segment_ids: + mesh_obj = meshes.get(seg_id) + if mesh_obj is None or len(mesh_obj.vertices) == 0 or len(mesh_obj.faces) == 0: + continue + mesh = trimesh.Trimesh(vertices=mesh_obj.vertices, faces=mesh_obj.faces) + clipped = mesh.slice_plane(box.facets_origin, -np.array(box.facets_normal)) + if clipped is not None and len(clipped.vertices) > 0 and len(clipped.faces) > 0: + result[seg_id] = clipped + return result + + +def _crop_mesh_to_sphere( + mesh: trimesh.Trimesh, center: np.ndarray, radius: float +) -> trimesh.Trimesh | None: + """Clip mesh to sphere, keeping only faces fully inside.""" + vertex_dists = np.linalg.norm(mesh.vertices - center, axis=1) + vertex_inside = vertex_dists <= radius + face_inside = vertex_inside[mesh.faces].all(axis=1) + if not face_inside.any(): + return None + result = mesh.submesh([face_inside], append=True) + if isinstance(result, list): + return result[0] if result else None + return result + + +def _sample_mesh_points(mesh: trimesh.Trimesh | None, n: int) -> np.ndarray: + """Sample N points from mesh surface, area-weighted.""" + if mesh is None or len(mesh.faces) == 0: + return np.zeros((n, 3), dtype=np.float32) + result = trimesh.sample.sample_surface(mesh, n) + points = result[0] + return points.astype(np.float32) + + +def _compute_affinity_weighted_com( + contacts: list[tuple[float, float, float, float]], resolution: np.ndarray +) -> np.ndarray: + """Compute affinity-weighted center of mass in nm.""" + x = np.array([c[0] for c in contacts]) + y = np.array([c[1] for c in contacts]) + z = np.array([c[2] for c in contacts]) + aff = np.array([c[3] for c in contacts]) + aff_sum = aff.sum() + if aff_sum == 0: + return np.array( + [x.mean() * resolution[0], y.mean() * resolution[1], z.mean() * resolution[2]] + ) + return np.array( + [ + (x * aff).sum() / aff_sum * resolution[0], + (y * aff).sum() / aff_sum * resolution[1], + (z * aff).sum() / aff_sum * resolution[2], + ] + ) + + +def _all_contacts_in_sphere( + contacts: list[tuple[float, float, float, float]], + center: np.ndarray, + radius: float, + resolution: np.ndarray, +) -> bool: + """Check if all contacts are within sphere radius of center.""" + for x, y, z, _ in contacts: + pos_nm = np.array([x * resolution[0], y * resolution[1], z * resolution[2]]) + if np.linalg.norm(pos_nm - center) > radius: + return False + return True + + +def _make_contact_faces_array( + contacts: list[tuple[float, float, float, float]], resolution: np.ndarray +) -> np.ndarray: + """Create contact faces array (N, 4) with x, y, z, affinity in nm.""" + x = np.array([c[0] for c in contacts]) * resolution[0] + y = np.array([c[1] for c in contacts]) * resolution[1] + z = np.array([c[2] for c in contacts]) * resolution[2] + aff = np.array([c[3] for c in contacts]) + return np.stack([x, y, z, aff], axis=1).astype(np.float32) + + +def _generate_seg_contact( + contact_id: int, + seg_a_id: int, + seg_b_id: int, + meshes: dict[int, trimesh.Trimesh], + contact_data: dict[tuple[int, int], list[tuple[float, float, float, float]]], + seg_to_ref: dict[int, set[int]], + resolution: np.ndarray, + sphere_radius_nm: float, + n_pointcloud_points: int, + merge_authority: str, +) -> SegContact | None: + """Generate a single SegContact for a contact pair.""" + mesh_a, mesh_b = meshes.get(seg_a_id), meshes.get(seg_b_id) + if mesh_a is None or mesh_b is None: + return None + + contacts = contact_data[(seg_a_id, seg_b_id)] + com = _compute_affinity_weighted_com(contacts, resolution) + + mesh_a_cropped = _crop_mesh_to_sphere(mesh_a, com, sphere_radius_nm) + mesh_b_cropped = _crop_mesh_to_sphere(mesh_b, com, sphere_radius_nm) + if mesh_a_cropped is None or mesh_b_cropped is None: + return None + + if not _all_contacts_in_sphere(contacts, com, sphere_radius_nm, resolution): + return None + + contact_faces = _make_contact_faces_array(contacts, resolution) + + # Sample pointclouds + pointcloud_a = _sample_mesh_points(mesh_a_cropped, n_pointcloud_points) + pointcloud_b = _sample_mesh_points(mesh_b_cropped, n_pointcloud_points) + local_pointclouds = {seg_a_id: pointcloud_a, seg_b_id: pointcloud_b} + + # Compute merge decision: should merge if both segments overlap same reference CC + should_merge = bool(seg_to_ref.get(seg_a_id, set()) & seg_to_ref.get(seg_b_id, set())) + merge_decisions = {merge_authority: should_merge} + + return SegContact( + id=contact_id, + seg_a=seg_a_id, + seg_b=seg_b_id, + com=Vec3D(float(com[0]), float(com[1]), float(com[2])), + contact_faces=contact_faces, + local_pointclouds=local_pointclouds, + merge_decisions=merge_decisions, + ) + + +@builder.register("SegContactOp") +@taskable_operation_cls +@attrs.frozen +class SegContactOp: + """Operation to find and write segment contacts with pointclouds and merge decisions.""" + + sphere_radius_nm: float + crop_pad: Sequence[int] = (0, 0, 0) + min_seg_size_vx: int = 2000 + min_overlap_vx: int = 1000 + min_contact_vx: int = 5 + max_contact_vx: int = 2048 + n_pointcloud_points: int = 2048 + merge_authority: str = "reference_overlap" + + def get_input_resolution(self, dst_resolution: Vec3D[float]) -> Vec3D[float]: + return dst_resolution + + def with_added_crop_pad(self, crop_pad: Vec3D[int]) -> SegContactOp: + return attrs.evolve(self, crop_pad=Vec3D[int](*self.crop_pad) + crop_pad) + + def __call__( + self, + idx: VolumetricIndex, + dst: VolumetricSegContactLayer, + segmentation_layer: VolumetricLayer, + reference_layer: VolumetricLayer, + affinity_layer: VolumetricLayer, + ) -> None: + idx_padded = idx.padded(Vec3D[int](*self.crop_pad)) + resolution = np.array([idx.resolution[0], idx.resolution[1], idx.resolution[2]]) + + # Read all layers + with semaphore("read"): + seg, reference, aff = _read_layers_parallel( + segmentation_layer, reference_layer, affinity_layer, idx_padded + ) + + # Compute overlaps and find segments to exclude + overlap_seg, overlap_ref, overlap_count = _compute_overlaps(seg, reference) + small_ids = _find_small_segment_ids(seg, self.min_seg_size_vx) + merger_ids = _find_merger_segment_ids( + overlap_seg, overlap_ref, overlap_count, self.min_overlap_vx + ) + unclaimed_ids = _find_unclaimed_segment_ids(overlap_seg, overlap_count, self.min_overlap_vx) + exclude_ids = small_ids | merger_ids | unclaimed_ids + + # Build seg_to_ref mapping for merge decisions + seg_to_ref = _build_seg_to_ref( + overlap_seg, overlap_ref, overlap_count, self.min_overlap_vx + ) + + # Blackout excluded segments + seg = _blackout_segments(seg, exclude_ids) + + # Find contacts + seg_a, seg_b, aff_vals, x, y, z = _find_contacts(seg, aff, idx_padded.start) + if len(seg_a) == 0: + return + + # Filter out pairs touching padded boundary (may have incomplete contacts) + seg_a, seg_b, aff_vals, x, y, z = _filter_pairs_touching_boundary( + seg_a, seg_b, aff_vals, x, y, z, idx_padded.start, seg.shape + ) + if len(seg_a) == 0: + return + + # Filter out pairs with COM outside kernel region + seg_a, seg_b, aff_vals, x, y, z = _filter_pairs_by_com( + seg_a, seg_b, aff_vals, x, y, z, idx_padded.start, seg.shape, self.crop_pad + ) + if len(seg_a) == 0: + return + + # Filter pairs by contact count + contact_counts = _compute_contact_counts(seg_a, seg_b) + valid_pairs: list[tuple[int, int]] = [] + segs_needing_mesh: set[int] = set() + for (a, b), count in contact_counts.items(): + if self.min_contact_vx <= count <= self.max_contact_vx: + valid_pairs.append((a, b)) + segs_needing_mesh.update([a, b]) + + if not valid_pairs: + return + + # Build contact lookup + contact_data = _build_contact_lookup(seg_a, seg_b, x, y, z, aff_vals) + + # Download meshes + mesh_cv = CloudVolume(segmentation_layer.backend.name, use_https=True, progress=False) + bbox_start = ( + np.array([idx_padded.start[0], idx_padded.start[1], idx_padded.start[2]]) * resolution + ) + bbox_end = ( + np.array([idx_padded.stop[0], idx_padded.stop[1], idx_padded.stop[2]]) * resolution + ) + meshes = _download_and_clip_meshes(mesh_cv, list(segs_needing_mesh), bbox_start, bbox_end) + + # Generate SegContact objects + contacts: list[SegContact] = [] + for contact_id, (seg_a_id, seg_b_id) in enumerate(valid_pairs): + contact = _generate_seg_contact( + contact_id=contact_id, + seg_a_id=seg_a_id, + seg_b_id=seg_b_id, + meshes=meshes, + contact_data=contact_data, + seg_to_ref=seg_to_ref, + resolution=resolution, + sphere_radius_nm=self.sphere_radius_nm, + n_pointcloud_points=self.n_pointcloud_points, + merge_authority=self.merge_authority, + ) + if contact is not None: + contacts.append(contact) + + if contacts: + with semaphore("write"): + dst[idx] = contacts