Files
gen_data_curobo/src/curobo/curobolib/cpp/lbfgs_step_kernel.cu
Balakumar Sundaralingam 58958bbcce update to 0.6.2
2023-12-15 02:01:33 -08:00

870 lines
30 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*
* Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
*/
#include <cuda.h>
#include <torch/extension.h>
#include <vector>
#include <c10/cuda/CUDAStream.h>
#include <cuda_fp16.h>
// #include "helper_cuda.h"
#include "helper_math.h"
#include <assert.h>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#include <math.h>
// #include <stdio.h>
//
// For the CUDA runtime routines (prefixed with "cuda_")
// #include <cuda_runtime.h>
// #include <cuda_fp16.h>
// #include <helper_cuda.h>
#define M_MAX 512
#define HALF_MAX 65504.0
#define M 15
#define VDIM 175 // 25 * 7,
#define FULL_MASK 0xffffffff
namespace Curobo {
namespace Optimization {
template <typename scalar_t>
__device__ __forceinline__ void
scalar_vec_product(const scalar_t a, const scalar_t *b, scalar_t *out,
const int v_dim) {
for (int i = 0; i < v_dim; i++) {
out[i] = a * b[i];
}
}
template <typename scalar_t>
__device__ __forceinline__ void
m_scalar_vec_product(const scalar_t *a, const scalar_t *b, scalar_t *out,
const int v_dim, const int m) {
for (int j = 0; j < m; j++) {
for (int i = 0; i < v_dim; i++) {
out[j * v_dim + i] = a[j] * b[j * v_dim + i];
}
}
}
template <typename scalar_t>
__device__ __forceinline__ void vec_vec_dot(const scalar_t *a,
const scalar_t *b, scalar_t &out,
const int v_dim) {
for (int i = 0; i < v_dim; i++) {
out += a[i] * b[i];
}
}
template <typename scalar_t>
__device__ __forceinline__ void update_r(const scalar_t *rho_y,
const scalar_t *s_buffer, scalar_t *r,
scalar_t &alpha, const int v_dim) {
// dot product: and subtract with alpha
for (int i = 0; i < v_dim; i++) {
alpha -= rho_y[i] * r[i];
}
// scalar vector product:
for (int i = 0; i < v_dim; i++) {
r[i] = r[i] + alpha * s_buffer[i];
}
}
template <typename scalar_t>
__device__ __forceinline__ void update_q(const scalar_t *y_buffer, scalar_t *gq,
const scalar_t alpha,
const int v_dim) {
//
for (int i = 0; i < v_dim; i++) {
gq[i] = gq[i] - (alpha * y_buffer[i]);
}
}
template <typename scalar_t>
__global__ void
lbfgs_step_kernel_old(scalar_t *step_vec, scalar_t *rho_buffer,
const scalar_t *y_buffer, const scalar_t *s_buffer,
const scalar_t *grad_q, const float epsilon,
const int batchSize, const int m, const int v_dim) {
// each thread writes one sphere of some link
const int t = blockDim.x * blockIdx.x + threadIdx.x; // batch
const int b_idx = t;
if (t >= (batchSize)) {
return;
}
// get thread start address:
const int b_start_scalar_adrs = b_idx * m;
const int b_start_vec_adrs = b_idx * m * v_dim;
const int b_step_start_adrs = b_idx * v_dim;
scalar_t rho_s[M * VDIM];
// copy floats to local buffer?
// y_buffer, s_buffer, rho_buffer
// compute rho_s
scalar_t loc_ybuf[M * VDIM];
scalar_t loc_sbuf[M * VDIM];
scalar_t loc_rho[M];
scalar_t gq[VDIM]; //, l_q[VDIM];
scalar_t alpha_buffer[M];
scalar_t t_1, t_2;
for (int i = 0; i < m * v_dim; i++) {
loc_ybuf[i] = y_buffer[b_start_vec_adrs + i];
loc_sbuf[i] = s_buffer[b_start_vec_adrs + i];
}
for (int i = 0; i < v_dim; i++) {
gq[i] = grad_q[b_step_start_adrs + i];
}
for (int i = 0; i < m; i++) {
loc_rho[i] = rho_buffer[b_start_scalar_adrs + i];
}
m_scalar_vec_product(&loc_rho[0], &loc_sbuf[0], &rho_s[0], v_dim, m);
// for loop over m
for (int i = m - 1; i > m - 2; i--) {
// l_start_vec_adrs = i * v_dim;
// scalar_vec_product(loc_rho[i], &loc_sbuf[i*v_dim], &rho_s[i*v_dim],
// v_dim);
vec_vec_dot(&rho_s[i * v_dim], &gq[0], alpha_buffer[i], v_dim);
update_q(&loc_ybuf[(i * v_dim)], &gq[0], alpha_buffer[i], v_dim);
}
// compute initial hessian:
vec_vec_dot(&loc_sbuf[(m - 1) * v_dim], &loc_ybuf[(m - 1) * v_dim], t_1,
v_dim);
vec_vec_dot(&loc_ybuf[(m - 1) * v_dim], &loc_ybuf[(m - 1) * v_dim], t_2,
v_dim);
t_1 = t_1 / t_2;
if (t_1 < 0) {
t_1 = 0;
}
t_1 += epsilon;
scalar_vec_product(t_1, &gq[0], &gq[0], v_dim);
m_scalar_vec_product(&loc_rho[0], &loc_ybuf[0], &rho_s[0], v_dim, m);
for (int i = 0; i < m; i++) {
// scalar_vec_product(loc_rho[i], &loc_ybuf[i*v_dim], &rho_s[i*v_dim],
// v_dim);
update_r(&rho_s[i * v_dim], &loc_sbuf[i * v_dim], &gq[0], alpha_buffer[i],
v_dim);
}
// write gq to out grad:
for (int i = 0; i < v_dim; i++) {
step_vec[b_step_start_adrs + i] = -1.0 * gq[i];
}
}
template <typename psum_t>
__forceinline__ __device__ psum_t warpReduce(psum_t v, int elems,
unsigned mask) {
psum_t val = v;
int shift = 1;
for (int i = elems; i > 1; i /= 2) {
val += __shfl_down_sync(mask, val, shift);
shift *= 2;
}
// val += __shfl_down_sync(mask, val, 1); // i=32
// val += __shfl_down_sync(mask, val, 2); // i=16
// val += __shfl_down_sync(mask, val, 4); // i=8
// val += __shfl_down_sync(mask, val, 8); // i=4
// val += __shfl_down_sync(mask, val, 16); // i=2
return val;
}
template <typename scalar_t, typename psum_t>
__forceinline__ __device__ void reduce(scalar_t v, int m, psum_t *data,
scalar_t *result) {
psum_t val = v;
unsigned mask = __ballot_sync(FULL_MASK, threadIdx.x < m);
val += __shfl_down_sync(mask, val, 1);
val += __shfl_down_sync(mask, val, 2);
val += __shfl_down_sync(mask, val, 4);
val += __shfl_down_sync(mask, val, 8);
val += __shfl_down_sync(mask, val, 16);
// int leader = __ffs(mask) 1; // select a leader lane
int leader = 0;
if (threadIdx.x % 32 == leader) {
if (m < 32) {
result[0] = (scalar_t)val;
} else {
data[(threadIdx.x + 1) / 32] = val;
}
}
if (m >= 32) {
__syncthreads();
int elems = (m + 31) / 32;
unsigned mask2 = __ballot_sync(FULL_MASK, threadIdx.x < elems);
if (threadIdx.x / 32 == 0) { // only the first warp will do this work
psum_t val2 = data[threadIdx.x % 32];
int shift = 1;
for (int i = elems - 1; i > 0; i /= 2) {
val2 += __shfl_down_sync(mask2, val2, shift);
shift *= 2;
}
// int leader = __ffs(mask2) 1; // select a leader lane
int leader = 0;
if (threadIdx.x % 32 == leader) {
result[0] = (scalar_t)val2;
}
}
}
__syncthreads();
}
// blockReduce
template <typename scalar_t, typename psum_t>
__forceinline__ __device__ void reduce_v1(scalar_t v, int m, psum_t *data,
scalar_t *result) {
unsigned mask = __ballot_sync(FULL_MASK, threadIdx.x < m);
psum_t val = warpReduce(v, 32, mask);
// int leader = __ffs(mask) 1; // select a leader lane
int leader = 0;
if (threadIdx.x % 32 == leader) {
data[(threadIdx.x + 1) / 32] = val;
}
if (m >= 32) {
__syncthreads();
int elems = (m + 31) / 32;
unsigned mask2 = __ballot_sync(FULL_MASK, threadIdx.x < elems);
if (threadIdx.x / 32 == 0) { // only the first warp will do this work
psum_t val2 = warpReduce(data[threadIdx.x % 32], elems, mask2);
// // int leader = __ffs(mask2) 1; // select a leader lane
if (threadIdx.x == leader) {
result[0] = (scalar_t)val2;
}
}
} else {
if (threadIdx.x == leader) {
result[0] = (scalar_t)val;
}
}
__syncthreads();
}
template <typename scalar_t, typename psum_t>
__inline__ __device__ void dot(const scalar_t *mat1, const scalar_t *mat2,
const int m, psum_t *data, scalar_t *result) {
scalar_t val = mat1[threadIdx.x] * mat2[threadIdx.x];
reduce(val, m, data, result);
}
template <typename scalar_t> __inline__ __device__ scalar_t relu(scalar_t var) {
if (var < 0)
return 0;
else
return var;
}
//////////////////////////////////////////////////////////
// one block per batch
// v_dim threads per block
//////////////////////////////////////////////////////////
template <typename scalar_t, typename psum_t>
__global__ void lbfgs_step_kernel(scalar_t *step_vec, // b x 175
scalar_t *rho_buffer, // m x b x 1
const scalar_t *y_buffer, // m x b x 175
const scalar_t *s_buffer, // m x b x 175
const scalar_t *grad_q, // b x 175
const float epsilon, const int batchsize,
const int m, const int v_dim) {
__shared__ psum_t
data[32]; // temporary buffer needed for block-wide reduction
__shared__ scalar_t
result; // result of the reduction or vector-vector dot product
__shared__ scalar_t gq[175]; /// gq = batch * v_dim
assert(v_dim < 176);
int batch = blockIdx.x; // one block per batch
if (threadIdx.x >= v_dim)
return;
gq[threadIdx.x] = grad_q[batch * v_dim + threadIdx.x]; // copy grad_q to gq
scalar_t alpha_buffer[16];
// assert(m<16); // allocating a buffer assuming m < 16
for (int i = m - 1; i > -1; i--) {
dot(&gq[0], &s_buffer[i * batchsize * v_dim + batch * v_dim], v_dim,
&data[0], &result);
alpha_buffer[i] = result * rho_buffer[i * batchsize + batch];
gq[threadIdx.x] =
gq[threadIdx.x] -
alpha_buffer[i] *
y_buffer[i * batchsize * v_dim + batch * v_dim + threadIdx.x];
}
// compute var1
scalar_t val1 =
y_buffer[(m - 1) * batchsize * v_dim + batch * v_dim + threadIdx.x];
scalar_t val2 =
s_buffer[(m - 1) * batchsize * v_dim + batch * v_dim + threadIdx.x];
reduce(val1 * val1, v_dim, data, &result);
scalar_t denominator = result;
reduce(val1 * val2, v_dim, data, &result);
scalar_t numerator = result;
scalar_t var1 = numerator / denominator;
scalar_t gamma = relu(var1) + epsilon; // epsilon
gq[threadIdx.x] = gamma * gq[threadIdx.x];
for (int i = 0; i < m; i++) {
dot(&gq[0], &y_buffer[i * batchsize * v_dim + batch * v_dim], v_dim,
&data[0], &result);
gq[threadIdx.x] =
gq[threadIdx.x] +
(alpha_buffer[i] - result * rho_buffer[i * batchsize + batch]) *
s_buffer[i * batchsize * v_dim + batch * v_dim + threadIdx.x];
}
step_vec[batch * v_dim + threadIdx.x] =
-1 * gq[threadIdx.x]; // copy from shared memory to global memory
}
template <typename scalar_t, typename psum_t>
__global__ void lbfgs_update_buffer_kernel(scalar_t *rho_buffer, // m x b x 1
scalar_t *y_buffer, // m x b x 175
scalar_t *s_buffer, // m x b x 175
scalar_t *q, // b x 175
scalar_t *x_0, // b x 175
scalar_t *grad_0, // b x 175
const scalar_t *grad_q, // b x 175
const int batchsize, const int m,
const int v_dim) {
__shared__ psum_t
data[32]; // temporary buffer needed for block-wide reduction
__shared__ scalar_t
result; // result of the reduction or vector-vector dot product
// __shared__ scalar_t y[175]; // temporary shared memory storage
// __shared__ scalar_t s[175]; // temporary shared memory storage
assert(v_dim <= VDIM);
int batch = blockIdx.x; // one block per batch
if (threadIdx.x >= v_dim)
return;
scalar_t y =
grad_q[batch * v_dim + threadIdx.x] - grad_0[batch * v_dim + threadIdx.x];
scalar_t s =
q[batch * v_dim + threadIdx.x] - x_0[batch * v_dim + threadIdx.x];
reduce(y * s, v_dim, &data[0], &result);
for (int i = 1; i < m; i++) {
s_buffer[(i - 1) * batchsize * v_dim + batch * v_dim + threadIdx.x] =
s_buffer[i * batchsize * v_dim + batch * v_dim + threadIdx.x];
y_buffer[(i - 1) * batchsize * v_dim + batch * v_dim + threadIdx.x] =
y_buffer[i * batchsize * v_dim + batch * v_dim + threadIdx.x];
}
// __syncthreads();
s_buffer[(m - 1) * batchsize * v_dim + batch * v_dim + threadIdx.x] = s;
y_buffer[(m - 1) * batchsize * v_dim + batch * v_dim + threadIdx.x] = y;
grad_0[batch * v_dim + threadIdx.x] = grad_q[batch * v_dim + threadIdx.x];
x_0[batch * v_dim + threadIdx.x] = q[batch * v_dim + threadIdx.x];
if (threadIdx.x == 0) {
scalar_t rho = 1 / result;
for (int i = 1; i < m; i++) {
rho_buffer[(i - 1) * batchsize + batch] =
rho_buffer[i * batchsize + batch];
}
rho_buffer[(m - 1) * batchsize + batch] = rho;
}
}
template <typename scalar_t, typename psum_t>
__global__ void reduce_kernel(
scalar_t *vec1, // b x 175
scalar_t *vec2, // b x 175
scalar_t *rho_buffer, // m x b x 1
scalar_t *sum_out, // m x b x 1
const int batchsize, const int m,
const int v_dim) // s_buffer and y_buffer are not rolled by default
{
__shared__ psum_t
data[32]; // temporary buffer needed for block-wide reduction
__shared__ scalar_t
result; // result of the reduction or vector-vector dot product
int batch = blockIdx.x; // one block per batch
if (threadIdx.x >= v_dim)
return;
////////////////////
// update_buffer
////////////////////
scalar_t y = vec1[batch * v_dim + threadIdx.x];
scalar_t s = vec2[batch * v_dim + threadIdx.x];
reduce(y * s, v_dim, &data[0], &result);
scalar_t numerator = result;
if (threadIdx.x == 0) {
sum_out[batch] = 1 / numerator;
}
// return;
if (threadIdx.x < m - 1) {
// m thread participate to shif the values
// this is safe as m<32 and this happens in lockstep
rho_buffer[threadIdx.x * batchsize + batch] =
rho_buffer[(threadIdx.x + 1) * batchsize + batch];
} else if (threadIdx.x == m - 1) {
scalar_t rho = (1 / numerator);
rho_buffer[threadIdx.x * batchsize + batch] = rho;
}
}
template <typename scalar_t, typename psum_t, bool rolled_ys>
__global__ void lbfgs_update_buffer_and_step_v1(
scalar_t *step_vec, // b x 175
scalar_t *rho_buffer, // m x b x 1
scalar_t *y_buffer, // m x b x 175
scalar_t *s_buffer, // m x b x 175
scalar_t *q, // b x 175
scalar_t *x_0, // b x 175
scalar_t *grad_0, // b x 175
const scalar_t *grad_q, // b x 175
const float epsilon, const int batchsize, const int m, const int v_dim,
const bool stable_mode =
false) // s_buffer and y_buffer are not rolled by default
{
// extern __shared__ scalar_t alpha_buffer_sh[];
extern __shared__ __align__(sizeof(scalar_t)) unsigned char my_smem[];
scalar_t *my_smem_rc = reinterpret_cast<scalar_t *>(my_smem);
scalar_t *alpha_buffer_sh = &my_smem_rc[0]; // m*blockDim.x
scalar_t *rho_buffer_sh = &my_smem_rc[m * blockDim.x]; // batchsize*m
scalar_t *s_buffer_sh =
&my_smem_rc[m * blockDim.x + m * batchsize]; // m*blockDim.x
scalar_t *y_buffer_sh =
&my_smem_rc[2 * m * blockDim.x + m * batchsize]; // m*blockDim.x
__shared__ psum_t
data[32]; // temporary buffer needed for block-wide reduction
__shared__ scalar_t
result; // result of the reduction or vector-vector dot product
int batch = blockIdx.x; // one block per batch
if (threadIdx.x >= v_dim)
return;
scalar_t gq;
gq = grad_q[batch * v_dim + threadIdx.x]; // copy grad_q to gq
////////////////////
// update_buffer
////////////////////
scalar_t y = gq - grad_0[batch * v_dim + threadIdx.x];
// if y is close to zero
scalar_t s =
q[batch * v_dim + threadIdx.x] - x_0[batch * v_dim + threadIdx.x];
reduce_v1(y * s, v_dim, &data[0], &result);
//reduce(y * s, v_dim, &data[0], &result);
scalar_t numerator = result;
// scalar_t rho = 1.0/numerator;
if (!rolled_ys) {
#pragma unroll
for (int i = 1; i < m; i++) {
scalar_t st =
s_buffer[i * batchsize * v_dim + batch * v_dim + threadIdx.x];
scalar_t yt =
y_buffer[i * batchsize * v_dim + batch * v_dim + threadIdx.x];
s_buffer[(i - 1) * batchsize * v_dim + batch * v_dim + threadIdx.x] = st;
y_buffer[(i - 1) * batchsize * v_dim + batch * v_dim + threadIdx.x] = yt;
s_buffer_sh[m * threadIdx.x + i - 1] = st;
y_buffer_sh[m * threadIdx.x + i - 1] = yt;
}
}
s_buffer[(m - 1) * batchsize * v_dim + batch * v_dim + threadIdx.x] = s;
y_buffer[(m - 1) * batchsize * v_dim + batch * v_dim + threadIdx.x] = y;
s_buffer_sh[m * threadIdx.x + m - 1] = s;
y_buffer_sh[m * threadIdx.x + m - 1] = y;
grad_0[batch * v_dim + threadIdx.x] = gq;
x_0[batch * v_dim + threadIdx.x] = q[batch * v_dim + threadIdx.x];
if (threadIdx.x < m - 1) {
// m thread participate to shif the values
// this is safe as m<32 and this happens in lockstep
scalar_t rho = rho_buffer[(threadIdx.x + 1) * batchsize + batch];
rho_buffer[threadIdx.x * batchsize + batch] = rho;
rho_buffer_sh[threadIdx.x * batchsize + batch] = rho;
}
if (threadIdx.x == m - 1) {
scalar_t rho = 1.0 / numerator;
// if this is nan, make it zero:
if (stable_mode && numerator == 0.0) {
rho = 0.0;
}
rho_buffer[threadIdx.x * batchsize + batch] = rho;
rho_buffer_sh[threadIdx.x * batchsize + batch] = rho;
}
// return;
__syncthreads();
////////////////////
// step
////////////////////
// scalar_t alpha_buffer[16];
// assert(m<16); // allocating a buffer assuming m < 16
#pragma unroll
for (int i = m - 1; i > -1; i--) {
// reduce(gq * s_buffer[i*batchsize*v_dim + batch*v_dim + threadIdx.x],
// v_dim, &data[0], &result);
reduce_v1(gq * s_buffer_sh[m * threadIdx.x + i], v_dim, &data[0], &result);
alpha_buffer_sh[threadIdx.x * m + i] =
result * rho_buffer_sh[i * batchsize + batch];
// gq = gq - alpha_buffer_sh[threadIdx.x*m+i]*y_buffer[i*batchsize*v_dim +
// batch*v_dim + threadIdx.x];
gq = gq - alpha_buffer_sh[threadIdx.x * m + i] *
y_buffer_sh[m * threadIdx.x + i];
}
// compute var1
reduce_v1(y * y, v_dim, data, &result);
scalar_t denominator = result;
// reduce(s*y, v_dim, data, &result); // redundant - already computed it above
// scalar_t numerator = result;
scalar_t var1 = numerator / denominator;
// To improve stability, uncomment below line: [this however leads to poor
// convergence]
if (stable_mode && denominator == 0.0) {
var1 = epsilon;
}
scalar_t gamma = relu(var1);
gq = gamma * gq;
#pragma unroll
for (int i = 0; i < m; i++) {
// reduce(gq * y_buffer[i*batchsize*v_dim + batch*v_dim + threadIdx.x],
// v_dim, &data[0], &result); gq = gq + (alpha_buffer_sh[threadIdx.x*m+i] -
// result * rho_buffer_sh[i*batchsize+batch]) * s_buffer[i*batchsize*v_dim +
// batch*v_dim + threadIdx.x];
reduce_v1(gq * y_buffer_sh[m * threadIdx.x + i], v_dim, &data[0], &result);
gq = gq + (alpha_buffer_sh[threadIdx.x * m + i] -
result * rho_buffer_sh[i * batchsize + batch]) *
s_buffer_sh[m * threadIdx.x + i];
}
step_vec[batch * v_dim + threadIdx.x] =
-1.0 * gq; // copy from shared memory to global memory
}
// (32/M) rolls per warp
// Threads in a warp in a GPU execute in a lock-step. We leverage that to do
// the roll without using temporary storage or explicit synchronization.
template <typename scalar_t>
__global__ void lbfgs_roll(scalar_t *a, // m x b x 175
scalar_t *b, // m x b x 175
const int m_t, const int batchsize, const int m,
const int v_dim) {
assert(m_t <= 32);
int t = blockDim.x * blockIdx.x + threadIdx.x;
if (t >= m_t * v_dim * batchsize)
return;
int _m = t % m_t;
int _v_dim = (t / m_t) % v_dim;
int batch = t / (m * v_dim); // this line could be wrong?
if (_m < m - 1) {
a[_m * batchsize * v_dim + batch * v_dim + _v_dim] =
a[(_m + 1) * batchsize * v_dim + batch * v_dim + _v_dim];
b[_m * batchsize * v_dim + batch * v_dim + _v_dim] =
b[(_m + 1) * batchsize * v_dim + batch * v_dim + _v_dim];
}
}
template <typename scalar_t, typename psum_t, bool rolled_ys>
__global__ void lbfgs_update_buffer_and_step(
scalar_t *step_vec, // b x 175
scalar_t *rho_buffer, // m x b x 1
scalar_t *y_buffer, // m x b x 175
scalar_t *s_buffer, // m x b x 175
scalar_t *q, // b x 175
scalar_t *x_0, // b x 175
scalar_t *grad_0, // b x 175
const scalar_t *grad_q, // b x 175
const float epsilon, const int batchsize, const int m, const int v_dim,
const bool stable_mode =
false) // s_buffer and y_buffer are not rolled by default
{
// extern __shared__ scalar_t alpha_buffer_sh[];
extern __shared__ __align__(sizeof(scalar_t)) unsigned char my_smem[];
scalar_t *alpha_buffer_sh = reinterpret_cast<scalar_t *>(my_smem);
__shared__ psum_t
data[32]; // temporary buffer needed for block-wide reduction
__shared__ scalar_t
result; // result of the reduction or vector-vector dot product
int batch = blockIdx.x; // one block per batch
if (threadIdx.x >= v_dim)
return;
scalar_t gq;
gq = grad_q[batch * v_dim + threadIdx.x]; // copy grad_q to gq
////////////////////
// update_buffer
////////////////////
scalar_t y = gq - grad_0[batch * v_dim + threadIdx.x];
// if y is close to zero
scalar_t s =
q[batch * v_dim + threadIdx.x] - x_0[batch * v_dim + threadIdx.x];
reduce(y * s, v_dim, &data[0], &result);
scalar_t numerator = result;
// scalar_t rho = 1.0/numerator;
if (!rolled_ys) {
for (int i = 1; i < m; i++) {
s_buffer[(i - 1) * batchsize * v_dim + batch * v_dim + threadIdx.x] =
s_buffer[i * batchsize * v_dim + batch * v_dim + threadIdx.x];
y_buffer[(i - 1) * batchsize * v_dim + batch * v_dim + threadIdx.x] =
y_buffer[i * batchsize * v_dim + batch * v_dim + threadIdx.x];
}
}
s_buffer[(m - 1) * batchsize * v_dim + batch * v_dim + threadIdx.x] = s;
y_buffer[(m - 1) * batchsize * v_dim + batch * v_dim + threadIdx.x] = y;
grad_0[batch * v_dim + threadIdx.x] = gq;
x_0[batch * v_dim + threadIdx.x] = q[batch * v_dim + threadIdx.x];
if (threadIdx.x < m - 1) {
// m thread participate to shif the values
// this is safe as m<32 and this happens in lockstep
rho_buffer[threadIdx.x * batchsize + batch] =
rho_buffer[(threadIdx.x + 1) * batchsize + batch];
}
if (threadIdx.x == m - 1) {
scalar_t rho = 1.0 / numerator;
// if this is nan, make it zero:
if (stable_mode && numerator == 0.0) {
rho = 0.0;
}
rho_buffer[threadIdx.x * batchsize + batch] = rho;
}
// return;
//__syncthreads();
////////////////////
// step
////////////////////
// scalar_t alpha_buffer[16];
// assert(m<16); // allocating a buffer assuming m < 16
#pragma unroll
for (int i = m - 1; i > -1; i--) {
reduce(gq * s_buffer[i * batchsize * v_dim + batch * v_dim + threadIdx.x],
v_dim, &data[0], &result);
alpha_buffer_sh[threadIdx.x * m + i] =
result * rho_buffer[i * batchsize + batch];
gq = gq - alpha_buffer_sh[threadIdx.x * m + i] *
y_buffer[i * batchsize * v_dim + batch * v_dim + threadIdx.x];
}
// compute var1
reduce(y * y, v_dim, data, &result);
scalar_t denominator = result;
// reduce(s*y, v_dim, data, &result); // redundant - already computed it above
// scalar_t numerator = result;
scalar_t var1 = numerator / denominator;
// To improve stability, uncomment below line: [this however leads to poor
// convergence]
if (stable_mode && denominator == 0.0) {
var1 = epsilon;
}
scalar_t gamma = relu(var1);
gq = gamma * gq;
#pragma unroll
for (int i = 0; i < m; i++) {
reduce(gq * y_buffer[i * batchsize * v_dim + batch * v_dim + threadIdx.x],
v_dim, &data[0], &result);
gq = gq + (alpha_buffer_sh[threadIdx.x * m + i] -
result * rho_buffer[i * batchsize + batch]) *
s_buffer[i * batchsize * v_dim + batch * v_dim + threadIdx.x];
}
step_vec[batch * v_dim + threadIdx.x] =
-1.0 * gq; // copy from shared memory to global memory
}
} // namespace Optimization
} // namespace Curobo
std::vector<torch::Tensor>
lbfgs_step_cuda(torch::Tensor step_vec, torch::Tensor rho_buffer,
torch::Tensor y_buffer, torch::Tensor s_buffer,
torch::Tensor grad_q, const float epsilon, const int batch_size,
const int m, const int v_dim) {
using namespace Curobo::Optimization;
const int threadsPerBlock = 128;
const int blocksPerGrid =
((batch_size) + threadsPerBlock - 1) / threadsPerBlock;
// launch threads per batch:
// int threadsPerBlock = pow(2,((int)log2(v_dim))+1);
// const int blocksPerGrid = batch_size; //((batch_size) + threadsPerBlock -
// 1) / threadsPerBlock;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(
step_vec.scalar_type(), "lbfgs_step_cu", ([&] {
lbfgs_step_kernel_old<scalar_t>
<<<blocksPerGrid, threadsPerBlock,
v_dim * sizeof(step_vec.scalar_type()), stream>>>(
step_vec.data_ptr<scalar_t>(), rho_buffer.data_ptr<scalar_t>(),
y_buffer.data_ptr<scalar_t>(), s_buffer.data_ptr<scalar_t>(),
grad_q.data_ptr<scalar_t>(), epsilon, batch_size, m, v_dim);
}));
C10_CUDA_KERNEL_LAUNCH_CHECK();
return {step_vec};
}
std::vector<torch::Tensor>
lbfgs_update_cuda(torch::Tensor rho_buffer, torch::Tensor y_buffer,
torch::Tensor s_buffer, torch::Tensor q, torch::Tensor grad_q,
torch::Tensor x_0, torch::Tensor grad_0, const int batch_size,
const int m, const int v_dim) {
using namespace Curobo::Optimization;
// const int threadsPerBlock = 128;
// launch threads per batch:
// int threadsPerBlock = pow(2,((int)log2(v_dim))+1);
int threadsPerBlock = v_dim;
const int blocksPerGrid =
batch_size; //((batch_size) + threadsPerBlock - 1) / threadsPerBlock;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(
y_buffer.scalar_type(), "lbfgs_update_cu", ([&] {
lbfgs_update_buffer_kernel<scalar_t, scalar_t>
<<<blocksPerGrid, threadsPerBlock,
v_dim * sizeof(y_buffer.scalar_type()), stream>>>(
rho_buffer.data_ptr<scalar_t>(), y_buffer.data_ptr<scalar_t>(),
s_buffer.data_ptr<scalar_t>(), q.data_ptr<scalar_t>(),
x_0.data_ptr<scalar_t>(), grad_0.data_ptr<scalar_t>(),
grad_q.data_ptr<scalar_t>(), batch_size, m, v_dim);
}));
C10_CUDA_KERNEL_LAUNCH_CHECK();
return {rho_buffer, y_buffer, s_buffer, x_0, grad_0};
}
std::vector<torch::Tensor>
lbfgs_cuda_fuse(torch::Tensor step_vec, torch::Tensor rho_buffer,
torch::Tensor y_buffer, torch::Tensor s_buffer, torch::Tensor q,
torch::Tensor grad_q, torch::Tensor x_0, torch::Tensor grad_0,
const float epsilon, const int batch_size, const int m,
const int v_dim, const bool stable_mode) {
using namespace Curobo::Optimization;
// call first kernel:
int threadsPerBlock = v_dim;
assert(threadsPerBlock < 1024);
assert(m < M_MAX);
int blocksPerGrid =
batch_size; //((batch_size) + threadsPerBlock - 1) / threadsPerBlock;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int smemsize = 0;
if (true) {
AT_DISPATCH_FLOATING_TYPES(
y_buffer.scalar_type(), "lbfgs_cuda_fuse_kernel", [&] {
smemsize = m * threadsPerBlock * sizeof(scalar_t);
lbfgs_update_buffer_and_step<scalar_t, scalar_t, false>
<<<blocksPerGrid, threadsPerBlock, smemsize, stream>>>(
step_vec.data_ptr<scalar_t>(),
rho_buffer.data_ptr<scalar_t>(),
y_buffer.data_ptr<scalar_t>(), s_buffer.data_ptr<scalar_t>(),
q.data_ptr<scalar_t>(), x_0.data_ptr<scalar_t>(),
grad_0.data_ptr<scalar_t>(), grad_q.data_ptr<scalar_t>(),
epsilon, batch_size, m, v_dim, stable_mode);
});
} else {
// v1 does not work
AT_DISPATCH_FLOATING_TYPES(
y_buffer.scalar_type(), "lbfgs_cuda_fuse_kernel_v1", [&] {
smemsize = 3 * m * threadsPerBlock * sizeof(scalar_t) +
m * batch_size * sizeof(scalar_t);
lbfgs_update_buffer_and_step_v1<scalar_t, scalar_t, false>
<<<blocksPerGrid, threadsPerBlock, smemsize, stream>>>(
step_vec.data_ptr<scalar_t>(),
rho_buffer.data_ptr<scalar_t>(),
y_buffer.data_ptr<scalar_t>(), s_buffer.data_ptr<scalar_t>(),
q.data_ptr<scalar_t>(), x_0.data_ptr<scalar_t>(),
grad_0.data_ptr<scalar_t>(), grad_q.data_ptr<scalar_t>(),
epsilon, batch_size, m, v_dim, stable_mode);
});
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
return {step_vec, rho_buffer, y_buffer, s_buffer, x_0, grad_0};
}
std::vector<torch::Tensor> reduce_cuda(torch::Tensor vec, torch::Tensor vec2,
torch::Tensor rho_buffer,
torch::Tensor sum, const int batch_size,
const int m, const int v_dim) {
using namespace Curobo::Optimization;
int threadsPerBlock = pow(2, ((int)log2(v_dim)) + 1);
int blocksPerGrid =
batch_size; //((batch_size) + threadsPerBlock - 1) / threadsPerBlock;
// printf("threadsPerBlock:%d, blocksPerGrid: %d\n",
// threadsPerBlock, blocksPerGrid);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(
vec.scalar_type(), "reduce_cu", ([&] {
reduce_kernel<scalar_t, scalar_t>
<<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
vec.data_ptr<scalar_t>(), vec2.data_ptr<scalar_t>(),
rho_buffer.data_ptr<scalar_t>(), sum.data_ptr<scalar_t>(),
batch_size, m, v_dim);
}));
C10_CUDA_KERNEL_LAUNCH_CHECK();
return {sum, rho_buffer};
}