Introducing a custom Reinforcement Learning - Sequential Sampling Model (RLSSM) into HSSM¶
Before proceeding with this tutorial, we recommend going through the main RLSSM tutorial (Tutorial for hierarchical Bayesian inference for Reinforcement Learning - Sequential Sampling Models for a general familiarity with RLSSM modeling in HSSM.
This tutorial demonstrates how to incorporate custom, user-defined Reinforcement Learning Sequential Sampling Models (RLSSM) into the HSSM framework by modifying the rldm.py
file. The tutorial walks through the key steps needed to define, implement, and integrate your custom RLSSM likelihood functions.
More specifically, this tutorial shows how to add a custom model on a 2-armed bandit environment. Our model employs simple Rescorla-Wagner-style learning updates with two learning rates - for positive and negative prediction errors. The decision process is a drift-diffusion model with collapsing bounds ('angle' model).
⚠️ Warning¶
You will have to edit rldm.py
file and re-install hssm for this tutorial to work. Running this notebook without changes to the rldm.py
file will throw errors.
# Import necessary libraries
import numpy as np
import arviz as az
import matplotlib.pyplot as plt
from functools import partial
# Import HSSM and simulator package
import hssm
from hssm.utils import decorate_atomic_simulator
from hssm.likelihoods.rldm import make_rldm_logp_op
from hssm.distribution_utils.dist import make_hssm_rv
from ssms.basic_simulators.simulator import simulator
# Set the style for the plots
plt.style.use('seaborn-v0_8-dark-palette')
Section I: Introduce the likelihood function for your custom model in HSSM¶
In the existing implementation, rldm.py
already defines the model config (Step 1) and the likelihood function (Step 3) for the model mentioned above. The model is called rlssm2
in the existing implementation. An easy way to follow this tutorial is simply considering rlssm2
as your custom model that you want to incoporate into HSSM. Make the necessary changes outlined in steps 2, 4 and 5 to readily use the implemented model.
Step 1: Define Your Custom RLSSM Model Configuration¶
Location: Add to the rlssm_model_config_list
dictionary
Purpose: Define the meta-data and parameters of your custom model in a configuration dictionary.
Details:
- Create a new entry in the
rlssm_model_config_list
dictionary with a unique model name - Specify the following required fields:
name
: Name your custom rlssm modeldescription
: Optional description of the modeln_params
: Number of model parametersn_extra_fields
: Number of extra_fields columns passed to the likelihood function (e.g., trial, feedback)list_params
: List of parameter names in the order they'll be passed to the likelihood functionextra_fields
: List of extra_fields columnsdecision_model
: Type of the likelihood for the decision process model (typically "LAN" - likelihood approximation networks)LAN
: Specific LAN model to use (e.g., "angle")
Example:
"my_custom_rlssm": {
"name": "my_custom_rlssm",
"description": "Custom RLSSM with special features",
"n_params": 8,
"n_extra_fields": 3,
"list_params": ["param1", "param2", "param3", ...],
"extra_fields": ["extra_fields1", "extra_fields2", "extra_fields3", ...],
"decision_model": "LAN",
"LAN": "angle",
}
Step 2: Specify Which Model Configuration to Use¶
Location: Update the MODEL_NAME
variable
Purpose: Inform the HSSM package which model configuration to use from your defined list.
Details:
Set
MODEL_NAME
to match one of the keys in yourrlssm_model_config_list
dictionaryThe system will automatically load the corresponding configuration
This makes it easy to switch between different model variants
Example:
MODEL_NAME = "my_custom_rlssm" # Must match a key in rlssm_model_config_list
Step 3: Implement Your Custom Likelihood Function¶
Location: Create a new function following the naming pattern {model_name}_logp_inner_func
Purpose: Define the core computational logic for your RLSSM model. See the existing implementation in rldm.py
for template details.
Details:
Function signature: Must follow the pattern:
def my_custom_rlssm_logp_inner_func( subj, ntrials_subj, data, *model_params, # Your specific parameters *extra_fields, # Additional data fields ):
Input requirements:
subj
: Subject indexntrials_subj
: Number of trials per subjectdata
: RT and response data matrix- Parameters must match the order specified in
list_params
- Extra fields must match those specified in
extra_fields
Output requirements:
- Return a 1D array of log likelihoods for each trial
- Must be differentiable for gradient-based sampling/optimization
Implementation notes:
- Use JAX operations for automatic differentiation
- Handle parameter slicing for individual subjects using
dynamic_slice
- Implement your specific RL update rules and decision mechanisms
- Structure the LAN matrix according to your decision model requirements
Step 4: Update the Vectorized Function Reference¶
Location: Modify the rldm_logp_inner_func_vmapped
assignment
Purpose: Enable parallel computation across multiple subjects.
Details:
- Update the function reference to point to your custom likelihood function
- The vectorization pattern remains the same - only the first argument (subject index) gets vectorized
- Ensure the
total_params
count includes all parameters plus extra fields plus data columns
Example:
rldm_logp_inner_func_vmapped = jax.vmap(
my_custom_rlssm_logp_inner_func, # Update this to your function name
in_axes=[0] + [None] * total_params,
)
Step 5: Update Parameter Unpacking in the Main Likelihood Function¶
Location: Modify the logp
function within make_logp_func
Purpose: Ensure parameters are correctly extracted and passed to your custom function.
Details:
Parameter extraction: Adjust the indexing to match your model's parameter structure:
# Extract extra fields (adjust indices based on your model) participant_id = dist_params[n_params] # Usually after all model params trial = dist_params[n_params + 1] feedback = dist_params[n_params + 2] # ... additional extra fields as needed # Extract model parameters param1, param2, ..., paramN = dist_params[:MODEL_CONFIG["n_params"]]
Function call: Update the
vec_logp
call to pass parameters in the correct order:return vec_logp( subj, n_trials, data, param1, # Your specific parameters param2, # ... paramN, trial, # Extra fields feedback, # ... )
Step 6: Verify VJP Function Compatibility¶
Location: Check the vjp_logp
function within make_vjp_logp_func
Purpose: Ensure gradient computation works correctly with your parameter structure.
Details:
- The VJP (Vector-Jacobian Product) function should automatically work with your custom model
- The slicing
[1:MODEL_CONFIG["n_params"] + 1]
excludes the data and extra fields from gradient computation - No changes typically needed unless you have special gradient requirements
Notes:
- VJP is used for automatic differentiation during MCMC sampling
- The function computes gradients with respect to model parameters only
- Extra fields (trial, feedback, etc.) are not differentiated
- If your model has unusual parameter dependencies, you may need custom gradient handling
Section II: Using the custom model with HSSM¶
Load and prepare the demo dataset¶
This data file contains (synthetic) data from a simulated 2-armed bandit task. We examine the dataset -- it contains the typical columns that are expected from a canonical instrumental learning task. participant_id
identifies the subject id, trial
identifies the sequence of trials within the subject data, response
and rt
are the data columns recorded for each trial, feedback
column shows the reward obtained on a given trial and correct
records whether the response was correct.
# Load synthetic RLSSM dataset containing both behavioral data and ground truth parameters
savefile = np.load("../../tests/fixtures/rldm_data.npy", allow_pickle=True).item()
dataset = savefile['data']
# Rename trial column to match HSSM conventions
dataset.rename(columns={'trial': 'trial_id'}, inplace=True)
# Examine the dataset structure
dataset.head()
participant_id | trial_id | response | rt | feedback | correct | |
---|---|---|---|---|---|---|
0 | 0 | 0 | 0.0 | 0.935602 | 0.126686 | 0.0 |
1 | 0 | 1 | 0.0 | 1.114379 | 0.173100 | 0.0 |
2 | 0 | 2 | 0.0 | 0.564311 | 0.444935 | 0.0 |
3 | 0 | 3 | 0.0 | 2.885860 | 0.307207 | 0.0 |
4 | 0 | 4 | 0.0 | 0.532113 | 0.177911 | 0.0 |
# Validate data structure and extract dataset configuration
dataset, n_participants, n_trials = hssm.check_data_for_rl(dataset)
print(f"Number of participants: {n_participants}")
print(f"Number of trials: {n_trials}")
Number of participants: 20 Number of trials: 200
Construct HSSM-compatible PyMC distribution from a simulator and JAX likelihood callable¶
We now construct a custom model that is compatible with HSSM and PyMC. Note that HSSM internally constructs a PyMC object (which is used for sampling) based on the user-specified HSSM model. In other words, we are peeling the abstration layers conveniently afforded by HSSM to directly use the core machinery of HSSM. This advanced HSSM tutorial explains how to use HSSM when starting from the very basics of a model -- a simulator and a JAX likelihood callable.
The simulator function is used for generating samples from the model (for posterior predictives, etc.) and the likelihood callable is employed for sampling/inference. This preview tutorial exposes the key flexibility of the HSSM for use in fitting RLSSM models. Therefore, the subsequent tutorial will focus only on the sampling/inference aspect. We create a dummy simulator function to bypass the need for defining the actual simulator.
Step 1: Define a pytensor RandomVariable¶
# Define parameters for the RLSSM model (RL + decision model parameters)
list_params = ['rl.alpha', 'rl.alpha_neg', 'scaler', 'a', 'z', 't', 'theta']
# Create a dummy simulator for generating synthetic data (used for posterior predictives)
# This bypasses the need for a full RLSSM simulator implementation
def create_dummy_simulator():
"""Create a dummy simulator function for RLSSM model."""
def sim_wrapper(simulator_fun, theta, model, n_samples, random_state, **kwargs):
# Generate random RT and choice data as placeholders
sim_rt = np.random.uniform(0.2, 0.6, n_samples)
sim_ch = np.random.randint(0, 2, n_samples)
return np.column_stack([sim_rt, sim_ch])
# Wrap the simulator function with required metadata
wrapped_simulator = partial(sim_wrapper, simulator_fun=simulator, model="custom", n_samples=1)
# Decorate the simulator to make it compatible with HSSM
return decorate_atomic_simulator(model_name="custom", choices=[0, 1], obs_dim=2)(wrapped_simulator)
# Create the simulator and RandomVariable
decorated_simulator = create_dummy_simulator()
# Create a PyTensor RandomVariable using `make_hssm_rv` for use in the PyMC model
CustomRV = make_hssm_rv(
simulator_fun=decorated_simulator, list_params=list_params
)
Step 2: Define a likelihood function¶
# Create a Pytensor Op for the likelihood function.
# The `make_rldm_logp_op` function is a utility that wraps the base JAX likelihood function into a HSSM/PyMC-compatible callable.
logp_jax_op = make_rldm_logp_op(
n_participants=n_participants,
n_trials=n_trials,
n_params=len(list_params),
)
# Test the likelihood function
def extract_data_columns(dataset):
"""Extract required data columns from dataset."""
return {
'participant_id': dataset["participant_id"].values,
'trial': dataset["trial_id"].values,
'response': dataset["response"].values,
'feedback': dataset["feedback"].values,
'rt': dataset["rt"].values
}
def create_test_parameters(n_trials):
"""Create dummy parameters for testing the likelihood function."""
return {
'rl_alpha': np.ones(n_trials) * 0.60,
'rl_alpha_neg': np.ones(n_trials) * 0.60,
'scaler': np.ones(n_trials) * 3.2,
'a': np.ones(n_trials) * 1.2,
'z': np.ones(n_trials) * 0.1,
't': np.ones(n_trials) * 0.1,
'theta': np.ones(n_trials) * 0.1
}
# Extract data and create test parameters
data_columns = extract_data_columns(dataset)
num_subj = len(np.unique(data_columns['participant_id']))
n_trials_total = num_subj * 200
test_params = create_test_parameters(n_trials_total)
# Evaluate the likelihood function
test_logp_out = logp_jax_op(
np.column_stack((data_columns['rt'], data_columns['response'])),
test_params['rl_alpha'],
test_params['rl_alpha_neg'],
test_params['scaler'],
test_params['a'],
test_params['z'],
test_params['t'],
test_params['theta'],
data_columns['participant_id'],
data_columns['trial'],
data_columns['feedback'],
)
LL = test_logp_out.eval()
print(f"Log likelihood: {np.sum(LL):.4f}")
Log likelihood: -6879.1526
Step 3: Define a model config and HSSM model¶
# Step 3: Define the model config
# Configure the HSSM model
model_config = hssm.ModelConfig(
response=["rt", "response"], # Dependent variables (RT and choice)
list_params= # List of model parameters
['rl.alpha', 'rl.alpha_neg', 'scaler', 'a', 'z', 't', 'theta'],
choices=[0, 1], # Possible choice options
default_priors={}, # Use custom priors (defined below)
bounds=dict( # Parameter bounds for optimization
rl_alpha=(0.01, 1), # Learning rate bounds
rl_alpha_neg=(0.01, 1), # Negative learning rate bounds
scaler=(1, 4), # Scaler bounds
a=(0.3, 2.5), # Boundary separation bounds
z=(0.1, 0.9), # Bias bounds
t=(0.1, 2.0), # Non-decision time bounds
theta=(0.0, 1.2) # Collapse rate bounds
),
rv=CustomRV, # Custom RandomVariable that we created earlier
extra_fields=[ # Additional data columns to be passed to the likelihood function as extra_fields
"participant_id",
"trial_id",
"feedback"],
backend="jax" # Use JAX for computation
)
# Create a hierarchical HSSM model with custom likelihood function
hssm_model = hssm.HSSM(
data=dataset, # Input dataset
model_config=model_config, # Model configuration
p_outlier=0, # No outlier modeling
lapse=None, # No lapse rate modeling
loglik=logp_jax_op, # Custom RLDM likelihood function
loglik_kind="approx_differentiable", # Use approximate gradients
noncentered=True, # Use non-centered parameterization
process_initvals=False, # Skip initial value processing in HSSM
include=[
# Define hierarchical priors: group-level intercepts + subject-level random effects
hssm.Param("rl.alpha",
formula="rl_alpha ~ 1 + (1|participant_id)",
prior={"Intercept": hssm.Prior("TruncatedNormal", lower=0.01, upper=1, mu=0.3)}),
hssm.Param("rl.alpha_neg",
formula="rl_alpha_neg ~ 1 + (1|participant_id)",
prior={"Intercept": hssm.Prior("TruncatedNormal", lower=0.01, upper=1, mu=0.3)}),
hssm.Param("scaler",
formula="scaler ~ 1 + (1|participant_id)",
prior={"Intercept": hssm.Prior("TruncatedNormal", lower=1, upper=4, mu=1.5)}),
hssm.Param("a",
formula="a ~ 1 + (1|participant_id)",
prior={"Intercept": hssm.Prior("TruncatedNormal", lower=0.3, upper=2.5, mu=1.0)}),
hssm.Param("z",
formula="z ~ 1 + (1|participant_id)",
prior={"Intercept": hssm.Prior("TruncatedNormal", lower=0.1, upper=0.9, mu=0.2)}),
hssm.Param("t",
formula="t ~ 1 + (1|participant_id)",
prior={"Intercept": hssm.Prior("TruncatedNormal", lower=0.01, upper=2, mu=0.2, initval=0.1)}),
hssm.Param("theta",
formula="theta ~ 1 + (1|participant_id)",
prior={"Intercept": hssm.Prior("TruncatedNormal", lower=0.00, upper=1.2, mu=0.3)}),
]
)
No common intercept. Bounds for parameter rl.alpha is not applied due to a current limitation of Bambi. This will change in the future. No common intercept. Bounds for parameter rl.alpha_neg is not applied due to a current limitation of Bambi. This will change in the future. No common intercept. Bounds for parameter scaler is not applied due to a current limitation of Bambi. This will change in the future. No common intercept. Bounds for parameter a is not applied due to a current limitation of Bambi. This will change in the future. No common intercept. Bounds for parameter z is not applied due to a current limitation of Bambi. This will change in the future. No common intercept. Bounds for parameter t is not applied due to a current limitation of Bambi. This will change in the future. No common intercept. Bounds for parameter theta is not applied due to a current limitation of Bambi. This will change in the future. Model initialized successfully.
Sample using NUTS MCMC¶
# Run MCMC sampling using NUTS sampler with JAX backend
# Note: Using small number of samples for demonstration (increase for real analysis)
idata_mcmc = hssm_model.sample(
sampler='nuts_numpyro', # JAX-based NUTS sampler for efficiency
chains=1, # Number of parallel chains
draws=1000, # Number of posterior samples
tune=1000, # Number of tuning/warmup samples
)
Using default initvals.
sample: 100%|██████████| 2000/2000 [08:33<00:00, 3.89it/s, 29 steps of size 1.56e-01. acc. prob=0.92] There were 1 divergences after tuning. Increase `target_accept` or reparameterize. Only one chain was sampled, this makes it impossible to run some convergence checks /Users/krishnbera/Documents/revert_rldm/HSSM/.venv/lib/python3.12/site-packages/pymc/pytensorf.py:958: FutureWarning: compile_pymc was renamed to compile. Old name will be removed in a future release of PyMC warnings.warn( 100%|██████████| 1000/1000 [00:02<00:00, 366.61it/s]
idata_mcmc
-
<xarray.Dataset> Size: 2MB Dimensions: (chain: 1, draw: 1000, participant_id__factor_dim: 20, rl.alpha_1|participant_id__factor_dim: 20) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 8kB 0 1 2 ... 998 999 * rl.alpha_1|participant_id__factor_dim (rl.alpha_1|participant_id__factor_dim) <U2 160B ... * participant_id__factor_dim (participant_id__factor_dim) <U2 160B ... Data variables: (12/35) rl.alpha_neg_1|participant_id (chain, draw, participant_id__factor_dim) float64 160kB ... rl.alpha_neg_1|participant_id_sigma (chain, draw) float64 8kB 0.09328 ... a_1|participant_id_mu (chain, draw) float64 8kB -0.2027 ... rl.alpha_1|participant_id_sigma (chain, draw) float64 8kB 0.03706 ... scaler_1|participant_id_offset (chain, draw, participant_id__factor_dim) float64 160kB ... a_1|participant_id (chain, draw, participant_id__factor_dim) float64 160kB ... ... ... theta_1|participant_id_mu (chain, draw) float64 8kB -0.2717 ... a_Intercept (chain, draw) float64 8kB 1.491 ..... scaler_Intercept (chain, draw) float64 8kB 2.639 ..... scaler_1|participant_id_mu (chain, draw) float64 8kB 0.3354 .... rl.alpha_1|participant_id_mu (chain, draw) float64 8kB 0.02453 ... theta_1|participant_id (chain, draw, participant_id__factor_dim) float64 160kB ... Attributes: created_at: 2025-07-15T19:55:45.042734+00:00 arviz_version: 0.22.0 inference_library: numpyro inference_library_version: 0.18.0 sampling_time: 517.168047 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 32MB Dimensions: (chain: 1, draw: 1000, __obs__: 4000) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999 * __obs__ (__obs__) int64 32kB 0 1 2 3 4 5 ... 3995 3996 3997 3998 3999 Data variables: rt,response (chain, draw, __obs__) float64 32MB -2.636 -3.071 ... -0.2937 Attributes: modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 57kB Dimensions: (chain: 1, draw: 1000) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: acceptance_rate (chain, draw) float64 8kB 0.9375 0.9946 ... 0.9262 0.9207 step_size (chain, draw) float64 8kB 0.1562 0.1562 ... 0.1562 0.1562 diverging (chain, draw) bool 1kB False False False ... False True energy (chain, draw) float64 8kB 4.153e+03 4.145e+03 ... 4.14e+03 n_steps (chain, draw) int64 8kB 31 31 31 31 31 ... 31 31 31 31 29 tree_depth (chain, draw) int64 8kB 5 5 5 5 5 5 5 5 ... 5 5 5 5 5 5 5 5 lp (chain, draw) float64 8kB 4.075e+03 4.063e+03 ... 4.057e+03 Attributes: created_at: 2025-07-15T19:55:45.056390+00:00 arviz_version: 0.22.0 modeling_interface: bambi modeling_interface_version: 0.15.0
-
<xarray.Dataset> Size: 96kB Dimensions: (__obs__: 4000, rt,response_extra_dim_0: 2) Coordinates: * __obs__ (__obs__) int64 32kB 0 1 2 3 ... 3997 3998 3999 * rt,response_extra_dim_0 (rt,response_extra_dim_0) int64 16B 0 1 Data variables: rt,response (__obs__, rt,response_extra_dim_0) float64 64kB ... Attributes: created_at: 2025-07-15T19:55:45.057017+00:00 arviz_version: 0.22.0 inference_library: numpyro inference_library_version: 0.18.0 sampling_time: 517.168047 tuning_steps: 1000 modeling_interface: bambi modeling_interface_version: 0.15.0
Assess the model fits¶
We examine the quality of fits by comparing the recovered parameters with the ground-truth data generating parameters of the simulated dataset. We examine the quality of fits both at group-level as well as subject-level.
Examining group-level posteriors¶
# Define parameter names for analysis
list_group_mean_params = [
"rl.alpha_Intercept",
"rl.alpha_neg_Intercept",
"scaler_Intercept",
"a_Intercept",
"z_Intercept",
"t_Intercept",
"theta_Intercept",
]
list_group_sd_params = [
"rl.alpha_1|participant_id_sigma",
"rl.alpha_neg_1|participant_id_sigma",
"scaler_1|participant_id_sigma",
"a_1|participant_id_sigma",
"z_1|participant_id_sigma",
"t_1|participant_id_sigma",
"theta_1|participant_id_sigma",
]
# Create mapping from HSSM model parameter names to ground truth values.
def create_ground_truth_mapping(savefile):
return {
"rl.alpha_Intercept": savefile['params_true_group']['rl_alpha_mean'],
"scaler_Intercept": savefile['params_true_group']['scaler_mean'],
"a_Intercept": savefile['params_true_group']['a_mean'],
"z_Intercept": savefile['params_true_group']['z_mean'],
"t_Intercept": savefile['params_true_group']['t_mean'],
"theta_Intercept": savefile['params_true_group']['theta_mean'],
}
ground_truth_params = create_ground_truth_mapping(savefile)
print("Ground truth group means:\n")
for param, value in ground_truth_params.items():
print(f"{param}: {value:.3f}")
Ground truth group means: rl.alpha_Intercept: 0.660 scaler_Intercept: 2.785 a_Intercept: 1.481 z_Intercept: 0.267 t_Intercept: 0.443 theta_Intercept: 0.312
# Plot posterior distributions and MCMC traces for group-level parameters
# Vertical lines show ground truth values for parameter recovery assessment
az.plot_trace(idata_mcmc, var_names=list_group_mean_params,
lines=[(key_, {}, ground_truth_params[key_]) for key_ in ground_truth_params])
plt.tight_layout()