Compare commits

...

4 Commits

Author SHA1 Message Date
Gaétan Lepage
8b86183ab5
Merge d2e0b0465c into 8402a2acf4 2025-06-14 09:53:58 +12:00
Awni Hannun
8402a2acf4
Fix complex power and print (#2286)
* fix complex power and print

* fix complex matmul shape
2025-06-13 11:13:00 -07:00
Jagrit Digani
fddb6933e1
Collection of refactors (#2274)
* Refactor gemv into a function

* Refactor splitk step 1

* Refactor split k axpby

* Rearrange steel_gemm_regular

* Redirect steel_gemm_regular

* Add axpby routing to steel_matmul_regular

* Refactor AddMM step 1

* Redirect steel_gemm

* Update addmm

* Comments and format

* Some cleanup

* Add architecture gen to device

* Update no copy condition in normalization to account for axis size 1
2025-06-13 10:44:56 -07:00
Gaetan Lepage
d2e0b0465c Feat: add USE_SYSTEM_FMT CMake option 2025-05-23 15:19:48 +02:00
14 changed files with 794 additions and 674 deletions

View File

@ -42,6 +42,7 @@ option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF) option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF) option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
# --------------------- Processor tests ------------------------- # --------------------- Processor tests -------------------------
message( message(
@ -234,12 +235,16 @@ target_include_directories(
# Do not add mlx_EXPORTS define for shared library. # Do not add mlx_EXPORTS define for shared library.
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "") set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
FetchContent_Declare( if(USE_SYSTEM_FMT)
fmt find_package(fmt REQUIRED)
GIT_REPOSITORY https://github.com/fmtlib/fmt.git else()
GIT_TAG 10.2.1 FetchContent_Declare(
EXCLUDE_FROM_ALL) fmt
FetchContent_MakeAvailable(fmt) GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(fmt)
endif()
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>) target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
if(MLX_BUILD_PYTHON_BINDINGS) if(MLX_BUILD_PYTHON_BINDINGS)

View File

@ -194,6 +194,13 @@ struct Power {
} }
return res; return res;
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) { } else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (base.y == 0 && base.x == 0) {
if (isnan(exp.x) || isnan(exp.y)) {
auto nan = cuda::std::numeric_limits<float>::quiet_NaN();
return make_cuFloatComplex(nan, nan);
}
return make_cuFloatComplex(0.0, 0.0);
}
auto x_theta = atan2f(base.y, base.x); auto x_theta = atan2f(base.y, base.x);
auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y); auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y);
auto mag = expf(exp.x * x_ln_r - exp.y * x_theta); auto mag = expf(exp.x * x_ln_r - exp.y * x_theta);

View File

@ -155,26 +155,26 @@ void explicit_gemm_conv_group_ND_gpu(
// Perform gemm // Perform gemm
std::vector<array> copies = {in_unfolded, wt_transpose}; std::vector<array> copies = {in_unfolded, wt_transpose};
return steel_matmul_regular( return steel_matmul_regular(
s, /* const Stream& s = */ s,
d, /* Device& d = */ d,
/* a = */ in_unfolded, /* const array& a = */ in_unfolded,
/* b = */ wt_transpose, /* const array& b = */ wt_transpose,
/* c = */ out, /* array& c = */ out,
/* M = */ implicit_M, /* int M = */ implicit_M,
/* N = */ implicit_N, /* int N = */ implicit_N,
/* K = */ implicit_K, /* int K = */ implicit_K,
/* batch_size_out = */ groups, /* int batch_size_out = */ groups,
/* a_cols = */ implicit_K * groups, /* int lda = */ implicit_K * groups,
/* b_cols = */ implicit_K, /* int ldb = */ implicit_K,
/* out_cols = */ implicit_N * groups, /* int ldd = */ implicit_N * groups,
/* a_transposed = */ false, /* bool transpose_a = */ false,
/* b_transposed = */ true, /* bool transpose_b = */ true,
/* batch_shape = */ {1}, /* std::vector<array>& copies = */ copies,
/* batch_strides = */ {0}, /* Shape batch_shape = */ {1},
/* A_batch_strides = */ size_t(implicit_K), /* Strides batch_strides = */ {0},
/* B_batch_strides = */ size_t(implicit_N) * implicit_K, /* int64_t A_batch_strides = */ int64_t(implicit_K),
/* matrix_stride_out = */ size_t(implicit_N), /* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K,
/*copies = */ copies); /* int64_t matrix_stride_out = */ int64_t(implicit_N));
} }
void implicit_gemm_conv_2D_gpu( void implicit_gemm_conv_2D_gpu(

View File

@ -297,6 +297,9 @@ Device::Device() {
device_ = load_device(); device_ = load_device();
default_library_ = load_default_library(device_); default_library_ = load_default_library(device_);
arch_ = std::string(device_->architecture()->name()->utf8String()); arch_ = std::string(device_->architecture()->name()->utf8String());
int ag_tens = arch_[arch_.size() - 3] - '0';
int ag_ones = arch_[arch_.size() - 2] - '0';
arch_gen_ = ag_tens * 10 + ag_ones;
auto arch = arch_.back(); auto arch = arch_.back();
switch (arch) { switch (arch) {
case 'p': // phone case 'p': // phone

View File

@ -177,6 +177,10 @@ class Device {
return arch_; return arch_;
} }
int get_architecture_gen() const {
return arch_gen_;
}
void new_queue(int index); void new_queue(int index);
MTL::CommandQueue* get_queue(Stream stream); MTL::CommandQueue* get_queue(Stream stream);
@ -268,6 +272,7 @@ class Device {
library_kernels_; library_kernels_;
const MTL::ResidencySet* residency_set_{nullptr}; const MTL::ResidencySet* residency_set_{nullptr};
std::string arch_; std::string arch_;
int arch_gen_;
int max_ops_per_buffer_; int max_ops_per_buffer_;
int max_mb_per_buffer_; int max_mb_per_buffer_;
}; };

View File

@ -235,6 +235,13 @@ struct Power {
template <> template <>
complex64_t operator()(complex64_t x, complex64_t y) { complex64_t operator()(complex64_t x, complex64_t y) {
if (x.real == 0 && x.imag == 0) {
if (metal::isnan(y.real) || metal::isnan(y.imag)) {
auto nan = metal::numeric_limits<float>::quiet_NaN();
return {nan, nan};
}
return {0.0, 0.0};
}
auto x_theta = metal::atan2(x.imag, x.real); auto x_theta = metal::atan2(x.imag, x.real);
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);

View File

@ -33,8 +33,8 @@ template <
device T* D [[buffer(3)]], device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]], const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]], const constant int* batch_shape [[buffer(6), function_constant(has_batch)]],
const constant int64_t* batch_strides [[buffer(7)]], const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],

