torch.mps.compile_shader¶
- torch.mps.compile_shader(source)[源][源]¶
從原始碼編譯計算著色器(compute shader),並允許在 Python 執行時方便地呼叫其中定義的核心。示例:
>>> lib = torch.mps.compile_shader( ... "kernel void full(device float* out, constant float& val, uint idx [[thread_position_in_grid]]) { out[idx] = val; }" ... ) >>> x = torch.zeros(16, device="mps") >>> lib.full(x, 3.14)