From 1c0c118f7c8a9e7ac7d29e43f2e39e0ae69b6fe0 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 7 Feb 2025 15:52:22 -0800 Subject: [PATCH] Fp64 on the CPU (#1843) * add fp64 data type * clean build * update docs * fix bug --- docs/src/python/data_types.rst | 9 ++ docs/src/usage/numpy.rst | 42 +++++--- mlx/array.cpp | 13 ++- mlx/array.h | 3 + mlx/backend/common/common.cpp | 3 + mlx/backend/cpu/arange.h | 3 + mlx/backend/cpu/arg_reduce.cpp | 3 + mlx/backend/cpu/binary.cpp | 69 +++++++----- mlx/backend/cpu/binary.h | 3 + mlx/backend/cpu/binary_two.h | 3 + mlx/backend/cpu/copy.cpp | 6 ++ mlx/backend/cpu/gemms/cblas.cpp | 35 ++++++ mlx/backend/cpu/indexing.cpp | 12 +++ mlx/backend/cpu/matmul.cpp | 3 + mlx/backend/cpu/reduce.cpp | 10 ++ mlx/backend/cpu/scan.cpp | 4 + mlx/backend/cpu/select.cpp | 3 + mlx/backend/cpu/softmax.cpp | 6 +- mlx/backend/cpu/sort.cpp | 8 ++ mlx/backend/cpu/unary.cpp | 9 ++ mlx/backend/cpu/unary.h | 6 ++ mlx/backend/metal/primitives.cpp | 4 +- mlx/backend/metal/utils.cpp | 3 + mlx/distributed/mpi/mpi.cpp | 4 + mlx/distributed/ring/ring.cpp | 4 + mlx/dtype.cpp | 33 +++--- mlx/dtype.h | 2 + mlx/types/limits.h | 3 + mlx/utils.cpp | 12 ++- mlx/utils.h | 5 +- python/src/array.cpp | 2 + python/tests/test_double.py | 178 +++++++++++++++++++++++++++++++ 32 files changed, 438 insertions(+), 65 deletions(-) create mode 100644 python/tests/test_double.py diff --git a/docs/src/python/data_types.rst b/docs/src/python/data_types.rst index c75bfcb9d..4c4a5910d 100644 --- a/docs/src/python/data_types.rst +++ b/docs/src/python/data_types.rst @@ -51,11 +51,20 @@ The default floating point type is ``float32`` and the default integer type is * - ``float32`` - 4 - 32-bit float + * - ``float64`` + - 4 + - 64-bit double * - ``complex64`` - 8 - 64-bit complex float +.. note:: + + Arrays with type ``float64`` only work with CPU operations. Using + ``float64`` arrays on the GPU will result in an exception. + + Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object documentation for more information. Use :func:`issubdtype` to determine if one ``dtype`` (or category) is a subtype of another category. diff --git a/docs/src/usage/numpy.rst b/docs/src/usage/numpy.rst index c589f1887..09b36ddbd 100644 --- a/docs/src/usage/numpy.rst +++ b/docs/src/usage/numpy.rst @@ -21,11 +21,13 @@ Let's convert an array to NumPy and back. .. note:: - Since NumPy does not support ``bfloat16`` arrays, you will need to convert to ``float16`` or ``float32`` first: - ``np.array(a.astype(mx.float32))``. - Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.`` + Since NumPy does not support ``bfloat16`` arrays, you will need to convert + to ``float16`` or ``float32`` first: ``np.array(a.astype(mx.float32))``. + Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 + buffer format string does not match the dtype V item size 0.`` -By default, NumPy copies data to a new array. This can be prevented by creating an array view: +By default, NumPy copies data to a new array. This can be prevented by creating +an array view: .. code-block:: python @@ -35,10 +37,16 @@ By default, NumPy copies data to a new array. This can be prevented by creating a_view[0] = 1 print(a[0].item()) # 1 -A NumPy array view is a normal NumPy array, except that it does not own its memory. -This means writing to the view is reflected in the original array. +.. note:: -While this is quite powerful to prevent copying arrays, it should be noted that external changes to the memory of arrays cannot be reflected in gradients. + NumPy arrays with type ``float64`` will be default converted to MLX arrays + with type ``float32``. + +A NumPy array view is a normal NumPy array, except that it does not own its +memory. This means writing to the view is reflected in the original array. + +While this is quite powerful to prevent copying arrays, it should be noted that +external changes to the memory of arrays cannot be reflected in gradients. Let's demonstrate this in an example: @@ -56,11 +64,12 @@ Let's demonstrate this in an example: The function ``f`` indirectly modifies the array ``x`` through a memory view. -However, this modification is not reflected in the gradient, as seen in the last line outputting ``1.0``, -representing the gradient of the sum operation alone. -The squaring of ``x`` occurs externally to MLX, meaning that no gradient is incorporated. -It's important to note that a similar issue arises during array conversion and copying. -For instance, a function defined as ``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient, +However, this modification is not reflected in the gradient, as seen in the +last line outputting ``1.0``, representing the gradient of the sum operation +alone. The squaring of ``x`` occurs externally to MLX, meaning that no +gradient is incorporated. It's important to note that a similar issue arises +during array conversion and copying. For instance, a function defined as +``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient, even though no in-place operations on MLX memory are executed. PyTorch @@ -71,7 +80,8 @@ PyTorch PyTorch Support for :obj:`memoryview` is experimental and can break for multi-dimensional arrays. Casting to NumPy first is advised for now. -PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`. +PyTorch supports the buffer protocol, but it requires an explicit +:obj:`memoryview`. .. code-block:: python @@ -82,7 +92,8 @@ PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryvi b = torch.tensor(memoryview(a)) c = mx.array(b.numpy()) -Conversion from PyTorch tensors back to arrays must be done via intermediate NumPy arrays with ``numpy()``. +Conversion from PyTorch tensors back to arrays must be done via intermediate +NumPy arrays with ``numpy()``. JAX --- @@ -100,7 +111,8 @@ JAX fully supports the buffer protocol. TensorFlow ---------- -TensorFlow supports the buffer protocol, but it requires an explicit :obj:`memoryview`. +TensorFlow supports the buffer protocol, but it requires an explicit +:obj:`memoryview`. .. code-block:: python diff --git a/mlx/array.cpp b/mlx/array.cpp index c2edb4940..b06de8fa3 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -25,7 +25,18 @@ array::array( std::move(shape), dtype, std::move(primitive), - std::move(inputs))) {} + std::move(inputs))) { + if (has_primitive() && this->primitive().stream().device == Device::gpu) { + for (auto& in : this->inputs()) { + if (in.dtype() == float64) { + throw std::invalid_argument("float64 is not supported on the GPU"); + } + } + if (this->dtype() == float64) { + throw std::invalid_argument("float64 is not supported on the GPU"); + } + } +} std::vector array::make_arrays( std::vector shapes, diff --git a/mlx/array.h b/mlx/array.h index 6ad0e578a..2c1b35cbc 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -594,6 +594,9 @@ void array::init(It src) { case float32: std::copy(src, src + size(), data()); break; + case float64: + std::copy(src, src + size(), data()); + break; case bfloat16: std::copy(src, src + size(), data()); break; diff --git a/mlx/backend/common/common.cpp b/mlx/backend/common/common.cpp index 310a02a49..a63de9aa1 100644 --- a/mlx/backend/common/common.cpp +++ b/mlx/backend/common/common.cpp @@ -151,6 +151,9 @@ void NumberOfElements::eval(const std::vector& inputs, array& out) { case bfloat16: *out.data() = static_cast(numel); break; + case float64: + *out.data() = static_cast(numel); + break; case complex64: *out.data() = static_cast(numel); break; diff --git a/mlx/backend/cpu/arange.h b/mlx/backend/cpu/arange.h index 5c155dc09..fadeafde3 100644 --- a/mlx/backend/cpu/arange.h +++ b/mlx/backend/cpu/arange.h @@ -62,6 +62,9 @@ void arange( case float32: arange(start, start + step, out, out.size()); break; + case float64: + arange(start, start + step, out, out.size()); + break; case bfloat16: arange(start, start + step, out, out.size()); break; diff --git a/mlx/backend/cpu/arg_reduce.cpp b/mlx/backend/cpu/arg_reduce.cpp index 38eff29a1..1eccf1012 100644 --- a/mlx/backend/cpu/arg_reduce.cpp +++ b/mlx/backend/cpu/arg_reduce.cpp @@ -103,6 +103,9 @@ void ArgReduce::eval_cpu(const std::vector& inputs, array& out) { case bfloat16: arg_reduce_dispatch(in, out, reduce_type_, axis_); break; + case float64: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; case complex64: arg_reduce_dispatch(in, out, reduce_type_, axis_); break; diff --git a/mlx/backend/cpu/binary.cpp b/mlx/backend/cpu/binary.cpp index 3fd0e63a5..acb988c4b 100644 --- a/mlx/backend/cpu/binary.cpp +++ b/mlx/backend/cpu/binary.cpp @@ -51,6 +51,9 @@ void comparison_op(const array& a, const array& b, array& out, Op op) { case float32: binary_op(a, b, out, op); break; + case float64: + binary_op(a, b, out, op); + break; case bfloat16: binary_op(a, b, out, op); break; @@ -114,6 +117,9 @@ void DivMod::eval_cpu( case float32: binary_op(a, b, outputs, float_op); break; + case float64: + binary_op(a, b, outputs, float_op); + break; case bfloat16: binary_op(a, b, outputs, float_op); break; @@ -150,6 +156,9 @@ void Equal::eval_cpu(const std::vector& inputs, array& out) { case float32: binary_op(a, b, out, detail::NaNEqual()); break; + case float64: + binary_op(a, b, out, detail::NaNEqual()); + break; case bfloat16: binary_op(a, b, out, detail::NaNEqual()); break; @@ -189,20 +198,22 @@ void LogAddExp::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); auto& a = inputs[0]; auto& b = inputs[1]; - if (out.dtype() == float32) { - binary_op(a, b, out, detail::LogAddExp()); - } else if (out.dtype() == float16) { - binary_op(a, b, out, detail::LogAddExp()); - } else if (out.dtype() == bfloat16) { - binary_op(a, b, out, detail::LogAddExp()); - } else if (issubdtype(out.dtype(), inexact)) { - std::ostringstream err; - err << "[logaddexp] Does not support " << out.dtype(); - throw std::invalid_argument(err.str()); - } else { - throw std::invalid_argument( - "[logaddexp] Cannot compute logaddexp for arrays with" - " non floating point type."); + switch (out.dtype()) { + case float16: + binary_op(a, b, out, detail::LogAddExp()); + break; + case float32: + binary_op(a, b, out, detail::LogAddExp()); + break; + case float64: + binary_op(a, b, out, detail::LogAddExp()); + break; + case bfloat16: + binary_op(a, b, out, detail::LogAddExp()); + break; + default: + throw std::runtime_error( + "[LogAddExp::eval_cpu] Only supports non-complex floating point types."); } } @@ -321,20 +332,22 @@ void ArcTan2::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); const auto& a = inputs[0]; const auto& b = inputs[1]; - if (out.dtype() == float32) { - binary_op(a, b, out, detail::ArcTan2()); - } else if (out.dtype() == float16) { - binary_op(a, b, out, detail::ArcTan2()); - } else if (out.dtype() == bfloat16) { - binary_op(a, b, out, detail::ArcTan2()); - } else if (issubdtype(out.dtype(), inexact)) { - std::ostringstream err; - err << "[arctan2] Does not support " << out.dtype(); - throw std::invalid_argument(err.str()); - } else { - throw std::invalid_argument( - "[arctan2] Cannot compute inverse tangent for arrays" - " with non floating point type."); + switch (out.dtype()) { + case float16: + binary_op(a, b, out, detail::ArcTan2()); + break; + case float32: + binary_op(a, b, out, detail::ArcTan2()); + break; + case float64: + binary_op(a, b, out, detail::ArcTan2()); + break; + case bfloat16: + binary_op(a, b, out, detail::ArcTan2()); + break; + default: + throw std::runtime_error( + "[ArcTan2::eval_cpu] Only supports non-complex floating point types."); } } diff --git a/mlx/backend/cpu/binary.h b/mlx/backend/cpu/binary.h index ab9fb2486..c85e6ca1a 100644 --- a/mlx/backend/cpu/binary.h +++ b/mlx/backend/cpu/binary.h @@ -358,6 +358,9 @@ void binary(const array& a, const array& b, array& out, Op op) { case float32: binary_op(a, b, out, op); break; + case float64: + binary_op(a, b, out, op); + break; case bfloat16: binary_op(a, b, out, op); break; diff --git a/mlx/backend/cpu/binary_two.h b/mlx/backend/cpu/binary_two.h index 6c106b904..4b6bf25cd 100644 --- a/mlx/backend/cpu/binary_two.h +++ b/mlx/backend/cpu/binary_two.h @@ -205,6 +205,9 @@ void binary( case float32: binary_op(a, b, outputs, op); break; + case float64: + binary_op(a, b, outputs, op); + break; case bfloat16: binary_op(a, b, outputs, op); break; diff --git a/mlx/backend/cpu/copy.cpp b/mlx/backend/cpu/copy.cpp index 66c27e745..6c14d0fa4 100644 --- a/mlx/backend/cpu/copy.cpp +++ b/mlx/backend/cpu/copy.cpp @@ -193,6 +193,9 @@ void copy(const array& src, array& dst, CopyType ctype, Args&&... args) { case float32: copy(src, dst, ctype, std::forward(args)...); break; + case float64: + copy(src, dst, ctype, std::forward(args)...); + break; case bfloat16: copy(src, dst, ctype, std::forward(args)...); break; @@ -242,6 +245,9 @@ inline void copy_inplace_dispatch( case float32: copy(src, dst, ctype, std::forward(args)...); break; + case float64: + copy(src, dst, ctype, std::forward(args)...); + break; case bfloat16: copy(src, dst, ctype, std::forward(args)...); break; diff --git a/mlx/backend/cpu/gemms/cblas.cpp b/mlx/backend/cpu/gemms/cblas.cpp index fef63b3e9..912098a1a 100644 --- a/mlx/backend/cpu/gemms/cblas.cpp +++ b/mlx/backend/cpu/gemms/cblas.cpp @@ -41,4 +41,39 @@ void matmul( } } +template <> +void matmul( + const array& a, + const array& b, + array& out, + bool a_transposed, + bool b_transposed, + size_t lda, + size_t ldb, + float alpha, + float beta) { + size_t M = a.shape(-2); + size_t N = b.shape(-1); + size_t K = a.shape(-1); + + for (int i = 0; i < (a.size() / (M * K)); ++i) { + cblas_dgemm( + CblasRowMajor, + a_transposed ? CblasTrans : CblasNoTrans, // transA + b_transposed ? CblasTrans : CblasNoTrans, // transB + M, + N, + K, + alpha, // alpha + a.data() + elem_to_loc(M * K * i, a.shape(), a.strides()), + lda, + b.data() + elem_to_loc(K * N * i, b.shape(), b.strides()), + ldb, + beta, // beta + out.data() + M * N * i, + out.shape(-1) // ldc + ); + } +} + } // namespace mlx::core diff --git a/mlx/backend/cpu/indexing.cpp b/mlx/backend/cpu/indexing.cpp index 4eb48b921..3a0b63f34 100644 --- a/mlx/backend/cpu/indexing.cpp +++ b/mlx/backend/cpu/indexing.cpp @@ -148,6 +148,9 @@ void dispatch_gather( case float32: gather(src, inds, out, axes, size); break; + case float64: + gather(src, inds, out, axes, size); + break; case bfloat16: gather(src, inds, out, axes, size); break; @@ -288,6 +291,9 @@ void dispatch_gather_axis( case float32: gather_axis(src, inds, out, axis); break; + case float64: + gather_axis(src, inds, out, axis); + break; case bfloat16: gather_axis(src, inds, out, axis); break; @@ -499,6 +505,9 @@ void Scatter::eval_cpu(const std::vector& inputs, array& out) { case float32: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; + case float64: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; case bfloat16: dispatch_scatter(out, inds, updates, axes_, reduce_type_); break; @@ -661,6 +670,9 @@ void ScatterAxis::eval_cpu(const std::vector& inputs, array& out) { case float32: dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); break; + case float64: + dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); + break; case bfloat16: dispatch_scatter_axis(out, idx, updates, axis_, reduce_type_); break; diff --git a/mlx/backend/cpu/matmul.cpp b/mlx/backend/cpu/matmul.cpp index 05989c328..0712ea2d3 100644 --- a/mlx/backend/cpu/matmul.cpp +++ b/mlx/backend/cpu/matmul.cpp @@ -46,6 +46,9 @@ void matmul_general( } else if (out.dtype() == bfloat16) { matmul( a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta); + } else if (out.dtype() == float64) { + matmul( + a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta); } else { throw std::runtime_error("[Matmul::eval_cpu] Invalid type."); } diff --git a/mlx/backend/cpu/reduce.cpp b/mlx/backend/cpu/reduce.cpp index 11f27ea06..1c93caaf1 100644 --- a/mlx/backend/cpu/reduce.cpp +++ b/mlx/backend/cpu/reduce.cpp @@ -42,6 +42,7 @@ instantiate_default_limit(int64_t); instantiate_float_limit(float16_t); instantiate_float_limit(bfloat16_t); instantiate_float_limit(float); +instantiate_float_limit(double); instantiate_float_limit(complex64_t); template <> @@ -59,6 +60,8 @@ const bfloat16_t Limits::min = const float16_t Limits::max = std::numeric_limits::infinity(); const float16_t Limits::min = -std::numeric_limits::infinity(); +const double Limits::max = std::numeric_limits::infinity(); +const double Limits::min = -std::numeric_limits::infinity(); const complex64_t Limits::max = std::numeric_limits::infinity(); const complex64_t Limits::min = @@ -460,6 +463,7 @@ void Reduce::eval_cpu(const std::vector& inputs, array& out) { break; case uint64: case int64: + case float64: case complex64: reduce_dispatch_and_or(in, out, reduce_type_, axes_); break; @@ -495,6 +499,9 @@ void Reduce::eval_cpu(const std::vector& inputs, array& out) { case float32: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; + case float64: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; case complex64: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; @@ -537,6 +544,9 @@ void Reduce::eval_cpu(const std::vector& inputs, array& out) { case float32: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; + case float64: + reduce_dispatch_min_max(in, out, reduce_type_, axes_); + break; case bfloat16: reduce_dispatch_min_max(in, out, reduce_type_, axes_); break; diff --git a/mlx/backend/cpu/scan.cpp b/mlx/backend/cpu/scan.cpp index 0c231baab..28c929beb 100644 --- a/mlx/backend/cpu/scan.cpp +++ b/mlx/backend/cpu/scan.cpp @@ -299,6 +299,10 @@ void Scan::eval_cpu(const std::vector& inputs, array& out) { scan_dispatch( reduce_type_, in, out, axis_, reverse_, inclusive_); break; + case float64: + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + break; case bfloat16: scan_dispatch( reduce_type_, in, out, axis_, reverse_, inclusive_); diff --git a/mlx/backend/cpu/select.cpp b/mlx/backend/cpu/select.cpp index a08805893..1382a8ff6 100644 --- a/mlx/backend/cpu/select.cpp +++ b/mlx/backend/cpu/select.cpp @@ -51,6 +51,9 @@ void select_op( case float32: ternary_op(a, b, c, out, op); break; + case float64: + ternary_op(a, b, c, out, op); + break; case bfloat16: ternary_op(a, b, c, out, op); break; diff --git a/mlx/backend/cpu/softmax.cpp b/mlx/backend/cpu/softmax.cpp index 3c80d7f28..43b5a55e2 100644 --- a/mlx/backend/cpu/softmax.cpp +++ b/mlx/backend/cpu/softmax.cpp @@ -6,6 +6,7 @@ #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/simd/simd.h" #include "mlx/primitives.h" +#include "mlx/types/limits.h" namespace mlx::core { @@ -28,7 +29,7 @@ void softmax(const array& in, array& out) { for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) { // Find the maximum current_in_ptr = in_ptr; - Simd vmaximum(-std::numeric_limits::infinity()); + Simd vmaximum(-numeric_limits::infinity()); size_t s = M; while (s >= N) { Simd vals = load(current_in_ptr); @@ -163,6 +164,9 @@ void Softmax::eval_cpu(const std::vector& inputs, array& out) { softmax(in, out); } break; + case float64: + softmax(in, out); + break; case complex64: throw std::invalid_argument( "[Softmax] Not yet implemented for complex64"); diff --git a/mlx/backend/cpu/sort.cpp b/mlx/backend/cpu/sort.cpp index 078b68ade..a6db5dcef 100644 --- a/mlx/backend/cpu/sort.cpp +++ b/mlx/backend/cpu/sort.cpp @@ -312,6 +312,8 @@ void ArgSort::eval_cpu(const std::vector& inputs, array& out) { return argsort(in, out, axis_); case float32: return argsort(in, out, axis_); + case float64: + return argsort(in, out, axis_); case float16: return argsort(in, out, axis_); case bfloat16: @@ -346,6 +348,8 @@ void Sort::eval_cpu(const std::vector& inputs, array& out) { return sort(in, out, axis_); case float32: return sort(in, out, axis_); + case float64: + return sort(in, out, axis_); case float16: return sort(in, out, axis_); case bfloat16: @@ -380,6 +384,8 @@ void ArgPartition::eval_cpu(const std::vector& inputs, array& out) { return argpartition(in, out, axis_, kth_); case float32: return argpartition(in, out, axis_, kth_); + case float64: + return argpartition(in, out, axis_, kth_); case float16: return argpartition(in, out, axis_, kth_); case bfloat16: @@ -414,6 +420,8 @@ void Partition::eval_cpu(const std::vector& inputs, array& out) { return partition(in, out, axis_, kth_); case float32: return partition(in, out, axis_, kth_); + case float64: + return partition(in, out, axis_, kth_); case float16: return partition(in, out, axis_, kth_); case bfloat16: diff --git a/mlx/backend/cpu/unary.cpp b/mlx/backend/cpu/unary.cpp index c6431baec..0ac4daf27 100644 --- a/mlx/backend/cpu/unary.cpp +++ b/mlx/backend/cpu/unary.cpp @@ -34,6 +34,9 @@ void Abs::eval_cpu(const std::vector& inputs, array& out) { case float32: unary_op(in, out, op); break; + case float64: + unary_op(in, out, op); + break; case bfloat16: unary_op(in, out, op); break; @@ -120,6 +123,9 @@ void Erf::eval_cpu(const std::vector& inputs, array& out) { case float16: unary_op(in, out, detail::Erf()); break; + case float64: + unary_op(in, out, detail::Erf()); + break; case bfloat16: unary_op(in, out, detail::Erf()); break; @@ -140,6 +146,9 @@ void ErfInv::eval_cpu(const std::vector& inputs, array& out) { case float16: unary_op(in, out, detail::ErfInv()); break; + case float64: + unary_op(in, out, detail::ErfInv()); + break; case bfloat16: unary_op(in, out, detail::ErfInv()); break; diff --git a/mlx/backend/cpu/unary.h b/mlx/backend/cpu/unary.h index 6dccaf615..edfe2a7b4 100644 --- a/mlx/backend/cpu/unary.h +++ b/mlx/backend/cpu/unary.h @@ -104,6 +104,9 @@ void unary(const array& a, array& out, Op op) { case float32: unary_op(a, out, op); break; + case float64: + unary_op(a, out, op); + break; case bfloat16: unary_op(a, out, op); break; @@ -125,6 +128,9 @@ void unary_fp(const array& a, array& out, Op op) { case float32: unary_op(a, out, op); break; + case float64: + unary_op(a, out, op); + break; case complex64: unary_op(a, out, op); break; diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 627f30478..9aefd3f44 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -151,8 +151,8 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { case bfloat16: arange_set_scalars(start_, start_ + step_, compute_encoder); break; - case complex64: - throw std::runtime_error("[Arange::eval_gpu] Does not support complex64"); + default: + throw std::runtime_error("[Arange::eval_gpu] Does not support type."); } compute_encoder.set_output_array(out, 2); diff --git a/mlx/backend/metal/utils.cpp b/mlx/backend/metal/utils.cpp index 2eaacba4c..329d250df 100644 --- a/mlx/backend/metal/utils.cpp +++ b/mlx/backend/metal/utils.cpp @@ -42,6 +42,9 @@ std::string type_to_name(const Dtype& t) { case float32: tname = "float32"; break; + case float64: + tname = "double"; + break; case bfloat16: tname = "bfloat16"; break; diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index 41ef03d97..e532cd771 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -95,6 +95,7 @@ struct MPIWrapper { LOAD_SYMBOL(ompi_mpi_int64_t, mpi_int64_); LOAD_SYMBOL(ompi_mpi_uint64_t, mpi_uint64_); LOAD_SYMBOL(ompi_mpi_float, mpi_float_); + LOAD_SYMBOL(ompi_mpi_double, mpi_double_); LOAD_SYMBOL(ompi_mpi_c_complex, mpi_complex_); } @@ -164,6 +165,8 @@ struct MPIWrapper { return mpi_float16_; case bfloat16: return mpi_bfloat16_; + case float64: + return mpi_double_; } } @@ -218,6 +221,7 @@ struct MPIWrapper { MPI_Datatype mpi_int64_; MPI_Datatype mpi_uint64_; MPI_Datatype mpi_float_; + MPI_Datatype mpi_double_; MPI_Datatype mpi_complex_; MPI_Datatype mpi_float16_; MPI_Datatype mpi_bfloat16_; diff --git a/mlx/distributed/ring/ring.cpp b/mlx/distributed/ring/ring.cpp index ad3f2e0a5..3f1586021 100644 --- a/mlx/distributed/ring/ring.cpp +++ b/mlx/distributed/ring/ring.cpp @@ -68,6 +68,10 @@ using T = float; \ __VA_ARGS__; \ } break; \ + case float64: { \ + using T = double; \ + __VA_ARGS__; \ + } break; \ case complex64: { \ using T = complex64_t; \ __VA_ARGS__; \ diff --git a/mlx/dtype.cpp b/mlx/dtype.cpp index 665512f04..5fd0415ba 100644 --- a/mlx/dtype.cpp +++ b/mlx/dtype.cpp @@ -8,7 +8,7 @@ namespace mlx::core { namespace { -constexpr int num_types = 13; +constexpr int num_types = 14; constexpr int num_cats = 8; constexpr Dtype::Kind type_kinds[num_types] = { @@ -23,6 +23,7 @@ constexpr Dtype::Kind type_kinds[num_types] = { Dtype::Kind::i, // int64, Dtype::Kind::f, // float16, Dtype::Kind::f, // float32, + Dtype::Kind::f, // float64, Dtype::Kind::V, // bfloat16, Dtype::Kind::c // complex64, }; @@ -31,20 +32,21 @@ constexpr Dtype::Kind type_kinds[num_types] = { // https://jax.readthedocs.io/en/latest/type_promotion.html // clang-format off constexpr Dtype type_rules[num_types][num_types] = { -// bool uint8 uint16 uint32 uint64 int8 int16 int32 int64 float16 float32 bfloat16 complex64 - {bool_, uint8, uint16, uint32, uint64, int8, int16, int32, int64, float16, float32, bfloat16, complex64}, // bool - {uint8, uint8, uint16, uint32, uint64, int16, int16, int32, int64, float16, float32, bfloat16, complex64}, // uint8 - {uint16, uint16, uint16, uint32, uint64, int32, int32, int32, int64, float16, float32, bfloat16, complex64}, // uint16 - {uint32, uint32, uint32, uint32, uint64, int64, int64, int64, int64, float16, float32, bfloat16, complex64}, // uint32 - {uint64, uint64, uint64, uint64, uint64, float32, float32, float32, float32, float16, float32, bfloat16, complex64}, // uint64 - {int8, int16, int32, int64, float32, int8, int16, int32, int64, float16, float32, bfloat16, complex64}, // int8 - {int16, int16, int32, int64, float32, int16, int16, int32, int64, float16, float32, bfloat16, complex64}, // int16 - {int32, int32, int32, int64, float32, int32, int32, int32, int64, float16, float32, bfloat16, complex64}, // int32 - {int64, int64, int64, int64, float32, int64, int64, int64, int64, float16, float32, bfloat16, complex64}, // int64 - {float16, float16, float16, float16, float16, float16, float16, float16, float16, float16, float32, float32, complex64}, // float16 - {float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, complex64}, // float32 - {bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, float32, float32, bfloat16, complex64}, // bfloat16 - {complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64}, // complex64 +// bool uint8 uint16 uint32 uint64 int8 int16 int32 int64 float16 float32 float64 bfloat16 complex64 + {bool_, uint8, uint16, uint32, uint64, int8, int16, int32, int64, float16, float32, float64, bfloat16, complex64}, // bool + {uint8, uint8, uint16, uint32, uint64, int16, int16, int32, int64, float16, float32, float64, bfloat16, complex64}, // uint8 + {uint16, uint16, uint16, uint32, uint64, int32, int32, int32, int64, float16, float32, float64, bfloat16, complex64}, // uint16 + {uint32, uint32, uint32, uint32, uint64, int64, int64, int64, int64, float16, float32, float64, bfloat16, complex64}, // uint32 + {uint64, uint64, uint64, uint64, uint64, float32, float32, float32, float32, float16, float32, float64, bfloat16, complex64}, // uint64 + {int8, int16, int32, int64, float32, int8, int16, int32, int64, float16, float32, float64, bfloat16, complex64}, // int8 + {int16, int16, int32, int64, float32, int16, int16, int32, int64, float16, float32, float64, bfloat16, complex64}, // int16 + {int32, int32, int32, int64, float32, int32, int32, int32, int64, float16, float32, float64, bfloat16, complex64}, // int32 + {int64, int64, int64, int64, float32, int64, int64, int64, int64, float16, float32, float64, bfloat16, complex64}, // int64 + {float16, float16, float16, float16, float16, float16, float16, float16, float16, float16, float32, float64, float32, complex64}, // float16 + {float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float64, float32, complex64}, // float32 + {float64, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float64, float32, complex64}, // float64 + {bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, float32, float32, float64, bfloat16, complex64}, // bfloat16 + {complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64,complex64, complex64, complex64}, // complex64 }; @@ -72,6 +74,7 @@ constexpr Dtype::Category type_to_category[num_types] = { Dtype::Category::signedinteger, // int64, Dtype::Category::floating, // float16, Dtype::Category::floating, // float32, + Dtype::Category::floating, // float64, Dtype::Category::floating, // bfloat16, Dtype::Category::complexfloating, // complex64, }; diff --git a/mlx/dtype.h b/mlx/dtype.h index 11d61e378..e02b6ca35 100644 --- a/mlx/dtype.h +++ b/mlx/dtype.h @@ -23,6 +23,7 @@ struct Dtype { int64, float16, float32, + float64, bfloat16, complex64, }; @@ -78,6 +79,7 @@ inline constexpr Dtype int64{Dtype::Val::int64, sizeof(int64_t)}; inline constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)}; inline constexpr Dtype float32{Dtype::Val::float32, sizeof(float)}; +inline constexpr Dtype float64{Dtype::Val::float64, sizeof(double)}; inline constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)}; inline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)}; diff --git a/mlx/types/limits.h b/mlx/types/limits.h index 6f2668a5f..7e0de15bc 100644 --- a/mlx/types/limits.h +++ b/mlx/types/limits.h @@ -12,6 +12,9 @@ struct numeric_limits; template <> struct numeric_limits : public std::numeric_limits {}; +template <> +struct numeric_limits : public std::numeric_limits {}; + template <> struct numeric_limits { private: diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 2848af2a3..11a9488b3 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -54,6 +54,9 @@ inline void PrintFormatter::print(std::ostream& os, bfloat16_t val) { inline void PrintFormatter::print(std::ostream& os, float val) { os << val; } +inline void PrintFormatter::print(std::ostream& os, double val) { + os << val; +} inline void PrintFormatter::print(std::ostream& os, complex64_t val) { os << val; } @@ -234,6 +237,8 @@ std::ostream& operator<<(std::ostream& os, const Dtype& dtype) { return os << "float16"; case float32: return os << "float32"; + case float64: + return os << "float64"; case bfloat16: return os << "bfloat16"; case complex64: @@ -299,6 +304,9 @@ std::ostream& operator<<(std::ostream& os, array a) { case float32: print_array(os, a); break; + case float64: + print_array(os, a); + break; case complex64: print_array(os, a); break; @@ -337,7 +345,7 @@ int get_var(const char* name, int default_value) { } // namespace env template -void set_finfo_limits(float& min, float& max) { +void set_finfo_limits(double& min, double& max) { min = numeric_limits::lowest(); max = numeric_limits::max(); } @@ -354,6 +362,8 @@ finfo::finfo(Dtype dtype) : dtype(dtype) { set_finfo_limits(min, max); } else if (dtype == bfloat16) { set_finfo_limits(min, max); + } else if (dtype == float64) { + set_finfo_limits(min, max); } else if (dtype == complex64) { this->dtype = float32; set_finfo_limits(min, max); diff --git a/mlx/utils.h b/mlx/utils.h index 94b0974d7..df6bc0ec6 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -47,6 +47,7 @@ struct PrintFormatter { inline void print(std::ostream& os, float16_t val); inline void print(std::ostream& os, bfloat16_t val); inline void print(std::ostream& os, float val); + inline void print(std::ostream& os, double val); inline void print(std::ostream& os, complex64_t val); bool capitalize_bool{false}; @@ -61,8 +62,8 @@ void abort_with_exception(const std::exception& error); struct finfo { explicit finfo(Dtype dtype); Dtype dtype; - float min; - float max; + double min; + double max; }; /** The type from promoting the arrays' types with one another. */ diff --git a/python/src/array.cpp b/python/src/array.cpp index 58193b2ea..8f0040b1e 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -128,6 +128,7 @@ void init_array(nb::module_& m) { m.attr("int64") = nb::cast(mx::int64); m.attr("float16") = nb::cast(mx::float16); m.attr("float32") = nb::cast(mx::float32); + m.attr("float64") = nb::cast(mx::float64); m.attr("bfloat16") = nb::cast(mx::bfloat16); m.attr("complex64") = nb::cast(mx::complex64); nb::enum_( @@ -163,6 +164,7 @@ void init_array(nb::module_& m) { * :ref:`float16 ` * :ref:`bfloat16 ` * :ref:`float32 ` + * :ref:`float64 ` * :attr:`~mlx.core.complexfloating` diff --git a/python/tests/test_double.py b/python/tests/test_double.py new file mode 100644 index 000000000..00d8c9639 --- /dev/null +++ b/python/tests/test_double.py @@ -0,0 +1,178 @@ +# Copyright © 2024 Apple Inc. + +import math +import os +import unittest + +import mlx.core as mx +import mlx_tests +import numpy as np + + +class TestDouble(mlx_tests.MLXTestCase): + def test_unary_ops(self): + shape = (3, 3) + x = mx.random.normal(shape=shape) + + if mx.default_device() == mx.gpu: + with self.assertRaises(ValueError): + x.astype(mx.float64) + + x_double = x.astype(mx.float64, stream=mx.cpu) + + ops = [ + mx.abs, + mx.arccos, + mx.arccosh, + mx.arcsin, + mx.arcsinh, + mx.arctan, + mx.arctanh, + mx.ceil, + mx.erf, + mx.erfinv, + mx.exp, + mx.expm1, + mx.floor, + mx.log, + mx.logical_not, + mx.negative, + mx.round, + mx.sin, + mx.sinh, + mx.sqrt, + mx.rsqrt, + mx.tan, + mx.tanh, + ] + for op in ops: + if mx.default_device() == mx.gpu: + with self.assertRaises(ValueError): + op(x_double) + continue + y = op(x) + y_double = op(x_double) + self.assertTrue( + mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True) + ) + + def test_binary_ops(self): + shape = (3, 3) + a = mx.random.normal(shape=shape) + b = mx.random.normal(shape=shape) + + a_double = a.astype(mx.float64, stream=mx.cpu) + b_double = b.astype(mx.float64, stream=mx.cpu) + + ops = [ + mx.add, + mx.arctan2, + mx.divide, + mx.multiply, + mx.subtract, + mx.logical_and, + mx.logical_or, + mx.remainder, + mx.maximum, + mx.minimum, + mx.power, + mx.equal, + mx.greater, + mx.greater_equal, + mx.less, + mx.less_equal, + mx.not_equal, + mx.logaddexp, + ] + for op in ops: + if mx.default_device() == mx.gpu: + with self.assertRaises(ValueError): + op(a_double, b_double) + continue + y = op(a, b) + y_double = op(a_double, b_double) + self.assertTrue( + mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True) + ) + + def test_where(self): + shape = (3, 3) + cond = mx.random.uniform(shape=shape) > 0.5 + a = mx.random.normal(shape=shape) + b = mx.random.normal(shape=shape) + + a_double = a.astype(mx.float64, stream=mx.cpu) + b_double = b.astype(mx.float64, stream=mx.cpu) + + if mx.default_device() == mx.gpu: + with self.assertRaises(ValueError): + mx.where(cond, a_double, b_double) + return + y = mx.where(cond, a, b) + y_double = mx.where(cond, a_double, b_double) + self.assertTrue(mx.allclose(y, y_double.astype(mx.float32, mx.cpu))) + + def test_reductions(self): + shape = (32, 32) + a = mx.random.normal(shape=shape) + a_double = a.astype(mx.float64, stream=mx.cpu) + + axes = [0, 1, (0, 1)] + ops = [mx.sum, mx.prod, mx.min, mx.max, mx.any, mx.all] + + for op in ops: + for ax in axes: + if mx.default_device() == mx.gpu: + with self.assertRaises(ValueError): + op(a_double, axis=ax) + continue + y = op(a) + y_double = op(a_double) + self.assertTrue(mx.allclose(y, y_double.astype(mx.float32, mx.cpu))) + + def test_get_and_set_item(self): + shape = (3, 3) + a = mx.random.normal(shape=shape) + b = mx.random.normal(shape=(2,)) + a_double = a.astype(mx.float64, stream=mx.cpu) + b_double = b.astype(mx.float64, stream=mx.cpu) + idx_i = mx.array([0, 2]) + idx_j = mx.array([0, 2]) + + if mx.default_device() == mx.gpu: + with self.assertRaises(ValueError): + a_double[idx_i, idx_j] + else: + y = a[idx_i, idx_j] + y_double = a_double[idx_i, idx_j] + self.assertTrue(mx.allclose(y, y_double.astype(mx.float32, mx.cpu))) + + if mx.default_device() == mx.gpu: + with self.assertRaises(ValueError): + a_double[idx_i, idx_j] = b_double + else: + a[idx_i, idx_j] = b + a_double[idx_i, idx_j] = b_double + self.assertTrue(mx.allclose(a, a_double.astype(mx.float32, mx.cpu))) + + def test_gemm(self): + shape = (8, 8) + a = mx.random.normal(shape=shape) + b = mx.random.normal(shape=shape) + + a_double = a.astype(mx.float64, stream=mx.cpu) + b_double = b.astype(mx.float64, stream=mx.cpu) + + if mx.default_device() == mx.gpu: + with self.assertRaises(ValueError): + a_double @ b_double + return + y = a @ b + y_double = a_double @ b_double + self.assertTrue( + mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True) + ) + + +if __name__ == "__main__": + unittest.main()