英文:
looking for a tool to predict runtime of XLA-HLO computational graph
问题
我正在寻找一个工具,可以在给定XLA-HLO计算图时打印运行时信息。
我知道有用于打印计算图操作节点的FLOPs(浮点运算次数)的HLO成本模型(分析模型)。
但是否有任何工具可以打印XLA-HLO计算图的预期运行时或任何与运行时相关的数值?
我需要其源代码或示例用法工具。谢谢
英文:
I'm looking for a tool to print the runtime when given the computational graph of XLA-HLO.
I know there are HLO cost model (analytical model) for print the FLOPs of operator node for computational graph.
But Is there any tool for print the expected runtime or any related value for runtime of XLA-HLO computational graph?
I need a source code of it or sample usage tool for it. Thanks
答案1
得分: 1
如果您正在使用JAX,您可以使用提前降级和编译 API 来了解计算的资源消耗情况。例如:
import jax
import numpy as np
def f(M, x):
for i in range(10):
x = M @ x
return x
M = np.random.randn(1000, 1000)
x = np.random.randn(1000)
print(jax.jit(f).lower(M, x).compile().cost_analysis())
[{'bytes accessed': 40080000.0,
'bytes accessed operand 0 {}': 40000000.0,
'bytes accessed operand 1 {}': 40000.0,
'bytes accessed output {}': 40000.0,
'flops': 20000000.0,
'optimal_seconds': 0.0,
'utilization operand 0 {}': 10.0,
'utilization operand 1 {}': 10.0}]
(注意:这是原文的翻译,其中的代码和链接部分未进行翻译。)
英文:
If you are using JAX, you can do this using the Ahead-of-time lowering and compilation APIs to get a sense of how resource-heavy a computation is. For example:
import jax
import numpy as np
def f(M, x):
for i in range(10):
x = M @ x
return x
M = np.random.randn(1000, 1000)
x = np.random.randn(1000)
print(jax.jit(f).lower(M, x).compile().cost_analysis())
[{'bytes accessed': 40080000.0,
'bytes accessed operand 0 {}': 40000000.0,
'bytes accessed operand 1 {}': 40000.0,
'bytes accessed output {}': 40000.0,
'flops': 20000000.0,
'optimal_seconds': 0.0,
'utilization operand 0 {}': 10.0,
'utilization operand 1 {}': 10.0}]
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论