diff --git a/pyproject.toml b/pyproject.toml index 1d558a6e6..6f1d03fb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,12 @@ chunkedgraph = [ # conda install -c conda-forge graph-tool-base ] convnet = ["torch >= 2.0", "onnx2torch", "nvidia-ml-py >= 13.580"] +pointcloud = [ + "zetta_utils[convnet]", + "pointnet >= 0.1.1", + # torch-geometric is optional for DGCNN; install separately if needed: + # pip install torch-geometric +] databackends = ["google-cloud-datastore", "google-cloud-firestore"] docs = [ "piccolo_theme >= 0.24.0", @@ -166,6 +172,7 @@ modules = [ "zetta_utils[meshing]", "zetta_utils[skeletonization]", "zetta_utils[chunkedgraph]", + "zetta_utils[pointcloud]", ] montaging = ["zetta_utils[cloudvol, databackends, mazepa]", "torch >= 2.0"] public = [ diff --git a/requirements.all.txt b/requirements.all.txt index db588cc2b..26843420f 100644 --- a/requirements.all.txt +++ b/requirements.all.txt @@ -126,7 +126,9 @@ cloud-volume==12.8.0 # meshparty # pcg-skel colorama==0.4.6 - # via awscli + # via + # awscli + # taichi colorlog==6.10.1 # via trimesh comm==0.2.3 @@ -183,6 +185,7 @@ dill==0.4.0 # multiprocess # pathos # pylint + # taichi distlib==0.4.0 # via virtualenv docker==7.1.0 @@ -213,7 +216,9 @@ edt==3.1.0 # fastmorph # kimimaro einops==0.8.1 - # via zetta-utils (pyproject.toml) + # via + # zetta-utils (pyproject.toml) + # pointnet embreex==2.17.7.post7 # via trimesh executing==2.2.1 @@ -663,6 +668,7 @@ numpy==2.3.5 # shapely # simplejpeg # tables + # taichi # task-queue # tensorstore # tifffile @@ -821,6 +827,8 @@ pluggy==1.6.0 # via # pytest # pytest-cov +pointnet==0.1.1 + # via zetta-utils (pyproject.toml) posix-ipc==1.3.2 # via cloud-volume pox==0.3.6 @@ -1015,7 +1023,9 @@ responses==0.25.8 rfc3339==6.2 # via python-logging-loki rich==14.2.0 - # via zetta-utils (pyproject.toml) + # via + # zetta-utils (pyproject.toml) + # taichi roman==5.2 # via sphinx-toolbox rpds-py==0.30.0 @@ -1174,6 +1184,8 @@ tabulate==0.9.0 # via # zetta-utils (pyproject.toml) # sphinx-toolbox +taichi==1.7.4 + # via pointnet task-queue==2.14.3 # via zetta-utils (pyproject.toml) tenacity==9.1.2 @@ -1204,6 +1216,7 @@ torch==2.9.1 # kornia # lightning # onnx2torch + # pointnet # pytorch-lightning # torchfields # torchmetrics diff --git a/zetta_utils/convnet/CLAUDE.md b/zetta_utils/convnet/CLAUDE.md index bd2bc953e..55886c946 100644 --- a/zetta_utils/convnet/CLAUDE.md +++ b/zetta_utils/convnet/CLAUDE.md @@ -5,7 +5,7 @@ PyTorch-based neural network utilities for connectomics and computer vision work ## Core Components Model Management (utils.py): load_model() (supports loading models from JSON builder format, JIT, and ONNX formats), load_weights_file() (loads pre-trained weights with component filtering), save_model() (saves model state dictionaries), load_and_run_model() (convenience function for inference with automatic type conversion) -Architecture Components (architecture/): ConvBlock (flexible convolutional block with residual connections, normalization, and activation), UNet (complete U-Net implementation with skip connections sum/concat modes), Primitives (extensive collection of neural network building blocks) +Architecture Components (architecture/): ConvBlock (flexible convolutional block with residual connections, normalization, and activation), UNet (complete U-Net implementation with skip connections sum/concat modes), Primitives (extensive collection of neural network building blocks), Point Cloud Networks (PointNet, PointNet++) Inference Runner (simple_inference_runner.py): SimpleInferenceRunner (cached inference runner with GPU memory management), supports 2D/3D operations, sigmoid activation, and zero-skipping optimizations @@ -16,9 +16,13 @@ UNet Architecture: Traditional U-Net with encoder-decoder structure, configurabl Primitive Components: Tensor Operations (View, Flatten, Unflatten, Crop, CenterCrop), Pooling (MaxPool2DFlatten, AvgPool2DFlatten), Utilities (RescaleValues, Clamp, UpConv, SplitTuple), Multi-head (MultiHeaded, MultiHeadedOutput for multi-task learning) +Point Cloud Networks (pointcloud.py): PointNet/PointNet++ (wraps pointnet package for classification and segmentation) + ## Builder Registrations Core Components: ConvBlock (versioned >=0.0.2), UNet (versioned >=0.0.2), SimpleInferenceRunner, load_model, load_weights_file +Point Cloud (requires pip install zetta_utils[pointcloud]): pointnet.PointNetCls, pointnet.PointNetSeg, pointnet.PointNet2ClsSSG, pointnet.PointNet2SegSSG, pointnet.PointNet2ClsMSG, pointnet.PointNet2SegMSG, pointnet.STN + PyTorch Integrations: All torch.nn.* classes auto-registered, All torch.optim.* classes auto-registered, PyTorch functional operations registered, custom builders for Sequential, Upsample, GroupNorm Primitive Components: 15+ custom layer types registered, tensor manipulation utilities, pooling and cropping operations diff --git a/zetta_utils/convnet/architecture/__init__.py b/zetta_utils/convnet/architecture/__init__.py index 0088bc443..c340e84e1 100644 --- a/zetta_utils/convnet/architecture/__init__.py +++ b/zetta_utils/convnet/architecture/__init__.py @@ -2,3 +2,8 @@ from .convblock import ConvBlock from .unet import UNet from . import deprecated + +try: + from . import pointcloud +except ImportError: + pass diff --git a/zetta_utils/convnet/architecture/pointcloud.py b/zetta_utils/convnet/architecture/pointcloud.py new file mode 100644 index 000000000..793d81a39 --- /dev/null +++ b/zetta_utils/convnet/architecture/pointcloud.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +import pointnet + +from zetta_utils import builder + +builder.register("pointnet.PointNetCls")(pointnet.PointNetCls) +builder.register("pointnet.PointNetSeg")(pointnet.PointNetSeg) +builder.register("pointnet.PointNet2ClsSSG")(pointnet.PointNet2ClsSSG) +builder.register("pointnet.PointNet2SegSSG")(pointnet.PointNet2SegSSG) +builder.register("pointnet.PointNet2ClsMSG")(pointnet.PointNet2ClsMSG) +builder.register("pointnet.PointNet2SegMSG")(pointnet.PointNet2SegMSG) +builder.register("pointnet.STN")(pointnet.STN) diff --git a/zetta_utils/internal b/zetta_utils/internal index 6016b81e0..49f871cc1 160000 --- a/zetta_utils/internal +++ b/zetta_utils/internal @@ -1 +1 @@ -Subproject commit 6016b81e0c3db3afed5bcbc74aef9a5197aceee4 +Subproject commit 49f871cc1739fe6f0fbbdfe25793011de80d7091