Skip to content

Commit

Permalink
Merge pull request #6 from ExcitingSystems/fluid-tank-environment
Browse files Browse the repository at this point in the history
Fluid tank environment
  • Loading branch information
hvater authored Aug 29, 2024
2 parents ad2b3aa + a84815c commit 974eb52
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 15 deletions.
36 changes: 21 additions & 15 deletions exciting_environments/cart_pole/cart_pole_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
static_params: dict = None,
solver=diffrax.Euler(),
reward_func: Callable = None,
tau: float = 1e-4,
tau: float = 2e-2,
):
"""
Args:
Expand Down Expand Up @@ -81,23 +81,25 @@ def __init__(

if not physical_constraints:
physical_constraints = {
"deflection": 10,
"velocity": 10,
"deflection": 2.4,
"velocity": 8,
"theta": jnp.pi,
"omega": 10,
"omega": 8,
}
if not action_constraints:
action_constraints = {"force": 20}

if not static_params:
static_params = {
"mu_p": 0,
"mu_c": 0,
"l": 1,
"m_p": 1,
"m_c": 1,
"g": 9.81,
}
static_params = (
{ # typical values from Source with DOI: 10.1109/TSMC.1983.6313077
"mu_p": 0.000002,
"mu_c": 0.0005,
"l": 0.5,
"m_p": 0.1,
"m_c": 1,
"g": 9.81,
},
)

physical_constraints = self.PhysicalState(**physical_constraints)
action_constraints = self.Action(**action_constraints)
Expand Down Expand Up @@ -149,6 +151,8 @@ class Action:
def _ode_solver_step(self, state, action, static_params):
"""Computes state by simulating one step.
Source DOI: 10.1109/TSMC.1983.6313077
Args:
state: The state from which to calculate state for the next step.
action: The action to apply to the environment.
Expand Down Expand Up @@ -271,9 +275,7 @@ def vector_field(t, y, args):
# keep theta between -pi and pi
theta_t = ((theta_t + jnp.pi) % (2 * jnp.pi)) - jnp.pi

physical_states = self.PhysicalState(
deflection=deflection_t, velocity=velocity_t, theta=theta_t, omega=omega_t
)
physical_states = self.PhysicalState(deflection=deflection_t, velocity=velocity_t, theta=theta_t, omega=omega_t)
opt = None
PRNGKey = None
return self.State(physical_state=physical_states, PRNGKey=PRNGKey, optional=opt)
Expand Down Expand Up @@ -315,6 +317,10 @@ def generate_observation(self, state, physical_constraints):
)
return obs

@property
def action_description(self):
return np.array(["force"])

@property
def obs_description(self):
return np.array(["deflection", "velocity", "theta", "omega"])
Expand Down
1 change: 1 addition & 0 deletions exciting_environments/fluid_tank/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .fluid_tank_env import FluidTank
170 changes: 170 additions & 0 deletions exciting_environments/fluid_tank/fluid_tank_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from functools import partial
from typing import Callable

import numpy as np
import jax
import jax.numpy as jnp
from jax.tree_util import tree_flatten, tree_structure
import jax_dataclasses as jdc
import chex
import diffrax

from exciting_environments import core_env


class FluidTank(core_env.CoreEnvironment):
"""Fluid tank based on torricelli's principle.
Based on ex. 7.3.2 on p. 355 of "System Dynamics" from Palm, William III.
"""

def __init__(
self,
batch_size: float = 1,
physical_constraints: dict = None,
action_constraints: dict = None,
static_params: dict = None,
solver=diffrax.Euler(),
reward_func: Callable = None,
tau: float = 1e-3,
):
if not physical_constraints:
physical_constraints = {"height": 3}

if not action_constraints:
action_constraints = {"inflow": 0.2}

if not static_params:
# c_d = 0.6 typical value for water [Palm2010]
static_params = {"base_area": jnp.pi, "orifice_area": jnp.pi * 0.1**2, "c_d": 0.6, "g": 9.81}

physical_constraints = self.PhysicalState(**physical_constraints)
action_constraints = self.Action(**action_constraints)
static_params = self.StaticParams(**static_params)

super().__init__(
batch_size,
physical_constraints,
action_constraints,
static_params,
tau=tau,
solver=solver,
reward_func=reward_func,
)

@jdc.pytree_dataclass
class PhysicalState:
"""Dataclass containing the physical state of the environment."""

height: jax.Array

