mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Fp64 on the CPU (#1843)
* add fp64 data type * clean build * update docs * fix bug
This commit is contained in:
parent
1a1b2108ec
commit
1c0c118f7c
@ -51,11 +51,20 @@ The default floating point type is ``float32`` and the default integer type is
|
|||||||
* - ``float32``
|
* - ``float32``
|
||||||
- 4
|
- 4
|
||||||
- 32-bit float
|
- 32-bit float
|
||||||
|
* - ``float64``
|
||||||
|
- 4
|
||||||
|
- 64-bit double
|
||||||
* - ``complex64``
|
* - ``complex64``
|
||||||
- 8
|
- 8
|
||||||
- 64-bit complex float
|
- 64-bit complex float
|
||||||
|
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Arrays with type ``float64`` only work with CPU operations. Using
|
||||||
|
``float64`` arrays on the GPU will result in an exception.
|
||||||
|
|
||||||
|
|
||||||
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
|
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
|
||||||
documentation for more information. Use :func:`issubdtype` to determine if one
|
documentation for more information. Use :func:`issubdtype` to determine if one
|
||||||
``dtype`` (or category) is a subtype of another category.
|
``dtype`` (or category) is a subtype of another category.
|
||||||
|
@ -21,11 +21,13 @@ Let's convert an array to NumPy and back.
|
|||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
Since NumPy does not support ``bfloat16`` arrays, you will need to convert to ``float16`` or ``float32`` first:
|
Since NumPy does not support ``bfloat16`` arrays, you will need to convert
|
||||||
``np.array(a.astype(mx.float32))``.
|
to ``float16`` or ``float32`` first: ``np.array(a.astype(mx.float32))``.
|
||||||
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.``
|
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118
|
||||||
|
buffer format string does not match the dtype V item size 0.``
|
||||||
|
|
||||||
By default, NumPy copies data to a new array. This can be prevented by creating an array view:
|
By default, NumPy copies data to a new array. This can be prevented by creating
|
||||||
|
an array view:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@ -35,10 +37,16 @@ By default, NumPy copies data to a new array. This can be prevented by creating
|
|||||||
a_view[0] = 1
|
a_view[0] = 1
|
||||||
print(a[0].item()) # 1
|
print(a[0].item()) # 1
|
||||||
|
|
||||||
A NumPy array view is a normal NumPy array, except that it does not own its memory.
|
.. note::
|
||||||
This means writing to the view is reflected in the original array.
|
|
||||||
|
|
||||||
While this is quite powerful to prevent copying arrays, it should be noted that external changes to the memory of arrays cannot be reflected in gradients.
|
NumPy arrays with type ``float64`` will be default converted to MLX arrays
|
||||||
|
with type ``float32``.
|
||||||
|
|
||||||
|
A NumPy array view is a normal NumPy array, except that it does not own its
|
||||||
|
memory. This means writing to the view is reflected in the original array.
|
||||||
|
|
||||||
|
While this is quite powerful to prevent copying arrays, it should be noted that
|
||||||
|
external changes to the memory of arrays cannot be reflected in gradients.
|
||||||
|
|
||||||
Let's demonstrate this in an example:
|
Let's demonstrate this in an example:
|
||||||
|
|
||||||
@ -56,11 +64,12 @@ Let's demonstrate this in an example:
|
|||||||
|
|
||||||
|
|
||||||
The function ``f`` indirectly modifies the array ``x`` through a memory view.
|
The function ``f`` indirectly modifies the array ``x`` through a memory view.
|
||||||
However, this modification is not reflected in the gradient, as seen in the last line outputting ``1.0``,
|
However, this modification is not reflected in the gradient, as seen in the
|
||||||
representing the gradient of the sum operation alone.
|
last line outputting ``1.0``, representing the gradient of the sum operation
|
||||||
The squaring of ``x`` occurs externally to MLX, meaning that no gradient is incorporated.
|
alone. The squaring of ``x`` occurs externally to MLX, meaning that no
|
||||||
It's important to note that a similar issue arises during array conversion and copying.
|
gradient is incorporated. It's important to note that a similar issue arises
|
||||||
For instance, a function defined as ``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
|
during array conversion and copying. For instance, a function defined as
|
||||||
|
``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
|
||||||
even though no in-place operations on MLX memory are executed.
|
even though no in-place operations on MLX memory are executed.
|
||||||
|
|
||||||
PyTorch
|
PyTorch
|
||||||
@ -71,7 +80,8 @@ PyTorch
|
|||||||
PyTorch Support for :obj:`memoryview` is experimental and can break for
|
PyTorch Support for :obj:`memoryview` is experimental and can break for
|
||||||
multi-dimensional arrays. Casting to NumPy first is advised for now.
|
multi-dimensional arrays. Casting to NumPy first is advised for now.
|
||||||
|
|
||||||
PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
|
PyTorch supports the buffer protocol, but it requires an explicit
|
||||||
|
:obj:`memoryview`.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@ -82,7 +92,8 @@ PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryvi
|
|||||||
b = torch.tensor(memoryview(a))
|
b = torch.tensor(memoryview(a))
|
||||||
c = mx.array(b.numpy())
|
c = mx.array(b.numpy())
|
||||||
|
|
||||||
Conversion from PyTorch tensors back to arrays must be done via intermediate NumPy arrays with ``numpy()``.
|
Conversion from PyTorch tensors back to arrays must be done via intermediate
|
||||||
|
NumPy arrays with ``numpy()``.
|
||||||
|
|
||||||
JAX
|
JAX
|
||||||
---
|
---
|
||||||
@ -100,7 +111,8 @@ JAX fully supports the buffer protocol.
|
|||||||
TensorFlow
|
TensorFlow
|
||||||
----------
|
----------
|
||||||
|
|
||||||
TensorFlow supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
|
TensorFlow supports the buffer protocol, but it requires an explicit
|
||||||
|
:obj:`memoryview`.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@ -25,7 +25,18 @@ array::array(
|
|||||||
std::move(shape),
|
std::move(shape),
|
||||||
dtype,
|
dtype,
|
||||||
std::move(primitive),
|
std::move(primitive),
|
||||||
std::move(inputs))) {}
|
std::move(inputs))) {
|
||||||
|
if (has_primitive() && this->primitive().stream().device == Device::gpu) {
|
||||||
|
for (auto& in : this->inputs()) {
|
||||||
|
if (in.dtype() == float64) {
|
||||||
|
throw std::invalid_argument("float64 is not supported on the GPU");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (this->dtype() == float64) {
|
||||||
|
throw std::invalid_argument("float64 is not supported on the GPU");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<array> array::make_arrays(
|
std::vector<array> array::make_arrays(
|
||||||
std::vector<Shape> shapes,
|
std::vector<Shape> shapes,
|
||||||
|
@ -594,6 +594,9 @@ void array::init(It src) {
|
|||||||
case float32:
|
case float32:
|
||||||
std::copy(src, src + size(), data<float>());
|
std::copy(src, src + size(), data<float>());
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
std::copy(src, src + size(), data<double>());
|
||||||
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
std::copy(src, src + size(), data<bfloat16_t>());
|
std::copy(src, src + size(), data<bfloat16_t>());
|
||||||
break;
|
break;
|
||||||
|
@ -151,6 +151,9 @@ void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
case bfloat16:
|
case bfloat16:
|
||||||
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
|
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
*out.data<double>() = static_cast<double>(numel);
|
||||||
|
break;
|
||||||
case complex64:
|
case complex64:
|
||||||
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
|
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
|
||||||
break;
|
break;
|
||||||
|
@ -62,6 +62,9 @@ void arange(
|
|||||||
case float32:
|
case float32:
|
||||||
arange<float>(start, start + step, out, out.size());
|
arange<float>(start, start + step, out, out.size());
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
arange<double>(start, start + step, out, out.size());
|
||||||
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
arange<bfloat16_t>(start, start + step, out, out.size());
|
arange<bfloat16_t>(start, start + step, out, out.size());
|
||||||
break;
|
break;
|
||||||
|
@ -103,6 +103,9 @@ void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
case bfloat16:
|
case bfloat16:
|
||||||
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
|
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_);
|
||||||
|
break;
|
||||||
case complex64:
|
case complex64:
|
||||||
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
|
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
|
||||||
break;
|
break;
|
||||||
|
@ -51,6 +51,9 @@ void comparison_op(const array& a, const array& b, array& out, Op op) {
|
|||||||
case float32:
|
case float32:
|
||||||
binary_op<float, bool>(a, b, out, op);
|
binary_op<float, bool>(a, b, out, op);
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
binary_op<double, bool>(a, b, out, op);
|
||||||
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
binary_op<bfloat16_t, bool>(a, b, out, op);
|
binary_op<bfloat16_t, bool>(a, b, out, op);
|
||||||
break;
|
break;
|
||||||
@ -114,6 +117,9 @@ void DivMod::eval_cpu(
|
|||||||
case float32:
|
case float32:
|
||||||
binary_op<float>(a, b, outputs, float_op);
|
binary_op<float>(a, b, outputs, float_op);
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
binary_op<double>(a, b, outputs, float_op);
|
||||||
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
binary_op<bfloat16_t>(a, b, outputs, float_op);
|
binary_op<bfloat16_t>(a, b, outputs, float_op);
|
||||||
break;
|
break;
|
||||||
@ -150,6 +156,9 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
case float32:
|
case float32:
|
||||||
binary_op<float, bool>(a, b, out, detail::NaNEqual());
|
binary_op<float, bool>(a, b, out, detail::NaNEqual());
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
binary_op<double, bool>(a, b, out, detail::NaNEqual());
|
||||||
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
binary_op<bfloat16_t, bool>(a, b, out, detail::NaNEqual());
|
binary_op<bfloat16_t, bool>(a, b, out, detail::NaNEqual());
|
||||||
break;
|
break;
|
||||||
@ -189,20 +198,22 @@ void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
if (out.dtype() == float32) {
|
switch (out.dtype()) {
|
||||||
binary_op<float>(a, b, out, detail::LogAddExp());
|
case float16:
|
||||||
} else if (out.dtype() == float16) {
|
|
||||||
binary_op<float16_t>(a, b, out, detail::LogAddExp());
|
binary_op<float16_t>(a, b, out, detail::LogAddExp());
|
||||||
} else if (out.dtype() == bfloat16) {
|
break;
|
||||||
|
case float32:
|
||||||
|
binary_op<float>(a, b, out, detail::LogAddExp());
|
||||||
|
break;
|
||||||
|
case float64:
|
||||||
|
binary_op<double>(a, b, out, detail::LogAddExp());
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
|
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
|
||||||
} else if (issubdtype(out.dtype(), inexact)) {
|
break;
|
||||||
std::ostringstream err;
|
default:
|
||||||
err << "[logaddexp] Does not support " << out.dtype();
|
throw std::runtime_error(
|
||||||
throw std::invalid_argument(err.str());
|
"[LogAddExp::eval_cpu] Only supports non-complex floating point types.");
|
||||||
} else {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[logaddexp] Cannot compute logaddexp for arrays with"
|
|
||||||
" non floating point type.");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -321,20 +332,22 @@ void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
const auto& a = inputs[0];
|
const auto& a = inputs[0];
|
||||||
const auto& b = inputs[1];
|
const auto& b = inputs[1];
|
||||||
if (out.dtype() == float32) {
|
switch (out.dtype()) {
|
||||||
binary_op<float>(a, b, out, detail::ArcTan2());
|
case float16:
|
||||||
} else if (out.dtype() == float16) {
|
|
||||||
binary_op<float16_t>(a, b, out, detail::ArcTan2());
|
binary_op<float16_t>(a, b, out, detail::ArcTan2());
|
||||||
} else if (out.dtype() == bfloat16) {
|
break;
|
||||||
|
case float32:
|
||||||
|
binary_op<float>(a, b, out, detail::ArcTan2());
|
||||||
|
break;
|
||||||
|
case float64:
|
||||||
|
binary_op<double>(a, b, out, detail::ArcTan2());
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
|
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
|
||||||
} else if (issubdtype(out.dtype(), inexact)) {
|
break;
|
||||||
std::ostringstream err;
|
default:
|
||||||
err << "[arctan2] Does not support " << out.dtype();
|
throw std::runtime_error(
|
||||||
throw std::invalid_argument(err.str());
|
"[ArcTan2::eval_cpu] Only supports non-complex floating point types.");
|
||||||
} else {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[arctan2] Cannot compute inverse tangent for arrays"
|
|
||||||
" with non floating point type.");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -358,6 +358,9 @@ void binary(const array& a, const array& b, array& out, Op op) {
|
|||||||
case float32:
|
case float32:
|
||||||
binary_op<float>(a, b, out, op);
|
binary_op<float>(a, b, out, op);
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
binary_op<double>(a, b, out, op);
|
||||||
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
binary_op<bfloat16_t>(a, b, out, op);
|
binary_op<bfloat16_t>(a, b, out, op);
|
||||||
break;
|
break;
|
||||||
|
@ -205,6 +205,9 @@ void binary(
|
|||||||
case float32:
|
case float32:
|
||||||
binary_op<float>(a, b, outputs, op);
|
binary_op<float>(a, b, outputs, op);
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
binary_op<double>(a, b, outputs, op);
|
||||||
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
binary_op<bfloat16_t>(a, b, outputs, op);
|
binary_op<bfloat16_t>(a, b, outputs, op);
|
||||||
break;
|
break;
|
||||||
|
@ -193,6 +193,9 @@ void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
|||||||
case float32:
|
case float32:
|
||||||
copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
|
copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
copy<SrcT, double>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
@ -242,6 +245,9 @@ inline void copy_inplace_dispatch(
|
|||||||
case float32:
|
case float32:
|
||||||
copy<float>(src, dst, ctype, std::forward<Args>(args)...);
|
copy<float>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
copy<double>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||||
break;
|
break;
|
||||||
|
@ -41,4 +41,39 @@ void matmul<float>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void matmul<double>(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
bool a_transposed,
|
||||||
|
bool b_transposed,
|
||||||
|
size_t lda,
|
||||||
|
size_t ldb,
|
||||||
|
float alpha,
|
||||||
|
float beta) {
|
||||||
|
size_t M = a.shape(-2);
|
||||||
|
size_t N = b.shape(-1);
|
||||||
|
size_t K = a.shape(-1);
|
||||||
|
|
||||||
|
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||||
|
cblas_dgemm(
|
||||||
|
CblasRowMajor,
|
||||||
|
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||||
|
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
alpha, // alpha
|
||||||
|
a.data<double>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
||||||
|
lda,
|
||||||
|
b.data<double>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
||||||
|
ldb,
|
||||||
|
beta, // beta
|
||||||
|
out.data<double>() + M * N * i,
|
||||||
|
out.shape(-1) // ldc
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -148,6 +148,9 @@ void dispatch_gather(
|
|||||||
case float32:
|
case float32:
|
||||||
gather<float, IdxT>(src, inds, out, axes, size);
|
gather<float, IdxT>(src, inds, out, axes, size);
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
gather<double, IdxT>(src, inds, out, axes, size);
|
||||||
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
gather<bfloat16_t, IdxT>(src, inds, out, axes, size);
|
gather<bfloat16_t, IdxT>(src, inds, out, axes, size);
|
||||||
break;
|
break;
|
||||||
@ -288,6 +291,9 @@ void dispatch_gather_axis(
|
|||||||
case float32:
|
case float32:
|
||||||
gather_axis<float, IdxT>(src, inds, out, axis);
|
gather_axis<float, IdxT>(src, inds, out, axis);
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
gather_axis<double, IdxT>(src, inds, out, axis);
|
||||||
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
gather_axis<bfloat16_t, IdxT>(src, inds, out, axis);
|
gather_axis<bfloat16_t, IdxT>(src, inds, out, axis);
|
||||||
break;
|
break;
|
||||||
@ -499,6 +505,9 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
case float32:
|
case float32:
|
||||||
dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
|
dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
dispatch_scatter<double>(out, inds, updates, axes_, reduce_type_);
|
||||||
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
|
dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
|
||||||
break;
|
break;
|
||||||
@ -661,6 +670,9 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
case float32:
|
case float32:
|
||||||
dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);
|
dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
dispatch_scatter_axis<double>(out, idx, updates, axis_, reduce_type_);
|
||||||
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
dispatch_scatter_axis<bfloat16_t>(out, idx, updates, axis_, reduce_type_);
|
dispatch_scatter_axis<bfloat16_t>(out, idx, updates, axis_, reduce_type_);
|
||||||
break;
|
break;
|
||||||
|
@ -46,6 +46,9 @@ void matmul_general(
|
|||||||
} else if (out.dtype() == bfloat16) {
|
} else if (out.dtype() == bfloat16) {
|
||||||
matmul<bfloat16_t>(
|
matmul<bfloat16_t>(
|
||||||
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
|
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
|
||||||
|
} else if (out.dtype() == float64) {
|
||||||
|
matmul<double>(
|
||||||
|
a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
|
throw std::runtime_error("[Matmul::eval_cpu] Invalid type.");
|
||||||
}
|
}
|
||||||
|
@ -42,6 +42,7 @@ instantiate_default_limit(int64_t);
|
|||||||
instantiate_float_limit(float16_t);
|
instantiate_float_limit(float16_t);
|
||||||
instantiate_float_limit(bfloat16_t);
|
instantiate_float_limit(bfloat16_t);
|
||||||
instantiate_float_limit(float);
|
instantiate_float_limit(float);
|
||||||
|
instantiate_float_limit(double);
|
||||||
instantiate_float_limit(complex64_t);
|
instantiate_float_limit(complex64_t);
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
@ -59,6 +60,8 @@ const bfloat16_t Limits<bfloat16_t>::min =
|
|||||||
const float16_t Limits<float16_t>::max = std::numeric_limits<float>::infinity();
|
const float16_t Limits<float16_t>::max = std::numeric_limits<float>::infinity();
|
||||||
const float16_t Limits<float16_t>::min =
|
const float16_t Limits<float16_t>::min =
|
||||||
-std::numeric_limits<float>::infinity();
|
-std::numeric_limits<float>::infinity();
|
||||||
|
const double Limits<double>::max = std::numeric_limits<double>::infinity();
|
||||||
|
const double Limits<double>::min = -std::numeric_limits<double>::infinity();
|
||||||
const complex64_t Limits<complex64_t>::max =
|
const complex64_t Limits<complex64_t>::max =
|
||||||
std::numeric_limits<float>::infinity();
|
std::numeric_limits<float>::infinity();
|
||||||
const complex64_t Limits<complex64_t>::min =
|
const complex64_t Limits<complex64_t>::min =
|
||||||
@ -460,6 +463,7 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
break;
|
break;
|
||||||
case uint64:
|
case uint64:
|
||||||
case int64:
|
case int64:
|
||||||
|
case float64:
|
||||||
case complex64:
|
case complex64:
|
||||||
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
@ -495,6 +499,9 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
case float32:
|
case float32:
|
||||||
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
|
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
reduce_dispatch_sum_prod<double>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
case complex64:
|
case complex64:
|
||||||
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
@ -537,6 +544,9 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
case float32:
|
case float32:
|
||||||
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
|
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
reduce_dispatch_min_max<double>(in, out, reduce_type_, axes_);
|
||||||
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
|
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||||
break;
|
break;
|
||||||
|
@ -299,6 +299,10 @@ void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
scan_dispatch<float, float>(
|
scan_dispatch<float, float>(
|
||||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
scan_dispatch<double, double>(
|
||||||
|
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||||
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
scan_dispatch<bfloat16_t, bfloat16_t>(
|
scan_dispatch<bfloat16_t, bfloat16_t>(
|
||||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||||
|
@ -51,6 +51,9 @@ void select_op(
|
|||||||
case float32:
|
case float32:
|
||||||
ternary_op<bool, float, float, float>(a, b, c, out, op);
|
ternary_op<bool, float, float, float>(a, b, c, out, op);
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
ternary_op<bool, double, double, double>(a, b, c, out, op);
|
||||||
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(a, b, c, out, op);
|
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(a, b, c, out, op);
|
||||||
break;
|
break;
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
#include "mlx/backend/cpu/copy.h"
|
#include "mlx/backend/cpu/copy.h"
|
||||||
#include "mlx/backend/cpu/simd/simd.h"
|
#include "mlx/backend/cpu/simd/simd.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/types/limits.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@ -28,7 +29,7 @@ void softmax(const array& in, array& out) {
|
|||||||
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
|
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
|
||||||
// Find the maximum
|
// Find the maximum
|
||||||
current_in_ptr = in_ptr;
|
current_in_ptr = in_ptr;
|
||||||
Simd<AccT, N> vmaximum(-std::numeric_limits<float>::infinity());
|
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
|
||||||
size_t s = M;
|
size_t s = M;
|
||||||
while (s >= N) {
|
while (s >= N) {
|
||||||
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
|
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
|
||||||
@ -163,6 +164,9 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
softmax<bfloat16_t, bfloat16_t>(in, out);
|
softmax<bfloat16_t, bfloat16_t>(in, out);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
softmax<double, double>(in, out);
|
||||||
|
break;
|
||||||
case complex64:
|
case complex64:
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[Softmax] Not yet implemented for complex64");
|
"[Softmax] Not yet implemented for complex64");
|
||||||
|
@ -312,6 +312,8 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return argsort<int64_t>(in, out, axis_);
|
return argsort<int64_t>(in, out, axis_);
|
||||||
case float32:
|
case float32:
|
||||||
return argsort<float>(in, out, axis_);
|
return argsort<float>(in, out, axis_);
|
||||||
|
case float64:
|
||||||
|
return argsort<double>(in, out, axis_);
|
||||||
case float16:
|
case float16:
|
||||||
return argsort<float16_t>(in, out, axis_);
|
return argsort<float16_t>(in, out, axis_);
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
@ -346,6 +348,8 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return sort<int64_t>(in, out, axis_);
|
return sort<int64_t>(in, out, axis_);
|
||||||
case float32:
|
case float32:
|
||||||
return sort<float>(in, out, axis_);
|
return sort<float>(in, out, axis_);
|
||||||
|
case float64:
|
||||||
|
return sort<double>(in, out, axis_);
|
||||||
case float16:
|
case float16:
|
||||||
return sort<float16_t>(in, out, axis_);
|
return sort<float16_t>(in, out, axis_);
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
@ -380,6 +384,8 @@ void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return argpartition<int64_t>(in, out, axis_, kth_);
|
return argpartition<int64_t>(in, out, axis_, kth_);
|
||||||
case float32:
|
case float32:
|
||||||
return argpartition<float>(in, out, axis_, kth_);
|
return argpartition<float>(in, out, axis_, kth_);
|
||||||
|
case float64:
|
||||||
|
return argpartition<double>(in, out, axis_, kth_);
|
||||||
case float16:
|
case float16:
|
||||||
return argpartition<float16_t>(in, out, axis_, kth_);
|
return argpartition<float16_t>(in, out, axis_, kth_);
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
@ -414,6 +420,8 @@ void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
return partition<int64_t>(in, out, axis_, kth_);
|
return partition<int64_t>(in, out, axis_, kth_);
|
||||||
case float32:
|
case float32:
|
||||||
return partition<float>(in, out, axis_, kth_);
|
return partition<float>(in, out, axis_, kth_);
|
||||||
|
case float64:
|
||||||
|
return partition<double>(in, out, axis_, kth_);
|
||||||
case float16:
|
case float16:
|
||||||
return partition<float16_t>(in, out, axis_, kth_);
|
return partition<float16_t>(in, out, axis_, kth_);
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
|
@ -34,6 +34,9 @@ void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
case float32:
|
case float32:
|
||||||
unary_op<float>(in, out, op);
|
unary_op<float>(in, out, op);
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
unary_op<double>(in, out, op);
|
||||||
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
unary_op<bfloat16_t>(in, out, op);
|
unary_op<bfloat16_t>(in, out, op);
|
||||||
break;
|
break;
|
||||||
@ -120,6 +123,9 @@ void Erf::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
case float16:
|
case float16:
|
||||||
unary_op<float16_t>(in, out, detail::Erf());
|
unary_op<float16_t>(in, out, detail::Erf());
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
unary_op<double>(in, out, detail::Erf());
|
||||||
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
unary_op<bfloat16_t>(in, out, detail::Erf());
|
unary_op<bfloat16_t>(in, out, detail::Erf());
|
||||||
break;
|
break;
|
||||||
@ -140,6 +146,9 @@ void ErfInv::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
case float16:
|
case float16:
|
||||||
unary_op<float16_t>(in, out, detail::ErfInv());
|
unary_op<float16_t>(in, out, detail::ErfInv());
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
unary_op<double>(in, out, detail::ErfInv());
|
||||||
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
unary_op<bfloat16_t>(in, out, detail::ErfInv());
|
unary_op<bfloat16_t>(in, out, detail::ErfInv());
|
||||||
break;
|
break;
|
||||||
|
@ -104,6 +104,9 @@ void unary(const array& a, array& out, Op op) {
|
|||||||
case float32:
|
case float32:
|
||||||
unary_op<float>(a, out, op);
|
unary_op<float>(a, out, op);
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
unary_op<double>(a, out, op);
|
||||||
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
unary_op<bfloat16_t>(a, out, op);
|
unary_op<bfloat16_t>(a, out, op);
|
||||||
break;
|
break;
|
||||||
@ -125,6 +128,9 @@ void unary_fp(const array& a, array& out, Op op) {
|
|||||||
case float32:
|
case float32:
|
||||||
unary_op<float>(a, out, op);
|
unary_op<float>(a, out, op);
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
unary_op<double>(a, out, op);
|
||||||
|
break;
|
||||||
case complex64:
|
case complex64:
|
||||||
unary_op<complex64_t>(a, out, op);
|
unary_op<complex64_t>(a, out, op);
|
||||||
break;
|
break;
|
||||||
|
@ -151,8 +151,8 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
case bfloat16:
|
case bfloat16:
|
||||||
arange_set_scalars<bfloat16_t>(start_, start_ + step_, compute_encoder);
|
arange_set_scalars<bfloat16_t>(start_, start_ + step_, compute_encoder);
|
||||||
break;
|
break;
|
||||||
case complex64:
|
default:
|
||||||
throw std::runtime_error("[Arange::eval_gpu] Does not support complex64");
|
throw std::runtime_error("[Arange::eval_gpu] Does not support type.");
|
||||||
}
|
}
|
||||||
|
|
||||||
compute_encoder.set_output_array(out, 2);
|
compute_encoder.set_output_array(out, 2);
|
||||||
|
@ -42,6 +42,9 @@ std::string type_to_name(const Dtype& t) {
|
|||||||
case float32:
|
case float32:
|
||||||
tname = "float32";
|
tname = "float32";
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
tname = "double";
|
||||||
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
tname = "bfloat16";
|
tname = "bfloat16";
|
||||||
break;
|
break;
|
||||||
|
@ -95,6 +95,7 @@ struct MPIWrapper {
|
|||||||
LOAD_SYMBOL(ompi_mpi_int64_t, mpi_int64_);
|
LOAD_SYMBOL(ompi_mpi_int64_t, mpi_int64_);
|
||||||
LOAD_SYMBOL(ompi_mpi_uint64_t, mpi_uint64_);
|
LOAD_SYMBOL(ompi_mpi_uint64_t, mpi_uint64_);
|
||||||
LOAD_SYMBOL(ompi_mpi_float, mpi_float_);
|
LOAD_SYMBOL(ompi_mpi_float, mpi_float_);
|
||||||
|
LOAD_SYMBOL(ompi_mpi_double, mpi_double_);
|
||||||
LOAD_SYMBOL(ompi_mpi_c_complex, mpi_complex_);
|
LOAD_SYMBOL(ompi_mpi_c_complex, mpi_complex_);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -164,6 +165,8 @@ struct MPIWrapper {
|
|||||||
return mpi_float16_;
|
return mpi_float16_;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
return mpi_bfloat16_;
|
return mpi_bfloat16_;
|
||||||
|
case float64:
|
||||||
|
return mpi_double_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -218,6 +221,7 @@ struct MPIWrapper {
|
|||||||
MPI_Datatype mpi_int64_;
|
MPI_Datatype mpi_int64_;
|
||||||
MPI_Datatype mpi_uint64_;
|
MPI_Datatype mpi_uint64_;
|
||||||
MPI_Datatype mpi_float_;
|
MPI_Datatype mpi_float_;
|
||||||
|
MPI_Datatype mpi_double_;
|
||||||
MPI_Datatype mpi_complex_;
|
MPI_Datatype mpi_complex_;
|
||||||
MPI_Datatype mpi_float16_;
|
MPI_Datatype mpi_float16_;
|
||||||
MPI_Datatype mpi_bfloat16_;
|
MPI_Datatype mpi_bfloat16_;
|
||||||
|
@ -68,6 +68,10 @@
|
|||||||
using T = float; \
|
using T = float; \
|
||||||
__VA_ARGS__; \
|
__VA_ARGS__; \
|
||||||
} break; \
|
} break; \
|
||||||
|
case float64: { \
|
||||||
|
using T = double; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
case complex64: { \
|
case complex64: { \
|
||||||
using T = complex64_t; \
|
using T = complex64_t; \
|
||||||
__VA_ARGS__; \
|
__VA_ARGS__; \
|
||||||
|
@ -8,7 +8,7 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr int num_types = 13;
|
constexpr int num_types = 14;
|
||||||
constexpr int num_cats = 8;
|
constexpr int num_cats = 8;
|
||||||
|
|
||||||
constexpr Dtype::Kind type_kinds[num_types] = {
|
constexpr Dtype::Kind type_kinds[num_types] = {
|
||||||
@ -23,6 +23,7 @@ constexpr Dtype::Kind type_kinds[num_types] = {
|
|||||||
Dtype::Kind::i, // int64,
|
Dtype::Kind::i, // int64,
|
||||||
Dtype::Kind::f, // float16,
|
Dtype::Kind::f, // float16,
|
||||||
Dtype::Kind::f, // float32,
|
Dtype::Kind::f, // float32,
|
||||||
|
Dtype::Kind::f, // float64,
|
||||||
Dtype::Kind::V, // bfloat16,
|
Dtype::Kind::V, // bfloat16,
|
||||||
Dtype::Kind::c // complex64,
|
Dtype::Kind::c // complex64,
|
||||||
};
|
};
|
||||||
@ -31,20 +32,21 @@ constexpr Dtype::Kind type_kinds[num_types] = {
|
|||||||
// https://jax.readthedocs.io/en/latest/type_promotion.html
|
// https://jax.readthedocs.io/en/latest/type_promotion.html
|
||||||
// clang-format off
|
// clang-format off
|
||||||
constexpr Dtype type_rules[num_types][num_types] = {
|
constexpr Dtype type_rules[num_types][num_types] = {
|
||||||
// bool uint8 uint16 uint32 uint64 int8 int16 int32 int64 float16 float32 bfloat16 complex64
|
// bool uint8 uint16 uint32 uint64 int8 int16 int32 int64 float16 float32 float64 bfloat16 complex64
|
||||||
{bool_, uint8, uint16, uint32, uint64, int8, int16, int32, int64, float16, float32, bfloat16, complex64}, // bool
|
{bool_, uint8, uint16, uint32, uint64, int8, int16, int32, int64, float16, float32, float64, bfloat16, complex64}, // bool
|
||||||
{uint8, uint8, uint16, uint32, uint64, int16, int16, int32, int64, float16, float32, bfloat16, complex64}, // uint8
|
{uint8, uint8, uint16, uint32, uint64, int16, int16, int32, int64, float16, float32, float64, bfloat16, complex64}, // uint8
|
||||||
{uint16, uint16, uint16, uint32, uint64, int32, int32, int32, int64, float16, float32, bfloat16, complex64}, // uint16
|
{uint16, uint16, uint16, uint32, uint64, int32, int32, int32, int64, float16, float32, float64, bfloat16, complex64}, // uint16
|
||||||
{uint32, uint32, uint32, uint32, uint64, int64, int64, int64, int64, float16, float32, bfloat16, complex64}, // uint32
|
{uint32, uint32, uint32, uint32, uint64, int64, int64, int64, int64, float16, float32, float64, bfloat16, complex64}, // uint32
|
||||||
{uint64, uint64, uint64, uint64, uint64, float32, float32, float32, float32, float16, float32, bfloat16, complex64}, // uint64
|
{uint64, uint64, uint64, uint64, uint64, float32, float32, float32, float32, float16, float32, float64, bfloat16, complex64}, // uint64
|
||||||
{int8, int16, int32, int64, float32, int8, int16, int32, int64, float16, float32, bfloat16, complex64}, // int8
|
{int8, int16, int32, int64, float32, int8, int16, int32, int64, float16, float32, float64, bfloat16, complex64}, // int8
|
||||||
{int16, int16, int32, int64, float32, int16, int16, int32, int64, float16, float32, bfloat16, complex64}, // int16
|
{int16, int16, int32, int64, float32, int16, int16, int32, int64, float16, float32, float64, bfloat16, complex64}, // int16
|
||||||
{int32, int32, int32, int64, float32, int32, int32, int32, int64, float16, float32, bfloat16, complex64}, // int32
|
{int32, int32, int32, int64, float32, int32, int32, int32, int64, float16, float32, float64, bfloat16, complex64}, // int32
|
||||||
{int64, int64, int64, int64, float32, int64, int64, int64, int64, float16, float32, bfloat16, complex64}, // int64
|
{int64, int64, int64, int64, float32, int64, int64, int64, int64, float16, float32, float64, bfloat16, complex64}, // int64
|
||||||
{float16, float16, float16, float16, float16, float16, float16, float16, float16, float16, float32, float32, complex64}, // float16
|
{float16, float16, float16, float16, float16, float16, float16, float16, float16, float16, float32, float64, float32, complex64}, // float16
|
||||||
{float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, complex64}, // float32
|
{float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float64, float32, complex64}, // float32
|
||||||
{bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, float32, float32, bfloat16, complex64}, // bfloat16
|
{float64, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float64, float32, complex64}, // float64
|
||||||
{complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64}, // complex64
|
{bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, float32, float32, float64, bfloat16, complex64}, // bfloat16
|
||||||
|
{complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64,complex64, complex64, complex64}, // complex64
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -72,6 +74,7 @@ constexpr Dtype::Category type_to_category[num_types] = {
|
|||||||
Dtype::Category::signedinteger, // int64,
|
Dtype::Category::signedinteger, // int64,
|
||||||
Dtype::Category::floating, // float16,
|
Dtype::Category::floating, // float16,
|
||||||
Dtype::Category::floating, // float32,
|
Dtype::Category::floating, // float32,
|
||||||
|
Dtype::Category::floating, // float64,
|
||||||
Dtype::Category::floating, // bfloat16,
|
Dtype::Category::floating, // bfloat16,
|
||||||
Dtype::Category::complexfloating, // complex64,
|
Dtype::Category::complexfloating, // complex64,
|
||||||
};
|
};
|
||||||
|
@ -23,6 +23,7 @@ struct Dtype {
|
|||||||
int64,
|
int64,
|
||||||
float16,
|
float16,
|
||||||
float32,
|
float32,
|
||||||
|
float64,
|
||||||
bfloat16,
|
bfloat16,
|
||||||
complex64,
|
complex64,
|
||||||
};
|
};
|
||||||
@ -78,6 +79,7 @@ inline constexpr Dtype int64{Dtype::Val::int64, sizeof(int64_t)};
|
|||||||
|
|
||||||
inline constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)};
|
inline constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)};
|
||||||
inline constexpr Dtype float32{Dtype::Val::float32, sizeof(float)};
|
inline constexpr Dtype float32{Dtype::Val::float32, sizeof(float)};
|
||||||
|
inline constexpr Dtype float64{Dtype::Val::float64, sizeof(double)};
|
||||||
inline constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)};
|
inline constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)};
|
||||||
inline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)};
|
inline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)};
|
||||||
|
|
||||||
|
@ -12,6 +12,9 @@ struct numeric_limits;
|
|||||||
template <>
|
template <>
|
||||||
struct numeric_limits<float> : public std::numeric_limits<float> {};
|
struct numeric_limits<float> : public std::numeric_limits<float> {};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct numeric_limits<double> : public std::numeric_limits<double> {};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct numeric_limits<float16_t> {
|
struct numeric_limits<float16_t> {
|
||||||
private:
|
private:
|
||||||
|
@ -54,6 +54,9 @@ inline void PrintFormatter::print(std::ostream& os, bfloat16_t val) {
|
|||||||
inline void PrintFormatter::print(std::ostream& os, float val) {
|
inline void PrintFormatter::print(std::ostream& os, float val) {
|
||||||
os << val;
|
os << val;
|
||||||
}
|
}
|
||||||
|
inline void PrintFormatter::print(std::ostream& os, double val) {
|
||||||
|
os << val;
|
||||||
|
}
|
||||||
inline void PrintFormatter::print(std::ostream& os, complex64_t val) {
|
inline void PrintFormatter::print(std::ostream& os, complex64_t val) {
|
||||||
os << val;
|
os << val;
|
||||||
}
|
}
|
||||||
@ -234,6 +237,8 @@ std::ostream& operator<<(std::ostream& os, const Dtype& dtype) {
|
|||||||
return os << "float16";
|
return os << "float16";
|
||||||
case float32:
|
case float32:
|
||||||
return os << "float32";
|
return os << "float32";
|
||||||
|
case float64:
|
||||||
|
return os << "float64";
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
return os << "bfloat16";
|
return os << "bfloat16";
|
||||||
case complex64:
|
case complex64:
|
||||||
@ -299,6 +304,9 @@ std::ostream& operator<<(std::ostream& os, array a) {
|
|||||||
case float32:
|
case float32:
|
||||||
print_array<float>(os, a);
|
print_array<float>(os, a);
|
||||||
break;
|
break;
|
||||||
|
case float64:
|
||||||
|
print_array<double>(os, a);
|
||||||
|
break;
|
||||||
case complex64:
|
case complex64:
|
||||||
print_array<complex64_t>(os, a);
|
print_array<complex64_t>(os, a);
|
||||||
break;
|
break;
|
||||||
@ -337,7 +345,7 @@ int get_var(const char* name, int default_value) {
|
|||||||
} // namespace env
|
} // namespace env
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void set_finfo_limits(float& min, float& max) {
|
void set_finfo_limits(double& min, double& max) {
|
||||||
min = numeric_limits<T>::lowest();
|
min = numeric_limits<T>::lowest();
|
||||||
max = numeric_limits<T>::max();
|
max = numeric_limits<T>::max();
|
||||||
}
|
}
|
||||||
@ -354,6 +362,8 @@ finfo::finfo(Dtype dtype) : dtype(dtype) {
|
|||||||
set_finfo_limits<float16_t>(min, max);
|
set_finfo_limits<float16_t>(min, max);
|
||||||
} else if (dtype == bfloat16) {
|
} else if (dtype == bfloat16) {
|
||||||
set_finfo_limits<bfloat16_t>(min, max);
|
set_finfo_limits<bfloat16_t>(min, max);
|
||||||
|
} else if (dtype == float64) {
|
||||||
|
set_finfo_limits<double>(min, max);
|
||||||
} else if (dtype == complex64) {
|
} else if (dtype == complex64) {
|
||||||
this->dtype = float32;
|
this->dtype = float32;
|
||||||
set_finfo_limits<float>(min, max);
|
set_finfo_limits<float>(min, max);
|
||||||
|
@ -47,6 +47,7 @@ struct PrintFormatter {
|
|||||||
inline void print(std::ostream& os, float16_t val);
|
inline void print(std::ostream& os, float16_t val);
|
||||||
inline void print(std::ostream& os, bfloat16_t val);
|
inline void print(std::ostream& os, bfloat16_t val);
|
||||||
inline void print(std::ostream& os, float val);
|
inline void print(std::ostream& os, float val);
|
||||||
|
inline void print(std::ostream& os, double val);
|
||||||
inline void print(std::ostream& os, complex64_t val);
|
inline void print(std::ostream& os, complex64_t val);
|
||||||
|
|
||||||
bool capitalize_bool{false};
|
bool capitalize_bool{false};
|
||||||
@ -61,8 +62,8 @@ void abort_with_exception(const std::exception& error);
|
|||||||
struct finfo {
|
struct finfo {
|
||||||
explicit finfo(Dtype dtype);
|
explicit finfo(Dtype dtype);
|
||||||
Dtype dtype;
|
Dtype dtype;
|
||||||
float min;
|
double min;
|
||||||
float max;
|
double max;
|
||||||
};
|
};
|
||||||
|
|
||||||
/** The type from promoting the arrays' types with one another. */
|
/** The type from promoting the arrays' types with one another. */
|
||||||
|
@ -128,6 +128,7 @@ void init_array(nb::module_& m) {
|
|||||||
m.attr("int64") = nb::cast(mx::int64);
|
m.attr("int64") = nb::cast(mx::int64);
|
||||||
m.attr("float16") = nb::cast(mx::float16);
|
m.attr("float16") = nb::cast(mx::float16);
|
||||||
m.attr("float32") = nb::cast(mx::float32);
|
m.attr("float32") = nb::cast(mx::float32);
|
||||||
|
m.attr("float64") = nb::cast(mx::float64);
|
||||||
m.attr("bfloat16") = nb::cast(mx::bfloat16);
|
m.attr("bfloat16") = nb::cast(mx::bfloat16);
|
||||||
m.attr("complex64") = nb::cast(mx::complex64);
|
m.attr("complex64") = nb::cast(mx::complex64);
|
||||||
nb::enum_<mx::Dtype::Category>(
|
nb::enum_<mx::Dtype::Category>(
|
||||||
@ -163,6 +164,7 @@ void init_array(nb::module_& m) {
|
|||||||
* :ref:`float16 <data_types>`
|
* :ref:`float16 <data_types>`
|
||||||
* :ref:`bfloat16 <data_types>`
|
* :ref:`bfloat16 <data_types>`
|
||||||
* :ref:`float32 <data_types>`
|
* :ref:`float32 <data_types>`
|
||||||
|
* :ref:`float64 <data_types>`
|
||||||
|
|
||||||
* :attr:`~mlx.core.complexfloating`
|
* :attr:`~mlx.core.complexfloating`
|
||||||
|
|
||||||
|
178
python/tests/test_double.py
Normal file
178
python/tests/test_double.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx_tests
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class TestDouble(mlx_tests.MLXTestCase):
|
||||||
|
def test_unary_ops(self):
|
||||||
|
shape = (3, 3)
|
||||||
|
x = mx.random.normal(shape=shape)
|
||||||
|
|
||||||
|
if mx.default_device() == mx.gpu:
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
x.astype(mx.float64)
|
||||||
|
|
||||||
|
x_double = x.astype(mx.float64, stream=mx.cpu)
|
||||||
|
|
||||||
|
ops = [
|
||||||
|
mx.abs,
|
||||||
|
mx.arccos,
|
||||||
|
mx.arccosh,
|
||||||
|
mx.arcsin,
|
||||||
|
mx.arcsinh,
|
||||||
|
mx.arctan,
|
||||||
|
mx.arctanh,
|
||||||
|
mx.ceil,
|
||||||
|
mx.erf,
|
||||||
|
mx.erfinv,
|
||||||
|
mx.exp,
|
||||||
|
mx.expm1,
|
||||||
|
mx.floor,
|
||||||
|
mx.log,
|
||||||
|
mx.logical_not,
|
||||||
|
mx.negative,
|
||||||
|
mx.round,
|
||||||
|
mx.sin,
|
||||||
|
mx.sinh,
|
||||||
|
mx.sqrt,
|
||||||
|
mx.rsqrt,
|
||||||
|
mx.tan,
|
||||||
|
mx.tanh,
|
||||||
|
]
|
||||||
|
for op in ops:
|
||||||
|
if mx.default_device() == mx.gpu:
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
op(x_double)
|
||||||
|
continue
|
||||||
|
y = op(x)
|
||||||
|
y_double = op(x_double)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_binary_ops(self):
|
||||||
|
shape = (3, 3)
|
||||||
|
a = mx.random.normal(shape=shape)
|
||||||
|
b = mx.random.normal(shape=shape)
|
||||||
|
|
||||||
|
a_double = a.astype(mx.float64, stream=mx.cpu)
|
||||||
|
b_double = b.astype(mx.float64, stream=mx.cpu)
|
||||||
|
|
||||||
|
ops = [
|
||||||
|
mx.add,
|
||||||
|
mx.arctan2,
|
||||||
|
mx.divide,
|
||||||
|
mx.multiply,
|
||||||
|
mx.subtract,
|
||||||
|
mx.logical_and,
|
||||||
|
mx.logical_or,
|
||||||
|
mx.remainder,
|
||||||
|
mx.maximum,
|
||||||
|
mx.minimum,
|
||||||
|
mx.power,
|
||||||
|
mx.equal,
|
||||||
|
mx.greater,
|
||||||
|
mx.greater_equal,
|
||||||
|
mx.less,
|
||||||
|
mx.less_equal,
|
||||||
|
mx.not_equal,
|
||||||
|
mx.logaddexp,
|
||||||
|
]
|
||||||
|
for op in ops:
|
||||||
|
if mx.default_device() == mx.gpu:
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
op(a_double, b_double)
|
||||||
|
continue
|
||||||
|
y = op(a, b)
|
||||||
|
y_double = op(a_double, b_double)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_where(self):
|
||||||
|
shape = (3, 3)
|
||||||
|
cond = mx.random.uniform(shape=shape) > 0.5
|
||||||
|
a = mx.random.normal(shape=shape)
|
||||||
|
b = mx.random.normal(shape=shape)
|
||||||
|
|
||||||
|
a_double = a.astype(mx.float64, stream=mx.cpu)
|
||||||
|
b_double = b.astype(mx.float64, stream=mx.cpu)
|
||||||
|
|
||||||
|
if mx.default_device() == mx.gpu:
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.where(cond, a_double, b_double)
|
||||||
|
return
|
||||||
|
y = mx.where(cond, a, b)
|
||||||
|
y_double = mx.where(cond, a_double, b_double)
|
||||||
|
self.assertTrue(mx.allclose(y, y_double.astype(mx.float32, mx.cpu)))
|
||||||
|
|
||||||
|
def test_reductions(self):
|
||||||
|
shape = (32, 32)
|
||||||
|
a = mx.random.normal(shape=shape)
|
||||||
|
a_double = a.astype(mx.float64, stream=mx.cpu)
|
||||||
|
|
||||||
|
axes = [0, 1, (0, 1)]
|
||||||
|
ops = [mx.sum, mx.prod, mx.min, mx.max, mx.any, mx.all]
|
||||||
|
|
||||||
|
for op in ops:
|
||||||
|
for ax in axes:
|
||||||
|
if mx.default_device() == mx.gpu:
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
op(a_double, axis=ax)
|
||||||
|
continue
|
||||||
|
y = op(a)
|
||||||
|
y_double = op(a_double)
|
||||||
|
self.assertTrue(mx.allclose(y, y_double.astype(mx.float32, mx.cpu)))
|
||||||
|
|
||||||
|
def test_get_and_set_item(self):
|
||||||
|
shape = (3, 3)
|
||||||
|
a = mx.random.normal(shape=shape)
|
||||||
|
b = mx.random.normal(shape=(2,))
|
||||||
|
a_double = a.astype(mx.float64, stream=mx.cpu)
|
||||||
|
b_double = b.astype(mx.float64, stream=mx.cpu)
|
||||||
|
idx_i = mx.array([0, 2])
|
||||||
|
idx_j = mx.array([0, 2])
|
||||||
|
|
||||||
|
if mx.default_device() == mx.gpu:
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
a_double[idx_i, idx_j]
|
||||||
|
else:
|
||||||
|
y = a[idx_i, idx_j]
|
||||||
|
y_double = a_double[idx_i, idx_j]
|
||||||
|
self.assertTrue(mx.allclose(y, y_double.astype(mx.float32, mx.cpu)))
|
||||||
|
|
||||||
|
if mx.default_device() == mx.gpu:
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
a_double[idx_i, idx_j] = b_double
|
||||||
|
else:
|
||||||
|
a[idx_i, idx_j] = b
|
||||||
|
a_double[idx_i, idx_j] = b_double
|
||||||
|
self.assertTrue(mx.allclose(a, a_double.astype(mx.float32, mx.cpu)))
|
||||||
|
|
||||||
|
def test_gemm(self):
|
||||||
|
shape = (8, 8)
|
||||||
|
a = mx.random.normal(shape=shape)
|
||||||
|
b = mx.random.normal(shape=shape)
|
||||||
|
|
||||||
|
a_double = a.astype(mx.float64, stream=mx.cpu)
|
||||||
|
b_double = b.astype(mx.float64, stream=mx.cpu)
|
||||||
|
|
||||||
|
if mx.default_device() == mx.gpu:
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
a_double @ b_double
|
||||||
|
return
|
||||||
|
y = a @ b
|
||||||
|
y_double = a_double @ b_double
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user