Skip to content

Commit

Permalink
add experimental exponential support
Browse files Browse the repository at this point in the history
  • Loading branch information
gboehl committed Oct 17, 2024
1 parent 958d700 commit a5c797d
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 21 deletions.
9 changes: 6 additions & 3 deletions econpizza/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .solvers.solve_linear_state_space import solve_linear_state_space, find_path_linear_state_space
from .solvers.shooting import find_path_shooting
from .parser import parse, load
from .config import config


# set number of cores for XLA
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={os.cpu_count()}"
Expand Down Expand Up @@ -59,8 +59,11 @@ def get_distributions(self, trajectory, init_dist=None, shock=None, pars=None):

dist0 = jnp.array(init_dist) if init_dist is not None else jnp.array(
self['steady_state'].get('distributions'))
pars = jnp.array(list(self['pars'].values())
) if pars is None else pars
if self.get('exp_all'):
pars = jnp.log(jnp.array(list(self['pars'].values())) if pars is None else pars)
trajectory = jnp.log(trajectory)
else:
pars = jnp.array(list(self['pars'].values())) if pars is None else pars
shocks = self.get("shocks") or ()
dist_names = list(self['distributions'].keys())
decisions_inputs = self['decisions']['inputs']
Expand Down
10 changes: 9 additions & 1 deletion econpizza/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,10 @@ def load(
evars, par_names, shocks, decisions_inputs, decisions_outputs, model['decisions']['calls'])
_define_function(model['func_strings']
['func_backw'], model['context'])
if model.get('exp_all'):
model['context']['func_backw'] = lambda xl,xc,xp,XSS,WFPrime,shocks,pars: model['context']['func_backw_raw'](jnp.exp(xl), jnp.exp(xc), jnp.exp(xp), jnp.exp(XSS), WFPrime, shocks, jnp.exp(pars))
else:
model['context']['func_backw'] = model['context']['func_backw_raw']
else:
decisions_outputs = []
decisions_inputs = []
Expand All @@ -352,7 +356,11 @@ def load(
'aux_equations'), shocks=shocks, distributions=dist_names, decisions_outputs=decisions_outputs)

# writing to tempfiles helps to get nice debug traces if the model does not work
_define_function(model['func_strings']["func_eqns"], model['context'])
_define_function(model['func_strings']['func_eqns'], model['context'])
if model.get('exp_all'):
model['context']['func_eqns'] = lambda xl,xc,xp,XSS,shocks,pars,distributions,decisions_outputs: model['context']['func_eqns_raw'](jnp.exp(xl), jnp.exp(xc), jnp.exp(xp), jnp.exp(XSS), shocks, jnp.exp(pars), distributions, decisions_outputs)
else:
model['context']['func_eqns'] = model['context']['func_eqns_raw']
# compile fixed and initial values
stst_inputs = compile_stst_inputs(model)
# try if function works on initvals
Expand Down
2 changes: 1 addition & 1 deletion econpizza/parser/build_generic_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def build_aggr_het_agent_funcs(model, zpars, nvars, stst, zshocks, horizon):

shocks = model.get("shocks") or ()
# get functions
func_eqns = model['context']["func_eqns"]
func_eqns = model['context']['func_eqns']
func_backw = model['context'].get('func_backw')
func_forw = model['context'].get('func_forw')

Expand Down
4 changes: 2 additions & 2 deletions econpizza/parser/compile_model_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def compile_backw_func_str(evars, par, shocks, inputs, outputs, calls):
if isinstance(calls, str):
calls = calls.splitlines()

