mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
fix broadcast bug in bitwise ops (#1157)
This commit is contained in:
parent
9f9cb7a2ef
commit
a87ef5bfc1
@ -186,8 +186,8 @@ should point to the path to the built metal library.
|
|||||||
Binary Size Minimization
|
Binary Size Minimization
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
To produce a smaller binary use the CMake flags `CMAKE_BUILD_TYPE=MinSizeRel`
|
To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``
|
||||||
and `BUILD_SHARED_LIBS=ON`.
|
and ``BUILD_SHARED_LIBS=ON``.
|
||||||
|
|
||||||
The MLX CMake build has several additional options to make smaller binaries.
|
The MLX CMake build has several additional options to make smaller binaries.
|
||||||
For example, if you don't need the CPU backend or support for safetensors and
|
For example, if you don't need the CPU backend or support for safetensors and
|
||||||
@ -203,7 +203,7 @@ GGUF, you can do:
|
|||||||
-DMLX_BUILD_GGUF=OFF \
|
-DMLX_BUILD_GGUF=OFF \
|
||||||
-DMLX_METAL_JIT=ON
|
-DMLX_METAL_JIT=ON
|
||||||
|
|
||||||
THE `MLX_METAL_JIT` flag minimizes the size of the MLX Metal library which
|
THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which
|
||||||
contains pre-built GPU kernels. This substantially reduces the size of the
|
contains pre-built GPU kernels. This substantially reduces the size of the
|
||||||
Metal library by run-time compiling kernels the first time they are used in MLX
|
Metal library by run-time compiling kernels the first time they are used in MLX
|
||||||
on a given machine. Note run-time compilation incurs a cold-start cost which can
|
on a given machine. Note run-time compilation incurs a cold-start cost which can
|
||||||
|
@ -4321,8 +4321,9 @@ array bitwise_impl(
|
|||||||
}
|
}
|
||||||
auto inputs =
|
auto inputs =
|
||||||
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
||||||
|
auto& out_shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
a.shape(),
|
out_shape,
|
||||||
out_type,
|
out_type,
|
||||||
std::make_shared<BitwiseBinary>(to_stream(s), op),
|
std::make_shared<BitwiseBinary>(to_stream(s), op),
|
||||||
std::move(inputs));
|
std::move(inputs));
|
||||||
|
@ -2291,6 +2291,13 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
out_np = getattr(np, op)(a_np, b_np)
|
out_np = getattr(np, op)(a_np, b_np)
|
||||||
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
|
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
|
||||||
|
|
||||||
|
# Check broadcasting
|
||||||
|
a = mx.ones((3, 1, 5), dtype=mx.bool_)
|
||||||
|
b = mx.zeros((1, 2, 5), dtype=mx.bool_)
|
||||||
|
c = a | b
|
||||||
|
self.assertEqual(c.shape, (3, 2, 5))
|
||||||
|
self.assertTrue(mx.array_equal(c, mx.ones((3, 2, 5), dtype=mx.bool_)))
|
||||||
|
|
||||||
def test_conjugate(self):
|
def test_conjugate(self):
|
||||||
shape = (3, 5, 7)
|
shape = (3, 5, 7)
|
||||||
a = np.random.normal(size=shape) + 1j * np.random.normal(size=shape)
|
a = np.random.normal(size=shape) + 1j * np.random.normal(size=shape)
|
||||||
|
Loading…
Reference in New Issue
Block a user