WIP (cpu)

This commit is contained in:
Ronan Collobert
2025-10-30 16:24:51 -07:00
parent 76ef1e98f3
commit 45a8b226af
16 changed files with 121 additions and 115 deletions

View File

@@ -10,7 +10,7 @@ namespace mlx::core {
namespace { namespace {
template <typename T> template <typename T>
void arange(T start, T next, array& out, size_t size, Stream stream) { void arange(T start, T next, array& out, int64_t size, Stream stream) {
auto ptr = out.data<T>(); auto ptr = out.data<T>();
auto step_size = next - start; auto step_size = next - start;
auto& encoder = cpu::get_command_encoder(stream); auto& encoder = cpu::get_command_encoder(stream);

View File

@@ -17,7 +17,12 @@ namespace mlx::core {
namespace { namespace {
template <typename Op> template <typename Op>
void binary(const array& a, const array& b, array& out, Op /* op */, Stream stream) { void binary(
const array& a,
const array& b,
array& out,
Op /* op */,
Stream stream) {
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt); set_binary_op_output_data(a, b, out, bopt);

View File

@@ -33,8 +33,8 @@ void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
N = a.shape(-1), N = a.shape(-1),
size = a.size()]() mutable { size = a.size()]() mutable {
char uplo = (upper) ? 'L' : 'U'; char uplo = (upper) ? 'L' : 'U';
size_t num_matrices = size / (N * N); int64_t num_matrices = size / (N * N);
for (int i = 0; i < num_matrices; i++) { for (int64_t i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization. // Compute Cholesky factorization.
int info; int info;
potrf<T>( potrf<T>(

View File

@@ -12,12 +12,12 @@ void matmul(
T* out, T* out,
bool a_transposed, bool a_transposed,
bool b_transposed, bool b_transposed,
size_t lda, int64_t lda,
size_t ldb, int64_t ldb,
size_t ldc, int64_t ldc,
float alpha, float alpha,
float beta, float beta,
size_t batch_size, int64_t batch_size,
const Shape& a_shape, const Shape& a_shape,
const Strides& a_strides, const Strides& a_strides,
const Shape& b_shape, const Shape& b_shape,

View File

@@ -34,7 +34,7 @@ void matmul_bnns(
bool b_transposed, bool b_transposed,
size_t lda, size_t lda,
size_t ldb, size_t ldb,
size_t ldc, size_t /* ldc */,
float alpha, float alpha,
float beta, float beta,
size_t batch_size, size_t batch_size,
@@ -52,7 +52,7 @@ void matmul_bnns(
#pragma GCC diagnostic ignored "-Wdeprecated-declarations" #pragma GCC diagnostic ignored "-Wdeprecated-declarations"
if (beta != 1.0 && beta != 0.0) { if (beta != 1.0 && beta != 0.0) {
// scale the output // scale the output
for (auto i = 0; i < batch_size * M * N; ++i) { for (size_t i = 0; i < batch_size * M * N; ++i) {
out[i] *= beta; out[i] *= beta;
} }
beta = 1.0; beta = 1.0;
@@ -127,7 +127,7 @@ void matmul_bnns(
auto bnns_filter = auto bnns_filter =
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr); BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
for (int i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
BNNSFilterApplyTwoInput( BNNSFilterApplyTwoInput(
bnns_filter, bnns_filter,
reinterpret_cast<const uint8_t*>( reinterpret_cast<const uint8_t*>(
@@ -148,12 +148,12 @@ void matmul<float16_t>(
float16_t* out, float16_t* out,
bool a_transposed, bool a_transposed,
bool b_transposed, bool b_transposed,
size_t lda, int64_t lda,
size_t ldb, int64_t ldb,
size_t ldc, int64_t ldc,
float alpha, float alpha,
float beta, float beta,
size_t batch_size, int64_t batch_size,
const Shape& a_shape, const Shape& a_shape,
const Strides& a_strides, const Strides& a_strides,
const Shape& b_shape, const Shape& b_shape,
@@ -183,12 +183,12 @@ void matmul<bfloat16_t>(
bfloat16_t* out, bfloat16_t* out,
bool a_transposed, bool a_transposed,
bool b_transposed, bool b_transposed,
size_t lda, int64_t lda,
size_t ldb, int64_t ldb,
size_t ldc, int64_t ldc,
float alpha, float alpha,
float beta, float beta,
size_t batch_size, int64_t batch_size,
const Shape& a_shape, const Shape& a_shape,
const Strides& a_strides, const Strides& a_strides,
const Shape& b_shape, const Shape& b_shape,

View File

@@ -13,20 +13,20 @@ void matmul<float>(
float* out, float* out,
bool a_transposed, bool a_transposed,
bool b_transposed, bool b_transposed,
size_t lda, int64_t lda,
size_t ldb, int64_t ldb,
size_t ldc, int64_t ldc,
float alpha, float alpha,
float beta, float beta,
size_t batch_size, int64_t batch_size,
const Shape& a_shape, const Shape& a_shape,
const Strides& a_strides, const Strides& a_strides,
const Shape& b_shape, const Shape& b_shape,
const Strides& b_strides) { const Strides& b_strides) {
auto ndim = a_shape.size(); auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2]; int64_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1]; int64_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1]; int64_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
cblas_sgemm( cblas_sgemm(
@@ -54,20 +54,20 @@ void matmul<double>(
double* out, double* out,
bool a_transposed, bool a_transposed,
bool b_transposed, bool b_transposed,
size_t lda, int64_t lda,
size_t ldb, int64_t ldb,
size_t ldc, int64_t ldc,
float alpha, float alpha,
float beta, float beta,
size_t batch_size, int64_t batch_size,
const Shape& a_shape, const Shape& a_shape,
const Strides& a_strides, const Strides& a_strides,
const Shape& b_shape, const Shape& b_shape,
const Strides& b_strides) { const Strides& b_strides) {
auto ndim = a_shape.size(); auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2]; int64_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1]; int64_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1]; int64_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
cblas_dgemm( cblas_dgemm(
@@ -95,20 +95,20 @@ void matmul<complex64_t>(
complex64_t* out, complex64_t* out,
bool a_transposed, bool a_transposed,
bool b_transposed, bool b_transposed,
size_t lda, int64_t lda,
size_t ldb, int64_t ldb,
size_t ldc, int64_t ldc,
float alpha, float alpha,
float beta, float beta,
size_t batch_size, int64_t batch_size,
const Shape& a_shape, const Shape& a_shape,
const Strides& a_strides, const Strides& a_strides,
const Shape& b_shape, const Shape& b_shape,
const Strides& b_strides) { const Strides& b_strides) {
auto ndim = a_shape.size(); auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2]; int64_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1]; int64_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1]; int64_t K = a_shape[ndim - 1];
auto calpha = static_cast<complex64_t>(alpha); auto calpha = static_cast<complex64_t>(alpha);
auto cbeta = static_cast<complex64_t>(beta); auto cbeta = static_cast<complex64_t>(beta);

View File

@@ -11,9 +11,9 @@ namespace mlx::core {
// n = 2^k component // n = 2^k component
template <typename T> template <typename T>
void hadamard_n(T* out, int n, int m, float scale, size_t size) { void hadamard_n(T* out, int n, int /* m */, float scale, int64_t size) {
for (int b = 0; b < size / n; b++) { for (int b = 0; b < size / n; b++) {
size_t loc = b * n; int64_t loc = b * n;
T* data_ptr = out + loc; T* data_ptr = out + loc;
int h = 1; int h = 1;
int n_over_2 = n / 2; int n_over_2 = n / 2;
@@ -37,7 +37,7 @@ void hadamard_n(T* out, int n, int m, float scale, size_t size) {
// m component // m component
template <typename T> template <typename T>
void hadamard_m(T* out, int n, int m, float scale, size_t size) { void hadamard_m(T* out, int n, int m, float scale, int64_t size) {
auto h_matrices = hadamard_matrices(); auto h_matrices = hadamard_matrices();
auto& matrix = h_matrices[m]; auto& matrix = h_matrices[m];
auto start = 1; auto start = 1;
@@ -45,7 +45,7 @@ void hadamard_m(T* out, int n, int m, float scale, size_t size) {
std::vector<bool> hmat_vec; std::vector<bool> hmat_vec;
while (end != std::string_view::npos) { while (end != std::string_view::npos) {
auto row = matrix.substr(start, end - start); auto row = matrix.substr(start, end - start);
for (int i = 0; i < row.length(); i++) { for (int i = 0; i < std::ssize(row); i++) {
hmat_vec.push_back(row[i] == '+'); hmat_vec.push_back(row[i] == '+');
} }
start = end + 1; start = end + 1;
@@ -53,7 +53,7 @@ void hadamard_m(T* out, int n, int m, float scale, size_t size) {
} }
for (int b = 0; b < size / m / n; b++) { for (int b = 0; b < size / m / n; b++) {
size_t loc = b * n * m; int64_t loc = b * n * m;
T* data_ptr = out + loc; T* data_ptr = out + loc;
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
std::vector<float> out(m); std::vector<float> out(m);

View File

@@ -78,7 +78,7 @@ void gather(
can_copy = true; can_copy = true;
// Ignore leading 1s // Ignore leading 1s
int i = 0; int64_t i = 0;
for (; i < slice_sizes.size() && slice_sizes[i] == 1; ++i) for (; i < slice_sizes.size() && slice_sizes[i] == 1; ++i)
; ;
@@ -91,7 +91,7 @@ void gather(
can_copy = true; can_copy = true;
// Ignore trailing 1s // Ignore trailing 1s
int i = slice_sizes.size() - 1; int64_t i = slice_sizes.size() - 1;
for (; i >= 0 && slice_sizes[i] == 1; --i) for (; i >= 0 && slice_sizes[i] == 1; --i)
; ;
@@ -101,11 +101,11 @@ void gather(
can_copy = (src.shape(i) == slice_sizes[i]); can_copy = (src.shape(i) == slice_sizes[i]);
} }
} }
size_t slice_size = 1; int64_t slice_size = 1;
for (auto s : slice_sizes) { for (auto s : slice_sizes) {
slice_size *= s; slice_size *= s;
} }
size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size; int64_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
const T* src_ptr = src.data<T>(); const T* src_ptr = src.data<T>();
T* dst_ptr = out.data<T>(); T* dst_ptr = out.data<T>();
@@ -115,10 +115,10 @@ void gather(
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim()); src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
} }
size_t out_idx = 0; int64_t out_idx = 0;
for (int idx = 0; idx < ind_size; idx++) { for (int64_t idx = 0; idx < ind_size; idx++) {
size_t src_idx = 0; int64_t src_idx = 0;
for (int ii = 0; ii < inds.size(); ++ii) { for (int ii = 0; ii < std::ssize(inds); ++ii) {
auto ax = axes[ii]; auto ax = axes[ii];
auto idx_loc = its[ii].loc; auto idx_loc = its[ii].loc;
its[ii].step(); its[ii].step();
@@ -134,7 +134,7 @@ void gather(
src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx); src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx);
out_idx += slice_size; out_idx += slice_size;
} else { } else {
for (int jj = 0; jj < slice_size; jj++) { for (int64_t jj = 0; jj < slice_size; jj++) {
dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc]; dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];
src_it.step(); src_it.step();
} }
@@ -403,11 +403,11 @@ void scatter(
const std::vector<int>& axes) { const std::vector<int>& axes) {
int nind = inds.size(); int nind = inds.size();
auto inds_ndim = updates.ndim() - out.ndim(); auto inds_ndim = updates.ndim() - out.ndim();
size_t n_updates = nind ? inds[0].size() : 1; int64_t n_updates = nind ? inds[0].size() : 1;
Shape update_shape( Shape update_shape(
updates.shape().begin() + inds_ndim, updates.shape().end()); updates.shape().begin() + inds_ndim, updates.shape().end());
size_t update_size = 1; int64_t update_size = 1;
for (auto us : update_shape) { for (auto us : update_shape) {
update_size *= us; update_size *= us;
} }
@@ -418,9 +418,9 @@ void scatter(
auto out_ptr = out.data<InT>(); auto out_ptr = out.data<InT>();
auto upd_ptr = updates.data<InT>(); auto upd_ptr = updates.data<InT>();
for (int i = 0; i < n_updates; ++i) { for (int64_t i = 0; i < n_updates; ++i) {
size_t out_offset = 0; int64_t out_offset = 0;
for (int j = 0; j < inds.size(); ++j) { for (int j = 0; j < std::ssize(inds); ++j) {
auto ax = axes[j]; auto ax = axes[j];
auto idx_loc = its[j].loc; auto idx_loc = its[j].loc;
its[j].step(); its[j].step();
@@ -429,7 +429,7 @@ void scatter(
out_offset += (idx_val * out.strides()[ax]); out_offset += (idx_val * out.strides()[ax]);
} }
update_it.seek(i * update_size); update_it.seek(i * update_size);
for (int j = 0; j < update_size; ++j) { for (int64_t j = 0; j < update_size; ++j) {
OpT{}(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc); OpT{}(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc);
update_it.step(); update_it.step();
out_it.step(); out_it.step();

View File

@@ -25,7 +25,7 @@ inline void mask_matrix(
const int64_t Y_data_str, const int64_t Y_data_str,
const int64_t X_mask_str, const int64_t X_mask_str,
const int64_t Y_mask_str, const int64_t Y_mask_str,
const size_t mask_offset) { const int64_t mask_offset) {
int tX = (X + block_size - 1) / block_size; int tX = (X + block_size - 1) / block_size;
int tY = (Y + block_size - 1) / block_size; int tY = (Y + block_size - 1) / block_size;
@@ -61,13 +61,13 @@ inline void segmented_mm(
T* out, T* out,
bool a_transposed, bool a_transposed,
bool b_transposed, bool b_transposed,
size_t lda, int64_t lda,
size_t ldb, int64_t ldb,
const Shape& a_shape, const Shape& a_shape,
const Strides& a_strides, const Strides& a_strides,
const Shape& b_shape, const Shape& b_shape,
const Strides& b_strides, const Strides& b_strides,
size_t num_segments, int64_t num_segments,
const Shape& segments_shape, const Shape& segments_shape,
const Strides& segments_strides) { const Strides& segments_strides) {
int ndim = a_shape.size(); int ndim = a_shape.size();
@@ -149,9 +149,9 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto [b_transposed, ldb, b, b_copied] = auto [b_transposed, ldb, b, b_copied] =
check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_); check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_);
size_t M = a.shape(-2); int64_t M = a.shape(-2);
size_t N = b.shape(-1); int64_t N = b.shape(-1);
size_t K = a.shape(-1); int64_t K = a.shape(-1);
if (M == 0 || N == 0) { if (M == 0 || N == 0) {
return; return;
@@ -172,8 +172,8 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
int batch_idx, int batch_idx,
int X, int X,
int Y, int Y,
size_t X_data_str, int64_t X_data_str,
size_t Y_data_str, int64_t Y_data_str,
const Shape& mask_shape, const Shape& mask_shape,
const Strides& mask_strides, const Strides& mask_strides,
bool is_bool) { bool is_bool) {
@@ -253,7 +253,7 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto a_ptr = a.data<float>(); auto a_ptr = a.data<float>();
auto b_ptr = b.data<float>(); auto b_ptr = b.data<float>();
auto out_ptr = out.data<float>(); auto out_ptr = out.data<float>();
size_t num_matrices = out.size() / (M * size_t(N)); int64_t num_matrices = out.size() / (M * int64_t(N));
auto ldc = out.shape(-1); auto ldc = out.shape(-1);
encoder.dispatch([a_ptr, encoder.dispatch([a_ptr,
@@ -394,9 +394,9 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto [a_transposed, lda, a] = check_transpose(a_pre); auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre); auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2); int64_t M = a.shape(-2);
size_t N = b.shape(-1); int64_t N = b.shape(-1);
size_t K = a.shape(-1); int64_t K = a.shape(-1);
if (M == 0 || N == 0) { if (M == 0 || N == 0) {
return; return;
@@ -413,7 +413,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
// Get batch dims // Get batch dims
auto batch_size_out = out.size() / (M * N); auto batch_size_out = out.size() / (M * N);
size_t matrix_stride_out = M * N; int64_t matrix_stride_out = M * N;
auto get_batch_dims = [](const auto& v) { auto get_batch_dims = [](const auto& v) {
return decltype(v){v.begin(), v.end() - 2}; return decltype(v){v.begin(), v.end() - 2};

View File

@@ -48,7 +48,7 @@ static std::pair<array, bool> compute_dynamic_offset(
auto compute_offset = auto compute_offset =
[strides, axes, offset = offset.data<int64_t>()](const auto* indices) { [strides, axes, offset = offset.data<int64_t>()](const auto* indices) {
int64_t offset_ = 0; int64_t offset_ = 0;
for (int i = 0; i < axes.size(); ++i) { for (int i = 0; i < std::ssize(axes); ++i) {
offset_ += indices[i] * strides[axes[i]]; offset_ += indices[i] * strides[axes[i]];
} }
offset[0] = offset_; offset[0] = offset_;
@@ -193,9 +193,9 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
flags.row_contiguous = false; flags.row_contiguous = false;
flags.col_contiguous = false; flags.col_contiguous = false;
flags.contiguous = false; flags.contiguous = false;
for (int i = 0; i < inputs.size(); i++) { for (int i = 0; i < std::ssize(inputs); i++) {
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
size_t data_offset = strides[axis_] * sizes[i]; int64_t data_offset = strides[axis_] * sizes[i];
out_slice.copy_shared_buffer( out_slice.copy_shared_buffer(
out, strides, flags, out_slice.size(), data_offset); out, strides, flags, out_slice.size(), data_offset);
copy_cpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream()); copy_cpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
@@ -205,7 +205,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) { void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
constexpr size_t extra_bytes = 16384; constexpr int64_t extra_bytes = 16384;
if (in.buffer_size() <= out.nbytes() + extra_bytes && if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous || (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) { (allow_col_major_ && in.flags().col_contiguous))) {
@@ -254,8 +254,8 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
copy_cpu(val, out, CopyType::Scalar, stream()); copy_cpu(val, out, CopyType::Scalar, stream());
// Find offset for start of input values // Find offset for start of input values
size_t data_offset = 0; int64_t data_offset = 0;
for (int i = 0; i < axes_.size(); i++) { for (int i = 0; i < std::ssize(axes_); i++) {
auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i]; auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i];
data_offset += out.strides()[ax] * low_pad_size_[i]; data_offset += out.strides()[ax] * low_pad_size_[i];
} }
@@ -274,10 +274,10 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
// keys has shape (N1, ..., NK, 2) // keys has shape (N1, ..., NK, 2)
// out has shape (N1, ..., NK, M1, M2, ...) // out has shape (N1, ..., NK, M1, M2, ...)
auto& keys = inputs[0]; auto& keys = inputs[0];
size_t num_keys = keys.size() / 2; int64_t num_keys = keys.size() / 2;
size_t elems_per_key = out.size() / num_keys; int64_t elems_per_key = out.size() / num_keys;
size_t bytes_per_key = out.itemsize() * elems_per_key; int64_t bytes_per_key = out.itemsize() * elems_per_key;
out.set_data(allocator::malloc(out.nbytes())); out.set_data(allocator::malloc(out.nbytes()));
auto kptr = inputs[0].data<uint32_t>(); auto kptr = inputs[0].data<uint32_t>();
@@ -291,8 +291,8 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
num_keys, num_keys,
kshape = keys.shape(), kshape = keys.shape(),
kstrides = keys.strides()]() mutable { kstrides = keys.strides()]() mutable {
size_t out_skip = (bytes_per_key + 4 - 1) / 4; int64_t out_skip = (bytes_per_key + 4 - 1) / 4;
auto half_size = out_skip / 2; uintptr_t half_size = out_skip / 2;
bool even = out_skip % 2 == 0; bool even = out_skip % 2 == 0;
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) { for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {
auto ptr = reinterpret_cast<uint32_t*>(cptr); auto ptr = reinterpret_cast<uint32_t*>(cptr);

View File

@@ -13,7 +13,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
const int M = a.shape(-2); const int M = a.shape(-2);
const int N = a.shape(-1); const int N = a.shape(-1);
const int lda = M; const int lda = M;
size_t num_matrices = a.size() / (M * N); int64_t num_matrices = a.size() / (M * N);
// Copy A to inplace input and make it col-contiguous // Copy A to inplace input and make it col-contiguous
array in(a.shape(), a.dtype(), nullptr, {}); array in(a.shape(), a.dtype(), nullptr, {});
@@ -54,7 +54,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
auto work = allocator::malloc(sizeof(T) * lwork); auto work = allocator::malloc(sizeof(T) * lwork);
// Loop over matrices // Loop over matrices
for (int i = 0; i < num_matrices; ++i) { for (int64_t i = 0; i < num_matrices; ++i) {
// Solve // Solve
geqrf<T>( geqrf<T>(
&M, &M,
@@ -68,7 +68,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
} }
allocator::free(work); allocator::free(work);
for (int i = 0; i < num_matrices; ++i) { for (int64_t i = 0; i < num_matrices; ++i) {
/// num_reflectors x N /// num_reflectors x N
for (int j = 0; j < num_reflectors; ++j) { for (int j = 0; j < num_reflectors; ++j) {
for (int k = 0; k < j; ++k) { for (int k = 0; k < j; ++k) {
@@ -97,7 +97,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
work = allocator::malloc(sizeof(T) * lwork); work = allocator::malloc(sizeof(T) * lwork);
// Loop over matrices // Loop over matrices
for (int i = 0; i < num_matrices; ++i) { for (int64_t i = 0; i < num_matrices; ++i) {
// Compute Q // Compute Q
orgqr<T>( orgqr<T>(
&M, &M,
@@ -111,7 +111,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
&info); &info);
} }
for (int i = 0; i < num_matrices; ++i) { for (int64_t i = 0; i < num_matrices; ++i) {
// M x num_reflectors // M x num_reflectors
for (int j = 0; j < M; ++j) { for (int j = 0; j < M; ++j) {
for (int k = 0; k < num_reflectors; ++k) { for (int k = 0; k < num_reflectors; ++k) {

View File

@@ -79,7 +79,8 @@ Simd<T, N> sincos(Simd<T, N> in) {
// Get the polynom selection mask. There is one polynom for 0 <= x <= Pi/4 // Get the polynom selection mask. There is one polynom for 0 <= x <= Pi/4
// and another one for Pi/4<x<=Pi/2. Both branches will be computed. // and another one for Pi/4<x<=Pi/2. Both branches will be computed.
auto poly_mask = (emm2 & 2) != 0; auto poly_mask =
(emm2 & static_cast<uint32_t>(2)) != static_cast<uint32_t>(0);
// The magic pass: "Extended precision modular arithmetic" // The magic pass: "Extended precision modular arithmetic"
// x = ((x - y * DP1) - y * DP2) - y * DP3 // x = ((x - y * DP1) - y * DP2) - y * DP3
@@ -87,8 +88,8 @@ Simd<T, N> sincos(Simd<T, N> in) {
x = fma(y, Simd<float, N>(-2.4187564849853515625e-4f), x); x = fma(y, Simd<float, N>(-2.4187564849853515625e-4f), x);
x = fma(y, Simd<float, N>(-3.77489497744594108e-8f), x); x = fma(y, Simd<float, N>(-3.77489497744594108e-8f), x);
sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != 0); sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != static_cast<uint32_t>(0));
auto sign_mask_cos = ((emm2 - 2) & 4) != 0; auto sign_mask_cos = ((emm2 - 2) & 4) != static_cast<uint32_t>(0);
// Evaluate the first polynom (0 <= x <= Pi/4) in y1, // Evaluate the first polynom (0 <= x <= Pi/4) in y1,
// and the second polynom (Pi/4 <= x <= 0) in y2 // and the second polynom (Pi/4 <= x <= 0) in y2

View File

@@ -120,8 +120,8 @@ template <typename T>
void sort(array& out, int axis) { void sort(array& out, int axis) {
// Get axis, shape and stride info // Get axis, shape and stride info
axis = axis < 0 ? axis + out.ndim() : axis; axis = axis < 0 ? axis + out.ndim() : axis;
size_t in_size = out.size(); int64_t in_size = out.size();
size_t n_rows = in_size / out.shape(axis); int64_t n_rows = in_size / out.shape(axis);
auto remaining_shape = out.shape(); auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis); remaining_shape.erase(remaining_shape.begin() + axis);
@@ -136,7 +136,7 @@ void sort(array& out, int axis) {
ContiguousIterator src_it( ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size()); remaining_shape, remaining_strides, remaining_shape.size());
auto out_ptr = out.data<T>(); auto out_ptr = out.data<T>();
for (int i = 0; i < n_rows; i++) { for (int64_t i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc; T* data_ptr = out_ptr + src_it.loc;
StridedIterator st(data_ptr, axis_stride, 0); StridedIterator st(data_ptr, axis_stride, 0);
@@ -151,7 +151,7 @@ template <typename T, typename IdxT = uint32_t>
void argsort(const array& in, array& out, int axis) { void argsort(const array& in, array& out, int axis) {
// Get axis, shape and stride info // Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis; axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis); int64_t n_rows = in.size() / in.shape(axis);
auto in_remaining_shape = in.shape(); auto in_remaining_shape = in.shape();
in_remaining_shape.erase(in_remaining_shape.begin() + axis); in_remaining_shape.erase(in_remaining_shape.begin() + axis);
@@ -176,7 +176,7 @@ void argsort(const array& in, array& out, int axis) {
out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
auto in_ptr = in.data<T>(); auto in_ptr = in.data<T>();
auto out_ptr = out.data<IdxT>(); auto out_ptr = out.data<IdxT>();
for (int i = 0; i < n_rows; i++) { for (int64_t i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc; const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc; IdxT* idx_ptr = out_ptr + out_it.loc;
@@ -214,8 +214,8 @@ template <typename T>
void partition(array& out, int axis, int kth) { void partition(array& out, int axis, int kth) {
// Get axis, shape and stride info // Get axis, shape and stride info
axis = axis < 0 ? axis + out.ndim() : axis; axis = axis < 0 ? axis + out.ndim() : axis;
size_t in_size = out.size(); int64_t in_size = out.size();
size_t n_rows = in_size / out.shape(axis); int64_t n_rows = in_size / out.shape(axis);
auto remaining_shape = out.shape(); auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis); remaining_shape.erase(remaining_shape.begin() + axis);
@@ -232,7 +232,7 @@ void partition(array& out, int axis, int kth) {
ContiguousIterator src_it( ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size()); remaining_shape, remaining_strides, remaining_shape.size());
auto out_ptr = out.data<T>(); auto out_ptr = out.data<T>();
for (int i = 0; i < n_rows; i++) { for (int64_t i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc; T* data_ptr = out_ptr + src_it.loc;
src_it.step(); src_it.step();
@@ -248,7 +248,7 @@ template <typename T, typename IdxT = uint32_t>
void argpartition(const array& in, array& out, int axis, int kth) { void argpartition(const array& in, array& out, int axis, int kth) {
// Get axis, shape and stride info // Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis; axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis); int64_t n_rows = in.size() / in.shape(axis);
auto in_remaining_shape = in.shape(); auto in_remaining_shape = in.shape();
in_remaining_shape.erase(in_remaining_shape.begin() + axis); in_remaining_shape.erase(in_remaining_shape.begin() + axis);
@@ -277,7 +277,7 @@ void argpartition(const array& in, array& out, int axis, int kth) {
auto in_ptr = in.data<T>(); auto in_ptr = in.data<T>();
auto out_ptr = out.data<IdxT>(); auto out_ptr = out.data<IdxT>();
for (int i = 0; i < n_rows; i++) { for (int64_t i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc; const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc; IdxT* idx_ptr = out_ptr + out_it.loc;
in_it.step(); in_it.step();

View File

@@ -27,7 +27,7 @@ void svd_impl(
const int N = a.shape(-1); const int N = a.shape(-1);
const int K = std::min(M, N); const int K = std::min(M, N);
size_t num_matrices = a.size() / (M * N); int64_t num_matrices = a.size() / (M * N);
// lapack clobbers the input, so we have to make a copy. // lapack clobbers the input, so we have to make a copy.
array in(a.shape(), a.dtype(), nullptr, {}); array in(a.shape(), a.dtype(), nullptr, {});
@@ -121,7 +121,7 @@ void svd_impl(
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)}; auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
// Loop over matrices. // Loop over matrices.
for (int i = 0; i < num_matrices; i++) { for (int64_t i = 0; i < num_matrices; i++) {
gesdd<T>( gesdd<T>(
/* jobz = */ jobz, /* jobz = */ jobz,
// M and N are swapped since lapack expects column-major. // M and N are swapped since lapack expects column-major.
@@ -153,10 +153,10 @@ void svd_impl(
template <typename T> template <typename T>
void compute_svd( void compute_svd(
const array& a, const array& /* a */,
bool compute_uv, bool /* compute_uv */,
std::vector<array>& outputs, std::vector<array>& /* outputs */,
Stream stream) {} Stream /* stream */) {}
void SVD::eval_cpu( void SVD::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,

View File

@@ -136,7 +136,7 @@ void ternary_op(
if (topt == TernaryOpType::ScalarScalarScalar) { if (topt == TernaryOpType::ScalarScalarScalar) {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr); *out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
} else if (topt == TernaryOpType::VectorVectorVector) { } else if (topt == TernaryOpType::VectorVectorVector) {
for (size_t i = 0; i < out.size(); ++i) { for (int64_t i = 0; i < out.size(); ++i) {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr); *out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
a_ptr++; a_ptr++;
b_ptr++; b_ptr++;

View File

@@ -10,8 +10,8 @@
namespace mlx::core { namespace mlx::core {
template <typename T, typename U = T, typename Op> template <typename T, typename U = T, typename Op>
void unary_op(const T* a, U* out, size_t shape, size_t stride) { void unary_op(const T* a, U* out, int64_t shape, int64_t stride) {
for (size_t i = 0; i < shape; i += 1) { for (int64_t i = 0; i < shape; i += 1) {
out[i] = Op{}(*a); out[i] = Op{}(*a);
a += stride; a += stride;
} }
@@ -38,14 +38,14 @@ void unary_op(const array& a, array& out, Op) {
src++; src++;
} }
} else { } else {
size_t shape = ndim > 0 ? a.shape().back() : 1; int64_t shape = ndim > 0 ? a.shape().back() : 1;
size_t stride = ndim > 0 ? a.strides().back() : 1; int64_t stride = ndim > 0 ? a.strides().back() : 1;
if (ndim <= 1) { if (ndim <= 1) {
unary_op<T, U, Op>(src, dst, shape, stride); unary_op<T, U, Op>(src, dst, shape, stride);
return; return;
} }
auto it = ContiguousIterator(a.shape(), a.strides(), ndim - 1); auto it = ContiguousIterator(a.shape(), a.strides(), ndim - 1);
for (size_t elem = 0; elem < a.size(); elem += shape) { for (int64_t elem = 0; elem < a.size(); elem += shape) {
unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride); unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride);
it.step(); it.step();
} }