From e142aaf8a1b1530c5e13b5931336af0b8cde02df Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 4 Apr 2024 08:32:35 -0700 Subject: [PATCH] Option for precise softmax (#953) * precise softmax * Add an equivalency check * Make the threadgroup memory definition fixed * precise cpu softmax * precise option on cpu * remove print --------- Co-authored-by: Angelos Katharopoulos --- mlx/backend/accelerate/softmax.cpp | 92 ++++++++++++++----- mlx/backend/common/softmax.cpp | 40 ++++++--- mlx/backend/metal/kernels/softmax.metal | 115 ++++++++++++++---------- mlx/backend/metal/softmax.cpp | 5 +- mlx/fast.cpp | 5 +- mlx/ops.cpp | 21 +++-- mlx/ops.h | 8 +- mlx/primitives.cpp | 9 +- mlx/primitives.h | 7 +- python/src/ops.cpp | 5 +- python/tests/test_ops.py | 7 ++ 11 files changed, 215 insertions(+), 99 deletions(-) diff --git a/mlx/backend/accelerate/softmax.cpp b/mlx/backend/accelerate/softmax.cpp index 8b95e32d4..0f415c2b8 100644 --- a/mlx/backend/accelerate/softmax.cpp +++ b/mlx/backend/accelerate/softmax.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include @@ -201,7 +201,7 @@ struct NeonFp16SimdOps { } }; -template +template void softmax(const array& in, array& out) { Ops ops; @@ -218,13 +218,21 @@ void softmax(const array& in, array& out) { VT vmaximum = ops.init(-std::numeric_limits::infinity()); size_t s = M; while (s >= N) { - vmaximum = ops.max(ops.load(current_in_ptr), vmaximum); + VT vals; + if constexpr (std::is_same::value) { + vals = ops.load(current_in_ptr); + } else { + for (int i = 0; i < N; ++i) { + vals[i] = static_cast(current_in_ptr[i]); + } + } + vmaximum = ops.max(vals, vmaximum); current_in_ptr += N; s -= N; } - T maximum = ops.reduce_max(vmaximum); + AccT maximum = ops.reduce_max(vmaximum); while (s-- > 0) { - maximum = std::max(maximum, *current_in_ptr); + maximum = std::max(maximum, static_cast(*current_in_ptr)); current_in_ptr++; } @@ -234,18 +242,29 @@ void softmax(const array& in, array& out) { current_in_ptr = in_ptr; s = M; while (s >= N) { - VT vexp = ops.exp(ops.sub(*(VT*)current_in_ptr, maximum)); - ops.store(current_out_ptr, vexp); - *(VT*)current_out_ptr = vexp; + VT vexp; + if constexpr (std::is_same::value) { + vexp = ops.load(current_in_ptr); + } else { + for (int i = 0; i < N; ++i) { + vexp[i] = static_cast(current_in_ptr[i]); + } + } + vexp = ops.exp(ops.sub(vexp, maximum)); + if constexpr (std::is_same::value) { + ops.store(current_out_ptr, vexp); + } vnormalizer = ops.add(vnormalizer, vexp); current_in_ptr += N; current_out_ptr += N; s -= N; } - T normalizer = ops.reduce_add(vnormalizer); + AccT normalizer = ops.reduce_add(vnormalizer); while (s-- > 0) { - T _exp = std::exp(*current_in_ptr - maximum); - *current_out_ptr = _exp; + AccT _exp = std::exp(*current_in_ptr - maximum); + if (std::is_same::value) { + *current_out_ptr = _exp; + } normalizer += _exp; current_in_ptr++; current_out_ptr++; @@ -254,14 +273,33 @@ void softmax(const array& in, array& out) { // Normalize current_out_ptr = out_ptr; + current_in_ptr = in_ptr; s = M; while (s >= N) { - ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer)); + if constexpr (std::is_same::value) { + ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer)); + } else { + VT vexp; + for (int i = 0; i < N; ++i) { + vexp[i] = static_cast(current_in_ptr[i]); + } + vexp = ops.mul(ops.exp(ops.sub(vexp, maximum)), normalizer); + for (int i = 0; i < N; ++i) { + current_out_ptr[i] = vexp[i]; + } + current_in_ptr += N; + } current_out_ptr += N; s -= N; } while (s-- > 0) { - *current_out_ptr *= normalizer; + if constexpr (std::is_same::value) { + *current_out_ptr *= normalizer; + } else { + AccT _exp = std::exp(*current_in_ptr - maximum); + *current_out_ptr = static_cast(_exp * normalizer); + current_in_ptr++; + } current_out_ptr++; } } @@ -308,15 +346,29 @@ void Softmax::eval_cpu(const std::vector& inputs, array& out) { "Softmax is defined only for floating point types"); break; case float32: - softmax, 16>( - in, out); + softmax< + float, + float, + simd_float16, + AccelerateSimdOps, + 16>(in, out); break; case float16: - softmax< - float16_t, - float16x8_t, - NeonFp16SimdOps, - 8>(in, out); + if (precise_) { + softmax< + float16_t, + float, + simd_float16, + AccelerateSimdOps, + 16>(in, out); + } else { + softmax< + float16_t, + float16_t, + float16x8_t, + NeonFp16SimdOps, + 8>(in, out); + } break; case bfloat16: eval(inputs, out); diff --git a/mlx/backend/common/softmax.cpp b/mlx/backend/common/softmax.cpp index 777163b9f..ed4e3958b 100644 --- a/mlx/backend/common/softmax.cpp +++ b/mlx/backend/common/softmax.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include @@ -10,7 +10,7 @@ namespace mlx::core { namespace { -template +template void softmax(const array& in, array& out) { const T* in_ptr = in.data(); T* out_ptr = out.data(); @@ -22,26 +22,36 @@ void softmax(const array& in, array& out) { for (int i = 0; i < M; i++, in_ptr += N, out_ptr += N) { // Find the maximum current_in_ptr = in_ptr; - T maximum = *current_in_ptr; + AccT maximum = *current_in_ptr; for (int j = 0; j < N; j++, current_in_ptr++) { - maximum = (maximum < *current_in_ptr) ? *current_in_ptr : maximum; + maximum = (maximum < *current_in_ptr) ? static_cast(*current_in_ptr) + : maximum; } // Compute the normalizer and the exponentials - T normalizer = 0; + AccT normalizer = 0; current_out_ptr = out_ptr; current_in_ptr = in_ptr; for (int j = 0; j < N; j++, current_out_ptr++, current_in_ptr++) { - T expv = std::exp(*current_in_ptr - maximum); + AccT expv = std::exp(*current_in_ptr - maximum); normalizer += expv; - *current_out_ptr = expv; + if constexpr (std::is_same::value) { + *current_out_ptr = expv; + } } normalizer = 1 / normalizer; // Normalize + current_in_ptr = in_ptr; current_out_ptr = out_ptr; for (int j = 0; j < N; j++, current_out_ptr++) { - *current_out_ptr *= normalizer; + if constexpr (std::is_same::value) { + *current_out_ptr *= normalizer; + } else { + auto v = std::exp(*current_in_ptr - maximum); + *current_out_ptr = static_cast(v * normalizer); + current_in_ptr++; + } } } } @@ -91,13 +101,21 @@ void Softmax::eval(const std::vector& inputs, array& out) { "Softmax is defined only for floating point types"); break; case float32: - softmax(in, out); + softmax(in, out); break; case float16: - softmax(in, out); + if (precise_) { + softmax(in, out); + } else { + softmax(in, out); + } break; case bfloat16: - softmax(in, out); + if (precise_) { + softmax(in, out); + } else { + softmax(in, out); + } break; case complex64: throw std::invalid_argument( diff --git a/mlx/backend/metal/kernels/softmax.metal b/mlx/backend/metal/kernels/softmax.metal index 2fdcaaa56..4bf9f2916 100644 --- a/mlx/backend/metal/kernels/softmax.metal +++ b/mlx/backend/metal/kernels/softmax.metal @@ -11,46 +11,48 @@ using namespace metal; template inline T softmax_exp(T x) { - // Softmax doesn't need high precision exponential cause it is gonna be x - // will be in (-oo, 0] anyway and subsequently it will be divided by - // sum(exp(x_i)). + // Softmax doesn't need high precision exponential cause x is gonna be in + // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)). return fast::exp(x); } -template +template [[kernel]] void softmax_single_row( const device T* in, device T* out, constant int& axis_size, - threadgroup T* local_max [[threadgroup(0)]], - threadgroup T* local_normalizer [[threadgroup(1)]], uint gid [[threadgroup_position_in_grid]], uint _lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { int lid = _lid; - T ld[N_READS]; + constexpr int SIMD_SIZE = 32; + + threadgroup AccT local_max[SIMD_SIZE]; + threadgroup AccT local_normalizer[SIMD_SIZE]; + + AccT ld[N_READS]; in += gid * axis_size + lid * N_READS; if (lid * N_READS + N_READS <= axis_size) { - for (int i=0; i::finite_min); - } + for (int i = 0; i < N_READS; i++) { + ld[i] = ((lid * N_READS + i) < axis_size) ? AccT(in[i]) + : Limits::finite_min; + } } if (simd_group_id == 0) { - local_max[simd_lane_id] = Limits::finite_min; + local_max[simd_lane_id] = Limits::finite_min; local_normalizer[simd_lane_id] = 0; } threadgroup_barrier(mem_flags::mem_threadgroup); // Get the max - T maxval = Limits::finite_min; + AccT maxval = Limits::finite_min; for (int i = 0; i < N_READS; i++) { maxval = (maxval < ld[i]) ? ld[i] : maxval; } @@ -69,9 +71,9 @@ template maxval = local_max[0]; // Compute exp(x_i - maxval) and store the partial sums in local_normalizer - T normalizer = 0; + AccT normalizer = 0; for (int i = 0; i < N_READS; i++) { - T exp_x = softmax_exp(ld[i] - maxval); + AccT exp_x = softmax_exp(ld[i] - maxval); ld[i] = exp_x; normalizer += exp_x; } @@ -92,25 +94,23 @@ template // Normalize and write to the output out += gid * axis_size + lid * N_READS; if (lid * N_READS + N_READS <= axis_size) { - for (int i=0; i +template [[kernel]] void softmax_looped( const device T* in, device T* out, constant int& axis_size, - threadgroup T* local_max [[threadgroup(0)]], - threadgroup T* local_normalizer [[threadgroup(1)]], uint gid [[threadgroup_position_in_grid]], uint lid [[thread_position_in_threadgroup]], uint lsize [[threads_per_threadgroup]], @@ -118,22 +118,27 @@ template uint simd_group_id [[simdgroup_index_in_threadgroup]]) { in += gid * axis_size; + constexpr int SIMD_SIZE = 32; + + threadgroup AccT local_max[SIMD_SIZE]; + threadgroup AccT local_normalizer[SIMD_SIZE]; + // Get the max and the normalizer in one go - T prevmax; - T maxval = Limits::finite_min; - T normalizer = 0; + AccT prevmax; + AccT maxval = Limits::finite_min; + AccT normalizer = 0; for (int r = 0; r < static_cast(ceildiv(axis_size, N_READS * lsize)); r++) { int offset = r * lsize * N_READS + lid * N_READS; - T vals[N_READS]; + AccT vals[N_READS]; if (offset + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { - vals[i] = in[offset + i]; + vals[i] = AccT(in[offset + i]); } } else { for (int i = 0; i < N_READS; i++) { - vals[i] = - (offset + i < axis_size) ? in[offset + i] : T(Limits::finite_min); + vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) + : Limits::finite_min; } } prevmax = maxval; @@ -179,50 +184,66 @@ template r++) { int offset = r * lsize * N_READS + lid * N_READS; if (offset + N_READS <= axis_size) { - for (int i=0; i( \ const device itype* in, \ device itype* out, \ constant int& axis_size, \ - threadgroup itype* local_max [[threadgroup(0)]], \ - threadgroup itype* local_normalizer [[threadgroup(1)]], \ uint gid [[thread_position_in_grid]], \ uint _lid [[thread_position_in_threadgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]]); - -#define instantiate_softmax_looped(name, itype) \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ template [[host_name("softmax_looped_" #name)]] [[kernel]] void \ softmax_looped( \ const device itype* in, \ device itype* out, \ constant int& axis_size, \ - threadgroup itype* local_max [[threadgroup(0)]], \ - threadgroup itype* local_normalizer [[threadgroup(1)]], \ uint gid [[threadgroup_position_in_grid]], \ uint lid [[thread_position_in_threadgroup]], \ uint lsize [[threads_per_threadgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]]); -#define instantiate_softmax(name, itype) \ - instantiate_softmax_single_row(name, itype) \ - instantiate_softmax_looped(name, itype) +#define instantiate_softmax_precise(name, itype) \ + template [[host_name("softmax_precise_" #name)]] [[kernel]] void \ + softmax_single_row( \ + const device itype* in, \ + device itype* out, \ + constant int& axis_size, \ + uint gid [[thread_position_in_grid]], \ + uint _lid [[thread_position_in_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ + template [[host_name("softmax_looped_precise_" #name)]] [[kernel]] void \ + softmax_looped( \ + const device itype* in, \ + device itype* out, \ + constant int& axis_size, \ + uint gid [[threadgroup_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint lsize [[threads_per_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); instantiate_softmax(float32, float) instantiate_softmax(float16, half) instantiate_softmax(bfloat16, bfloat16_t) +instantiate_softmax_precise(float16, half) +instantiate_softmax_precise(bfloat16, bfloat16_t) +// clang-format on diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index 12d89a665..58b19141c 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -56,6 +56,9 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { if (axis_size > looped_limit) { op_name += "looped_"; } + if (in.dtype() != float32 && precise_) { + op_name += "precise_"; + } op_name += type_to_name(out); auto compute_encoder = d.get_command_encoder(s.index); { @@ -82,8 +85,6 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0); set_array_buffer(compute_encoder, out, 1); compute_encoder->setBytes(&axis_size, sizeof(int), 2); - compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 0); - compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 1); compute_encoder->dispatchThreads(grid_dims, group_dims); } d.get_command_buffer(s.index)->addCompletedHandler( diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 6c8f33d79..52e7c8c21 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -550,10 +550,7 @@ array scaled_dot_product_attention( if (needs_mask) { scores = add(scores, inputs[3], s); } - scores = astype( - softmax(astype(scores, float32, s), std::vector{-1}, s), - final_type, - s); + scores = softmax(scores, std::vector{-1}, true, s); auto out = matmul(scores, v, s); if (n_repeats > 1) { out = reshape(out, {B, n_q_heads, L, -1}, s); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index c1e9c8e0f..1c9e930fc 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2619,25 +2619,34 @@ array rsqrt(const array& a, StreamOrDevice s /* = {} */) { array softmax( const array& a, const std::vector& axes, + bool precise /* = false */, StreamOrDevice s /* = {}*/) { if (axes.size() == 1 && (a.ndim() == axes[0] + 1 || axes[0] == -1)) { auto dtype = at_least_float(a.dtype()); return array( a.shape(), dtype, - std::make_shared(to_stream(s)), + std::make_shared(to_stream(s), precise), {astype(a, dtype, s)}); } else { - auto a_max = stop_gradient(max(a, axes, /*keepdims = */ true, s), s); - auto ex = exp(subtract(a, a_max, s), s); - return divide(ex, sum(ex, axes, /*keepdims = */ true, s), s); + auto in = a; + if (precise) { + in = astype(a, float32, s); + } + auto a_max = stop_gradient(max(in, axes, /*keepdims = */ true, s), s); + auto ex = exp(subtract(in, a_max, s), s); + return astype( + divide(ex, sum(ex, axes, /*keepdims = */ true, s), s), a.dtype(), s); } } -array softmax(const array& a, StreamOrDevice s /* = {}*/) { +array softmax( + const array& a, + bool precise /* = false */, + StreamOrDevice s /* = {}*/) { std::vector axes(a.ndim()); std::iota(axes.begin(), axes.end(), 0); - return softmax(a, axes, s); + return softmax(a, axes, precise, s); } array power(const array& a, const array& b, StreamOrDevice s /* = {} */) { diff --git a/mlx/ops.h b/mlx/ops.h index 69059ed75..e59ceaddb 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -976,14 +976,16 @@ array rsqrt(const array& a, StreamOrDevice s = {}); array softmax( const array& a, const std::vector& axes, + bool precise = false, StreamOrDevice s = {}); /** Softmax of an array. */ -array softmax(const array& a, StreamOrDevice s = {}); +array softmax(const array& a, bool precise = false, StreamOrDevice s = {}); /** Softmax of an array. */ -inline array softmax(const array& a, int axis, StreamOrDevice s = {}) { - return softmax(a, std::vector{axis}, s); +inline array +softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) { + return softmax(a, std::vector{axis}, precise, s); } /** Raise elements of a to the power of b element-wise */ diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 4529e0131..09a456c47 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2975,7 +2975,7 @@ std::pair, std::vector> Softmax::vmap( } else { softmax_axes.push_back(-2); } - return {{softmax(inputs[0], softmax_axes, stream())}, axes}; + return {{softmax(inputs[0], softmax_axes, precise_, stream())}, axes}; } std::vector Softmax::vjp( @@ -2998,13 +2998,18 @@ std::vector Softmax::jvp( const std::vector& argnums) { assert(primals.size() == 1); assert(tangents.size() == 1); - auto s = softmax(primals[0], std::vector{-1}, stream()); + auto s = softmax(primals[0], std::vector{-1}, precise_, stream()); auto sv = multiply(s, tangents[0], stream()); return {subtract( sv, multiply(s, sum(sv, std::vector{-1}, true, stream()), stream()))}; } +bool Softmax::is_equivalent(const Primitive& other) const { + const Softmax& s_other = static_cast(other); + return precise_ == s_other.precise_; +} + std::pair, std::vector> Sort::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index 9e36afbd1..fd05b23f0 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1702,7 +1702,8 @@ class SliceUpdate : public UnaryPrimitive { class Softmax : public UnaryPrimitive { public: - explicit Softmax(Stream stream) : UnaryPrimitive(stream){}; + explicit Softmax(Stream stream, bool precise) + : UnaryPrimitive(stream), precise_(precise){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1710,11 +1711,13 @@ class Softmax : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Softmax) - DEFINE_DEFAULT_IS_EQUIVALENT() DEFINE_INPUT_OUTPUT_SHAPE() + bool is_equivalent(const Primitive& other) const override; + private: void eval(const std::vector& inputs, array& out); + bool precise_; }; class Sort : public UnaryPrimitive { diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 87d483c50..d8e31daf0 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2430,12 +2430,13 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "softmax", - [](const array& a, const IntOrVec& axis, StreamOrDevice s) { - return softmax(a, get_reduce_axes(axis, a.ndim()), s); + [](const array& a, const IntOrVec& axis, bool precise, StreamOrDevice s) { + return softmax(a, get_reduce_axes(axis, a.ndim()), precise, s); }, nb::arg(), "axis"_a = nb::none(), nb::kw_only(), + "precise"_a = false, "stream"_a = nb::none(), nb::sig( "def softmax(a: array, /, axis: Union[None, int, Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"), diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 417272ad8..7f2e98e70 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1430,6 +1430,13 @@ class TestOps(mlx_tests.MLXTestCase): out = mx.softmax(y[:, 0:2], axis=-1) self.assertAlmostEqual(out.sum().item(), 8.0, 5) + # Precise + for t in [mx.float16, mx.bfloat16]: + a = (10 * mx.random.normal(shape=(1024,))).astype(t) + out_expect = mx.softmax(a.astype(mx.float32)).astype(t) + out = mx.softmax(a, axis=-1, precise=True) + self.assertTrue(mx.allclose(out_expect, out)) + def test_concatenate(self): a_npy = np.random.randn(32, 32, 32) b_npy = np.random.randn(32, 32, 32)