Skip to content

ssms

basic_simulators special

boundary_functions

Define a collection of boundary functions for the simulators in the package.

angle(t=1, theta=1)

Angle boundary function.

Arguments
t (float or np.ndarray, optional): Time point(s). Defaults to 1.
theta (float, optional): Angle in radians. Defaults to 1.
Returns
np.ndarray: Array of boundary values, same shape as t

conflict_gamma(t=array([ 0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ,1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2. , 2.1,2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3. , 3.1, 3.2,3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4. , 4.1, 4.2, 4.3,4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 5. , 5.1, 5.2, 5.3, 5.4,5.5, 5.6, 5.7, 5.8, 5.9, 6. , 6.1, 6.2, 6.3, 6.4, 6.5,6.6, 6.7, 6.8, 6.9, 7. , 7.1, 7.2, 7.3, 7.4, 7.5, 7.6,7.7, 7.8, 7.9, 8. , 8.1, 8.2, 8.3, 8.4, 8.5, 8.6, 8.7,8.8, 8.9, 9. , 9.1, 9.2, 9.3, 9.4, 9.5, 9.6, 9.7, 9.8,9.9, 10. , 10.1, 10.2, 10.3, 10.4, 10.5, 10.6, 10.7, 10.8, 10.9,11. , 11.1, 11.2, 11.3, 11.4, 11.5, 11.6, 11.7, 11.8, 11.9, 12. ,12.1, 12.2, 12.3, 12.4, 12.5, 12.6, 12.7, 12.8, 12.9, 13. , 13.1,13.2, 13.3, 13.4, 13.5, 13.6, 13.7, 13.8, 13.9, 14. , 14.1, 14.2,14.3, 14.4, 14.5, 14.6, 14.7, 14.8, 14.9, 15. , 15.1, 15.2, 15.3,15.4, 15.5, 15.6, 15.7, 15.8, 15.9, 16. , 16.1, 16.2, 16.3, 16.4,16.5, 16.6, 16.7, 16.8, 16.9, 17. , 17.1, 17.2, 17.3, 17.4, 17.5,17.6, 17.7, 17.8, 17.9, 18. , 18.1, 18.2, 18.3, 18.4, 18.5, 18.6,18.7, 18.8, 18.9, 19. , 19.1, 19.2, 19.3, 19.4, 19.5, 19.6, 19.7,19.8, 19.9]), theta=0.5, scale=1, alpha_gamma=1.01, scale_gamma=0.3)

Conflict bound that allows initial divergence then collapse.

Arguments
!!! t "(float, np.ndarray)"
    Time points (with arbitrary measure, but in HDDM it is used as seconds),
    at which to evaluate the bound. Defaults to np.arange(0, 20, 0.1).
!!! theta "float"
    Collapse angle. Defaults to 0.5.
!!! scale "float"
    Scaling the gamma distribution of the boundary
    (since bound does not have to integrate to one). Defaults to 1.0.
!!! alpha_gamma "float"
    alpha parameter for a gamma in scale shape parameterization. Defaults to

constant(t=0)

Constant boundary function.

Arguments
t (float or np.ndarray, optional): Time point(s). Defaults to 0.
Returns
float or np.ndarray: Constant boundary value(s), same shape as t

generalized_logistic(t=1, B=2.0, M=3.0, v=0.5)

Generalized logistic bound.

Arguments
t (float or np.ndarray, optional): Time point(s). Defaults to 1.
B (float, optional): Growth rate. Defaults to 2.0.
M (float, optional): Time of maximum growth. Defaults to 3.0.
v (float, optional): Affects near which asymptote maximum growth occurs.
Defaults to 0.5.
Returns
np.ndarray: Array of boundary values, same shape as t

weibull_cdf(t=1, alpha=1, beta=1)

Boundary based on weibull survival function.

Arguments
t (float or np.ndarray, optional): Time point(s). Defaults to 1.
alpha (float, optional): Shape parameter. Defaults to 1.
beta (float, optional): Scale parameter. Defaults to 1.
Returns
np.ndarray: Array of boundary values, same shape as t

drift_functions

Define a collection of drift functions for the simulators in the package.

