From f1edd47605b1b789a173198e33543cc5014140b4 Mon Sep 17 00:00:00 2001 From: idejie Date: Thu, 7 Jul 2022 16:40:03 +0800 Subject: [PATCH 1/4] Migrate from THC headers into new ones --- maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu | 16 +++++++------ maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu | 17 ++++++++------ .../csrc/cuda/SigmoidFocalLoss_cuda.cu | 17 ++++++++------ .../csrc/cuda/deform_conv_cuda.cu | 5 +++- .../csrc/cuda/deform_pool_cuda.cu | 5 +++- maskrcnn_benchmark/csrc/cuda/nms.cu | 23 +++++++++++-------- 6 files changed, 50 insertions(+), 33 deletions(-) diff --git a/maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu b/maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu index 3791e5f..f8e28b6 100644 --- a/maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu +++ b/maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu @@ -2,7 +2,9 @@ #include #include -#include +// #include +#include +#include #include #include @@ -272,11 +274,11 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, auto output_size = num_rois * pooled_height * pooled_width * channels; cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - dim3 grid(std::min(THCCeilDiv(output_size, 512L), 4096L)); + dim3 grid(std::min(at::ceil_div(output_size, 512L), 4096L)); dim3 block(512); if (output.numel() == 0) { - THCudaCheck(cudaGetLastError()); + C10_CUDA_CHECK(cudaGetLastError()); return output; } @@ -294,7 +296,7 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, rois.contiguous().data_ptr(), output.data_ptr()); }); - THCudaCheck(cudaGetLastError()); + C10_CUDA_CHECK(cudaGetLastError()); return output; } @@ -317,12 +319,12 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - dim3 grid(std::min(THCCeilDiv(grad.numel(), 512L), 4096L)); + dim3 grid(std::min(at::ceil_div(grad.numel(), 512L), 4096L)); dim3 block(512); // handle possibly empty gradients if (grad.numel() == 0) { - THCudaCheck(cudaGetLastError()); + C10_CUDA_CHECK(cudaGetLastError()); return grad_input; } @@ -341,6 +343,6 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, grad_input.data_ptr(), rois.contiguous().data_ptr()); }); - THCudaCheck(cudaGetLastError()); + C10_CUDA_CHECK(cudaGetLastError()); return grad_input; } diff --git a/maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu b/maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu index 1bede94..f98ff64 100644 --- a/maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu +++ b/maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu @@ -2,7 +2,10 @@ #include #include -#include +// #include +#include + +#include #include #include @@ -126,11 +129,11 @@ std::tuple ROIPool_forward_cuda(const at::Tensor& input, cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - dim3 grid(std::min(THCCeilDiv(output_size, 512L), 4096L)); + dim3 grid(std::min(at::ceil_div(output_size, 512L), 4096L)); dim3 block(512); if (output.numel() == 0) { - THCudaCheck(cudaGetLastError()); + C10_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(output, argmax); } @@ -148,7 +151,7 @@ std::tuple ROIPool_forward_cuda(const at::Tensor& input, output.data_ptr(), argmax.data_ptr()); }); - THCudaCheck(cudaGetLastError()); + C10_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(output, argmax); } @@ -173,12 +176,12 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad, cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - dim3 grid(std::min(THCCeilDiv(grad.numel(), 512L), 4096L)); + dim3 grid(std::min(at::ceil_div(grad.numel(), 512L), 4096L)); dim3 block(512); // handle possibly empty gradients if (grad.numel() == 0) { - THCudaCheck(cudaGetLastError()); + C10_CUDA_CHECK(cudaGetLastError()); return grad_input; } @@ -197,6 +200,6 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad, grad_input.data_ptr(), rois.contiguous().data_ptr()); }); - THCudaCheck(cudaGetLastError()); + C10_CUDA_CHECK(cudaGetLastError()); return grad_input; } diff --git a/maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu b/maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu index f2e4ebb..20d29cb 100644 --- a/maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu +++ b/maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu @@ -5,7 +5,10 @@ #include #include -#include +// #include +#include + +#include #include #include @@ -117,11 +120,11 @@ at::Tensor SigmoidFocalLoss_forward_cuda( auto losses_size = num_samples * logits.size(1); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - dim3 grid(std::min(THCCeilDiv(losses_size, 512L), 4096L)); + dim3 grid(std::min(at::ceil_div(losses_size, 512L), 4096L)); dim3 block(512); if (losses.numel() == 0) { - THCudaCheck(cudaGetLastError()); + C10_CUDA_CHECK(cudaGetLastError()); return losses; } @@ -136,7 +139,7 @@ at::Tensor SigmoidFocalLoss_forward_cuda( num_samples, losses.data_ptr()); }); - THCudaCheck(cudaGetLastError()); + C10_CUDA_CHECK(cudaGetLastError()); return losses; } @@ -161,11 +164,11 @@ at::Tensor SigmoidFocalLoss_backward_cuda( auto d_logits_size = num_samples * logits.size(1); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - dim3 grid(std::min(THCCeilDiv(d_logits_size, 512L), 4096L)); + dim3 grid(std::min(at::ceil_div(d_logits_size, 512L), 4096L)); dim3 block(512); if (d_logits.numel() == 0) { - THCudaCheck(cudaGetLastError()); + C10_CUDA_CHECK(cudaGetLastError()); return d_logits; } @@ -182,7 +185,7 @@ at::Tensor SigmoidFocalLoss_backward_cuda( d_logits.data_ptr()); }); - THCudaCheck(cudaGetLastError()); + C10_CUDA_CHECK(cudaGetLastError()); return d_logits; } diff --git a/maskrcnn_benchmark/csrc/cuda/deform_conv_cuda.cu b/maskrcnn_benchmark/csrc/cuda/deform_conv_cuda.cu index 087adfe..750818a 100644 --- a/maskrcnn_benchmark/csrc/cuda/deform_conv_cuda.cu +++ b/maskrcnn_benchmark/csrc/cuda/deform_conv_cuda.cu @@ -4,7 +4,10 @@ #include #include -#include +// #include +#include + +#include #include #include diff --git a/maskrcnn_benchmark/csrc/cuda/deform_pool_cuda.cu b/maskrcnn_benchmark/csrc/cuda/deform_pool_cuda.cu index 0be4cb8..184acb7 100644 --- a/maskrcnn_benchmark/csrc/cuda/deform_pool_cuda.cu +++ b/maskrcnn_benchmark/csrc/cuda/deform_pool_cuda.cu @@ -8,7 +8,10 @@ #include #include -#include +// #include +#include + +#include #include #include diff --git a/maskrcnn_benchmark/csrc/cuda/nms.cu b/maskrcnn_benchmark/csrc/cuda/nms.cu index 2ba785d..2df8138 100644 --- a/maskrcnn_benchmark/csrc/cuda/nms.cu +++ b/maskrcnn_benchmark/csrc/cuda/nms.cu @@ -2,7 +2,10 @@ #include #include -#include +// #include +#include + +#include #include #include @@ -61,7 +64,7 @@ __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, t |= 1ULL << i; } } - const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock); + const int col_blocks = at::ceil_div(n_boxes, threadsPerBlock); dev_mask[cur_box_idx * col_blocks + col_start] = t; } } @@ -76,20 +79,20 @@ at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) { int boxes_num = boxes.size(0); - const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock); + const int col_blocks = at::ceil_div(boxes_num, threadsPerBlock); scalar_t* boxes_dev = boxes_sorted.data_ptr(); - THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState + // THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState unsigned long long* mask_dev = NULL; - //THCudaCheck(THCudaMalloc(state, (void**) &mask_dev, + //C10_CUDA_CHECK(THCudaMalloc(state, (void**) &mask_dev, // boxes_num * col_blocks * sizeof(unsigned long long))); - mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long)); + mask_dev = (unsigned long long*) c10::cuda::CUDACachingAllocator::raw_alloc( boxes_num * col_blocks * sizeof(unsigned long long)); - dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock), - THCCeilDiv(boxes_num, threadsPerBlock)); + dim3 blocks(at::ceil_div(boxes_num, threadsPerBlock), + at::ceil_div(boxes_num, threadsPerBlock)); dim3 threads(threadsPerBlock); nms_kernel<<>>(boxes_num, nms_overlap_thresh, @@ -97,7 +100,7 @@ at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) { mask_dev); std::vector mask_host(boxes_num * col_blocks); - THCudaCheck(cudaMemcpy(&mask_host[0], + C10_CUDA_CHECK(cudaMemcpy(&mask_host[0], mask_dev, sizeof(unsigned long long) * boxes_num * col_blocks, cudaMemcpyDeviceToHost)); @@ -122,7 +125,7 @@ at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) { } } - THCudaFree(state, mask_dev); + // THCudaFree(state, mask_dev); // TODO improve this part return std::get<0>(order_t.index({ keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to( From d7dd9128d463f04c30b82b562b9902224d4843b3 Mon Sep 17 00:00:00 2001 From: idejie Date: Thu, 7 Jul 2022 17:03:18 +0800 Subject: [PATCH 2/4] Migrate from THC headers into new ones --- maskrcnn_benchmark/utils/imports.py | 4 ++-- maskrcnn_benchmark/utils/model_zoo.py | 14 ++++++-------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/maskrcnn_benchmark/utils/imports.py b/maskrcnn_benchmark/utils/imports.py index 53e27e2..b5fdd49 100644 --- a/maskrcnn_benchmark/utils/imports.py +++ b/maskrcnn_benchmark/utils/imports.py @@ -1,7 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import torch -if torch._six.PY3: +try: import importlib import importlib.util import sys @@ -15,7 +15,7 @@ def import_file(module_name, file_path, make_importable=False): if make_importable: sys.modules[module_name] = module return module -else: +except: import imp def import_file(module_name, file_path, make_importable=None): diff --git a/maskrcnn_benchmark/utils/model_zoo.py b/maskrcnn_benchmark/utils/model_zoo.py index 2128ad7..8b20bad 100644 --- a/maskrcnn_benchmark/utils/model_zoo.py +++ b/maskrcnn_benchmark/utils/model_zoo.py @@ -2,14 +2,11 @@ import os import sys -try: - from torch.hub import _download_url_to_file - from torch.hub import urlparse - from torch.hub import HASH_REGEX -except ImportError: - from torch.utils.model_zoo import _download_url_to_file - from torch.utils.model_zoo import urlparse - from torch.utils.model_zoo import HASH_REGEX + +from torch.hub import download_url_to_file as _download_url_to_file +from torch.hub import urlparse +from torch.hub import HASH_REGEX + from maskrcnn_benchmark.utils.comm import is_main_process from maskrcnn_benchmark.utils.comm import synchronize @@ -39,6 +36,7 @@ def cache_url(url, model_dir=None, progress=True): model_dir = os.getenv("TORCH_MODEL_ZOO", os.path.join(torch_home, "models")) if not os.path.exists(model_dir): os.makedirs(model_dir) + print(url) parts = urlparse(url) filename = os.path.basename(parts.path) if filename == "model_final.pkl": From 4bb9847f0694fb47ce02aeae5d4c2684fa518a44 Mon Sep 17 00:00:00 2001 From: idejie Date: Thu, 7 Jul 2022 17:09:00 +0800 Subject: [PATCH 3/4] Migrate from THC headers into new ones --- INSTALL.md | 17 +++++++++-------- README.md | 2 +- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/INSTALL.md b/INSTALL.md index e399b63..49fee56 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -1,15 +1,14 @@ ## Installation ### Requirements: -- PyTorch 1.7 +- PyTorch 1.12 - torchvision - cocoapi -- yacs>=0.1.8 -- numpy>=1.19.5 +- yacs +- numpy - matplotlib -- GCC >= 4.9 - OpenCV -- CUDA >= 10.1 +- CUDA == 11.6 ### Option 1: Step-by-step installation @@ -19,16 +18,16 @@ # for that, check that `which conda`, `which pip` and `which python` points to the # right path. From a clean conda env, this is what you need to do -conda create --name sg_benchmark python=3.7 -y +conda create --name sg_benchmark python=3.9 -y conda activate sg_benchmark # this installs the right pip and dependencies for the fresh python conda install ipython h5py nltk joblib jupyter pandas scipy # maskrcnn_benchmark and coco api dependencies -pip install ninja yacs>=0.1.8 cython matplotlib tqdm opencv-python numpy>=1.19.5 +pip install ninja yacs cython matplotlib tqdm opencv-python numpy -conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.1 -c pytorch +conda install pytorch==1.12.0 torchvision torchaudi cudatoolkit=11.6 -c pytorch conda install -c conda-forge timm einops # install pycocotools @@ -51,6 +50,8 @@ python setup.py build develop ``` ### Option 2: Docker Image (Requires CUDA, Linux only) +> please use torch1.7.1 version + Build image with defaults (`CUDA=10.1`, `CUDNN=7`, `FORCE_CUDA=1`): nvidia-docker build -t scene_graph_benchmark docker/ diff --git a/README.md b/README.md index f345d88..d7fb6bc 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ ## Highlights -- **Upgrad to pytorch 1.7** +- **Upgrad to pytorch 1.12** - **Multi-GPU training and inference** - **Batched inference:** can perform inference using multiple images per batch per GPU. - **Fast and flexible tsv dataset format** From 5a635c96c916f739d45e6136b493ab4818ab957c Mon Sep 17 00:00:00 2001 From: idejie Date: Thu, 7 Jul 2022 17:12:42 +0800 Subject: [PATCH 4/4] torch 1.12 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d7fb6bc..c5b1ba1 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Scene Graph Benchmark in PyTorch 1.7 +# Scene Graph Benchmark in PyTorch 1.12 **This project is based on [maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark)**