update to 0.6.2

This commit is contained in:
Balakumar Sundaralingam
2023-12-15 02:01:33 -08:00
parent d85ae41fba
commit 58958bbcce
105 changed files with 2514 additions and 934 deletions

View File

@@ -441,7 +441,7 @@ __global__ void reduce_kernel(
rho_buffer[threadIdx.x * batchsize + batch] = rho;
}
}
template <typename scalar_t, typename psum_t>
template <typename scalar_t, typename psum_t, bool rolled_ys>
__global__ void lbfgs_update_buffer_and_step_v1(
scalar_t *step_vec, // b x 175
scalar_t *rho_buffer, // m x b x 1
@@ -452,7 +452,6 @@ __global__ void lbfgs_update_buffer_and_step_v1(
scalar_t *grad_0, // b x 175
const scalar_t *grad_q, // b x 175
const float epsilon, const int batchsize, const int m, const int v_dim,
const bool rolled_ys = false,
const bool stable_mode =
false) // s_buffer and y_buffer are not rolled by default
{
@@ -485,6 +484,7 @@ __global__ void lbfgs_update_buffer_and_step_v1(
scalar_t s =
q[batch * v_dim + threadIdx.x] - x_0[batch * v_dim + threadIdx.x];
reduce_v1(y * s, v_dim, &data[0], &result);
//reduce(y * s, v_dim, &data[0], &result);
scalar_t numerator = result;
// scalar_t rho = 1.0/numerator;
@@ -824,14 +824,14 @@ lbfgs_cuda_fuse(torch::Tensor step_vec, torch::Tensor rho_buffer,
y_buffer.scalar_type(), "lbfgs_cuda_fuse_kernel_v1", [&] {
smemsize = 3 * m * threadsPerBlock * sizeof(scalar_t) +
m * batch_size * sizeof(scalar_t);
lbfgs_update_buffer_and_step_v1<scalar_t, scalar_t>
lbfgs_update_buffer_and_step_v1<scalar_t, scalar_t, false>
<<<blocksPerGrid, threadsPerBlock, smemsize, stream>>>(
step_vec.data_ptr<scalar_t>(),
rho_buffer.data_ptr<scalar_t>(),
y_buffer.data_ptr<scalar_t>(), s_buffer.data_ptr<scalar_t>(),
q.data_ptr<scalar_t>(), x_0.data_ptr<scalar_t>(),
grad_0.data_ptr<scalar_t>(), grad_q.data_ptr<scalar_t>(),
epsilon, batch_size, m, v_dim, false, stable_mode);
epsilon, batch_size, m, v_dim, stable_mode);
});
}