Merge branch 'ml-explore:main' into feature/metal-svd-base

This commit is contained in:
Arkar Min Aung 2025-06-14 16:53:43 +10:00 committed by GitHub
commit c67eea520e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 792 additions and 688 deletions

View File

@ -194,6 +194,13 @@ struct Power {
}
return res;
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (base.y == 0 && base.x == 0) {
if (isnan(exp.x) || isnan(exp.y)) {
auto nan = cuda::std::numeric_limits<float>::quiet_NaN();
return make_cuFloatComplex(nan, nan);
}
return make_cuFloatComplex(0.0, 0.0);
}
auto x_theta = atan2f(base.y, base.x);
auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y);
auto mag = expf(exp.x * x_ln_r - exp.y * x_theta);

View File

@ -145,7 +145,7 @@ bool compiler_supports_device_sass(Device& device) {
}
}
#define INCLUDE_PREFIX "mlx/backend/cuda/kernels/"
#define INCLUDE_PREFIX "mlx/backend/cuda/device/"
constexpr const char* g_include_names[] = {
INCLUDE_PREFIX "atomic_ops.cuh",

View File

@ -44,9 +44,12 @@ class MatMul {
int64_t b_batch_stride) {
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
auto type = dtype_to_cuda_type(dtype);
auto scale_type = dtype_to_cuda_type(dtype);
if (dtype == bfloat16) {
scale_type = CUDA_R_32F;
}
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
&matmul_desc_, dtype_to_compute_type(dtype), type));
&matmul_desc_, dtype_to_compute_type(dtype), scale_type));
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
matmul_desc_,
@ -65,6 +68,7 @@ class MatMul {
&op,
sizeof(cublasOperation_t)));
auto type = dtype_to_cuda_type(dtype);
a_desc_ = create_matrix_layout(
type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride);
b_desc_ = create_matrix_layout(
@ -187,15 +191,10 @@ class MatMul {
private:
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
switch (dtype) {
case uint8:
case uint16:
case int8:
case int16:
case int32:
return CUBLAS_COMPUTE_32I;
case float16:
case bfloat16:
return CUBLAS_COMPUTE_16F;
case bfloat16:
return CUBLAS_COMPUTE_32F;
case float32:
return CUBLAS_COMPUTE_32F;
case float64:
@ -209,16 +208,6 @@ class MatMul {
cudaDataType_t dtype_to_cuda_type(Dtype dtype) {
switch (dtype) {
case uint8:
return CUDA_R_8U;
case uint16:
return CUDA_R_16U;
case int8:
return CUDA_R_8I;
case int16:
return CUDA_R_16I;
case int32:
return CUDA_R_32I;
case float16:
return CUDA_R_16F;
case bfloat16:

View File

@ -155,26 +155,26 @@ void explicit_gemm_conv_group_ND_gpu(
// Perform gemm
std::vector<array> copies = {in_unfolded, wt_transpose};
return steel_matmul_regular(
s,
d,
/* a = */ in_unfolded,
/* b = */ wt_transpose,
/* c = */ out,
/* M = */ implicit_M,
/* N = */ implicit_N,
/* K = */ implicit_K,
/* batch_size_out = */ groups,
/* a_cols = */ implicit_K * groups,
/* b_cols = */ implicit_K,
/* out_cols = */ implicit_N * groups,
/* a_transposed = */ false,
/* b_transposed = */ true,
/* batch_shape = */ {1},
/* batch_strides = */ {0},
/* A_batch_strides = */ size_t(implicit_K),
/* B_batch_strides = */ size_t(implicit_N) * implicit_K,
/* matrix_stride_out = */ size_t(implicit_N),
/*copies = */ copies);
/* const Stream& s = */ s,
/* Device& d = */ d,
/* const array& a = */ in_unfolded,
/* const array& b = */ wt_transpose,
/* array& c = */ out,
/* int M = */ implicit_M,
/* int N = */ implicit_N,
/* int K = */ implicit_K,
/* int batch_size_out = */ groups,
/* int lda = */ implicit_K * groups,
/* int ldb = */ implicit_K,
/* int ldd = */ implicit_N * groups,
/* bool transpose_a = */ false,
/* bool transpose_b = */ true,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ {1},
/* Strides batch_strides = */ {0},
/* int64_t A_batch_strides = */ int64_t(implicit_K),
/* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K,
/* int64_t matrix_stride_out = */ int64_t(implicit_N));
}
void implicit_gemm_conv_2D_gpu(

View File

@ -297,6 +297,9 @@ Device::Device() {
device_ = load_device();
default_library_ = load_default_library(device_);
arch_ = std::string(device_->architecture()->name()->utf8String());
int ag_tens = arch_[arch_.size() - 3] - '0';
int ag_ones = arch_[arch_.size() - 2] - '0';
arch_gen_ = ag_tens * 10 + ag_ones;
auto arch = arch_.back();
switch (arch) {
case 'p': // phone

View File

@ -177,6 +177,10 @@ class Device {
return arch_;
}
int get_architecture_gen() const {
return arch_gen_;
}
void new_queue(int index);
MTL::CommandQueue* get_queue(Stream stream);
@ -268,6 +272,7 @@ class Device {
library_kernels_;
const MTL::ResidencySet* residency_set_{nullptr};
std::string arch_;
int arch_gen_;
int max_ops_per_buffer_;
int max_mb_per_buffer_;
};

View File

@ -235,6 +235,13 @@ struct Power {
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
if (x.real == 0 && x.imag == 0) {
if (metal::isnan(y.real) || metal::isnan(y.imag)) {
auto nan = metal::numeric_limits<float>::quiet_NaN();
return {nan, nan};
}
return {0.0, 0.0};
}
auto x_theta = metal::atan2(x.imag, x.real);
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);

View File

@ -33,8 +33,8 @@ template <
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]],
const constant int64_t* batch_strides [[buffer(7)]],
const constant int* batch_shape [[buffer(6), function_constant(has_batch)]],
const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],

File diff suppressed because it is too large Load Diff

View File

@ -6,7 +6,34 @@
namespace mlx::core {
void steel_matmul_regular(
template <bool CHECK_AB = true>
void steel_matmul_regular_axpby(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
const array& c,
array& out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
int ldd,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies,
Shape batch_shape,
Strides batch_strides,
int64_t A_batch_stride,
int64_t B_batch_stride,
int64_t matrix_stride_out,
int64_t C_batch_stride = 0,
float alpha = 1.0f,
float beta = 0.0f);
inline void steel_matmul_regular(
const Stream& s,
metal::Device& d,
const array& a,
@ -21,14 +48,61 @@ void steel_matmul_regular(
int ldd,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies,
Shape batch_shape,
Strides batch_strides,
int64_t A_batch_stride,
int64_t B_batch_stride,
int64_t matrix_stride_out,
std::vector<array>& copies);
int64_t matrix_stride_out) {
return steel_matmul_regular_axpby<false>(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
/* const array& b = */ b,
/* const array& c = */ b,
/* array& out = */ out,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int batch_size_out = */ batch_size_out,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* int ldd = */ ldd,
/* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ batch_shape,
/* Strides batch_strides = */ batch_strides,
/* int64_t A_batch_stride = */ A_batch_stride,
/* int64_t B_batch_stride = */ B_batch_stride,
/* int64_t matrix_stride_out = */ matrix_stride_out);
}
void steel_matmul(
template <bool CHECK_AB = true>
void steel_matmul_axpby(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
const array& c,
array& out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies,
Shape batch_shape = {},
Strides A_batch_stride = {},
Strides B_batch_stride = {},
Strides C_batch_stride = {},
float alpha = 1.0f,
float beta = 0.0f);
inline void steel_matmul(
const Stream& s,
metal::Device& d,
const array& a,
@ -45,6 +119,26 @@ void steel_matmul(
std::vector<array>& copies,
Shape batch_shape = {},
Strides A_batch_stride = {},
Strides B_batch_stride = {});
Strides B_batch_stride = {}) {
return steel_matmul_axpby<false>(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
/* const array& b = */ b,
/* const array& c = */ b,
/* array& out = */ out,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int batch_size_out = */ batch_size_out,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ batch_shape,
/* Strides A_batch_stride = */ A_batch_stride,
/* Strides B_batch_stride = */ B_batch_stride);
}
} // namespace mlx::core

View File

@ -26,7 +26,7 @@ void RMSNorm::eval_gpu(
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1);
}
if (no_copy) {
if (x.is_donatable()) {
@ -227,7 +227,7 @@ void LayerNorm::eval_gpu(
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1);
}
if (no_copy) {
if (x.is_donatable()) {

View File

@ -2847,21 +2847,6 @@ array matmul(
"[matmul] Got 0 dimension input. Inputs must "
"have at least one dimension.");
}
if (a.ndim() == 1) {
// Insert a singleton dim in the beginning
a = expand_dims(a, 0, s);
}
if (b.ndim() == 1) {
// Insert a singleton dim at the end
b = expand_dims(b, 1, s);
}
if (a.shape(-1) != b.shape(-2)) {
std::ostringstream msg;
msg << "[matmul] Last dimension of first input with shape " << a.shape()
<< " must match second to last dimension of"
<< " second input with shape " << b.shape() << ".";
throw std::invalid_argument(msg.str());
}
// complex matmul using Karatsuba's Algorithm
if (a.dtype() == complex64 || b.dtype() == complex64) {
@ -2883,6 +2868,22 @@ array matmul(
c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s);
}
if (a.ndim() == 1) {
// Insert a singleton dim in the beginning
a = expand_dims(a, 0, s);
}
if (b.ndim() == 1) {
// Insert a singleton dim at the end
b = expand_dims(b, 1, s);
}
if (a.shape(-1) != b.shape(-2)) {
std::ostringstream msg;
msg << "[matmul] Last dimension of first input with shape " << a.shape()
<< " must match second to last dimension of"
<< " second input with shape " << b.shape() << ".";
throw std::invalid_argument(msg.str());
}
// Type promotion
auto out_type = promote_types(a.dtype(), b.dtype());
@ -4240,6 +4241,16 @@ array addmm(
"have at least one dimension.");
}
// Type promotion
auto out_type = result_type(a, b, c);
if (out_type == complex64) {
return add(
multiply(matmul(a, b, s), array(alpha), s),
multiply(array(beta), c, s),
s);
}
if (a.ndim() == 1) {
// Insert a singleton dim in the beginning
a = expand_dims(a, 0, s);
@ -4257,16 +4268,6 @@ array addmm(
throw std::invalid_argument(msg.str());
}
// Type promotion
auto out_type = result_type(a, b, c);
if (out_type == complex64) {
return add(
multiply(matmul(a, b, s), array(alpha), s),
multiply(array(beta), c, s),
s);
}
if (!issubdtype(out_type, floating)) {
std::ostringstream msg;
msg << "[addmm] Only real floating point types are supported but "

View File

@ -69,7 +69,12 @@ inline void PrintFormatter::print(std::ostream& os, double val) {
os << val;
}
inline void PrintFormatter::print(std::ostream& os, complex64_t val) {
os << val;
os << val.real();
if (val.imag() >= 0 || std::isnan(val.imag())) {
os << "+" << val.imag() << "j";
} else {
os << "-" << -val.imag() << "j";
}
}
PrintFormatter& get_global_formatter() {

View File

@ -1195,6 +1195,16 @@ class TestBlas(mlx_tests.MLXTestCase):
c_np = np.matmul(np.array(a).T, b)
self.assertTrue(np.allclose(c, c_np))
# Check shapes
a = mx.random.normal((2, 3)).astype(mx.complex64)
b = mx.random.normal((3,))
self.assertEqual((a @ b).shape, (2,))
a = mx.random.normal((2, 3)).astype(mx.complex64)
b = mx.random.normal((3,))
c = mx.random.normal((2,))
self.assertEqual(mx.addmm(c, a, b).shape, (2,))
def test_complex_gemm(self):
M = 16
K = 50

View File

@ -3078,6 +3078,13 @@ class TestOps(mlx_tests.MLXTestCase):
)
self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x)))
def test_complex_power(self):
out = mx.power(mx.array(0j), 2)
self.assertEqual(out.item(), 0j)
out = mx.power(mx.array(0j), float("nan"))
self.assertTrue(mx.isnan(out))
class TestBroadcast(mlx_tests.MLXTestCase):
def test_broadcast_shapes(self):