diff --git a/mlx/backend/metal/kernels/random.metal b/mlx/backend/metal/kernels/random.metal index 39bc15953..f61663f8a 100644 --- a/mlx/backend/metal/kernels/random.metal +++ b/mlx/backend/metal/kernels/random.metal @@ -34,8 +34,8 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) { [[kernel]] void rbitsc( device const uint32_t* keys, device char* out, - device const bool& odd, - device const uint& bytes_per_key, + constant const bool& odd, + constant const uint& bytes_per_key, uint2 grid_dim [[threads_per_grid]], uint2 index [[thread_position_in_grid]]) { auto kidx = 2 * index.x; @@ -67,8 +67,8 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) { [[kernel]] void rbits( device const uint32_t* keys, device char* out, - device const bool& odd, - device const uint& bytes_per_key, + constant const bool& odd, + constant const uint& bytes_per_key, constant const int& ndim, constant const int* key_shape, constant const size_t* key_strides, diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index e5a7d885b..67ec2949a 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -273,7 +273,7 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { // organize into grid nkeys x elem_per_key MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); + MTL::Size group_dims = MTL::Size(1, thread_group_size, 1); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); compute_encoder.set_input_array(keys, 0); diff --git a/mlx/random.cpp b/mlx/random.cpp index 6368bdeed..ba6434191 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -290,8 +290,15 @@ array bernoulli( throw std::invalid_argument( "[bernoulli] bernoulli probability `p` must be a float type."); } - auto res = uniform(shape, p.dtype(), key, s); - res = less(res, p, s); + + // Place p on the scale [0, nexthigher(UINT32_MAX)] so that if p >= 1.0 we + // get all true and if p <= 0.0 we get all false + auto upper = array( + std::nextafter( + static_cast(std::numeric_limits::max()), + std::numeric_limits::max()), + float32); + auto res = less(bits(shape, key, s), multiply(p, upper, s), s); if (res.shape() != shape) { throw std::invalid_argument( "[bernoulli] shape of `p` is incompatible with argument `shape`.");