Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ chunkedgraph = [
# conda install -c conda-forge graph-tool-base
]
convnet = ["torch >= 2.0", "onnx2torch", "nvidia-ml-py >= 13.580"]
pointcloud = [
"zetta_utils[convnet]",
"pointnet >= 0.1.1",
# torch-geometric is optional for DGCNN; install separately if needed:
# pip install torch-geometric
]
databackends = ["google-cloud-datastore", "google-cloud-firestore"]
docs = [
"piccolo_theme >= 0.24.0",
Expand Down Expand Up @@ -166,6 +172,7 @@ modules = [
"zetta_utils[meshing]",
"zetta_utils[skeletonization]",
"zetta_utils[chunkedgraph]",
"zetta_utils[pointcloud]",
]
montaging = ["zetta_utils[cloudvol, databackends, mazepa]", "torch >= 2.0"]
public = [
Expand Down
6 changes: 5 additions & 1 deletion zetta_utils/convnet/CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ PyTorch-based neural network utilities for connectomics and computer vision work
## Core Components
Model Management (utils.py): load_model() (supports loading models from JSON builder format, JIT, and ONNX formats), load_weights_file() (loads pre-trained weights with component filtering), save_model() (saves model state dictionaries), load_and_run_model() (convenience function for inference with automatic type conversion)

Architecture Components (architecture/): ConvBlock (flexible convolutional block with residual connections, normalization, and activation), UNet (complete U-Net implementation with skip connections sum/concat modes), Primitives (extensive collection of neural network building blocks)
Architecture Components (architecture/): ConvBlock (flexible convolutional block with residual connections, normalization, and activation), UNet (complete U-Net implementation with skip connections sum/concat modes), Primitives (extensive collection of neural network building blocks), Point Cloud Networks (PointNet, PointNet++)

Inference Runner (simple_inference_runner.py): SimpleInferenceRunner (cached inference runner with GPU memory management), supports 2D/3D operations, sigmoid activation, and zero-skipping optimizations

Expand All @@ -16,9 +16,13 @@ UNet Architecture: Traditional U-Net with encoder-decoder structure, configurabl

Primitive Components: Tensor Operations (View, Flatten, Unflatten, Crop, CenterCrop), Pooling (MaxPool2DFlatten, AvgPool2DFlatten), Utilities (RescaleValues, Clamp, UpConv, SplitTuple), Multi-head (MultiHeaded, MultiHeadedOutput for multi-task learning)

Point Cloud Networks (pointcloud.py): PointNet/PointNet++ (wraps pointnet package for classification and segmentation)

## Builder Registrations
Core Components: ConvBlock (versioned >=0.0.2), UNet (versioned >=0.0.2), SimpleInferenceRunner, load_model, load_weights_file

Point Cloud (requires pip install zetta_utils[pointcloud]): pointnet.PointNetCls, pointnet.PointNetSeg, pointnet.PointNet2ClsSSG, pointnet.PointNet2SegSSG, pointnet.PointNet2ClsMSG, pointnet.PointNet2SegMSG, pointnet.STN

PyTorch Integrations: All torch.nn.* classes auto-registered, All torch.optim.* classes auto-registered, PyTorch functional operations registered, custom builders for Sequential, Upsample, GroupNorm

Primitive Components: 15+ custom layer types registered, tensor manipulation utilities, pooling and cropping operations
Expand Down
1 change: 1 addition & 0 deletions zetta_utils/convnet/architecture/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from . import primitives
from . import pointcloud
from .convblock import ConvBlock
from .unet import UNet
from . import deprecated
13 changes: 13 additions & 0 deletions zetta_utils/convnet/architecture/pointcloud.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations

import pointnet

from zetta_utils import builder

builder.register("pointnet.PointNetCls")(pointnet.PointNetCls)
builder.register("pointnet.PointNetSeg")(pointnet.PointNetSeg)
builder.register("pointnet.PointNet2ClsSSG")(pointnet.PointNet2ClsSSG)
builder.register("pointnet.PointNet2SegSSG")(pointnet.PointNet2SegSSG)
builder.register("pointnet.PointNet2ClsMSG")(pointnet.PointNet2ClsMSG)
builder.register("pointnet.PointNet2SegMSG")(pointnet.PointNet2SegMSG)
builder.register("pointnet.STN")(pointnet.STN)
Loading