https://stackoverflow.com/questions/78210393/cannot-import-name-linear-util-from-jax
To fix your issue, you’ll either need to install an older version of JAX which still has jax.linear_util, or update to a newer version of flax which is compatible with more recent JAX versions.
评论区说,要么降级jax,要么升级flax。笔者这边升级flax:
pip install -U flax
通常来说问题到这里就解决了。
作者的特殊情况
但是flax与jax、jaxlib有版本兼容,因此安装flax会同步替换jax的版本。而在笔者情况下,jax不能安装过高版本的flax,最后选择了如下版本:
pip install flax==0.5.3
然后再安装:
pip install jax[cuda12]==0.4.30