mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Fp64 on the CPU (#1843)
* add fp64 data type * clean build * update docs * fix bug
This commit is contained in:
parent
1a1b2108ec
commit
1c0c118f7c
@ -51,11 +51,20 @@ The default floating point type is ``float32`` and the default integer type is
|
||||
* - ``float32``
|
||||
- 4
|
||||
- 32-bit float
|
||||
* - ``float64``
|
||||
- 4
|
||||
- 64-bit double
|
||||
* - ``complex64``
|
||||
- 8
|
||||
- 64-bit complex float
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
Arrays with type ``float64`` only work with CPU operations. Using
|
||||
``float64`` arrays on the GPU will result in an exception.
|
||||
|
||||
|
||||
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
|
||||
documentation for more information. Use :func:`issubdtype` to determine if one
|
||||
``dtype`` (or category) is a subtype of another category.
|
||||
|
@ -21,11 +21,13 @@ Let's convert an array to NumPy and back.
|
||||
|
||||
.. note::
|
||||
|
||||
Since NumPy does not support ``bfloat16`` arrays, you will need to convert to ``float16`` or ``float32`` first:
|
||||
``np.array(a.astype(mx.float32))``.
|
||||
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.``
|
||||
Since NumPy does not support ``bfloat16`` arrays, you will need to convert
|
||||
to ``float16`` or ``float32`` first: ``np.array(a.astype(mx.float32))``.
|
||||
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118
|
||||
buffer format string does not match the dtype V item size 0.``
|
||||
|
||||
By default, NumPy copies data to a new array. This can be prevented by creating an array view:
|
||||
By default, NumPy copies data to a new array. This can be prevented by creating
|
||||
an array view:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@ -35,10 +37,16 @@ By default, NumPy copies data to a new array. This can be prevented by creating
|
||||
a_view[0] = 1
|
||||
print(a[0].item()) # 1
|
||||
|
||||
A NumPy array view is a normal NumPy array, except that it does not own its memory.
|
||||
This means writing to the view is reflected in the original array.
|
||||
.. note::
|
||||
|
||||
While this is quite powerful to prevent copying arrays, it should be noted that external changes to the memory of arrays cannot be reflected in gradients.
|
||||
NumPy arrays with type ``float64`` will be default converted to MLX arrays
|
||||
with type ``float32``.
|
||||
|
||||
A NumPy array view is a normal NumPy array, except that it does not own its
|
||||
memory. This means writing to the view is reflected in the original array.
|
||||
|
||||
While this is quite powerful to prevent copying arrays, it should be noted that
|
||||
external changes to the memory of arrays cannot be reflected in gradients.
|
||||
|
||||
Let's demonstrate this in an example:
|
||||
|
||||
@ -56,11 +64,12 @@ Let's demonstrate this in an example:
|
||||
|
||||
|
||||
The function ``f`` indirectly modifies the array ``x`` through a memory view.
|
||||
However, this modification is not reflected in the gradient, as seen in the last line outputting ``1.0``,
|
||||
representing the gradient of the sum operation alone.
|
||||
The squaring of ``x`` occurs externally to MLX, meaning that no gradient is incorporated.
|
||||
It's important to note that a similar issue arises during array conversion and copying.
|
||||
For instance, a function defined as ``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
|
||||
However, this modification is not reflected in the gradient, as seen in the
|
||||
last line outputting ``1.0``, representing the gradient of the sum operation
|
||||
alone. The squaring of ``x`` occurs externally to MLX, meaning that no
|
||||
gradient is incorporated. It's important to note that a similar issue arises
|
||||
during array conversion and copying. For instance, a function defined as
|
||||
``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
|
||||
even though no in-place operations on MLX memory are executed.
|
||||
|
||||
PyTorch
|
||||
@ -71,7 +80,8 @@ PyTorch
|
||||
PyTorch Support for :obj:`memoryview` is experimental and can break for
|
||||
multi-dimensional arrays. Casting to NumPy first is advised for now.
|
||||
|
||||
PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
|
||||
PyTorch supports the buffer protocol, but it requires an explicit
|
||||
:obj:`memoryview`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@ -82,7 +92,8 @@ PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryvi
|
||||
b = torch.tensor(memoryview(a))
|
||||
c = mx.array(b.numpy())
|
||||
|
||||
Conversion from PyTorch tensors back to arrays must be done via intermediate NumPy arrays with ``numpy()``.
|
||||
Conversion from PyTorch tensors back to arrays must be done via intermediate
|
||||
NumPy arrays with ``numpy()``.
|
||||
|
||||
JAX
|
||||
---
|
||||
@ -100,7 +111,8 @@ JAX fully supports the buffer protocol.
|
||||
TensorFlow
|
||||
----------
|
||||
|
||||
TensorFlow supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
|
||||
TensorFlow supports the buffer protocol, but it requires an explicit
|
||||
:obj:`memoryview`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -25,7 +25,18 @@ array::array(
|
||||
std::move(shape),
|
||||
dtype,
|
||||
std::move(primitive),
|
||||
std::move(inputs))) {}
|
||||
std::move(inputs))) {
|
||||
if (has_primitive() && this->primitive().stream().device == Device::gpu) {
|
||||
for (auto& in : this->inputs()) {
|
||||
if (in.dtype() == float64) {
|
||||
throw std::invalid_argument("float64 is not supported on the GPU");
|
||||
}
|
||||
}
|
||||
if (this->dtype() == float64) {
|
||||
throw std::invalid_argument("float64 is not supported on the GPU");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> array::make_arrays(
|
||||
std::vector<Shape> shapes,
|
||||
|
@ -594,6 +594,9 @@ void array::init(It src) {
|
||||
case float32:
|
||||
std::copy(src, src + size(), data<float>());
|
||||
break;
|
||||
case float64:
|
||||
std::copy(src, src + size(), data<double>());
|
||||
break;
|
||||
case bfloat16:
|
||||
std::copy(src, src + size(), data<bfloat16_t>());
|
||||
break;
|
||||
|
@ -151,6 +151,9 @@ void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
case bfloat16:
|
||||
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
|
||||
break;
|
||||
case float64:
|
||||
*out.data<double>() = static_cast<double>(numel);
|
||||
break;
|
||||
case complex64:
|
||||
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
|
||||
break;
|
||||
|
@ -62,6 +62,9 @@ void arange(
|
||||
case float32:
|
||||
arange<float>(start, start + step, out, out.size());
|
||||
break;
|
||||
case float64:
|
||||
arange<double>(start, start + step, out, out.size());
|
||||
break;
|
||||
case bfloat16:
|
||||
arange<bfloat16_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
|
@ -103,6 +103,9 @@ void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case bfloat16:
|
||||
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case float64:
|
||||
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case complex64:
|
||||
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
|
@ -51,6 +51,9 @@ void comparison_op(const array& a, const array& b, array& out, Op op) {
|
||||
case float32:
|
||||
binary_op<float, bool>(a, b, out, op);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, bool>(a, b, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
@ -114,6 +117,9 @@ void DivMod::eval_cpu(
|
||||
case float32:
|
||||
binary_op<float>(a, b, outputs, float_op);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double>(a, b, outputs, float_op);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, outputs, float_op);
|
||||
break;
|
||||
@ -150,6 +156,9 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case float32:
|
||||
binary_op<float, bool>(a, b, out, detail::NaNEqual());
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, bool>(a, b, out, detail::NaNEqual());
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, bool>(a, b, out, detail::NaNEqual());
|
||||
break;
|
||||
@ -189,20 +198,22 @@ void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (out.dtype() == float32) {
|
||||
binary_op<float>(a, b, out, detail::LogAddExp());
|
||||
} else if (out.dtype() == float16) {
|
||||
switch (out.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, out, detail::LogAddExp());
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, out, detail::LogAddExp());
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double>(a, b, out, detail::LogAddExp());
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
std::ostringstream err;
|
||||
err << "[logaddexp] Does not support " << out.dtype();
|
||||
throw std::invalid_argument(err.str());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[logaddexp] Cannot compute logaddexp for arrays with"
|
||||
" non floating point type.");
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[LogAddExp::eval_cpu] Only supports non-complex floating point types.");
|
||||
}
|
||||
}
|
||||
|
||||
@ -321,20 +332,22 @@ void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
if (out.dtype() == float32) {
|
||||
binary_op<float>(a, b, out, detail::ArcTan2());
|
||||
} else if (out.dtype() == float16) {
|
||||
switch (out.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, out, detail::ArcTan2());
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, out, detail::ArcTan2());
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double>(a, b, out, detail::ArcTan2());
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
std::ostringstream err;
|
||||
err << "[arctan2] Does not support " << out.dtype();
|
||||
throw std::invalid_argument(err.str());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arctan2] Cannot compute inverse tangent for arrays"
|
||||
" with non floating point type.");
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[ArcTan2::eval_cpu] Only supports non-complex floating point types.");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -358,6 +358,9 @@ void binary(const array& a, const array& b, array& out, Op op) {
|
||||
case float32:
|
||||
binary_op<float>(a, b, out, op);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double>(a, b, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, out, op);
|
||||
break;
|
||||
|
@ -205,6 +205,9 @@ void binary(
|
||||
case float32:
|
||||
binary_op<float>(a, b, outputs, op);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double>(a, b, outputs, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, outputs, op);
|
||||
break;
|
||||
|
@ -193,6 +193,9 @@ void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
||||
case float32:
|
||||
copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float64:
|
||||
copy<SrcT, double>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case bfloat16:
|
||||
copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
@ -242,6 +245,9 @@ inline void copy_inplace_dispatch(
|
||||
case float32:
|
||||
copy<float>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float64:
|
||||
copy<double>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case bfloat16:
|
||||
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
|
@ -41,4 +41,39 @@ void matmul<float>(
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void matmul<double>(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
float alpha,
|
||||
float beta) {
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
cblas_dgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
alpha, // alpha
|
||||
a.data<double>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
||||
lda,
|
||||
b.data<double>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
||||
ldb,
|
||||
beta, // beta
|
||||
out.data<double>() + M * N * i,
|
||||
out.shape(-1) // ldc
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -148,6 +148,9 @@ void dispatch_gather(
|
||||
case float32:
|
||||
gather<float, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case float64:
|
||||
gather<double, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case bfloat16:
|
||||
gather<bfloat16_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
@ -288,6 +291,9 @@ void dispatch_gather_axis(
|
||||
case float32:
|
||||
gather_axis<float, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case float64:
|
||||
gather_axis<double, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
case bfloat16:
|
||||
gather_axis<bfloat16_t, IdxT>(src, inds, out, axis);
|
||||
break;
|
||||
@ -499,6 +505,9 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case float32:
|
||||
dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case float64:
|
||||
dispatch_scatter<double>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case bfloat16:
|
||||
dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
@ -661,6 +670,9 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case float32:
|
||||
dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case float64:
|
||||
dispatch_scatter_axis<double>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
case bfloat16:
|
||||
dispatch_scatter_axis<bfloat16_t>(out, idx, updates, axis_, reduce_type_);
|
||||
break;
|
||||
|
@ -46,6 +46,9 @@ void matmul_general(
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
matmul<bfloat16_t>(
|
||||
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
|
||||
} else if (out.dtype() == float64) {
|
||||
matmul<double>(
|
||||
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
|
||||
} else {
|
||||
throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
|
||||
}
|
||||
|
@ -42,6 +42,7 @@ instantiate_default_limit(int64_t);
|
||||
instantiate_float_limit(float16_t);
|
||||
instantiate_float_limit(bfloat16_t);
|
||||
instantiate_float_limit(float);
|
||||
instantiate_float_limit(double);
|
||||
instantiate_float_limit(complex64_t);
|
||||
|
||||
template <>
|
||||
@ -59,6 +60,8 @@ const bfloat16_t Limits<bfloat16_t>::min =
|
||||
const float16_t Limits<float16_t>::max = std::numeric_limits<float>::infinity();
|
||||
const float16_t Limits<float16_t>::min =
|
||||
-std::numeric_limits<float>::infinity();
|
||||
const double Limits<double>::max = std::numeric_limits<double>::infinity();
|
||||
const double Limits<double>::min = -std::numeric_limits<double>::infinity();
|
||||
const complex64_t Limits<complex64_t>::max =
|
||||
std::numeric_limits<float>::infinity();
|
||||
const complex64_t Limits<complex64_t>::min =
|
||||
@ -460,6 +463,7 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
break;
|
||||
case uint64:
|
||||
case int64:
|
||||
case float64:
|
||||
case complex64:
|
||||
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
@ -495,6 +499,9 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case float32:
|
||||
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float64:
|
||||
reduce_dispatch_sum_prod<double>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
@ -537,6 +544,9 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case float32:
|
||||
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float64:
|
||||
reduce_dispatch_min_max<double>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
|
@ -299,6 +299,10 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
scan_dispatch<float, float>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case float64:
|
||||
scan_dispatch<double, double>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case bfloat16:
|
||||
scan_dispatch<bfloat16_t, bfloat16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
|
@ -51,6 +51,9 @@ void select_op(
|
||||
case float32:
|
||||
ternary_op<bool, float, float, float>(a, b, c, out, op);
|
||||
break;
|
||||
case float64:
|
||||
ternary_op<bool, double, double, double>(a, b, c, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(a, b, c, out, op);
|
||||
break;
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/types/limits.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@ -28,7 +29,7 @@ void softmax(const array& in, array& out) {
|
||||
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
|
||||
// Find the maximum
|
||||
current_in_ptr = in_ptr;
|
||||
Simd<AccT, N> vmaximum(-std::numeric_limits<float>::infinity());
|
||||
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
|
||||
size_t s = M;
|
||||
while (s >= N) {
|
||||
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
|
||||
@ -163,6 +164,9 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
softmax<bfloat16_t, bfloat16_t>(in, out);
|
||||
}
|
||||
break;
|
||||
case float64:
|
||||
softmax<double, double>(in, out);
|
||||
break;
|
||||
case complex64:
|
||||
throw std::invalid_argument(
|
||||
"[Softmax] Not yet implemented for complex64");
|
||||
|
@ -312,6 +312,8 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return argsort<int64_t>(in, out, axis_);
|
||||
case float32:
|
||||
return argsort<float>(in, out, axis_);
|
||||
case float64:
|
||||
return argsort<double>(in, out, axis_);
|
||||
case float16:
|
||||
return argsort<float16_t>(in, out, axis_);
|
||||
case bfloat16:
|
||||
@ -346,6 +348,8 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return sort<int64_t>(in, out, axis_);
|
||||
case float32:
|
||||
return sort<float>(in, out, axis_);
|
||||
case float64:
|
||||
return sort<double>(in, out, axis_);
|
||||
case float16:
|
||||
return sort<float16_t>(in, out, axis_);
|
||||
case bfloat16:
|
||||
@ -380,6 +384,8 @@ void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return argpartition<int64_t>(in, out, axis_, kth_);
|
||||
case float32:
|
||||
return argpartition<float>(in, out, axis_, kth_);
|
||||
case float64:
|
||||
return argpartition<double>(in, out, axis_, kth_);
|
||||
case float16:
|
||||
return argpartition<float16_t>(in, out, axis_, kth_);
|
||||
case bfloat16:
|
||||
@ -414,6 +420,8 @@ void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return partition<int64_t>(in, out, axis_, kth_);
|
||||
case float32:
|
||||
return partition<float>(in, out, axis_, kth_);
|
||||
case float64:
|
||||
return partition<double>(in, out, axis_, kth_);
|
||||
case float16:
|
||||
return partition<float16_t>(in, out, axis_, kth_);
|
||||
case bfloat16:
|
||||
|
@ -34,6 +34,9 @@ void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case float32:
|
||||
unary_op<float>(in, out, op);
|
||||
break;
|
||||
case float64:
|
||||
unary_op<double>(in, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(in, out, op);
|
||||
break;
|
||||
@ -120,6 +123,9 @@ void Erf::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case float16:
|
||||
unary_op<float16_t>(in, out, detail::Erf());
|
||||
break;
|
||||
case float64:
|
||||
unary_op<double>(in, out, detail::Erf());
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(in, out, detail::Erf());
|
||||
break;
|
||||
@ -140,6 +146,9 @@ void ErfInv::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case float16:
|
||||
unary_op<float16_t>(in, out, detail::ErfInv());
|
||||
break;
|
||||
case float64:
|
||||
unary_op<double>(in, out, detail::ErfInv());
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(in, out, detail::ErfInv());
|
||||
break;
|
||||
|
@ -104,6 +104,9 @@ void unary(const array& a, array& out, Op op) {
|
||||
case float32:
|
||||
unary_op<float>(a, out, op);
|
||||
break;
|
||||
case float64:
|
||||
unary_op<double>(a, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
unary_op<bfloat16_t>(a, out, op);
|
||||
break;
|
||||
@ -125,6 +128,9 @@ void unary_fp(const array& a, array& out, Op op) {
|
||||
case float32:
|
||||
unary_op<float>(a, out, op);
|
||||
break;
|
||||
case float64:
|
||||
unary_op<double>(a, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
unary_op<complex64_t>(a, out, op);
|
||||
break;
|
||||
|
@ -151,8 +151,8 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
case bfloat16:
|
||||
arange_set_scalars<bfloat16_t>(start_, start_ + step_, compute_encoder);
|
||||
break;
|
||||
case complex64:
|
||||
throw std::runtime_error("[Arange::eval_gpu] Does not support complex64");
|
||||
default:
|
||||
throw std::runtime_error("[Arange::eval_gpu] Does not support type.");
|
||||
}
|
||||
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
@ -42,6 +42,9 @@ std::string type_to_name(const Dtype& t) {
|
||||
case float32:
|
||||
tname = "float32";
|
||||
break;
|
||||
case float64:
|
||||
tname = "double";
|
||||
break;
|
||||
case bfloat16:
|
||||
tname = "bfloat16";
|
||||
break;
|
||||
|
@ -95,6 +95,7 @@ struct MPIWrapper {
|
||||
LOAD_SYMBOL(ompi_mpi_int64_t, mpi_int64_);
|
||||
LOAD_SYMBOL(ompi_mpi_uint64_t, mpi_uint64_);
|
||||
LOAD_SYMBOL(ompi_mpi_float, mpi_float_);
|
||||
LOAD_SYMBOL(ompi_mpi_double, mpi_double_);
|
||||
LOAD_SYMBOL(ompi_mpi_c_complex, mpi_complex_);
|
||||
}
|
||||
|
||||
@ -164,6 +165,8 @@ struct MPIWrapper {
|
||||
return mpi_float16_;
|
||||
case bfloat16:
|
||||
return mpi_bfloat16_;
|
||||
case float64:
|
||||
return mpi_double_;
|
||||
}
|
||||
}
|
||||
|
||||
@ -218,6 +221,7 @@ struct MPIWrapper {
|
||||
MPI_Datatype mpi_int64_;
|
||||
MPI_Datatype mpi_uint64_;
|
||||
MPI_Datatype mpi_float_;
|
||||
MPI_Datatype mpi_double_;
|
||||
MPI_Datatype mpi_complex_;
|
||||
MPI_Datatype mpi_float16_;
|
||||
MPI_Datatype mpi_bfloat16_;
|
||||
|
@ -68,6 +68,10 @@
|
||||
using T = float; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case float64: { \
|
||||
using T = double; \
|
||||
__VA_ARGS__; \
|
||||
} break; \
|
||||
case complex64: { \
|
||||
using T = complex64_t; \
|
||||
__VA_ARGS__; \
|
||||
|
@ -8,7 +8,7 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr int num_types = 13;
|
||||
constexpr int num_types = 14;
|
||||
constexpr int num_cats = 8;
|
||||
|
||||
constexpr Dtype::Kind type_kinds[num_types] = {
|
||||
@ -23,6 +23,7 @@ constexpr Dtype::Kind type_kinds[num_types] = {
|
||||
Dtype::Kind::i, // int64,
|
||||
Dtype::Kind::f, // float16,
|
||||
Dtype::Kind::f, // float32,
|
||||
Dtype::Kind::f, // float64,
|
||||
Dtype::Kind::V, // bfloat16,
|
||||
Dtype::Kind::c // complex64,
|
||||
};
|
||||
@ -31,20 +32,21 @@ constexpr Dtype::Kind type_kinds[num_types] = {
|
||||
// https://jax.readthedocs.io/en/latest/type_promotion.html
|
||||
// clang-format off
|
||||
constexpr Dtype type_rules[num_types][num_types] = {
|
||||
// bool uint8 uint16 uint32 uint64 int8 int16 int32 int64 float16 float32 bfloat16 complex64
|
||||
{bool_, uint8, uint16, uint32, uint64, int8, int16, int32, int64, float16, float32, bfloat16, complex64}, // bool
|
||||
{uint8, uint8, uint16, uint32, uint64, int16, int16, int32, int64, float16, float32, bfloat16, complex64}, // uint8
|
||||
{uint16, uint16, uint16, uint32, uint64, int32, int32, int32, int64, float16, float32, bfloat16, complex64}, // uint16
|
||||
{uint32, uint32, uint32, uint32, uint64, int64, int64, int64, int64, float16, float32, bfloat16, complex64}, // uint32
|
||||
{uint64, uint64, uint64, uint64, uint64, float32, float32, float32, float32, float16, float32, bfloat16, complex64}, // uint64
|
||||
{int8, int16, int32, int64, float32, int8, int16, int32, int64, float16, float32, bfloat16, complex64}, // int8
|
||||
{int16, int16, int32, int64, float32, int16, int16, int32, int64, float16, float32, bfloat16, complex64}, // int16
|
||||
{int32, int32, int32, int64, float32, int32, int32, int32, int64, float16, float32, bfloat16, complex64}, // int32
|
||||
{int64, int64, int64, int64, float32, int64, int64, int64, int64, float16, float32, bfloat16, complex64}, // int64
|
||||
{float16, float16, float16, float16, float16, float16, float16, float16, float16, float16, float32, float32, complex64}, // float16
|
||||
{float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, complex64}, // float32
|
||||
{bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, float32, float32, bfloat16, complex64}, // bfloat16
|
||||
{complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64}, // complex64
|
||||
// bool uint8 uint16 uint32 uint64 int8 int16 int32 int64 float16 float32 float64 bfloat16 complex64
|
||||
{bool_, uint8, uint16, uint32, uint64, int8, int16, int32, int64, float16, float32, float64, bfloat16, complex64}, // bool
|
||||
{uint8, uint8, uint16, uint32, uint64, int16, int16, int32, int64, float16, float32, float64, bfloat16, complex64}, // uint8
|
||||
{uint16, uint16, uint16, uint32, uint64, int32, int32, int32, int64, float16, float32, float64, bfloat16, complex64}, // uint16
|
||||
{uint32, uint32, uint32, uint32, uint64, int64, int64, int64, int64, float16, float32, float64, bfloat16, complex64}, // uint32
|
||||
{uint64, uint64, uint64, uint64, uint64, float32, float32, float32, float32, float16, float32, float64, bfloat16, complex64}, // uint64
|
||||
{int8, int16, int32, int64, float32, int8, int16, int32, int64, float16, float32, float64, bfloat16, complex64}, // int8
|
||||
{int16, int16, int32, int64, float32, int16, int16, int32, int64, float16, float32, float64, bfloat16, complex64}, // int16
|
||||
{int32, int32, int32, int64, float32, int32, int32, int32, int64, float16, float32, float64, bfloat16, complex64}, // int32
|
||||
{int64, int64, int64, int64, float32, int64, int64, int64, int64, float16, float32, float64, bfloat16, complex64}, // int64
|
||||
{float16, float16, float16, float16, float16, float16, float16, float16, float16, float16, float32, float64, float32, complex64}, // float16
|
||||
{float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float64, float32, complex64}, // float32
|
||||
{float64, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float64, float32, complex64}, // float64
|
||||
{bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, float32, float32, float64, bfloat16, complex64}, // bfloat16
|
||||
{complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64,complex64, complex64, complex64}, // complex64
|
||||
};
|
||||
|
||||
|
||||
@ -72,6 +74,7 @@ constexpr Dtype::Category type_to_category[num_types] = {
|
||||
Dtype::Category::signedinteger, // int64,
|
||||
Dtype::Category::floating, // float16,
|
||||
Dtype::Category::floating, // float32,
|
||||
Dtype::Category::floating, // float64,
|
||||
Dtype::Category::floating, // bfloat16,
|
||||
Dtype::Category::complexfloating, // complex64,
|
||||
};
|
||||
|
@ -23,6 +23,7 @@ struct Dtype {
|
||||
int64,
|
||||
float16,
|
||||
float32,
|
||||
float64,
|
||||
bfloat16,
|
||||
complex64,
|
||||
};
|
||||
@ -78,6 +79,7 @@ inline constexpr Dtype int64{Dtype::Val::int64, sizeof(int64_t)};
|
||||
|
||||
inline constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)};
|
||||
inline constexpr Dtype float32{Dtype::Val::float32, sizeof(float)};
|
||||
inline constexpr Dtype float64{Dtype::Val::float64, sizeof(double)};
|
||||
inline constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)};
|
||||
inline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)};
|
||||
|
||||
|
@ -12,6 +12,9 @@ struct numeric_limits;
|
||||
template <>
|
||||
struct numeric_limits<float> : public std::numeric_limits<float> {};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<double> : public std::numeric_limits<double> {};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<float16_t> {
|
||||
private:
|
||||
|
@ -54,6 +54,9 @@ inline void PrintFormatter::print(std::ostream& os, bfloat16_t val) {
|
||||
inline void PrintFormatter::print(std::ostream& os, float val) {
|
||||
os << val;
|
||||
}
|
||||
inline void PrintFormatter::print(std::ostream& os, double val) {
|
||||
os << val;
|
||||
}
|
||||
inline void PrintFormatter::print(std::ostream& os, complex64_t val) {
|
||||
os << val;
|
||||
}
|
||||
@ -234,6 +237,8 @@ std::ostream& operator<<(std::ostream& os, const Dtype& dtype) {
|
||||
return os << "float16";
|
||||
case float32:
|
||||
return os << "float32";
|
||||
case float64:
|
||||
return os << "float64";
|
||||
case bfloat16:
|
||||
return os << "bfloat16";
|
||||
case complex64:
|
||||
@ -299,6 +304,9 @@ std::ostream& operator<<(std::ostream& os, array a) {
|
||||
case float32:
|
||||
print_array<float>(os, a);
|
||||
break;
|
||||
case float64:
|
||||
print_array<double>(os, a);
|
||||
break;
|
||||
case complex64:
|
||||
print_array<complex64_t>(os, a);
|
||||
break;
|
||||
@ -337,7 +345,7 @@ int get_var(const char* name, int default_value) {
|
||||
} // namespace env
|
||||
|
||||
template <typename T>
|
||||
void set_finfo_limits(float& min, float& max) {
|
||||
void set_finfo_limits(double& min, double& max) {
|
||||
min = numeric_limits<T>::lowest();
|
||||
max = numeric_limits<T>::max();
|
||||
}
|
||||
@ -354,6 +362,8 @@ finfo::finfo(Dtype dtype) : dtype(dtype) {
|
||||
set_finfo_limits<float16_t>(min, max);
|
||||
} else if (dtype == bfloat16) {
|
||||
set_finfo_limits<bfloat16_t>(min, max);
|
||||
} else if (dtype == float64) {
|
||||
set_finfo_limits<double>(min, max);
|
||||
} else if (dtype == complex64) {
|
||||
this->dtype = float32;
|
||||
set_finfo_limits<float>(min, max);
|
||||
|
@ -47,6 +47,7 @@ struct PrintFormatter {
|
||||
inline void print(std::ostream& os, float16_t val);
|
||||
inline void print(std::ostream& os, bfloat16_t val);
|
||||
inline void print(std::ostream& os, float val);
|
||||
inline void print(std::ostream& os, double val);
|
||||
inline void print(std::ostream& os, complex64_t val);
|
||||
|
||||
bool capitalize_bool{false};
|
||||
@ -61,8 +62,8 @@ void abort_with_exception(const std::exception& error);
|
||||
struct finfo {
|
||||
explicit finfo(Dtype dtype);
|
||||
Dtype dtype;
|
||||
float min;
|
||||
float max;
|
||||
double min;
|
||||
double max;
|
||||
};
|
||||
|
||||
/** The type from promoting the arrays' types with one another. */
|
||||
|
@ -128,6 +128,7 @@ void init_array(nb::module_& m) {
|
||||
m.attr("int64") = nb::cast(mx::int64);
|
||||
m.attr("float16") = nb::cast(mx::float16);
|
||||
m.attr("float32") = nb::cast(mx::float32);
|
||||
m.attr("float64") = nb::cast(mx::float64);
|
||||
m.attr("bfloat16") = nb::cast(mx::bfloat16);
|
||||
m.attr("complex64") = nb::cast(mx::complex64);
|
||||
nb::enum_<mx::Dtype::Category>(
|
||||
@ -163,6 +164,7 @@ void init_array(nb::module_& m) {
|
||||
* :ref:`float16 <data_types>`
|
||||
* :ref:`bfloat16 <data_types>`
|
||||
* :ref:`float32 <data_types>`
|
||||
* :ref:`float64 <data_types>`
|
||||
|
||||
* :attr:`~mlx.core.complexfloating`
|
||||
|
||||
|
178
python/tests/test_double.py
Normal file
178
python/tests/test_double.py
Normal file
@ -0,0 +1,178 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestDouble(mlx_tests.MLXTestCase):
|
||||
def test_unary_ops(self):
|
||||
shape = (3, 3)
|
||||
x = mx.random.normal(shape=shape)
|
||||
|
||||
if mx.default_device() == mx.gpu:
|
||||
with self.assertRaises(ValueError):
|
||||
x.astype(mx.float64)
|
||||
|
||||
x_double = x.astype(mx.float64, stream=mx.cpu)
|
||||
|
||||
ops = [
|
||||
mx.abs,
|
||||
mx.arccos,
|
||||
mx.arccosh,
|
||||
mx.arcsin,
|
||||
mx.arcsinh,
|
||||
mx.arctan,
|
||||
mx.arctanh,
|
||||
mx.ceil,
|
||||
mx.erf,
|
||||
mx.erfinv,
|
||||
mx.exp,
|
||||
mx.expm1,
|
||||
mx.floor,
|
||||
mx.log,
|
||||
mx.logical_not,
|
||||
mx.negative,
|
||||
mx.round,
|
||||
mx.sin,
|
||||
mx.sinh,
|
||||
mx.sqrt,
|
||||
mx.rsqrt,
|
||||
mx.tan,
|
||||
mx.tanh,
|
||||
]
|
||||
for op in ops:
|
||||
if mx.default_device() == mx.gpu:
|
||||
with self.assertRaises(ValueError):
|
||||
op(x_double)
|
||||
continue
|
||||
y = op(x)
|
||||
y_double = op(x_double)
|
||||
self.assertTrue(
|
||||
mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True)
|
||||
)
|
||||
|
||||
def test_binary_ops(self):
|
||||
shape = (3, 3)
|
||||
a = mx.random.normal(shape=shape)
|
||||
b = mx.random.normal(shape=shape)
|
||||
|
||||
a_double = a.astype(mx.float64, stream=mx.cpu)
|
||||
b_double = b.astype(mx.float64, stream=mx.cpu)
|
||||
|
||||
ops = [
|
||||
mx.add,
|
||||
mx.arctan2,
|
||||
mx.divide,
|
||||
mx.multiply,
|
||||
mx.subtract,
|
||||
mx.logical_and,
|
||||
mx.logical_or,
|
||||
mx.remainder,
|
||||
mx.maximum,
|
||||
mx.minimum,
|
||||
mx.power,
|
||||
mx.equal,
|
||||
mx.greater,
|
||||
mx.greater_equal,
|
||||
mx.less,
|
||||
mx.less_equal,
|
||||
mx.not_equal,
|
||||
mx.logaddexp,
|
||||
]
|
||||
for op in ops:
|
||||
if mx.default_device() == mx.gpu:
|
||||
with self.assertRaises(ValueError):
|
||||
op(a_double, b_double)
|
||||
continue
|
||||
y = op(a, b)
|
||||
y_double = op(a_double, b_double)
|
||||
self.assertTrue(
|
||||
mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True)
|
||||
)
|
||||
|
||||
def test_where(self):
|
||||
shape = (3, 3)
|
||||
cond = mx.random.uniform(shape=shape) > 0.5
|
||||
a = mx.random.normal(shape=shape)
|
||||
b = mx.random.normal(shape=shape)
|
||||
|
||||
a_double = a.astype(mx.float64, stream=mx.cpu)
|
||||
b_double = b.astype(mx.float64, stream=mx.cpu)
|
||||
|
||||
if mx.default_device() == mx.gpu:
|
||||
with self.assertRaises(ValueError):
|
||||
mx.where(cond, a_double, b_double)
|
||||
return
|
||||
y = mx.where(cond, a, b)
|
||||
y_double = mx.where(cond, a_double, b_double)
|
||||
self.assertTrue(mx.allclose(y, y_double.astype(mx.float32, mx.cpu)))
|
||||
|
||||
def test_reductions(self):
|
||||
shape = (32, 32)
|
||||
a = mx.random.normal(shape=shape)
|
||||
a_double = a.astype(mx.float64, stream=mx.cpu)
|
||||
|
||||
axes = [0, 1, (0, 1)]
|
||||
ops = [mx.sum, mx.prod, mx.min, mx.max, mx.any, mx.all]
|
||||
|
||||
for op in ops:
|
||||
for ax in axes:
|
||||
if mx.default_device() == mx.gpu:
|
||||
with self.assertRaises(ValueError):
|
||||
op(a_double, axis=ax)
|
||||
continue
|
||||
y = op(a)
|
||||
y_double = op(a_double)
|
||||
self.assertTrue(mx.allclose(y, y_double.astype(mx.float32, mx.cpu)))
|
||||
|
||||
def test_get_and_set_item(self):
|
||||
shape = (3, 3)
|
||||
a = mx.random.normal(shape=shape)
|
||||
b = mx.random.normal(shape=(2,))
|
||||
a_double = a.astype(mx.float64, stream=mx.cpu)
|
||||
b_double = b.astype(mx.float64, stream=mx.cpu)
|
||||
idx_i = mx.array([0, 2])
|
||||
idx_j = mx.array([0, 2])
|
||||
|
||||
if mx.default_device() == mx.gpu:
|
||||
with self.assertRaises(ValueError):
|
||||
a_double[idx_i, idx_j]
|
||||
else:
|
||||
y = a[idx_i, idx_j]
|
||||
y_double = a_double[idx_i, idx_j]
|
||||
self.assertTrue(mx.allclose(y, y_double.astype(mx.float32, mx.cpu)))
|
||||
|
||||
if mx.default_device() == mx.gpu:
|
||||
with self.assertRaises(ValueError):
|
||||
a_double[idx_i, idx_j] = b_double
|
||||
else:
|
||||
a[idx_i, idx_j] = b
|
||||
a_double[idx_i, idx_j] = b_double
|
||||
self.assertTrue(mx.allclose(a, a_double.astype(mx.float32, mx.cpu)))
|
||||
|
||||
def test_gemm(self):
|
||||
shape = (8, 8)
|
||||
a = mx.random.normal(shape=shape)
|
||||
b = mx.random.normal(shape=shape)
|
||||
|
||||
a_double = a.astype(mx.float64, stream=mx.cpu)
|
||||
b_double = b.astype(mx.float64, stream=mx.cpu)
|
||||
|
||||
if mx.default_device() == mx.gpu:
|
||||
with self.assertRaises(ValueError):
|
||||
a_double @ b_double
|
||||
return
|
||||
y = a @ b
|
||||
y_double = a_double @ b_double
|
||||
self.assertTrue(
|
||||
mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user