Skip to content

Commit

Permalink
Merge pull request #48 from Farama-Foundation/feature/add-uniform-wei…
Browse files Browse the repository at this point in the history
…ghts

Feature/add uniform weights generation to MOMAPPO
  • Loading branch information
ffelten authored Mar 12, 2024
2 parents 9a287c6 + 8d46452 commit 7c3c1fa
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 10 deletions.
21 changes: 16 additions & 5 deletions momaland/learning/cooperative_momappo/continuous_momappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import flax.linen as nn
import jax
import jax.numpy as jnp
import morl_baselines.common.weights
import numpy as np
import optax
import wandb
Expand Down Expand Up @@ -68,9 +69,10 @@ def parse_args():
)

# Algorithm specific arguments
parser.add_argument("--num-weights", type=int, default=5, help="Number of different weights to train on")
parser.add_argument("--weights-generation", type=str, default="OLS", help="The method to generate the weights - 'OLS' or 'uniform'")
parser.add_argument("--num-steps-per-epoch", type=int, default=1280, help="the number of steps per epoch (higher batch size should be better)")
parser.add_argument("--timesteps-per-weight", type=int, default=10e5, help="timesteps per weight vector")
parser.add_argument("--num-weights", type=int, default=5, help="Number of different weights to train on")
parser.add_argument("--update-epochs", type=int, default=2, help="the number epochs to update the policy")
parser.add_argument("--num-minibatches", type=int, default=2, help="the number of minibatches (keep small in MARL)")
parser.add_argument("--gamma", type=float, default=0.99,
Expand Down Expand Up @@ -557,7 +559,7 @@ def _env_step(runner_state):
exp_name = args.exp_name
args_dict = vars(args)
args_dict["algo"] = exp_name
run_name = f"{args.env_id}__{exp_name}__{args.seed}__{int(time.time())}"
run_name = f"{args.env_id}__{exp_name}({args.weights_generation})__{args.seed}__{int(time.time())}"
wandb.init(
project=args.wandb_project,
entity=args.wandb_entity,
Expand All @@ -569,8 +571,14 @@ def _env_step(runner_state):
ols = LinearSupport(num_objectives=reward_dim, epsilon=0.0, verbose=args.debug)
weight_number = 1
value = []
w = ols.next_weight()
while not ols.ended() and weight_number <= args.num_weights:
if args.weights_generation == "OLS":
w = ols.next_weight()
elif args.weights_generation == "uniform":
all_weights = morl_baselines.common.weights.equally_spaced_weights(reward_dim, args.num_weights)
w = all_weights[weight_number - 1]
else:
raise ValueError("Weights generation method not recognized")
while (args.weights_generation != "OLS" or not ols.ended()) and weight_number <= args.num_weights:
out = train(args, env, w, rng)
actor_state = out["runner_state"][0]
_, disc_vec_return = policy_evaluation_mo(
Expand All @@ -589,8 +597,11 @@ def _env_step(runner_state):
)
if args.save_policies:
save_actor(actor_state, w, args)
w = ols.next_weight()
weight_number += 1
if args.weights_generation == "OLS":
w = ols.next_weight()
elif args.weights_generation == "uniform":
w = all_weights[weight_number - 1]

env.close()
wandb.finish()
Expand Down
21 changes: 16 additions & 5 deletions momaland/learning/cooperative_momappo/discrete_momappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import flax.linen as nn
import jax
import jax.numpy as jnp
import morl_baselines.common.weights
import numpy as np
import optax
import wandb
Expand Down Expand Up @@ -68,10 +69,11 @@ def parse_args():
)

# Algorithm specific arguments
parser.add_argument("--num-weights", type=int, default=10, help="the number of different weights to train on")
parser.add_argument("--weights-generation", type=str, default="OLS", help="The method to generate the weights - 'OLS' or 'uniform'")
parser.add_argument("--num-steps-per-epoch", type=int, default=128, help="the number of steps per epoch (higher batch size should be better)")
parser.add_argument("--timesteps-per-weight", type=int, default=2e3,
help="timesteps per weight vector")
parser.add_argument("--num-weights", type=int, default=10, help="the number of different weights to train on")
parser.add_argument("--update-epochs", type=int, default=2, help="the number epochs to update the policy")
parser.add_argument("--num-minibatches", type=int, default=2, help="the number of minibatches (keep small in MARL)")
parser.add_argument("--gamma", type=float, default=0.99,
Expand Down Expand Up @@ -553,7 +555,7 @@ def _env_step(runner_state):
exp_name = args.exp_name
args_dict = vars(args)
args_dict["algo"] = exp_name
run_name = f"{args.env_id}__{exp_name}__{args.seed}__{int(time.time())}"
run_name = f"{args.env_id}__{exp_name}({args.weights_generation})__{args.seed}__{int(time.time())}"
wandb.init(
project=args.wandb_project,
entity=args.wandb_entity,
Expand All @@ -565,8 +567,14 @@ def _env_step(runner_state):
ols = LinearSupport(num_objectives=reward_dim, epsilon=0.0, verbose=args.debug)
weight_number = 1
value = []
w = ols.next_weight()
while not ols.ended() and weight_number <= args.num_weights:
if args.weights_generation == "OLS":
w = ols.next_weight()
elif args.weights_generation == "uniform":
all_weights = morl_baselines.common.weights.equally_spaced_weights(reward_dim, args.num_weights)
w = all_weights[weight_number - 1]
else:
raise ValueError("Weights generation method not recognized")
while (args.weights_generation != "OLS" or not ols.ended()) and weight_number <= args.num_weights:
out = train(args, env, w, rng)
actor_state = out["runner_state"][0]
_, disc_vec_return = policy_evaluation_mo(
Expand All @@ -585,8 +593,11 @@ def _env_step(runner_state):
)
if args.save_policies:
save_actor(actor_state, w, args)
w = ols.next_weight()
weight_number += 1
if args.weights_generation == "OLS":
w = ols.next_weight()
elif args.weights_generation == "uniform":
w = all_weights[weight_number - 1]

env.close()
wandb.finish()
Expand Down

0 comments on commit 7c3c1fa

Please sign in to comment.