attend_drift(t=array([ 0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ,1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2. , 2.1,2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3. , 3.1, 3.2,3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4. , 4.1, 4.2, 4.3,4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 5. , 5.1, 5.2, 5.3, 5.4,5.5, 5.6, 5.7, 5.8, 5.9, 6. , 6.1, 6.2, 6.3, 6.4, 6.5,6.6, 6.7, 6.8, 6.9, 7. , 7.1, 7.2, 7.3, 7.4, 7.5, 7.6,7.7, 7.8, 7.9, 8. , 8.1, 8.2, 8.3, 8.4, 8.5, 8.6, 8.7,8.8, 8.9, 9. , 9.1, 9.2, 9.3, 9.4, 9.5, 9.6, 9.7, 9.8,9.9, 10. , 10.1, 10.2, 10.3, 10.4, 10.5, 10.6, 10.7, 10.8, 10.9,11. , 11.1, 11.2, 11.3, 11.4, 11.5, 11.6, 11.7, 11.8, 11.9, 12. ,12.1, 12.2, 12.3, 12.4, 12.5, 12.6, 12.7, 12.8, 12.9, 13. , 13.1,13.2, 13.3, 13.4, 13.5, 13.6, 13.7, 13.8, 13.9, 14. , 14.1, 14.2,14.3, 14.4, 14.5, 14.6, 14.7, 14.8, 14.9, 15. , 15.1, 15.2, 15.3,15.4, 15.5, 15.6, 15.7, 15.8, 15.9, 16. , 16.1, 16.2, 16.3, 16.4,16.5, 16.6, 16.7, 16.8, 16.9, 17. , 17.1, 17.2, 17.3, 17.4, 17.5,17.6, 17.7, 17.8, 17.9, 18. , 18.1, 18.2, 18.3, 18.4, 18.5, 18.6,18.7, 18.8, 18.9, 19. , 19.1, 19.2, 19.3, 19.4, 19.5, 19.6, 19.7,19.8, 19.9]), ptarget=-0.3, pouter=-0.3, pinner=0.3, r=0.5, sda=2)

Shrink spotlight model, which involves a time varying function dependent on a linearly decreasing standard deviation of attention.

Arguments
!!! t "np.ndarray"
    Timepoints at which to evaluate the drift.
    Usually np.arange() of some sort.
!!! pouter "float"
    perceptual input for outer flankers
!!! pinner "float"
    perceptual input for inner flankers
!!! ptarget "float"
    perceptual input for target flanker
!!! r "float"
    rate parameter for sda decrease
!!! sda "float"
    width of attentional spotlight
Return

np.ndarray Drift evaluated at timepoints t

attend_drift_simple(t=array([ 0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ,1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2. , 2.1,2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3. , 3.1, 3.2,3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4. , 4.1, 4.2, 4.3,4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 5. , 5.1, 5.2, 5.3, 5.4,5.5, 5.6, 5.7, 5.8, 5.9, 6. , 6.1, 6.2, 6.3, 6.4, 6.5,6.6, 6.7, 6.8, 6.9, 7. , 7.1, 7.2, 7.3, 7.4, 7.5, 7.6,7.7, 7.8, 7.9, 8. , 8.1, 8.2, 8.3, 8.4, 8.5, 8.6, 8.7,8.8, 8.9, 9. , 9.1, 9.2, 9.3, 9.4, 9.5, 9.6, 9.7, 9.8,9.9, 10. , 10.1, 10.2, 10.3, 10.4, 10.5, 10.6, 10.7, 10.8, 10.9,11. , 11.1, 11.2, 11.3, 11.4, 11.5, 11.6, 11.7, 11.8, 11.9, 12. ,12.1, 12.2, 12.3, 12.4, 12.5, 12.6, 12.7, 12.8, 12.9, 13. , 13.1,13.2, 13.3, 13.4, 13.5, 13.6, 13.7, 13.8, 13.9, 14. , 14.1, 14.2,14.3, 14.4, 14.5, 14.6, 14.7, 14.8, 14.9, 15. , 15.1, 15.2, 15.3,15.4, 15.5, 15.6, 15.7, 15.8, 15.9, 16. , 16.1, 16.2, 16.3, 16.4,16.5, 16.6, 16.7, 16.8, 16.9, 17. , 17.1, 17.2, 17.3, 17.4, 17.5,17.6, 17.7, 17.8, 17.9, 18. , 18.1, 18.2, 18.3, 18.4, 18.5, 18.6,18.7, 18.8, 18.9, 19. , 19.1, 19.2, 19.3, 19.4, 19.5, 19.6, 19.7,19.8, 19.9]), ptarget=-0.3, pouter=-0.3, r=0.5, sda=2)

Drift function for shrinking spotlight model, which involves a time varying function dependent on a linearly decreasing standard deviation of attention.

