WIP (cpu)

This commit is contained in:
Ronan Collobert
2025-10-30 16:24:51 -07:00
parent 76ef1e98f3
commit 45a8b226af
16 changed files with 121 additions and 115 deletions

View File

@@ -10,7 +10,7 @@ namespace mlx::core {
namespace {
template <typename T>
void arange(T start, T next, array& out, size_t size, Stream stream) {
void arange(T start, T next, array& out, int64_t size, Stream stream) {
auto ptr = out.data<T>();
auto step_size = next - start;
auto& encoder = cpu::get_command_encoder(stream);

View File

@@ -17,7 +17,12 @@ namespace mlx::core {
namespace {
template <typename Op>
void binary(const array& a, const array& b, array& out, Op /* op */, Stream stream) {
void binary(
const array& a,
const array& b,
array& out,
Op /* op */,
Stream stream) {
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt);

View File

@@ -33,8 +33,8 @@ void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
N = a.shape(-1),
size = a.size()]() mutable {
char uplo = (upper) ? 'L' : 'U';
size_t num_matrices = size / (N * N);
for (int i = 0; i < num_matrices; i++) {
int64_t num_matrices = size / (N * N);
for (int64_t i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization.
int info;
potrf<T>(

View File

@@ -12,12 +12,12 @@ void matmul(
T* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
int64_t lda,
int64_t ldb,
int64_t ldc,
float alpha,
float beta,
size_t batch_size,
int64_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,

View File

@@ -34,7 +34,7 @@ void matmul_bnns(
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
size_t /* ldc */,
float alpha,
float beta,
size_t batch_size,
@@ -52,7 +52,7 @@ void matmul_bnns(
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
if (beta != 1.0 && beta != 0.0) {
// scale the output
for (auto i = 0; i < batch_size * M * N; ++i) {
for (size_t i = 0; i < batch_size * M * N; ++i) {
out[i] *= beta;
}
beta = 1.0;
@@ -127,7 +127,7 @@ void matmul_bnns(
auto bnns_filter =
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
for (int i = 0; i < batch_size; ++i) {
for (size_t i = 0; i < batch_size; ++i) {
BNNSFilterApplyTwoInput(
bnns_filter,
reinterpret_cast<const uint8_t*>(
@@ -148,12 +148,12 @@ void matmul<float16_t>(
float16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
int64_t lda,
int64_t ldb,
int64_t ldc,
float alpha,
float beta,
size_t batch_size,
int64_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
@@ -183,12 +183,12 @@ void matmul<bfloat16_t>(
bfloat16_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
int64_t lda,
int64_t ldb,
int64_t ldc,
float alpha,
float beta,
size_t batch_size,
int64_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,

View File

@@ -13,20 +13,20 @@ void matmul<float>(
float* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
int64_t lda,
int64_t ldb,
int64_t ldc,
float alpha,
float beta,
size_t batch_size,
int64_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
int64_t M = a_shape[ndim - 2];
int64_t N = b_shape[ndim - 1];
int64_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) {
cblas_sgemm(
@@ -54,20 +54,20 @@ void matmul<double>(
double* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
int64_t lda,
int64_t ldb,
int64_t ldc,
float alpha,
float beta,
size_t batch_size,
int64_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
int64_t M = a_shape[ndim - 2];
int64_t N = b_shape[ndim - 1];
int64_t K = a_shape[ndim - 1];
for (int i = 0; i < batch_size; ++i) {
cblas_dgemm(
@@ -95,20 +95,20 @@ void matmul<complex64_t>(
complex64_t* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
size_t ldc,
int64_t lda,
int64_t ldb,
int64_t ldc,
float alpha,
float beta,
size_t batch_size,
int64_t batch_size,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides) {
auto ndim = a_shape.size();
size_t M = a_shape[ndim - 2];
size_t N = b_shape[ndim - 1];
size_t K = a_shape[ndim - 1];
int64_t M = a_shape[ndim - 2];
int64_t N = b_shape[ndim - 1];
int64_t K = a_shape[ndim - 1];
auto calpha = static_cast<complex64_t>(alpha);
auto cbeta = static_cast<complex64_t>(beta);

View File

@@ -11,9 +11,9 @@ namespace mlx::core {
// n = 2^k component
template <typename T>
void hadamard_n(T* out, int n, int m, float scale, size_t size) {
void hadamard_n(T* out, int n, int /* m */, float scale, int64_t size) {
for (int b = 0; b < size / n; b++) {
size_t loc = b * n;
int64_t loc = b * n;
T* data_ptr = out + loc;
int h = 1;
int n_over_2 = n / 2;
@@ -37,7 +37,7 @@ void hadamard_n(T* out, int n, int m, float scale, size_t size) {
// m component
template <typename T>
void hadamard_m(T* out, int n, int m, float scale, size_t size) {
void hadamard_m(T* out, int n, int m, float scale, int64_t size) {
auto h_matrices = hadamard_matrices();
auto& matrix = h_matrices[m];
auto start = 1;
@@ -45,7 +45,7 @@ void hadamard_m(T* out, int n, int m, float scale, size_t size) {
std::vector<bool> hmat_vec;
while (end != std::string_view::npos) {
auto row = matrix.substr(start, end - start);
for (int i = 0; i < row.length(); i++) {
for (int i = 0; i < std::ssize(row); i++) {
hmat_vec.push_back(row[i] == '+');
}
start = end + 1;
@@ -53,7 +53,7 @@ void hadamard_m(T* out, int n, int m, float scale, size_t size) {
}
for (int b = 0; b < size / m / n; b++) {
size_t loc = b * n * m;
int64_t loc = b * n * m;
T* data_ptr = out + loc;
for (int i = 0; i < n; i++) {
std::vector<float> out(m);

View File

@@ -78,7 +78,7 @@ void gather(
can_copy = true;
// Ignore leading 1s
int i = 0;
int64_t i = 0;
for (; i < slice_sizes.size() && slice_sizes[i] == 1; ++i)
;
@@ -91,7 +91,7 @@ void gather(
can_copy = true;
// Ignore trailing 1s
int i = slice_sizes.size() - 1;
int64_t i = slice_sizes.size() - 1;
for (; i >= 0 && slice_sizes[i] == 1; --i)
;
@@ -101,11 +101,11 @@ void gather(
can_copy = (src.shape(i) == slice_sizes[i]);
}
}
size_t slice_size = 1;
int64_t slice_size = 1;
for (auto s : slice_sizes) {
slice_size *= s;
}
size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
int64_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
const T* src_ptr = src.data<T>();
T* dst_ptr = out.data<T>();
@@ -115,10 +115,10 @@ void gather(
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
}
size_t out_idx = 0;
for (int idx = 0; idx < ind_size; idx++) {
size_t src_idx = 0;
for (int ii = 0; ii < inds.size(); ++ii) {
int64_t out_idx = 0;
for (int64_t idx = 0; idx < ind_size; idx++) {
int64_t src_idx = 0;
for (int ii = 0; ii < std::ssize(inds); ++ii) {
auto ax = axes[ii];
auto idx_loc = its[ii].loc;
its[ii].step();
@@ -134,7 +134,7 @@ void gather(
src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx);
out_idx += slice_size;
} else {
for (int jj = 0; jj < slice_size; jj++) {
for (int64_t jj = 0; jj < slice_size; jj++) {
dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];
src_it.step();
}
@@ -403,11 +403,11 @@ void scatter(
const std::vector<int>& axes) {
int nind = inds.size();
auto inds_ndim = updates.ndim() - out.ndim();
size_t n_updates = nind ? inds[0].size() : 1;
int64_t n_updates = nind ? inds[0].size() : 1;
Shape update_shape(
updates.shape().begin() + inds_ndim, updates.shape().end());
size_t update_size = 1;
int64_t update_size = 1;
for (auto us : update_shape) {
update_size *= us;
}
@@ -418,9 +418,9 @@ void scatter(
auto out_ptr = out.data<InT>();
auto upd_ptr = updates.data<InT>();
for (int i = 0; i < n_updates; ++i) {
size_t out_offset = 0;
for (int j = 0; j < inds.size(); ++j) {
for (int64_t i = 0; i < n_updates; ++i) {
int64_t out_offset = 0;
for (int j = 0; j < std::ssize(inds); ++j) {
auto ax = axes[j];
auto idx_loc = its[j].loc;
its[j].step();
@@ -429,7 +429,7 @@ void scatter(
out_offset += (idx_val * out.strides()[ax]);
}
update_it.seek(i * update_size);
for (int j = 0; j < update_size; ++j) {
for (int64_t j = 0; j < update_size; ++j) {
OpT{}(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc);
update_it.step();
out_it.step();

View File

@@ -25,7 +25,7 @@ inline void mask_matrix(
const int64_t Y_data_str,
const int64_t X_mask_str,
const int64_t Y_mask_str,
const size_t mask_offset) {
const int64_t mask_offset) {
int tX = (X + block_size - 1) / block_size;
int tY = (Y + block_size - 1) / block_size;
@@ -61,13 +61,13 @@ inline void segmented_mm(
T* out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
int64_t lda,
int64_t ldb,
const Shape& a_shape,
const Strides& a_strides,
const Shape& b_shape,
const Strides& b_strides,
size_t num_segments,
int64_t num_segments,
const Shape& segments_shape,
const Strides& segments_strides) {
int ndim = a_shape.size();
@@ -149,9 +149,9 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto [b_transposed, ldb, b, b_copied] =
check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
int64_t M = a.shape(-2);
int64_t N = b.shape(-1);
int64_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
@@ -172,8 +172,8 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
int batch_idx,
int X,
int Y,
size_t X_data_str,
size_t Y_data_str,
int64_t X_data_str,
int64_t Y_data_str,
const Shape& mask_shape,
const Strides& mask_strides,
bool is_bool) {
@@ -253,7 +253,7 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto a_ptr = a.data<float>();
auto b_ptr = b.data<float>();
auto out_ptr = out.data<float>();
size_t num_matrices = out.size() / (M * size_t(N));
int64_t num_matrices = out.size() / (M * int64_t(N));
auto ldc = out.shape(-1);
encoder.dispatch([a_ptr,
@@ -394,9 +394,9 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
int64_t M = a.shape(-2);
int64_t N = b.shape(-1);
int64_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
@@ -413,7 +413,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
// Get batch dims
auto batch_size_out = out.size() / (M * N);
size_t matrix_stride_out = M * N;
int64_t matrix_stride_out = M * N;
auto get_batch_dims = [](const auto& v) {
return decltype(v){v.begin(), v.end() - 2};

View File

@@ -48,7 +48,7 @@ static std::pair<array, bool> compute_dynamic_offset(
auto compute_offset =
[strides, axes, offset = offset.data<int64_t>()](const auto* indices) {
int64_t offset_ = 0;
for (int i = 0; i < axes.size(); ++i) {
for (int i = 0; i < std::ssize(axes); ++i) {
offset_ += indices[i] * strides[axes[i]];
}
offset[0] = offset_;
@@ -193,9 +193,9 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
flags.row_contiguous = false;
flags.col_contiguous = false;
flags.contiguous = false;
for (int i = 0; i < inputs.size(); i++) {
for (int i = 0; i < std::ssize(inputs); i++) {
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
size_t data_offset = strides[axis_] * sizes[i];
int64_t data_offset = strides[axis_] * sizes[i];
out_slice.copy_shared_buffer(
out, strides, flags, out_slice.size(), data_offset);
copy_cpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
@@ -205,7 +205,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
constexpr size_t extra_bytes = 16384;
constexpr int64_t extra_bytes = 16384;
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
(in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous))) {
@@ -254,8 +254,8 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
copy_cpu(val, out, CopyType::Scalar, stream());
// Find offset for start of input values
size_t data_offset = 0;
for (int i = 0; i < axes_.size(); i++) {
int64_t data_offset = 0;
for (int i = 0; i < std::ssize(axes_); i++) {
auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i];
data_offset += out.strides()[ax] * low_pad_size_[i];
}
@@ -274,10 +274,10 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
// keys has shape (N1, ..., NK, 2)
// out has shape (N1, ..., NK, M1, M2, ...)
auto& keys = inputs[0];
size_t num_keys = keys.size() / 2;
int64_t num_keys = keys.size() / 2;
size_t elems_per_key = out.size() / num_keys;
size_t bytes_per_key = out.itemsize() * elems_per_key;
int64_t elems_per_key = out.size() / num_keys;
int64_t bytes_per_key = out.itemsize() * elems_per_key;
out.set_data(allocator::malloc(out.nbytes()));
auto kptr = inputs[0].data<uint32_t>();
@@ -291,8 +291,8 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
num_keys,
kshape = keys.shape(),
kstrides = keys.strides()]() mutable {
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
auto half_size = out_skip / 2;
int64_t out_skip = (bytes_per_key + 4 - 1) / 4;
uintptr_t half_size = out_skip / 2;
bool even = out_skip % 2 == 0;
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {
auto ptr = reinterpret_cast<uint32_t*>(cptr);

View File

@@ -13,7 +13,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
const int M = a.shape(-2);
const int N = a.shape(-1);
const int lda = M;
size_t num_matrices = a.size() / (M * N);
int64_t num_matrices = a.size() / (M * N);
// Copy A to inplace input and make it col-contiguous
array in(a.shape(), a.dtype(), nullptr, {});
@@ -54,7 +54,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
auto work = allocator::malloc(sizeof(T) * lwork);
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
for (int64_t i = 0; i < num_matrices; ++i) {
// Solve
geqrf<T>(
&M,
@@ -68,7 +68,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
}
allocator::free(work);
for (int i = 0; i < num_matrices; ++i) {
for (int64_t i = 0; i < num_matrices; ++i) {
/// num_reflectors x N
for (int j = 0; j < num_reflectors; ++j) {
for (int k = 0; k < j; ++k) {
@@ -97,7 +97,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
work = allocator::malloc(sizeof(T) * lwork);
// Loop over matrices
for (int i = 0; i < num_matrices; ++i) {
for (int64_t i = 0; i < num_matrices; ++i) {
// Compute Q
orgqr<T>(
&M,
@@ -111,7 +111,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
&info);
}
for (int i = 0; i < num_matrices; ++i) {
for (int64_t i = 0; i < num_matrices; ++i) {
// M x num_reflectors
for (int j = 0; j < M; ++j) {
for (int k = 0; k < num_reflectors; ++k) {

View File

@@ -79,7 +79,8 @@ Simd<T, N> sincos(Simd<T, N> in) {
// Get the polynom selection mask. There is one polynom for 0 <= x <= Pi/4
// and another one for Pi/4<x<=Pi/2. Both branches will be computed.
auto poly_mask = (emm2 & 2) != 0;
auto poly_mask =
(emm2 & static_cast<uint32_t>(2)) != static_cast<uint32_t>(0);
// The magic pass: "Extended precision modular arithmetic"
// x = ((x - y * DP1) - y * DP2) - y * DP3
@@ -87,8 +88,8 @@ Simd<T, N> sincos(Simd<T, N> in) {
x = fma(y, Simd<float, N>(-2.4187564849853515625e-4f), x);
x = fma(y, Simd<float, N>(-3.77489497744594108e-8f), x);
sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != 0);
auto sign_mask_cos = ((emm2 - 2) & 4) != 0;
sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != static_cast<uint32_t>(0));
auto sign_mask_cos = ((emm2 - 2) & 4) != static_cast<uint32_t>(0);
// Evaluate the first polynom (0 <= x <= Pi/4) in y1,
// and the second polynom (Pi/4 <= x <= 0) in y2

View File

@@ -120,8 +120,8 @@ template <typename T>
void sort(array& out, int axis) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + out.ndim() : axis;
size_t in_size = out.size();
size_t n_rows = in_size / out.shape(axis);
int64_t in_size = out.size();
int64_t n_rows = in_size / out.shape(axis);
auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
@@ -136,7 +136,7 @@ void sort(array& out, int axis) {
ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size());
auto out_ptr = out.data<T>();
for (int i = 0; i < n_rows; i++) {
for (int64_t i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc;
StridedIterator st(data_ptr, axis_stride, 0);
@@ -151,7 +151,7 @@ template <typename T, typename IdxT = uint32_t>
void argsort(const array& in, array& out, int axis) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
int64_t n_rows = in.size() / in.shape(axis);
auto in_remaining_shape = in.shape();
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
@@ -176,7 +176,7 @@ void argsort(const array& in, array& out, int axis) {
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
auto in_ptr = in.data<T>();
auto out_ptr = out.data<IdxT>();
for (int i = 0; i < n_rows; i++) {
for (int64_t i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc;
@@ -214,8 +214,8 @@ template <typename T>
void partition(array& out, int axis, int kth) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + out.ndim() : axis;
size_t in_size = out.size();
size_t n_rows = in_size / out.shape(axis);
int64_t in_size = out.size();
int64_t n_rows = in_size / out.shape(axis);
auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
@@ -232,7 +232,7 @@ void partition(array& out, int axis, int kth) {
ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size());
auto out_ptr = out.data<T>();
for (int i = 0; i < n_rows; i++) {
for (int64_t i = 0; i < n_rows; i++) {
T* data_ptr = out_ptr + src_it.loc;
src_it.step();
@@ -248,7 +248,7 @@ template <typename T, typename IdxT = uint32_t>
void argpartition(const array& in, array& out, int axis, int kth) {
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
int64_t n_rows = in.size() / in.shape(axis);
auto in_remaining_shape = in.shape();
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
@@ -277,7 +277,7 @@ void argpartition(const array& in, array& out, int axis, int kth) {
auto in_ptr = in.data<T>();
auto out_ptr = out.data<IdxT>();
for (int i = 0; i < n_rows; i++) {
for (int64_t i = 0; i < n_rows; i++) {
const T* data_ptr = in_ptr + in_it.loc;
IdxT* idx_ptr = out_ptr + out_it.loc;
in_it.step();

View File

@@ -27,7 +27,7 @@ void svd_impl(
const int N = a.shape(-1);
const int K = std::min(M, N);
size_t num_matrices = a.size() / (M * N);
int64_t num_matrices = a.size() / (M * N);
// lapack clobbers the input, so we have to make a copy.
array in(a.shape(), a.dtype(), nullptr, {});
@@ -121,7 +121,7 @@ void svd_impl(
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
// Loop over matrices.
for (int i = 0; i < num_matrices; i++) {
for (int64_t i = 0; i < num_matrices; i++) {
gesdd<T>(
/* jobz = */ jobz,
// M and N are swapped since lapack expects column-major.
@@ -153,10 +153,10 @@ void svd_impl(
template <typename T>
void compute_svd(
const array& a,
bool compute_uv,
std::vector<array>& outputs,
Stream stream) {}
const array& /* a */,
bool /* compute_uv */,
std::vector<array>& /* outputs */,
Stream /* stream */) {}
void SVD::eval_cpu(
const std::vector<array>& inputs,

View File

@@ -136,7 +136,7 @@ void ternary_op(
if (topt == TernaryOpType::ScalarScalarScalar) {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
} else if (topt == TernaryOpType::VectorVectorVector) {
for (size_t i = 0; i < out.size(); ++i) {
for (int64_t i = 0; i < out.size(); ++i) {
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
a_ptr++;
b_ptr++;

View File

@@ -10,8 +10,8 @@
namespace mlx::core {
template <typename T, typename U = T, typename Op>
void unary_op(const T* a, U* out, size_t shape, size_t stride) {
for (size_t i = 0; i < shape; i += 1) {
void unary_op(const T* a, U* out, int64_t shape, int64_t stride) {
for (int64_t i = 0; i < shape; i += 1) {
out[i] = Op{}(*a);
a += stride;
}
@@ -38,14 +38,14 @@ void unary_op(const array& a, array& out, Op) {
src++;
}
} else {
size_t shape = ndim > 0 ? a.shape().back() : 1;
size_t stride = ndim > 0 ? a.strides().back() : 1;
int64_t shape = ndim > 0 ? a.shape().back() : 1;
int64_t stride = ndim > 0 ? a.strides().back() : 1;
if (ndim <= 1) {
unary_op<T, U, Op>(src, dst, shape, stride);
return;
}
auto it = ContiguousIterator(a.shape(), a.strides(), ndim - 1);
for (size_t elem = 0; elem < a.size(); elem += shape) {
for (int64_t elem = 0; elem < a.size(); elem += shape) {
unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride);
it.step();
}