A three-period consumption-savings model with two regimes:
Working life (ages 25 and 45): The agent chooses whether to work and how much to consume. A simple tax-and-transfer system guarantees a consumption floor. Savings earn interest.
Retirement (age 65): Terminal regime. The agent consumes out of remaining wealth.
Model¶
An agent lives for three periods (ages 25, 45, and 65). In the first two periods (working life), the agent chooses whether to work and how much to consume . In the final period (retirement), the agent consumes out of remaining wealth.
Working life (ages 25 and 45):
subject to
where is wealth, earnings, the wage, a consumption floor guaranteed by transfers, the tax rate, and end-of-period wealth. The transfer only kicks in when the agent’s resources () fall below the consumption floor.
Retirement (age 65, terminal):
from pprint import pprint
import jax.numpy as jnp
import numpy as np
import pandas as pd
import plotly.express as px
from lcm import (
AgeGrid,
DiscreteGrid,
LinSpacedGrid,
LogSpacedGrid,
Model,
Regime,
categorical,
)
from lcm.typing import (
BoolND,
ContinuousAction,
ContinuousState,
DiscreteAction,
FloatND,
ScalarInt,
)Detecting C-based callable <method-wrapper '__call__' of jaxlib._jax.PjitFunction object at 0x71cb9b881f90> isomorphism...
Categorical Variables¶
@categorical(ordered=True)
class LaborSupply:
do_not_work: ScalarInt
work: ScalarInt
@categorical(ordered=False)
class RegimeId:
working_life: ScalarInt
retirement: ScalarIntModel Functions¶
# Utility
def utility(
consumption: ContinuousAction,
labor_supply: DiscreteAction,
disutility_of_work: float,
risk_aversion: float,
) -> FloatND:
return consumption ** (1 - risk_aversion) / (
1 - risk_aversion
) - disutility_of_work * (labor_supply == LaborSupply.work)
def utility_retirement(wealth: ContinuousState, risk_aversion: float) -> FloatND:
return wealth ** (1 - risk_aversion) / (1 - risk_aversion)
# Auxiliary functions
def earnings(labor_supply: DiscreteAction, wage: float) -> FloatND:
return jnp.where(labor_supply == LaborSupply.work, wage, 0.0)
def taxes_transfers(
earnings: FloatND,
wealth: ContinuousState,
consumption_floor: float,
tax_rate: float,
) -> FloatND:
return jnp.where(
earnings >= consumption_floor,
tax_rate * (earnings - consumption_floor),
jnp.minimum(0.0, wealth + earnings - consumption_floor),
)
def end_of_period_wealth(
wealth: ContinuousState,
earnings: FloatND,
taxes_transfers: FloatND,
consumption: ContinuousAction,
) -> FloatND:
return wealth + earnings - taxes_transfers - consumption
# State transition
def next_wealth(end_of_period_wealth: FloatND, interest_rate: float) -> ContinuousState:
return (1 + interest_rate) * end_of_period_wealth
# Constraints
def borrowing_constraint_working(end_of_period_wealth: FloatND) -> BoolND:
return end_of_period_wealth >= 0
# Regime transition
def next_regime(age: float, last_working_age: float) -> ScalarInt:
return jnp.where(
age >= last_working_age, RegimeId.retirement, RegimeId.working_life
)Regimes and Model¶
age_grid = AgeGrid(start=25, stop=65, step="20Y")
retirement_age = age_grid.exact_values[-1]
working_life = Regime(
transition=next_regime,
active=lambda age: age < retirement_age,
states={
"wealth": LinSpacedGrid(start=0, stop=50, n_points=25),
},
state_transitions={
"wealth": next_wealth,
},
actions={
"labor_supply": DiscreteGrid(LaborSupply),
"consumption": LogSpacedGrid(start=4, stop=50, n_points=100),
},
functions={
"utility": utility,
"earnings": earnings,
"taxes_transfers": taxes_transfers,
"end_of_period_wealth": end_of_period_wealth,
},
constraints={
"borrowing_constraint_working": borrowing_constraint_working,
},
)
retirement = Regime(
transition=None,
active=lambda age: age >= retirement_age,
states={
"wealth": LinSpacedGrid(start=0, stop=50, n_points=25),
},
functions={"utility": utility_retirement},
)
model = Model(
regimes={
"working_life": working_life,
"retirement": retirement,
},
ages=age_grid,
regime_id_class=RegimeId,
description="A tiny three-period consumption-savings model.",
)Parameters¶
Use model.get_params_template() to see what parameters the model expects, organized
by regime and function.
pprint(dict(model.get_params_template())){'retirement': {'utility': {'risk_aversion': 'float'}},
'working_life': {'H': {'discount_factor': 'FloatND'},
'borrowing_constraint_working': {},
'earnings': {'wage': 'float'},
'end_of_period_wealth': {},
'next_regime': {'last_working_age': 'float'},
'next_wealth': {'interest_rate': 'float'},
'taxes_transfers': {'consumption_floor': 'float',
'tax_rate': 'float'},
'utility': {'disutility_of_work': 'float',
'risk_aversion': 'float'}}}
Parameters shared across regimes (risk_aversion, discount_factor,
interest_rate) can be specified at the model level. Parameters unique to one
regime go under the regime name.
params = {
"discount_factor": 0.95,
"risk_aversion": 1.5,
"interest_rate": 0.03,
"working_life": {
"utility": {"disutility_of_work": 1.0},
"earnings": {"wage": 20.0},
"taxes_transfers": {"consumption_floor": 2.0, "tax_rate": 0.2},
"next_regime": {"last_working_age": age_grid.exact_values[-2]},
},
}Simulate¶
n_agents = 100
initial_df = pd.DataFrame(
{
"regime": "working_life",
"age": float(age_grid.exact_values[0]),
"wealth": np.linspace(1, 20, n_agents),
}
)
result = model.simulate(
params=params,
initial_conditions=initial_df,
period_to_regime_to_V_arr=None,
log_level="debug",
)---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[7], line 11
7 "wealth": np.linspace(1, 20, n_agents),
8 }
9 )
10
---> 11 result = model.simulate(
12 params=params,
13 initial_conditions=initial_df,
14 period_to_regime_to_V_arr=None,
File <@beartype(lcm.model.Model.simulate) at 0x71cb9ad231c0>:242, in simulate(__beartype_object_104192909090976, __beartype_object_125119301232192, __beartype_get_violation, __beartype_conf, __beartype_object_104192959655360, __beartype_object_125119302468928, __beartype_object_125119382526576, __beartype_object_125119301027776, __beartype_object_125120974336816, __beartype_object_125120982585952, __beartype_object_104192634617408, __beartype_object_104192634602840, __beartype_object_125119299442368, __beartype_object_125119289434624, __beartype_object_104192972598160, __beartype_check_meta, __beartype_func, *args, **kwargs)
File ~/checkouts/readthedocs.org/user_builds/pylcm/checkouts/361/src/lcm/model.py:482, in Model.simulate(self, params, initial_conditions, period_to_regime_to_V_arr, log_level, seed, log_path, log_keep_n_latest, max_compilation_workers)
480 log = get_logger(log_level=log_level)
481 if isinstance(initial_conditions, pd.DataFrame):
--> 482 initial_conditions = initial_conditions_from_dataframe(
483 df=initial_conditions,
484 user_regimes=self.user_regimes,
485 regime_names_to_ids=self.regime_names_to_ids,
486 )
487 initial_conditions = canonicalize_initial_conditions(
488 initial_conditions=initial_conditions,
489 regimes=self._regimes,
490 )
491 flat_params = self._process_params(params)
File <@beartype(_lcm.pandas_utils.initial_conditions_from_dataframe) at 0x71cb9b82dc70>:74, in initial_conditions_from_dataframe(__beartype_object_104192959655360, __beartype_get_violation, __beartype_conf, __beartype_object_104192909090976, __beartype_object_104192971091968, __beartype_object_104192634362016, __beartype_object_125119305598848, __beartype_object_125119382526576, __beartype_check_meta, __beartype_func, *args, **kwargs)
File ~/checkouts/readthedocs.org/user_builds/pylcm/checkouts/361/src/_lcm/pandas_utils.py:79, in initial_conditions_from_dataframe(df, user_regimes, regime_names_to_ids)
77 if "regime_name" not in df.columns:
78 msg = "DataFrame must contain a 'regime_name' column."
---> 79 raise ValueError(msg)
81 if len(df) == 0:
82 msg = "DataFrame must not be empty."
ValueError: DataFrame must contain a 'regime_name' column.df = result.to_dataframe(additional_targets="all")
df["age"] = df["age"].astype(int)
df.loc[df["age"] == retirement_age, "consumption"] = df.loc[
df["age"] == retirement_age, "wealth"
]
columns = [
"regime",
"labor_supply",
"consumption",
"wealth",
"earnings",
"taxes_transfers",
"end_of_period_wealth",
"value",
]
df.set_index(["subject_id", "age"])[columns].head(20).style.format(
precision=1,
na_rep="",
)Source
# Classify agents by work pattern across the two working-life periods
first_working_age = age_grid.exact_values[0]
last_working_age = age_grid.exact_values[-2]
df_working = df[df["regime"] == "working_life"]
work_by_age = df_working.pivot_table(
index="subject_id",
columns="age",
values="labor_supply",
aggfunc="first",
)
work_pattern = (
work_by_age[first_working_age].astype(str)
+ ", "
+ work_by_age[last_working_age].astype(str)
)
assert "work, work" not in work_pattern.to_numpy(), (
"Plotting assumes that no agent works in both periods of working life."
)
label_map = {
"work, do_not_work": "low", # work early, not later
"do_not_work, work": "medium", # coast early, work later
"do_not_work, do_not_work": "high", # never work
}
groups = work_pattern.map(label_map).rename("initial_wealth")
# Combined descriptives and work decisions table
initial_wealth = df[df["age"] == first_working_age].set_index("subject_id")["wealth"]
group_desc = initial_wealth.groupby(groups).agg(["min", "max"]).round(1)
df_groups = df.copy()
df_groups["initial_wealth"] = df_groups["subject_id"].map(groups)
df_mean = df_groups.groupby(["initial_wealth", "age"], as_index=False).mean(
numeric_only=True,
)
work_table = df_mean[df_mean["age"] < retirement_age].pivot_table(
index="initial_wealth",
columns="age",
values="earnings",
)
work_table = (work_table > 0).astype(int)
work_table.columns = [f"works {c}" for c in work_table.columns]
summary = pd.concat([group_desc, work_table], axis=1)
summary.index.name = "initial_wealth"
summary.loc[["low", "medium", "high"]].style.format(precision=1, na_rep="")Source
fig = px.line(
df_mean,
x="age",
y="consumption",
color="initial_wealth",
title="Consumption by Age",
template="plotly_dark",
)
fig.show()Source
fig = px.line(
df_mean,
x="age",
y="wealth",
color="initial_wealth",
title="Wealth by Age",
template="plotly_dark",
)
fig.show()