Arguments
!!! t "np.ndarray"
    Timepoints at which to evaluate the drift.
    Usually np.arange() of some sort.
!!! pouter "float"
    perceptual input for outer flankers
!!! ptarget "float"
    perceptual input for target flanker
!!! r "float"
    rate parameter for sda decrease
!!! sda "float"
    width of attentional spotlight
Return

np.ndarray Drift evaluated at timepoints t

constant(t=array([ 0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ,1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2. , 2.1,2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3. , 3.1, 3.2,3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4. , 4.1, 4.2, 4.3,4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 5. , 5.1, 5.2, 5.3, 5.4,5.5, 5.6, 5.7, 5.8, 5.9, 6. , 6.1, 6.2, 6.3, 6.4, 6.5,6.6, 6.7, 6.8, 6.9, 7. , 7.1, 7.2, 7.3, 7.4, 7.5, 7.6,7.7, 7.8, 7.9, 8. , 8.1, 8.2, 8.3, 8.4, 8.5, 8.6, 8.7,8.8, 8.9, 9. , 9.1, 9.2, 9.3, 9.4, 9.5, 9.6, 9.7, 9.8,9.9, 10. , 10.1, 10.2, 10.3, 10.4, 10.5, 10.6, 10.7, 10.8, 10.9,11. , 11.1, 11.2, 11.3, 11.4, 11.5, 11.6, 11.7, 11.8, 11.9, 12. ,12.1, 12.2, 12.3, 12.4, 12.5, 12.6, 12.7, 12.8, 12.9, 13. , 13.1,13.2, 13.3, 13.4, 13.5, 13.6, 13.7, 13.8, 13.9, 14. , 14.1, 14.2,14.3, 14.4, 14.5, 14.6, 14.7, 14.8, 14.9, 15. , 15.1, 15.2, 15.3,15.4, 15.5, 15.6, 15.7, 15.8, 15.9, 16. , 16.1, 16.2, 16.3, 16.4,16.5, 16.6, 16.7, 16.8, 16.9, 17. , 17.1, 17.2, 17.3, 17.4, 17.5,17.6, 17.7, 17.8, 17.9, 18. , 18.1, 18.2, 18.3, 18.4, 18.5, 18.6,18.7, 18.8, 18.9, 19. , 19.1, 19.2, 19.3, 19.4, 19.5, 19.6, 19.7,19.8, 19.9]))

Constant drift function.

Arguments
!!! t "np.ndarray, optional"
    Timepoints at which to evaluate the drift. Defaults to
    np.arange(0, 20, 0.1).
Returns
np.ndarray: Array of drift values, same length as t

ds_conflict_drift(t=array([0.000e+00, 1.000e-03, 2.000e-03, ..., 9.997e+00, 9.998e+00,9.999e+00]), tinit=0, dinit=0, tslope=1, dslope=1, tfixedp=1, tcoh=1.5, dcoh=1.5)

This drift is inspired by a conflict task which involves a target and a distractor stimuli both presented simultaneously.

Two drift timecourses are linearly combined weighted by the coherence in the respective target and distractor stimuli. Each timecourse follows a dynamical system as described in the ds_support_analytic() function.

Arguments
!!! t "np.ndarray"
    Timepoints at which to evaluate the drift.
    Usually np.arange() of some sort.
!!! tinit "float"
    Initial condition of target drift timecourse
!!! dinit "float"
    Initial condition of distractor drift timecourse
!!! tslope "float"
    Slope parameter for target drift timecourse
!!! dslope "float"
    Slope parameter for distractor drift timecourse
!!! tfixedp "float"
    Fixed point for target drift timecourse
!!! tcoh "float"
    Coefficient for the target drift timecourse
!!! dcoh "float"
    Coefficient for the distractor drift timecourse
Return

np.ndarray The full drift timecourse evaluated at the supplied timepoints t.

ds_support_analytic(t=array([0.000e+00, 1.000e-03, 2.000e-03, ..., 9.997e+00, 9.998e+00,9.999e+00]), init_p=0, fix_point=1, slope=2)

Solve DE.

DE is of the form: x' = slope*(fix_point - x), with initial condition init_p. The solution takes the form: (init_p - fix_point) * exp(-slope * t) + fix_point

Arguments
!!! t "np.ndarray"
    Timepoints at which to evaluate the drift. Usually np.arange() of some sort.
!!! init_p "float"
    Initial condition of dynamical system
!!! fix_point "float"
    Fixed point of dynamical system
!!! slope "float"
    Coefficient in exponent of the solution.
Return

