Skip to content

Latest commit

 

History

History
68 lines (38 loc) · 2.48 KB

README.md

File metadata and controls

68 lines (38 loc) · 2.48 KB

SMORe

Official JAX code base for ICLR 2024 paper - SMORE: Score Models for Offline Goal-Conditioned Reinforcement Learning

Harshit Sikchi1, Rohan Chitnis2, Ahmed Touati2, Alborz Geramifard2, Amy Zhang1,2, Scott Niekum3,

1UT Austin

2Meta AI

3UMass Amherst


Paper

How to run the code

Install dependencies

Create an empty conda environment and follow the commands below.

conda create -n smore python=3.9

conda install -c conda-forge cudnn

pip install --upgrade pip

# Install 1 of the below jax versions depending on your CUDA version
## 1. CUDA 12 installation
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

## 2. CUDA 11 installation
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html


pip install -r requirements.txt

Offline data

The offline datasets can be downloaded from the google drive link WGCSL offline data. This dataset is provided by prior work WGCSL. Extract the offline data in root-folder/offline_data/*

Example training code

Locomotion

python train_offline_smore.py --double=True --env_name=halfcheetah-medium-v2 --config=configs/gcrl_config.py --eval_episodes=10 --eval_interval=5000  --beta=0.8 --loss_type=<'smore_stable'/'smore'> --exp_name=<exp_name>

Manipulation

python train_offline_smore.py --double=True --env_name=SawyerReach --config=configs/gcrl_config.py --eval_episodes=10 --eval_interval=5000  --beta=0.8 --loss_type=<'smore_stable'/'smore'> --exp_name=<exp_name>

Acknowledgement and Reference

This code base builds upon the following code bases: Extreme Q-learning and Implicit Q-Learning.