mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Faster bits and bernoulli (#1535)
* faster bits and bernoulli * fix bernoulli
This commit is contained in:
@@ -34,8 +34,8 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
||||
[[kernel]] void rbitsc(
|
||||
device const uint32_t* keys,
|
||||
device char* out,
|
||||
device const bool& odd,
|
||||
device const uint& bytes_per_key,
|
||||
constant const bool& odd,
|
||||
constant const uint& bytes_per_key,
|
||||
uint2 grid_dim [[threads_per_grid]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto kidx = 2 * index.x;
|
||||
@@ -67,8 +67,8 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
||||
[[kernel]] void rbits(
|
||||
device const uint32_t* keys,
|
||||
device char* out,
|
||||
device const bool& odd,
|
||||
device const uint& bytes_per_key,
|
||||
constant const bool& odd,
|
||||
constant const uint& bytes_per_key,
|
||||
constant const int& ndim,
|
||||
constant const int* key_shape,
|
||||
constant const size_t* key_strides,
|
||||
|
||||
Reference in New Issue
Block a user