diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index 76fe389d4..abf46a7d5 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -5,6 +5,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/export.cpp ${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp diff --git a/mlx/dtype_utils.cpp b/mlx/dtype_utils.cpp new file mode 100644 index 000000000..a4448536d --- /dev/null +++ b/mlx/dtype_utils.cpp @@ -0,0 +1,20 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/dtype_utils.h" + +namespace mlx::core { + +const char* dtype_to_string(Dtype arg) { + if (arg == bool_) { + return "bool"; + } +#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \ + if (DTYPE == arg) { \ + return #DTYPE; \ + } + MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString) +#undef SPECIALIZE_DtypeToString + return "(unknown)"; +} + +} // namespace mlx::core diff --git a/mlx/dtype_utils.h b/mlx/dtype_utils.h new file mode 100644 index 000000000..55de890f2 --- /dev/null +++ b/mlx/dtype_utils.h @@ -0,0 +1,207 @@ +// Copyright © 2025 Apple Inc. +// Copyright © Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the BSD-style license found in +// https://github.com/pytorch/executorch/blob/main/LICENSE +// +// Forked from +// https://github.com/pytorch/executorch/blob/main/runtime/core/exec_aten/util/scalar_type_util.h + +#pragma once + +#include "mlx/dtype.h" + +#include + +namespace mlx::core { + +// Return string representation of dtype. +const char* dtype_to_string(Dtype arg); + +// Macros that iterate across different subsets of Dtypes. +// +// For all of these macros, the final `_` parameter is the name of another macro +// that takes two parameters: the name of a C type, and the name of the +// corresponding Dtype enumerator. +// +// Note that these macros should use fully-qualified namespaces (starting with +// `::`) to ensure that they can be called safely in any arbitrary namespace. +#define MLX_FORALL_INT_TYPES(_) \ + _(uint8_t, uint8) \ + _(uint16_t, uint16) \ + _(uint32_t, uint32) \ + _(uint64_t, uint64) \ + _(int8_t, int8) \ + _(int16_t, int16) \ + _(int32_t, int32) \ + _(int64_t, int64) + +#define MLX_FORALL_FLOAT_TYPES(_) \ + _(float16_t, float16) \ + _(float, float32) \ + _(double, float64) \ + _(bfloat16_t, bfloat16) + +// Calls the provided macro on every Dtype, providing the C type and the +// Dtype name to each call. +// +// @param _ A macro that takes two parameters: the name of a C type, and the +// name of the corresponding Dtype enumerator. +#define MLX_FORALL_DTYPES(_) \ + MLX_FORALL_INT_TYPES(_) \ + MLX_FORALL_FLOAT_TYPES(_) \ + _(bool, bool_) \ + _(complex64_t, complex64) + +// Maps Dtypes to C++ types. +template +struct DtypeToCppType; + +#define SPECIALIZE_DtypeToCppType(CPP_TYPE, DTYPE) \ + template <> \ + struct DtypeToCppType { \ + using type = CPP_TYPE; \ + }; + +MLX_FORALL_DTYPES(SPECIALIZE_DtypeToCppType) + +#undef SPECIALIZE_DtypeToCppType + +// Maps C++ types to Dtypes. +template +struct CppTypeToDtype; + +#define SPECIALIZE_CppTypeToDtype(CPP_TYPE, DTYPE) \ + template <> \ + struct CppTypeToDtype \ + : std::integral_constant {}; + +MLX_FORALL_DTYPES(SPECIALIZE_CppTypeToDtype) + +#undef SPECIALIZE_CppTypeToDtype + +// Helper macros for switch case macros (see below) +// +// These macros are not meant to be used directly. They provide an easy way to +// generate a switch statement that can handle subsets of Dtypes supported. + +#define MLX_INTERNAL_SWITCH_CASE(enum_type, CTYPE_ALIAS, ...) \ + case enum_type: { \ + using CTYPE_ALIAS = ::mlx::core::DtypeToCppType::type; \ + __VA_ARGS__; \ + break; \ + } + +#define MLX_INTERNAL_SWITCH_CHECKED(TYPE, NAME, ...) \ + switch (TYPE) { \ + __VA_ARGS__ \ + default: \ + throw std::invalid_argument(fmt::format( \ + "Unhandled dtype %s for %s", dtype_to_string(TYPE), NAME)); \ + } + +#define MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::uint8, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::uint16, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::uint32, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::uint64, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::int8, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::int16, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::int32, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::int64, CTYPE_ALIAS, __VA_ARGS__) + +#define MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::float16, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::float32, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::float64, CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::bfloat16, CTYPE_ALIAS, __VA_ARGS__) + +#define MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) + +#define MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::bool_, CTYPE_ALIAS, __VA_ARGS__) + +#define MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CASE( \ + ::mlx::core::Dtype::Val::complex64, CTYPE_ALIAS, __VA_ARGS__) + +#define MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ + MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, __VA_ARGS__) + +// Switch case macros +// +// These macros provide an easy way to generate switch statements that apply a +// common lambda function to subsets of Dtypes supported by MLX. +// The lambda function can type specialize to the ctype associated with the +// Dtype being handled through an alias passed as the CTYPE_ALIAS argument. +// +// Arguments: +// - ADDITIONAL: Additional Dtype case to add +// - TYPE: The Dtype to handle through the switch statement +// - NAME: A name for this operation which will be used in error messages +// - CTYPE_ALIAS: A typedef for the ctype associated with the Dtype. +// - ...: A statement to be applied to each Dtype case +// +// An example usage is: +// +// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE, { +// output.data[0] = input.data[0]; +// }); +// +// Note that these can be nested as well: +// +// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE_IN, { +// MLX_SWITCH_ALL_TYPES(output.dtype(), CTYPE_OUT, { +// output.data[0] = input.data[0]; +// }); +// }); +// +// These macros are adapted from Dispatch.h in the ATen library. The primary +// difference is that the CTYPE_ALIAS argument is exposed to users, which is +// used to alias the ctype associated with the Dtype that is being handled. + +#define MLX_SWITCH_ALL_TYPES(TYPE, CTYPE_ALIAS, ...) \ + switch (TYPE) { MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, __VA_ARGS__) } + +#define MLX_SWITCH_INT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CHECKED( \ + TYPE, \ + NAME, \ + MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__)) + +#define MLX_SWITCH_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CHECKED( \ + TYPE, \ + NAME, \ + MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__)) + +#define MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CHECKED( \ + TYPE, \ + NAME, \ + MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__)) + +#define MLX_SWITCH_REAL_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ + MLX_INTERNAL_SWITCH_CHECKED( \ + TYPE, \ + NAME, \ + MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__)) + +} // namespace mlx::core diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 5197e516f..188584174 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -5,6 +5,7 @@ #include #include +#include "mlx/dtype_utils.h" #include "mlx/types/limits.h" #include "mlx/utils.h" @@ -224,37 +225,7 @@ void print_array(std::ostream& os, const array& a) { } // namespace std::ostream& operator<<(std::ostream& os, const Dtype& dtype) { - switch (dtype) { - case bool_: - return os << "bool"; - case uint8: - return os << "uint8"; - case uint16: - return os << "uint16"; - case uint32: - return os << "uint32"; - case uint64: - return os << "uint64"; - case int8: - return os << "int8"; - case int16: - return os << "int16"; - case int32: - return os << "int32"; - case int64: - return os << "int64"; - case float16: - return os << "float16"; - case float32: - return os << "float32"; - case float64: - return os << "float64"; - case bfloat16: - return os << "bfloat16"; - case complex64: - return os << "complex64"; - } - return os; + return os << dtype_to_string(dtype); } std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) { @@ -277,50 +248,7 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) { std::ostream& operator<<(std::ostream& os, array a) { a.eval(); - switch (a.dtype()) { - case bool_: - print_array(os, a); - break; - case uint8: - print_array(os, a); - break; - case uint16: - print_array(os, a); - break; - case uint32: - print_array(os, a); - break; - case uint64: - print_array(os, a); - break; - case int8: - print_array(os, a); - break; - case int16: - print_array(os, a); - break; - case int32: - print_array(os, a); - break; - case int64: - print_array(os, a); - break; - case float16: - print_array(os, a); - break; - case bfloat16: - print_array(os, a); - break; - case float32: - print_array(os, a); - break; - case float64: - print_array(os, a); - break; - case complex64: - print_array(os, a); - break; - } + MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE, print_array(os, a)); return os; } @@ -387,36 +315,8 @@ void set_iinfo_limits(int64_t& min, uint64_t& max) { } iinfo::iinfo(Dtype dtype) : dtype(dtype) { - switch (dtype) { - case int8: - set_iinfo_limits(min, max); - break; - case uint8: - set_iinfo_limits(min, max); - break; - case int16: - set_iinfo_limits(min, max); - break; - case uint16: - set_iinfo_limits(min, max); - break; - case int32: - set_iinfo_limits(min, max); - break; - case uint32: - set_iinfo_limits(min, max); - break; - case int64: - set_iinfo_limits(min, max); - break; - case uint64: - set_iinfo_limits(min, max); - break; - default: - std::ostringstream msg; - msg << "[iinfo] dtype " << dtype << " is not integral."; - throw std::invalid_argument(msg.str()); - } + MLX_SWITCH_INT_TYPES_CHECKED( + dtype, "[iinfo]", CTYPE, set_iinfo_limits(min, max)); } } // namespace mlx::core