np.ndarray The gamma drift evaluated at the supplied timepoints t.

gamma_drift(t=array([ 0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ,1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2. , 2.1,2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3. , 3.1, 3.2,3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4. , 4.1, 4.2, 4.3,4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 5. , 5.1, 5.2, 5.3, 5.4,5.5, 5.6, 5.7, 5.8, 5.9, 6. , 6.1, 6.2, 6.3, 6.4, 6.5,6.6, 6.7, 6.8, 6.9, 7. , 7.1, 7.2, 7.3, 7.4, 7.5, 7.6,7.7, 7.8, 7.9, 8. , 8.1, 8.2, 8.3, 8.4, 8.5, 8.6, 8.7,8.8, 8.9, 9. , 9.1, 9.2, 9.3, 9.4, 9.5, 9.6, 9.7, 9.8,9.9, 10. , 10.1, 10.2, 10.3, 10.4, 10.5, 10.6, 10.7, 10.8, 10.9,11. , 11.1, 11.2, 11.3, 11.4, 11.5, 11.6, 11.7, 11.8, 11.9, 12. ,12.1, 12.2, 12.3, 12.4, 12.5, 12.6, 12.7, 12.8, 12.9, 13. , 13.1,13.2, 13.3, 13.4, 13.5, 13.6, 13.7, 13.8, 13.9, 14. , 14.1, 14.2,14.3, 14.4, 14.5, 14.6, 14.7, 14.8, 14.9, 15. , 15.1, 15.2, 15.3,15.4, 15.5, 15.6, 15.7, 15.8, 15.9, 16. , 16.1, 16.2, 16.3, 16.4,16.5, 16.6, 16.7, 16.8, 16.9, 17. , 17.1, 17.2, 17.3, 17.4, 17.5,17.6, 17.7, 17.8, 17.9, 18. , 18.1, 18.2, 18.3, 18.4, 18.5, 18.6,18.7, 18.8, 18.9, 19. , 19.1, 19.2, 19.3, 19.4, 19.5, 19.6, 19.7,19.8, 19.9]), shape=2, scale=0.01, c=1.5)

Drift function that follows a scaled gamma distribution.

Arguments
!!! t "np.ndarray"
    Timepoints at which to evaluate the drift.
    Usually np.arange() of some sort.
!!! shape "float"
    Shape parameter of the gamma distribution
!!! scale "float"
    Scale parameter of the gamma distribution
!!! c "float"
    Scalar parameter that scales the peak of
    the gamma distribution.
    (Note this function follows a gamma distribution
    but does not integrate to 1)
Return
np.ndarray
    The gamma drift evaluated at the supplied timepoints t.

simulator

This module defines the basic simulator function which is the main workshorse of the package. In addition some utility functions are provided that help with preprocessing the output of the simulator function.

bin_arbitrary_fptd(out=None, bin_dt=0.04, nbins=256, nchoices=2, choice_codes=[-1.0, 1.0], max_t=10.0)

Takes in simulator output and returns a histogram of bin counts Arguments


!!! out "np.ndarray"
    Output of the 'simulator' function
bin_dt : float
    If nbins is 0, this determines the desired bin size
    which in turn automatically determines the resulting number of bins.
nbins : int
    Number of bins to bin reaction time data into.
    If supplied as 0, bin_dt instead determines the number of
    bins automatically.
!!! nchoices "int <default=2>"
    Number of choices allowed by the simulator.
choice_codes = list[float] <default=[-1.0, 1.0]>
    Choice labels to be used.
!!! max_t "float"
    Maximum RT to consider.
Returns
2d array (nbins, nchoices): A histogram of bin counts

bin_simulator_output(out=None, bin_dt=0.04, nbins=0, max_t=-1, freq_cnt=False)

Turns RT part of simulator output into bin-identifier by trial

Arguments
out : dict
    Output of the 'simulator' function
bin_dt : float
    If nbins is 0, this determines the desired
    bin size which in turn automatically
    determines the resulting number of bins.
nbins : int
    Number of bins to bin reaction time data into.
    If supplied as 0, bin_dt instead determines the number of
    bins automatically.
max_t : float <default=-1>
    Override the 'max_t' metadata as part of the simulator output.
    Sometimes useful, but usually default will do the job.
freq_cnt : bool <default=False>
    Decide whether to return proportions (default) or counts in bins.
Returns
A histogram of counts or proportions.

bin_simulator_output_pointwise(out=(array([0]), array([0])), bin_dt=0.04, nbins=0)

Turns RT part of simulator output into bin-identifier by trial

