mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Add grid_sample
example to metal_kernel
docs (#1352)
* Add `zero_outputs` and `atomic_outputs` options to `metal_kernel` * add grid sample to docs * zero_outputs -> init_value * add missing header for linux
This commit is contained in:
parent
3b4d5484c7
commit
b96e105244
@ -43,7 +43,8 @@ The full function signature will be generated using:
|
||||
* The keys and shapes/dtypes of ``inputs``
|
||||
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
|
||||
so we will add ``const device float16_t* inp`` to the signature.
|
||||
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience.
|
||||
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
|
||||
in ``source``.
|
||||
* The keys and values of ``output_shapes`` and ``output_dtypes``
|
||||
In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
|
||||
so we add ``device float16_t* out``.
|
||||
@ -73,7 +74,7 @@ Putting this all together, the generated function signature for ``myexp`` is as
|
||||
|
||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||
|
||||
You can print the generated code for a ``mx.fast.metal_kernel`` by passing ``verbose=True`` when you call it.
|
||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
||||
|
||||
Using Shape/Strides
|
||||
-------------------
|
||||
@ -121,3 +122,292 @@ Let's convert ``myexp`` above to support arbitrarily strided arrays without rely
|
||||
a = a[::2]
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
|
||||
Complex Example
|
||||
-----------------------------
|
||||
|
||||
Let's implement a more complex example: ``grid_sample`` in ``"bilinear"`` mode.
|
||||
|
||||
We'll start with the following MLX implementation using standard ops:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def grid_sample_ref(x, grid):
|
||||
N, H_in, W_in, _ = x.shape
|
||||
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
||||
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
||||
|
||||
ix_nw = mx.floor(ix).astype(mx.int32)
|
||||
iy_nw = mx.floor(iy).astype(mx.int32)
|
||||
|
||||
ix_ne = ix_nw + 1
|
||||
iy_ne = iy_nw
|
||||
|
||||
ix_sw = ix_nw
|
||||
iy_sw = iy_nw + 1
|
||||
|
||||
ix_se = ix_nw + 1
|
||||
iy_se = iy_nw + 1
|
||||
|
||||
nw = (ix_se - ix) * (iy_se - iy)
|
||||
ne = (ix - ix_sw) * (iy_sw - iy)
|
||||
sw = (ix_ne - ix) * (iy - iy_ne)
|
||||
se = (ix - ix_nw) * (iy - iy_nw)
|
||||
|
||||
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
||||
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
||||
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
||||
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
||||
|
||||
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
||||
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
||||
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
||||
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
||||
|
||||
I_nw *= mask_nw[..., None]
|
||||
I_ne *= mask_ne[..., None]
|
||||
I_sw *= mask_sw[..., None]
|
||||
I_se *= mask_se[..., None]
|
||||
|
||||
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
||||
|
||||
return output
|
||||
|
||||
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
|
||||
to write a fast GPU kernel for both the forward and backward passes.
|
||||
|
||||
First we'll implement the forward pass as a fused kernel:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@mx.custom_function
|
||||
def grid_sample(x, grid):
|
||||
|
||||
assert x.ndim == 4, "`x` must be 4D."
|
||||
assert grid.ndim == 4, "`grid` must be 4D."
|
||||
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
out_shape = (B, gN, gM, C)
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
|
||||
uint grid_idx = elem / C * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
int batch_idx = elem / C / gH / gW * b_stride;
|
||||
int channel_idx = elem % C;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
|
||||
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
||||
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
||||
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
||||
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
||||
|
||||
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
||||
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
||||
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
||||
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
||||
|
||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample",
|
||||
source=source,
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs={"x": x, "grid": grid},
|
||||
template={"T": x.dtype},
|
||||
output_shapes={"out": out_shape},
|
||||
output_dtypes={"out": x.dtype},
|
||||
grid=(np.prod(out_shape), 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
)
|
||||
return outputs["out"]
|
||||
|
||||
For a reasonably sized input such as:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
x.shape = (8, 1024, 1024, 64)
|
||||
grid.shape = (8, 256, 256, 2)
|
||||
|
||||
On an M1 Max, we see a big performance improvement:
|
||||
|
||||
``55.7ms -> 6.7ms => 8x speed up``
|
||||
|
||||
Grid Sample VJP
|
||||
---------------
|
||||
|
||||
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
|
||||
its custom vjp transform so MLX can differentiate it.
|
||||
|
||||
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
||||
requires a few extra ``mx.fast.metal_kernel`` features:
|
||||
|
||||
* ``init_value=0``
|
||||
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
||||
|
||||
* ``atomic_outputs=True``
|
||||
Designate all of the kernel outputs as ``atomic`` in the function signature.
|
||||
This means we can use Metal's ``atomic`` features to simultaneously update the ``x_grad`` and ``grid_grad`` arrays from multiple threadgroups.
|
||||
See section 6.15 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ for more details.
|
||||
|
||||
We can then implement the backwards pass as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@grid_sample.vjp
|
||||
def grid_sample_vjp(primals, cotangent, _):
|
||||
x, grid = primals
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
// Pad C to the nearest larger simdgroup size multiple
|
||||
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
||||
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
|
||||
uint grid_idx = elem / C_padded * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
||||
int channel_idx = elem % C_padded;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
|
||||
T gix = T(0);
|
||||
T giy = T(0);
|
||||
if (channel_idx < C) {
|
||||
int cot_index = elem / C_padded * C + channel_idx;
|
||||
T cot = cotangent[cot_index];
|
||||
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
||||
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
||||
|
||||
T I_nw = x[offset];
|
||||
gix -= I_nw * (iy_se - iy) * cot;
|
||||
giy -= I_nw * (ix_se - ix) * cot;
|
||||
}
|
||||
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
||||
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
||||
|
||||
T I_ne = x[offset];
|
||||
gix += I_ne * (iy_sw - iy) * cot;
|
||||
giy -= I_ne * (ix - ix_sw) * cot;
|
||||
}
|
||||
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
||||
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
||||
|
||||
T I_sw = x[offset];
|
||||
gix -= I_sw * (iy - iy_ne) * cot;
|
||||
giy += I_sw * (ix_ne - ix) * cot;
|
||||
}
|
||||
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
||||
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
||||
|
||||
T I_se = x[offset];
|
||||
gix += I_se * (iy - iy_nw) * cot;
|
||||
giy += I_se * (ix - ix_nw) * cot;
|
||||
}
|
||||
}
|
||||
|
||||
T gix_mult = W / 2;
|
||||
T giy_mult = H / 2;
|
||||
|
||||
// Reduce across each simdgroup first.
|
||||
// This is much faster than relying purely on atomics.
|
||||
gix = simd_sum(gix);
|
||||
giy = simd_sum(giy);
|
||||
|
||||
if (thread_index_in_simdgroup == 0) {
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
||||
}
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample_grad",
|
||||
source=source,
|
||||
atomic_outputs=True,
|
||||
)
|
||||
# pad the output channels to simd group size
|
||||
# so that our `simd_sum`s don't overlap.
|
||||
simdgroup_size = 32
|
||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||
grid_size = B * gN * gM * C_padded
|
||||
outputs = kernel(
|
||||
inputs={"x": x, "grid": grid, "cotangent": cotangent},
|
||||
template={"T": x.dtype},
|
||||
output_shapes={"x_grad": x.shape, "grid_grad": grid.shape},
|
||||
output_dtypes={"x_grad": x.dtype, "grid_grad": x.dtype},
|
||||
grid=(grid_size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
init_value=0,
|
||||
)
|
||||
return outputs["x_grad"], outputs["grid_grad"]
|
||||
|
||||
There's an even larger speed up for the vjp:
|
||||
|
||||
``676.4ms -> 16.7ms => 40x speed up``
|
||||
|
@ -12,12 +12,17 @@ void CustomKernel::eval_gpu(
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
|
||||
std::vector<array> copies;
|
||||
|
||||
for (auto& out : outputs) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
if (init_value_) {
|
||||
array init = array(init_value_.value(), out.dtype());
|
||||
copy_gpu(init, out, CopyType::Scalar, s);
|
||||
copies.push_back(init);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> copies;
|
||||
|
||||
auto check_input = [&copies, &s, this](const array& x) -> const array {
|
||||
bool no_copy = x.flags().row_contiguous;
|
||||
if (!ensure_row_contiguous_ || no_copy) {
|
||||
|
16
mlx/fast.cpp
16
mlx/fast.cpp
@ -949,6 +949,7 @@ void write_signature(
|
||||
std::map<std::string, Dtype>& output_dtypes,
|
||||
std::optional<std::map<std::string, TemplateArg>> template_args,
|
||||
std::vector<CustomKernelShapeInfo>& shape_infos,
|
||||
bool atomic_outputs,
|
||||
std::ostringstream& kernel_source) {
|
||||
// Auto-generate a function signature based on `template_args`
|
||||
// and the dtype/shape of the arrays passed as `inputs`.
|
||||
@ -1042,8 +1043,14 @@ void write_signature(
|
||||
}
|
||||
// Add outputs
|
||||
for (const auto& [name, dtype] : output_dtypes) {
|
||||
kernel_source << " device " << get_type_string(dtype) << "* " << name
|
||||
<< " [[buffer(" << index << ")]]";
|
||||
kernel_source << " device ";
|
||||
auto type_string = get_type_string(dtype);
|
||||
if (atomic_outputs) {
|
||||
kernel_source << "atomic<" << type_string << ">";
|
||||
} else {
|
||||
kernel_source << type_string;
|
||||
}
|
||||
kernel_source << "* " << name << " [[buffer(" << index << ")]]";
|
||||
if (index < inputs.size() + output_shapes.size() - 1 || attrs.size() > 0) {
|
||||
kernel_source << "," << std::endl;
|
||||
} else {
|
||||
@ -1094,6 +1101,7 @@ std::map<std::string, array> MetalKernel::operator()(
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
std::optional<std::map<std::string, TemplateArg>> template_args,
|
||||
std::optional<float> init_value,
|
||||
bool verbose,
|
||||
StreamOrDevice s_) {
|
||||
validate_output_shapes(output_shapes, output_dtypes);
|
||||
@ -1129,6 +1137,7 @@ std::map<std::string, array> MetalKernel::operator()(
|
||||
output_dtypes,
|
||||
template_args,
|
||||
shape_infos,
|
||||
atomic_outputs_,
|
||||
kernel_source);
|
||||
|
||||
if (needs_template) {
|
||||
@ -1174,7 +1183,8 @@ std::map<std::string, array> MetalKernel::operator()(
|
||||
grid,
|
||||
threadgroup,
|
||||
shape_infos,
|
||||
ensure_row_contiguous_),
|
||||
ensure_row_contiguous_,
|
||||
init_value),
|
||||
in_arrs);
|
||||
|
||||
int i = 0;
|
||||
|
@ -71,10 +71,12 @@ class MetalKernel {
|
||||
MetalKernel(
|
||||
const std::string& name,
|
||||
const std::string& source,
|
||||
bool ensure_row_contiguous)
|
||||
bool ensure_row_contiguous,
|
||||
bool atomic_outputs)
|
||||
: name_(name),
|
||||
source_(source),
|
||||
ensure_row_contiguous_(ensure_row_contiguous) {}
|
||||
ensure_row_contiguous_(ensure_row_contiguous),
|
||||
atomic_outputs_(atomic_outputs) {}
|
||||
|
||||
std::map<std::string, array> operator()(
|
||||
std::map<std::string, array>& inputs,
|
||||
@ -84,6 +86,7 @@ class MetalKernel {
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
std::optional<std::map<std::string, TemplateArg>> template_args =
|
||||
std::nullopt,
|
||||
std::optional<float> init_value = std::nullopt,
|
||||
bool verbose = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
@ -91,5 +94,6 @@ class MetalKernel {
|
||||
std::string name_;
|
||||
std::string source_;
|
||||
bool ensure_row_contiguous_ = true;
|
||||
bool atomic_outputs_ = false;
|
||||
};
|
||||
} // namespace mlx::core::fast
|
||||
|
@ -1,5 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
@ -257,14 +259,16 @@ class CustomKernel : public Primitive {
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
std::vector<CustomKernelShapeInfo> shape_infos,
|
||||
bool ensure_row_contiguous)
|
||||
bool ensure_row_contiguous,
|
||||
std::optional<float> init_value)
|
||||
: Primitive(stream),
|
||||
source_(source),
|
||||
name_(name),
|
||||
grid_(grid),
|
||||
threadgroup_(threadgroup),
|
||||
shape_infos_(shape_infos),
|
||||
ensure_row_contiguous_(ensure_row_contiguous) {}
|
||||
ensure_row_contiguous_(ensure_row_contiguous),
|
||||
init_value_(init_value) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
@ -283,6 +287,7 @@ class CustomKernel : public Primitive {
|
||||
std::tuple<int, int, int> threadgroup_;
|
||||
std::vector<CustomKernelShapeInfo> shape_infos_;
|
||||
bool ensure_row_contiguous_;
|
||||
std::optional<float> init_value_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::fast
|
||||
|
@ -200,10 +200,11 @@ void init_fast(nb::module_& parent_module) {
|
||||
A jit-compiled custom Metal kernel defined from a source string.
|
||||
)pbdoc")
|
||||
.def(
|
||||
nb::init<const std::string&, const std::string&, bool>(),
|
||||
nb::init<const std::string&, const std::string&, bool, bool>(),
|
||||
"name"_a,
|
||||
"source"_a,
|
||||
"ensure_row_contiguous"_a = true,
|
||||
"atomic_outputs"_a = false,
|
||||
R"pbdoc(
|
||||
Initialize a metal_kernel.
|
||||
|
||||
@ -215,6 +216,8 @@ void init_fast(nb::module_& parent_module) {
|
||||
used when the kernel is called.
|
||||
ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous
|
||||
before the kernel runs. Default: ``True``.
|
||||
atomic_outputs (bool): Whether to use atomic outputs in the function signature
|
||||
e.g. ``device atomic<float>``. Default: ``False``.
|
||||
Returns:
|
||||
Callable ``metal_kernel``.
|
||||
|
||||
@ -256,6 +259,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
std::optional<std::map<std::string, nb::handle>> template_args_,
|
||||
std::optional<float> init_value,
|
||||
bool verbose,
|
||||
StreamOrDevice s) {
|
||||
std::map<std::string, array> inputs;
|
||||
@ -289,6 +293,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
grid,
|
||||
threadgroup,
|
||||
template_args,
|
||||
init_value,
|
||||
verbose,
|
||||
s);
|
||||
},
|
||||
@ -299,10 +304,11 @@ void init_fast(nb::module_& parent_module) {
|
||||
"grid"_a,
|
||||
"threadgroup"_a,
|
||||
"template"_a = nb::none(),
|
||||
"init_value"_a = nb::none(),
|
||||
"verbose"_a = false,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def __call__(self, *, inputs: Mapping[str, Union[scalar, array]], output_shapes: Mapping[str, Sequence[int]], output_dtypes: Mapping[str, Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[Mapping[str, Union[bool, int, Dtype]]] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
|
||||
"def __call__(self, *, inputs: Mapping[str, Union[scalar, array]], output_shapes: Mapping[str, Sequence[int]], output_dtypes: Mapping[str, Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[Mapping[str, Union[bool, int, Dtype]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"),
|
||||
R"pbdoc(
|
||||
Run the kernel.
|
||||
|
||||
@ -316,9 +322,11 @@ void init_fast(nb::module_& parent_module) {
|
||||
grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with.
|
||||
threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use.
|
||||
template (Mapping[str, Union[bool, int, Dtype]], optional): Template arguments.
|
||||
These will be added as template arguments to the kernel definition.
|
||||
These will be added as template arguments to the kernel definition. Default: ``None``.
|
||||
init_value (float, optional): Optional value to use to initialize all of the output arrays.
|
||||
By default, output arrays are uninitialized. Default: ``None``.
|
||||
verbose (bool, optional): Whether to print the full generated source code of the kernel
|
||||
when it is run.
|
||||
when it is run. Default: ``False``.
|
||||
stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
|
Loading…
Reference in New Issue
Block a user