mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Faster bits and bernoulli (#1535)
* faster bits and bernoulli * fix bernoulli
This commit is contained in:
parent
91f6c499d7
commit
d3cd26820e
@ -34,8 +34,8 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
|||||||
[[kernel]] void rbitsc(
|
[[kernel]] void rbitsc(
|
||||||
device const uint32_t* keys,
|
device const uint32_t* keys,
|
||||||
device char* out,
|
device char* out,
|
||||||
device const bool& odd,
|
constant const bool& odd,
|
||||||
device const uint& bytes_per_key,
|
constant const uint& bytes_per_key,
|
||||||
uint2 grid_dim [[threads_per_grid]],
|
uint2 grid_dim [[threads_per_grid]],
|
||||||
uint2 index [[thread_position_in_grid]]) {
|
uint2 index [[thread_position_in_grid]]) {
|
||||||
auto kidx = 2 * index.x;
|
auto kidx = 2 * index.x;
|
||||||
@ -67,8 +67,8 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
|||||||
[[kernel]] void rbits(
|
[[kernel]] void rbits(
|
||||||
device const uint32_t* keys,
|
device const uint32_t* keys,
|
||||||
device char* out,
|
device char* out,
|
||||||
device const bool& odd,
|
constant const bool& odd,
|
||||||
device const uint& bytes_per_key,
|
constant const uint& bytes_per_key,
|
||||||
constant const int& ndim,
|
constant const int& ndim,
|
||||||
constant const int* key_shape,
|
constant const int* key_shape,
|
||||||
constant const size_t* key_strides,
|
constant const size_t* key_strides,
|
||||||
|
@ -273,7 +273,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// organize into grid nkeys x elem_per_key
|
// organize into grid nkeys x elem_per_key
|
||||||
MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1);
|
MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1);
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
MTL::Size group_dims = MTL::Size(1, thread_group_size, 1);
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
compute_encoder.set_input_array(keys, 0);
|
compute_encoder.set_input_array(keys, 0);
|
||||||
|
@ -290,8 +290,15 @@ array bernoulli(
|
|||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[bernoulli] bernoulli probability `p` must be a float type.");
|
"[bernoulli] bernoulli probability `p` must be a float type.");
|
||||||
}
|
}
|
||||||
auto res = uniform(shape, p.dtype(), key, s);
|
|
||||||
res = less(res, p, s);
|
// Place p on the scale [0, nexthigher(UINT32_MAX)] so that if p >= 1.0 we
|
||||||
|
// get all true and if p <= 0.0 we get all false
|
||||||
|
auto upper = array(
|
||||||
|
std::nextafter(
|
||||||
|
static_cast<float>(std::numeric_limits<uint32_t>::max()),
|
||||||
|
std::numeric_limits<float>::max()),
|
||||||
|
float32);
|
||||||
|
auto res = less(bits(shape, key, s), multiply(p, upper, s), s);
|
||||||
if (res.shape() != shape) {
|
if (res.shape() != shape) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[bernoulli] shape of `p` is incompatible with argument `shape`.");
|
"[bernoulli] shape of `p` is incompatible with argument `shape`.");
|
||||||
|
Loading…
Reference in New Issue
Block a user