Skip to content

Commit

Permalink
Merge pull request #45 from mila-iqia/feature/configured-clubs
Browse files Browse the repository at this point in the history
made clubs part of config
  • Loading branch information
tianyu-z authored Aug 16, 2024
2 parents ee22aa6 + 742a77d commit f743956
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 2 deletions.
7 changes: 7 additions & 0 deletions rice.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def __init__(
temperature_calibration="base",
prescribed_emissions=None,
pct_reward=False,
clubs_enabled = False,
club_members = []
):
self.action_space_type = action_space_type
self.num_discrete_action_levels = num_discrete_action_levels
Expand All @@ -67,6 +69,11 @@ def __init__(
self.pct_reward = pct_reward
self.global_state = {}

#clubs
self.clubs_enabled = clubs_enabled
if self.clubs_enabled:
self.club_members = club_members

self.set_dtypes()

self.set_all_region_params()
Expand Down
7 changes: 5 additions & 2 deletions scripts/create_submission_zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,11 @@ def prepare_submission(results_dir=None):
if file.endswith(".state_dict")
]
sorted_policy_models = sorted(policy_models, key=os.path.getmtime)
# Delete all but the last policy model file
for policy_model in sorted_policy_models[:-1]:

#in the case of multi-model, there will be multiple state dictionaries per model.
policy_prefixes = set([model_name.split("/")[-1].split("_")[0]for model_name in sorted_policy_models])
# Delete all but the last policy model file of each unique prefix
for policy_model in sorted_policy_models[:-len(policy_prefixes)]:
os.remove(os.path.join(results_dir_copy, policy_model.split("/")[-1]))

shutil.make_archive(submission_file, "zip", results_dir_copy)
Expand Down
3 changes: 3 additions & 0 deletions scripts/rice_rllib_discrete.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ env:
carbon_model: "base"
temperature_calibration: "base"
pct_reward: False
clubs_enabled: True
club_members: [1]

regions:
num_agents: 3 #can be either {3,7,20,27}
Expand All @@ -52,6 +54,7 @@ logging:

# Policy network settings
policy:
multi_model: True #only active if club_enabled also set to True
regions:
vf_loss_coeff: 0.1 # loss coefficient schedule for the value function loss
entropy_coeff_schedule: # loss coefficient schedule for the entropy loss
Expand Down

0 comments on commit f743956

Please sign in to comment.