Arguments
!!! out "tuple"
    Output of the 'simulator' function
!!! bin_dt "float"
    If nbins is 0, this determines the desired
    bin size which in turn automatically
    determines the resulting number of bins.
!!! nbins "int"
    Number of bins to bin reaction time data into.
    If supplied as 0, bin_dt instead determines the
    number of bins automatically.
Returns
2d array. The first columns collects bin-identifiers
by trial, the second column lists the corresponding choices.

make_boundary_dict(config, theta)

Create a dictionary containing boundary-related parameters and functions.

This function extracts boundary-related parameters from the input theta dictionary, based on the boundary configuration specified in the config. It also retrieves the appropriate boundary function and multiplicative flag from the boundary_config.

Parameters:

Name Type Description Default
config dict

A dictionary containing model configuration, including the boundary name.

required
theta dict

A dictionary of parameter values, potentially including boundary-related parameters.

required

Returns:

Type Description
dict

A dictionary containing: - boundary_params (dict): Extracted boundary-related parameters. - boundary_fun (callable): The boundary function corresponding to the specified boundary name. - boundary_multiplicative (bool): Flag indicating if the boundary is multiplicative.

make_drift_dict(config, theta)

Create a dictionary containing drift-related parameters and functions.

This function extracts drift-related parameters from the input theta dictionary, based on the drift configuration specified in the config. It also retrieves the appropriate drift function from the drift_config.

Parameters:

Name Type Description Default
config dict

A dictionary containing model configuration, including the drift name.

required
theta dict

A dictionary of parameter values, potentially including drift-related parameters.

required

Returns:

Type Description
dict

A dictionary containing: - drift_fun (callable): The drift function corresponding to the specified drift name. - drift_params (dict): Extracted drift-related parameters. If no drift name is specified in config, returns an empty dictionary.

simulator(theta, model='angle', n_samples=1000, delta_t=0.001, max_t=20, no_noise=False, bin_dim=None, bin_pointwise=False, sigma_noise=None, smooth_unif=True, return_option='full', random_state=None)

Basic data simulator for the models included in HDDM.

Arguments
theta : list, numpy.array, dict or pd.DataFrame
    Parameters of the simulator. If 2d array, each row is treated as a 'trial'
    and the function runs n_sample * n_trials simulations.
deadline : numpy.array <default=None>
    If supplied, the simulator will run a deadline model. RTs will be returned
!!! model "str <default='angle'>"
    Determines the model that will be simulated.
!!! n_samples "int <default=1000>"
    Number of simulation runs for each row in the theta argument.
!!! delta_t "float"
    Size fo timesteps in simulator (conceptually measured in seconds)
!!! max_t "float"
    Maximum reaction the simulator can reach
!!! no_noise "bool <default=False>"
    Turn noise of (useful for plotting purposes mostly)
!!! bin_dim "int | None <default=None>"
    Number of bins to use (in case the simulator output is
    supposed to come out as a count histogram)
!!! bin_pointwise "bool <default=False>"
    Wheter or not to bin the output data pointwise.
    If true the 'RT' part of the data is now specifies the
    'bin-number' of a given trial instead of the 'RT' directly.
    You need to specify bin_dim as some number for this to work.
!!! sigma_noise "float | None <default=None>"
    Standard deviation of noise in the diffusion process. If None, defaults to 1.0 for most models
    and 0.1 for LBA models. If no_noise is True, sigma_noise will be set to 0.0.
    If 'sd' or 's' is passed via theta dictionary, sigma_noise must be None.
!!! smooth_unif "bool <default=True>"
    Whether to add uniform random noise to RTs to smooth the distributions.
!!! return_option "str <default='full'>"
    Determines what the function returns. Can be either
    'full' or 'minimal'. If 'full' the function returns
    a dictionary with keys 'rts', 'responses' and 'metadata', and
    metadata contains the model parameters and some additional
    information. 'metadata' is a simpler dictionary with less information
    if 'minimal' is chosen.
!!! random_state "int | None <default=None>"
    Integer passed to random_seed function in the simulator.
    Can be used for reproducibility.
Return

dictionary where keys can be (rts, responses, metadata) or (rt-response histogram, metadata) or (rts binned pointwise, responses, metadata)

validate_ssm_parameters(model, theta)

Validate the parameters for Sequential Sampling Models (SSM).

This function checks the validity of parameters for different SSM models. It performs specific checks based on the model type.

Parameters:

Name Type Description Default
model str

The name of the SSM model.

