mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Faster cpu ops (#1434)
* faster binary and cleaner copy * use recursive template for other ops * more cleanup * fix from cleanup * more clean * fix binary * use contiguous iterator * add 3d * nits * fix * fix? * fix * fix rebase
This commit is contained in:
@@ -24,6 +24,14 @@ void set_unary_output_data(const array& in, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
void unary_op(const T* a, T* out, Op op, size_t shape, size_t stride) {
|
||||
for (size_t i = 0; i < shape; i += 1) {
|
||||
out[i] = op(*a);
|
||||
a += stride;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
void unary_op(const array& a, array& out, Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
@@ -36,10 +44,16 @@ void unary_op(const array& a, array& out, Op op) {
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
T* dst = out.data<T>();
|
||||
for (size_t i = 0; i < out.size(); ++i) {
|
||||
// TODO this is super inefficient, need to fix.
|
||||
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||
dst[i] = op(a_ptr[a_idx]);
|
||||
size_t shape = a.ndim() > 0 ? a.shape(-1) : 1;
|
||||
size_t stride = a.ndim() > 0 ? a.strides(-1) : 1;
|
||||
if (a.ndim() <= 1) {
|
||||
unary_op(a_ptr, dst, op, shape, stride);
|
||||
return;
|
||||
}
|
||||
ContiguousIterator it(a.shape(), a.strides(), a.ndim() - 1);
|
||||
for (size_t elem = 0; elem < a.size(); elem += shape) {
|
||||
unary_op(a_ptr + it.loc, dst + elem, op, shape, stride);
|
||||
it.step();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user