fix broadcast bug in bitwise ops (#1157)

This commit is contained in:
Awni Hannun 2024-05-24 11:44:40 -07:00 committed by GitHub
parent 9f9cb7a2ef
commit a87ef5bfc1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 12 additions and 4 deletions

View File

@ -186,8 +186,8 @@ should point to the path to the built metal library.
Binary Size Minimization Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~
To produce a smaller binary use the CMake flags `CMAKE_BUILD_TYPE=MinSizeRel` To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``
and `BUILD_SHARED_LIBS=ON`. and ``BUILD_SHARED_LIBS=ON``.
The MLX CMake build has several additional options to make smaller binaries. The MLX CMake build has several additional options to make smaller binaries.
For example, if you don't need the CPU backend or support for safetensors and For example, if you don't need the CPU backend or support for safetensors and
@ -203,7 +203,7 @@ GGUF, you can do:
-DMLX_BUILD_GGUF=OFF \ -DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON -DMLX_METAL_JIT=ON
THE `MLX_METAL_JIT` flag minimizes the size of the MLX Metal library which THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which
contains pre-built GPU kernels. This substantially reduces the size of the contains pre-built GPU kernels. This substantially reduces the size of the
Metal library by run-time compiling kernels the first time they are used in MLX Metal library by run-time compiling kernels the first time they are used in MLX
on a given machine. Note run-time compilation incurs a cold-start cost which can on a given machine. Note run-time compilation incurs a cold-start cost which can

View File

@ -4321,8 +4321,9 @@ array bitwise_impl(
} }
auto inputs = auto inputs =
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s); broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
auto& out_shape = inputs[0].shape();
return array( return array(
a.shape(), out_shape,
out_type, out_type,
std::make_shared<BitwiseBinary>(to_stream(s), op), std::make_shared<BitwiseBinary>(to_stream(s), op),
std::move(inputs)); std::move(inputs));

View File

@ -2291,6 +2291,13 @@ class TestOps(mlx_tests.MLXTestCase):
out_np = getattr(np, op)(a_np, b_np) out_np = getattr(np, op)(a_np, b_np)
self.assertTrue(np.array_equal(np.array(out_mlx), out_np)) self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
# Check broadcasting
a = mx.ones((3, 1, 5), dtype=mx.bool_)
b = mx.zeros((1, 2, 5), dtype=mx.bool_)
c = a | b
self.assertEqual(c.shape, (3, 2, 5))
self.assertTrue(mx.array_equal(c, mx.ones((3, 2, 5), dtype=mx.bool_)))
def test_conjugate(self): def test_conjugate(self):
shape = (3, 5, 7) shape = (3, 5, 7)
a = np.random.normal(size=shape) + 1j * np.random.normal(size=shape) a = np.random.normal(size=shape) + 1j * np.random.normal(size=shape)