Thanks to my brother for the jax environment, which is completely configured according to his conda_env.yml
(How to export conda_env.yml of other environments:Conda | How to copy the conda environment of the old server (on the new server), Linux server)
- 01 Install various libraries
- 02 Install jax
- 03 Install dm_control metaworld d4rl
- 04 test
- 05 Reference versions of various libraries
First, create a new conda environment:
conda create -n jax_env python==3.8
conda activate jax_env
(How to configure conda:Conda | How to install conda on a Linux server)
01 Install various libraries
Direct pip installation:
pip install numpy==1.21.6 torch==1.13.1 wandb==0.15.10 \
transformers==4.30.2 typing-extensions==4.7.1 optax==0.1.4 \
jax==0.3.24 flax==0.6.0 cloudpickle==2.2.1 distrax==0.1.3 \
glfw==2.6.2 gym==0.15.7
02 Install jax
jax puts its own library on the website:
- /jax-releases/jax_releases.html
- /jax-releases/jax_cuda_releases.html
To install jax 0.3.24, run:
pip install "jax[cuda11_cudnn82]==0.3.24" \
-f /jax-releases/jax_cuda_releases.html
Need to pay attention to:
- Libraries such as jax, jaxlib, optax flax, etc. have corresponding versions. You can install them according to the reference version in this blog;
- You need pip install cloudpickle==2.2.1. It seems to be easy to install to version 1.2.2. Finally, you need to check the version;
- When compiling, if an error occurs because the ptxas version is too low, you can run which ptxas to check which ptxas version is being used. If you find that you are using an old cuda version, change the path, modify ~/.bashrc, and add
export PATH="/usr/local/cuda-{version number}/bin:$PATH"
export LD_LIBRARY_PATH="/usr/local/cuda-{version number}/lib64:$LD_LIBRARY_PATH"
# The cuda version number can be found in the /usr/local directory. I am using 11.7.
03 Install dm_control metaworld d4rl
You need to install MuJoCo first, see this article:Python · MuJoCo | MuJoCo corresponds to the version of mujoco_py, and installs Cython
First, take down the three libraries dm_control metaworld d4rl:
git clone git@:Farama-Foundation/
git clone git@:Farama-Foundation/
git clone git@:denisyarats/
Then enter their paths respectively and execute pip install -e .
04 test
What I ran was/csmile-1006/PreferenceTransformerThis library also contains the jax implementation of IQL, so this environment should be able to run IQL jax)
05 Reference versions of various libraries
The following is a reference environment version:
name: jax_env
channels:
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- ca-certificates=2023.08.22=h06a4308_0
- certifi=2022.12.7=py37h06a4308_0
- ld_impl_linux-64=2.38=h1181459_1
- libffi=3.3=he6710b0_2
- libgcc-ng=9.1.0=hdf63c60_0
- libstdcxx-ng=9.1.0=hdf63c60_0
- ncurses=6.3=h7f8727e_2
- openssl=1.1.1w=h7f8727e_0
- pip=22.3.1=py37h06a4308_0
- python=3.7.13=h12debd9_0
- readline=8.1.2=h7f8727e_1
- setuptools=65.6.3=py37h06a4308_0
- sqlite=3.38.5=hc218d9a_0
- tk=8.6.12=h1ccaba5_0
- wheel=0.38.4=py37h06a4308_0
- xz=5.2.5=h7f8727e_1
- zlib=1.2.12=h7f8727e_2
- pip:
- absl-py==1.4.0
- appdirs==1.4.4
- beautifulsoup4==4.12.2
- cffi==1.15.1
- charset-normalizer==3.2.0
- chex==0.1.5
- click==8.1.7
- cloudpickle==2.2.1
- colorama==0.4.6
- commonmark==0.9.1
- contextlib2==21.6.0
- cycler==0.11.0
- cython==3.0.2
- decorator==5.1.1
- distrax==0.1.3
- dm-control==1.0.13
- dm-env==1.6
- dm-tree==0.1.8
- docker-pycreds==0.4.0
- etils==0.9.0
- fasteners==0.18
- filelock==3.12.2
- flax==0.6.0
- fonttools==4.38.0
- fsspec==2023.1.0
- future==0.18.3
- gast==0.5.4
- gdown==4.7.1
- gitdb==4.0.10
- gitpython==3.1.36
- glfw==2.6.2
- gym==0.15.7
- gym-notices==0.0.8
- h5py==3.8.0
- huggingface-hub==0.16.4
- idna==3.4
- imageio==2.31.2
- imageio-ffmpeg==0.4.9
- importlib-metadata==6.7.0
- importlib-resources==5.12.0
- jax==0.3.24
- jaxlib==0.3.24+cuda11.cudnn82
- joblib==1.3.2
- kiwisolver==1.4.5
- labmaze==1.0.6
- lxml==4.9.3
- matplotlib==3.5.3
- ml-collections==0.1.1
- msgpack==1.0.5
- mujoco==2.3.6
- mujoco-py==2.0.2.13
- numpy==1.21.6
- nvidia-cublas-cu11==11.10.3.66
- nvidia-cuda-nvrtc-cu11==11.7.99
- nvidia-cuda-runtime-cu11==11.7.99
- nvidia-cudnn-cu11==8.5.0.96
- opt-einsum==3.3.0
- optax==0.1.4
- packaging==23.1
- pathtools==0.1.2
- pillow==9.5.0
- protobuf==3.20.1
- psutil==5.9.5
- pybullet==3.2.5
- pycparser==2.21
- pyglet==1.5.0
- pygments==2.16.1
- pyopengl==3.1.7
- pyparsing==3.1.1
- pysocks==1.7.1
- python-dateutil==2.8.2
- pyyaml==6.0.1
- regex==2023.8.8
- requests==2.31.0
- rich==11.2.0
- safetensors==0.3.3
- scikit-learn==1.0.2
- scipy==1.7.3
- sentry-sdk==1.31.0
- setproctitle==1.3.2
- six==1.16.0
- smmap==5.0.1
- soupsieve==2.4.1
- tensorboardx==2.1
- tensorflow-probability==0.19.0
- termcolor==2.3.0
- threadpoolctl==3.1.0
- tokenizers==0.13.3
- toolz==0.12.0
- torch==1.13.1
- tqdm==4.66.1
- transformers==4.30.2
- typing-extensions==4.7.1
- ujson==5.7.0
- urllib3==2.0.4
- wandb==0.15.10
- zipp==3.15.0
prefix: /home/user_name/miniconda3/envs/jax