MLX
Loading...
Searching...
No Matches
readwrite.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#include <metal_common>
4
6
7/* FFT helpers for reading and writing from/to device memory.
8
9For many sizes, GPU FFTs are memory bandwidth bound so
10read/write performance is important.
11
12Where possible, we read 128 bits sequentially in each thread,
13coalesced with accesses from adajcent threads for optimal performance.
14
15We implement specialized reading/writing for:
16 - FFT
17 - RFFT
18 - IRFFT
19
20Each with support for:
21 - Contiguous reads
22 - Padded reads
23 - Strided reads
24*/
25
26#define MAX_RADIX 13
27
28using namespace metal;
29
30template <
31 typename in_T,
32 typename out_T,
33 int step = 0,
34 bool four_step_real = false>
35struct ReadWriter {
36 const device in_T* in;
37 threadgroup float2* buf;
38 device out_T* out;
39 int n;
42 uint3 elem;
43 uint3 grid;
45 bool inv;
46
47 // Used for strided access
50
51 METAL_FUNC ReadWriter(
52 const device in_T* in_,
53 threadgroup float2* buf_,
54 device out_T* out_,
55 const short n_,
56 const int batch_size_,
57 const short elems_per_thread_,
58 const uint3 elem_,
59 const uint3 grid_,
60 const bool inv_)
61 : in(in_),
62 buf(buf_),
63 out(out_),
64 n(n_),
65 batch_size(batch_size_),
67 elem(elem_),
68 grid(grid_),
69 inv(inv_) {
70 // Account for padding on last threadgroup
71 threads_per_tg = elem.x == grid.x - 1
72 ? (batch_size - (grid.x - 1) * grid.y) * grid.z
73 : grid.y * grid.z;
74 }
75
76 // ifft(x) = 1/n * conj(fft(conj(x)))
77 METAL_FUNC float2 post_in(float2 elem) const {
78 return inv ? float2(elem.x, -elem.y) : elem;
79 }
80
81 // Handle float case for generic RFFT alg
82 METAL_FUNC float2 post_in(float elem) const {
83 return float2(elem, 0);
84 }
85
86 METAL_FUNC float2 pre_out(float2 elem) const {
87 return inv ? float2(elem.x / n, -elem.y / n) : elem;
88 }
89
90 METAL_FUNC float2 pre_out(float2 elem, int length) const {
91 return inv ? float2(elem.x / length, -elem.y / length) : elem;
92 }
93
94 METAL_FUNC bool out_of_bounds() const {
95 // Account for possible extra threadgroups
96 int grid_index = elem.x * grid.y + elem.y;
97 return grid_index >= batch_size;
98 }
99
100 METAL_FUNC void load() const {
101 int batch_idx = elem.x * grid.y * n;
102 short tg_idx = elem.y * grid.z + elem.z;
103 short max_index = grid.y * n - 2;
104
105 // 2 complex64s = 128 bits
106 constexpr int read_width = 2;
107 for (short e = 0; e < (elems_per_thread / read_width); e++) {
108 short index = read_width * tg_idx + read_width * threads_per_tg * e;
109 index = metal::min(index, max_index);
110 // vectorized reads
111 buf[index] = post_in(in[batch_idx + index]);
112 buf[index + 1] = post_in(in[batch_idx + index + 1]);
113 }
114 max_index += 1;
115 if (elems_per_thread % 2 != 0) {
116 short index = tg_idx +
117 read_width * threads_per_tg * (elems_per_thread / read_width);
118 index = metal::min(index, max_index);
119 buf[index] = post_in(in[batch_idx + index]);
120 }
121 }
122
123 METAL_FUNC void write() const {
124 int batch_idx = elem.x * grid.y * n;
125 short tg_idx = elem.y * grid.z + elem.z;
126 short max_index = grid.y * n - 2;
127
128 constexpr int read_width = 2;
129 for (short e = 0; e < (elems_per_thread / read_width); e++) {
130 short index = read_width * tg_idx + read_width * threads_per_tg * e;
131 index = metal::min(index, max_index);
132 // vectorized reads
133 out[batch_idx + index] = pre_out(buf[index]);
134 out[batch_idx + index + 1] = pre_out(buf[index + 1]);
135 }
136 max_index += 1;
137 if (elems_per_thread % 2 != 0) {
138 short index = tg_idx +
139 read_width * threads_per_tg * (elems_per_thread / read_width);
140 index = metal::min(index, max_index);
141 out[batch_idx + index] = pre_out(buf[index]);
142 }
143 }
144
145 // Padded IO for Bluestein's algorithm
146 METAL_FUNC void load_padded(int length, const device float2* w_k) const {
147 int batch_idx = elem.x * grid.y * length + elem.y * length;
148 int fft_idx = elem.z;
149 int m = grid.z;
150
151 threadgroup float2* seq_buf = buf + elem.y * n;
152 for (int e = 0; e < elems_per_thread; e++) {
153 int index = metal::min(fft_idx + e * m, n - 1);
154 if (index < length) {
155 float2 elem = post_in(in[batch_idx + index]);
156 seq_buf[index] = complex_mul(elem, w_k[index]);
157 } else {
158 seq_buf[index] = 0.0;
159 }
160 }
161 }
162
163 METAL_FUNC void write_padded(int length, const device float2* w_k) const {
164 int batch_idx = elem.x * grid.y * length + elem.y * length;
165 int fft_idx = elem.z;
166 int m = grid.z;
167 float2 inv_factor = {1.0f / n, -1.0f / n};
168
169 threadgroup float2* seq_buf = buf + elem.y * n;
170 for (int e = 0; e < elems_per_thread; e++) {
171 int index = metal::min(fft_idx + e * m, n - 1);
172 if (index < length) {
173 float2 elem = seq_buf[index + length - 1] * inv_factor;
174 out[batch_idx + index] = pre_out(complex_mul(elem, w_k[index]), length);
175 }
176 }
177 }
178
179 // Strided IO for four step FFT
180 METAL_FUNC void compute_strided_indices(int stride, int overall_n) {
181 // Use the batch threadgroup dimension to coalesce memory accesses:
182 // e.g. stride = 12
183 // device | shared mem
184 // 0 1 2 3 | 0 12 - -
185 // - - - - | 1 13 - -
186 // - - - - | 2 14 - -
187 // 12 13 14 15 | 3 15 - -
188 int coalesce_width = grid.y;
189 int tg_idx = elem.y * grid.z + elem.z;
190 int outer_batch_size = stride / coalesce_width;
191
192 int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width +
193 overall_n * (elem.x / outer_batch_size);
194 strided_device_idx = strided_batch_idx +
195 tg_idx / coalesce_width * elems_per_thread * stride +
196 tg_idx % coalesce_width;
197 strided_shared_idx = (tg_idx % coalesce_width) * n +
198 tg_idx / coalesce_width * elems_per_thread;
199 }
200
201 // Four Step FFT First Step
202 METAL_FUNC void load_strided(int stride, int overall_n) {
203 compute_strided_indices(stride, overall_n);
204 for (int e = 0; e < elems_per_thread; e++) {
206 post_in(in[strided_device_idx + e * stride]);
207 }
208 }
209
210 METAL_FUNC void write_strided(int stride, int overall_n) {
211 for (int e = 0; e < elems_per_thread; e++) {
212 float2 output = buf[strided_shared_idx + e];
213 int combined_idx = (strided_device_idx + e * stride) % overall_n;
214 int ij = (combined_idx / stride) * (combined_idx % stride);
215 // Apply four step twiddles at end of first step
216 float2 twiddle = get_twiddle(ij, overall_n);
217 out[strided_device_idx + e * stride] = complex_mul(output, twiddle);
218 }
219 }
220};
221
222// Four Step FFT Second Step
223template <>
224METAL_FUNC void ReadWriter<float2, float2, /*step=*/1>::load_strided(
225 int stride,
226 int overall_n) {
227 // Silence compiler warnings
228 (void)stride;
229 (void)overall_n;
230 // Don't invert between steps
231 bool default_inv = inv;
232 inv = false;
233 load();
234 inv = default_inv;
235}
236
237template <>
238METAL_FUNC void ReadWriter<float2, float2, /*step=*/1>::write_strided(
239 int stride,
240 int overall_n) {
241 compute_strided_indices(stride, overall_n);
242 for (int e = 0; e < elems_per_thread; e++) {
243 float2 output = buf[strided_shared_idx + e];
244 out[strided_device_idx + e * stride] = pre_out(output, overall_n);
245 }
246}
247
248// For RFFT, we interleave batches of two real sequences into one complex one:
249//
250// z_k = x_k + j.y_k
251// X_k = (Z_k + Z_(N-k)*) / 2
252// Y_k = -j * ((Z_k - Z_(N-k)*) / 2)
253//
254// This roughly doubles the throughput over the regular FFT.
255template <>
257 int grid_index = elem.x * grid.y + elem.y;
258 // We pack two sequences into one for RFFTs
259 return grid_index * 2 >= batch_size;
260}
261
262template <>
263METAL_FUNC void ReadWriter<float, float2>::load() const {
264 int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2;
265 threadgroup float2* seq_buf = buf + elem.y * n;
266
267 // No out of bounds accesses on odd batch sizes
268 int grid_index = elem.x * grid.y + elem.y;
269 short next_in =
270 batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n;
271
272 short m = grid.z;
273 short fft_idx = elem.z;
274
275 for (int e = 0; e < elems_per_thread; e++) {
276 int index = metal::min(fft_idx + e * m, n - 1);
277 seq_buf[index].x = in[batch_idx + index];
278 seq_buf[index].y = in[batch_idx + index + next_in];
279 }
280}
281
282template <>
283METAL_FUNC void ReadWriter<float, float2>::write() const {
284 short n_over_2 = (n / 2) + 1;
285
286 int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
287 threadgroup float2* seq_buf = buf + elem.y * n;
288
289 int grid_index = elem.x * grid.y + elem.y;
290 short next_out =
291 batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2;
292
293 float2 conj = {1, -1};
294 float2 minus_j = {0, -1};
295
296 short m = grid.z;
297 short fft_idx = elem.z;
298
299 for (int e = 0; e < elems_per_thread / 2 + 1; e++) {
300 int index = metal::min(fft_idx + e * m, n_over_2 - 1);
301 // x_0 = z_0.real
302 // y_0 = z_0.imag
303 if (index == 0) {
304 out[batch_idx + index] = {seq_buf[index].x, 0};
305 out[batch_idx + index + next_out] = {seq_buf[index].y, 0};
306 } else {
307 float2 x_k = seq_buf[index];
308 float2 x_n_minus_k = seq_buf[n - index] * conj;
309 out[batch_idx + index] = (x_k + x_n_minus_k) / 2;
310 out[batch_idx + index + next_out] =
311 complex_mul(((x_k - x_n_minus_k) / 2), minus_j);
312 }
313 }
314}
315
316template <>
318 int length,
319 const device float2* w_k) const {
320 int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
321 threadgroup float2* seq_buf = buf + elem.y * n;
322
323 // No out of bounds accesses on odd batch sizes
324 int grid_index = elem.x * grid.y + elem.y;
325 short next_in =
326 batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length;
327
328 short m = grid.z;
329 short fft_idx = elem.z;
330
331 for (int e = 0; e < elems_per_thread; e++) {
332 int index = metal::min(fft_idx + e * m, n - 1);
333 if (index < length) {
334 float2 elem =
335 float2(in[batch_idx + index], in[batch_idx + index + next_in]);
336 seq_buf[index] = complex_mul(elem, w_k[index]);
337 } else {
338 seq_buf[index] = 0;
339 }
340 }
341}
342
343template <>
345 int length,
346 const device float2* w_k) const {
347 int length_over_2 = (length / 2) + 1;
348 int batch_idx =
349 elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
350 threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
351
352 int grid_index = elem.x * grid.y + elem.y;
353 short next_out = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1
354 ? 0
355 : length_over_2;
356
357 float2 conj = {1, -1};
358 float2 inv_factor = {1.0f / n, -1.0f / n};
359 float2 minus_j = {0, -1};
360
361 short m = grid.z;
362 short fft_idx = elem.z;
363
364 for (int e = 0; e < elems_per_thread / 2 + 1; e++) {
365 int index = metal::min(fft_idx + e * m, length_over_2 - 1);
366 // x_0 = z_0.real
367 // y_0 = z_0.imag
368 if (index == 0) {
369 float2 elem = complex_mul(w_k[index], seq_buf[index] * inv_factor);
370 out[batch_idx + index] = float2(elem.x, 0);
371 out[batch_idx + index + next_out] = float2(elem.y, 0);
372 } else {
373 float2 x_k = complex_mul(w_k[index], seq_buf[index] * inv_factor);
374 float2 x_n_minus_k = complex_mul(
375 w_k[length - index], seq_buf[length - index] * inv_factor);
376 x_n_minus_k *= conj;
377 // w_k should happen before this extraction
378 out[batch_idx + index] = (x_k + x_n_minus_k) / 2;
379 out[batch_idx + index + next_out] =
380 complex_mul(((x_k - x_n_minus_k) / 2), minus_j);
381 }
382 }
383}
384
385// For IRFFT, we do the opposite
386//
387// Z_k = X_k + j.Y_k
388// x_k = Re(Z_k)
389// Y_k = Imag(Z_k)
390template <>
392 int grid_index = elem.x * grid.y + elem.y;
393 // We pack two sequences into one for IRFFTs
394 return grid_index * 2 >= batch_size;
395}
396
397template <>
398METAL_FUNC void ReadWriter<float2, float>::load() const {
399 short n_over_2 = (n / 2) + 1;
400 int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
401 threadgroup float2* seq_buf = buf + elem.y * n;
402
403 // No out of bounds accesses on odd batch sizes
404 int grid_index = elem.x * grid.y + elem.y;
405 short next_in =
406 batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2;
407
408 short m = grid.z;
409 short fft_idx = elem.z;
410
411 float2 conj = {1, -1};
412 float2 plus_j = {0, 1};
413
414 for (int t = 0; t < elems_per_thread / 2 + 1; t++) {
415 int index = metal::min(fft_idx + t * m, n_over_2 - 1);
416 float2 x = in[batch_idx + index];
417 float2 y = in[batch_idx + index + next_in];
418 // NumPy forces first input to be real
419 bool first_val = index == 0;
420 // NumPy forces last input on even irffts to be real
421 bool last_val = n % 2 == 0 && index == n_over_2 - 1;
422 if (first_val || last_val) {
423 x = float2(x.x, 0);
424 y = float2(y.x, 0);
425 }
426 seq_buf[index] = x + complex_mul(y, plus_j);
427 seq_buf[index].y = -seq_buf[index].y;
428 if (index > 0 && !last_val) {
429 seq_buf[n - index] = (x * conj) + complex_mul(y * conj, plus_j);
430 seq_buf[n - index].y = -seq_buf[n - index].y;
431 }
432 }
433}
434
435template <>
436METAL_FUNC void ReadWriter<float2, float>::write() const {
437 int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2;
438 threadgroup float2* seq_buf = buf + elem.y * n;
439
440 int grid_index = elem.x * grid.y + elem.y;
441 short next_out =
442 batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n;
443
444 short m = grid.z;
445 short fft_idx = elem.z;
446
447 for (int e = 0; e < elems_per_thread; e++) {
448 int index = metal::min(fft_idx + e * m, n - 1);
449 out[batch_idx + index] = seq_buf[index].x / n;
450 out[batch_idx + index + next_out] = seq_buf[index].y / -n;
451 }
452}
453
454template <>
456 int length,
457 const device float2* w_k) const {
458 int n_over_2 = (n / 2) + 1;
459 int length_over_2 = (length / 2) + 1;
460
461 int batch_idx =
462 elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
463 threadgroup float2* seq_buf = buf + elem.y * n;
464
465 // No out of bounds accesses on odd batch sizes
466 int grid_index = elem.x * grid.y + elem.y;
467 short next_in = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1
468 ? 0
469 : length_over_2;
470
471 short m = grid.z;
472 short fft_idx = elem.z;
473
474 float2 conj = {1, -1};
475 float2 plus_j = {0, 1};
476
477 for (int t = 0; t < elems_per_thread / 2 + 1; t++) {
478 int index = metal::min(fft_idx + t * m, n_over_2 - 1);
479 float2 x = in[batch_idx + index];
480 float2 y = in[batch_idx + index + next_in];
481 if (index < length_over_2) {
482 bool last_val = length % 2 == 0 && index == length_over_2 - 1;
483 if (last_val) {
484 x = float2(x.x, 0);
485 y = float2(y.x, 0);
486 }
487 float2 elem1 = x + complex_mul(y, plus_j);
488 seq_buf[index] = complex_mul(elem1 * conj, w_k[index]);
489 if (index > 0 && !last_val) {
490 float2 elem2 = (x * conj) + complex_mul(y * conj, plus_j);
491 seq_buf[length - index] =
492 complex_mul(elem2 * conj, w_k[length - index]);
493 }
494 } else {
495 short pad_index = metal::min(length + (index - length_over_2) * 2, n - 2);
496 seq_buf[pad_index] = 0;
497 seq_buf[pad_index + 1] = 0;
498 }
499 }
500}
501
502template <>
504 int length,
505 const device float2* w_k) const {
506 int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
507 threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
508
509 int grid_index = elem.x * grid.y + elem.y;
510 short next_out =
511 batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length;
512
513 short m = grid.z;
514 short fft_idx = elem.z;
515
516 float2 inv_factor = {1.0f / n, -1.0f / n};
517 for (int e = 0; e < elems_per_thread; e++) {
518 int index = fft_idx + e * m;
519 if (index < length) {
520 float2 output = complex_mul(seq_buf[index] * inv_factor, w_k[index]);
521 out[batch_idx + index] = output.x / length;
522 out[batch_idx + index + next_out] = output.y / -length;
523 }
524 }
525}
526
527// Four Step RFFT
528template <>
529METAL_FUNC void
530ReadWriter<float2, float2, /*step=*/1, /*real=*/true>::load_strided(
531 int stride,
532 int overall_n) {
533 // Silence compiler warnings
534 (void)stride;
535 (void)overall_n;
536 // Don't invert between steps
537 bool default_inv = inv;
538 inv = false;
539 load();
540 inv = default_inv;
541}
542
543template <>
544METAL_FUNC void
545ReadWriter<float2, float2, /*step=*/1, /*real=*/true>::write_strided(
546 int stride,
547 int overall_n) {
548 int overall_n_over_2 = overall_n / 2 + 1;
549 int coalesce_width = grid.y;
550 int tg_idx = elem.y * grid.z + elem.z;
551 int outer_batch_size = stride / coalesce_width;
552
553 int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width +
554 overall_n_over_2 * (elem.x / outer_batch_size);
555 strided_device_idx = strided_batch_idx +
556 tg_idx / coalesce_width * elems_per_thread / 2 * stride +
557 tg_idx % coalesce_width;
558 strided_shared_idx = (tg_idx % coalesce_width) * n +
559 tg_idx / coalesce_width * elems_per_thread / 2;
560 for (int e = 0; e < elems_per_thread / 2; e++) {
561 float2 output = buf[strided_shared_idx + e];
562 out[strided_device_idx + e * stride] = output;
563 }
564
565 // Add on n/2 + 1 element
566 if (tg_idx == 0 && elem.x % outer_batch_size == 0) {
567 out[strided_batch_idx + overall_n / 2] = buf[n / 2];
568 }
569}
570
571// Four Step IRFFT
572template <>
573METAL_FUNC void
574ReadWriter<float2, float2, /*step=*/0, /*real=*/true>::load_strided(
575 int stride,
576 int overall_n) {
577 int overall_n_over_2 = overall_n / 2 + 1;
578 auto conj = float2(1, -1);
579
580 compute_strided_indices(stride, overall_n);
581 // Translate indices in terms of N - k
582 for (int e = 0; e < elems_per_thread; e++) {
583 int device_idx = strided_device_idx + e * stride;
584 int overall_batch = device_idx / overall_n;
585 int overall_index = device_idx % overall_n;
586 if (overall_index < overall_n_over_2) {
587 device_idx -= overall_batch * (overall_n - overall_n_over_2);
588 buf[strided_shared_idx + e] = in[device_idx] * conj;
589 } else {
590 int conj_idx = overall_n - overall_index;
591 device_idx = overall_batch * overall_n_over_2 + conj_idx;
592 buf[strided_shared_idx + e] = in[device_idx];
593 }
594 }
595}
596
597template <>
598METAL_FUNC void
599ReadWriter<float2, float, /*step=*/1, /*real=*/true>::load_strided(
600 int stride,
601 int overall_n) {
602 // Silence compiler warnings
603 (void)stride;
604 (void)overall_n;
605 bool default_inv = inv;
606 inv = false;
607 load();
608 inv = default_inv;
609}
610
611template <>
612METAL_FUNC void
613ReadWriter<float2, float, /*step=*/1, /*real=*/true>::write_strided(
614 int stride,
615 int overall_n) {
616 compute_strided_indices(stride, overall_n);
617
618 for (int e = 0; e < elems_per_thread; e++) {
619 out[strided_device_idx + e * stride] =
620 pre_out(buf[strided_shared_idx + e], overall_n).x;
621 }
622}
MTL::Buffer * buf
Definition allocator.h:38
static constant constexpr const bool inv_
Definition fft.h:23
static constant constexpr const int elems_per_thread_
Definition fft.h:25
Definition bf16.h:265
METAL_FUNC bfloat16_t min(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:234
METAL_FUNC float2 complex_mul(float2 a, float2 b)
Definition radix.h:19
METAL_FUNC float2 get_twiddle(int k, int p)
Definition radix.h:29
Definition readwrite.h:35
METAL_FUNC bool out_of_bounds() const
Definition readwrite.h:94
METAL_FUNC void load() const
Definition readwrite.h:100
METAL_FUNC float2 pre_out(float2 elem, int length) const
Definition readwrite.h:90
METAL_FUNC ReadWriter(const device in_T *in_, threadgroup float2 *buf_, device out_T *out_, const short n_, const int batch_size_, const short elems_per_thread_, const uint3 elem_, const uint3 grid_, const bool inv_)
Definition readwrite.h:51
threadgroup float2 * buf
Definition readwrite.h:37
uint3 elem
Definition readwrite.h:42
int elems_per_thread
Definition readwrite.h:41
int strided_device_idx
Definition readwrite.h:48
int threads_per_tg
Definition readwrite.h:44
int n
Definition readwrite.h:39
int batch_size
Definition readwrite.h:40
METAL_FUNC float2 post_in(float elem) const
Definition readwrite.h:82
bool inv
Definition readwrite.h:45
METAL_FUNC void write_strided(int stride, int overall_n)
Definition readwrite.h:210
METAL_FUNC void compute_strided_indices(int stride, int overall_n)
Definition readwrite.h:180
METAL_FUNC float2 pre_out(float2 elem) const
Definition readwrite.h:86
METAL_FUNC void write_padded(int length, const device float2 *w_k) const
Definition readwrite.h:163
METAL_FUNC void load_strided(int stride, int overall_n)
Definition readwrite.h:202
METAL_FUNC float2 post_in(float2 elem) const
Definition readwrite.h:77
const device in_T * in
Definition readwrite.h:36
device out_T * out
Definition readwrite.h:38
METAL_FUNC void write() const
Definition readwrite.h:123
uint3 grid
Definition readwrite.h:43
int strided_shared_idx
Definition readwrite.h:49
METAL_FUNC void load_padded(int length, const device float2 *w_k) const
Definition readwrite.h:146