Location>code7788 >text

Python · Jax | Install jax on python 3.8, run offline RL's IQL

Popularity:846 ℃/2025-01-23 11:44:06

Thanks to my brother for the jax environment, which is completely configured according to his conda_env.yml

(How to export conda_env.yml of other environments:Conda | How to copy the conda environment of the old server (on the new server), Linux server

Table of contents
  • 01 Install various libraries
  • 02 Install jax
  • 03 Install dm_control metaworld d4rl
  • 04 test
  • 05 Reference versions of various libraries


First, create a new conda environment:

conda create -n jax_env python==3.8
conda activate jax_env

(How to configure conda:Conda | How to install conda on a Linux server

01 Install various libraries

Direct pip installation:

pip install numpy==1.21.6 torch==1.13.1 wandb==0.15.10 \
transformers==4.30.2 typing-extensions==4.7.1 optax==0.1.4 \
jax==0.3.24 flax==0.6.0 cloudpickle==2.2.1 distrax==0.1.3 \
glfw==2.6.2 gym==0.15.7

02 Install jax

jax puts its own library on the website:

  • /jax-releases/jax_releases.html
  • /jax-releases/jax_cuda_releases.html

To install jax 0.3.24, run:

pip install "jax[cuda11_cudnn82]==0.3.24" \
-f /jax-releases/jax_cuda_releases.html

Need to pay attention to:

  1. Libraries such as jax, jaxlib, optax flax, etc. have corresponding versions. You can install them according to the reference version in this blog;
  2. You need pip install cloudpickle==2.2.1. It seems to be easy to install to version 1.2.2. Finally, you need to check the version;
  3. When compiling, if an error occurs because the ptxas version is too low, you can run which ptxas to check which ptxas version is being used. If you find that you are using an old cuda version, change the path, modify ~/.bashrc, and add
export PATH="/usr/local/cuda-{version number}/bin:$PATH"
 export LD_LIBRARY_PATH="/usr/local/cuda-{version number}/lib64:$LD_LIBRARY_PATH"
 # The cuda version number can be found in the /usr/local directory. I am using 11.7.

03 Install dm_control metaworld d4rl

You need to install MuJoCo first, see this article:Python · MuJoCo | MuJoCo corresponds to the version of mujoco_py, and installs Cython

First, take down the three libraries dm_control metaworld d4rl:

git clone git@:Farama-Foundation/
git clone git@:Farama-Foundation/
git clone git@:denisyarats/

Then enter their paths respectively and execute pip install -e .

04 test

What I ran was/csmile-1006/PreferenceTransformerThis library also contains the jax implementation of IQL, so this environment should be able to run IQL jax)

05 Reference versions of various libraries

The following is a reference environment version:

name: jax_env
channels:
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - ca-certificates=2023.08.22=h06a4308_0
  - certifi=2022.12.7=py37h06a4308_0
  - ld_impl_linux-64=2.38=h1181459_1
  - libffi=3.3=he6710b0_2
  - libgcc-ng=9.1.0=hdf63c60_0
  - libstdcxx-ng=9.1.0=hdf63c60_0
  - ncurses=6.3=h7f8727e_2
  - openssl=1.1.1w=h7f8727e_0
  - pip=22.3.1=py37h06a4308_0
  - python=3.7.13=h12debd9_0
  - readline=8.1.2=h7f8727e_1
  - setuptools=65.6.3=py37h06a4308_0
  - sqlite=3.38.5=hc218d9a_0
  - tk=8.6.12=h1ccaba5_0
  - wheel=0.38.4=py37h06a4308_0
  - xz=5.2.5=h7f8727e_1
  - zlib=1.2.12=h7f8727e_2
  - pip:
    - absl-py==1.4.0
    - appdirs==1.4.4
    - beautifulsoup4==4.12.2
    - cffi==1.15.1
    - charset-normalizer==3.2.0
    - chex==0.1.5
    - click==8.1.7
    - cloudpickle==2.2.1
    - colorama==0.4.6
    - commonmark==0.9.1
    - contextlib2==21.6.0
    - cycler==0.11.0
    - cython==3.0.2
    - decorator==5.1.1
    - distrax==0.1.3
    - dm-control==1.0.13
    - dm-env==1.6
    - dm-tree==0.1.8
    - docker-pycreds==0.4.0
    - etils==0.9.0
    - fasteners==0.18
    - filelock==3.12.2
    - flax==0.6.0
    - fonttools==4.38.0
    - fsspec==2023.1.0
    - future==0.18.3
    - gast==0.5.4
    - gdown==4.7.1
    - gitdb==4.0.10
    - gitpython==3.1.36
    - glfw==2.6.2
    - gym==0.15.7
    - gym-notices==0.0.8
    - h5py==3.8.0
    - huggingface-hub==0.16.4
    - idna==3.4
    - imageio==2.31.2
    - imageio-ffmpeg==0.4.9
    - importlib-metadata==6.7.0
    - importlib-resources==5.12.0
    - jax==0.3.24
    - jaxlib==0.3.24+cuda11.cudnn82
    - joblib==1.3.2
    - kiwisolver==1.4.5
    - labmaze==1.0.6
    - lxml==4.9.3
    - matplotlib==3.5.3
    - ml-collections==0.1.1
    - msgpack==1.0.5
    - mujoco==2.3.6
    - mujoco-py==2.0.2.13
    - numpy==1.21.6
    - nvidia-cublas-cu11==11.10.3.66
    - nvidia-cuda-nvrtc-cu11==11.7.99
    - nvidia-cuda-runtime-cu11==11.7.99
    - nvidia-cudnn-cu11==8.5.0.96
    - opt-einsum==3.3.0
    - optax==0.1.4
    - packaging==23.1
    - pathtools==0.1.2
    - pillow==9.5.0
    - protobuf==3.20.1
    - psutil==5.9.5
    - pybullet==3.2.5
    - pycparser==2.21
    - pyglet==1.5.0
    - pygments==2.16.1
    - pyopengl==3.1.7
    - pyparsing==3.1.1
    - pysocks==1.7.1
    - python-dateutil==2.8.2
    - pyyaml==6.0.1
    - regex==2023.8.8
    - requests==2.31.0
    - rich==11.2.0
    - safetensors==0.3.3
    - scikit-learn==1.0.2
    - scipy==1.7.3
    - sentry-sdk==1.31.0
    - setproctitle==1.3.2
    - six==1.16.0
    - smmap==5.0.1
    - soupsieve==2.4.1
    - tensorboardx==2.1
    - tensorflow-probability==0.19.0
    - termcolor==2.3.0
    - threadpoolctl==3.1.0
    - tokenizers==0.13.3
    - toolz==0.12.0
    - torch==1.13.1
    - tqdm==4.66.1
    - transformers==4.30.2
    - typing-extensions==4.7.1
    - ujson==5.7.0
    - urllib3==2.0.4
    - wandb==0.15.10
    - zipp==3.15.0
prefix: /home/user_name/miniconda3/envs/jax