mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-06 02:38:18 +08:00
Faster complex matmul (#2571)
This commit is contained in:
@@ -142,9 +142,7 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
|
||||
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
|
||||
|
||||
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
|
||||
c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype(
|
||||
np.float32
|
||||
)
|
||||
c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype)
|
||||
|
||||
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
@@ -163,7 +161,7 @@ def get_gflop_count(B, M, N, K):
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
||||
|
||||
dtypes = ("float32", "float16")
|
||||
dtypes = ("float32", "float16", "complex64")
|
||||
transposes = ("nn", "nt", "tn")
|
||||
shapes = (
|
||||
(16, 234, 768, 3072),
|
||||
@@ -187,7 +185,7 @@ if __name__ == "__main__":
|
||||
diff = gflops_mx / gflops_pt - 1.0
|
||||
|
||||
print(
|
||||
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%"
|
||||
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%"
|
||||
)
|
||||
if gflops_pt >= 2.0 * gflops_mx:
|
||||
print("ATTENTION ^^^^^^^")
|
||||
|
||||
@@ -196,7 +196,7 @@ def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
|
||||
|
||||
|
||||
for transpose in (False, True):
|
||||
for dtype in ("float32", "float16"):
|
||||
for dtype in ("float32", "float16", "complex64"):
|
||||
fig, axs = plt.subplots(
|
||||
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
|
||||
)
|
||||
@@ -215,7 +215,7 @@ for transpose in (False, True):
|
||||
fig.suptitle(f"{device_name}: {dtype} {op_name}")
|
||||
fig.savefig(
|
||||
os.path.join(
|
||||
results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf'
|
||||
results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf"
|
||||
)
|
||||
)
|
||||
plt.close(fig)
|
||||
|
||||
@@ -88,4 +88,47 @@ void matmul<double>(
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void matmul<complex64_t>(
|
||||
const complex64_t* a,
|
||||
const complex64_t* b,
|
||||
complex64_t* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
auto calpha = static_cast<complex64_t>(alpha);
|
||||
auto cbeta = static_cast<complex64_t>(beta);
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
cblas_cgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
&calpha,
|
||||
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
||||
lda,
|
||||
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
||||
ldb,
|
||||
&cbeta,
|
||||
out + M * N * i,
|
||||
ldc);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -108,6 +108,9 @@ void matmul_general(
|
||||
} else if (out.dtype() == float64) {
|
||||
matmul_dispatch<double>(
|
||||
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
|
||||
} else if (out.dtype() == complex64) {
|
||||
matmul_dispatch<complex64_t>(
|
||||
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta, stream);
|
||||
} else {
|
||||
throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
|
||||
}
|
||||
|
||||
@@ -50,8 +50,10 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
||||
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
|
||||
: CUBLAS_COMPUTE_32F;
|
||||
case float64:
|
||||
case complex64:
|
||||
return CUBLAS_COMPUTE_64F;
|
||||
case complex64:
|
||||
return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32
|
||||
: CUBLAS_COMPUTE_32F;
|
||||
default:
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype)));
|
||||
@@ -126,12 +128,13 @@ CublasGemm::CublasGemm(
|
||||
N_(b_cols) {
|
||||
heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
|
||||
|
||||
auto scale_type = dtype_to_cublas_type(dtype);
|
||||
scale_type_ = dtype_to_cublas_type(dtype);
|
||||
if (dtype == bfloat16 || dtype == float16) {
|
||||
scale_type = CUDA_R_32F;
|
||||
scale_type_ = CUDA_R_32F;
|
||||
}
|
||||
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(
|
||||
&matmul_desc_, dtype_to_compute_type(dtype), scale_type));
|
||||
&matmul_desc_, dtype_to_compute_type(dtype), scale_type_));
|
||||
int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
||||
matmul_desc_,
|
||||
@@ -352,6 +355,16 @@ void CublasGemm::execute(
|
||||
}
|
||||
}
|
||||
|
||||
const void* alpha_ptr = α
|
||||
const void* beta_ptr = β
|
||||
complex64_t alpha_c, beta_c;
|
||||
if (scale_type_ == CUDA_C_32F) {
|
||||
alpha_c = complex64_t{alpha, 0.0f};
|
||||
beta_c = complex64_t{beta, 0.0f};
|
||||
alpha_ptr = &alpha_c;
|
||||
beta_ptr = &beta_c;
|
||||
}
|
||||
|
||||
void* workspace_ptr = nullptr;
|
||||
if (heuristic_.workspaceSize > 0) {
|
||||
// Ensure workspace is 256-byte aligned
|
||||
@@ -368,12 +381,12 @@ void CublasGemm::execute(
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatmul(
|
||||
handle_,
|
||||
matmul_desc_,
|
||||
&alpha,
|
||||
alpha_ptr,
|
||||
b, // a and b are swapped
|
||||
a_desc_,
|
||||
a,
|
||||
b_desc_,
|
||||
&beta,
|
||||
beta_ptr,
|
||||
c ? c : out,
|
||||
c ? c_desc_ : out_desc_,
|
||||
out,
|
||||
|
||||
@@ -115,6 +115,7 @@ class CublasGemm {
|
||||
|
||||
uint64_t M_;
|
||||
uint64_t N_;
|
||||
cudaDataType_t scale_type_;
|
||||
cublasLtMatmulPreference_t pref_{nullptr};
|
||||
cublasLtHandle_t handle_{nullptr};
|
||||
cublasLtMatmulDesc_t matmul_desc_{nullptr};
|
||||
|
||||
@@ -13,6 +13,37 @@ namespace cg = cooperative_groups;
|
||||
|
||||
static constexpr int rows_per_block = 8;
|
||||
|
||||
// Accumulator type selection per input element type T.
|
||||
template <typename T>
|
||||
struct GemvAccType {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemvAccType<__half> {
|
||||
using type = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemvAccType<__nv_bfloat16> {
|
||||
using type = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemvAccType<float> {
|
||||
using type = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemvAccType<double> {
|
||||
using type = double;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemvAccType<cu::complex64_t> {
|
||||
using type = cu::complex64_t;
|
||||
};
|
||||
|
||||
template <typename T, int rows_per_block, int n_per_thread>
|
||||
__device__ void
|
||||
gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
|
||||
@@ -24,7 +55,8 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
|
||||
int row = g_idx.x * rows_per_block + t_idx.y;
|
||||
|
||||
if (row < rows) {
|
||||
float sum = 0.0f;
|
||||
using Acc = typename GemvAccType<T>::type;
|
||||
Acc sum = Acc(0);
|
||||
for (int col = n_per_thread * warp.thread_rank(); col < cols;
|
||||
col += (WARP_SIZE * n_per_thread)) {
|
||||
auto local_mat =
|
||||
@@ -32,12 +64,11 @@ gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) {
|
||||
auto local_vec = unsafe_load_vector<n_per_thread>(vec + col, 0);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < n_per_thread; ++j) {
|
||||
sum +=
|
||||
static_cast<float>(local_mat[j]) * static_cast<float>(local_vec[j]);
|
||||
sum += static_cast<Acc>(local_mat[j]) * static_cast<Acc>(local_vec[j]);
|
||||
}
|
||||
}
|
||||
|
||||
sum = cg::reduce(warp, sum, cg::plus<float>{});
|
||||
sum = cg::reduce(warp, sum, cg::plus<Acc>{});
|
||||
if (warp.thread_rank() == 0) {
|
||||
out[row] = static_cast<T>(sum);
|
||||
}
|
||||
@@ -107,7 +138,7 @@ void gemv(
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
dispatch_float_types(out.dtype(), "gemv", [&](auto type_tag) {
|
||||
dispatch_inexact_types(out.dtype(), "gemv", [&](auto type_tag) {
|
||||
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||
dim3 block_dims{WARP_SIZE, rows_per_block};
|
||||
const DataType* mat;
|
||||
|
||||
@@ -93,6 +93,10 @@ void gemm_and_bias(
|
||||
a_batch_strides.back(),
|
||||
b_batch_strides.back());
|
||||
if (bias) {
|
||||
if (a.dtype() == complex64) {
|
||||
throw std::runtime_error(
|
||||
"[gemm_and_bias] complex64 bias epilogue isn’t supported in cublasLtMatmul.");
|
||||
}
|
||||
gemm.set_bias(encoder, *bias);
|
||||
}
|
||||
gemm.run(
|
||||
@@ -157,7 +161,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Dispatch to GEMM with epilogue or AddMM
|
||||
|
||||
if (beta_ == 1 && c.strides(-1) == 1 && c.data_size() == out.shape(-1)) {
|
||||
if (beta_ == 1 && a.dtype() != complex64 && c.strides(-1) == 1 &&
|
||||
c.data_size() == out.shape(-1)) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
gemm_and_bias(
|
||||
encoder,
|
||||
|
||||
@@ -104,6 +104,27 @@ constexpr bool operator==(complex64_t a, complex64_t b) {
|
||||
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
|
||||
return {a.real + b.real, a.imag + b.imag};
|
||||
}
|
||||
|
||||
constexpr thread complex64_t& operator+=(thread complex64_t& a, complex64_t b) {
|
||||
a.real += b.real;
|
||||
a.imag += b.imag;
|
||||
return a;
|
||||
}
|
||||
|
||||
constexpr threadgroup complex64_t& operator+=(
|
||||
threadgroup complex64_t& a,
|
||||
complex64_t b) {
|
||||
a.real += b.real;
|
||||
a.imag += b.imag;
|
||||
return a;
|
||||
}
|
||||
|
||||
constexpr device complex64_t& operator+=(device complex64_t& a, complex64_t b) {
|
||||
a.real += b.real;
|
||||
a.imag += b.imag;
|
||||
return a;
|
||||
}
|
||||
|
||||
constexpr complex64_t operator+(float a, complex64_t b) {
|
||||
return {a + b.real, b.imag};
|
||||
}
|
||||
|
||||
@@ -15,6 +15,15 @@ using namespace metal;
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
|
||||
template <typename U>
|
||||
struct DefaultAccT {
|
||||
using type = float;
|
||||
};
|
||||
template <>
|
||||
struct DefaultAccT<complex64_t> {
|
||||
using type = complex64_t;
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM, /* Threadgroup rows (in simdgroups) */
|
||||
@@ -24,8 +33,10 @@ template <
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoAxpby, /* Do out = alpha * out + beta * bias */
|
||||
typename AccT = float>
|
||||
typename AccT = typename DefaultAccT<T>::type>
|
||||
struct GEMVKernel {
|
||||
using acc_type = AccT;
|
||||
|
||||
MLX_MTL_CONST int threadsM = BM * SM;
|
||||
MLX_MTL_CONST int threadsN = BN * SN;
|
||||
|
||||
@@ -246,8 +257,10 @@ template <
|
||||
const int TM, /* Thread rows (in elements) */
|
||||
const int TN, /* Thread cols (in elements) */
|
||||
const bool kDoAxpby, /* Do out = alpha * out + beta * bias */
|
||||
typename AccT = float>
|
||||
typename AccT = typename DefaultAccT<T>::type>
|
||||
struct GEMVTKernel {
|
||||
using acc_type = AccT;
|
||||
|
||||
MLX_MTL_CONST int threadsM = BM * SM;
|
||||
MLX_MTL_CONST int threadsN = BN * SN;
|
||||
|
||||
@@ -453,7 +466,7 @@ template <
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;
|
||||
threadgroup float tgp_memory
|
||||
threadgroup typename gemv_kernel::acc_type tgp_memory
|
||||
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
||||
|
||||
// Update batch offsets
|
||||
@@ -530,6 +543,7 @@ template <
|
||||
instantiate_gemv_blocks(float32, float);
|
||||
instantiate_gemv_blocks(float16, half);
|
||||
instantiate_gemv_blocks(bfloat16, bfloat16_t);
|
||||
instantiate_gemv_blocks(complex64, complex64_t);
|
||||
|
||||
template <
|
||||
typename T,
|
||||
@@ -565,7 +579,7 @@ template <
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, false>;
|
||||
threadgroup float tgp_memory
|
||||
threadgroup typename gemv_kernel::acc_type tgp_memory
|
||||
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
||||
|
||||
uint32_t indx_vec;
|
||||
@@ -636,6 +650,7 @@ template <
|
||||
instantiate_gemv_bs_blocks(float32, float);
|
||||
instantiate_gemv_bs_blocks(float16, half);
|
||||
instantiate_gemv_bs_blocks(bfloat16, bfloat16_t);
|
||||
instantiate_gemv_bs_blocks(complex64, complex64_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Vector matrix multiplication
|
||||
@@ -672,7 +687,7 @@ template <
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;
|
||||
threadgroup float tgp_memory
|
||||
threadgroup typename gemv_kernel::acc_type tgp_memory
|
||||
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
||||
|
||||
// Update batch offsets
|
||||
@@ -738,7 +753,8 @@ template <
|
||||
// clang-format off
|
||||
instantiate_gemv_t_blocks(float32, float);
|
||||
instantiate_gemv_t_blocks(float16, half);
|
||||
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on
|
||||
instantiate_gemv_t_blocks(bfloat16, bfloat16_t);
|
||||
instantiate_gemv_t_blocks(complex64, complex64_t); // clang-format on
|
||||
|
||||
template <
|
||||
typename T,
|
||||
@@ -773,8 +789,8 @@ template <
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, false, float>;
|
||||
threadgroup float tgp_memory
|
||||
using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, false>;
|
||||
threadgroup typename gemv_kernel::acc_type tgp_memory
|
||||
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
|
||||
|
||||
uint32_t indx_vec;
|
||||
@@ -848,4 +864,5 @@ template <
|
||||
// clang-format off
|
||||
instantiate_gemv_t_bs_blocks(float32, float);
|
||||
instantiate_gemv_t_bs_blocks(float16, half);
|
||||
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on
|
||||
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t);
|
||||
instantiate_gemv_t_bs_blocks(complex64, complex64_t); // clang-format on
|
||||
|
||||
@@ -3,9 +3,8 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h"
|
||||
|
||||
@@ -3,9 +3,8 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||
|
||||
@@ -23,10 +23,12 @@
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 8, 4, 1)
|
||||
|
||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
||||
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||
instantiate_gemm_shapes_helper(complex64, complex64_t, complex64, complex64_t);
|
||||
// clang-format on
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h"
|
||||
|
||||
#define instantiate_gemm( \
|
||||
|
||||
@@ -60,6 +60,7 @@
|
||||
instantiate_gemm_shapes_helper(float16, half, float32, float);
|
||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
|
||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||
instantiate_gemm_shapes_helper(complex64, complex64_t, complex64, complex64_t);
|
||||
|
||||
#define instantiate_accum(oname, otype, aname, atype) \
|
||||
instantiate_kernel( \
|
||||
@@ -71,4 +72,5 @@ instantiate_gemm_shapes_helper(float32, float, float32, float);
|
||||
|
||||
instantiate_accum(bfloat16, bfloat16_t, float32, float);
|
||||
instantiate_accum(float16, half, float32, float);
|
||||
instantiate_accum(float32, float, float32, float); // clang-format on
|
||||
instantiate_accum(float32, float, float32, float);
|
||||
instantiate_accum(complex64, complex64_t, complex64, complex64_t); // clang-format on
|
||||
|
||||
@@ -421,6 +421,16 @@ METAL_FUNC void tile_matmad(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
struct TransformNone<complex64_t, InT> {
|
||||
static METAL_FUNC complex64_t apply(complex64_t x) {
|
||||
return x;
|
||||
}
|
||||
static METAL_FUNC complex64_t apply(complex64_t x, complex64_t) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
@@ -731,5 +741,406 @@ struct BlockMMA {
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename U,
|
||||
int BM,
|
||||
int BN,
|
||||
int BK,
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
short lda_tgp,
|
||||
short ldb_tgp,
|
||||
typename AccumType,
|
||||
typename Epilogue>
|
||||
struct BlockMMA<
|
||||
complex64_t,
|
||||
U,
|
||||
BM,
|
||||
BN,
|
||||
BK,
|
||||
WM,
|
||||
WN,
|
||||
transpose_a,
|
||||
transpose_b,
|
||||
lda_tgp,
|
||||
ldb_tgp,
|
||||
AccumType,
|
||||
Epilogue> {
|
||||
static_assert(
|
||||
metal::is_same_v<AccumType, float>,
|
||||
"BlockMMA<complex64_t,...> expects float accumulators");
|
||||
static_assert(
|
||||
metal::is_same_v<U, complex64_t>,
|
||||
"For complex BlockMMA, U must be complex64_t; use a different epilogue for projections");
|
||||
// MMAFrag size
|
||||
STEEL_CONST short kFragSize = 8;
|
||||
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
|
||||
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
STEEL_CONST short TM_stride = kFragSize * WM;
|
||||
// Warp tile simdgroup matrix strides along M
|
||||
STEEL_CONST short TN_stride = kFragSize * WN;
|
||||
|
||||
// Warp tile size along M
|
||||
STEEL_CONST short TM = BM / (kFragSize * WM);
|
||||
// Warp tile size along N
|
||||
STEEL_CONST short TN = BN / (kFragSize * WN);
|
||||
|
||||
// Threadgroup A strides
|
||||
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
|
||||
STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
|
||||
|
||||
// Threadgroup B strides
|
||||
STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
|
||||
STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
|
||||
|
||||
// Threadgroup strides along K
|
||||
STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
|
||||
STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
|
||||
|
||||
// When indexing complex as float[2]
|
||||
STEEL_CONST short A_str_m_f = A_str_m * 2;
|
||||
STEEL_CONST short A_str_k_f = A_str_k * 2;
|
||||
STEEL_CONST short B_str_k_f = B_str_k * 2;
|
||||
STEEL_CONST short B_str_n_f = B_str_n * 2;
|
||||
STEEL_CONST short tile_stride_a_f = tile_stride_a * 2;
|
||||
STEEL_CONST short tile_stride_b_f = tile_stride_b * 2;
|
||||
|
||||
// Accumulators (real/imag)
|
||||
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile_r;
|
||||
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile_i;
|
||||
|
||||
// Offsets within threadgroup
|
||||
short sm, sn;
|
||||
short As_offset, Bs_offset;
|
||||
|
||||
/* Constructor */
|
||||
METAL_FUNC BlockMMA(
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]]) {
|
||||
// Determine thread position in simdgroup matrix
|
||||
short tm = kFragSize * (simd_group_id / WN);
|
||||
short tn = kFragSize * (simd_group_id % WN);
|
||||
|
||||
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
|
||||
sm = simd_coord.y;
|
||||
sn = simd_coord.x;
|
||||
|
||||
// Determine thread and simdgroup offset
|
||||
As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // (M,K)
|
||||
Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // (K,N)
|
||||
|
||||
sm += tm;
|
||||
sn += tn;
|
||||
}
|
||||
|
||||
/* Karatsuba MMA: 3 real MMAs per K-chunk */
|
||||
METAL_FUNC void mma(
|
||||
const threadgroup complex64_t* As,
|
||||
const threadgroup complex64_t* Bs) {
|
||||
// Adjust for simdgroup and thread location
|
||||
As += As_offset;
|
||||
Bs += Bs_offset;
|
||||
threadgroup const float* As_f =
|
||||
reinterpret_cast<threadgroup const float*>(As);
|
||||
threadgroup const float* Bs_f =
|
||||
reinterpret_cast<threadgroup const float*>(Bs);
|
||||
|
||||
// Iterate over BK in blocks of kFragSize
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short kk = 0; kk < BK; kk += kFragSize) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
MMATile<AccumType, TM, 1, MMAFrag_acc_t> Ar, Ai;
|
||||
Ar.template load<float, WM, 1, A_str_m_f, A_str_k_f>(As_f + 0);
|
||||
Ai.template load<float, WM, 1, A_str_m_f, A_str_k_f>(As_f + 1);
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
MMATile<AccumType, 1, TN, MMAFrag_acc_t> Br, Bi;
|
||||
Br.template load<float, 1, WN, B_str_k_f, B_str_n_f>(Bs_f + 0);
|
||||
Bi.template load<float, 1, WN, B_str_k_f, B_str_n_f>(Bs_f + 1);
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// P = Ar*Br ; Q = Ai*Bi ; R = (Ar+Ai)*(Br+Bi)
|
||||
MMATile<AccumType, TM, TN, MMAFrag_acc_t> P, Q, R;
|
||||
|
||||
tile_matmad(P, Ar, Br, P);
|
||||
tile_matmad(Q, Ai, Bi, Q);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < decltype(Ar)::kElemsPerTile; ++i)
|
||||
Ar.elems()[i] += Ai.elems()[i];
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < decltype(Br)::kElemsPerTile; ++i)
|
||||
Br.elems()[i] += Bi.elems()[i];
|
||||
|
||||
tile_matmad(R, Ar, Br, R);
|
||||
|
||||
// C_r += P - Q ; C_i -= Q
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; ++i) {
|
||||
const auto p = P.elems()[i];
|
||||
const auto q = Q.elems()[i];
|
||||
const auto r = R.elems()[i];
|
||||
Ctile_r.elems()[i] += (p - q);
|
||||
Ctile_i.elems()[i] += (r - p - q);
|
||||
}
|
||||
|
||||
// Progress to next simdgroup tile
|
||||
As_f += tile_stride_a_f;
|
||||
Bs_f += tile_stride_b_f;
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(device U* D, const int ldd) {
|
||||
// Adjust for simdgroup and thread location
|
||||
D += sm * ldd + sn;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
thread const auto& r = Ctile_r.frag_at(i, j);
|
||||
thread const auto& im = Ctile_i.frag_at(i, j);
|
||||
int off = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {
|
||||
D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_slice(device U* D, const int ldd, short2 start, short2 stop) {
|
||||
D += sm * ldd + sn;
|
||||
start -= short2(sn, sm);
|
||||
stop -= short2(sn, sm);
|
||||
|
||||
if (stop.y <= 0 || stop.x <= 0)
|
||||
return;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; ++i) {
|
||||
const int row = i * TM_stride;
|
||||
if (row >= start.y && row < stop.y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; ++j) {
|
||||
const int off = row * ldd + (j * TN_stride);
|
||||
thread const auto& r = Ctile_r.frag_at(i, j);
|
||||
thread const auto& im = Ctile_i.frag_at(i, j);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; ++k) {
|
||||
const int col = j * TN_stride + k;
|
||||
if (col >= start.x && col < stop.x) {
|
||||
D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void
|
||||
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
|
||||
D += sm * ldd + sn;
|
||||
dst_tile_dims -= short2(sn, sm);
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
int off = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
thread const auto& r = Ctile_r.frag_at(i, j);
|
||||
thread const auto& im = Ctile_i.frag_at(i, j);
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {
|
||||
if ((j * TN_stride + k) < dst_tile_dims.x) {
|
||||
D[off + k] = Epilogue::apply(complex64_t(r[k], im[k]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Apply epilogue */
|
||||
template <typename UnaryEpilogue>
|
||||
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; i++) {
|
||||
complex64_t out = epilogue_op.apply(
|
||||
complex64_t(Ctile_r.elems()[i], Ctile_i.elems()[i]));
|
||||
Ctile_r.elems()[i] = out.real;
|
||||
Ctile_i.elems()[i] = out.imag;
|
||||
}
|
||||
}
|
||||
|
||||
/* Apply epilogue */
|
||||
template <typename BinaryEpilogue>
|
||||
METAL_FUNC void apply_epilogue(
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
thread const BinaryEpilogue& epilogue_op) {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in Cr, Ci
|
||||
thread auto& r = Ctile_r.frag_at(i, j);
|
||||
thread auto& im = Ctile_i.frag_at(i, j);
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) {
|
||||
complex64_t out = epilogue_op.apply(
|
||||
complex64_t(r[k], im[k]), C[offset_c + k * fdc]);
|
||||
r[k] = out.real;
|
||||
im[k] = out.imag;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Apply epilogue */
|
||||
template <typename BinaryEpilogue>
|
||||
METAL_FUNC void apply_epilogue_safe(
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
short2 dst_tile_dims,
|
||||
thread const BinaryEpilogue& epilogue_op) {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
dst_tile_dims -= short2(sn, sm);
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in Cr, Ci
|
||||
thread auto& r = Ctile_r.frag_at(i, j);
|
||||
thread auto& im = Ctile_i.frag_at(i, j);
|
||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
|
||||
constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;
|
||||
complex64_t tmp[kelems];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
if ((j * TN_stride + k) < dst_tile_dims.x &&
|
||||
(i * TM_stride) < dst_tile_dims.y) {
|
||||
tmp[k] = C[offset_c + k * fdc];
|
||||
} else {
|
||||
tmp[k] = complex64_t(0.0f, 0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
complex64_t out = epilogue_op.apply(complex64_t(r[k], im[k]), tmp[k]);
|
||||
r[k] = out.real;
|
||||
im[k] = out.imag;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Store results from simdgroup_matrix results into device memory */
|
||||
METAL_FUNC void store_result(
|
||||
device U* D,
|
||||
const int ldd,
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
thread const Epilogue& epilogue_op) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
D += (sm)*ldd + sn;
|
||||
|
||||
constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;
|
||||
|
||||
// Loop over all simdgroup tiles
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < TM; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in Cr, Ci
|
||||
thread const auto& r = Ctile_r.frag_at(i, j);
|
||||
thread const auto& im = Ctile_i.frag_at(i, j);
|
||||
int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
int off_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
D[off_d + k] =
|
||||
epilogue_op.apply(complex64_t(r[k], im[k]), C[off_c + k * fdc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store_result_safe(
|
||||
device U* D,
|
||||
const int ldd,
|
||||
const device U* C,
|
||||
const int ldc,
|
||||
const int fdc,
|
||||
short2 dst_tile_dims,
|
||||
thread const Epilogue& epilogue_op) const {
|
||||
// Adjust for simdgroup and thread location
|
||||
C += (sm)*ldc + (sn)*fdc;
|
||||
D += (sm)*ldd + sn;
|
||||
dst_tile_dims -= short2(sn, sm);
|
||||
|
||||
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
||||
return;
|
||||
|
||||
constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int i = 0; i < TM; i++) {
|
||||
if (i * TM_stride < dst_tile_dims.y) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int j = 0; j < TN; j++) {
|
||||
// Get accumulated result and associated offset in Cr, Ci
|
||||
thread const auto& r = Ctile_r.frag_at(i, j);
|
||||
thread const auto& im = Ctile_i.frag_at(i, j);
|
||||
int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
||||
int off_d = (i * TM_stride) * ldd + (j * TN_stride);
|
||||
|
||||
// Apply epilogue
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < kelems; k++) {
|
||||
if ((j * TN_stride + k) < dst_tile_dims.x) {
|
||||
D[off_d + k] = epilogue_op.apply(
|
||||
complex64_t(r[k], im[k]), C[off_c + k * fdc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace steel
|
||||
} // namespace mlx
|
||||
|
||||
@@ -48,7 +48,8 @@ struct TransformAxpby {
|
||||
}
|
||||
|
||||
METAL_FUNC OutT apply(InT x, OutT c) const {
|
||||
return static_cast<OutT>(x * alpha + (beta * c));
|
||||
return static_cast<OutT>(
|
||||
x * static_cast<InT>(alpha) + (static_cast<OutT>(beta) * c));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -86,7 +86,13 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
||||
|
||||
#define GEMM_TPARAM_MACRO(devc) \
|
||||
if (devc == 'g' || devc == 'p') { /* Small device */ \
|
||||
if (!transpose_a && transpose_b) { /* nt */ \
|
||||
if (out.dtype() == complex64) { \
|
||||
bm = 64; \
|
||||
bn = 32; \
|
||||
bk = 8; \
|
||||
wm = 4; \
|
||||
wn = 1; \
|
||||
} else if (!transpose_a && transpose_b) { /* nt */ \
|
||||
bm = 64; \
|
||||
bn = 32; \
|
||||
bk = 32; \
|
||||
@@ -362,7 +368,11 @@ void steel_gemm_splitk_axpby(
|
||||
int gemm_k_iterations = (K / bk) / split_k_partitions;
|
||||
int split_k_partition_size = gemm_k_iterations * bk;
|
||||
|
||||
array C_split({split_k_partitions, M, N}, float32, nullptr, {});
|
||||
array C_split(
|
||||
{split_k_partitions, M, N},
|
||||
issubdtype(out.dtype(), complexfloating) ? complex64 : float32,
|
||||
nullptr,
|
||||
{});
|
||||
C_split.set_data(allocator::malloc(C_split.nbytes()));
|
||||
copies.push_back(C_split);
|
||||
|
||||
@@ -807,9 +817,8 @@ inline void gemv(
|
||||
|
||||
void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
throw std::runtime_error(
|
||||
"[matmul] Does not yet support non-floating point types.");
|
||||
if (!issubdtype(out.dtype(), inexact)) {
|
||||
throw std::runtime_error("[matmul] dtype must be inexact.");
|
||||
}
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
@@ -1338,7 +1347,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
|
||||
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
|
||||
<< "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n")
|
||||
<< "aligned" << "_K_" << (k_aligned ? "t" : "n") << "aligned";
|
||||
<< "aligned"
|
||||
<< "_K_" << (k_aligned ? "t" : "n") << "aligned";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
@@ -76,6 +76,19 @@ void dispatch_float_types(Dtype dt, std::string_view tag, F&& f) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void dispatch_inexact_types(Dtype dt, std::string_view tag, F&& f) {
|
||||
switch (dt) {
|
||||
MLX_INTERNAL_DTYPE_SWITCH_FLOATS();
|
||||
MLX_INTERNAL_DTYPE_SWITCH_CASE(complex64, complex64_t);
|
||||
default:
|
||||
std::ostringstream msg;
|
||||
msg << tag << " Only inexact (float/complex) types supported but " << dt
|
||||
<< " was provided";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void dispatch_int_float_types(Dtype dt, std::string_view tag, F&& f) {
|
||||
switch (dt) {
|
||||
|
||||
43
mlx/ops.cpp
43
mlx/ops.cpp
@@ -1228,14 +1228,14 @@ array pad(
|
||||
if (low_pad_size[i] < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Invalid low padding size (" << low_pad_size[i]
|
||||
<< ") passed to pad" << " for axis " << i
|
||||
<< ") passed to pad for axis " << i
|
||||
<< ". Padding sizes must be non-negative";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (high_pad_size[i] < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "Invalid high padding size (" << high_pad_size[i]
|
||||
<< ") passed to pad" << " for axis " << i
|
||||
<< ") passed to pad for axis " << i
|
||||
<< ". Padding sizes must be non-negative";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
@@ -2859,26 +2859,6 @@ array matmul(
|
||||
"have at least one dimension.");
|
||||
}
|
||||
|
||||
// complex matmul using Karatsuba's Algorithm
|
||||
if (a.dtype() == complex64 || b.dtype() == complex64) {
|
||||
// Extract real and imaginary parts
|
||||
auto a_real = real(a, s);
|
||||
auto a_imag = imag(a, s);
|
||||
auto b_real = real(b, s);
|
||||
auto b_imag = imag(b, s);
|
||||
|
||||
// Compute real and imaginary components of the result
|
||||
auto m1 = matmul(a_real, b_real, s);
|
||||
auto m2 = matmul(a_imag, b_imag, s);
|
||||
auto m3 = matmul(add(a_real, a_imag, s), add(b_real, b_imag, s), s);
|
||||
|
||||
auto c_real = subtract(m1, m2, s);
|
||||
auto c_imag = subtract(m3, add(m1, m2, s), s);
|
||||
|
||||
return add(
|
||||
c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s);
|
||||
}
|
||||
|
||||
if (a.ndim() == 1) {
|
||||
// Insert a singleton dim in the beginning
|
||||
a = expand_dims(a, 0, s);
|
||||
@@ -2898,11 +2878,11 @@ array matmul(
|
||||
// Type promotion
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
|
||||
if (!issubdtype(out_type, floating)) {
|
||||
if (!issubdtype(out_type, inexact)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[matmul] Only real floating point types are supported but "
|
||||
<< a.dtype() << " and " << b.dtype() << " were provided which results"
|
||||
<< " in " << out_type << ", which is not a real floating point type.";
|
||||
msg << "[matmul] Only inexact types are supported but " << a.dtype()
|
||||
<< " and " << b.dtype() << " were provided which results"
|
||||
<< " in " << out_type << ", which is not a floating point type.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (a.dtype() != out_type) {
|
||||
@@ -3146,8 +3126,8 @@ array take_along_axis(
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (axis + a.ndim() < 0 || axis >= static_cast<int>(a.ndim())) {
|
||||
std::ostringstream msg;
|
||||
msg << "[take_along_axis] Received invalid axis " << " for array with "
|
||||
<< a.ndim() << " dimensions.";
|
||||
msg << "[take_along_axis] Received invalid axis for array with " << a.ndim()
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
@@ -3184,7 +3164,7 @@ array scatter_axis(
|
||||
(mode == ScatterAxis::None) ? "[put_along_axis]" : "[scatter_add_axis]";
|
||||
if (axis + a.ndim() < 0 || axis >= static_cast<int>(a.ndim())) {
|
||||
std::ostringstream msg;
|
||||
msg << prefix << " Received invalid axis " << " for array with " << a.ndim()
|
||||
msg << prefix << " Received invalid axis for array with " << a.ndim()
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
@@ -3592,7 +3572,7 @@ run_conv_checks(const array& in, const array& wt, int n_dim, int groups) {
|
||||
if (in.ndim() != n_dim + 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[conv] Invalid input array with " << in.ndim() << " dimensions for "
|
||||
<< n_dim << "D convolution." << " Expected an array with " << n_dim + 2
|
||||
<< n_dim << "D convolution. Expected an array with " << n_dim + 2
|
||||
<< " dimensions following the format [N, ..., C_in].";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
@@ -4141,7 +4121,8 @@ std::vector<array> quantize(
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
|
||||
<< "the quantization group size " << group_size
|
||||
<< ". However the provided " << " matrix has shape " << w.shape();
|
||||
<< ". However the provided "
|
||||
<< " matrix has shape " << w.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user