update to 0.6.2
This commit is contained in:
@@ -441,7 +441,7 @@ __global__ void reduce_kernel(
|
||||
rho_buffer[threadIdx.x * batchsize + batch] = rho;
|
||||
}
|
||||
}
|
||||
template <typename scalar_t, typename psum_t>
|
||||
template <typename scalar_t, typename psum_t, bool rolled_ys>
|
||||
__global__ void lbfgs_update_buffer_and_step_v1(
|
||||
scalar_t *step_vec, // b x 175
|
||||
scalar_t *rho_buffer, // m x b x 1
|
||||
@@ -452,7 +452,6 @@ __global__ void lbfgs_update_buffer_and_step_v1(
|
||||
scalar_t *grad_0, // b x 175
|
||||
const scalar_t *grad_q, // b x 175
|
||||
const float epsilon, const int batchsize, const int m, const int v_dim,
|
||||
const bool rolled_ys = false,
|
||||
const bool stable_mode =
|
||||
false) // s_buffer and y_buffer are not rolled by default
|
||||
{
|
||||
@@ -485,6 +484,7 @@ __global__ void lbfgs_update_buffer_and_step_v1(
|
||||
scalar_t s =
|
||||
q[batch * v_dim + threadIdx.x] - x_0[batch * v_dim + threadIdx.x];
|
||||
reduce_v1(y * s, v_dim, &data[0], &result);
|
||||
//reduce(y * s, v_dim, &data[0], &result);
|
||||
scalar_t numerator = result;
|
||||
// scalar_t rho = 1.0/numerator;
|
||||
|
||||
@@ -824,14 +824,14 @@ lbfgs_cuda_fuse(torch::Tensor step_vec, torch::Tensor rho_buffer,
|
||||
y_buffer.scalar_type(), "lbfgs_cuda_fuse_kernel_v1", [&] {
|
||||
smemsize = 3 * m * threadsPerBlock * sizeof(scalar_t) +
|
||||
m * batch_size * sizeof(scalar_t);
|
||||
lbfgs_update_buffer_and_step_v1<scalar_t, scalar_t>
|
||||
lbfgs_update_buffer_and_step_v1<scalar_t, scalar_t, false>
|
||||
<<<blocksPerGrid, threadsPerBlock, smemsize, stream>>>(
|
||||
step_vec.data_ptr<scalar_t>(),
|
||||
rho_buffer.data_ptr<scalar_t>(),
|
||||
y_buffer.data_ptr<scalar_t>(), s_buffer.data_ptr<scalar_t>(),
|
||||
q.data_ptr<scalar_t>(), x_0.data_ptr<scalar_t>(),
|
||||
grad_0.data_ptr<scalar_t>(), grad_q.data_ptr<scalar_t>(),
|
||||
epsilon, batch_size, m, v_dim, false, stable_mode);
|
||||
epsilon, batch_size, m, v_dim, stable_mode);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user