mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
WIP (cpu)
This commit is contained in:
@@ -10,7 +10,7 @@ namespace mlx::core {
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void arange(T start, T next, array& out, size_t size, Stream stream) {
|
||||
void arange(T start, T next, array& out, int64_t size, Stream stream) {
|
||||
auto ptr = out.data<T>();
|
||||
auto step_size = next - start;
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
|
||||
@@ -17,7 +17,12 @@ namespace mlx::core {
|
||||
namespace {
|
||||
|
||||
template <typename Op>
|
||||
void binary(const array& a, const array& b, array& out, Op /* op */, Stream stream) {
|
||||
void binary(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op /* op */,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
|
||||
@@ -33,8 +33,8 @@ void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
|
||||
N = a.shape(-1),
|
||||
size = a.size()]() mutable {
|
||||
char uplo = (upper) ? 'L' : 'U';
|
||||
size_t num_matrices = size / (N * N);
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
int64_t num_matrices = size / (N * N);
|
||||
for (int64_t i = 0; i < num_matrices; i++) {
|
||||
// Compute Cholesky factorization.
|
||||
int info;
|
||||
potrf<T>(
|
||||
|
||||
@@ -12,12 +12,12 @@ void matmul(
|
||||
T* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
int64_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
|
||||
@@ -34,7 +34,7 @@ void matmul_bnns(
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
size_t /* ldc */,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
@@ -52,7 +52,7 @@ void matmul_bnns(
|
||||
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
if (beta != 1.0 && beta != 0.0) {
|
||||
// scale the output
|
||||
for (auto i = 0; i < batch_size * M * N; ++i) {
|
||||
for (size_t i = 0; i < batch_size * M * N; ++i) {
|
||||
out[i] *= beta;
|
||||
}
|
||||
beta = 1.0;
|
||||
@@ -127,7 +127,7 @@ void matmul_bnns(
|
||||
auto bnns_filter =
|
||||
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
BNNSFilterApplyTwoInput(
|
||||
bnns_filter,
|
||||
reinterpret_cast<const uint8_t*>(
|
||||
@@ -148,12 +148,12 @@ void matmul<float16_t>(
|
||||
float16_t* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
int64_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
@@ -183,12 +183,12 @@ void matmul<bfloat16_t>(
|
||||
bfloat16_t* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
int64_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
|
||||
@@ -13,20 +13,20 @@ void matmul<float>(
|
||||
float* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
int64_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
int64_t M = a_shape[ndim - 2];
|
||||
int64_t N = b_shape[ndim - 1];
|
||||
int64_t K = a_shape[ndim - 1];
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
cblas_sgemm(
|
||||
@@ -54,20 +54,20 @@ void matmul<double>(
|
||||
double* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
int64_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
int64_t M = a_shape[ndim - 2];
|
||||
int64_t N = b_shape[ndim - 1];
|
||||
int64_t K = a_shape[ndim - 1];
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
cblas_dgemm(
|
||||
@@ -95,20 +95,20 @@ void matmul<complex64_t>(
|
||||
complex64_t* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
int64_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
int64_t M = a_shape[ndim - 2];
|
||||
int64_t N = b_shape[ndim - 1];
|
||||
int64_t K = a_shape[ndim - 1];
|
||||
auto calpha = static_cast<complex64_t>(alpha);
|
||||
auto cbeta = static_cast<complex64_t>(beta);
|
||||
|
||||
|
||||
@@ -11,9 +11,9 @@ namespace mlx::core {
|
||||
|
||||
// n = 2^k component
|
||||
template <typename T>
|
||||
void hadamard_n(T* out, int n, int m, float scale, size_t size) {
|
||||
void hadamard_n(T* out, int n, int /* m */, float scale, int64_t size) {
|
||||
for (int b = 0; b < size / n; b++) {
|
||||
size_t loc = b * n;
|
||||
int64_t loc = b * n;
|
||||
T* data_ptr = out + loc;
|
||||
int h = 1;
|
||||
int n_over_2 = n / 2;
|
||||
@@ -37,7 +37,7 @@ void hadamard_n(T* out, int n, int m, float scale, size_t size) {
|
||||
|
||||
// m component
|
||||
template <typename T>
|
||||
void hadamard_m(T* out, int n, int m, float scale, size_t size) {
|
||||
void hadamard_m(T* out, int n, int m, float scale, int64_t size) {
|
||||
auto h_matrices = hadamard_matrices();
|
||||
auto& matrix = h_matrices[m];
|
||||
auto start = 1;
|
||||
@@ -45,7 +45,7 @@ void hadamard_m(T* out, int n, int m, float scale, size_t size) {
|
||||
std::vector<bool> hmat_vec;
|
||||
while (end != std::string_view::npos) {
|
||||
auto row = matrix.substr(start, end - start);
|
||||
for (int i = 0; i < row.length(); i++) {
|
||||
for (int i = 0; i < std::ssize(row); i++) {
|
||||
hmat_vec.push_back(row[i] == '+');
|
||||
}
|
||||
start = end + 1;
|
||||
@@ -53,7 +53,7 @@ void hadamard_m(T* out, int n, int m, float scale, size_t size) {
|
||||
}
|
||||
|
||||
for (int b = 0; b < size / m / n; b++) {
|
||||
size_t loc = b * n * m;
|
||||
int64_t loc = b * n * m;
|
||||
T* data_ptr = out + loc;
|
||||
for (int i = 0; i < n; i++) {
|
||||
std::vector<float> out(m);
|
||||
|
||||
@@ -78,7 +78,7 @@ void gather(
|
||||
can_copy = true;
|
||||
|
||||
// Ignore leading 1s
|
||||
int i = 0;
|
||||
int64_t i = 0;
|
||||
for (; i < slice_sizes.size() && slice_sizes[i] == 1; ++i)
|
||||
;
|
||||
|
||||
@@ -91,7 +91,7 @@ void gather(
|
||||
can_copy = true;
|
||||
|
||||
// Ignore trailing 1s
|
||||
int i = slice_sizes.size() - 1;
|
||||
int64_t i = slice_sizes.size() - 1;
|
||||
for (; i >= 0 && slice_sizes[i] == 1; --i)
|
||||
;
|
||||
|
||||
@@ -101,11 +101,11 @@ void gather(
|
||||
can_copy = (src.shape(i) == slice_sizes[i]);
|
||||
}
|
||||
}
|
||||
size_t slice_size = 1;
|
||||
int64_t slice_size = 1;
|
||||
for (auto s : slice_sizes) {
|
||||
slice_size *= s;
|
||||
}
|
||||
size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
|
||||
int64_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
|
||||
const T* src_ptr = src.data<T>();
|
||||
T* dst_ptr = out.data<T>();
|
||||
|
||||
@@ -115,10 +115,10 @@ void gather(
|
||||
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
|
||||
}
|
||||
|
||||
size_t out_idx = 0;
|
||||
for (int idx = 0; idx < ind_size; idx++) {
|
||||
size_t src_idx = 0;
|
||||
for (int ii = 0; ii < inds.size(); ++ii) {
|
||||
int64_t out_idx = 0;
|
||||
for (int64_t idx = 0; idx < ind_size; idx++) {
|
||||
int64_t src_idx = 0;
|
||||
for (int ii = 0; ii < std::ssize(inds); ++ii) {
|
||||
auto ax = axes[ii];
|
||||
auto idx_loc = its[ii].loc;
|
||||
its[ii].step();
|
||||
@@ -134,7 +134,7 @@ void gather(
|
||||
src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx);
|
||||
out_idx += slice_size;
|
||||
} else {
|
||||
for (int jj = 0; jj < slice_size; jj++) {
|
||||
for (int64_t jj = 0; jj < slice_size; jj++) {
|
||||
dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];
|
||||
src_it.step();
|
||||
}
|
||||
@@ -403,11 +403,11 @@ void scatter(
|
||||
const std::vector<int>& axes) {
|
||||
int nind = inds.size();
|
||||
auto inds_ndim = updates.ndim() - out.ndim();
|
||||
size_t n_updates = nind ? inds[0].size() : 1;
|
||||
int64_t n_updates = nind ? inds[0].size() : 1;
|
||||
|
||||
Shape update_shape(
|
||||
updates.shape().begin() + inds_ndim, updates.shape().end());
|
||||
size_t update_size = 1;
|
||||
int64_t update_size = 1;
|
||||
for (auto us : update_shape) {
|
||||
update_size *= us;
|
||||
}
|
||||
@@ -418,9 +418,9 @@ void scatter(
|
||||
|
||||
auto out_ptr = out.data<InT>();
|
||||
auto upd_ptr = updates.data<InT>();
|
||||
for (int i = 0; i < n_updates; ++i) {
|
||||
size_t out_offset = 0;
|
||||
for (int j = 0; j < inds.size(); ++j) {
|
||||
for (int64_t i = 0; i < n_updates; ++i) {
|
||||
int64_t out_offset = 0;
|
||||
for (int j = 0; j < std::ssize(inds); ++j) {
|
||||
auto ax = axes[j];
|
||||
auto idx_loc = its[j].loc;
|
||||
its[j].step();
|
||||
@@ -429,7 +429,7 @@ void scatter(
|
||||
out_offset += (idx_val * out.strides()[ax]);
|
||||
}
|
||||
update_it.seek(i * update_size);
|
||||
for (int j = 0; j < update_size; ++j) {
|
||||
for (int64_t j = 0; j < update_size; ++j) {
|
||||
OpT{}(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc);
|
||||
update_it.step();
|
||||
out_it.step();
|
||||
|
||||
@@ -25,7 +25,7 @@ inline void mask_matrix(
|
||||
const int64_t Y_data_str,
|
||||
const int64_t X_mask_str,
|
||||
const int64_t Y_mask_str,
|
||||
const size_t mask_offset) {
|
||||
const int64_t mask_offset) {
|
||||
int tX = (X + block_size - 1) / block_size;
|
||||
int tY = (Y + block_size - 1) / block_size;
|
||||
|
||||
@@ -61,13 +61,13 @@ inline void segmented_mm(
|
||||
T* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides,
|
||||
size_t num_segments,
|
||||
int64_t num_segments,
|
||||
const Shape& segments_shape,
|
||||
const Strides& segments_strides) {
|
||||
int ndim = a_shape.size();
|
||||
@@ -149,9 +149,9 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto [b_transposed, ldb, b, b_copied] =
|
||||
check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_);
|
||||
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
int64_t M = a.shape(-2);
|
||||
int64_t N = b.shape(-1);
|
||||
int64_t K = a.shape(-1);
|
||||
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
@@ -172,8 +172,8 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
int batch_idx,
|
||||
int X,
|
||||
int Y,
|
||||
size_t X_data_str,
|
||||
size_t Y_data_str,
|
||||
int64_t X_data_str,
|
||||
int64_t Y_data_str,
|
||||
const Shape& mask_shape,
|
||||
const Strides& mask_strides,
|
||||
bool is_bool) {
|
||||
@@ -253,7 +253,7 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto a_ptr = a.data<float>();
|
||||
auto b_ptr = b.data<float>();
|
||||
auto out_ptr = out.data<float>();
|
||||
size_t num_matrices = out.size() / (M * size_t(N));
|
||||
int64_t num_matrices = out.size() / (M * int64_t(N));
|
||||
auto ldc = out.shape(-1);
|
||||
|
||||
encoder.dispatch([a_ptr,
|
||||
@@ -394,9 +394,9 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
int64_t M = a.shape(-2);
|
||||
int64_t N = b.shape(-1);
|
||||
int64_t K = a.shape(-1);
|
||||
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
@@ -413,7 +413,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Get batch dims
|
||||
auto batch_size_out = out.size() / (M * N);
|
||||
size_t matrix_stride_out = M * N;
|
||||
int64_t matrix_stride_out = M * N;
|
||||
|
||||
auto get_batch_dims = [](const auto& v) {
|
||||
return decltype(v){v.begin(), v.end() - 2};
|
||||
|
||||
@@ -48,7 +48,7 @@ static std::pair<array, bool> compute_dynamic_offset(
|
||||
auto compute_offset =
|
||||
[strides, axes, offset = offset.data<int64_t>()](const auto* indices) {
|
||||
int64_t offset_ = 0;
|
||||
for (int i = 0; i < axes.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(axes); ++i) {
|
||||
offset_ += indices[i] * strides[axes[i]];
|
||||
}
|
||||
offset[0] = offset_;
|
||||
@@ -193,9 +193,9 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
flags.row_contiguous = false;
|
||||
flags.col_contiguous = false;
|
||||
flags.contiguous = false;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(inputs); i++) {
|
||||
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
||||
size_t data_offset = strides[axis_] * sizes[i];
|
||||
int64_t data_offset = strides[axis_] * sizes[i];
|
||||
out_slice.copy_shared_buffer(
|
||||
out, strides, flags, out_slice.size(), data_offset);
|
||||
copy_cpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
|
||||
@@ -205,7 +205,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
constexpr size_t extra_bytes = 16384;
|
||||
constexpr int64_t extra_bytes = 16384;
|
||||
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
|
||||
(in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous))) {
|
||||
@@ -254,8 +254,8 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
copy_cpu(val, out, CopyType::Scalar, stream());
|
||||
|
||||
// Find offset for start of input values
|
||||
size_t data_offset = 0;
|
||||
for (int i = 0; i < axes_.size(); i++) {
|
||||
int64_t data_offset = 0;
|
||||
for (int i = 0; i < std::ssize(axes_); i++) {
|
||||
auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i];
|
||||
data_offset += out.strides()[ax] * low_pad_size_[i];
|
||||
}
|
||||
@@ -274,10 +274,10 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
// keys has shape (N1, ..., NK, 2)
|
||||
// out has shape (N1, ..., NK, M1, M2, ...)
|
||||
auto& keys = inputs[0];
|
||||
size_t num_keys = keys.size() / 2;
|
||||
int64_t num_keys = keys.size() / 2;
|
||||
|
||||
size_t elems_per_key = out.size() / num_keys;
|
||||
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
int64_t elems_per_key = out.size() / num_keys;
|
||||
int64_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto kptr = inputs[0].data<uint32_t>();
|
||||
@@ -291,8 +291,8 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
num_keys,
|
||||
kshape = keys.shape(),
|
||||
kstrides = keys.strides()]() mutable {
|
||||
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
|
||||
auto half_size = out_skip / 2;
|
||||
int64_t out_skip = (bytes_per_key + 4 - 1) / 4;
|
||||
uintptr_t half_size = out_skip / 2;
|
||||
bool even = out_skip % 2 == 0;
|
||||
for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) {
|
||||
auto ptr = reinterpret_cast<uint32_t*>(cptr);
|
||||
|
||||
@@ -13,7 +13,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
const int M = a.shape(-2);
|
||||
const int N = a.shape(-1);
|
||||
const int lda = M;
|
||||
size_t num_matrices = a.size() / (M * N);
|
||||
int64_t num_matrices = a.size() / (M * N);
|
||||
|
||||
// Copy A to inplace input and make it col-contiguous
|
||||
array in(a.shape(), a.dtype(), nullptr, {});
|
||||
@@ -54,7 +54,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
auto work = allocator::malloc(sizeof(T) * lwork);
|
||||
|
||||
// Loop over matrices
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
for (int64_t i = 0; i < num_matrices; ++i) {
|
||||
// Solve
|
||||
geqrf<T>(
|
||||
&M,
|
||||
@@ -68,7 +68,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
}
|
||||
allocator::free(work);
|
||||
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
for (int64_t i = 0; i < num_matrices; ++i) {
|
||||
/// num_reflectors x N
|
||||
for (int j = 0; j < num_reflectors; ++j) {
|
||||
for (int k = 0; k < j; ++k) {
|
||||
@@ -97,7 +97,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
work = allocator::malloc(sizeof(T) * lwork);
|
||||
|
||||
// Loop over matrices
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
for (int64_t i = 0; i < num_matrices; ++i) {
|
||||
// Compute Q
|
||||
orgqr<T>(
|
||||
&M,
|
||||
@@ -111,7 +111,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
&info);
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
for (int64_t i = 0; i < num_matrices; ++i) {
|
||||
// M x num_reflectors
|
||||
for (int j = 0; j < M; ++j) {
|
||||
for (int k = 0; k < num_reflectors; ++k) {
|
||||
|
||||
@@ -79,7 +79,8 @@ Simd<T, N> sincos(Simd<T, N> in) {
|
||||
|
||||
// Get the polynom selection mask. There is one polynom for 0 <= x <= Pi/4
|
||||
// and another one for Pi/4<x<=Pi/2. Both branches will be computed.
|
||||
auto poly_mask = (emm2 & 2) != 0;
|
||||
auto poly_mask =
|
||||
(emm2 & static_cast<uint32_t>(2)) != static_cast<uint32_t>(0);
|
||||
|
||||
// The magic pass: "Extended precision modular arithmetic"
|
||||
// x = ((x - y * DP1) - y * DP2) - y * DP3
|
||||
@@ -87,8 +88,8 @@ Simd<T, N> sincos(Simd<T, N> in) {
|
||||
x = fma(y, Simd<float, N>(-2.4187564849853515625e-4f), x);
|
||||
x = fma(y, Simd<float, N>(-3.77489497744594108e-8f), x);
|
||||
|
||||
sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != 0);
|
||||
auto sign_mask_cos = ((emm2 - 2) & 4) != 0;
|
||||
sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != static_cast<uint32_t>(0));
|
||||
auto sign_mask_cos = ((emm2 - 2) & 4) != static_cast<uint32_t>(0);
|
||||
|
||||
// Evaluate the first polynom (0 <= x <= Pi/4) in y1,
|
||||
// and the second polynom (Pi/4 <= x <= 0) in y2
|
||||
|
||||
@@ -120,8 +120,8 @@ template <typename T>
|
||||
void sort(array& out, int axis) {
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + out.ndim() : axis;
|
||||
size_t in_size = out.size();
|
||||
size_t n_rows = in_size / out.shape(axis);
|
||||
int64_t in_size = out.size();
|
||||
int64_t n_rows = in_size / out.shape(axis);
|
||||
|
||||
auto remaining_shape = out.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
@@ -136,7 +136,7 @@ void sort(array& out, int axis) {
|
||||
ContiguousIterator src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
auto out_ptr = out.data<T>();
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
for (int64_t i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out_ptr + src_it.loc;
|
||||
|
||||
StridedIterator st(data_ptr, axis_stride, 0);
|
||||
@@ -151,7 +151,7 @@ template <typename T, typename IdxT = uint32_t>
|
||||
void argsort(const array& in, array& out, int axis) {
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
int64_t n_rows = in.size() / in.shape(axis);
|
||||
|
||||
auto in_remaining_shape = in.shape();
|
||||
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
|
||||
@@ -176,7 +176,7 @@ void argsort(const array& in, array& out, int axis) {
|
||||
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
|
||||
auto in_ptr = in.data<T>();
|
||||
auto out_ptr = out.data<IdxT>();
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
for (int64_t i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in_ptr + in_it.loc;
|
||||
IdxT* idx_ptr = out_ptr + out_it.loc;
|
||||
|
||||
@@ -214,8 +214,8 @@ template <typename T>
|
||||
void partition(array& out, int axis, int kth) {
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + out.ndim() : axis;
|
||||
size_t in_size = out.size();
|
||||
size_t n_rows = in_size / out.shape(axis);
|
||||
int64_t in_size = out.size();
|
||||
int64_t n_rows = in_size / out.shape(axis);
|
||||
|
||||
auto remaining_shape = out.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
@@ -232,7 +232,7 @@ void partition(array& out, int axis, int kth) {
|
||||
ContiguousIterator src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
auto out_ptr = out.data<T>();
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
for (int64_t i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out_ptr + src_it.loc;
|
||||
src_it.step();
|
||||
|
||||
@@ -248,7 +248,7 @@ template <typename T, typename IdxT = uint32_t>
|
||||
void argpartition(const array& in, array& out, int axis, int kth) {
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
int64_t n_rows = in.size() / in.shape(axis);
|
||||
|
||||
auto in_remaining_shape = in.shape();
|
||||
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
|
||||
@@ -277,7 +277,7 @@ void argpartition(const array& in, array& out, int axis, int kth) {
|
||||
auto in_ptr = in.data<T>();
|
||||
auto out_ptr = out.data<IdxT>();
|
||||
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
for (int64_t i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in_ptr + in_it.loc;
|
||||
IdxT* idx_ptr = out_ptr + out_it.loc;
|
||||
in_it.step();
|
||||
|
||||
@@ -27,7 +27,7 @@ void svd_impl(
|
||||
const int N = a.shape(-1);
|
||||
const int K = std::min(M, N);
|
||||
|
||||
size_t num_matrices = a.size() / (M * N);
|
||||
int64_t num_matrices = a.size() / (M * N);
|
||||
|
||||
// lapack clobbers the input, so we have to make a copy.
|
||||
array in(a.shape(), a.dtype(), nullptr, {});
|
||||
@@ -121,7 +121,7 @@ void svd_impl(
|
||||
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
|
||||
// Loop over matrices.
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
for (int64_t i = 0; i < num_matrices; i++) {
|
||||
gesdd<T>(
|
||||
/* jobz = */ jobz,
|
||||
// M and N are swapped since lapack expects column-major.
|
||||
@@ -153,10 +153,10 @@ void svd_impl(
|
||||
|
||||
template <typename T>
|
||||
void compute_svd(
|
||||
const array& a,
|
||||
bool compute_uv,
|
||||
std::vector<array>& outputs,
|
||||
Stream stream) {}
|
||||
const array& /* a */,
|
||||
bool /* compute_uv */,
|
||||
std::vector<array>& /* outputs */,
|
||||
Stream /* stream */) {}
|
||||
|
||||
void SVD::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
|
||||
@@ -136,7 +136,7 @@ void ternary_op(
|
||||
if (topt == TernaryOpType::ScalarScalarScalar) {
|
||||
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
|
||||
} else if (topt == TernaryOpType::VectorVectorVector) {
|
||||
for (size_t i = 0; i < out.size(); ++i) {
|
||||
for (int64_t i = 0; i < out.size(); ++i) {
|
||||
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
|
||||
a_ptr++;
|
||||
b_ptr++;
|
||||
|
||||
@@ -10,8 +10,8 @@
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T, typename U = T, typename Op>
|
||||
void unary_op(const T* a, U* out, size_t shape, size_t stride) {
|
||||
for (size_t i = 0; i < shape; i += 1) {
|
||||
void unary_op(const T* a, U* out, int64_t shape, int64_t stride) {
|
||||
for (int64_t i = 0; i < shape; i += 1) {
|
||||
out[i] = Op{}(*a);
|
||||
a += stride;
|
||||
}
|
||||
@@ -38,14 +38,14 @@ void unary_op(const array& a, array& out, Op) {
|
||||
src++;
|
||||
}
|
||||
} else {
|
||||
size_t shape = ndim > 0 ? a.shape().back() : 1;
|
||||
size_t stride = ndim > 0 ? a.strides().back() : 1;
|
||||
int64_t shape = ndim > 0 ? a.shape().back() : 1;
|
||||
int64_t stride = ndim > 0 ? a.strides().back() : 1;
|
||||
if (ndim <= 1) {
|
||||
unary_op<T, U, Op>(src, dst, shape, stride);
|
||||
return;
|
||||
}
|
||||
auto it = ContiguousIterator(a.shape(), a.strides(), ndim - 1);
|
||||
for (size_t elem = 0; elem < a.size(); elem += shape) {
|
||||
for (int64_t elem = 0; elem < a.size(); elem += shape) {
|
||||
unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride);
|
||||
it.step();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user