func_str = f"""def func_backw(XLag, X, XPrime, XSS, WFPrime, shocks, pars):
func_str = f"""def func_backw_raw(XLag, X, XPrime, XSS, WFPrime, shocks, pars):
{compile_func_basics_str(evars, par, shocks)}
\n ({"".join(v + ", " for v in inputs)}) = WFPrime
\n %s
Expand Down Expand Up @@ -108,7 +108,7 @@ def compile_eqn_func_str(evars, eqns, par, eqns_aux, shocks, distributions, deci
eqns_stack = "\n ".join(eqns)

# compile the final function string
func_str = f"""def func_eqns(XLag, X, XPrime, XSS, shocks, pars, distributions=[], decisions_outputs=[]):
func_str = f"""def func_eqns_raw(XLag, X, XPrime, XSS, shocks, pars, distributions=[], decisions_outputs=[]):
{compile_func_basics_str(evars, par, shocks)}
\n ({"".join(d+', ' for d in distributions)}) = distributions
\n ({"".join(d+', ' for d in decisions_outputs)}) = decisions_outputs
Expand Down
18 changes: 14 additions & 4 deletions econpizza/solvers/stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,20 @@ def find_path_stacking(
'jac_factorized') else False

# get variables
stst = d2jnp(self["stst"])
nvars = len(self["var_names"])
pars = d2jnp(pars if pars is not None else self["pars"])
if self.get('exp_all'):
stst = jnp.log(d2jnp(self["stst"]))
pars = jnp.log(d2jnp(pars if pars is not None else self["pars"]))
else:
stst = d2jnp(self["stst"])
pars = d2jnp(pars if pars is not None else self["pars"])
shocks = self.get("shocks") or ()

# get initial guess
x0 = jnp.array(list(init_state)) if init_state is not None else stst
if self.get('exp_all'):
x0 = jnp.log(jnp.array(list(init_state))) if init_state is not None else stst
else:
x0 = jnp.array(list(init_state)) if init_state is not None else stst
init_dist = init_dist if init_dist is not None else self['steady_state'].get(
'distributions')
dist0 = jnp.array(init_dist if init_dist is not None else jnp.nan)
Expand Down Expand Up @@ -154,4 +161,7 @@ def find_path_stacking(
elif verbose:
print(mess)

return x_out, (flag, f)
if self.get('exp_all'):
return jnp.exp(x_out), (flag, f)
else:
return x_out, (flag, f)
34 changes: 24 additions & 10 deletions econpizza/solvers/steady_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,17 +105,24 @@ def solve_stst(self, tol=1e-8, maxit=15, tol_backwards=None, maxit_backwards=200
'decisions_output')

# get the actual steady state function
func_stst = get_func_stst(func_backw, func_forw_stst, func_eqns, shocks, wf_init, decisions_output_init, fixed_values=d2jnp(
fixed_vals), pre_stst_mapping=pre_stst_mapping, tol_backw=tol_backwards, maxit_backw=maxit_backwards, tol_forw=tol_forwards, maxit_forw=maxit_forwards)
if self.get('exp_all'):
func_stst = get_func_stst(func_backw, func_forw_stst, func_eqns, shocks, wf_init, decisions_output_init, fixed_values=jnp.log(d2jnp(fixed_vals)), pre_stst_mapping=pre_stst_mapping, tol_backw=tol_backwards, maxit_backw=maxit_backwards, tol_forw=tol_forwards, maxit_forw=maxit_forwards)
else:
func_stst = get_func_stst(func_backw, func_forw_stst, func_eqns, shocks, wf_init, decisions_output_init, fixed_values=d2jnp(fixed_vals), pre_stst_mapping=pre_stst_mapping, tol_backw=tol_backwards, maxit_backw=maxit_backwards, tol_forw=tol_forwards, maxit_forw=maxit_forwards)
# store jitted stst function that returns jacobian and func. value
self["context"]['func_stst'] = func_stst

if not self['steady_state'].get('skip'):
# actual root finding
res = newton_jax(func_stst, d2jnp(init_vals), maxit, tol,
solver=solver, verbose=verbose, **newton_kwargs)
if self.get('exp_all'):
res = newton_jax(func_stst, jnp.log(d2jnp(init_vals)), maxit, tol, solver=solver, verbose=verbose, **newton_kwargs)
else:
res = newton_jax(func_stst, d2jnp(init_vals), maxit, tol, solver=solver, verbose=verbose, **newton_kwargs)
else:
f, jac, aux = func_stst(d2jnp(init_vals))
if self.get('exp_all'):
f, jac, aux = func_stst(jnp.log(d2jnp(init_vals)))
else:
f, jac, aux = func_stst(d2jnp(init_vals))
res = {'x': d2jnp(init_vals),
'fun': f,
'jac': jac,
Expand All @@ -125,15 +132,22 @@ def solve_stst(self, tol=1e-8, maxit=15, tol_backwards=None, maxit_backwards=200
}

# exchange those values that are identified via stst_equations
stst_vals, par_vals = func_pre_stst(
res['x'], d2jnp(fixed_vals), pre_stst_mapping)
if self.get('exp_all'):
stst_vals, par_vals = func_pre_stst(res['x'], jnp.log(d2jnp(fixed_vals)), pre_stst_mapping)
else:
stst_vals, par_vals = func_pre_stst(res['x'], d2jnp(fixed_vals), pre_stst_mapping)
res['initial_values'] = {'guesses': init_vals, 'fixed': fixed_vals, 'value_functions': wf_init, 'decisions': decisions_output_init}

# store results
self['steady_state']['root_finding_result'] = res
self['steady_state']['found_values'] = dict(zip(init_vals.keys(),res['x']))
self['stst'] = self['steady_state']['all_values'] = dict(zip(evars, stst_vals))
self['pars'] = dict(zip(par_names, par_vals))
if self.get('exp_all'):
self['steady_state']['found_values'] = dict(zip(init_vals.keys(),jnp.exp(res['x'])))
self['stst'] = self['steady_state']['all_values'] = dict(zip(evars, jnp.exp(stst_vals)))
self['pars'] = dict(zip(par_names, jnp.exp(par_vals)))
else:
self['steady_state']['found_values'] = dict(zip(init_vals.keys(),res['x']))
self['stst'] = self['steady_state']['all_values'] = dict(zip(evars, stst_vals))
self['pars'] = dict(zip(par_names, par_vals))

# calculate dist objects and compile message
mess = _get_stst_dist_objs(self, res, maxit_backwards,
Expand Down

0 comments on commit a5c797d

Please sign in to comment.