required
theta dict

A dictionary containing the model parameters.

required

Exceptions:

Type Description
ValueError

If any of the parameter validations fail.

theta_processor

Define the AbstractThetaProcessor and its concrete implementation.

SimpleThetaProcessor for processing theta parameters based on model configurations.

Classes

  • AbstractThetaProcessor: An abstract base class that defines the interface for processing theta parameters.
  • SimpleThetaProcessor: A concrete implementation of AbstractThetaProcessor that processes theta parameters based on various model configurations.

The SimpleThetaProcessor class includes methods to handle different models such as single particle models, multi-particle models, LBA-based models, and various choice models. It modifies the theta parameters according to the specified model configuration and number of trials.

AbstractThetaProcessor (ABC)

Abstract base class for theta processors.

process_theta(self, theta, model_config, n_trials)

Abstract method to process theta parameters.

Parameters:

Name Type Description Default
theta Dict[str, Any]

Dictionary of theta parameters.

required
model_config Dict[str, Any]

Dictionary of model configuration.

required
n_trials int

Number of trials.

required
Returns
Dict[str, Any]: Processed theta parameters.

SimpleThetaProcessor (AbstractThetaProcessor)

Simple implementation of the AbstractThetaProcessor.

This class collects functions (for now very simple) that build the bridge between the model_config level specification of the model and the theta parameters that are used in the simulator.

process_theta(self, theta, model_config, n_trials)

Process theta parameters based on the model configuration.

Parameters:

Name Type Description Default
theta Dict[str, Any]

Dictionary of theta parameters.

required
model_config Dict[str, Any]

Dictionary of model configuration.

required
n_trials int

Number of trials.

required
Returns
Dict[str, Any]: Processed theta parameters.

config special

config

Configuration dictionary for simulators.

Variables:

dict

Dictionary containing all the information about the models

dict

Dictionary containing the filters for the KDE simulations

dict

Dictionary containing information for data generator settings. Supposed to serve as a starting point and example, which the user then modifies to their needs.

boundary_config_to_function_params(boundary_config)

Convert boundary configuration to function parameters.

Parameters

dict

Dictionary containing the boundary configuration

Returns

dict Dictionary with adjusted key names so that they match function parameters names directly.

dataset_generators special

lan_mlp

This module defines a data generator class for use with LANs. The class defined below can be used to generate training data compatible with the expectations of LANs.

data_generator

The data_generator() class is used to generate training data for various likelihood approximators.

Attributes
!!! generator_config "dict"
    Configuation dictionary for the data generator.
    (For an example load ssms.config.data_generator_config['lan'])
!!! model_config "dict"
    Configuration dictionary for the model to be simulated.
    (For an example load ssms.config.model_config['ddm'])
Methods
generate_data_training_uniform(save=False, verbose=True, cpn_only=False)
    Generates training data for LANs.
get_simulations(theta=None, random_seed=None)
    Generates simulations for a given parameter set.
_filter_simulations(simulations=None)
    Filters simulations according to the criteria
    specified in the generator_config.
_make_kde_data(simulations=None, theta=None)
    Generates KDE data from simulations.
_mlp_get_processed_data_for_theta(random_seed_tuple)
    Helper function for generating training data for MLPs.
_cpn_get_processed_data_for_theta(random_seed_tuple)
    Helper function for generating training data for CPNs.
_get_rejected_parameter_setups(random_seed_tuple)
    Helper function that collectes parameters sets which were rejected
    by the filter used in the _filter_simulations() method.
_make_save_file_name(unique_tag=None)
    Helper function for generating save file names.
_build_simulator()
    Builds simulator function for LANs.
_get_ncpus()
    Helper function for determining the number of
    cpus to use for parallelization.
Returns
data_generator object
__init__(self, generator_config=None, model_config=None) special

Initialize data generator class.

Arguments

dict

Configuration dictionary for the data generator. (For an example load ssms.config.data_generator_config['lan'])

dict

Configuration dictionary for the model to be simulated. (For an example load ssms.config.model_config['ddm'])

Raises

ValueError If no generator_config or model_config is specified.

Returns

data_generator object

generate_data_training_uniform(self, save=False, verbose=True, cpn_only=False)

Generates training data for LANs.

Arguments
!!! save "bool"
    If True, the generated data is saved to disk.
!!! verbose "bool"
    If True, progress is printed to the console.
!!! cpn_only "bool"
    If True, only choice probabilities are computed.
    This is useful for training CPNs.
Returns
!!! data "dict"
    Dictionary containing the generated data.
