Compare commits

..

4 Commits

Author SHA1 Message Date
Suryash Malviya
0408ba0a76 Optimizing Complex Matrix Multiplication using Karatsuba’s Algorithm (#2220)
* Implementing Complex Matmul using Karatsuba Algorithm

* Implemented Karatsuba's Algorithm for complex matmul and pre-commit them

* fix

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2025-06-02 15:58:46 -07:00
Awni Hannun
cbad6c3093 version (#2237) 2025-06-02 15:58:33 -07:00
Cheng
1b021f6984 Fast primitives decide when to use the fallback (#2216) 2025-06-02 13:26:37 -07:00
Cheng
95b7551d65 Do not check event.is_signaled() in eval_impl (#2230) 2025-06-02 13:23:34 -07:00
11 changed files with 142 additions and 65 deletions

View File

@@ -43,12 +43,29 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
}); });
} }
bool fast::ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
return true;
}
#define NO_GPU_MULTI(func) \ #define NO_GPU_MULTI(func) \
void func::eval_gpu( \ void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \ const std::vector<array>& inputs, std::vector<array>& outputs) { \
throw std::runtime_error(#func " has no CUDA implementation."); \ throw std::runtime_error(#func " has no CUDA implementation."); \
} }
#define NO_GPU_USE_FALLBACK(func) \
bool func::use_fallback(Stream s) { \
return true; \
} \
NO_GPU_MULTI(func)
#define NO_GPU(func) \ #define NO_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \ void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
throw std::runtime_error(#func " has no CUDA implementation."); \ throw std::runtime_error(#func " has no CUDA implementation."); \
@@ -144,11 +161,11 @@ NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh) NO_GPU_MULTI(Eigh)
namespace fast { namespace fast {
NO_GPU_MULTI(LayerNorm) NO_GPU_USE_FALLBACK(LayerNorm)
NO_GPU_MULTI(LayerNormVJP) NO_GPU_MULTI(LayerNormVJP)
NO_GPU_MULTI(RMSNorm) NO_GPU_USE_FALLBACK(RMSNorm)
NO_GPU_MULTI(RMSNormVJP) NO_GPU_MULTI(RMSNormVJP)
NO_GPU_MULTI(RoPE) NO_GPU_USE_FALLBACK(RoPE)
NO_GPU(ScaledDotProductAttention) NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize) NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel) NO_GPU_MULTI(CustomKernel)

View File

