mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Add optional headers to `mx.fast.metal_kernel
` (#1358)
This commit is contained in:
parent
5f7d19d1f5
commit
1d94ac3f90
@ -1112,7 +1112,6 @@ std::map<std::string, array> MetalKernel::operator()(
|
|||||||
"[metal_kernel] MetalKernel only works on GPU.");
|
"[metal_kernel] MetalKernel only works on GPU.");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::ostringstream kernel_source;
|
|
||||||
std::ostringstream func_name;
|
std::ostringstream func_name;
|
||||||
|
|
||||||
std::string template_def = "";
|
std::string template_def = "";
|
||||||
@ -1128,6 +1127,9 @@ std::map<std::string, array> MetalKernel::operator()(
|
|||||||
func_name << "custom_kernel_" << name_ << hash_key;
|
func_name << "custom_kernel_" << name_ << hash_key;
|
||||||
std::string kernel_name = func_name.str();
|
std::string kernel_name = func_name.str();
|
||||||
|
|
||||||
|
std::ostringstream kernel_source;
|
||||||
|
kernel_source << header_ << std::endl;
|
||||||
|
|
||||||
std::vector<CustomKernelShapeInfo> shape_infos;
|
std::vector<CustomKernelShapeInfo> shape_infos;
|
||||||
write_signature(
|
write_signature(
|
||||||
func_name.str(),
|
func_name.str(),
|
||||||
|
11
mlx/fast.h
11
mlx/fast.h
@ -71,10 +71,12 @@ class MetalKernel {
|
|||||||
MetalKernel(
|
MetalKernel(
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
const std::string& source,
|
const std::string& source,
|
||||||
bool ensure_row_contiguous,
|
const std::string& header = "",
|
||||||
bool atomic_outputs)
|
bool ensure_row_contiguous = true,
|
||||||
|
bool atomic_outputs = false)
|
||||||
: name_(name),
|
: name_(name),
|
||||||
source_(source),
|
source_(source),
|
||||||
|
header_(header),
|
||||||
ensure_row_contiguous_(ensure_row_contiguous),
|
ensure_row_contiguous_(ensure_row_contiguous),
|
||||||
atomic_outputs_(atomic_outputs) {}
|
atomic_outputs_(atomic_outputs) {}
|
||||||
|
|
||||||
@ -93,7 +95,8 @@ class MetalKernel {
|
|||||||
private:
|
private:
|
||||||
std::string name_;
|
std::string name_;
|
||||||
std::string source_;
|
std::string source_;
|
||||||
bool ensure_row_contiguous_ = true;
|
std::string header_;
|
||||||
bool atomic_outputs_ = false;
|
bool ensure_row_contiguous_;
|
||||||
|
bool atomic_outputs_;
|
||||||
};
|
};
|
||||||
} // namespace mlx::core::fast
|
} // namespace mlx::core::fast
|
||||||
|
@ -200,9 +200,15 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
A jit-compiled custom Metal kernel defined from a source string.
|
A jit-compiled custom Metal kernel defined from a source string.
|
||||||
)pbdoc")
|
)pbdoc")
|
||||||
.def(
|
.def(
|
||||||
nb::init<const std::string&, const std::string&, bool, bool>(),
|
nb::init<
|
||||||
|
const std::string&,
|
||||||
|
const std::string&,
|
||||||
|
const std::string&,
|
||||||
|
bool,
|
||||||
|
bool>(),
|
||||||
"name"_a,
|
"name"_a,
|
||||||
"source"_a,
|
"source"_a,
|
||||||
|
"header"_a = "",
|
||||||
"ensure_row_contiguous"_a = true,
|
"ensure_row_contiguous"_a = true,
|
||||||
"atomic_outputs"_a = false,
|
"atomic_outputs"_a = false,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
@ -214,6 +220,8 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
the function signature will be generated for you. The names of the inputs/outputs
|
the function signature will be generated for you. The names of the inputs/outputs
|
||||||
are determined by the ``inputs`` and ``output_shapes``/``output_dtypes``
|
are determined by the ``inputs`` and ``output_shapes``/``output_dtypes``
|
||||||
used when the kernel is called.
|
used when the kernel is called.
|
||||||
|
header (str): Header source code to include before the main function.
|
||||||
|
Useful for helper functions or includes that should live outside of the main function body.
|
||||||
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
|
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
|
||||||
before the kernel runs. Default: ``True``.
|
before the kernel runs. Default: ``True``.
|
||||||
atomic_outputs (bool): Whether to use atomic outputs in the function signature
|
atomic_outputs (bool): Whether to use atomic outputs in the function signature
|
||||||
|
@ -551,7 +551,7 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||||
def test_custom_kernel_basic(self):
|
def test_custom_kernel_basic(self):
|
||||||
mx.random.seed(7)
|
mx.random.seed(7)
|
||||||
a = mx.random.normal(shape=(3, 6))
|
a = mx.random.normal(shape=(2, 2))
|
||||||
kernel = mx.fast.metal_kernel(
|
kernel = mx.fast.metal_kernel(
|
||||||
name="basic",
|
name="basic",
|
||||||
source="""
|
source="""
|
||||||
@ -567,7 +567,7 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
output_dtypes={"out1": mx.float32},
|
output_dtypes={"out1": mx.float32},
|
||||||
stream=mx.gpu,
|
stream=mx.gpu,
|
||||||
)
|
)
|
||||||
mx.allclose(out["out1"], a[:2, :2])
|
self.assertTrue(mx.allclose(out["out1"], a))
|
||||||
|
|
||||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||||
def test_custom_kernel_args(self):
|
def test_custom_kernel_args(self):
|
||||||
@ -618,12 +618,12 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||||
T tmp = inp[loc];
|
T tmp = inp[loc];
|
||||||
out[elem] = metal::exp(tmp) * threads_per_simdgroup;
|
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
|
||||||
"""
|
"""
|
||||||
source_contig = """
|
source_contig = """
|
||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
T tmp = inp[elem];
|
T tmp = inp[elem];
|
||||||
out[elem] = metal::exp(tmp) * threads_per_simdgroup;
|
out[elem] = metal::precise::exp(tmp) * threads_per_simdgroup;
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# non contiguous
|
# non contiguous
|
||||||
@ -646,6 +646,33 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs["out"]))
|
self.assertTrue(mx.allclose(mx.exp(a) * 32, outputs["out"]))
|
||||||
|
|
||||||
|
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||||
|
def test_custom_kernel_helper(self):
|
||||||
|
mx.random.seed(7)
|
||||||
|
a = mx.random.normal(shape=(2, 2))
|
||||||
|
kernel = mx.fast.metal_kernel(
|
||||||
|
name="helper",
|
||||||
|
header="""
|
||||||
|
template <typename T>
|
||||||
|
T do_exp(T x) {
|
||||||
|
return metal::precise::exp(x);
|
||||||
|
}
|
||||||
|
""",
|
||||||
|
source="""
|
||||||
|
uint elem = thread_position_in_grid.x;
|
||||||
|
out1[elem] = do_exp(a[elem]);
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
out = kernel(
|
||||||
|
inputs={"a": a},
|
||||||
|
grid=(4, 1, 1),
|
||||||
|
threadgroup=(2, 1, 1),
|
||||||
|
output_shapes={"out1": (2, 2)},
|
||||||
|
output_dtypes={"out1": mx.float32},
|
||||||
|
stream=mx.gpu,
|
||||||
|
)
|
||||||
|
self.assertTrue(mx.allclose(out["out1"], mx.exp(a)))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user