Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions INSTALL.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
## Installation

### Requirements:
- PyTorch 1.7
- PyTorch 1.12
- torchvision
- cocoapi
- yacs>=0.1.8
- numpy>=1.19.5
- yacs
- numpy
- matplotlib
- GCC >= 4.9
- OpenCV
- CUDA >= 10.1
- CUDA == 11.6


### Option 1: Step-by-step installation
Expand All @@ -19,16 +18,16 @@
# for that, check that `which conda`, `which pip` and `which python` points to the
# right path. From a clean conda env, this is what you need to do

conda create --name sg_benchmark python=3.7 -y
conda create --name sg_benchmark python=3.9 -y
conda activate sg_benchmark

# this installs the right pip and dependencies for the fresh python
conda install ipython h5py nltk joblib jupyter pandas scipy

# maskrcnn_benchmark and coco api dependencies
pip install ninja yacs>=0.1.8 cython matplotlib tqdm opencv-python numpy>=1.19.5
pip install ninja yacs cython matplotlib tqdm opencv-python numpy

conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.1 -c pytorch
conda install pytorch==1.12.0 torchvision torchaudi cudatoolkit=11.6 -c pytorch
conda install -c conda-forge timm einops

# install pycocotools
Expand All @@ -51,6 +50,8 @@ python setup.py build develop
```
### Option 2: Docker Image (Requires CUDA, Linux only)

> please use torch1.7.1 version

Build image with defaults (`CUDA=10.1`, `CUDNN=7`, `FORCE_CUDA=1`):

nvidia-docker build -t scene_graph_benchmark docker/
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Scene Graph Benchmark in PyTorch 1.7
# Scene Graph Benchmark in PyTorch 1.12

**This project is based on [maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark)**

![alt text](demo/R152FPN_demo.png "from https://storage.googleapis.com/openimages/web/index.html")


## Highlights
- **Upgrad to pytorch 1.7**
- **Upgrad to pytorch 1.12**
- **Multi-GPU training and inference**
- **Batched inference:** can perform inference using multiple images per batch per GPU.
- **Fast and flexible tsv dataset format**
Expand Down
16 changes: 9 additions & 7 deletions maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <THC/THC.h>
// #include <THC/THC.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/ceil_div.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCDeviceUtils.cuh>

Expand Down Expand Up @@ -272,11 +274,11 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input,
auto output_size = num_rois * pooled_height * pooled_width * channels;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(THCCeilDiv(output_size, 512L), 4096L));
dim3 grid(std::min(at::ceil_div(output_size, 512L), 4096L));
dim3 block(512);

if (output.numel() == 0) {
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return output;
}

Expand All @@ -294,7 +296,7 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input,
rois.contiguous().data_ptr<scalar_t>(),
output.data_ptr<scalar_t>());
});
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return output;
}

Expand All @@ -317,12 +319,12 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(THCCeilDiv(grad.numel(), 512L), 4096L));
dim3 grid(std::min(at::ceil_div(grad.numel(), 512L), 4096L));
dim3 block(512);

// handle possibly empty gradients
if (grad.numel() == 0) {
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return grad_input;
}

Expand All @@ -341,6 +343,6 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,
grad_input.data_ptr<scalar_t>(),
rois.contiguous().data_ptr<scalar_t>());
});
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return grad_input;
}
17 changes: 10 additions & 7 deletions maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <THC/THC.h>
// #include <THC/THC.h>
#include <ATen/ceil_div.h>

#include <c10/cuda/CUDAGuard.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCDeviceUtils.cuh>

Expand Down Expand Up @@ -126,11 +129,11 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input,

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(THCCeilDiv(output_size, 512L), 4096L));
dim3 grid(std::min(at::ceil_div(output_size, 512L), 4096L));
dim3 block(512);

if (output.numel() == 0) {
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(output, argmax);
}

Expand All @@ -148,7 +151,7 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input,
output.data_ptr<scalar_t>(),
argmax.data_ptr<int>());
});
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(output, argmax);
}

Expand All @@ -173,12 +176,12 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad,

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(THCCeilDiv(grad.numel(), 512L), 4096L));
dim3 grid(std::min(at::ceil_div(grad.numel(), 512L), 4096L));
dim3 block(512);

// handle possibly empty gradients
if (grad.numel() == 0) {
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return grad_input;
}

Expand All @@ -197,6 +200,6 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad,
grad_input.data_ptr<scalar_t>(),
rois.contiguous().data_ptr<scalar_t>());
});
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return grad_input;
}
17 changes: 10 additions & 7 deletions maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <THC/THC.h>
// #include <THC/THC.h>
#include <ATen/ceil_div.h>

#include <c10/cuda/CUDAGuard.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCDeviceUtils.cuh>

Expand Down Expand Up @@ -117,11 +120,11 @@ at::Tensor SigmoidFocalLoss_forward_cuda(
auto losses_size = num_samples * logits.size(1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(THCCeilDiv(losses_size, 512L), 4096L));
dim3 grid(std::min(at::ceil_div(losses_size, 512L), 4096L));
dim3 block(512);

if (losses.numel() == 0) {
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return losses;
}

Expand All @@ -136,7 +139,7 @@ at::Tensor SigmoidFocalLoss_forward_cuda(
num_samples,
losses.data_ptr<scalar_t>());
});
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return losses;
}

Expand All @@ -161,11 +164,11 @@ at::Tensor SigmoidFocalLoss_backward_cuda(
auto d_logits_size = num_samples * logits.size(1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(THCCeilDiv(d_logits_size, 512L), 4096L));
dim3 grid(std::min(at::ceil_div(d_logits_size, 512L), 4096L));
dim3 block(512);

if (d_logits.numel() == 0) {
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return d_logits;
}

Expand All @@ -182,7 +185,7 @@ at::Tensor SigmoidFocalLoss_backward_cuda(
d_logits.data_ptr<scalar_t>());
});

THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return d_logits;
}

5 changes: 4 additions & 1 deletion maskrcnn_benchmark/csrc/cuda/deform_conv_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <THC/THC.h>
// #include <THC/THC.h>
#include <ATen/ceil_div.h>

#include <c10/cuda/CUDAGuard.h>
#include <THC/THCDeviceUtils.cuh>

#include <vector>
Expand Down
5 changes: 4 additions & 1 deletion maskrcnn_benchmark/csrc/cuda/deform_pool_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <THC/THC.h>
// #include <THC/THC.h>
#include <ATen/ceil_div.h>

#include <c10/cuda/CUDAGuard.h>
#include <THC/THCDeviceUtils.cuh>

#include <vector>
Expand Down
23 changes: 13 additions & 10 deletions maskrcnn_benchmark/csrc/cuda/nms.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <THC/THC.h>
// #include <THC/THC.h>
#include <ATen/ceil_div.h>

#include <c10/cuda/CUDAGuard.h>
#include <THC/THCDeviceUtils.cuh>

#include <vector>
Expand Down Expand Up @@ -61,7 +64,7 @@ __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
t |= 1ULL << i;
}
}
const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock);
const int col_blocks = at::ceil_div(n_boxes, threadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
}
Expand All @@ -76,28 +79,28 @@ at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {

int boxes_num = boxes.size(0);

const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock);
const int col_blocks = at::ceil_div(boxes_num, threadsPerBlock);

scalar_t* boxes_dev = boxes_sorted.data_ptr<scalar_t>();

THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState
// THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState

unsigned long long* mask_dev = NULL;
//THCudaCheck(THCudaMalloc(state, (void**) &mask_dev,
//C10_CUDA_CHECK(THCudaMalloc(state, (void**) &mask_dev,
// boxes_num * col_blocks * sizeof(unsigned long long)));

mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long));
mask_dev = (unsigned long long*) c10::cuda::CUDACachingAllocator::raw_alloc( boxes_num * col_blocks * sizeof(unsigned long long));

dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock),
THCCeilDiv(boxes_num, threadsPerBlock));
dim3 blocks(at::ceil_div(boxes_num, threadsPerBlock),
at::ceil_div(boxes_num, threadsPerBlock));
dim3 threads(threadsPerBlock);
nms_kernel<<<blocks, threads>>>(boxes_num,
nms_overlap_thresh,
boxes_dev,
mask_dev);

std::vector<unsigned long long> mask_host(boxes_num * col_blocks);
THCudaCheck(cudaMemcpy(&mask_host[0],
C10_CUDA_CHECK(cudaMemcpy(&mask_host[0],
mask_dev,
sizeof(unsigned long long) * boxes_num * col_blocks,
cudaMemcpyDeviceToHost));
Expand All @@ -122,7 +125,7 @@ at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {
}
}

THCudaFree(state, mask_dev);
// THCudaFree(state, mask_dev);
// TODO improve this part
return std::get<0>(order_t.index({
keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(
Expand Down
4 changes: 2 additions & 2 deletions maskrcnn_benchmark/utils/imports.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch

if torch._six.PY3:
try:
import importlib
import importlib.util
import sys
Expand All @@ -15,7 +15,7 @@ def import_file(module_name, file_path, make_importable=False):
if make_importable:
sys.modules[module_name] = module
return module
else:
except:
import imp

def import_file(module_name, file_path, make_importable=None):
Expand Down
14 changes: 6 additions & 8 deletions maskrcnn_benchmark/utils/model_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@
import os
import sys

try:
from torch.hub import _download_url_to_file
from torch.hub import urlparse
from torch.hub import HASH_REGEX
except ImportError:
from torch.utils.model_zoo import _download_url_to_file
from torch.utils.model_zoo import urlparse
from torch.utils.model_zoo import HASH_REGEX

from torch.hub import download_url_to_file as _download_url_to_file
from torch.hub import urlparse
from torch.hub import HASH_REGEX


from maskrcnn_benchmark.utils.comm import is_main_process
from maskrcnn_benchmark.utils.comm import synchronize
Expand Down Expand Up @@ -39,6 +36,7 @@ def cache_url(url, model_dir=None, progress=True):
model_dir = os.getenv("TORCH_MODEL_ZOO", os.path.join(torch_home, "models"))
if not os.path.exists(model_dir):
os.makedirs(model_dir)
print(url)
parts = urlparse(url)
filename = os.path.basename(parts.path)
if filename == "model_final.pkl":
Expand Down