mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
Faster bits and bernoulli (#1535)
* faster bits and bernoulli * fix bernoulli
This commit is contained in:
@@ -34,8 +34,8 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
||||
[[kernel]] void rbitsc(
|
||||
device const uint32_t* keys,
|
||||
device char* out,
|
||||
device const bool& odd,
|
||||
device const uint& bytes_per_key,
|
||||
constant const bool& odd,
|
||||
constant const uint& bytes_per_key,
|
||||
uint2 grid_dim [[threads_per_grid]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto kidx = 2 * index.x;
|
||||
@@ -67,8 +67,8 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
|
||||
[[kernel]] void rbits(
|
||||
device const uint32_t* keys,
|
||||
device char* out,
|
||||
device const bool& odd,
|
||||
device const uint& bytes_per_key,
|
||||
constant const bool& odd,
|
||||
constant const uint& bytes_per_key,
|
||||
constant const int& ndim,
|
||||
constant const int* key_shape,
|
||||
constant const size_t* key_strides,
|
||||
|
@@ -273,7 +273,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// organize into grid nkeys x elem_per_key
|
||||
MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(1, thread_group_size, 1);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(keys, 0);
|
||||
|
Reference in New Issue
Block a user