@@ -10,6 +10,10 @@
namespace mlx::core::fast { namespace mlx::core::fast {
bool RMSNorm::use_fallback(Stream s) {
return s.device == Device::cpu;
}
void RMSNorm::eval_gpu( void RMSNorm::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
@@ -207,6 +211,10 @@ void RMSNormVJP::eval_gpu(
} }
} }
bool LayerNorm::use_fallback(Stream s) {
return s.device == Device::cpu;
}
void LayerNorm::eval_gpu( void LayerNorm::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {

View File

@@ -7,6 +7,10 @@ namespace mlx::core::fast {
constexpr int n_per_thread = 4; constexpr int n_per_thread = 4;
bool RoPE::use_fallback(Stream s) {
return s.device == Device::cpu;
}
void RoPE::eval_gpu( void RoPE::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {

View File

@@ -4,10 +4,10 @@
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/steel/attn/params.h" #include "mlx/backend/metal/kernels/steel/attn/params.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
#include "mlx/transforms_impl.h"
#include "mlx/utils.h" #include "mlx/utils.h"
namespace mlx::core::fast { namespace mlx::core::fast {
@@ -339,6 +339,46 @@ void sdpa_vector_2pass(
} // namespace } // namespace
bool ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
if (detail::in_grad_tracing()) {
return true;
}
if (s.device == Device::cpu) {
return true;
}
const int value_head_dim = v.shape(-1);
const int query_head_dim = q.shape(-1);
const int query_sequence_length = q.shape(2);
const int key_sequence_length = k.shape(2);
const bool sdpa_vector_supported_head_dim =
query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
query_head_dim == 256);
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
(query_sequence_length <= key_sequence_length && do_causal);
const bool supports_sdpa_full =
sdpa_full_supported_mask && sdpa_full_supported_head_dim;
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
(query_sequence_length <= key_sequence_length) &&
sdpa_vector_supported_head_dim;
return !(supports_sdpa_full || supports_sdpa_vector);
}
void ScaledDotProductAttention::eval_gpu( void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out) { array& out) {

View File

@@ -10,6 +10,12 @@
throw std::runtime_error(#func " has no GPU implementation."); \ throw std::runtime_error(#func " has no GPU implementation."); \
} }
#define NO_GPU_USE_FALLBACK(func) \
bool func::use_fallback(Stream s) { \
return true; \
} \
NO_GPU_MULTI(func)
#define NO_GPU(func) \ #define NO_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \ void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
throw std::runtime_error(#func " has no GPU implementation."); \ throw std::runtime_error(#func " has no GPU implementation."); \
@@ -17,6 +23,17 @@
namespace mlx::core { namespace mlx::core {
bool fast::ScaledDotProductAttention::use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s) {
return true;
}
NO_GPU(Abs) NO_GPU(Abs)
NO_GPU(Add) NO_GPU(Add)
NO_GPU(AddMM) NO_GPU(AddMM)
@@ -130,11 +147,11 @@ NO_GPU_MULTI(Eig)
NO_GPU(View) NO_GPU(View)
namespace fast { namespace fast {
NO_GPU_MULTI(LayerNorm) NO_GPU_USE_FALLBACK(LayerNorm)
NO_GPU_MULTI(LayerNormVJP) NO_GPU_MULTI(LayerNormVJP)
NO_GPU_MULTI(RMSNorm) NO_GPU_USE_FALLBACK(RMSNorm)
NO_GPU_MULTI(RMSNormVJP) NO_GPU_MULTI(RMSNormVJP)
NO_GPU_MULTI(RoPE) NO_GPU_USE_FALLBACK(RoPE)
NO_GPU(ScaledDotProductAttention) NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize) NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(CustomKernel) NO_GPU_MULTI(CustomKernel)

View File

@@ -9,7 +9,6 @@
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
#include "mlx/ops.h" #include "mlx/ops.h"
#include "mlx/transforms.h" #include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
namespace mlx::core::fast { namespace mlx::core::fast {
@@ -112,7 +111,8 @@ array rms_norm(
auto passed_weight = auto passed_weight =
(has_weight) ? astype(*weight, out_type, s) : array(1, out_type); (has_weight) ? astype(*weight, out_type, s) : array(1, out_type);
if (s.device == Device::gpu) {
if (!RMSNorm::use_fallback(s)) {
return array( return array(
x.shape(), x.shape(),
out_type, out_type,
@@ -256,7 +256,7 @@ array layer_norm(
auto passed_bias = auto passed_bias =
(has_bias) ? astype(*bias, out_type, s) : array(0, out_type); (has_bias) ? astype(*bias, out_type, s) : array(0, out_type);
if (s.device == Device::gpu) { if (!LayerNorm::use_fallback(s)) {
return array( return array(
x.shape(), x.shape(),
out_type, out_type,
@@ -470,7 +470,7 @@ array rope(
} }
}; };
auto stream = to_stream(s); auto stream = to_stream(s);
if (stream.device == Device::gpu) { if (!RoPE::use_fallback(stream)) {
return array( return array(
x.shape(), x.shape(),
x.dtype(), x.dtype(),
@@ -727,31 +727,6 @@ array scaled_dot_product_attention(
}; };
auto stream = to_stream(s); auto stream = to_stream(s);
const int value_head_dim = v.shape(-1);
const int query_head_dim = q.shape(-1);
const int query_sequence_length = q.shape(2);
const int key_sequence_length = k.shape(2);
const bool sdpa_vector_supported_head_dim =
query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
query_head_dim == 256);
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
(query_sequence_length <= key_sequence_length && do_causal);
const bool supports_sdpa_full = sdpa_full_supported_mask &&
sdpa_full_supported_head_dim && stream.device == Device::gpu;
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
(query_sequence_length <= key_sequence_length) &&
sdpa_vector_supported_head_dim && stream.device == Device::gpu;
const bool implementation_supports_use_case =
supports_sdpa_full || supports_sdpa_vector;
std::vector<array> inputs = {q, k, v}; std::vector<array> inputs = {q, k, v};
if (has_arr_mask) { if (has_arr_mask) {
// Check type // Check type
@@ -770,7 +745,8 @@ array scaled_dot_product_attention(
mask_shape.back() = keys.shape(-2); mask_shape.back() = keys.shape(-2);
inputs.push_back(broadcast_to(mask_arr, mask_shape, stream)); inputs.push_back(broadcast_to(mask_arr, mask_shape, stream));
} }
if (!detail::in_grad_tracing() && implementation_supports_use_case) { if (!ScaledDotProductAttention::use_fallback(
q, k, v, has_mask, has_arr_mask, do_causal, stream)) {
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}; auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
return array( return array(
std::move(out_shape), std::move(out_shape),
@@ -779,7 +755,7 @@ array scaled_dot_product_attention(
stream, fallback, scale, do_causal), stream, fallback, scale, do_causal),
std::move(inputs)); std::move(inputs));
} }
return fallback(inputs)[0]; return fallback(std::move(inputs))[0];
} }
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const { bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {

View File

@@ -43,6 +43,8 @@ class RMSNorm : public Custom {
float eps) float eps)
: Custom(stream, fallback), eps_(eps) {} : Custom(stream, fallback), eps_(eps) {}
static bool use_fallback(Stream stream);
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override { override {
throw std::runtime_error("NYI"); throw std::runtime_error("NYI");
@@ -65,7 +67,6 @@ class RMSNorm : public Custom {
} }
private: private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float eps_; float eps_;
}; };
@@ -91,7 +92,6 @@ class RMSNormVJP : public Custom {
} }
private: private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float eps_; float eps_;
}; };
@@ -103,6 +103,8 @@ class LayerNorm : public Custom {
float eps) float eps)
: Custom(stream, fallback), eps_(eps) {} : Custom(stream, fallback), eps_(eps) {}
static bool use_fallback(Stream s);
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override { override {
throw std::runtime_error("NYI"); throw std::runtime_error("NYI");
@@ -124,7 +126,6 @@ class LayerNorm : public Custom {
} }
private: private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float eps_; float eps_;
}; };
@@ -150,7 +151,6 @@ class LayerNormVJP : public Custom {
} }
private: private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float eps_; float eps_;
}; };
@@ -171,6 +171,8 @@ class RoPE : public Custom {
scale_(scale), scale_(scale),
forward_(forward) {} forward_(forward) {}
static bool use_fallback(Stream s);
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override { override {
throw std::runtime_error("NYI"); throw std::runtime_error("NYI");
@@ -193,7 +195,6 @@ class RoPE : public Custom {
} }
private: private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
int dims_; int dims_;
bool traditional_; bool traditional_;
float base_; float base_;
@@ -210,6 +211,15 @@ class ScaledDotProductAttention : public Custom {
const bool do_causal) const bool do_causal)
: Custom(stream, fallback), scale_(scale), do_causal_(do_causal) {} : Custom(stream, fallback), scale_(scale), do_causal_(do_causal) {}
static bool use_fallback(
const array& q,
const array& k,
const array& v,
bool has_mask,
bool has_arr_mask,
bool do_causal,
Stream s);
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override { override {
throw std::runtime_error("NYI"); throw std::runtime_error("NYI");
@@ -230,7 +240,6 @@ class ScaledDotProductAttention : public Custom {
} }
private: private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float scale_; float scale_;
bool do_causal_; bool do_causal_;
}; };
@@ -263,7 +272,6 @@ class AffineQuantize : public Custom {
} }
private: private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
int group_size_; int group_size_;
int bits_; int bits_;
bool dequantize_; bool dequantize_;

View File

@@ -2862,21 +2862,30 @@ array matmul(
<< " second input with shape " << b.shape() << "."; << " second input with shape " << b.shape() << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
// Type promotion
auto out_type = promote_types(a.dtype(), b.dtype()); // complex matmul using Karatsuba's Algorithm
// Complex matmul in terms of real matmuls if (a.dtype() == complex64 || b.dtype() == complex64) {
if (out_type == complex64) { // Extract real and imaginary parts
auto a_real = real(a, s); auto a_real = real(a, s);
auto b_real = real(b, s);
auto a_imag = imag(a, s); auto a_imag = imag(a, s);
auto b_real = real(b, s);
auto b_imag = imag(b, s); auto b_imag = imag(b, s);
auto c_real =
subtract(matmul(a_real, b_real, s), matmul(a_imag, b_imag, s), s); // Compute real and imaginary components of the result
auto c_imag = add(matmul(a_real, b_imag, s), matmul(a_imag, b_real, s), s); auto m1 = matmul(a_real, b_real, s);
auto m2 = matmul(a_imag, b_imag, s);
auto m3 = matmul(add(a_real, a_imag, s), add(b_real, b_imag, s), s);
auto c_real = subtract(m1, m2, s);
auto c_imag = subtract(m3, add(m1, m2, s), s);
return add( return add(
c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s); c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s);
} }
// Type promotion
auto out_type = promote_types(a.dtype(), b.dtype());
if (!issubdtype(out_type, floating)) { if (!issubdtype(out_type, floating)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[matmul] Only real floating point types are supported but " msg << "[matmul] Only real floating point types are supported but "

View File

@@ -208,9 +208,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
// output arrays stream // output arrays stream
fences[it->second].wait(stream, in); fences[it->second].wait(stream, in);
} else if (in.event().valid()) { } else if (in.event().valid()) {
if (in.event().is_signaled()) { if (in.event().stream() != stream) {
in.detach_event();
} else if (in.event().stream() != stream) {
// Use event to wait across async eval // Use event to wait across async eval
in.event().wait(stream); in.event().wait(stream);
} }

View File

@@ -3,8 +3,8 @@
#pragma once #pragma once
#define MLX_VERSION_MAJOR 0 #define MLX_VERSION_MAJOR 0
#define MLX_VERSION_MINOR 25 #define MLX_VERSION_MINOR 26
#define MLX_VERSION_PATCH 2 #define MLX_VERSION_PATCH 0
#define MLX_VERSION_NUMERIC \ #define MLX_VERSION_NUMERIC \
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)

View File

@@ -1210,13 +1210,6 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(c, c_np)) self.assertTrue(np.allclose(c, c_np))
# Test addmm # Test addmm
M = 16
K = 50
N = 32
def rand(shape):
return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape)
a = rand((M, K)) a = rand((M, K))
b = rand((K, N)) b = rand((K, N))
c = rand((M, N)) c = rand((M, N))
@@ -1224,6 +1217,13 @@ class TestBlas(mlx_tests.MLXTestCase):
out_np = 2.0 * np.matmul(a, b) + 2.0 * c out_np = 2.0 * np.matmul(a, b) + 2.0 * c
self.assertTrue(np.allclose(out, out_np)) self.assertTrue(np.allclose(out, out_np))
# complex with real
a = rand((M, K)).real
b = rand((K, N))
c = mx.matmul(a, b)
c_np = np.matmul(a, b)
self.assertTrue(np.allclose(out, out_np))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()