mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-23 14:08:12 +08:00
Some fixes to typing (#1371)
* some fixes to typing * fix module reference * comment
This commit is contained in:
@@ -229,34 +229,35 @@ void init_fast(nb::module_& parent_module) {
|
||||
Returns:
|
||||
Callable ``metal_kernel``.
|
||||
|
||||
.. code-block:: python
|
||||
Example:
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
.. code-block:: python
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
source=source
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs={"inp": a},
|
||||
template={"T": mx.float32},
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes={"out": a.shape},
|
||||
output_dtypes={"out": a.dtype},
|
||||
verbose=True,
|
||||
)
|
||||
return outputs["out"]
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = '''
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
'''
|
||||
|
||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
source=source
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs={"inp": a},
|
||||
template={"T": mx.float32},
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes={"out": a.shape},
|
||||
output_dtypes={"out": a.dtype},
|
||||
verbose=True,
|
||||
)
|
||||
return outputs["out"]
|
||||
|
||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
)pbdoc")
|
||||
.def(
|
||||
"__call__",
|
||||
|
Reference in New Issue
Block a user