本文为PointNet++ CUDA代码阅读系列的第二部分,其他详见:
(一)PointNet++代码梳理
(二)PointNet++中的FPS的CUDA实现
(三)PointNet++中ball query的CUDA实现
(四)PointNet++中的Three_nn的CUDA实现
之前只是使用PointNet++,也没有想过是怎么实现的。之前学了一下cuda编程,这里就来详解一个示例。
本文使用的代码是PointRCNN中PointNet++的实现。
FPS的实现是用c和cu实现的,所以先看一下pytorch中的定义。在pointnet2/pointnet2_utils.py中
class FurthestPointSampling(Function):@staticmethoddef forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:"""Uses iterative furthest point sampling to select a set of npoint features that have the largestminimum distance:param ctx::param xyz: (B, N, 3) where N > npoint:param npoint: int, number of features in the sampled set:return:output: (B, npoint) tensor containing the set"""assert xyz.is_contiguous()B, N, _ = xyz.size()output = torch.cuda.IntTensor(B, npoint)temp = torch.cuda.FloatTensor(B, N).fill_(1e10)pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)return output@staticmethoddef backward(xyz, a=None):return None, Nonefurthest_point_sample = FurthestPointSampling.apply
核心函数是furthest_point_sampling_wrapper,这个使用c++写成的。具体怎么链接到cpp,以及这个怎么再变成一个pytorch兼容的函数,具体可见我的另外一篇博客。
代码在pointnet2/src/sampling.cpp中
int furthest_point_sampling_wrapper(int b, int n, int m, at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) {const float *points = points_tensor.data<float>();float *temp = temp_tensor.data<float>();int *idx = idx_tensor.data<int>();cudaStream_t stream = THCState_getCurrentStream(state);furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream);return 1;
}
可以看到,在cpp中,接收由python函数传入的变量,然后调用cu中的kernel_launcher函数
kernel_launcher函数做的也不多,首先确定开的线程的数量
void furthest_point_sampling_kernel_launcher(int b, int n, int m, const float *dataset, float *temp, int *idxs, cudaStream_t stream) {// dataset: (B, N, 3)// tmp: (B, N)// output:// idx: (B, M)cudaError_t err;unsigned int n_threads = opt_n_threads(n); //计算线程数量,最大为1024switch (n_threads) {case 1024://我认为<1024>就是传入开的线程数量的值furthest_point_sampling_kernel<1024><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;case 512:furthest_point_sampling_kernel<512><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;case 256:furthest_point_sampling_kernel<256><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;case 128:furthest_point_sampling_kernel<128><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;case 64:furthest_point_sampling_kernel<64><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;case 32:furthest_point_sampling_kernel<32><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;case 16:furthest_point_sampling_kernel<16><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;case 8:furthest_point_sampling_kernel<8><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;case 4:furthest_point_sampling_kernel<4><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;case 2:furthest_point_sampling_kernel<2><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;case 1:furthest_point_sampling_kernel<1><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;default:furthest_point_sampling_kernel<512><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);}err = cudaGetLastError();if (cudaSuccess != err) {fprintf(stderr, "CUDA kernel failed : %sn", cudaGetErrorString(err));exit(-1);}
}
接下来看另外一段程序
// block_size就是对应kernel_launcher函数中的<1024>这个
template <unsigned int block_size>
__global__ void furthest_point_sampling_kernel(int b, int n, int m, const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) {// dataset: (B, N, 3)// tmp: (B, N)// output:// idx: (B, M)if (m <= 0) return;// 开两个共享内存,dists储存每个线程找到的最远的dists,dists_i储存对应的下标__shared__ float dists[block_size];__shared__ int dists_i[block_size];int batch_index = blockIdx.x;// 开的block的数量等于batch,一个block处理一个batch// dataset、temp、idxs这些都是指针,加上batch_index就是为了使得指针指向当前block要处理的batchdataset += batch_index * n * 3;temp += batch_index * n;idxs += batch_index * m;int tid = threadIdx.x;const int stride = block_size;int old = 0;// FPS总会找到第一个点,就用threadIdx.x=0这个线程处理一下。if (threadIdx.x == 0)idxs[0] = old;__syncthreads();for (int j = 1; j < m; j++) {int besti = 0;float best = -1;// 把上一次找出的点的坐标拿出来float x1 = dataset[old * 3 + 0];float y1 = dataset[old * 3 + 1];float z1 = dataset[old * 3 + 2];for (int k = tid; k < n; k += stride) {// 利用多个线程加速,每个线程处理n/k个点float x2, y2, z2;x2 = dataset[k * 3 + 0];y2 = dataset[k * 3 + 1];z2 = dataset[k * 3 + 2];float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);// temp大小是[B, N],维护的是每个原始点到已经所有已经选到的点的最小距离float d2 = min(d, temp[k]);temp[k] = d2;besti = d2 > best ? k : besti;best = d2 > best ? d2 : best;}dists[tid] = best;dists_i[tid] = besti;__syncthreads();// 以下为找到dists中最大的点if (block_size >= 1024) {if (tid < 512) {__update(dists, dists_i, tid, tid + 512);}__syncthreads();}if (block_size >= 512) {if (tid < 256) {__update(dists, dists_i, tid, tid + 256);}__syncthreads();}if (block_size >= 256) {if (tid < 128) {__update(dists, dists_i, tid, tid + 128);}__syncthreads();}if (block_size >= 128) {if (tid < 64) {__update(dists, dists_i, tid, tid + 64);}__syncthreads();}if (block_size >= 64) {if (tid < 32) {__update(dists, dists_i, tid, tid + 32);}__syncthreads();}if (block_size >= 32) {if (tid < 16) {__update(dists, dists_i, tid, tid + 16);}__syncthreads();}if (block_size >= 16) {if (tid < 8) {__update(dists, dists_i, tid, tid + 8);}__syncthreads();}if (block_size >= 8) {if (tid < 4) {__update(dists, dists_i, tid, tid + 4);}__syncthreads();}if (block_size >= 4) {if (tid < 2) {__update(dists, dists_i, tid, tid + 2);}__syncthreads();}if (block_size >= 2) {if (tid < 1) {__update(dists, dists_i, tid, tid + 1);}__syncthreads();}// 找到dist最大的一个,作为本次循环选出的点old = dists_i[0];if (tid == 0)idxs[j] = old;}
}
本文发布于:2024-01-28 10:11:56,感谢您对本站的认可!
本文链接:https://www.4u4v.net/it/17064079196690.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
留言与评论(共有 0 条评论) |