mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Compare commits
4 Commits
8f58ee4348
...
8b86183ab5
Author | SHA1 | Date | |
---|---|---|---|
![]() |
8b86183ab5 | ||
![]() |
8402a2acf4 | ||
![]() |
fddb6933e1 | ||
![]() |
d2e0b0465c |
@ -42,6 +42,7 @@ option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
message(
|
||||
@ -234,12 +235,16 @@ target_include_directories(
|
||||
# Do not add mlx_EXPORTS define for shared library.
|
||||
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
|
||||
|
||||
FetchContent_Declare(
|
||||
if(USE_SYSTEM_FMT)
|
||||
find_package(fmt REQUIRED)
|
||||
else()
|
||||
FetchContent_Declare(
|
||||
fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||
GIT_TAG 10.2.1
|
||||
EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(fmt)
|
||||
FetchContent_MakeAvailable(fmt)
|
||||
endif()
|
||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
||||
|
||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||
|
@ -194,6 +194,13 @@ struct Power {
|
||||
}
|
||||
return res;
|
||||
} else if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
if (base.y == 0 && base.x == 0) {
|
||||
if (isnan(exp.x) || isnan(exp.y)) {
|
||||
auto nan = cuda::std::numeric_limits<float>::quiet_NaN();
|
||||
return make_cuFloatComplex(nan, nan);
|
||||
}
|
||||
return make_cuFloatComplex(0.0, 0.0);
|
||||
}
|
||||
auto x_theta = atan2f(base.y, base.x);
|
||||
auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y);
|
||||
auto mag = expf(exp.x * x_ln_r - exp.y * x_theta);
|
||||
|
@ -155,26 +155,26 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
// Perform gemm
|
||||
std::vector<array> copies = {in_unfolded, wt_transpose};
|
||||
return steel_matmul_regular(
|
||||
s,
|
||||
d,
|
||||
/* a = */ in_unfolded,
|
||||
/* b = */ wt_transpose,
|
||||
/* c = */ out,
|
||||
/* M = */ implicit_M,
|
||||
/* N = */ implicit_N,
|
||||
/* K = */ implicit_K,
|
||||
/* batch_size_out = */ groups,
|
||||
/* a_cols = */ implicit_K * groups,
|
||||
/* b_cols = */ implicit_K,
|
||||
/* out_cols = */ implicit_N * groups,
|
||||
/* a_transposed = */ false,
|
||||
/* b_transposed = */ true,
|
||||
/* batch_shape = */ {1},
|
||||
/* batch_strides = */ {0},
|
||||
/* A_batch_strides = */ size_t(implicit_K),
|
||||
/* B_batch_strides = */ size_t(implicit_N) * implicit_K,
|
||||
/* matrix_stride_out = */ size_t(implicit_N),
|
||||
/*copies = */ copies);
|
||||
/* const Stream& s = */ s,
|
||||
/* Device& d = */ d,
|
||||
/* const array& a = */ in_unfolded,
|
||||
/* const array& b = */ wt_transpose,
|
||||
/* array& c = */ out,
|
||||
/* int M = */ implicit_M,
|
||||
/* int N = */ implicit_N,
|
||||
/* int K = */ implicit_K,
|
||||
/* int batch_size_out = */ groups,
|
||||
/* int lda = */ implicit_K * groups,
|
||||
/* int ldb = */ implicit_K,
|
||||
/* int ldd = */ implicit_N * groups,
|
||||
/* bool transpose_a = */ false,
|
||||
/* bool transpose_b = */ true,
|
||||
/* std::vector<array>& copies = */ copies,
|
||||
/* Shape batch_shape = */ {1},
|
||||
/* Strides batch_strides = */ {0},
|
||||
/* int64_t A_batch_strides = */ int64_t(implicit_K),
|
||||
/* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K,
|
||||
/* int64_t matrix_stride_out = */ int64_t(implicit_N));
|
||||
}
|
||||
|
||||
void implicit_gemm_conv_2D_gpu(
|
||||
|
@ -297,6 +297,9 @@ Device::Device() {
|
||||
device_ = load_device();
|
||||
default_library_ = load_default_library(device_);
|
||||
arch_ = std::string(device_->architecture()->name()->utf8String());
|
||||
int ag_tens = arch_[arch_.size() - 3] - '0';
|
||||
int ag_ones = arch_[arch_.size() - 2] - '0';
|
||||
arch_gen_ = ag_tens * 10 + ag_ones;
|
||||
auto arch = arch_.back();
|
||||
switch (arch) {
|
||||
case 'p': // phone
|
||||
|
@ -177,6 +177,10 @@ class Device {
|
||||
return arch_;
|
||||
}
|
||||
|
||||
int get_architecture_gen() const {
|
||||
return arch_gen_;
|
||||
}
|
||||
|
||||
void new_queue(int index);
|
||||
|
||||
MTL::CommandQueue* get_queue(Stream stream);
|
||||
@ -268,6 +272,7 @@ class Device {
|
||||
library_kernels_;
|
||||
const MTL::ResidencySet* residency_set_{nullptr};
|
||||
std::string arch_;
|
||||
int arch_gen_;
|
||||
int max_ops_per_buffer_;
|
||||
int max_mb_per_buffer_;
|
||||
};
|
||||
|
@ -235,6 +235,13 @@ struct Power {
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (x.real == 0 && x.imag == 0) {
|
||||
if (metal::isnan(y.real) || metal::isnan(y.imag)) {
|
||||
auto nan = metal::numeric_limits<float>::quiet_NaN();
|
||||
return {nan, nan};
|
||||
}
|
||||
return {0.0, 0.0};
|
||||
}
|
||||
auto x_theta = metal::atan2(x.imag, x.real);
|
||||
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
|
||||
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
|
||||
|
@ -33,8 +33,8 @@ template <
|
||||
device T* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant int64_t* batch_strides [[buffer(7)]],
|
||||
const constant int* batch_shape [[buffer(6), function_constant(has_batch)]],
|
||||
const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -6,7 +6,34 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void steel_matmul_regular(
|
||||
template <bool CHECK_AB = true>
|
||||
void steel_matmul_regular_axpby(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int batch_size_out,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldd,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies,
|
||||
Shape batch_shape,
|
||||
Strides batch_strides,
|
||||
int64_t A_batch_stride,
|
||||
int64_t B_batch_stride,
|
||||
int64_t matrix_stride_out,
|
||||
int64_t C_batch_stride = 0,
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f);
|
||||
|
||||
inline void steel_matmul_regular(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
@ -21,14 +48,61 @@ void steel_matmul_regular(
|
||||
int ldd,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies,
|
||||
Shape batch_shape,
|
||||
Strides batch_strides,
|
||||
int64_t A_batch_stride,
|
||||
int64_t B_batch_stride,
|
||||
int64_t matrix_stride_out,
|
||||
std::vector<array>& copies);
|
||||
int64_t matrix_stride_out) {
|
||||
return steel_matmul_regular_axpby<false>(
|
||||
/* const Stream& s = */ s,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const array& a = */ a,
|
||||
/* const array& b = */ b,
|
||||
/* const array& c = */ b,
|
||||
/* array& out = */ out,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* int batch_size_out = */ batch_size_out,
|
||||
/* int lda = */ lda,
|
||||
/* int ldb = */ ldb,
|
||||
/* int ldd = */ ldd,
|
||||
/* bool transpose_a = */ transpose_a,
|
||||
/* bool transpose_b = */ transpose_b,
|
||||
/* std::vector<array>& copies = */ copies,
|
||||
/* Shape batch_shape = */ batch_shape,
|
||||
/* Strides batch_strides = */ batch_strides,
|
||||
/* int64_t A_batch_stride = */ A_batch_stride,
|
||||
/* int64_t B_batch_stride = */ B_batch_stride,
|
||||
/* int64_t matrix_stride_out = */ matrix_stride_out);
|
||||
}
|
||||
|
||||
void steel_matmul(
|
||||
template <bool CHECK_AB = true>
|
||||
void steel_matmul_axpby(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int batch_size_out,
|
||||
int lda,
|
||||
int ldb,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
std::vector<array>& copies,
|
||||
Shape batch_shape = {},
|
||||
Strides A_batch_stride = {},
|
||||
Strides B_batch_stride = {},
|
||||
Strides C_batch_stride = {},
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f);
|
||||
|
||||
inline void steel_matmul(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& a,
|
||||
@ -45,6 +119,26 @@ void steel_matmul(
|
||||
std::vector<array>& copies,
|
||||
Shape batch_shape = {},
|
||||
Strides A_batch_stride = {},
|
||||
Strides B_batch_stride = {});
|
||||
Strides B_batch_stride = {}) {
|
||||
return steel_matmul_axpby<false>(
|
||||
/* const Stream& s = */ s,
|
||||
/* metal::Device& d = */ d,
|
||||
/* const array& a = */ a,
|
||||
/* const array& b = */ b,
|
||||
/* const array& c = */ b,
|
||||
/* array& out = */ out,
|
||||
/* int M = */ M,
|
||||
/* int N = */ N,
|
||||
/* int K = */ K,
|
||||
/* int batch_size_out = */ batch_size_out,
|
||||
/* int lda = */ lda,
|
||||
/* int ldb = */ ldb,
|
||||
/* bool transpose_a = */ transpose_a,
|
||||
/* bool transpose_b = */ transpose_b,
|
||||
/* std::vector<array>& copies = */ copies,
|
||||
/* Shape batch_shape = */ batch_shape,
|
||||
/* Strides A_batch_stride = */ A_batch_stride,
|
||||
/* Strides B_batch_stride = */ B_batch_stride);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -26,7 +26,7 @@ void RMSNorm::eval_gpu(
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1);
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.is_donatable()) {
|
||||
@ -227,7 +227,7 @@ void LayerNorm::eval_gpu(
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
no_copy &= (s == 0 || s == x.shape().back() || x.shape(-2) == 1);
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.is_donatable()) {
|
||||
|
51
mlx/ops.cpp
51
mlx/ops.cpp
@ -2847,21 +2847,6 @@ array matmul(
|
||||
"[matmul] Got 0 dimension input. Inputs must "
|
||||
"have at least one dimension.");
|
||||
}
|
||||
if (a.ndim() == 1) {
|
||||
// Insert a singleton dim in the beginning
|
||||
a = expand_dims(a, 0, s);
|
||||
}
|
||||
if (b.ndim() == 1) {
|
||||
// Insert a singleton dim at the end
|
||||
b = expand_dims(b, 1, s);
|
||||
}
|
||||
if (a.shape(-1) != b.shape(-2)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[matmul] Last dimension of first input with shape " << a.shape()
|
||||
<< " must match second to last dimension of"
|
||||
<< " second input with shape " << b.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// complex matmul using Karatsuba's Algorithm
|
||||
if (a.dtype() == complex64 || b.dtype() == complex64) {
|
||||
@ -2883,6 +2868,22 @@ array matmul(
|
||||
c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s);
|
||||
}
|
||||
|
||||
if (a.ndim() == 1) {
|
||||
// Insert a singleton dim in the beginning
|
||||
a = expand_dims(a, 0, s);
|
||||
}
|
||||
if (b.ndim() == 1) {
|
||||
// Insert a singleton dim at the end
|
||||
b = expand_dims(b, 1, s);
|
||||
}
|
||||
if (a.shape(-1) != b.shape(-2)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[matmul] Last dimension of first input with shape " << a.shape()
|
||||
<< " must match second to last dimension of"
|
||||
<< " second input with shape " << b.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Type promotion
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
|
||||
@ -4240,6 +4241,16 @@ array addmm(
|
||||
"have at least one dimension.");
|
||||
}
|
||||
|
||||
// Type promotion
|
||||
auto out_type = result_type(a, b, c);
|
||||
|
||||
if (out_type == complex64) {
|
||||
return add(
|
||||
multiply(matmul(a, b, s), array(alpha), s),
|
||||
multiply(array(beta), c, s),
|
||||
s);
|
||||
}
|
||||
|
||||
if (a.ndim() == 1) {
|
||||
// Insert a singleton dim in the beginning
|
||||
a = expand_dims(a, 0, s);
|
||||
@ -4257,16 +4268,6 @@ array addmm(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Type promotion
|
||||
auto out_type = result_type(a, b, c);
|
||||
|
||||
if (out_type == complex64) {
|
||||
return add(
|
||||
multiply(matmul(a, b, s), array(alpha), s),
|
||||
multiply(array(beta), c, s),
|
||||
s);
|
||||
}
|
||||
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[addmm] Only real floating point types are supported but "
|
||||
|
@ -69,7 +69,12 @@ inline void PrintFormatter::print(std::ostream& os, double val) {
|
||||
os << val;
|
||||
}
|
||||
inline void PrintFormatter::print(std::ostream& os, complex64_t val) {
|
||||
os << val;
|
||||
os << val.real();
|
||||
if (val.imag() >= 0 || std::isnan(val.imag())) {
|
||||
os << "+" << val.imag() << "j";
|
||||
} else {
|
||||
os << "-" << -val.imag() << "j";
|
||||
}
|
||||
}
|
||||
|
||||
PrintFormatter& get_global_formatter() {
|
||||
|
@ -1195,6 +1195,16 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
c_np = np.matmul(np.array(a).T, b)
|
||||
self.assertTrue(np.allclose(c, c_np))
|
||||
|
||||
# Check shapes
|
||||
a = mx.random.normal((2, 3)).astype(mx.complex64)
|
||||
b = mx.random.normal((3,))
|
||||
self.assertEqual((a @ b).shape, (2,))
|
||||
|
||||
a = mx.random.normal((2, 3)).astype(mx.complex64)
|
||||
b = mx.random.normal((3,))
|
||||
c = mx.random.normal((2,))
|
||||
self.assertEqual(mx.addmm(c, a, b).shape, (2,))
|
||||
|
||||
def test_complex_gemm(self):
|
||||
M = 16
|
||||
K = 50
|
||||
|
@ -3078,6 +3078,13 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x)))
|
||||
|
||||
def test_complex_power(self):
|
||||
out = mx.power(mx.array(0j), 2)
|
||||
self.assertEqual(out.item(), 0j)
|
||||
|
||||
out = mx.power(mx.array(0j), float("nan"))
|
||||
self.assertTrue(mx.isnan(out))
|
||||
|
||||
|
||||
class TestBroadcast(mlx_tests.MLXTestCase):
|
||||
def test_broadcast_shapes(self):
|
||||
|
Loading…
Reference in New Issue
Block a user