generate_rejected_parameterizations(self, save=False)

Generates parameterizations that are rejected by the filter.

Arguments
!!! save "bool"
    If True, the generated data is saved to disk.
Returns
!!! rejected_parameterization_list "np.array"
    Array containing the rejected parameterizations.
get_simulations(self, theta=None, random_seed=None)

Generates simulations for a given parameter set.

parameter_transform_for_data_gen(self, theta)

Function to impose constraints on the parameters for data generation.

Arguments
!!! theta "dict"
    Dictionary containing the parameters.
Returns
!!! theta "dict"
    Dictionary containing the transformed parameters.

snpe

data_generator_snpe (data_generator)

Class for generating data for SNPE.

Attributes

dict

Configuration for data generation

dict

Configuration for model

Methods

generate_data_training_uniform(save=False) Generates data for training SNPE. _snpe_get_processed_data_for_theta(random_seed) Helper function for generating data for SNPE. _build_simulator() Builds simulator function for SNPE.

generate_data_training_uniform(self, save=False)

Generates training data for LANs.

Arguments
!!! save "bool"
    If True, the generated data is saved to disk.
!!! verbose "bool"
    If True, progress is printed to the console.
!!! cpn_only "bool"
    If True, only choice probabilities are computed.
    This is useful for training CPNs.
Returns
!!! data "dict"
    Dictionary containing the generated data.

support_utils special

kde_class

LogKDE

Class for generating kdes from (rt, choice) data. Works for any number of choices.

Attributes
!!! simulator_data "dict, default<None"
    Dictionary of the type {'rts':[], 'choices':[], 'metadata':{}}.
    Follows the format of simulator returns in this package.
!!! bandwidth_type "string"
    type of bandwidth to use, default is 'silverman'
!!! auto_bandwidth "boolean"
    whether to compute bandwidths automatically, default is True
Methods
compute_bandwidths(type='silverman')
    Computes bandwidths for each choice from rt data.
generate_base_kdes(auto_bandwidth=True, bandwidth_type='silverman')
    Generates kdes from rt data.
kde_eval(data=([], []), log_eval=True)
    Evaluates kde log likelihood at chosen points.
kde_sample(n_samples=2000, use_empirical_choice_p=True, alternate_choice_p=0)
    Samples from a given kde.
attach_data_from_simulator(simulator_data={'rts':[0, 2, 4], 'choices':[-1, 1, -1], 'metadata':{}}))
    Helper function to transform ddm simulator output
    to dataset suitable for the kde function class.

Returns:

Type Description
_type_

description

__init__(self, simulator_data, bandwidth_type='silverman', auto_bandwidth=True, displace_t=False) special

Initialize LogKDE class.


!!! simulator_data "Dictionary containing simulation data with keys 'rts', 'choices', and 'metadata'."
    Follows the format returned by simulator functions in this package.
!!! bandwidth_type "Type of bandwidth to use for KDE. Currently only 'silverman' is supported."
    Defaults to 'silverman'.
!!! auto_bandwidth "Whether to automatically compute bandwidths based on the data."
    If False, bandwidths must be set manually. Defaults to True.
!!! displace_t "Whether to shift RTs by the t parameter from metadata."
    Only works if all trials have the same t value. Defaults to False.

AssertionError: If displace_t is True but metadata contains multiple t values.
attach_data_from_simulator(self, simulator_data=([0, 2, 4], [-1, 1, -1]), filter_rts=-999)

Helper function to transform ddm simulator output to dataset suitable for the kde function class.


tuple

Tuple of (rts, choices, simulator_info) as returned by simulator function.

float

Value to filter rts by, default is -999. -999 is the number returned by the simulators if we breach max_t or deadline.

compute_bandwidths(self, bandwidth_type='silverman')

Computes bandwidths for each choice from rt data.


string

Type of bandwidth to use, default is 'silverman' which follows silverman rule.


list

List of bandwidths for each choice.

generate_base_kdes(self, auto_bandwidth=True, bandwidth_type='silverman')

Generates kdes from rt data. We apply gaussian kernels to the log of the rts.


boolean

Whether to compute bandwidths automatically, default is True.

string

Type of bandwidth to use, default is 'silverman' which follows silverman rule.


list

List of kdes for each choice. (These get attached to the base_kdes attribute of the class, not returned)

kde_eval(self, data={}, log_eval=True, lb=-66.774, eps=0.0001, filter_rts=-999)

Evaluates kde log likelihood at chosen points.


dict

