mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 19:38:16 +08:00
rebase
This commit is contained in:
422
docs/build/html/dev/custom_metal_kernels.html
vendored
422
docs/build/html/dev/custom_metal_kernels.html
vendored
@@ -8,7 +8,7 @@
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
|
||||
<title>Custom Metal Kernels — MLX 0.26.1 documentation</title>
|
||||
<title>Custom Metal Kernels — MLX 0.26.2 documentation</title>
|
||||
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@
|
||||
<link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" />
|
||||
<script src="../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script>
|
||||
|
||||
<script src="../_static/documentation_options.js?v=3724ff34"></script>
|
||||
<script src="../_static/documentation_options.js?v=20507f52"></script>
|
||||
<script src="../_static/doctools.js?v=9a2dae69"></script>
|
||||
<script src="../_static/sphinx_highlight.js?v=dc90522c"></script>
|
||||
<script src="../_static/scripts/sphinx-book-theme.js?v=887ef09a"></script>
|
||||
@@ -137,8 +137,8 @@
|
||||
|
||||
|
||||
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.26.1 documentation - Home"/>
|
||||
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.26.1 documentation - Home"/>`);</script>
|
||||
<img src="../_static/mlx_logo.png" class="logo__image only-light" alt="MLX 0.26.2 documentation - Home"/>
|
||||
<script>document.write(`<img src="../_static/mlx_logo_dark.png" class="logo__image only-dark" alt="MLX 0.26.2 documentation - Home"/>`);</script>
|
||||
|
||||
|
||||
</a></div>
|
||||
@@ -926,19 +926,20 @@ document.write(`
|
||||
<section id="simple-example">
|
||||
<h2>Simple Example<a class="headerlink" href="#simple-example" title="Link to this heading">#</a></h2>
|
||||
<p>Let’s write a custom kernel that computes <code class="docutils literal notranslate"><span class="pre">exp</span></code> elementwise:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">exp_elementwise</span><span class="p">(</span><span class="n">a</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">):</span>
|
||||
<span class="n">source</span> <span class="o">=</span> <span class="s2">"""</span>
|
||||
<span class="s2"> uint elem = thread_position_in_grid.x;</span>
|
||||
<span class="s2"> T tmp = inp[elem];</span>
|
||||
<span class="s2"> out[elem] = metal::exp(tmp);</span>
|
||||
<span class="s2"> """</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">source</span> <span class="o">=</span> <span class="s2">"""</span>
|
||||
<span class="s2"> uint elem = thread_position_in_grid.x;</span>
|
||||
<span class="s2"> T tmp = inp[elem];</span>
|
||||
<span class="s2"> out[elem] = metal::exp(tmp);</span>
|
||||
<span class="s2">"""</span>
|
||||
|
||||
<span class="n">kernel</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">fast</span><span class="o">.</span><span class="n">metal_kernel</span><span class="p">(</span>
|
||||
<span class="n">name</span><span class="o">=</span><span class="s2">"myexp"</span><span class="p">,</span>
|
||||
<span class="n">input_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"inp"</span><span class="p">],</span>
|
||||
<span class="n">output_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"out"</span><span class="p">],</span>
|
||||
<span class="n">source</span><span class="o">=</span><span class="n">source</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
<span class="n">kernel</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">fast</span><span class="o">.</span><span class="n">metal_kernel</span><span class="p">(</span>
|
||||
<span class="n">name</span><span class="o">=</span><span class="s2">"myexp"</span><span class="p">,</span>
|
||||
<span class="n">input_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"inp"</span><span class="p">],</span>
|
||||
<span class="n">output_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"out"</span><span class="p">],</span>
|
||||
<span class="n">source</span><span class="o">=</span><span class="n">source</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">exp_elementwise</span><span class="p">(</span><span class="n">a</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">):</span>
|
||||
<span class="n">outputs</span> <span class="o">=</span> <span class="n">kernel</span><span class="p">(</span>
|
||||
<span class="n">inputs</span><span class="o">=</span><span class="p">[</span><span class="n">a</span><span class="p">],</span>
|
||||
<span class="n">template</span><span class="o">=</span><span class="p">[(</span><span class="s2">"T"</span><span class="p">,</span> <span class="n">mx</span><span class="o">.</span><span class="n">float32</span><span class="p">)],</span>
|
||||
@@ -954,9 +955,13 @@ document.write(`
|
||||
<span class="k">assert</span> <span class="n">mx</span><span class="o">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="n">mx</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">a</span><span class="p">))</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Every time you make a kernel, a new Metal library is created and possibly
|
||||
JIT compiled. To reduce the overhead from that, build the kernel once with
|
||||
<a class="reference internal" href="../python/_autosummary/mlx.core.fast.metal_kernel.html#mlx.core.fast.metal_kernel" title="mlx.core.fast.metal_kernel"><code class="xref py py-func docutils literal notranslate"><span class="pre">fast.metal_kernel()</span></code></a> and then use it many times.</p>
|
||||
<div class="admonition note">
|
||||
<p class="admonition-title">Note</p>
|
||||
<p>We are only required to pass the body of the Metal kernel in <code class="docutils literal notranslate"><span class="pre">source</span></code>.</p>
|
||||
<p>Only pass the body of the Metal kernel in <code class="docutils literal notranslate"><span class="pre">source</span></code>. The function
|
||||
signature is generated automatically.</p>
|
||||
</div>
|
||||
<p>The full function signature will be generated using:</p>
|
||||
<ul class="simple">
|
||||
@@ -1004,37 +1009,43 @@ All the attributes defined in Table 5.8 of the <a class="reference external" hre
|
||||
<span class="k">template</span><span class="w"> </span><span class="p">[[</span><span class="n">host_name</span><span class="p">(</span><span class="s">"custom_kernel_myexp_float"</span><span class="p">)]]</span><span class="w"> </span><span class="p">[[</span><span class="n">kernel</span><span class="p">]]</span><span class="w"> </span><span class="k">decltype</span><span class="p">(</span><span class="n">custom_kernel_myexp_float</span><span class="o"><</span><span class="kt">float</span><span class="o">></span><span class="p">)</span><span class="w"> </span><span class="n">custom_kernel_myexp_float</span><span class="o"><</span><span class="kt">float</span><span class="o">></span><span class="p">;</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Note: <code class="docutils literal notranslate"><span class="pre">grid</span></code> and <code class="docutils literal notranslate"><span class="pre">threadgroup</span></code> are parameters to the Metal <a class="reference external" href="https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads">dispatchThreads</a> function.
|
||||
This means we will launch <code class="docutils literal notranslate"><span class="pre">mx.prod(grid)</span></code> threads, subdivided into <code class="docutils literal notranslate"><span class="pre">threadgroup</span></code> size threadgroups.
|
||||
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.</p>
|
||||
<p>Passing <code class="docutils literal notranslate"><span class="pre">verbose=True</span></code> to <code class="docutils literal notranslate"><span class="pre">mx.fast.metal_kernel.__call__</span></code> will print the generated code for debugging purposes.</p>
|
||||
<p>Note: <code class="docutils literal notranslate"><span class="pre">grid</span></code> and <code class="docutils literal notranslate"><span class="pre">threadgroup</span></code> are parameters to the Metal <a class="reference external" href="https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads">dispatchThreads</a>
|
||||
function. This means we will launch <code class="docutils literal notranslate"><span class="pre">mx.prod(grid)</span></code> threads, subdivided into
|
||||
<code class="docutils literal notranslate"><span class="pre">threadgroup</span></code> size threadgroups. For optimal performance, each thread group
|
||||
dimension should be less than or equal to the corresponding grid dimension.</p>
|
||||
<p>Passing <code class="docutils literal notranslate"><span class="pre">verbose=True</span></code> to <code class="xref py py-func docutils literal notranslate"><span class="pre">ast.metal_kernel.__call__()</span></code> will print the
|
||||
generated code for debugging purposes.</p>
|
||||
</section>
|
||||
<section id="using-shape-strides">
|
||||
<h2>Using Shape/Strides<a class="headerlink" href="#using-shape-strides" title="Link to this heading">#</a></h2>
|
||||
<p><code class="docutils literal notranslate"><span class="pre">mx.fast.metal_kernel</span></code> supports an argument <code class="docutils literal notranslate"><span class="pre">ensure_row_contiguous</span></code> which is <code class="docutils literal notranslate"><span class="pre">True</span></code> by default.
|
||||
This will copy the <code class="docutils literal notranslate"><span class="pre">mx.array</span></code> inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
|
||||
Generally this makes writing the kernel easier, since we don’t have to worry about gaps or the ordering of the dims
|
||||
when indexing.</p>
|
||||
<p>If we want to avoid this copy, <code class="docutils literal notranslate"><span class="pre">metal_kernel</span></code> automatically passes <code class="docutils literal notranslate"><span class="pre">a_shape</span></code>, <code class="docutils literal notranslate"><span class="pre">a_strides</span></code> and <code class="docutils literal notranslate"><span class="pre">a_ndim</span></code> for each
|
||||
input array <code class="docutils literal notranslate"><span class="pre">a</span></code> if any are present in <code class="docutils literal notranslate"><span class="pre">source</span></code>.
|
||||
We can then use MLX’s built in indexing utils to fetch the right elements for each thread.</p>
|
||||
<p>Let’s convert <code class="docutils literal notranslate"><span class="pre">myexp</span></code> above to support arbitrarily strided arrays without relying on a copy from <code class="docutils literal notranslate"><span class="pre">ensure_row_contiguous</span></code>:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">exp_elementwise</span><span class="p">(</span><span class="n">a</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">):</span>
|
||||
<span class="n">source</span> <span class="o">=</span> <span class="s2">"""</span>
|
||||
<span class="s2"> uint elem = thread_position_in_grid.x;</span>
|
||||
<span class="s2"> // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included</span>
|
||||
<span class="s2"> uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);</span>
|
||||
<span class="s2"> T tmp = inp[loc];</span>
|
||||
<span class="s2"> // Output arrays are always row contiguous</span>
|
||||
<span class="s2"> out[elem] = metal::exp(tmp);</span>
|
||||
<span class="s2"> """</span>
|
||||
<p><a class="reference internal" href="../python/_autosummary/mlx.core.fast.metal_kernel.html#mlx.core.fast.metal_kernel" title="mlx.core.fast.metal_kernel"><code class="xref py py-func docutils literal notranslate"><span class="pre">fast.metal_kernel()</span></code></a> supports an argument <code class="docutils literal notranslate"><span class="pre">ensure_row_contiguous</span></code> which
|
||||
is <code class="docutils literal notranslate"><span class="pre">True</span></code> by default. This will copy the array inputs if needed
|
||||
before the kernel is launched to ensure that the memory layout is row
|
||||
contiguous. Generally this makes writing the kernel easier, since we don’t
|
||||
have to worry about gaps or the ordering of the dims when indexing.</p>
|
||||
<p>If we want to avoid this copy, <a class="reference internal" href="../python/_autosummary/mlx.core.fast.metal_kernel.html#mlx.core.fast.metal_kernel" title="mlx.core.fast.metal_kernel"><code class="xref py py-func docutils literal notranslate"><span class="pre">fast.metal_kernel()</span></code></a> automatically passes
|
||||
<code class="docutils literal notranslate"><span class="pre">a_shape</span></code>, <code class="docutils literal notranslate"><span class="pre">a_strides</span></code> and <code class="docutils literal notranslate"><span class="pre">a_ndim</span></code> for each input array <code class="docutils literal notranslate"><span class="pre">a</span></code> if any are
|
||||
present in <code class="docutils literal notranslate"><span class="pre">source</span></code>. We can then use MLX’s built in indexing utils to fetch
|
||||
the right elements for each thread.</p>
|
||||
<p>Let’s convert <code class="docutils literal notranslate"><span class="pre">myexp</span></code> above to support arbitrarily strided arrays without
|
||||
relying on a copy from <code class="docutils literal notranslate"><span class="pre">ensure_row_contiguous</span></code>:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">source</span> <span class="o">=</span> <span class="s2">"""</span>
|
||||
<span class="s2"> uint elem = thread_position_in_grid.x;</span>
|
||||
<span class="s2"> // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included</span>
|
||||
<span class="s2"> uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);</span>
|
||||
<span class="s2"> T tmp = inp[loc];</span>
|
||||
<span class="s2"> // Output arrays are always row contiguous</span>
|
||||
<span class="s2"> out[elem] = metal::exp(tmp);</span>
|
||||
<span class="s2">"""</span>
|
||||
|
||||
<span class="n">kernel</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">fast</span><span class="o">.</span><span class="n">metal_kernel</span><span class="p">(</span>
|
||||
<span class="n">name</span><span class="o">=</span><span class="s2">"myexp_strided"</span><span class="p">,</span>
|
||||
<span class="n">input_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"inp"</span><span class="p">],</span>
|
||||
<span class="n">output_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"out"</span><span class="p">],</span>
|
||||
<span class="n">source</span><span class="o">=</span><span class="n">source</span>
|
||||
<span class="p">)</span>
|
||||
<span class="n">kernel</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">fast</span><span class="o">.</span><span class="n">metal_kernel</span><span class="p">(</span>
|
||||
<span class="n">name</span><span class="o">=</span><span class="s2">"myexp_strided"</span><span class="p">,</span>
|
||||
<span class="n">input_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"inp"</span><span class="p">],</span>
|
||||
<span class="n">output_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"out"</span><span class="p">],</span>
|
||||
<span class="n">source</span><span class="o">=</span><span class="n">source</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">exp_elementwise</span><span class="p">(</span><span class="n">a</span><span class="p">:</span> <span class="n">mx</span><span class="o">.</span><span class="n">array</span><span class="p">):</span>
|
||||
<span class="n">outputs</span> <span class="o">=</span> <span class="n">kernel</span><span class="p">(</span>
|
||||
<span class="n">inputs</span><span class="o">=</span><span class="p">[</span><span class="n">a</span><span class="p">],</span>
|
||||
<span class="n">template</span><span class="o">=</span><span class="p">[(</span><span class="s2">"T"</span><span class="p">,</span> <span class="n">mx</span><span class="o">.</span><span class="n">float32</span><span class="p">)],</span>
|
||||
@@ -1100,10 +1111,67 @@ We can then use MLX’s built in indexing utils to fetch the right elements for
|
||||
<span class="k">return</span> <span class="n">output</span>
|
||||
</pre></div>
|
||||
</div>
|
||||
<p>Now let’s use <code class="docutils literal notranslate"><span class="pre">mx.custom_function</span></code> together with <code class="docutils literal notranslate"><span class="pre">mx.fast.metal_kernel</span></code>
|
||||
<p>Now let’s use <a class="reference internal" href="../python/_autosummary/mlx.core.custom_function.html#mlx.core.custom_function" title="mlx.core.custom_function"><code class="xref py py-func docutils literal notranslate"><span class="pre">custom_function()</span></code></a> together with <a class="reference internal" href="../python/_autosummary/mlx.core.fast.metal_kernel.html#mlx.core.fast.metal_kernel" title="mlx.core.fast.metal_kernel"><code class="xref py py-func docutils literal notranslate"><span class="pre">fast.metal_kernel()</span></code></a>
|
||||
to write a fast GPU kernel for both the forward and backward passes.</p>
|
||||
<p>First we’ll implement the forward pass as a fused kernel:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nd">@mx</span><span class="o">.</span><span class="n">custom_function</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">source</span> <span class="o">=</span> <span class="s2">"""</span>
|
||||
<span class="s2"> uint elem = thread_position_in_grid.x;</span>
|
||||
<span class="s2"> int H = x_shape[1];</span>
|
||||
<span class="s2"> int W = x_shape[2];</span>
|
||||
<span class="s2"> int C = x_shape[3];</span>
|
||||
<span class="s2"> int gH = grid_shape[1];</span>
|
||||
<span class="s2"> int gW = grid_shape[2];</span>
|
||||
|
||||
<span class="s2"> int w_stride = C;</span>
|
||||
<span class="s2"> int h_stride = W * w_stride;</span>
|
||||
<span class="s2"> int b_stride = H * h_stride;</span>
|
||||
|
||||
<span class="s2"> uint grid_idx = elem / C * 2;</span>
|
||||
<span class="s2"> float ix = ((grid[grid_idx] + 1) * W - 1) / 2;</span>
|
||||
<span class="s2"> float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;</span>
|
||||
|
||||
<span class="s2"> int ix_nw = floor(ix);</span>
|
||||
<span class="s2"> int iy_nw = floor(iy);</span>
|
||||
|
||||
<span class="s2"> int ix_ne = ix_nw + 1;</span>
|
||||
<span class="s2"> int iy_ne = iy_nw;</span>
|
||||
|
||||
<span class="s2"> int ix_sw = ix_nw;</span>
|
||||
<span class="s2"> int iy_sw = iy_nw + 1;</span>
|
||||
|
||||
<span class="s2"> int ix_se = ix_nw + 1;</span>
|
||||
<span class="s2"> int iy_se = iy_nw + 1;</span>
|
||||
|
||||
<span class="s2"> T nw = (ix_se - ix) * (iy_se - iy);</span>
|
||||
<span class="s2"> T ne = (ix - ix_sw) * (iy_sw - iy);</span>
|
||||
<span class="s2"> T sw = (ix_ne - ix) * (iy - iy_ne);</span>
|
||||
<span class="s2"> T se = (ix - ix_nw) * (iy - iy_nw);</span>
|
||||
|
||||
<span class="s2"> int batch_idx = elem / C / gH / gW * b_stride;</span>
|
||||
<span class="s2"> int channel_idx = elem % C;</span>
|
||||
<span class="s2"> int base_idx = batch_idx + channel_idx;</span>
|
||||
|
||||
<span class="s2"> T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];</span>
|
||||
<span class="s2"> T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];</span>
|
||||
<span class="s2"> T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];</span>
|
||||
<span class="s2"> T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];</span>
|
||||
|
||||
<span class="s2"> I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;</span>
|
||||
<span class="s2"> I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;</span>
|
||||
<span class="s2"> I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;</span>
|
||||
<span class="s2"> I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;</span>
|
||||
|
||||
<span class="s2"> out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;</span>
|
||||
<span class="s2">"""</span>
|
||||
|
||||
<span class="n">kernel</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">fast</span><span class="o">.</span><span class="n">metal_kernel</span><span class="p">(</span>
|
||||
<span class="n">name</span><span class="o">=</span><span class="s2">"grid_sample"</span><span class="p">,</span>
|
||||
<span class="n">input_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"x"</span><span class="p">,</span> <span class="s2">"grid"</span><span class="p">],</span>
|
||||
<span class="n">output_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"out"</span><span class="p">],</span>
|
||||
<span class="n">source</span><span class="o">=</span><span class="n">source</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="nd">@mx</span><span class="o">.</span><span class="n">custom_function</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">grid_sample</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">grid</span><span class="p">):</span>
|
||||
|
||||
<span class="k">assert</span> <span class="n">x</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">4</span><span class="p">,</span> <span class="s2">"`x` must be 4D."</span>
|
||||
@@ -1115,61 +1183,6 @@ to write a fast GPU kernel for both the forward and backward passes.</p>
|
||||
|
||||
<span class="k">assert</span> <span class="n">D</span> <span class="o">==</span> <span class="mi">2</span><span class="p">,</span> <span class="s2">"Last dim of `grid` must be size 2."</span>
|
||||
|
||||
<span class="n">source</span> <span class="o">=</span> <span class="s2">"""</span>
|
||||
<span class="s2"> uint elem = thread_position_in_grid.x;</span>
|
||||
<span class="s2"> int H = x_shape[1];</span>
|
||||
<span class="s2"> int W = x_shape[2];</span>
|
||||
<span class="s2"> int C = x_shape[3];</span>
|
||||
<span class="s2"> int gH = grid_shape[1];</span>
|
||||
<span class="s2"> int gW = grid_shape[2];</span>
|
||||
|
||||
<span class="s2"> int w_stride = C;</span>
|
||||
<span class="s2"> int h_stride = W * w_stride;</span>
|
||||
<span class="s2"> int b_stride = H * h_stride;</span>
|
||||
|
||||
<span class="s2"> uint grid_idx = elem / C * 2;</span>
|
||||
<span class="s2"> float ix = ((grid[grid_idx] + 1) * W - 1) / 2;</span>
|
||||
<span class="s2"> float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;</span>
|
||||
|
||||
<span class="s2"> int ix_nw = floor(ix);</span>
|
||||
<span class="s2"> int iy_nw = floor(iy);</span>
|
||||
|
||||
<span class="s2"> int ix_ne = ix_nw + 1;</span>
|
||||
<span class="s2"> int iy_ne = iy_nw;</span>
|
||||
|
||||
<span class="s2"> int ix_sw = ix_nw;</span>
|
||||
<span class="s2"> int iy_sw = iy_nw + 1;</span>
|
||||
|
||||
<span class="s2"> int ix_se = ix_nw + 1;</span>
|
||||
<span class="s2"> int iy_se = iy_nw + 1;</span>
|
||||
|
||||
<span class="s2"> T nw = (ix_se - ix) * (iy_se - iy);</span>
|
||||
<span class="s2"> T ne = (ix - ix_sw) * (iy_sw - iy);</span>
|
||||
<span class="s2"> T sw = (ix_ne - ix) * (iy - iy_ne);</span>
|
||||
<span class="s2"> T se = (ix - ix_nw) * (iy - iy_nw);</span>
|
||||
|
||||
<span class="s2"> int batch_idx = elem / C / gH / gW * b_stride;</span>
|
||||
<span class="s2"> int channel_idx = elem % C;</span>
|
||||
<span class="s2"> int base_idx = batch_idx + channel_idx;</span>
|
||||
|
||||
<span class="s2"> T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];</span>
|
||||
<span class="s2"> T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];</span>
|
||||
<span class="s2"> T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];</span>
|
||||
<span class="s2"> T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];</span>
|
||||
|
||||
<span class="s2"> I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;</span>
|
||||
<span class="s2"> I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;</span>
|
||||
<span class="s2"> I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;</span>
|
||||
<span class="s2"> I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;</span>
|
||||
|
||||
<span class="s2"> out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;</span>
|
||||
<span class="s2"> """</span>
|
||||
<span class="n">kernel</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">fast</span><span class="o">.</span><span class="n">metal_kernel</span><span class="p">(</span>
|
||||
<span class="n">name</span><span class="o">=</span><span class="s2">"grid_sample"</span><span class="p">,</span>
|
||||
<span class="n">input_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"x"</span><span class="p">,</span> <span class="s2">"grid"</span><span class="p">],</span>
|
||||
<span class="n">output_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"out"</span><span class="p">],</span>
|
||||
<span class="n">source</span><span class="o">=</span><span class="n">source</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
<span class="n">outputs</span> <span class="o">=</span> <span class="n">kernel</span><span class="p">(</span>
|
||||
<span class="n">inputs</span><span class="o">=</span><span class="p">[</span><span class="n">x</span><span class="p">,</span> <span class="n">grid</span><span class="p">],</span>
|
||||
<span class="n">template</span><span class="o">=</span><span class="p">[(</span><span class="s2">"T"</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)],</span>
|
||||
@@ -1191,10 +1204,10 @@ to write a fast GPU kernel for both the forward and backward passes.</p>
|
||||
</section>
|
||||
<section id="grid-sample-vjp">
|
||||
<h2>Grid Sample VJP<a class="headerlink" href="#grid-sample-vjp" title="Link to this heading">#</a></h2>
|
||||
<p>Since we decorated <code class="docutils literal notranslate"><span class="pre">grid_sample</span></code> with <code class="docutils literal notranslate"><span class="pre">mx.custom_function</span></code>, we can now define
|
||||
its custom vjp transform so MLX can differentiate it.</p>
|
||||
<p>Since we decorated <code class="docutils literal notranslate"><span class="pre">grid_sample</span></code> with <a class="reference internal" href="../python/_autosummary/mlx.core.custom_function.html#mlx.core.custom_function" title="mlx.core.custom_function"><code class="xref py py-func docutils literal notranslate"><span class="pre">custom_function()</span></code></a>, we can now
|
||||
define its custom vjp transform so MLX can differentiate it.</p>
|
||||
<p>The backwards pass requires atomically updating <code class="docutils literal notranslate"><span class="pre">x_grad</span></code>/<code class="docutils literal notranslate"><span class="pre">grid_grad</span></code> and so
|
||||
requires a few extra <code class="docutils literal notranslate"><span class="pre">mx.fast.metal_kernel</span></code> features:</p>
|
||||
requires a few extra <a class="reference internal" href="../python/_autosummary/mlx.core.fast.metal_kernel.html#mlx.core.fast.metal_kernel" title="mlx.core.fast.metal_kernel"><code class="xref py py-func docutils literal notranslate"><span class="pre">fast.metal_kernel()</span></code></a> features:</p>
|
||||
<ul class="simple">
|
||||
<li><dl class="simple">
|
||||
<dt><code class="docutils literal notranslate"><span class="pre">init_value=0</span></code></dt><dd><p>Initialize all of the kernel’s outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.</p>
|
||||
@@ -1210,7 +1223,107 @@ See section 6.15 of the <a class="reference external" href="https://developer.ap
|
||||
</li>
|
||||
</ul>
|
||||
<p>We can then implement the backwards pass as follows:</p>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nd">@grid_sample</span><span class="o">.</span><span class="n">vjp</span>
|
||||
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">source</span> <span class="o">=</span> <span class="s2">"""</span>
|
||||
<span class="s2"> uint elem = thread_position_in_grid.x;</span>
|
||||
<span class="s2"> int H = x_shape[1];</span>
|
||||
<span class="s2"> int W = x_shape[2];</span>
|
||||
<span class="s2"> int C = x_shape[3];</span>
|
||||
<span class="s2"> // Pad C to the nearest larger simdgroup size multiple</span>
|
||||
<span class="s2"> int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;</span>
|
||||
|
||||
<span class="s2"> int gH = grid_shape[1];</span>
|
||||
<span class="s2"> int gW = grid_shape[2];</span>
|
||||
|
||||
<span class="s2"> int w_stride = C;</span>
|
||||
<span class="s2"> int h_stride = W * w_stride;</span>
|
||||
<span class="s2"> int b_stride = H * h_stride;</span>
|
||||
|
||||
<span class="s2"> uint grid_idx = elem / C_padded * 2;</span>
|
||||
<span class="s2"> float ix = ((grid[grid_idx] + 1) * W - 1) / 2;</span>
|
||||
<span class="s2"> float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;</span>
|
||||
|
||||
<span class="s2"> int ix_nw = floor(ix);</span>
|
||||
<span class="s2"> int iy_nw = floor(iy);</span>
|
||||
|
||||
<span class="s2"> int ix_ne = ix_nw + 1;</span>
|
||||
<span class="s2"> int iy_ne = iy_nw;</span>
|
||||
|
||||
<span class="s2"> int ix_sw = ix_nw;</span>
|
||||
<span class="s2"> int iy_sw = iy_nw + 1;</span>
|
||||
|
||||
<span class="s2"> int ix_se = ix_nw + 1;</span>
|
||||
<span class="s2"> int iy_se = iy_nw + 1;</span>
|
||||
|
||||
<span class="s2"> T nw = (ix_se - ix) * (iy_se - iy);</span>
|
||||
<span class="s2"> T ne = (ix - ix_sw) * (iy_sw - iy);</span>
|
||||
<span class="s2"> T sw = (ix_ne - ix) * (iy - iy_ne);</span>
|
||||
<span class="s2"> T se = (ix - ix_nw) * (iy - iy_nw);</span>
|
||||
|
||||
<span class="s2"> int batch_idx = elem / C_padded / gH / gW * b_stride;</span>
|
||||
<span class="s2"> int channel_idx = elem % C_padded;</span>
|
||||
<span class="s2"> int base_idx = batch_idx + channel_idx;</span>
|
||||
|
||||
<span class="s2"> T gix = T(0);</span>
|
||||
<span class="s2"> T giy = T(0);</span>
|
||||
<span class="s2"> if (channel_idx < C) {</span>
|
||||
<span class="s2"> int cot_index = elem / C_padded * C + channel_idx;</span>
|
||||
<span class="s2"> T cot = cotangent[cot_index];</span>
|
||||
<span class="s2"> if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {</span>
|
||||
<span class="s2"> int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;</span>
|
||||
<span class="s2"> atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);</span>
|
||||
|
||||
<span class="s2"> T I_nw = x[offset];</span>
|
||||
<span class="s2"> gix -= I_nw * (iy_se - iy) * cot;</span>
|
||||
<span class="s2"> giy -= I_nw * (ix_se - ix) * cot;</span>
|
||||
<span class="s2"> }</span>
|
||||
<span class="s2"> if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {</span>
|
||||
<span class="s2"> int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;</span>
|
||||
<span class="s2"> atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);</span>
|
||||
|
||||
<span class="s2"> T I_ne = x[offset];</span>
|
||||
<span class="s2"> gix += I_ne * (iy_sw - iy) * cot;</span>
|
||||
<span class="s2"> giy -= I_ne * (ix - ix_sw) * cot;</span>
|
||||
<span class="s2"> }</span>
|
||||
<span class="s2"> if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {</span>
|
||||
<span class="s2"> int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;</span>
|
||||
<span class="s2"> atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);</span>
|
||||
|
||||
<span class="s2"> T I_sw = x[offset];</span>
|
||||
<span class="s2"> gix -= I_sw * (iy - iy_ne) * cot;</span>
|
||||
<span class="s2"> giy += I_sw * (ix_ne - ix) * cot;</span>
|
||||
<span class="s2"> }</span>
|
||||
<span class="s2"> if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {</span>
|
||||
<span class="s2"> int offset = base_idx + iy_se * h_stride + ix_se * w_stride;</span>
|
||||
<span class="s2"> atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);</span>
|
||||
|
||||
<span class="s2"> T I_se = x[offset];</span>
|
||||
<span class="s2"> gix += I_se * (iy - iy_nw) * cot;</span>
|
||||
<span class="s2"> giy += I_se * (ix - ix_nw) * cot;</span>
|
||||
<span class="s2"> }</span>
|
||||
<span class="s2"> }</span>
|
||||
|
||||
<span class="s2"> T gix_mult = W / 2;</span>
|
||||
<span class="s2"> T giy_mult = H / 2;</span>
|
||||
|
||||
<span class="s2"> // Reduce across each simdgroup first.</span>
|
||||
<span class="s2"> // This is much faster than relying purely on atomics.</span>
|
||||
<span class="s2"> gix = simd_sum(gix);</span>
|
||||
<span class="s2"> giy = simd_sum(giy);</span>
|
||||
|
||||
<span class="s2"> if (thread_index_in_simdgroup == 0) {</span>
|
||||
<span class="s2"> atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);</span>
|
||||
<span class="s2"> atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);</span>
|
||||
<span class="s2"> }</span>
|
||||
<span class="s2">"""</span>
|
||||
<span class="n">kernel</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">fast</span><span class="o">.</span><span class="n">metal_kernel</span><span class="p">(</span>
|
||||
<span class="n">name</span><span class="o">=</span><span class="s2">"grid_sample_grad"</span><span class="p">,</span>
|
||||
<span class="n">input_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"x"</span><span class="p">,</span> <span class="s2">"grid"</span><span class="p">,</span> <span class="s2">"cotangent"</span><span class="p">],</span>
|
||||
<span class="n">output_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"x_grad"</span><span class="p">,</span> <span class="s2">"grid_grad"</span><span class="p">],</span>
|
||||
<span class="n">source</span><span class="o">=</span><span class="n">source</span><span class="p">,</span>
|
||||
<span class="n">atomic_outputs</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
|
||||
<span class="nd">@grid_sample</span><span class="o">.</span><span class="n">vjp</span>
|
||||
<span class="k">def</span><span class="w"> </span><span class="nf">grid_sample_vjp</span><span class="p">(</span><span class="n">primals</span><span class="p">,</span> <span class="n">cotangent</span><span class="p">,</span> <span class="n">_</span><span class="p">):</span>
|
||||
<span class="n">x</span><span class="p">,</span> <span class="n">grid</span> <span class="o">=</span> <span class="n">primals</span>
|
||||
<span class="n">B</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
|
||||
@@ -1218,105 +1331,6 @@ See section 6.15 of the <a class="reference external" href="https://developer.ap
|
||||
|
||||
<span class="k">assert</span> <span class="n">D</span> <span class="o">==</span> <span class="mi">2</span><span class="p">,</span> <span class="s2">"Last dim of `grid` must be size 2."</span>
|
||||
|
||||
<span class="n">source</span> <span class="o">=</span> <span class="s2">"""</span>
|
||||
<span class="s2"> uint elem = thread_position_in_grid.x;</span>
|
||||
<span class="s2"> int H = x_shape[1];</span>
|
||||
<span class="s2"> int W = x_shape[2];</span>
|
||||
<span class="s2"> int C = x_shape[3];</span>
|
||||
<span class="s2"> // Pad C to the nearest larger simdgroup size multiple</span>
|
||||
<span class="s2"> int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;</span>
|
||||
|
||||
<span class="s2"> int gH = grid_shape[1];</span>
|
||||
<span class="s2"> int gW = grid_shape[2];</span>
|
||||
|
||||
<span class="s2"> int w_stride = C;</span>
|
||||
<span class="s2"> int h_stride = W * w_stride;</span>
|
||||
<span class="s2"> int b_stride = H * h_stride;</span>
|
||||
|
||||
<span class="s2"> uint grid_idx = elem / C_padded * 2;</span>
|
||||
<span class="s2"> float ix = ((grid[grid_idx] + 1) * W - 1) / 2;</span>
|
||||
<span class="s2"> float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;</span>
|
||||
|
||||
<span class="s2"> int ix_nw = floor(ix);</span>
|
||||
<span class="s2"> int iy_nw = floor(iy);</span>
|
||||
|
||||
<span class="s2"> int ix_ne = ix_nw + 1;</span>
|
||||
<span class="s2"> int iy_ne = iy_nw;</span>
|
||||
|
||||
<span class="s2"> int ix_sw = ix_nw;</span>
|
||||
<span class="s2"> int iy_sw = iy_nw + 1;</span>
|
||||
|
||||
<span class="s2"> int ix_se = ix_nw + 1;</span>
|
||||
<span class="s2"> int iy_se = iy_nw + 1;</span>
|
||||
|
||||
<span class="s2"> T nw = (ix_se - ix) * (iy_se - iy);</span>
|
||||
<span class="s2"> T ne = (ix - ix_sw) * (iy_sw - iy);</span>
|
||||
<span class="s2"> T sw = (ix_ne - ix) * (iy - iy_ne);</span>
|
||||
<span class="s2"> T se = (ix - ix_nw) * (iy - iy_nw);</span>
|
||||
|
||||
<span class="s2"> int batch_idx = elem / C_padded / gH / gW * b_stride;</span>
|
||||
<span class="s2"> int channel_idx = elem % C_padded;</span>
|
||||
<span class="s2"> int base_idx = batch_idx + channel_idx;</span>
|
||||
|
||||
<span class="s2"> T gix = T(0);</span>
|
||||
<span class="s2"> T giy = T(0);</span>
|
||||
<span class="s2"> if (channel_idx < C) {</span>
|
||||
<span class="s2"> int cot_index = elem / C_padded * C + channel_idx;</span>
|
||||
<span class="s2"> T cot = cotangent[cot_index];</span>
|
||||
<span class="s2"> if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {</span>
|
||||
<span class="s2"> int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;</span>
|
||||
<span class="s2"> atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);</span>
|
||||
|
||||
<span class="s2"> T I_nw = x[offset];</span>
|
||||
<span class="s2"> gix -= I_nw * (iy_se - iy) * cot;</span>
|
||||
<span class="s2"> giy -= I_nw * (ix_se - ix) * cot;</span>
|
||||
<span class="s2"> }</span>
|
||||
<span class="s2"> if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {</span>
|
||||
<span class="s2"> int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;</span>
|
||||
<span class="s2"> atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);</span>
|
||||
|
||||
<span class="s2"> T I_ne = x[offset];</span>
|
||||
<span class="s2"> gix += I_ne * (iy_sw - iy) * cot;</span>
|
||||
<span class="s2"> giy -= I_ne * (ix - ix_sw) * cot;</span>
|
||||
<span class="s2"> }</span>
|
||||
<span class="s2"> if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {</span>
|
||||
<span class="s2"> int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;</span>
|
||||
<span class="s2"> atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);</span>
|
||||
|
||||
<span class="s2"> T I_sw = x[offset];</span>
|
||||
<span class="s2"> gix -= I_sw * (iy - iy_ne) * cot;</span>
|
||||
<span class="s2"> giy += I_sw * (ix_ne - ix) * cot;</span>
|
||||
<span class="s2"> }</span>
|
||||
<span class="s2"> if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {</span>
|
||||
<span class="s2"> int offset = base_idx + iy_se * h_stride + ix_se * w_stride;</span>
|
||||
<span class="s2"> atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);</span>
|
||||
|
||||
<span class="s2"> T I_se = x[offset];</span>
|
||||
<span class="s2"> gix += I_se * (iy - iy_nw) * cot;</span>
|
||||
<span class="s2"> giy += I_se * (ix - ix_nw) * cot;</span>
|
||||
<span class="s2"> }</span>
|
||||
<span class="s2"> }</span>
|
||||
|
||||
<span class="s2"> T gix_mult = W / 2;</span>
|
||||
<span class="s2"> T giy_mult = H / 2;</span>
|
||||
|
||||
<span class="s2"> // Reduce across each simdgroup first.</span>
|
||||
<span class="s2"> // This is much faster than relying purely on atomics.</span>
|
||||
<span class="s2"> gix = simd_sum(gix);</span>
|
||||
<span class="s2"> giy = simd_sum(giy);</span>
|
||||
|
||||
<span class="s2"> if (thread_index_in_simdgroup == 0) {</span>
|
||||
<span class="s2"> atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);</span>
|
||||
<span class="s2"> atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);</span>
|
||||
<span class="s2"> }</span>
|
||||
<span class="s2"> """</span>
|
||||
<span class="n">kernel</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">fast</span><span class="o">.</span><span class="n">metal_kernel</span><span class="p">(</span>
|
||||
<span class="n">name</span><span class="o">=</span><span class="s2">"grid_sample_grad"</span><span class="p">,</span>
|
||||
<span class="n">input_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"x"</span><span class="p">,</span> <span class="s2">"grid"</span><span class="p">,</span> <span class="s2">"cotangent"</span><span class="p">],</span>
|
||||
<span class="n">output_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"x_grad"</span><span class="p">,</span> <span class="s2">"grid_grad"</span><span class="p">],</span>
|
||||
<span class="n">source</span><span class="o">=</span><span class="n">source</span><span class="p">,</span>
|
||||
<span class="n">atomic_outputs</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
||||
<span class="p">)</span>
|
||||
<span class="c1"># pad the output channels to simd group size</span>
|
||||
<span class="c1"># so that our `simd_sum`s don't overlap.</span>
|
||||
<span class="n">simdgroup_size</span> <span class="o">=</span> <span class="mi">32</span>
|
||||
|
Reference in New Issue
Block a user