mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Fix boolean all reduce bug (#1355)
This commit is contained in:
parent
64bec4fad7
commit
8081df79be
@ -24,7 +24,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
|||||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||||
|
|
||||||
if(NOT MLX_VERSION)
|
if(NOT MLX_VERSION)
|
||||||
set(MLX_VERSION 0.17.0)
|
set(MLX_VERSION 0.17.1)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# --------------------- Processor tests -------------------------
|
# --------------------- Processor tests -------------------------
|
||||||
|
@ -308,7 +308,11 @@ void all_reduce_dispatch(
|
|||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
|
|
||||||
// 2nd pass
|
// 2nd pass
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
std::ostringstream kname_2nd_pass;
|
||||||
|
kname_2nd_pass << "all_reduce_" << op_name << type_to_name(intermediate);
|
||||||
|
auto kernel_2nd_pass =
|
||||||
|
get_reduce_kernel(d, kname_2nd_pass.str(), op_name, intermediate, out);
|
||||||
|
compute_encoder->setComputePipelineState(kernel_2nd_pass);
|
||||||
size_t intermediate_size = n_rows;
|
size_t intermediate_size = n_rows;
|
||||||
grid_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);
|
grid_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);
|
||||||
group_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);
|
group_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);
|
||||||
|
@ -124,6 +124,13 @@ class TestReduce(mlx_tests.MLXTestCase):
|
|||||||
z = np.array(x).sum((0, 2, 3))
|
z = np.array(x).sum((0, 2, 3))
|
||||||
self.assertTrue(np.all(z == y))
|
self.assertTrue(np.all(z == y))
|
||||||
|
|
||||||
|
def test_sum_bool(self):
|
||||||
|
x = np.random.uniform(0, 1, size=(10, 10, 10)) > 0.5
|
||||||
|
y = mx.array(x)
|
||||||
|
npsum = x.sum().item()
|
||||||
|
mxsum = y.sum().item()
|
||||||
|
self.assertEqual(npsum, mxsum)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main(failfast=True)
|
unittest.main(failfast=True)
|
||||||
|
2
setup.py
2
setup.py
@ -163,7 +163,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="mlx",
|
name="mlx",
|
||||||
version=get_version("0.17.0"),
|
version=get_version("0.17.1"),
|
||||||
author="MLX Contributors",
|
author="MLX Contributors",
|
||||||
author_email="mlx@group.apple.com",
|
author_email="mlx@group.apple.com",
|
||||||
description="A framework for machine learning on Apple silicon.",
|
description="A framework for machine learning on Apple silicon.",
|
||||||
|
Loading…
Reference in New Issue
Block a user