mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
* add fp8 e4m3 converters * add cuda * default saturate to min/max * fix for older OS * fix no gpu/cpu * fix saturate * fix compile
20 lines
542 B
Plaintext
20 lines
542 B
Plaintext
// Copyright © 2025 Apple Inc.
|
|
#include "mlx/backend/cuda/unary/unary.cuh"
|
|
#include "mlx/fast_primitives.h"
|
|
|
|
namespace mlx::core {
|
|
void fast::ConvertFP8::eval_gpu(
|
|
const std::vector<array>& inputs,
|
|
std::vector<array>& outputs) {
|
|
nvtx3::scoped_range r("ConvertFP8::eval_gpu");
|
|
auto& in = inputs[0];
|
|
auto& out = outputs[0];
|
|
auto& s = out.primitive().stream();
|
|
if (to_fp8_) {
|
|
unary_op_gpu<cu::ToFP8>(inputs, out, name(), s);
|
|
} else {
|
|
unary_op_gpu<cu::FromFP8>(inputs, out, name(), s);
|
|
}
|
|
}
|
|
} // namespace mlx::core
|