fix broadcast bug in bitwise ops (#1157)

This commit is contained in:
Awni Hannun 2024-05-24 11:44:40 -07:00 committed by GitHub
parent 9f9cb7a2ef
commit a87ef5bfc1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 12 additions and 4 deletions

View File

@ -186,8 +186,8 @@ should point to the path to the built metal library.
Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~
To produce a smaller binary use the CMake flags `CMAKE_BUILD_TYPE=MinSizeRel`
and `BUILD_SHARED_LIBS=ON`.
To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``
and ``BUILD_SHARED_LIBS=ON``.
The MLX CMake build has several additional options to make smaller binaries.
For example, if you don't need the CPU backend or support for safetensors and
@ -203,7 +203,7 @@ GGUF, you can do:
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
THE `MLX_METAL_JIT` flag minimizes the size of the MLX Metal library which
THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which
contains pre-built GPU kernels. This substantially reduces the size of the
Metal library by run-time compiling kernels the first time they are used in MLX
on a given machine. Note run-time compilation incurs a cold-start cost which can

View File

@ -4321,8 +4321,9 @@ array bitwise_impl(
}
auto inputs =
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
auto& out_shape = inputs[0].shape();
return array(
a.shape(),
out_shape,
out_type,
std::make_shared<BitwiseBinary>(to_stream(s), op),
std::move(inputs));

View File

@ -2291,6 +2291,13 @@ class TestOps(mlx_tests.MLXTestCase):
out_np = getattr(np, op)(a_np, b_np)
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
# Check broadcasting
a = mx.ones((3, 1, 5), dtype=mx.bool_)
b = mx.zeros((1, 2, 5), dtype=mx.bool_)
c = a | b
self.assertEqual(c.shape, (3, 2, 5))
self.assertTrue(mx.array_equal(c, mx.ones((3, 2, 5), dtype=mx.bool_)))
def test_conjugate(self):
shape = (3, 5, 7)
a = np.random.normal(size=shape) + 1j * np.random.normal(size=shape)