mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Introduce macros for dispatching dynamic dtypes as static types (#2073)
This commit is contained in:
parent
5f04c0f818
commit
b13f2aed16
@ -5,6 +5,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
|
20
mlx/dtype_utils.cpp
Normal file
20
mlx/dtype_utils.cpp
Normal file
@ -0,0 +1,20 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
const char* dtype_to_string(Dtype arg) {
|
||||
if (arg == bool_) {
|
||||
return "bool";
|
||||
}
|
||||
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
|
||||
if (DTYPE == arg) { \
|
||||
return #DTYPE; \
|
||||
}
|
||||
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString)
|
||||
#undef SPECIALIZE_DtypeToString
|
||||
return "(unknown)";
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
207
mlx/dtype_utils.h
Normal file
207
mlx/dtype_utils.h
Normal file
@ -0,0 +1,207 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
// Copyright © Meta Platforms, Inc. and affiliates.
|
||||
//
|
||||
// This source code is licensed under the BSD-style license found in
|
||||
// https://github.com/pytorch/executorch/blob/main/LICENSE
|
||||
//
|
||||
// Forked from
|
||||
// https://github.com/pytorch/executorch/blob/main/runtime/core/exec_aten/util/scalar_type_util.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/dtype.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Return string representation of dtype.
|
||||
const char* dtype_to_string(Dtype arg);
|
||||
|
||||
// Macros that iterate across different subsets of Dtypes.
|
||||
//
|
||||
// For all of these macros, the final `_` parameter is the name of another macro
|
||||
// that takes two parameters: the name of a C type, and the name of the
|
||||
// corresponding Dtype enumerator.
|
||||
//
|
||||
// Note that these macros should use fully-qualified namespaces (starting with
|
||||
// `::`) to ensure that they can be called safely in any arbitrary namespace.
|
||||
#define MLX_FORALL_INT_TYPES(_) \
|
||||
_(uint8_t, uint8) \
|
||||
_(uint16_t, uint16) \
|
||||
_(uint32_t, uint32) \
|
||||
_(uint64_t, uint64) \
|
||||
_(int8_t, int8) \
|
||||
_(int16_t, int16) \
|
||||
_(int32_t, int32) \
|
||||
_(int64_t, int64)
|
||||
|
||||
#define MLX_FORALL_FLOAT_TYPES(_) \
|
||||
_(float16_t, float16) \
|
||||
_(float, float32) \
|
||||
_(double, float64) \
|
||||
_(bfloat16_t, bfloat16)
|
||||
|
||||
// Calls the provided macro on every Dtype, providing the C type and the
|
||||
// Dtype name to each call.
|
||||
//
|
||||
// @param _ A macro that takes two parameters: the name of a C type, and the
|
||||
// name of the corresponding Dtype enumerator.
|
||||
#define MLX_FORALL_DTYPES(_) \
|
||||
MLX_FORALL_INT_TYPES(_) \
|
||||
MLX_FORALL_FLOAT_TYPES(_) \
|
||||
_(bool, bool_) \
|
||||
_(complex64_t, complex64)
|
||||
|
||||
// Maps Dtypes to C++ types.
|
||||
template <Dtype::Val N>
|
||||
struct DtypeToCppType;
|
||||
|
||||
#define SPECIALIZE_DtypeToCppType(CPP_TYPE, DTYPE) \
|
||||
template <> \
|
||||
struct DtypeToCppType<Dtype::Val::DTYPE> { \
|
||||
using type = CPP_TYPE; \
|
||||
};
|
||||
|
||||
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToCppType)
|
||||
|
||||
#undef SPECIALIZE_DtypeToCppType
|
||||
|
||||
// Maps C++ types to Dtypes.
|
||||
template <typename T>
|
||||
struct CppTypeToDtype;
|
||||
|
||||
#define SPECIALIZE_CppTypeToDtype(CPP_TYPE, DTYPE) \
|
||||
template <> \
|
||||
struct CppTypeToDtype<CPP_TYPE> \
|
||||
: std::integral_constant<Dtype::Val, Dtype::Val::DTYPE> {};
|
||||
|
||||
MLX_FORALL_DTYPES(SPECIALIZE_CppTypeToDtype)
|
||||
|
||||
#undef SPECIALIZE_CppTypeToDtype
|
||||
|
||||
// Helper macros for switch case macros (see below)
|
||||
//
|
||||
// These macros are not meant to be used directly. They provide an easy way to
|
||||
// generate a switch statement that can handle subsets of Dtypes supported.
|
||||
|
||||
#define MLX_INTERNAL_SWITCH_CASE(enum_type, CTYPE_ALIAS, ...) \
|
||||
case enum_type: { \
|
||||
using CTYPE_ALIAS = ::mlx::core::DtypeToCppType<enum_type>::type; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
}
|
||||
|
||||
#define MLX_INTERNAL_SWITCH_CHECKED(TYPE, NAME, ...) \
|
||||
switch (TYPE) { \
|
||||
__VA_ARGS__ \
|
||||
default: \
|
||||
throw std::invalid_argument(fmt::format( \
|
||||
"Unhandled dtype %s for %s", dtype_to_string(TYPE), NAME)); \
|
||||
}
|
||||
|
||||
#define MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, ...) \
|
||||
MLX_INTERNAL_SWITCH_CASE( \
|
||||
::mlx::core::Dtype::Val::uint8, CTYPE_ALIAS, __VA_ARGS__) \
|
||||
MLX_INTERNAL_SWITCH_CASE( \
|
||||
::mlx::core::Dtype::Val::uint16, CTYPE_ALIAS, __VA_ARGS__) \
|
||||
MLX_INTERNAL_SWITCH_CASE( \
|
||||
::mlx::core::Dtype::Val::uint32, CTYPE_ALIAS, __VA_ARGS__) \
|
||||
MLX_INTERNAL_SWITCH_CASE( \
|
||||
::mlx::core::Dtype::Val::uint64, CTYPE_ALIAS, __VA_ARGS__) \
|
||||
MLX_INTERNAL_SWITCH_CASE( \
|
||||
::mlx::core::Dtype::Val::int8, CTYPE_ALIAS, __VA_ARGS__) \
|
||||
MLX_INTERNAL_SWITCH_CASE( \
|
||||
::mlx::core::Dtype::Val::int16, CTYPE_ALIAS, __VA_ARGS__) \
|
||||
MLX_INTERNAL_SWITCH_CASE( \
|
||||
::mlx::core::Dtype::Val::int32, CTYPE_ALIAS, __VA_ARGS__) \
|
||||
MLX_INTERNAL_SWITCH_CASE( \
|
||||
::mlx::core::Dtype::Val::int64, CTYPE_ALIAS, __VA_ARGS__)
|
||||
|
||||
#define MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, ...) \
|
||||
MLX_INTERNAL_SWITCH_CASE( \
|
||||
::mlx::core::Dtype::Val::float16, CTYPE_ALIAS, __VA_ARGS__) \
|
||||
MLX_INTERNAL_SWITCH_CASE( \
|
||||
::mlx::core::Dtype::Val::float32, CTYPE_ALIAS, __VA_ARGS__) \
|
||||
MLX_INTERNAL_SWITCH_CASE( \
|
||||
::mlx::core::Dtype::Val::float64, CTYPE_ALIAS, __VA_ARGS__) \
|
||||
MLX_INTERNAL_SWITCH_CASE( \
|
||||
::mlx::core::Dtype::Val::bfloat16, CTYPE_ALIAS, __VA_ARGS__)
|
||||
|
||||
#define MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, ...) \
|
||||
MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \
|
||||
MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__)
|
||||
|
||||
#define MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, ...) \
|
||||
MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \
|
||||
MLX_INTERNAL_SWITCH_CASE( \
|
||||
::mlx::core::Dtype::Val::bool_, CTYPE_ALIAS, __VA_ARGS__)
|
||||
|
||||
#define MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, ...) \
|
||||
MLX_INTERNAL_SWITCH_CASE( \
|
||||
::mlx::core::Dtype::Val::complex64, CTYPE_ALIAS, __VA_ARGS__)
|
||||
|
||||
#define MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, ...) \
|
||||
MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__) \
|
||||
MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, __VA_ARGS__)
|
||||
|
||||
// Switch case macros
|
||||
//
|
||||
// These macros provide an easy way to generate switch statements that apply a
|
||||
// common lambda function to subsets of Dtypes supported by MLX.
|
||||
// The lambda function can type specialize to the ctype associated with the
|
||||
// Dtype being handled through an alias passed as the CTYPE_ALIAS argument.
|
||||
//
|
||||
// Arguments:
|
||||
// - ADDITIONAL: Additional Dtype case to add
|
||||
// - TYPE: The Dtype to handle through the switch statement
|
||||
// - NAME: A name for this operation which will be used in error messages
|
||||
// - CTYPE_ALIAS: A typedef for the ctype associated with the Dtype.
|
||||
// - ...: A statement to be applied to each Dtype case
|
||||
//
|
||||
// An example usage is:
|
||||
//
|
||||
// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE, {
|
||||
// output.data<CTYPE>[0] = input.data<CTYPE>[0];
|
||||
// });
|
||||
//
|
||||
// Note that these can be nested as well:
|
||||
//
|
||||
// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE_IN, {
|
||||
// MLX_SWITCH_ALL_TYPES(output.dtype(), CTYPE_OUT, {
|
||||
// output.data<CTYPE_OUT>[0] = input.data<CTYPE_IN>[0];
|
||||
// });
|
||||
// });
|
||||
//
|
||||
// These macros are adapted from Dispatch.h in the ATen library. The primary
|
||||
// difference is that the CTYPE_ALIAS argument is exposed to users, which is
|
||||
// used to alias the ctype associated with the Dtype that is being handled.
|
||||
|
||||
#define MLX_SWITCH_ALL_TYPES(TYPE, CTYPE_ALIAS, ...) \
|
||||
switch (TYPE) { MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, __VA_ARGS__) }
|
||||
|
||||
#define MLX_SWITCH_INT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
|
||||
MLX_INTERNAL_SWITCH_CHECKED( \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__))
|
||||
|
||||
#define MLX_SWITCH_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
|
||||
MLX_INTERNAL_SWITCH_CHECKED( \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__))
|
||||
|
||||
#define MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
|
||||
MLX_INTERNAL_SWITCH_CHECKED( \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__))
|
||||
|
||||
#define MLX_SWITCH_REAL_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
|
||||
MLX_INTERNAL_SWITCH_CHECKED( \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__))
|
||||
|
||||
} // namespace mlx::core
|
110
mlx/utils.cpp
110
mlx/utils.cpp
@ -5,6 +5,7 @@
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/types/limits.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@ -224,37 +225,7 @@ void print_array(std::ostream& os, const array& a) {
|
||||
} // namespace
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Dtype& dtype) {
|
||||
switch (dtype) {
|
||||
case bool_:
|
||||
return os << "bool";
|
||||
case uint8:
|
||||
return os << "uint8";
|
||||
case uint16:
|
||||
return os << "uint16";
|
||||
case uint32:
|
||||
return os << "uint32";
|
||||
case uint64:
|
||||
return os << "uint64";
|
||||
case int8:
|
||||
return os << "int8";
|
||||
case int16:
|
||||
return os << "int16";
|
||||
case int32:
|
||||
return os << "int32";
|
||||
case int64:
|
||||
return os << "int64";
|
||||
case float16:
|
||||
return os << "float16";
|
||||
case float32:
|
||||
return os << "float32";
|
||||
case float64:
|
||||
return os << "float64";
|
||||
case bfloat16:
|
||||
return os << "bfloat16";
|
||||
case complex64:
|
||||
return os << "complex64";
|
||||
}
|
||||
return os;
|
||||
return os << dtype_to_string(dtype);
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) {
|
||||
@ -277,50 +248,7 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) {
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, array a) {
|
||||
a.eval();
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
print_array<bool>(os, a);
|
||||
break;
|
||||
case uint8:
|
||||
print_array<uint8_t>(os, a);
|
||||
break;
|
||||
case uint16:
|
||||
print_array<uint16_t>(os, a);
|
||||
break;
|
||||
case uint32:
|
||||
print_array<uint32_t>(os, a);
|
||||
break;
|
||||
case uint64:
|
||||
print_array<uint64_t>(os, a);
|
||||
break;
|
||||
case int8:
|
||||
print_array<int8_t>(os, a);
|
||||
break;
|
||||
case int16:
|
||||
print_array<int16_t>(os, a);
|
||||
break;
|
||||
case int32:
|
||||
print_array<int32_t>(os, a);
|
||||
break;
|
||||
case int64:
|
||||
print_array<int64_t>(os, a);
|
||||
break;
|
||||
case float16:
|
||||
print_array<float16_t>(os, a);
|
||||
break;
|
||||
case bfloat16:
|
||||
print_array<bfloat16_t>(os, a);
|
||||
break;
|
||||
case float32:
|
||||
print_array<float>(os, a);
|
||||
break;
|
||||
case float64:
|
||||
print_array<double>(os, a);
|
||||
break;
|
||||
case complex64:
|
||||
print_array<complex64_t>(os, a);
|
||||
break;
|
||||
}
|
||||
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE, print_array<CTYPE>(os, a));
|
||||
return os;
|
||||
}
|
||||
|
||||
@ -387,36 +315,8 @@ void set_iinfo_limits(int64_t& min, uint64_t& max) {
|
||||
}
|
||||
|
||||
iinfo::iinfo(Dtype dtype) : dtype(dtype) {
|
||||
switch (dtype) {
|
||||
case int8:
|
||||
set_iinfo_limits<int8_t>(min, max);
|
||||
break;
|
||||
case uint8:
|
||||
set_iinfo_limits<uint8_t>(min, max);
|
||||
break;
|
||||
case int16:
|
||||
set_iinfo_limits<int16_t>(min, max);
|
||||
break;
|
||||
case uint16:
|
||||
set_iinfo_limits<uint16_t>(min, max);
|
||||
break;
|
||||
case int32:
|
||||
set_iinfo_limits<int32_t>(min, max);
|
||||
break;
|
||||
case uint32:
|
||||
set_iinfo_limits<uint32_t>(min, max);
|
||||
break;
|
||||
case int64:
|
||||
set_iinfo_limits<int64_t>(min, max);
|
||||
break;
|
||||
case uint64:
|
||||
set_iinfo_limits<uint64_t>(min, max);
|
||||
break;
|
||||
default:
|
||||
std::ostringstream msg;
|
||||
msg << "[iinfo] dtype " << dtype << " is not integral.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
MLX_SWITCH_INT_TYPES_CHECKED(
|
||||
dtype, "[iinfo]", CTYPE, set_iinfo_limits<CTYPE>(min, max));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
Loading…
Reference in New Issue
Block a user