From 589cd06acb00faf331e084e1d62901f3d7f545a4 Mon Sep 17 00:00:00 2001 From: Hendrik Vater Date: Wed, 10 Jul 2024 15:30:33 +0200 Subject: [PATCH 1/5] added fluid tank implementation as a new environment --- exciting_environments/fluid_tank/__init__.py | 1 + .../fluid_tank/fluid_tank_env.py | 163 ++++++++++++++++++ exciting_environments/registration.py | 4 + 3 files changed, 168 insertions(+) create mode 100644 exciting_environments/fluid_tank/__init__.py create mode 100644 exciting_environments/fluid_tank/fluid_tank_env.py diff --git a/exciting_environments/fluid_tank/__init__.py b/exciting_environments/fluid_tank/__init__.py new file mode 100644 index 0000000..9a39fbc --- /dev/null +++ b/exciting_environments/fluid_tank/__init__.py @@ -0,0 +1 @@ +from .fluid_tank_env import FluidTank diff --git a/exciting_environments/fluid_tank/fluid_tank_env.py b/exciting_environments/fluid_tank/fluid_tank_env.py new file mode 100644 index 0000000..ed325f8 --- /dev/null +++ b/exciting_environments/fluid_tank/fluid_tank_env.py @@ -0,0 +1,163 @@ +from functools import partial +from typing import Callable + +import numpy as np +import jax +import jax.numpy as jnp +from jax.tree_util import tree_flatten, tree_structure +import jax_dataclasses as jdc +import chex +import diffrax + +from exciting_environments import core_env + + +class FluidTank(core_env.CoreEnvironment): + """Fluid tank based on torricelli's principle. + + Based on ex. 7.3.2 on p. 355 of "System Dynamics" from Palm, William III. + """ + + def __init__( + self, + batch_size: float = 1, + physical_constraints: dict = None, + action_constraints: dict = None, + static_params: dict = None, + solver=diffrax.Euler(), + reward_func: Callable = None, + tau: float = 1e-3, + ): + if not physical_constraints: + physical_constraints = {"height": 3} + + if not action_constraints: + action_constraints = {"inflow": 0.2} + + if not static_params: + # c_d = 0.6 typical value for water [Palm2010] + static_params = {"base_area": jnp.pi, "orifice_area": jnp.pi * 0.1**2, "c_d": 0.6, "g": 9.81} + + physical_constraints = self.PhysicalState(**physical_constraints) + action_constraints = self.Action(**action_constraints) + static_params = self.StaticParams(**static_params) + + super().__init__( + batch_size, + physical_constraints, + action_constraints, + static_params, + tau=tau, + solver=solver, + reward_func=reward_func, + ) + + @jdc.pytree_dataclass + class PhysicalState: + """Dataclass containing the physical state of the environment.""" + + height: jax.Array + + @jdc.pytree_dataclass + class Optional: + """Dataclass containing additional information for simulation.""" + + something: jax.Array + + @jdc.pytree_dataclass + class StaticParams: + """Dataclass containing the static parameters of the environment.""" + + base_area: jax.Array + orifice_area: jax.Array + c_d: jax.Array + g: jax.Array + + @jdc.pytree_dataclass + class Action: + """Dataclass containing the action, that can be applied to the environment.""" + + inflow: jax.Array + + @partial(jax.jit, static_argnums=0) + def _ode_solver_step(self, state, action, static_params): + physical_state = state.physical_state + + action = action / jnp.array(tree_flatten(self.env_properties.action_constraints)[0]).T + action = (action + 1) / 2 + action = action * jnp.array(tree_flatten(self.env_properties.action_constraints)[0]).T + + args = (action, static_params) + + def vector_field(t, y, args): + h = y[0] + inflow, params = args + + dh_dt = inflow[0] / params.base_area - params.c_d * params.orifice_area / params.base_area * jnp.sqrt( + 2 * params.g * h + ) + return (dh_dt,) + + term = diffrax.ODETerm(vector_field) + t0 = 0 + t1 = self.tau + y0 = (physical_state.height,) + + env_state = self._solver.init(term, t0, t1, y0, args) + y, _, _, env_state, _ = self._solver.step(term, t0, t1, y0, args, env_state, made_jump=False) + + h_k1 = y[0] + phys = self.PhysicalState(height=h_k1) + opt = None # Optional(something=...) + return self.State(physical_state=phys, PRNGKey=None, optional=None) + + @partial(jax.jit, static_argnums=0) + def default_reward_func(self, obs, action, action_constraints): + return 0 + + @partial(jax.jit, static_argnums=0) + def generate_observation(self, states, physical_constraints): + return (states.physical_state.height - physical_constraints.height / 2) / (physical_constraints.height / 2) + + @partial(jax.jit, static_argnums=0) + def generate_truncated(self, states, physical_constraints): + return 0 + + @partial(jax.jit, static_argnums=0) + def generate_terminated(self, states, reward): + return False + + @property + def obs_description(self): + return self.states_description + + @property + def states_description(self): + return np.array(["fluid height"]) + + @property + def action_description(self): + return np.array(["inflow"]) + + @partial(jax.jit, static_argnums=0) + def init_state(self): + phys = self.PhysicalState(height=jnp.full(self.batch_size, self.env_properties.physical_constraints.height / 2)) + opt = None # self.Optional(something=jnp.zeros(self.batch_size)) + return self.State(physical_state=phys, PRNGKey=None, optional=opt) + + def reset(self, rng: jax.random.PRNGKey = None, initial_state: jdc.pytree_dataclass = None): + if initial_state is not None: + assert tree_structure(self.init_state()) == tree_structure( + initial_state + ), f"initial_state should have the same dataclass structure as self.init_state()" + state = initial_state + else: + state = self.init_state() + + obs = jax.vmap( + self.generate_observation, + in_axes=(0, self.in_axes_env_properties.physical_constraints), + )(state, self.env_properties.physical_constraints) + + # TODO: this [None] looks off -> investigate + return obs[None], state diff --git a/exciting_environments/registration.py b/exciting_environments/registration.py index 185eb76..b307860 100644 --- a/exciting_environments/registration.py +++ b/exciting_environments/registration.py @@ -1,6 +1,7 @@ from .cart_pole import CartPole from .mass_spring_damper import MassSpringDamper from .pendulum import Pendulum +from .fluid_tank import FluidTank def make(env_id: str, **env_kwargs): @@ -13,6 +14,9 @@ def make(env_id: str, **env_kwargs): elif env_id == "Pendulum-v0": env = Pendulum(**env_kwargs) + elif env_id == "FluidTank-v0": + env = FluidTank(**env_kwargs) + else: print(f"No existing environments got env_id ={env_id}") env = None From c85b33245ab0056bedb65e28e29997fb68b1d9e5 Mon Sep 17 00:00:00 2001 From: Hendrik Vater Date: Mon, 15 Jul 2024 17:18:01 +0200 Subject: [PATCH 2/5] added state clipping because the env crashes when h<0 --- exciting_environments/fluid_tank/fluid_tank_env.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/exciting_environments/fluid_tank/fluid_tank_env.py b/exciting_environments/fluid_tank/fluid_tank_env.py index ed325f8..afb28cc 100644 --- a/exciting_environments/fluid_tank/fluid_tank_env.py +++ b/exciting_environments/fluid_tank/fluid_tank_env.py @@ -107,6 +107,11 @@ def vector_field(t, y, args): y, _, _, env_state, _ = self._solver.step(term, t0, t1, y0, args, env_state, made_jump=False) h_k1 = y[0] + + # clip to 0 because tank cannot be more empty than empty + # necessary because of ODE solver approximation + h_k1 = jnp.clip(h_k1, min=0) + phys = self.PhysicalState(height=h_k1) opt = None # Optional(something=...) return self.State(physical_state=phys, PRNGKey=None, optional=None) From 3663f4f410c3abc01b47dc0f099722b4a81a2841 Mon Sep 17 00:00:00 2001 From: Hendrik Vater Date: Wed, 7 Aug 2024 11:16:16 +0200 Subject: [PATCH 3/5] fix for fluid tank nan issues --- exciting_environments/fluid_tank/fluid_tank_env.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/exciting_environments/fluid_tank/fluid_tank_env.py b/exciting_environments/fluid_tank/fluid_tank_env.py index afb28cc..e56c067 100644 --- a/exciting_environments/fluid_tank/fluid_tank_env.py +++ b/exciting_environments/fluid_tank/fluid_tank_env.py @@ -93,6 +93,8 @@ def vector_field(t, y, args): h = y[0] inflow, params = args + h = jnp.clip(h, min=0) + dh_dt = inflow[0] / params.base_area - params.c_d * params.orifice_area / params.base_area * jnp.sqrt( 2 * params.g * h ) From 7ac46294353c0693a6065df37e465fcd7edef863 Mon Sep 17 00:00:00 2001 From: Hendrik Vater Date: Wed, 7 Aug 2024 11:16:31 +0200 Subject: [PATCH 4/5] updated default cartpole parameters --- .../cart_pole/cart_pole_env.py | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/exciting_environments/cart_pole/cart_pole_env.py b/exciting_environments/cart_pole/cart_pole_env.py index a83e2e0..17a776c 100644 --- a/exciting_environments/cart_pole/cart_pole_env.py +++ b/exciting_environments/cart_pole/cart_pole_env.py @@ -52,7 +52,7 @@ def __init__( static_params: dict = None, solver=diffrax.Euler(), reward_func: Callable = None, - tau: float = 1e-4, + tau: float = 2e-2, ): """ Args: @@ -81,23 +81,25 @@ def __init__( if not physical_constraints: physical_constraints = { - "deflection": 10, - "velocity": 10, + "deflection": 2.4, + "velocity": 8, "theta": jnp.pi, - "omega": 10, + "omega": 8, } if not action_constraints: action_constraints = {"force": 20} if not static_params: - static_params = { - "mu_p": 0, - "mu_c": 0, - "l": 1, - "m_p": 1, - "m_c": 1, - "g": 9.81, - } + static_params = ( + { # typical values from Source with DOI: 10.1109/TSMC.1983.6313077 + "mu_p": 0.000002, + "mu_c": 0.0005, + "l": 0.5, + "m_p": 0.1, + "m_c": 1, + "g": 9.81, + }, + ) physical_constraints = self.PhysicalState(**physical_constraints) action_constraints = self.Action(**action_constraints) @@ -149,6 +151,8 @@ class Action: def _ode_solver_step(self, state, action, static_params): """Computes state by simulating one step. + Source DOI: 10.1109/TSMC.1983.6313077 + Args: state: The state from which to calculate state for the next step. action: The action to apply to the environment. @@ -271,9 +275,7 @@ def vector_field(t, y, args): # keep theta between -pi and pi theta_t = ((theta_t + jnp.pi) % (2 * jnp.pi)) - jnp.pi - physical_states = self.PhysicalState( - deflection=deflection_t, velocity=velocity_t, theta=theta_t, omega=omega_t - ) + physical_states = self.PhysicalState(deflection=deflection_t, velocity=velocity_t, theta=theta_t, omega=omega_t) opt = None PRNGKey = None return self.State(physical_state=physical_states, PRNGKey=PRNGKey, optional=opt) @@ -315,6 +317,10 @@ def generate_observation(self, state, physical_constraints): ) return obs + @property + def action_description(self): + return np.array(["force"]) + @property def obs_description(self): return np.array(["deflection", "velocity", "theta", "omega"]) From a84815c121b6fafbada3560e9b321d417208177c Mon Sep 17 00:00:00 2001 From: Hendrik Vater Date: Wed, 7 Aug 2024 11:16:48 +0200 Subject: [PATCH 5/5] added action description to pendulum env --- exciting_environments/pendulum/pendulum_env.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/exciting_environments/pendulum/pendulum_env.py b/exciting_environments/pendulum/pendulum_env.py index 69ed2cd..556430b 100644 --- a/exciting_environments/pendulum/pendulum_env.py +++ b/exciting_environments/pendulum/pendulum_env.py @@ -232,6 +232,10 @@ def generate_observation(self, state, physical_constraints): def obs_description(self): return np.array(["theta", "omega"]) + @property + def action_description(self): + return np.array(["torque"]) + @partial(jax.jit, static_argnums=0) def generate_truncated(self, state, physical_constraints): """Returns truncated information for one batch."""