mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
4 Commits
db5a7c6192
...
v0.26.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0408ba0a76 | ||
|
|
cbad6c3093 | ||
|
|
1b021f6984 | ||
|
|
95b7551d65 |
@@ -43,12 +43,29 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool fast::ScaledDotProductAttention::use_fallback(
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
bool has_mask,
|
||||||
|
bool has_arr_mask,
|
||||||
|
bool do_causal,
|
||||||
|
Stream s) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
#define NO_GPU_MULTI(func) \
|
#define NO_GPU_MULTI(func) \
|
||||||
void func::eval_gpu( \
|
void func::eval_gpu( \
|
||||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||||
throw std::runtime_error(#func " has no CUDA implementation."); \
|
throw std::runtime_error(#func " has no CUDA implementation."); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define NO_GPU_USE_FALLBACK(func) \
|
||||||
|
bool func::use_fallback(Stream s) { \
|
||||||
|
return true; \
|
||||||
|
} \
|
||||||
|
NO_GPU_MULTI(func)
|
||||||
|
|
||||||
#define NO_GPU(func) \
|
#define NO_GPU(func) \
|
||||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||||
throw std::runtime_error(#func " has no CUDA implementation."); \
|
throw std::runtime_error(#func " has no CUDA implementation."); \
|
||||||
@@ -144,11 +161,11 @@ NO_GPU_MULTI(Eig)
|
|||||||
NO_GPU_MULTI(Eigh)
|
NO_GPU_MULTI(Eigh)
|
||||||
|
|
||||||
namespace fast {
|
namespace fast {
|
||||||
NO_GPU_MULTI(LayerNorm)
|
NO_GPU_USE_FALLBACK(LayerNorm)
|
||||||
NO_GPU_MULTI(LayerNormVJP)
|
NO_GPU_MULTI(LayerNormVJP)
|
||||||
NO_GPU_MULTI(RMSNorm)
|
NO_GPU_USE_FALLBACK(RMSNorm)
|
||||||
NO_GPU_MULTI(RMSNormVJP)
|
NO_GPU_MULTI(RMSNormVJP)
|
||||||
NO_GPU_MULTI(RoPE)
|
NO_GPU_USE_FALLBACK(RoPE)
|
||||||
NO_GPU(ScaledDotProductAttention)
|
NO_GPU(ScaledDotProductAttention)
|
||||||
NO_GPU_MULTI(AffineQuantize)
|
NO_GPU_MULTI(AffineQuantize)
|
||||||
NO_GPU_MULTI(CustomKernel)
|
NO_GPU_MULTI(CustomKernel)
|
||||||
|
|||||||
@@ -10,6 +10,10 @@
|
|||||||
|
|
||||||
namespace mlx::core::fast {
|
namespace mlx::core::fast {
|
||||||
|
|
||||||
|
bool RMSNorm::use_fallback(Stream s) {
|
||||||
|
return s.device == Device::cpu;
|
||||||
|
}
|
||||||
|
|
||||||
void RMSNorm::eval_gpu(
|
void RMSNorm::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
@@ -207,6 +211,10 @@ void RMSNormVJP::eval_gpu(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool LayerNorm::use_fallback(Stream s) {
|
||||||
|
return s.device == Device::cpu;
|
||||||
|
}
|
||||||
|
|
||||||
void LayerNorm::eval_gpu(
|
void LayerNorm::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
|
|||||||
@@ -7,6 +7,10 @@ namespace mlx::core::fast {
|
|||||||
|
|
||||||
constexpr int n_per_thread = 4;
|
constexpr int n_per_thread = 4;
|
||||||
|
|
||||||
|
bool RoPE::use_fallback(Stream s) {
|
||||||
|
return s.device == Device::cpu;
|
||||||
|
}
|
||||||
|
|
||||||
void RoPE::eval_gpu(
|
void RoPE::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
|
|||||||
@@ -4,10 +4,10 @@
|
|||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/steel/attn/params.h"
|
#include "mlx/backend/metal/kernels/steel/attn/params.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
|
#include "mlx/transforms_impl.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core::fast {
|
namespace mlx::core::fast {
|
||||||
@@ -339,6 +339,46 @@ void sdpa_vector_2pass(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
bool ScaledDotProductAttention::use_fallback(
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
bool has_mask,
|
||||||
|
bool has_arr_mask,
|
||||||
|
bool do_causal,
|
||||||
|
Stream s) {
|
||||||
|
if (detail::in_grad_tracing()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (s.device == Device::cpu) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int value_head_dim = v.shape(-1);
|
||||||
|
const int query_head_dim = q.shape(-1);
|
||||||
|
const int query_sequence_length = q.shape(2);
|
||||||
|
const int key_sequence_length = k.shape(2);
|
||||||
|
|
||||||
|
const bool sdpa_vector_supported_head_dim =
|
||||||
|
query_head_dim == value_head_dim &&
|
||||||
|
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
|
||||||
|
query_head_dim == 256);
|
||||||
|
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
|
||||||
|
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
|
||||||
|
|
||||||
|
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
|
||||||
|
(query_sequence_length <= key_sequence_length && do_causal);
|
||||||
|
|
||||||
|
const bool supports_sdpa_full =
|
||||||
|
sdpa_full_supported_mask && sdpa_full_supported_head_dim;
|
||||||
|
|
||||||
|
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
|
||||||
|
(query_sequence_length <= key_sequence_length) &&
|
||||||
|
sdpa_vector_supported_head_dim;
|
||||||
|
|
||||||
|
return !(supports_sdpa_full || supports_sdpa_vector);
|
||||||
|
}
|
||||||
|
|
||||||
void ScaledDotProductAttention::eval_gpu(
|
void ScaledDotProductAttention::eval_gpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
array& out) {
|
array& out) {
|
||||||
|
|||||||
@@ -10,6 +10,12 @@
|
|||||||
throw std::runtime_error(#func " has no GPU implementation."); \
|
throw std::runtime_error(#func " has no GPU implementation."); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define NO_GPU_USE_FALLBACK(func) \
|
||||||
|
bool func::use_fallback(Stream s) { \
|
||||||
|
return true; \
|
||||||
|
} \
|
||||||
|
NO_GPU_MULTI(func)
|
||||||
|
|
||||||
#define NO_GPU(func) \
|
#define NO_GPU(func) \
|
||||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||||
throw std::runtime_error(#func " has no GPU implementation."); \
|
throw std::runtime_error(#func " has no GPU implementation."); \
|
||||||
@@ -17,6 +23,17 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
bool fast::ScaledDotProductAttention::use_fallback(
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
bool has_mask,
|
||||||
|
bool has_arr_mask,
|
||||||
|
bool do_causal,
|
||||||
|
Stream s) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
NO_GPU(Abs)
|
NO_GPU(Abs)
|
||||||
NO_GPU(Add)
|
NO_GPU(Add)
|
||||||
NO_GPU(AddMM)
|
NO_GPU(AddMM)
|
||||||
@@ -130,11 +147,11 @@ NO_GPU_MULTI(Eig)
|
|||||||
NO_GPU(View)
|
NO_GPU(View)
|
||||||
|
|
||||||
namespace fast {
|
namespace fast {
|
||||||
NO_GPU_MULTI(LayerNorm)
|
NO_GPU_USE_FALLBACK(LayerNorm)
|
||||||
NO_GPU_MULTI(LayerNormVJP)
|
NO_GPU_MULTI(LayerNormVJP)
|
||||||
NO_GPU_MULTI(RMSNorm)
|
NO_GPU_USE_FALLBACK(RMSNorm)
|
||||||
NO_GPU_MULTI(RMSNormVJP)
|
NO_GPU_MULTI(RMSNormVJP)
|
||||||
NO_GPU_MULTI(RoPE)
|
NO_GPU_USE_FALLBACK(RoPE)
|
||||||
NO_GPU(ScaledDotProductAttention)
|
NO_GPU(ScaledDotProductAttention)
|
||||||
NO_GPU_MULTI(AffineQuantize)
|
NO_GPU_MULTI(AffineQuantize)
|
||||||
NO_GPU_MULTI(CustomKernel)
|
NO_GPU_MULTI(CustomKernel)
|
||||||
|
|||||||
38
mlx/fast.cpp
38
mlx/fast.cpp
@@ -9,7 +9,6 @@
|
|||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/transforms.h"
|
#include "mlx/transforms.h"
|
||||||
#include "mlx/transforms_impl.h"
|
|
||||||
|
|
||||||
namespace mlx::core::fast {
|
namespace mlx::core::fast {
|
||||||
|
|
||||||
@@ -112,7 +111,8 @@ array rms_norm(
|
|||||||
|
|
||||||
auto passed_weight =
|
auto passed_weight =
|
||||||
(has_weight) ? astype(*weight, out_type, s) : array(1, out_type);
|
(has_weight) ? astype(*weight, out_type, s) : array(1, out_type);
|
||||||
if (s.device == Device::gpu) {
|
|
||||||
|
if (!RMSNorm::use_fallback(s)) {
|
||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
out_type,
|
out_type,
|
||||||
@@ -256,7 +256,7 @@ array layer_norm(
|
|||||||
auto passed_bias =
|
auto passed_bias =
|
||||||
(has_bias) ? astype(*bias, out_type, s) : array(0, out_type);
|
(has_bias) ? astype(*bias, out_type, s) : array(0, out_type);
|
||||||
|
|
||||||
if (s.device == Device::gpu) {
|
if (!LayerNorm::use_fallback(s)) {
|
||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
out_type,
|
out_type,
|
||||||
@@ -470,7 +470,7 @@ array rope(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
auto stream = to_stream(s);
|
auto stream = to_stream(s);
|
||||||
if (stream.device == Device::gpu) {
|
if (!RoPE::use_fallback(stream)) {
|
||||||
return array(
|
return array(
|
||||||
x.shape(),
|
x.shape(),
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
@@ -727,31 +727,6 @@ array scaled_dot_product_attention(
|
|||||||
};
|
};
|
||||||
|
|
||||||
auto stream = to_stream(s);
|
auto stream = to_stream(s);
|
||||||
const int value_head_dim = v.shape(-1);
|
|
||||||
const int query_head_dim = q.shape(-1);
|
|
||||||
const int query_sequence_length = q.shape(2);
|
|
||||||
const int key_sequence_length = k.shape(2);
|
|
||||||
|
|
||||||
const bool sdpa_vector_supported_head_dim =
|
|
||||||
query_head_dim == value_head_dim &&
|
|
||||||
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
|
|
||||||
query_head_dim == 256);
|
|
||||||
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
|
|
||||||
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
|
|
||||||
|
|
||||||
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
|
|
||||||
(query_sequence_length <= key_sequence_length && do_causal);
|
|
||||||
|
|
||||||
const bool supports_sdpa_full = sdpa_full_supported_mask &&
|
|
||||||
sdpa_full_supported_head_dim && stream.device == Device::gpu;
|
|
||||||
|
|
||||||
const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
|
|
||||||
(query_sequence_length <= key_sequence_length) &&
|
|
||||||
sdpa_vector_supported_head_dim && stream.device == Device::gpu;
|
|
||||||
|
|
||||||
const bool implementation_supports_use_case =
|
|
||||||
supports_sdpa_full || supports_sdpa_vector;
|
|
||||||
|
|
||||||
std::vector<array> inputs = {q, k, v};
|
std::vector<array> inputs = {q, k, v};
|
||||||
if (has_arr_mask) {
|
if (has_arr_mask) {
|
||||||
// Check type
|
// Check type
|
||||||
@@ -770,7 +745,8 @@ array scaled_dot_product_attention(
|
|||||||
mask_shape.back() = keys.shape(-2);
|
mask_shape.back() = keys.shape(-2);
|
||||||
inputs.push_back(broadcast_to(mask_arr, mask_shape, stream));
|
inputs.push_back(broadcast_to(mask_arr, mask_shape, stream));
|
||||||
}
|
}
|
||||||
if (!detail::in_grad_tracing() && implementation_supports_use_case) {
|
if (!ScaledDotProductAttention::use_fallback(
|
||||||
|
q, k, v, has_mask, has_arr_mask, do_causal, stream)) {
|
||||||
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
|
||||||
return array(
|
return array(
|
||||||
std::move(out_shape),
|
std::move(out_shape),
|
||||||
@@ -779,7 +755,7 @@ array scaled_dot_product_attention(
|
|||||||
stream, fallback, scale, do_causal),
|
stream, fallback, scale, do_causal),
|
||||||
std::move(inputs));
|
std::move(inputs));
|
||||||
}
|
}
|
||||||
return fallback(inputs)[0];
|
return fallback(std::move(inputs))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
|
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
|
||||||
|
|||||||
@@ -43,6 +43,8 @@ class RMSNorm : public Custom {
|
|||||||
float eps)
|
float eps)
|
||||||
: Custom(stream, fallback), eps_(eps) {}
|
: Custom(stream, fallback), eps_(eps) {}
|
||||||
|
|
||||||
|
static bool use_fallback(Stream stream);
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override {
|
override {
|
||||||
throw std::runtime_error("NYI");
|
throw std::runtime_error("NYI");
|
||||||
@@ -65,7 +67,6 @@ class RMSNorm : public Custom {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
|
||||||
float eps_;
|
float eps_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -91,7 +92,6 @@ class RMSNormVJP : public Custom {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
|
||||||
float eps_;
|
float eps_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -103,6 +103,8 @@ class LayerNorm : public Custom {
|
|||||||
float eps)
|
float eps)
|
||||||
: Custom(stream, fallback), eps_(eps) {}
|
: Custom(stream, fallback), eps_(eps) {}
|
||||||
|
|
||||||
|
static bool use_fallback(Stream s);
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override {
|
override {
|
||||||
throw std::runtime_error("NYI");
|
throw std::runtime_error("NYI");
|
||||||
@@ -124,7 +126,6 @@ class LayerNorm : public Custom {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
|
||||||
float eps_;
|
float eps_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -150,7 +151,6 @@ class LayerNormVJP : public Custom {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
|
||||||
float eps_;
|
float eps_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -171,6 +171,8 @@ class RoPE : public Custom {
|
|||||||
scale_(scale),
|
scale_(scale),
|
||||||
forward_(forward) {}
|
forward_(forward) {}
|
||||||
|
|
||||||
|
static bool use_fallback(Stream s);
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override {
|
override {
|
||||||
throw std::runtime_error("NYI");
|
throw std::runtime_error("NYI");
|
||||||
@@ -193,7 +195,6 @@ class RoPE : public Custom {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
|
||||||
int dims_;
|
int dims_;
|
||||||
bool traditional_;
|
bool traditional_;
|
||||||
float base_;
|
float base_;
|
||||||
@@ -210,6 +211,15 @@ class ScaledDotProductAttention : public Custom {
|
|||||||
const bool do_causal)
|
const bool do_causal)
|
||||||
: Custom(stream, fallback), scale_(scale), do_causal_(do_causal) {}
|
: Custom(stream, fallback), scale_(scale), do_causal_(do_causal) {}
|
||||||
|
|
||||||
|
static bool use_fallback(
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
bool has_mask,
|
||||||
|
bool has_arr_mask,
|
||||||
|
bool do_causal,
|
||||||
|
Stream s);
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override {
|
override {
|
||||||
throw std::runtime_error("NYI");
|
throw std::runtime_error("NYI");
|
||||||
@@ -230,7 +240,6 @@ class ScaledDotProductAttention : public Custom {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
|
||||||
float scale_;
|
float scale_;
|
||||||
bool do_causal_;
|
bool do_causal_;
|
||||||
};
|
};
|
||||||
@@ -263,7 +272,6 @@ class AffineQuantize : public Custom {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
|
||||||
int group_size_;
|
int group_size_;
|
||||||
int bits_;
|
int bits_;
|
||||||
bool dequantize_;
|
bool dequantize_;
|
||||||
|
|||||||
25
mlx/ops.cpp
25
mlx/ops.cpp
@@ -2862,21 +2862,30 @@ array matmul(
|
|||||||
<< " second input with shape " << b.shape() << ".";
|
<< " second input with shape " << b.shape() << ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
// Type promotion
|
|
||||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
// complex matmul using Karatsuba's Algorithm
|
||||||
// Complex matmul in terms of real matmuls
|
if (a.dtype() == complex64 || b.dtype() == complex64) {
|
||||||
if (out_type == complex64) {
|
// Extract real and imaginary parts
|
||||||
auto a_real = real(a, s);
|
auto a_real = real(a, s);
|
||||||
auto b_real = real(b, s);
|
|
||||||
auto a_imag = imag(a, s);
|
auto a_imag = imag(a, s);
|
||||||
|
auto b_real = real(b, s);
|
||||||
auto b_imag = imag(b, s);
|
auto b_imag = imag(b, s);
|
||||||
auto c_real =
|
|
||||||
subtract(matmul(a_real, b_real, s), matmul(a_imag, b_imag, s), s);
|
// Compute real and imaginary components of the result
|
||||||
auto c_imag = add(matmul(a_real, b_imag, s), matmul(a_imag, b_real, s), s);
|
auto m1 = matmul(a_real, b_real, s);
|
||||||
|
auto m2 = matmul(a_imag, b_imag, s);
|
||||||
|
auto m3 = matmul(add(a_real, a_imag, s), add(b_real, b_imag, s), s);
|
||||||
|
|
||||||
|
auto c_real = subtract(m1, m2, s);
|
||||||
|
auto c_imag = subtract(m3, add(m1, m2, s), s);
|
||||||
|
|
||||||
return add(
|
return add(
|
||||||
c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s);
|
c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Type promotion
|
||||||
|
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||||
|
|
||||||
if (!issubdtype(out_type, floating)) {
|
if (!issubdtype(out_type, floating)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[matmul] Only real floating point types are supported but "
|
msg << "[matmul] Only real floating point types are supported but "
|
||||||
|
|||||||
@@ -208,9 +208,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
// output arrays stream
|
// output arrays stream
|
||||||
fences[it->second].wait(stream, in);
|
fences[it->second].wait(stream, in);
|
||||||
} else if (in.event().valid()) {
|
} else if (in.event().valid()) {
|
||||||
if (in.event().is_signaled()) {
|
if (in.event().stream() != stream) {
|
||||||
in.detach_event();
|
|
||||||
} else if (in.event().stream() != stream) {
|
|
||||||
// Use event to wait across async eval
|
// Use event to wait across async eval
|
||||||
in.event().wait(stream);
|
in.event().wait(stream);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,8 +3,8 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#define MLX_VERSION_MAJOR 0
|
#define MLX_VERSION_MAJOR 0
|
||||||
#define MLX_VERSION_MINOR 25
|
#define MLX_VERSION_MINOR 26
|
||||||
#define MLX_VERSION_PATCH 2
|
#define MLX_VERSION_PATCH 0
|
||||||
#define MLX_VERSION_NUMERIC \
|
#define MLX_VERSION_NUMERIC \
|
||||||
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
||||||
|
|
||||||
|
|||||||
@@ -1210,13 +1210,6 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(np.allclose(c, c_np))
|
self.assertTrue(np.allclose(c, c_np))
|
||||||
|
|
||||||
# Test addmm
|
# Test addmm
|
||||||
M = 16
|
|
||||||
K = 50
|
|
||||||
N = 32
|
|
||||||
|
|
||||||
def rand(shape):
|
|
||||||
return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape)
|
|
||||||
|
|
||||||
a = rand((M, K))
|
a = rand((M, K))
|
||||||
b = rand((K, N))
|
b = rand((K, N))
|
||||||
c = rand((M, N))
|
c = rand((M, N))
|
||||||
@@ -1224,6 +1217,13 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
out_np = 2.0 * np.matmul(a, b) + 2.0 * c
|
out_np = 2.0 * np.matmul(a, b) + 2.0 * c
|
||||||
self.assertTrue(np.allclose(out, out_np))
|
self.assertTrue(np.allclose(out, out_np))
|
||||||
|
|
||||||
|
# complex with real
|
||||||
|
a = rand((M, K)).real
|
||||||
|
b = rand((K, N))
|
||||||
|
c = mx.matmul(a, b)
|
||||||
|
c_np = np.matmul(a, b)
|
||||||
|
self.assertTrue(np.allclose(out, out_np))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user