diff --git a/pettingzoo/sisl/multiwalker/multiwalker.py b/pettingzoo/sisl/multiwalker/multiwalker.py index 839b6152e..8edf250d1 100644 --- a/pettingzoo/sisl/multiwalker/multiwalker.py +++ b/pettingzoo/sisl/multiwalker/multiwalker.py @@ -160,6 +160,7 @@ def __init__(self, *args, **kwargs): # spaces self.action_spaces = dict(zip(self.agents, self.env.action_space)) self.observation_spaces = dict(zip(self.agents, self.env.observation_space)) + self.state_space = self.env.state_space self.steps = 0 def observation_space(self, agent): @@ -191,6 +192,9 @@ def close(self): def render(self): return self.env.render() + def state(self): + return self.env.state() + def observe(self, agent): return self.env.observe(self.agent_name_mapping[agent]) diff --git a/pettingzoo/sisl/multiwalker/multiwalker_base.py b/pettingzoo/sisl/multiwalker/multiwalker_base.py index 3da9a1260..476b3950c 100644 --- a/pettingzoo/sisl/multiwalker/multiwalker_base.py +++ b/pettingzoo/sisl/multiwalker/multiwalker_base.py @@ -360,6 +360,14 @@ def setup(self): ] self.observation_space = [agent.observation_space for agent in self.walkers] self.action_space = [agent.action_space for agent in self.walkers] + self.state_space = spaces.Box( + low=-np.float32(np.inf), + high=+np.float32(np.inf), + shape=( + self.n_walkers * 24 + 3, + ), # 24 is the observation space of each walker, 3 is the package observation space + dtype=np.float32, + ) self.package_scale = self.n_walkers / 1.75 self.package_length = PACKAGE_LENGTH / SCALE * self.package_scale @@ -545,6 +553,20 @@ def observe(self, agent): o = np.array(o, dtype=np.float32) return o + def state(self): + all_walker_obs = self.get_last_obs() + all_walker_obs = np.array(list(all_walker_obs.values())).flatten() + package_obs = np.array( + [ + self.package.position.x, + self.package.position.y, + self.package.angle, + ] + ) + global_state = np.concatenate((all_walker_obs, package_obs)).astype(np.float32) + + return global_state + def render(self, close=False): if close: self.close()