dm_haiku==0.0.12 jax[cuda12-pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html numpy==1.26.4 sentencepiece==0.2.0