mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
fix broadcast bug in bitwise ops (#1157)
This commit is contained in:
parent
9f9cb7a2ef
commit
a87ef5bfc1
@ -186,8 +186,8 @@ should point to the path to the built metal library.
|
||||
Binary Size Minimization
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
To produce a smaller binary use the CMake flags `CMAKE_BUILD_TYPE=MinSizeRel`
|
||||
and `BUILD_SHARED_LIBS=ON`.
|
||||
To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``
|
||||
and ``BUILD_SHARED_LIBS=ON``.
|
||||
|
||||
The MLX CMake build has several additional options to make smaller binaries.
|
||||
For example, if you don't need the CPU backend or support for safetensors and
|
||||
@ -203,7 +203,7 @@ GGUF, you can do:
|
||||
-DMLX_BUILD_GGUF=OFF \
|
||||
-DMLX_METAL_JIT=ON
|
||||
|
||||
THE `MLX_METAL_JIT` flag minimizes the size of the MLX Metal library which
|
||||
THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which
|
||||
contains pre-built GPU kernels. This substantially reduces the size of the
|
||||
Metal library by run-time compiling kernels the first time they are used in MLX
|
||||
on a given machine. Note run-time compilation incurs a cold-start cost which can
|
||||
|
@ -4321,8 +4321,9 @@ array bitwise_impl(
|
||||
}
|
||||
auto inputs =
|
||||
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
||||
auto& out_shape = inputs[0].shape();
|
||||
return array(
|
||||
a.shape(),
|
||||
out_shape,
|
||||
out_type,
|
||||
std::make_shared<BitwiseBinary>(to_stream(s), op),
|
||||
std::move(inputs));
|
||||
|
@ -2291,6 +2291,13 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out_np = getattr(np, op)(a_np, b_np)
|
||||
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
|
||||
|
||||
# Check broadcasting
|
||||
a = mx.ones((3, 1, 5), dtype=mx.bool_)
|
||||
b = mx.zeros((1, 2, 5), dtype=mx.bool_)
|
||||
c = a | b
|
||||
self.assertEqual(c.shape, (3, 2, 5))
|
||||
self.assertTrue(mx.array_equal(c, mx.ones((3, 2, 5), dtype=mx.bool_)))
|
||||
|
||||
def test_conjugate(self):
|
||||
shape = (3, 5, 7)
|
||||
a = np.random.normal(size=shape) + 1j * np.random.normal(size=shape)
|
||||
|
Loading…
Reference in New Issue
Block a user