Skip to content
This repository was archived by the owner on Sep 1, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ passing `experiment=your-name`. The logger will also save a file called

Beyond the override defaults, You can also change other configuration options,
such as the type of dynamics model
(e.g., `dynamics_model=basic_ensemble`), or the number of models in the ensemble
(e.g., `dynamics_model.model.ensemble_size=some-number`). To learn more about
(e.g., `dynamics_model=basic_ensemble`), or the number of models in the ensemble via configuration override
(e.g., `++dynamics_model.model.ensemble_size=some-number`). To learn more about
all the available options, take a look at the provided
[configuration files](https://github.com/facebookresearch/mbrl-lib/tree/main/mbrl/examples/conf).

Expand Down
5 changes: 4 additions & 1 deletion mbrl/algorithms/mbpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,10 @@ def train(

mbrl.planning.complete_agent_cfg(env, cfg.algorithm.agent)
agent = SACAgent(
cast(pytorch_sac_pranz24.SAC, hydra.utils.instantiate(cfg.algorithm.agent))
cast(
pytorch_sac_pranz24.SAC,
hydra.utils.instantiate(cfg.algorithm.agent, _recursive_=False),
)
)

work_dir = work_dir or os.getcwd()
Expand Down
2 changes: 1 addition & 1 deletion mbrl/algorithms/planet.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def train(

# Create PlaNet model
cfg.dynamics_model.action_size = env.action_space.shape[0]
planet = hydra.utils.instantiate(cfg.dynamics_model)
planet = hydra.utils.instantiate(cfg.dynamics_model, _recursive_=False)
assert isinstance(planet, mbrl.models.PlaNetModel)
model_env = ModelEnv(env, planet, no_termination, generator=rng)
trainer = ModelTrainer(planet, logger=logger, optim_lr=1e-3, optim_eps=1e-4)
Expand Down
6 changes: 3 additions & 3 deletions mbrl/env/pets_reacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ def get_EE_pos(self, states):
z = -np.sin(hinge) * np.cos(roll) * perp_all_axis
new_rot_axis = x + y + z
new_rot_perp_axis = np.cross(new_rot_axis, rot_axis)
new_rot_perp_axis[
np.linalg.norm(new_rot_perp_axis, axis=1) < 1e-30
] = rot_perp_axis[np.linalg.norm(new_rot_perp_axis, axis=1) < 1e-30]
new_rot_perp_axis[np.linalg.norm(new_rot_perp_axis, axis=1) < 1e-30] = (
rot_perp_axis[np.linalg.norm(new_rot_perp_axis, axis=1) < 1e-30]
)
new_rot_perp_axis /= np.linalg.norm(
new_rot_perp_axis, axis=1, keepdims=True
)
Expand Down
13 changes: 7 additions & 6 deletions mbrl/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,19 @@ The config files are generally structured in 4 groups:
* `action_optimizer`: describes possible optimizers to use for action selections. Some algorithms,
like MBPO, ignore this.

For example, to run MBPO on `gym`'s Hopper environment using the standard ensemble version of
For example, to run MBPO on `gym`'s cartpole environment using the standard ensemble version of
[GaussianMLP](https://github.com/facebookresearch/mbrl-lib/blob/main/mbrl/models/gaussian_mlp.py),
you can type

```bash
python -m mbrl.examples.main \
algorithm=mbpo \
overrides=mbpo_hopper \
overrides=mbpo_cartpole \
dynamics_model=gaussian_mlp_ensemble \
algorithm.agent.batch_size=256 \
overrides.validation_ratio=0.2 \
dynamics_model.activation_fn_cfg._target_=torch.nn.ReLU
++device=cpu \
++overrides.sac_batch_size=256 \
++overrides.validation_ratio=0.2 \
++dynamics_model.activation_fn_cfg._target_=torch.nn.ReLU
```
where we have re-written some defaults, just to show how `hydra` command line syntax
works. The number of possible options is extensive, and the best way to explore would be to
Expand All @@ -52,4 +53,4 @@ inside a folder whose path looks like
you can change the root directory (`./exp`) by passing
`root_dir=path-to-your-dir`, and the experiment sub-folder (`default`) by
passing `experiment=your-name`. The logger will also save a file called
`model_train.csv` with training information for the dynamics model.
`model_train.csv` with training information for the dynamics model.
2 changes: 1 addition & 1 deletion mbrl/examples/conf/action_optimizer/cem.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_


_target_: mbrl.planning.CEMOptimizer
num_iterations: ${overrides.cem_num_iters}
Expand Down
2 changes: 1 addition & 1 deletion mbrl/examples/conf/action_optimizer/icem.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_


_target_: mbrl.planning.ICEMOptimizer
num_iterations: ${overrides.cem_num_iters}
Expand Down
4 changes: 2 additions & 2 deletions mbrl/examples/conf/action_optimizer/mppi.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_


_target_: mbrl.planning.MPPIOptimizer
num_iterations: ${overrides.mppi_num_iters}
Expand All @@ -8,4 +8,4 @@ sigma: ${overrides.mppi_sigma}
beta: ${overrides.mppi_beta}
lower_bound: ???
upper_bound: ???
device: ${device}
device: ${device}
4 changes: 2 additions & 2 deletions mbrl/examples/conf/algorithm/mbpo.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

name: "mbpo"

normalize: true
Expand Down Expand Up @@ -34,4 +34,4 @@ agent:
target_entropy: ${overrides.sac_target_entropy}
hidden_size: ${overrides.sac_hidden_size}
device: ${device}
lr: ${overrides.sac_lr}
lr: ${overrides.sac_lr}
2 changes: 1 addition & 1 deletion mbrl/examples/conf/algorithm/pets.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

name: "pets"

agent:
Expand Down
4 changes: 2 additions & 2 deletions mbrl/examples/conf/algorithm/planet.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

name: "planet"

agent:
Expand All @@ -15,4 +15,4 @@ num_initial_trajectories: 5
action_noise_std: 0.3
test_frequency: 25
num_episodes: 1000
dataset_size: 1000000
dataset_size: 1000000
2 changes: 1 addition & 1 deletion mbrl/examples/conf/dynamics_model/basic_ensemble.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

_target_: mbrl.models.BasicEnsemble
ensemble_size: 5
device: ${device}
Expand Down
2 changes: 1 addition & 1 deletion mbrl/examples/conf/dynamics_model/gaussian_mlp.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

_target_: mbrl.models.GaussianMLP
device: ${device}
num_layers: 4
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

_target_: mbrl.models.GaussianMLP
device: ${device}
num_layers: 4
Expand Down
2 changes: 1 addition & 1 deletion mbrl/examples/conf/dynamics_model/planet.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

_target_: mbrl.models.PlaNetModel
obs_shape: [3, 64, 64]
obs_encoding_size: 1024
Expand Down
1 change: 1 addition & 0 deletions mbrl/examples/conf/main.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
defaults:
- _self_
- algorithm: pets
- dynamics_model: gaussian_mlp_ensemble
- overrides: pets_cartpole
Expand Down
2 changes: 1 addition & 1 deletion mbrl/examples/conf/overrides/mbpo_ant.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

env: "ant_truncated_obs"
# term_fn is set automatically by mbrl.util.env.EnvHandler.make_env

Expand Down
4 changes: 2 additions & 2 deletions mbrl/examples/conf/overrides/mbpo_cartpole.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

env: "cartpole_continuous"
trial_length: 200

Expand Down Expand Up @@ -26,4 +26,4 @@ sac_automatic_entropy_tuning: true
sac_target_entropy: -0.05
sac_hidden_size: 256
sac_lr: 0.0003
sac_batch_size: 256
sac_batch_size: 256
2 changes: 1 addition & 1 deletion mbrl/examples/conf/overrides/mbpo_halfcheetah.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

env: "gym___HalfCheetah-v2"
term_fn: "no_termination"

Expand Down
4 changes: 2 additions & 2 deletions mbrl/examples/conf/overrides/mbpo_hopper.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

env: "gym___Hopper-v4"
term_fn: "hopper"

Expand Down Expand Up @@ -26,4 +26,4 @@ sac_automatic_entropy_tuning: false
sac_target_entropy: 1 # ignored, since entropy tuning is false
sac_hidden_size: 512
sac_lr: 0.0003
sac_batch_size: 256
sac_batch_size: 256
2 changes: 1 addition & 1 deletion mbrl/examples/conf/overrides/mbpo_humanoid.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

env: "humanoid_truncated_obs"
# term_fn is set automatically by mbrl.util.env.EnvHandler.make_env

Expand Down
4 changes: 2 additions & 2 deletions mbrl/examples/conf/overrides/mbpo_inv_pendulum.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

env: "gym___InvertedPendulum-v4"
term_fn: "inverted_pendulum"

Expand Down Expand Up @@ -26,4 +26,4 @@ sac_automatic_entropy_tuning: true
sac_hidden_size: 256
sac_lr: 0.0003
sac_batch_size: 256
sac_target_entropy: -1
sac_target_entropy: -1
2 changes: 1 addition & 1 deletion mbrl/examples/conf/overrides/mbpo_pusher.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

env: "pets_pusher"
term_fn: "no_termination"
trial_length: 150
Expand Down
2 changes: 1 addition & 1 deletion mbrl/examples/conf/overrides/mbpo_walker.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

env: "gym___Walker2d-v4"
term_fn: "walker2d"

Expand Down
2 changes: 1 addition & 1 deletion mbrl/examples/conf/overrides/pb_mbpo_inv_pendulum.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

env: "pybulletgym___InvertedPendulumMuJoCoEnv-v0"

term_fn: "inverted_pendulum"
Expand Down
4 changes: 2 additions & 2 deletions mbrl/examples/conf/overrides/pets_cartpole.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

env: "cartpole_continuous"
learned_rewards: false
trial_length: 200
Expand All @@ -18,4 +18,4 @@ cem_num_iters: 5
cem_elite_ratio: 0.1
cem_population_size: 350
cem_alpha: 0.1
cem_clipped_normal: false
cem_clipped_normal: false
4 changes: 2 additions & 2 deletions mbrl/examples/conf/overrides/pets_cartpole_paper_version.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

env: "cartpole_pets_version"

# Note: This pre-process function requires setting model input manually
Expand All @@ -23,4 +23,4 @@ cem_num_iters: 5
cem_elite_ratio: 0.1
cem_population_size: 500
cem_alpha: 0.1
cem_clipped_normal: false
cem_clipped_normal: false
4 changes: 2 additions & 2 deletions mbrl/examples/conf/overrides/pets_halfcheetah.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

env: "pets_halfcheetah"
term_fn: "no_termination"
obs_process_fn: mbrl.env.pets_halfcheetah.HalfCheetahEnv.preprocess_fn
Expand All @@ -21,4 +21,4 @@ cem_num_iters: 5
cem_elite_ratio: 0.16
cem_population_size: 400
cem_alpha: 0.12
cem_clipped_normal: false
cem_clipped_normal: false
4 changes: 2 additions & 2 deletions mbrl/examples/conf/overrides/pets_hopper.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

env: "gym___Hopper-v4"
term_fn: "hopper"
learned_rewards: true
Expand All @@ -19,4 +19,4 @@ cem_num_iters: 5
cem_elite_ratio: 0.1
cem_population_size: 350
cem_alpha: 0.1
cem_clipped_normal: false
cem_clipped_normal: false
2 changes: 1 addition & 1 deletion mbrl/examples/conf/overrides/pets_icem_cartpole.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

env: "cartpole_continuous"
learned_rewards: false
trial_length: 200
Expand Down
2 changes: 1 addition & 1 deletion mbrl/examples/conf/overrides/pets_inv_pendulum.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

env: "gym___InvertedPendulum-v4"
term_fn: "inverted_pendulum"
learned_rewards: true
Expand Down
4 changes: 2 additions & 2 deletions mbrl/examples/conf/overrides/pets_mppi_halfcheetah.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

env: "pets_halfcheetah"
term_fn: "no_termination"
obs_process_fn: mbrl.env.pets_halfcheetah.HalfCheetahEnv.preprocess_fn
Expand All @@ -21,4 +21,4 @@ mppi_num_iters: 5
mppi_population_size: 350
mppi_gamma: 0.9
mppi_sigma: 1.0
mppi_beta: 0.9
mppi_beta: 0.9
2 changes: 1 addition & 1 deletion mbrl/examples/conf/overrides/pets_pusher.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

env: "pets_pusher"
term_fn: "no_termination"
learned_rewards: true
Expand Down
4 changes: 2 additions & 2 deletions mbrl/examples/conf/overrides/pets_reacher.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

env: "pets_reacher"
learned_rewards: true
num_steps: 15000
Expand All @@ -19,4 +19,4 @@ cem_num_iters: 5
cem_elite_ratio: 0.1
cem_population_size: 350
cem_alpha: 0.1
cem_clipped_normal: false
cem_clipped_normal: false
2 changes: 1 addition & 1 deletion mbrl/examples/conf/overrides/planet_cartpole_balance.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

env: "dmcontrol_cartpole_balance" # used to set the hydra dir, ignored otherwise

env_cfg:
Expand Down
2 changes: 1 addition & 1 deletion mbrl/examples/conf/overrides/planet_cartpole_swingup.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

env: "dmcontrol_cartpole_swingup" # used to set the hydra dir, ignored otherwise

env_cfg:
Expand Down
2 changes: 1 addition & 1 deletion mbrl/examples/conf/overrides/planet_cheetah_run.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

env: "dmcontrol_cheetah_run" # used to set the hydra dir, ignored otherwise

env_cfg:
Expand Down
2 changes: 1 addition & 1 deletion mbrl/examples/conf/overrides/planet_cup_catch.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

env: "dmcontrol_ball_in_cup_catch" # used to set the hydra dir, ignored otherwise

env_cfg:
Expand Down
2 changes: 1 addition & 1 deletion mbrl/examples/conf/overrides/planet_finger_spin.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

env: "dmcontrol_finger_spin" # used to set the hydra dir, ignored otherwise

env_cfg:
Expand Down
2 changes: 1 addition & 1 deletion mbrl/examples/conf/overrides/planet_walker_walk.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @package _group_

env: "dmcontrol_walker_walk" # used to set the hydra dir, ignored otherwise

env_cfg:
Expand Down
4 changes: 2 additions & 2 deletions mbrl/models/basic_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class BasicEnsemble(Ensemble):
device (str or torch.device): the device to use for the model.
member_cfg (omegaconf.DictConfig): the configuration needed to instantiate the models
in the ensemble. They will be instantiated using
`hydra.utils.instantiate(member_cfg)`.
`hydra.utils.instantiate(member_cfg, _recursive_=False)`.
propagation_method (str, optional): the uncertainty propagation method to use (see
above). Defaults to ``None``.
"""
Expand All @@ -71,7 +71,7 @@ def __init__(
)
self.members = []
for i in range(ensemble_size):
model = hydra.utils.instantiate(member_cfg)
model = hydra.utils.instantiate(member_cfg, _recursive_=False)
self.members.append(model)
self.deterministic = self.members[0].deterministic
self.in_size = getattr(self.members[0], "in_size", None)
Expand Down
2 changes: 1 addition & 1 deletion mbrl/models/gaussian_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def create_activation():
else:
# Handle the case where activation_fn_cfg is a dict
cfg = omegaconf.OmegaConf.create(activation_fn_cfg)
activation_func = hydra.utils.instantiate(cfg)
activation_func = hydra.utils.instantiate(cfg, _recursive_=False)
return activation_func

def create_linear_layer(l_in, l_out):
Expand Down
Loading