英文:
XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed
问题
2023-07-31 01:53:45.016563: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:427] Loaded runtime CuDNN library: 8.5.0 but source was compiled with: 8.6.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
XlaRuntimeError Traceback (most recent call last)
Cell In[4], line 29
26 model = trainer.make_model(nmask)
28 lr_fn, opt = trainer.make_optimizer(steps_per_epoch=len(train_dl))
---> 29 state = trainer.create_train_state(jax.random.PRNGKey(0), model, opt)
30 state = checkpoints.restore_checkpoint(ckpt.parent, state)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/random.py:137, in PRNGKey(seed)
134 if np.ndim(seed):
135 raise TypeError("PRNGKey accepts a scalar seed, but was given an array of"
136 f"shape {np.shape(seed)} != (). Use jax.vmap for batching")
--> 137 key = prng.seed_with_impl(impl, seed)
138 return _return_prng_keys(True, key)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:320, in seed_with_impl(impl, seed)
319 def seed_with_impl(impl: PRNGImpl, seed: Union[int, Array]) -> PRNGKeyArrayImpl:
--> 320 return random_seed(seed, impl=impl)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:734, in random_seed(seeds, impl)
732 else:
733 seeds_arr = jnp.asarray(seeds)
--> 734 return random_seed_p.bind(seeds_arr, impl=impl)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/core.py:380, in Primitive.bind(self, *args, **params)
377 def bind(self, *args, **params):
378 assert (not config.jax_enable_checks or
379 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 380 return self.bind_with_trace(find_top_trace(args), args, params)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/core.py:383, in Primitive.bind_with_trace(self, trace, args, params)
382 def bind_with_trace(self, trace, args, params):
--> 383 out = trace.process_primitive(self, map(trace.full_raise, args), params)
384 return map(full_lower, out) if self.multiple_results else full_lower(out)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/core.py:790, in EvalTrace.process_primitive(self, primitive, tracers, params)
789 def process_primitive(self, primitive, tracers, params):
--> 790 return primitive.impl(*tracers, **params)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:746, in random_seed_impl(seeds, impl)
744 @random_seed_p.def_impl
745 def random_seed_impl(seeds, *, impl):
--> 746 base_arr = random_seed_impl_base(seeds, impl=impl)
747 return PRNGKeyArrayImpl(impl, base_arr)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:751, in random_seed_impl_base(seeds, impl)
749 def random_seed_impl_base(seeds, *, impl):
750 seed = iterated_vmap_unary(seeds.ndim, impl.seed)
--> 751 return seed(seeds)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:980, in threefry_seed(seed)
968 def threefry_seed(seed: typing.Array) -> typing.Array:
969 """Create a single raw threefry PRNG key from an integer seed.
970
971 Args:
(...)
978 first padding out with zeros).
979 """
--> 980 return _threefry_seed(seed)
[... skipping hidden 12 frame]
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/dispatch.py:463, in backend_compile(backend, module, options, host_callbacks)
458 return backend.compile(built_c, compile_options=options,
459 host_callbacks=host_callbacks)
460 # Some backends don't have host_callbacks option yet
461 # TODO(sharadmv): remove this fallback when all backends allow compile
462 # to take in host_callbacks
--> 463 return backend.compile(built_c, compile_options=options)
XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
What jax/jaxlib version are you using?
jax0.4.10, jaxlib0.4.10+cuda11.cudnn86
Which accelerator(s) are you using?
GPU
Additional system info
python3.11.4, Ubuntu22.04, cuda11.7,cudnn86
NVIDIA GPU info
enter image description here
I hope to fix this error.
英文:
Description
The OUTPUT:
2023-07-31 01:53:45.016563: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:427] Loaded runtime CuDNN library: 8.5.0 but source was compiled with: 8.6.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
XlaRuntimeError Traceback (most recent call last)
Cell In[4], line 29
26 model = trainer.make_model(nmask)
28 lr_fn, opt = trainer.make_optimizer(steps_per_epoch=len(train_dl))
---> 29 state = trainer.create_train_state(jax.random.PRNGKey(0), model, opt)
30 state = checkpoints.restore_checkpoint(ckpt.parent, state)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/random.py:137, in PRNGKey(seed)
134 if np.ndim(seed):
135 raise TypeError("PRNGKey accepts a scalar seed, but was given an array of"
136 f"shape {np.shape(seed)} != (). Use jax.vmap for batching")
--> 137 key = prng.seed_with_impl(impl, seed)
138 return _return_prng_keys(True, key)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:320, in seed_with_impl(impl, seed)
319 def seed_with_impl(impl: PRNGImpl, seed: Union[int, Array]) -> PRNGKeyArrayImpl:
--> 320 return random_seed(seed, impl=impl)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:734, in random_seed(seeds, impl)
732 else:
733 seeds_arr = jnp.asarray(seeds)
--> 734 return random_seed_p.bind(seeds_arr, impl=impl)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/core.py:380, in Primitive.bind(self, *args, **params)
377 def bind(self, *args, **params):
378 assert (not config.jax_enable_checks or
379 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 380 return self.bind_with_trace(find_top_trace(args), args, params)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/core.py:383, in Primitive.bind_with_trace(self, trace, args, params)
382 def bind_with_trace(self, trace, args, params):
--> 383 out = trace.process_primitive(self, map(trace.full_raise, args), params)
384 return map(full_lower, out) if self.multiple_results else full_lower(out)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/core.py:790, in EvalTrace.process_primitive(self, primitive, tracers, params)
789 def process_primitive(self, primitive, tracers, params):
--> 790 return primitive.impl(*tracers, **params)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:746, in random_seed_impl(seeds, impl)
744 @random_seed_p.def_impl
745 def random_seed_impl(seeds, *, impl):
--> 746 base_arr = random_seed_impl_base(seeds, impl=impl)
747 return PRNGKeyArrayImpl(impl, base_arr)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:751, in random_seed_impl_base(seeds, impl)
749 def random_seed_impl_base(seeds, *, impl):
750 seed = iterated_vmap_unary(seeds.ndim, impl.seed)
--> 751 return seed(seeds)
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:980, in threefry_seed(seed)
968 def threefry_seed(seed: typing.Array) -> typing.Array:
969 """Create a single raw threefry PRNG key from an integer seed.
970
971 Args:
(...)
978 first padding out with zeros).
979 """
--> 980 return _threefry_seed(seed)
[... skipping hidden 12 frame]
File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/dispatch.py:463, in backend_compile(backend, module, options, host_callbacks)
458 return backend.compile(built_c, compile_options=options,
459 host_callbacks=host_callbacks)
460 # Some backends don't have host_callbacks option yet
461 # TODO(sharadmv): remove this fallback when all backends allow compile
462 # to take in host_callbacks
--> 463 return backend.compile(built_c, compile_options=options)`
XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
What jax/jaxlib version are you using?
jax0.4.10, jaxlib0.4.10+cuda11.cudnn86
Which accelerator(s) are you using?
GPU
Additional system info
python3.11.4, Ubuntu22.04, cuda11.7,cudnn86
NVIDIA GPU info
I hope to fix this error.
答案1
得分: 2
错误消息告诉您出了什么问题:
加载的运行时CuDNN库:8.5.0,但源代码是使用8.6.0编译的。
如果使用二进制安装,请升级您的CuDNN库。如果从源代码构建,请确保运行时加载的库与编译配置中指定的版本兼容。
您正在使用一个已编译为CuDNN 8.6.0的jaxlib
包,但在运行时系统找到的是CuDNN 8.5.0。通常情况下,修复这个问题的方法是要么安装一个已编译为CuDNN 8.5.0的jaxlib
,要么升级您的系统CuDNN到8.6.0版本。
一个简单的方法是在新环境中使用cuda11_pip
或cuda12_pip
选项重新安装jax(请参阅安装说明),这将确保您的包版本匹配。
英文:
The error message tells you what's wrong:
Loaded runtime CuDNN library: 8.5.0 but source was compiled with: 8.6.0.
If using a binary install, upgrade your CuDNN library. If building from
sources, make sure the library loaded at runtime is compatible with the
version specified during compile configuration.
You are using a jaxlib
package that was compiled with CuDNN 8.6.0, but at runtime the system is finding CuDNN 8.5.0. In general, the way to fix this is to either install a jaxlib compiled with CuDNN 8.5.0, or to upgrade your system CuDNN to 8.6.0.
An easy way to achieve this is to reinstall jax in a fresh environment using the cuda11_pip
or cuda12_pip
options (see installation instructions), as this will ensure that your package versions match.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论