Faster bits and bernoulli (#1535)

* faster bits and bernoulli

* fix bernoulli
This commit is contained in:
Awni Hannun 2024-10-28 11:11:00 -07:00 committed by GitHub
parent 91f6c499d7
commit d3cd26820e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 14 additions and 7 deletions

View File

@ -34,8 +34,8 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
[[kernel]] void rbitsc( [[kernel]] void rbitsc(
device const uint32_t* keys, device const uint32_t* keys,
device char* out, device char* out,
device const bool& odd, constant const bool& odd,
device const uint& bytes_per_key, constant const uint& bytes_per_key,
uint2 grid_dim [[threads_per_grid]], uint2 grid_dim [[threads_per_grid]],
uint2 index [[thread_position_in_grid]]) { uint2 index [[thread_position_in_grid]]) {
auto kidx = 2 * index.x; auto kidx = 2 * index.x;
@ -67,8 +67,8 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
[[kernel]] void rbits( [[kernel]] void rbits(
device const uint32_t* keys, device const uint32_t* keys,
device char* out, device char* out,
device const bool& odd, constant const bool& odd,
device const uint& bytes_per_key, constant const uint& bytes_per_key,
constant const int& ndim, constant const int& ndim,
constant const int* key_shape, constant const int* key_shape,
constant const size_t* key_strides, constant const size_t* key_strides,

View File

@ -273,7 +273,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
// organize into grid nkeys x elem_per_key // organize into grid nkeys x elem_per_key
MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1); MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(1, thread_group_size, 1);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(keys, 0); compute_encoder.set_input_array(keys, 0);

View File

@ -290,8 +290,15 @@ array bernoulli(
throw std::invalid_argument( throw std::invalid_argument(
"[bernoulli] bernoulli probability `p` must be a float type."); "[bernoulli] bernoulli probability `p` must be a float type.");
} }
auto res = uniform(shape, p.dtype(), key, s);
res = less(res, p, s); // Place p on the scale [0, nexthigher(UINT32_MAX)] so that if p >= 1.0 we
// get all true and if p <= 0.0 we get all false
auto upper = array(
std::nextafter(
static_cast<float>(std::numeric_limits<uint32_t>::max()),
std::numeric_limits<float>::max()),
float32);
auto res = less(bits(shape, key, s), multiply(p, upper, s), s);
if (res.shape() != shape) { if (res.shape() != shape) {
throw std::invalid_argument( throw std::invalid_argument(
"[bernoulli] shape of `p` is incompatible with argument `shape`."); "[bernoulli] shape of `p` is incompatible with argument `shape`.");