From 01f0d27894628b637f0d33ead33a35778855db8a Mon Sep 17 00:00:00 2001 From: Sergiy Date: Thu, 1 Jan 2026 14:45:43 +0000 Subject: [PATCH 1/3] Add point cloud network support (PointNet, PointNet++, DGCNN) --- pyproject.toml | 7 +++++++ zetta_utils/convnet/CLAUDE.md | 6 +++++- zetta_utils/convnet/architecture/__init__.py | 1 + zetta_utils/convnet/architecture/pointcloud.py | 13 +++++++++++++ 4 files changed, 26 insertions(+), 1 deletion(-) create mode 100644 zetta_utils/convnet/architecture/pointcloud.py diff --git a/pyproject.toml b/pyproject.toml index 1d558a6e6..6f1d03fb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,12 @@ chunkedgraph = [ # conda install -c conda-forge graph-tool-base ] convnet = ["torch >= 2.0", "onnx2torch", "nvidia-ml-py >= 13.580"] +pointcloud = [ + "zetta_utils[convnet]", + "pointnet >= 0.1.1", + # torch-geometric is optional for DGCNN; install separately if needed: + # pip install torch-geometric +] databackends = ["google-cloud-datastore", "google-cloud-firestore"] docs = [ "piccolo_theme >= 0.24.0", @@ -166,6 +172,7 @@ modules = [ "zetta_utils[meshing]", "zetta_utils[skeletonization]", "zetta_utils[chunkedgraph]", + "zetta_utils[pointcloud]", ] montaging = ["zetta_utils[cloudvol, databackends, mazepa]", "torch >= 2.0"] public = [ diff --git a/zetta_utils/convnet/CLAUDE.md b/zetta_utils/convnet/CLAUDE.md index bd2bc953e..55886c946 100644 --- a/zetta_utils/convnet/CLAUDE.md +++ b/zetta_utils/convnet/CLAUDE.md @@ -5,7 +5,7 @@ PyTorch-based neural network utilities for connectomics and computer vision work ## Core Components Model Management (utils.py): load_model() (supports loading models from JSON builder format, JIT, and ONNX formats), load_weights_file() (loads pre-trained weights with component filtering), save_model() (saves model state dictionaries), load_and_run_model() (convenience function for inference with automatic type conversion) -Architecture Components (architecture/): ConvBlock (flexible convolutional block with residual connections, normalization, and activation), UNet (complete U-Net implementation with skip connections sum/concat modes), Primitives (extensive collection of neural network building blocks) +Architecture Components (architecture/): ConvBlock (flexible convolutional block with residual connections, normalization, and activation), UNet (complete U-Net implementation with skip connections sum/concat modes), Primitives (extensive collection of neural network building blocks), Point Cloud Networks (PointNet, PointNet++) Inference Runner (simple_inference_runner.py): SimpleInferenceRunner (cached inference runner with GPU memory management), supports 2D/3D operations, sigmoid activation, and zero-skipping optimizations @@ -16,9 +16,13 @@ UNet Architecture: Traditional U-Net with encoder-decoder structure, configurabl Primitive Components: Tensor Operations (View, Flatten, Unflatten, Crop, CenterCrop), Pooling (MaxPool2DFlatten, AvgPool2DFlatten), Utilities (RescaleValues, Clamp, UpConv, SplitTuple), Multi-head (MultiHeaded, MultiHeadedOutput for multi-task learning) +Point Cloud Networks (pointcloud.py): PointNet/PointNet++ (wraps pointnet package for classification and segmentation) + ## Builder Registrations Core Components: ConvBlock (versioned >=0.0.2), UNet (versioned >=0.0.2), SimpleInferenceRunner, load_model, load_weights_file +Point Cloud (requires pip install zetta_utils[pointcloud]): pointnet.PointNetCls, pointnet.PointNetSeg, pointnet.PointNet2ClsSSG, pointnet.PointNet2SegSSG, pointnet.PointNet2ClsMSG, pointnet.PointNet2SegMSG, pointnet.STN + PyTorch Integrations: All torch.nn.* classes auto-registered, All torch.optim.* classes auto-registered, PyTorch functional operations registered, custom builders for Sequential, Upsample, GroupNorm Primitive Components: 15+ custom layer types registered, tensor manipulation utilities, pooling and cropping operations diff --git a/zetta_utils/convnet/architecture/__init__.py b/zetta_utils/convnet/architecture/__init__.py index 0088bc443..4b53b3629 100644 --- a/zetta_utils/convnet/architecture/__init__.py +++ b/zetta_utils/convnet/architecture/__init__.py @@ -1,4 +1,5 @@ from . import primitives +from . import pointcloud from .convblock import ConvBlock from .unet import UNet from . import deprecated diff --git a/zetta_utils/convnet/architecture/pointcloud.py b/zetta_utils/convnet/architecture/pointcloud.py new file mode 100644 index 000000000..793d81a39 --- /dev/null +++ b/zetta_utils/convnet/architecture/pointcloud.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +import pointnet + +from zetta_utils import builder + +builder.register("pointnet.PointNetCls")(pointnet.PointNetCls) +builder.register("pointnet.PointNetSeg")(pointnet.PointNetSeg) +builder.register("pointnet.PointNet2ClsSSG")(pointnet.PointNet2ClsSSG) +builder.register("pointnet.PointNet2SegSSG")(pointnet.PointNet2SegSSG) +builder.register("pointnet.PointNet2ClsMSG")(pointnet.PointNet2ClsMSG) +builder.register("pointnet.PointNet2SegMSG")(pointnet.PointNet2SegMSG) +builder.register("pointnet.STN")(pointnet.STN) From 27104f19133e4abb365d1b3abffa6dd68470ed28 Mon Sep 17 00:00:00 2001 From: Sergiy Date: Wed, 7 Jan 2026 10:09:37 +0000 Subject: [PATCH 2/3] update requirements --- requirements.all.txt | 19 ++++++++++++++++--- zetta_utils/convnet/architecture/__init__.py | 1 - zetta_utils/internal | 2 +- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/requirements.all.txt b/requirements.all.txt index db588cc2b..26843420f 100644 --- a/requirements.all.txt +++ b/requirements.all.txt @@ -126,7 +126,9 @@ cloud-volume==12.8.0 # meshparty # pcg-skel colorama==0.4.6 - # via awscli + # via + # awscli + # taichi colorlog==6.10.1 # via trimesh comm==0.2.3 @@ -183,6 +185,7 @@ dill==0.4.0 # multiprocess # pathos # pylint + # taichi distlib==0.4.0 # via virtualenv docker==7.1.0 @@ -213,7 +216,9 @@ edt==3.1.0 # fastmorph # kimimaro einops==0.8.1 - # via zetta-utils (pyproject.toml) + # via + # zetta-utils (pyproject.toml) + # pointnet embreex==2.17.7.post7 # via trimesh executing==2.2.1 @@ -663,6 +668,7 @@ numpy==2.3.5 # shapely # simplejpeg # tables + # taichi # task-queue # tensorstore # tifffile @@ -821,6 +827,8 @@ pluggy==1.6.0 # via # pytest # pytest-cov +pointnet==0.1.1 + # via zetta-utils (pyproject.toml) posix-ipc==1.3.2 # via cloud-volume pox==0.3.6 @@ -1015,7 +1023,9 @@ responses==0.25.8 rfc3339==6.2 # via python-logging-loki rich==14.2.0 - # via zetta-utils (pyproject.toml) + # via + # zetta-utils (pyproject.toml) + # taichi roman==5.2 # via sphinx-toolbox rpds-py==0.30.0 @@ -1174,6 +1184,8 @@ tabulate==0.9.0 # via # zetta-utils (pyproject.toml) # sphinx-toolbox +taichi==1.7.4 + # via pointnet task-queue==2.14.3 # via zetta-utils (pyproject.toml) tenacity==9.1.2 @@ -1204,6 +1216,7 @@ torch==2.9.1 # kornia # lightning # onnx2torch + # pointnet # pytorch-lightning # torchfields # torchmetrics diff --git a/zetta_utils/convnet/architecture/__init__.py b/zetta_utils/convnet/architecture/__init__.py index 4b53b3629..0088bc443 100644 --- a/zetta_utils/convnet/architecture/__init__.py +++ b/zetta_utils/convnet/architecture/__init__.py @@ -1,5 +1,4 @@ from . import primitives -from . import pointcloud from .convblock import ConvBlock from .unet import UNet from . import deprecated diff --git a/zetta_utils/internal b/zetta_utils/internal index 6016b81e0..49f871cc1 160000 --- a/zetta_utils/internal +++ b/zetta_utils/internal @@ -1 +1 @@ -Subproject commit 6016b81e0c3db3afed5bcbc74aef9a5197aceee4 +Subproject commit 49f871cc1739fe6f0fbbdfe25793011de80d7091 From 2151ca21d116b3c27dc72e818b3019472f32167b Mon Sep 17 00:00:00 2001 From: Sergiy Date: Wed, 7 Jan 2026 11:17:01 +0000 Subject: [PATCH 3/3] feat: auto-import pointcloud module during load_all_modules MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add try/except import in architecture/__init__.py so pointcloud network registrations are loaded when zetta_utils modules are loaded. The import is wrapped in try/except to handle environments without the pointnet package installed. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- zetta_utils/convnet/architecture/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/zetta_utils/convnet/architecture/__init__.py b/zetta_utils/convnet/architecture/__init__.py index 0088bc443..c340e84e1 100644 --- a/zetta_utils/convnet/architecture/__init__.py +++ b/zetta_utils/convnet/architecture/__init__.py @@ -2,3 +2,8 @@ from .convblock import ConvBlock from .unet import UNet from . import deprecated + +try: + from . import pointcloud +except ImportError: + pass