File diff suppressed because it is too large Load Diff

View File

@ -6,7 +6,34 @@
namespace mlx::core { namespace mlx::core {
void steel_matmul_regular( template <bool CHECK_AB = true>
void steel_matmul_regular_axpby(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
const array& c,
array& out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
int ldd,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies,
Shape batch_shape,
Strides batch_strides,
int64_t A_batch_stride,
int64_t B_batch_stride,
int64_t matrix_stride_out,
int64_t C_batch_stride = 0,
float alpha = 1.0f,
float beta = 0.0f);
inline void steel_matmul_regular(
const Stream& s, const Stream& s,
metal::Device& d, metal::Device& d,
const array& a, const array& a,
@ -21,14 +48,61 @@ void steel_matmul_regular(
int ldd, int ldd,
bool transpose_a, bool transpose_a,
bool transpose_b, bool transpose_b,
std::vector<array>& copies,
Shape batch_shape, Shape batch_shape,
Strides batch_strides, Strides batch_strides,
int64_t A_batch_stride, int64_t A_batch_stride,
int64_t B_batch_stride, int64_t B_batch_stride,
int64_t matrix_stride_out, int64_t matrix_stride_out) {
std::vector<array>& copies); return steel_matmul_regular_axpby<false>(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
/* const array& b = */ b,
/* const array& c = */ b,
/* array& out = */ out,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int batch_size_out = */ batch_size_out,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* int ldd = */ ldd,
/* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ batch_shape,
/* Strides batch_strides = */ batch_strides,
/* int64_t A_batch_stride = */ A_batch_stride,
/* int64_t B_batch_stride = */ B_batch_stride,
/* int64_t matrix_stride_out = */ matrix_stride_out);
}
void steel_matmul( template <bool CHECK_AB = true>
void steel_matmul_axpby(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
const array& c,
array& out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies,
Shape batch_shape = {},
Strides A_batch_stride = {},
Strides B_batch_stride = {},
Strides C_batch_stride = {},
float alpha = 1.0f,
float beta = 0.0f);
inline void steel_matmul(
const Stream& s, const Stream& s,
metal::Device& d, metal::Device& d,
const array& a, const array& a,
@ -45,6 +119,26 @@ void steel_matmul(
std::vector<array>& copies, std::vector<array>& copies,
Shape batch_shape = {}, Shape batch_shape = {},
Strides A_batch_stride = {}, Strides A_batch_stride = {},
Strides B_batch_stride = {}); Strides B_batch_stride = {}) {
return steel_matmul_axpby<false>(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
/* const array& b = */ b,
/* const array& c = */ b,
/* array& out = */ out,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int batch_size_out = */ batch_size_out,
/* int lda = */ lda,
/* int ldb = */ ldb,
/* bool transpose_a = */ transpose_a,
/* bool transpose_b = */ transpose_b,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ batch_shape,
/* Strides A_batch_stride = */ A_batch_stride,
/* Strides B_batch_stride = */ B_batch_stride);
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -26,7 +26,7 @@ void RMSNorm::eval_gpu(
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) { if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2]; auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back()); no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1);
} }
if (no_copy) { if (no_copy) {
if (x.is_donatable()) { if (x.is_donatable()) {
@ -227,7 +227,7 @@ void LayerNorm::eval_gpu(
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) { if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2]; auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back()); no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1);
} }
if (no_copy) { if (no_copy) {
if (x.is_donatable()) { if (x.is_donatable()) {

View File

@ -2847,21 +2847,6 @@ array matmul(
"[matmul] Got 0 dimension input. Inputs must " "[matmul] Got 0 dimension input. Inputs must "
"have at least one dimension."); "have at least one dimension.");
} }
if (a.ndim() == 1) {
// Insert a singleton dim in the beginning
a = expand_dims(a, 0, s);
}
if (b.ndim() == 1) {
// Insert a singleton dim at the end
b = expand_dims(b, 1, s);
}
if (a.shape(-1) != b.shape(-2)) {
std::ostringstream msg;
msg << "[matmul] Last dimension of first input with shape " << a.shape()
<< " must match second to last dimension of"
<< " second input with shape " << b.shape() << ".";
throw std::invalid_argument(msg.str());
}
// complex matmul using Karatsuba's Algorithm // complex matmul using Karatsuba's Algorithm
if (a.dtype() == complex64 || b.dtype() == complex64) { if (a.dtype() == complex64 || b.dtype() == complex64) {
@ -2883,6 +2868,22 @@ array matmul(
c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s); c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s);
} }
if (a.ndim() == 1) {
// Insert a singleton dim in the beginning
a = expand_dims(a, 0, s);
}
if (b.ndim() == 1) {
// Insert a singleton dim at the end
b = expand_dims(b, 1, s);
}
if (a.shape(-1) != b.shape(-2)) {
std::ostringstream msg;
msg << "[matmul] Last dimension of first input with shape " << a.shape()
<< " must match second to last dimension of"
<< " second input with shape " << b.shape() << ".";
throw std::invalid_argument(msg.str());
}
// Type promotion // Type promotion
auto out_type = promote_types(a.dtype(), b.dtype()); auto out_type = promote_types(a.dtype(), b.dtype());
@ -4240,6 +4241,16 @@ array addmm(
"have at least one dimension."); "have at least one dimension.");
} }
// Type promotion
auto out_type = result_type(a, b, c);
if (out_type == complex64) {
return add(
multiply(matmul(a, b, s), array(alpha), s),
multiply(array(beta), c, s),
s);
}
if (a.ndim() == 1) { if (a.ndim() == 1) {
// Insert a singleton dim in the beginning // Insert a singleton dim in the beginning
a = expand_dims(a, 0, s); a = expand_dims(a, 0, s);
@ -4257,16 +4268,6 @@ array addmm(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
// Type promotion
auto out_type = result_type(a, b, c);
if (out_type == complex64) {
return add(
multiply(matmul(a, b, s), array(alpha), s),
multiply(array(beta), c, s),
s);
}
if (!issubdtype(out_type, floating)) { if (!issubdtype(out_type, floating)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[addmm] Only real floating point types are supported but " msg << "[addmm] Only real floating point types are supported but "

View File

@ -69,7 +69,12 @@ inline void PrintFormatter::print(std::ostream& os, double val) {
os << val; os << val;
} }
inline void PrintFormatter::print(std::ostream& os, complex64_t val) { inline void PrintFormatter::print(std::ostream& os, complex64_t val) {
os << val; os << val.real();
if (val.imag() >= 0 || std::isnan(val.imag())) {
os << "+" << val.imag() << "j";
} else {
os << "-" << -val.imag() << "j";
}
} }
PrintFormatter& get_global_formatter() { PrintFormatter& get_global_formatter() {

View File

@ -1195,6 +1195,16 @@ class TestBlas(mlx_tests.MLXTestCase):
c_np = np.matmul(np.array(a).T, b) c_np = np.matmul(np.array(a).T, b)
self.assertTrue(np.allclose(c, c_np)) self.assertTrue(np.allclose(c, c_np))
# Check shapes
a = mx.random.normal((2, 3)).astype(mx.complex64)
b = mx.random.normal((3,))
self.assertEqual((a @ b).shape, (2,))
a = mx.random.normal((2, 3)).astype(mx.complex64)
b = mx.random.normal((3,))
c = mx.random.normal((2,))
self.assertEqual(mx.addmm(c, a, b).shape, (2,))
def test_complex_gemm(self): def test_complex_gemm(self):
M = 16 M = 16
K = 50 K = 50

View File

@ -3078,6 +3078,13 @@ class TestOps(mlx_tests.MLXTestCase):
) )
self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x))) self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x)))
def test_complex_power(self):
out = mx.power(mx.array(0j), 2)
self.assertEqual(out.item(), 0j)
out = mx.power(mx.array(0j), float("nan"))
self.assertTrue(mx.isnan(out))
class TestBroadcast(mlx_tests.MLXTestCase): class TestBroadcast(mlx_tests.MLXTestCase):
def test_broadcast_shapes(self): def test_broadcast_shapes(self):