Fp64 on the CPU (#1843)

* add fp64 data type

* clean build

* update docs

* fix bug
This commit is contained in:
Awni Hannun 2025-02-07 15:52:22 -08:00 committed by GitHub
parent 1a1b2108ec
commit 1c0c118f7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 438 additions and 65 deletions

View File

@ -51,11 +51,20 @@ The default floating point type is ``float32`` and the default integer type is
* - ``float32``
- 4
- 32-bit float
* - ``float64``
- 4
- 64-bit double
* - ``complex64``
- 8
- 64-bit complex float
.. note::
Arrays with type ``float64`` only work with CPU operations. Using
``float64`` arrays on the GPU will result in an exception.
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
documentation for more information. Use :func:`issubdtype` to determine if one
``dtype`` (or category) is a subtype of another category.

View File

@ -21,11 +21,13 @@ Let's convert an array to NumPy and back.
.. note::
Since NumPy does not support ``bfloat16`` arrays, you will need to convert to ``float16`` or ``float32`` first:
``np.array(a.astype(mx.float32))``.
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.``
Since NumPy does not support ``bfloat16`` arrays, you will need to convert
to ``float16`` or ``float32`` first: ``np.array(a.astype(mx.float32))``.
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118
buffer format string does not match the dtype V item size 0.``
By default, NumPy copies data to a new array. This can be prevented by creating an array view:
By default, NumPy copies data to a new array. This can be prevented by creating
an array view:
.. code-block:: python
@ -35,10 +37,16 @@ By default, NumPy copies data to a new array. This can be prevented by creating
a_view[0] = 1
print(a[0].item()) # 1
A NumPy array view is a normal NumPy array, except that it does not own its memory.
This means writing to the view is reflected in the original array.
.. note::
While this is quite powerful to prevent copying arrays, it should be noted that external changes to the memory of arrays cannot be reflected in gradients.
NumPy arrays with type ``float64`` will be default converted to MLX arrays
with type ``float32``.
A NumPy array view is a normal NumPy array, except that it does not own its
memory. This means writing to the view is reflected in the original array.
While this is quite powerful to prevent copying arrays, it should be noted that
external changes to the memory of arrays cannot be reflected in gradients.
Let's demonstrate this in an example:
@ -56,11 +64,12 @@ Let's demonstrate this in an example:
The function ``f`` indirectly modifies the array ``x`` through a memory view.
However, this modification is not reflected in the gradient, as seen in the last line outputting ``1.0``,
representing the gradient of the sum operation alone.
The squaring of ``x`` occurs externally to MLX, meaning that no gradient is incorporated.
It's important to note that a similar issue arises during array conversion and copying.
For instance, a function defined as ``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
However, this modification is not reflected in the gradient, as seen in the
last line outputting ``1.0``, representing the gradient of the sum operation
alone. The squaring of ``x`` occurs externally to MLX, meaning that no
gradient is incorporated. It's important to note that a similar issue arises
during array conversion and copying. For instance, a function defined as
``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
even though no in-place operations on MLX memory are executed.
PyTorch
@ -71,7 +80,8 @@ PyTorch
PyTorch Support for :obj:`memoryview` is experimental and can break for
multi-dimensional arrays. Casting to NumPy first is advised for now.
PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
PyTorch supports the buffer protocol, but it requires an explicit
:obj:`memoryview`.
.. code-block:: python
@ -82,7 +92,8 @@ PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryvi
b = torch.tensor(memoryview(a))
c = mx.array(b.numpy())
Conversion from PyTorch tensors back to arrays must be done via intermediate NumPy arrays with ``numpy()``.
Conversion from PyTorch tensors back to arrays must be done via intermediate
NumPy arrays with ``numpy()``.
JAX
---
@ -100,7 +111,8 @@ JAX fully supports the buffer protocol.
TensorFlow
----------
TensorFlow supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
TensorFlow supports the buffer protocol, but it requires an explicit
:obj:`memoryview`.
.. code-block:: python

View File

@ -25,7 +25,18 @@ array::array(
std::move(shape),
dtype,
std::move(primitive),
std::move(inputs))) {}
std::move(inputs))) {
if (has_primitive() && this->primitive().stream().device == Device::gpu) {
for (auto& in : this->inputs()) {
if (in.dtype() == float64) {
throw std::invalid_argument("float64 is not supported on the GPU");
}
}
if (this->dtype() == float64) {
throw std::invalid_argument("float64 is not supported on the GPU");
}
}
}
std::vector<array> array::make_arrays(
std::vector<Shape> shapes,

View File

@ -594,6 +594,9 @@ void array::init(It src) {
case float32:
std::copy(src, src + size(), data<float>());
break;
case float64:
std::copy(src, src + size(), data<double>());
break;
case bfloat16:
std::copy(src, src + size(), data<bfloat16_t>());
break;

View File

@ -151,6 +151,9 @@ void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
case bfloat16:
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
break;
case float64:
*out.data<double>() = static_cast<double>(numel);
break;
case complex64:
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
break;

View File

@ -62,6 +62,9 @@ void arange(
case float32:
arange<float>(start, start + step, out, out.size());
break;
case float64:
arange<double>(start, start + step, out, out.size());
break;
case bfloat16:
arange<bfloat16_t>(start, start + step, out, out.size());
break;

View File

@ -103,6 +103,9 @@ void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
case bfloat16:
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
break;
case float64:
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_);
break;
case complex64:
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
break;

View File

@ -51,6 +51,9 @@ void comparison_op(const array& a, const array& b, array& out, Op op) {
case float32:
binary_op<float, bool>(a, b, out, op);
break;
case float64:
binary_op<double, bool>(a, b, out, op);
break;
case bfloat16:
binary_op<bfloat16_t, bool>(a, b, out, op);
break;
@ -114,6 +117,9 @@ void DivMod::eval_cpu(
case float32:
binary_op<float>(a, b, outputs, float_op);
break;
case float64:
binary_op<double>(a, b, outputs, float_op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, float_op);
break;
@ -150,6 +156,9 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
case float32:
binary_op<float, bool>(a, b, out, detail::NaNEqual());
break;
case float64:
binary_op<double, bool>(a, b, out, detail::NaNEqual());
break;
case bfloat16:
binary_op<bfloat16_t, bool>(a, b, out, detail::NaNEqual());
break;
@ -189,20 +198,22 @@ void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32) {
binary_op<float>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == float16) {
switch (out.dtype()) {
case float16:
binary_op<float16_t>(a, b, out, detail::LogAddExp());
} else if (out.dtype() == bfloat16) {
break;
case float32:
binary_op<float>(a, b, out, detail::LogAddExp());
break;
case float64:
binary_op<double>(a, b, out, detail::LogAddExp());
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
} else if (issubdtype(out.dtype(), inexact)) {
std::ostringstream err;
err << "[logaddexp] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
} else {
throw std::invalid_argument(
"[logaddexp] Cannot compute logaddexp for arrays with"
" non floating point type.");
break;
default:
throw std::runtime_error(
"[LogAddExp::eval_cpu] Only supports non-complex floating point types.");
}
}
@ -321,20 +332,22 @@ void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
const auto& a = inputs[0];
const auto& b = inputs[1];
if (out.dtype() == float32) {
binary_op<float>(a, b, out, detail::ArcTan2());
} else if (out.dtype() == float16) {
switch (out.dtype()) {
case float16:
binary_op<float16_t>(a, b, out, detail::ArcTan2());
} else if (out.dtype() == bfloat16) {
break;
case float32:
binary_op<float>(a, b, out, detail::ArcTan2());
break;
case float64:
binary_op<double>(a, b, out, detail::ArcTan2());
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
} else if (issubdtype(out.dtype(), inexact)) {
std::ostringstream err;
err << "[arctan2] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
} else {
throw std::invalid_argument(
"[arctan2] Cannot compute inverse tangent for arrays"
" with non floating point type.");
break;
default:
throw std::runtime_error(
"[ArcTan2::eval_cpu] Only supports non-complex floating point types.");
}
}

View File

@ -358,6 +358,9 @@ void binary(const array& a, const array& b, array& out, Op op) {
case float32:
binary_op<float>(a, b, out, op);
break;
case float64:
binary_op<double>(a, b, out, op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, out, op);
break;

View File

@ -205,6 +205,9 @@ void binary(
case float32:
binary_op<float>(a, b, outputs, op);
break;
case float64:
binary_op<double>(a, b, outputs, op);
break;
case bfloat16:
binary_op<bfloat16_t>(a, b, outputs, op);
break;

View File

@ -193,6 +193,9 @@ void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
case float32:
copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float64:
copy<SrcT, double>(src, dst, ctype, std::forward<Args>(args)...);
break;
case bfloat16:
copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;
@ -242,6 +245,9 @@ inline void copy_inplace_dispatch(
case float32:
copy<float>(src, dst, ctype, std::forward<Args>(args)...);
break;
case float64:
copy<double>(src, dst, ctype, std::forward<Args>(args)...);
break;
case bfloat16:
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
break;

View File

@ -41,4 +41,39 @@ void matmul<float>(
}
}
template <>
void matmul<double>(
const array& a,
const array& b,
array& out,
bool a_transposed,
bool b_transposed,
size_t lda,
size_t ldb,
float alpha,
float beta) {
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_dgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
alpha, // alpha
a.data<double>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
lda,
b.data<double>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
ldb,
beta, // beta
out.data<double>() + M * N * i,
out.shape(-1) // ldc
);
}
}
} // namespace mlx::core

View File

@ -148,6 +148,9 @@ void dispatch_gather(
case float32:
gather<float, IdxT>(src, inds, out, axes, size);
break;
case float64:
gather<double, IdxT>(src, inds, out, axes, size);
break;
case bfloat16:
gather<bfloat16_t, IdxT>(src, inds, out, axes, size);
break;
@ -288,6 +291,9 @@ void dispatch_gather_axis(
case float32:
gather_axis<float, IdxT>(src, inds, out, axis);
break;
case float64:
gather_axis<double, IdxT>(src, inds, out, axis);
break;
case bfloat16:
gather_axis<bfloat16_t, IdxT>(src, inds, out, axis);
break;
@ -499,6 +505,9 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
case float32:
dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
break;
case float64:
dispatch_scatter<double>(out, inds, updates, axes_, reduce_type_);
break;
case bfloat16:
dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
break;
@ -661,6 +670,9 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
case float32:
dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);
break;
case float64:
dispatch_scatter_axis<double>(out, idx, updates, axis_, reduce_type_);
break;
case bfloat16:
dispatch_scatter_axis<bfloat16_t>(out, idx, updates, axis_, reduce_type_);
break;

View File

@ -46,6 +46,9 @@ void matmul_general(
} else if (out.dtype() == bfloat16) {
matmul<bfloat16_t>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
} else if (out.dtype() == float64) {
matmul<double>(
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
} else {
throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
}

View File

@ -42,6 +42,7 @@ instantiate_default_limit(int64_t);
instantiate_float_limit(float16_t);
instantiate_float_limit(bfloat16_t);
instantiate_float_limit(float);
instantiate_float_limit(double);
instantiate_float_limit(complex64_t);
template <>
@ -59,6 +60,8 @@ const bfloat16_t Limits<bfloat16_t>::min =
const float16_t Limits<float16_t>::max = std::numeric_limits<float>::infinity();
const float16_t Limits<float16_t>::min =
-std::numeric_limits<float>::infinity();
const double Limits<double>::max = std::numeric_limits<double>::infinity();
const double Limits<double>::min = -std::numeric_limits<double>::infinity();
const complex64_t Limits<complex64_t>::max =
std::numeric_limits<float>::infinity();
const complex64_t Limits<complex64_t>::min =
@ -460,6 +463,7 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
break;
case uint64:
case int64:
case float64:
case complex64:
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
break;
@ -495,6 +499,9 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
case float32:
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
break;
case float64:
reduce_dispatch_sum_prod<double>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
break;
@ -537,6 +544,9 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
case float32:
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
break;
case float64:
reduce_dispatch_min_max<double>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
break;

View File

@ -299,6 +299,10 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
scan_dispatch<float, float>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case float64:
scan_dispatch<double, double>(
reduce_type_, in, out, axis_, reverse_, inclusive_);
break;
case bfloat16:
scan_dispatch<bfloat16_t, bfloat16_t>(
reduce_type_, in, out, axis_, reverse_, inclusive_);

View File

@ -51,6 +51,9 @@ void select_op(
case float32:
ternary_op<bool, float, float, float>(a, b, c, out, op);
break;
case float64:
ternary_op<bool, double, double, double>(a, b, c, out, op);
break;
case bfloat16:
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(a, b, c, out, op);
break;

View File

@ -6,6 +6,7 @@
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/simd/simd.h"
#include "mlx/primitives.h"
#include "mlx/types/limits.h"
namespace mlx::core {
@ -28,7 +29,7 @@ void softmax(const array& in, array& out) {
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
// Find the maximum
current_in_ptr = in_ptr;
Simd<AccT, N> vmaximum(-std::numeric_limits<float>::infinity());
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
size_t s = M;
while (s >= N) {
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
@ -163,6 +164,9 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
softmax<bfloat16_t, bfloat16_t>(in, out);
}
break;
case float64:
softmax<double, double>(in, out);
break;
case complex64:
throw std::invalid_argument(
"[Softmax] Not yet implemented for complex64");

View File

@ -312,6 +312,8 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
return argsort<int64_t>(in, out, axis_);
case float32:
return argsort<float>(in, out, axis_);
case float64:
return argsort<double>(in, out, axis_);
case float16:
return argsort<float16_t>(in, out, axis_);
case bfloat16:
@ -346,6 +348,8 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
return sort<int64_t>(in, out, axis_);
case float32:
return sort<float>(in, out, axis_);
case float64:
return sort<double>(in, out, axis_);
case float16:
return sort<float16_t>(in, out, axis_);
case bfloat16:
@ -380,6 +384,8 @@ void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
return argpartition<int64_t>(in, out, axis_, kth_);
case float32:
return argpartition<float>(in, out, axis_, kth_);
case float64:
return argpartition<double>(in, out, axis_, kth_);
case float16:
return argpartition<float16_t>(in, out, axis_, kth_);
case bfloat16:
@ -414,6 +420,8 @@ void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
return partition<int64_t>(in, out, axis_, kth_);
case float32:
return partition<float>(in, out, axis_, kth_);
case float64:
return partition<double>(in, out, axis_, kth_);
case float16:
return partition<float16_t>(in, out, axis_, kth_);
case bfloat16:

View File

@ -34,6 +34,9 @@ void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
case float32:
unary_op<float>(in, out, op);
break;
case float64:
unary_op<double>(in, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, op);
break;
@ -120,6 +123,9 @@ void Erf::eval_cpu(const std::vector<array>& inputs, array& out) {
case float16:
unary_op<float16_t>(in, out, detail::Erf());
break;
case float64:
unary_op<double>(in, out, detail::Erf());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::Erf());
break;
@ -140,6 +146,9 @@ void ErfInv::eval_cpu(const std::vector<array>& inputs, array& out) {
case float16:
unary_op<float16_t>(in, out, detail::ErfInv());
break;
case float64:
unary_op<double>(in, out, detail::ErfInv());
break;
case bfloat16:
unary_op<bfloat16_t>(in, out, detail::ErfInv());
break;

View File

@ -104,6 +104,9 @@ void unary(const array& a, array& out, Op op) {
case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
case bfloat16:
unary_op<bfloat16_t>(a, out, op);
break;
@ -125,6 +128,9 @@ void unary_fp(const array& a, array& out, Op op) {
case float32:
unary_op<float>(a, out, op);
break;
case float64:
unary_op<double>(a, out, op);
break;
case complex64:
unary_op<complex64_t>(a, out, op);
break;

View File

@ -151,8 +151,8 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
case bfloat16:
arange_set_scalars<bfloat16_t>(start_, start_ + step_, compute_encoder);
break;
case complex64:
throw std::runtime_error("[Arange::eval_gpu] Does not support complex64");
default:
throw std::runtime_error("[Arange::eval_gpu] Does not support type.");
}
compute_encoder.set_output_array(out, 2);

View File

@ -42,6 +42,9 @@ std::string type_to_name(const Dtype& t) {
case float32:
tname = "float32";
break;
case float64:
tname = "double";
break;
case bfloat16:
tname = "bfloat16";
break;

View File

@ -95,6 +95,7 @@ struct MPIWrapper {
LOAD_SYMBOL(ompi_mpi_int64_t, mpi_int64_);
LOAD_SYMBOL(ompi_mpi_uint64_t, mpi_uint64_);
LOAD_SYMBOL(ompi_mpi_float, mpi_float_);
LOAD_SYMBOL(ompi_mpi_double, mpi_double_);
LOAD_SYMBOL(ompi_mpi_c_complex, mpi_complex_);
}
@ -164,6 +165,8 @@ struct MPIWrapper {
return mpi_float16_;
case bfloat16:
return mpi_bfloat16_;
case float64:
return mpi_double_;
}
}
@ -218,6 +221,7 @@ struct MPIWrapper {
MPI_Datatype mpi_int64_;
MPI_Datatype mpi_uint64_;
MPI_Datatype mpi_float_;
MPI_Datatype mpi_double_;
MPI_Datatype mpi_complex_;
MPI_Datatype mpi_float16_;
MPI_Datatype mpi_bfloat16_;

View File

@ -68,6 +68,10 @@
using T = float; \
__VA_ARGS__; \
} break; \
case float64: { \
using T = double; \
__VA_ARGS__; \
} break; \
case complex64: { \
using T = complex64_t; \
__VA_ARGS__; \

View File

@ -8,7 +8,7 @@ namespace mlx::core {
namespace {
constexpr int num_types = 13;
constexpr int num_types = 14;
constexpr int num_cats = 8;
constexpr Dtype::Kind type_kinds[num_types] = {
@ -23,6 +23,7 @@ constexpr Dtype::Kind type_kinds[num_types] = {
Dtype::Kind::i, // int64,
Dtype::Kind::f, // float16,
Dtype::Kind::f, // float32,
Dtype::Kind::f, // float64,
Dtype::Kind::V, // bfloat16,
Dtype::Kind::c // complex64,
};
@ -31,20 +32,21 @@ constexpr Dtype::Kind type_kinds[num_types] = {
// https://jax.readthedocs.io/en/latest/type_promotion.html
// clang-format off
constexpr Dtype type_rules[num_types][num_types] = {
// bool uint8 uint16 uint32 uint64 int8 int16 int32 int64 float16 float32 bfloat16 complex64
{bool_, uint8, uint16, uint32, uint64, int8, int16, int32, int64, float16, float32, bfloat16, complex64}, // bool
{uint8, uint8, uint16, uint32, uint64, int16, int16, int32, int64, float16, float32, bfloat16, complex64}, // uint8
{uint16, uint16, uint16, uint32, uint64, int32, int32, int32, int64, float16, float32, bfloat16, complex64}, // uint16
{uint32, uint32, uint32, uint32, uint64, int64, int64, int64, int64, float16, float32, bfloat16, complex64}, // uint32
{uint64, uint64, uint64, uint64, uint64, float32, float32, float32, float32, float16, float32, bfloat16, complex64}, // uint64
{int8, int16, int32, int64, float32, int8, int16, int32, int64, float16, float32, bfloat16, complex64}, // int8
{int16, int16, int32, int64, float32, int16, int16, int32, int64, float16, float32, bfloat16, complex64}, // int16
{int32, int32, int32, int64, float32, int32, int32, int32, int64, float16, float32, bfloat16, complex64}, // int32
{int64, int64, int64, int64, float32, int64, int64, int64, int64, float16, float32, bfloat16, complex64}, // int64
{float16, float16, float16, float16, float16, float16, float16, float16, float16, float16, float32, float32, complex64}, // float16
{float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, complex64}, // float32
{bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, float32, float32, bfloat16, complex64}, // bfloat16
{complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64}, // complex64
// bool uint8 uint16 uint32 uint64 int8 int16 int32 int64 float16 float32 float64 bfloat16 complex64
{bool_, uint8, uint16, uint32, uint64, int8, int16, int32, int64, float16, float32, float64, bfloat16, complex64}, // bool
{uint8, uint8, uint16, uint32, uint64, int16, int16, int32, int64, float16, float32, float64, bfloat16, complex64}, // uint8
{uint16, uint16, uint16, uint32, uint64, int32, int32, int32, int64, float16, float32, float64, bfloat16, complex64}, // uint16
{uint32, uint32, uint32, uint32, uint64, int64, int64, int64, int64, float16, float32, float64, bfloat16, complex64}, // uint32
{uint64, uint64, uint64, uint64, uint64, float32, float32, float32, float32, float16, float32, float64, bfloat16, complex64}, // uint64
{int8, int16, int32, int64, float32, int8, int16, int32, int64, float16, float32, float64, bfloat16, complex64}, // int8
{int16, int16, int32, int64, float32, int16, int16, int32, int64, float16, float32, float64, bfloat16, complex64}, // int16
{int32, int32, int32, int64, float32, int32, int32, int32, int64, float16, float32, float64, bfloat16, complex64}, // int32
{int64, int64, int64, int64, float32, int64, int64, int64, int64, float16, float32, float64, bfloat16, complex64}, // int64
{float16, float16, float16, float16, float16, float16, float16, float16, float16, float16, float32, float64, float32, complex64}, // float16
{float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float64, float32, complex64}, // float32
{float64, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float64, float32, complex64}, // float64
{bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, float32, float32, float64, bfloat16, complex64}, // bfloat16
{complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64,complex64, complex64, complex64}, // complex64
};
@ -72,6 +74,7 @@ constexpr Dtype::Category type_to_category[num_types] = {
Dtype::Category::signedinteger, // int64,
Dtype::Category::floating, // float16,
Dtype::Category::floating, // float32,
Dtype::Category::floating, // float64,
Dtype::Category::floating, // bfloat16,
Dtype::Category::complexfloating, // complex64,
};

View File

@ -23,6 +23,7 @@ struct Dtype {
int64,
float16,
float32,
float64,
bfloat16,
complex64,
};
@ -78,6 +79,7 @@ inline constexpr Dtype int64{Dtype::Val::int64, sizeof(int64_t)};
inline constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)};
inline constexpr Dtype float32{Dtype::Val::float32, sizeof(float)};
inline constexpr Dtype float64{Dtype::Val::float64, sizeof(double)};
inline constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)};
inline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)};

View File

@ -12,6 +12,9 @@ struct numeric_limits;
template <>
struct numeric_limits<float> : public std::numeric_limits<float> {};
template <>
struct numeric_limits<double> : public std::numeric_limits<double> {};
template <>
struct numeric_limits<float16_t> {
private:

View File

@ -54,6 +54,9 @@ inline void PrintFormatter::print(std::ostream& os, bfloat16_t val) {
inline void PrintFormatter::print(std::ostream& os, float val) {
os << val;
}
inline void PrintFormatter::print(std::ostream& os, double val) {
os << val;
}
inline void PrintFormatter::print(std::ostream& os, complex64_t val) {
os << val;
}
@ -234,6 +237,8 @@ std::ostream& operator<<(std::ostream& os, const Dtype& dtype) {
return os << "float16";
case float32:
return os << "float32";
case float64:
return os << "float64";
case bfloat16:
return os << "bfloat16";
case complex64:
@ -299,6 +304,9 @@ std::ostream& operator<<(std::ostream& os, array a) {
case float32:
print_array<float>(os, a);
break;
case float64:
print_array<double>(os, a);
break;
case complex64:
print_array<complex64_t>(os, a);
break;
@ -337,7 +345,7 @@ int get_var(const char* name, int default_value) {
} // namespace env
template <typename T>
void set_finfo_limits(float& min, float& max) {
void set_finfo_limits(double& min, double& max) {
min = numeric_limits<T>::lowest();
max = numeric_limits<T>::max();
}
@ -354,6 +362,8 @@ finfo::finfo(Dtype dtype) : dtype(dtype) {
set_finfo_limits<float16_t>(min, max);
} else if (dtype == bfloat16) {
set_finfo_limits<bfloat16_t>(min, max);
} else if (dtype == float64) {
set_finfo_limits<double>(min, max);
} else if (dtype == complex64) {
this->dtype = float32;
set_finfo_limits<float>(min, max);

View File

@ -47,6 +47,7 @@ struct PrintFormatter {
inline void print(std::ostream& os, float16_t val);
inline void print(std::ostream& os, bfloat16_t val);
inline void print(std::ostream& os, float val);
inline void print(std::ostream& os, double val);
inline void print(std::ostream& os, complex64_t val);
bool capitalize_bool{false};
@ -61,8 +62,8 @@ void abort_with_exception(const std::exception& error);
struct finfo {
explicit finfo(Dtype dtype);
Dtype dtype;
float min;
float max;
double min;
double max;
};
/** The type from promoting the arrays' types with one another. */

View File

@ -128,6 +128,7 @@ void init_array(nb::module_& m) {
m.attr("int64") = nb::cast(mx::int64);
m.attr("float16") = nb::cast(mx::float16);
m.attr("float32") = nb::cast(mx::float32);
m.attr("float64") = nb::cast(mx::float64);
m.attr("bfloat16") = nb::cast(mx::bfloat16);
m.attr("complex64") = nb::cast(mx::complex64);
nb::enum_<mx::Dtype::Category>(
@ -163,6 +164,7 @@ void init_array(nb::module_& m) {
* :ref:`float16 <data_types>`
* :ref:`bfloat16 <data_types>`
* :ref:`float32 <data_types>`
* :ref:`float64 <data_types>`
* :attr:`~mlx.core.complexfloating`

178
python/tests/test_double.py Normal file
View File

@ -0,0 +1,178 @@
# Copyright © 2024 Apple Inc.
import math
import os
import unittest
import mlx.core as mx
import mlx_tests
import numpy as np
class TestDouble(mlx_tests.MLXTestCase):
def test_unary_ops(self):
shape = (3, 3)
x = mx.random.normal(shape=shape)
if mx.default_device() == mx.gpu:
with self.assertRaises(ValueError):
x.astype(mx.float64)
x_double = x.astype(mx.float64, stream=mx.cpu)
ops = [
mx.abs,
mx.arccos,
mx.arccosh,
mx.arcsin,
mx.arcsinh,
mx.arctan,
mx.arctanh,
mx.ceil,
mx.erf,
mx.erfinv,
mx.exp,
mx.expm1,
mx.floor,
mx.log,
mx.logical_not,
mx.negative,
mx.round,
mx.sin,
mx.sinh,
mx.sqrt,
mx.rsqrt,
mx.tan,
mx.tanh,
]
for op in ops:
if mx.default_device() == mx.gpu:
with self.assertRaises(ValueError):
op(x_double)
continue
y = op(x)
y_double = op(x_double)
self.assertTrue(
mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True)
)
def test_binary_ops(self):
shape = (3, 3)
a = mx.random.normal(shape=shape)
b = mx.random.normal(shape=shape)
a_double = a.astype(mx.float64, stream=mx.cpu)
b_double = b.astype(mx.float64, stream=mx.cpu)
ops = [
mx.add,
mx.arctan2,
mx.divide,
mx.multiply,
mx.subtract,
mx.logical_and,
mx.logical_or,
mx.remainder,
mx.maximum,
mx.minimum,
mx.power,
mx.equal,
mx.greater,
mx.greater_equal,
mx.less,
mx.less_equal,
mx.not_equal,
mx.logaddexp,
]
for op in ops:
if mx.default_device() == mx.gpu:
with self.assertRaises(ValueError):
op(a_double, b_double)
continue
y = op(a, b)
y_double = op(a_double, b_double)
self.assertTrue(
mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True)
)
def test_where(self):
shape = (3, 3)
cond = mx.random.uniform(shape=shape) > 0.5
a = mx.random.normal(shape=shape)
b = mx.random.normal(shape=shape)
a_double = a.astype(mx.float64, stream=mx.cpu)
b_double = b.astype(mx.float64, stream=mx.cpu)
if mx.default_device() == mx.gpu:
with self.assertRaises(ValueError):
mx.where(cond, a_double, b_double)
return
y = mx.where(cond, a, b)
y_double = mx.where(cond, a_double, b_double)
self.assertTrue(mx.allclose(y, y_double.astype(mx.float32, mx.cpu)))
def test_reductions(self):
shape = (32, 32)
a = mx.random.normal(shape=shape)
a_double = a.astype(mx.float64, stream=mx.cpu)
axes = [0, 1, (0, 1)]
ops = [mx.sum, mx.prod, mx.min, mx.max, mx.any, mx.all]
for op in ops:
for ax in axes:
if mx.default_device() == mx.gpu:
with self.assertRaises(ValueError):
op(a_double, axis=ax)
continue
y = op(a)
y_double = op(a_double)
self.assertTrue(mx.allclose(y, y_double.astype(mx.float32, mx.cpu)))
def test_get_and_set_item(self):
shape = (3, 3)
a = mx.random.normal(shape=shape)
b = mx.random.normal(shape=(2,))
a_double = a.astype(mx.float64, stream=mx.cpu)
b_double = b.astype(mx.float64, stream=mx.cpu)
idx_i = mx.array([0, 2])
idx_j = mx.array([0, 2])
if mx.default_device() == mx.gpu:
with self.assertRaises(ValueError):
a_double[idx_i, idx_j]
else:
y = a[idx_i, idx_j]
y_double = a_double[idx_i, idx_j]
self.assertTrue(mx.allclose(y, y_double.astype(mx.float32, mx.cpu)))
if mx.default_device() == mx.gpu:
with self.assertRaises(ValueError):
a_double[idx_i, idx_j] = b_double
else:
a[idx_i, idx_j] = b
a_double[idx_i, idx_j] = b_double
self.assertTrue(mx.allclose(a, a_double.astype(mx.float32, mx.cpu)))
def test_gemm(self):
shape = (8, 8)
a = mx.random.normal(shape=shape)
b = mx.random.normal(shape=shape)
a_double = a.astype(mx.float64, stream=mx.cpu)
b_double = b.astype(mx.float64, stream=mx.cpu)
if mx.default_device() == mx.gpu:
with self.assertRaises(ValueError):
a_double @ b_double
return
y = a @ b
y_double = a_double @ b_double
self.assertTrue(
mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True)
)
if __name__ == "__main__":
unittest.main()