mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Start to cleanup/unify accelerate and common back-ends (Part 1/N) (#1777)
* start to cleanup/unify accelerate and common back-ends * more progress * simplify * add half type and allow infs in simd exp * unify softmax + quantized, more dispatches to simd quantized mm * add sin/cos, use simd in vector-scalar ops * faster CPU vectorize quant * faster erf/erfinv
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/simd/simd.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@@ -38,8 +39,19 @@ void unary_op(const array& a, array& out, Op op) {
|
||||
if (a.flags().contiguous) {
|
||||
set_unary_output_data(a, out);
|
||||
U* dst = out.data<U>();
|
||||
for (size_t i = 0; i < a.data_size(); ++i) {
|
||||
dst[i] = op(a_ptr[i]);
|
||||
constexpr int N = simd::max_size<T>;
|
||||
size_t size = a.data_size();
|
||||
while (size >= N) {
|
||||
simd::store(dst, op(simd::load<T, N>(a_ptr)));
|
||||
size -= N;
|
||||
a_ptr += N;
|
||||
dst += N;
|
||||
}
|
||||
while (size > 0) {
|
||||
*dst = op(*a_ptr);
|
||||
size--;
|
||||
dst++;
|
||||
a_ptr++;
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
Reference in New Issue
Block a user