Dictionary with keys 'rts', and/or 'log_rts' and 'choices' to evaluate the kde at. If 'rts' is provided, 'log_rts' is ignored.

boolean

Whether to return log likelihood or likelihood, default is True.

float

Lower bound for log likelihoods, default is -66.774. (This is the log of 1e-29)

float

Epsilon value to use for lower bounds on rts.

float

Value to filter rts by, default is -999. -999 is the number returned by the simulators if we breach max_t or deadline.


array

Array of log likelihoods for each (rt, choice) pair.

kde_sample(self, n_samples=2000, use_empirical_choice_p=True, alternate_choice_p=0)

Samples from a given kde.


int

Number of samples to draw.

boolean

Whether to use empirical choice proportions, default is True. (Note 'empirical' here, refers to the originally attached datasets that served as the basis to generate the choice-wise kdes)

array

Array of choice proportions to use, default is 0. (Note 'alternate' here refers to 'alternative' to the 'empirical' choice proportions)

bandwidth_silverman(sample=[0, 0, 0], std_cutoff=0.001, std_proc='restrict', std_n_1=10)

Computes silverman bandwidth for an array of samples (rts in our context, but general).


array

Array of samples to compute bandwidth for.

float

Cutoff for std, default is 1e-3. (If sample-std is smaller than this, we either kill it or restrict it to this value)

string

How to deal with small stds, default is 'restrict'. (Options: 'kill', 'restrict')

float

Value to use if n = 1, default is 10. (Not clear if the default is sensible here)


float

Silverman bandwidth for the given sample. This is applied as the bandwidth parameter when generating gaussian-based kdes in the LogKDE class.

utils

This module provides utility functions for handling parameter dependencies and sampling parameters within specified constraints.

Functions

parse_bounds(bounds: Tuple[Any, Any]) -> Set[str] Parse the bounds of a parameter and extract any dependencies.

build_dependency_graph(param_dict: Dict[str, Tuple[Any, Any]]) -> Dict[str, Set[str]] Build a dependency graph based on parameter bounds.

topological_sort_util(node: str, visited: Set[str], stack: List[str], graph: Dict[str, Set[str]], temp_marks: Set[str]) -> None Helper function for performing a depth-first search in the topological sort.

topological_sort(graph: Dict[str, Set[str]]) -> List[str] Perform a topological sort on the dependency graph to determine the sampling order.

sample_parameters_from_constraints(param_dict: Dict[str, Tuple[Any, Any]], sample_size: int) -> Dict[str, np.ndarray] Sample parameters uniformly within specified bounds, respecting any dependencies.

build_dependency_graph(param_dict)

Build a dependency graph based on parameter bounds.

Parameters
param_dict (Dict[str, Tuple[Any, Any]]): A dictionary mapping parameter names
to their bounds.
Returns
Dict[str, Set[str]]: A dictionary representing the dependency graph where keys
are parameter names,
                     and values are sets of parameter names they depend on.

parse_bounds(bounds)

Parse the bounds of a parameter and extract any dependencies.

Parameters
bounds (Tuple[Any, Any]): A tuple containing the lower and upper bounds,
                          numeric or strings, indicating dependencies.
Returns
Set[str]: A set of parameter names that the bounds depend on.

sample_parameters_from_constraints(param_dict, sample_size)

Sample parameters uniformly within specified bounds, respecting any dependencies.

Parameters
param_dict (Dict[str, Tuple[Any, Any]]): Dictionary mapping parameter names to
their bounds.
sample_size (int): Number of samples to generate.
Returns
Dict[str, np.ndarray]: A dictionary mapping parameter names to arrays of sampled
 values.
Raises
ValueError: If dependencies cannot be resolved due to missing parameters or
 circular dependencies.

topological_sort(graph)

Perform a topological sort on the dependency graph to determine the sampling order.

Parameters
graph (Dict[str, Set[str]]): The dependency graph.
Returns
List[str]: A list of parameter names in the order they should be sampled.
Raises
ValueError: If a circular dependency is detected.

topological_sort_util(node, visited, stack, graph, temp_marks)

Helper function for performing a depth-first search in the topological sort.

Parameters
node (str): The current node being visited.
visited (Set[str]): Set of nodes that have been permanently marked
(fully processed).
stack (List[str]): List representing the ordering of nodes.
graph (Dict[str, Set[str]]): The dependency graph.
temp_marks (Set[str]): Set of nodes that have been temporarily marked (currently
 being processed).
Raises
ValueError: If a circular dependency is detected.