@jdc.pytree_dataclass
class Optional:
"""Dataclass containing additional information for simulation."""

something: jax.Array

@jdc.pytree_dataclass
class StaticParams:
"""Dataclass containing the static parameters of the environment."""

base_area: jax.Array
orifice_area: jax.Array
c_d: jax.Array
g: jax.Array

@jdc.pytree_dataclass
class Action:
"""Dataclass containing the action, that can be applied to the environment."""

inflow: jax.Array

@partial(jax.jit, static_argnums=0)
def _ode_solver_step(self, state, action, static_params):
physical_state = state.physical_state

action = action / jnp.array(tree_flatten(self.env_properties.action_constraints)[0]).T
action = (action + 1) / 2
action = action * jnp.array(tree_flatten(self.env_properties.action_constraints)[0]).T

args = (action, static_params)

def vector_field(t, y, args):
h = y[0]
inflow, params = args

h = jnp.clip(h, min=0)

dh_dt = inflow[0] / params.base_area - params.c_d * params.orifice_area / params.base_area * jnp.sqrt(
2 * params.g * h
)
return (dh_dt,)

term = diffrax.ODETerm(vector_field)
t0 = 0
t1 = self.tau
y0 = (physical_state.height,)

env_state = self._solver.init(term, t0, t1, y0, args)
y, _, _, env_state, _ = self._solver.step(term, t0, t1, y0, args, env_state, made_jump=False)

h_k1 = y[0]

# clip to 0 because tank cannot be more empty than empty
# necessary because of ODE solver approximation
h_k1 = jnp.clip(h_k1, min=0)

phys = self.PhysicalState(height=h_k1)
opt = None # Optional(something=...)
return self.State(physical_state=phys, PRNGKey=None, optional=None)

@partial(jax.jit, static_argnums=0)
def default_reward_func(self, obs, action, action_constraints):
return 0

@partial(jax.jit, static_argnums=0)
def generate_observation(self, states, physical_constraints):
return (states.physical_state.height - physical_constraints.height / 2) / (physical_constraints.height / 2)

@partial(jax.jit, static_argnums=0)
def generate_truncated(self, states, physical_constraints):
return 0

@partial(jax.jit, static_argnums=0)
def generate_terminated(self, states, reward):
return False

@property
def obs_description(self):
return self.states_description

@property
def states_description(self):
return np.array(["fluid height"])

@property
def action_description(self):
return np.array(["inflow"])

@partial(jax.jit, static_argnums=0)
def init_state(self):
phys = self.PhysicalState(height=jnp.full(self.batch_size, self.env_properties.physical_constraints.height / 2))
opt = None # self.Optional(something=jnp.zeros(self.batch_size))
return self.State(physical_state=phys, PRNGKey=None, optional=opt)

def reset(self, rng: jax.random.PRNGKey = None, initial_state: jdc.pytree_dataclass = None):
if initial_state is not None:
assert tree_structure(self.init_state()) == tree_structure(
initial_state
), f"initial_state should have the same dataclass structure as self.init_state()"
state = initial_state
else:
state = self.init_state()

obs = jax.vmap(
self.generate_observation,
in_axes=(0, self.in_axes_env_properties.physical_constraints),
)(state, self.env_properties.physical_constraints)

# TODO: this [None] looks off -> investigate
return obs[None], state
4 changes: 4 additions & 0 deletions exciting_environments/pendulum/pendulum_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ def generate_observation(self, state, physical_constraints):
def obs_description(self):
return np.array(["theta", "omega"])

@property
def action_description(self):
return np.array(["torque"])

@partial(jax.jit, static_argnums=0)
def generate_truncated(self, state, physical_constraints):
"""Returns truncated information for one batch."""
Expand Down
4 changes: 4 additions & 0 deletions exciting_environments/registration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .cart_pole import CartPole
from .mass_spring_damper import MassSpringDamper
from .pendulum import Pendulum
from .fluid_tank import FluidTank


def make(env_id: str, **env_kwargs):
Expand All @@ -13,6 +14,9 @@ def make(env_id: str, **env_kwargs):
elif env_id == "Pendulum-v0":
env = Pendulum(**env_kwargs)

elif env_id == "FluidTank-v0":
env = FluidTank(**env_kwargs)

else:
print(f"No existing environments got env_id ={env_id}")
env = None
Expand Down

0 comments on commit 974eb52

Please sign in to comment.