ssms
basic_simulators
special
boundary_functions
Define a collection of boundary functions for the simulators in the package.
angle(t=1, theta=1)
Angle boundary function.
Arguments
t (float or np.ndarray, optional): Time point(s). Defaults to 1.
theta (float, optional): Angle in radians. Defaults to 1.
Returns
np.ndarray: Array of boundary values, same shape as t
conflict_gamma(t=array([ 0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ,1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2. , 2.1,2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3. , 3.1, 3.2,3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4. , 4.1, 4.2, 4.3,4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 5. , 5.1, 5.2, 5.3, 5.4,5.5, 5.6, 5.7, 5.8, 5.9, 6. , 6.1, 6.2, 6.3, 6.4, 6.5,6.6, 6.7, 6.8, 6.9, 7. , 7.1, 7.2, 7.3, 7.4, 7.5, 7.6,7.7, 7.8, 7.9, 8. , 8.1, 8.2, 8.3, 8.4, 8.5, 8.6, 8.7,8.8, 8.9, 9. , 9.1, 9.2, 9.3, 9.4, 9.5, 9.6, 9.7, 9.8,9.9, 10. , 10.1, 10.2, 10.3, 10.4, 10.5, 10.6, 10.7, 10.8, 10.9,11. , 11.1, 11.2, 11.3, 11.4, 11.5, 11.6, 11.7, 11.8, 11.9, 12. ,12.1, 12.2, 12.3, 12.4, 12.5, 12.6, 12.7, 12.8, 12.9, 13. , 13.1,13.2, 13.3, 13.4, 13.5, 13.6, 13.7, 13.8, 13.9, 14. , 14.1, 14.2,14.3, 14.4, 14.5, 14.6, 14.7, 14.8, 14.9, 15. , 15.1, 15.2, 15.3,15.4, 15.5, 15.6, 15.7, 15.8, 15.9, 16. , 16.1, 16.2, 16.3, 16.4,16.5, 16.6, 16.7, 16.8, 16.9, 17. , 17.1, 17.2, 17.3, 17.4, 17.5,17.6, 17.7, 17.8, 17.9, 18. , 18.1, 18.2, 18.3, 18.4, 18.5, 18.6,18.7, 18.8, 18.9, 19. , 19.1, 19.2, 19.3, 19.4, 19.5, 19.6, 19.7,19.8, 19.9]), theta=0.5, scale=1, alpha_gamma=1.01, scale_gamma=0.3)
Conflict bound that allows initial divergence then collapse.
Arguments
!!! t "(float, np.ndarray)"
Time points (with arbitrary measure, but in HDDM it is used as seconds),
at which to evaluate the bound. Defaults to np.arange(0, 20, 0.1).
!!! theta "float"
Collapse angle. Defaults to 0.5.
!!! scale "float"
Scaling the gamma distribution of the boundary
(since bound does not have to integrate to one). Defaults to 1.0.
!!! alpha_gamma "float"
alpha parameter for a gamma in scale shape parameterization. Defaults to
constant(t=0)
Constant boundary function.
Arguments
t (float or np.ndarray, optional): Time point(s). Defaults to 0.
Returns
float or np.ndarray: Constant boundary value(s), same shape as t
generalized_logistic(t=1, B=2.0, M=3.0, v=0.5)
Generalized logistic bound.
Arguments
t (float or np.ndarray, optional): Time point(s). Defaults to 1.
B (float, optional): Growth rate. Defaults to 2.0.
M (float, optional): Time of maximum growth. Defaults to 3.0.
v (float, optional): Affects near which asymptote maximum growth occurs.
Defaults to 0.5.
Returns
np.ndarray: Array of boundary values, same shape as t
weibull_cdf(t=1, alpha=1, beta=1)
Boundary based on weibull survival function.
Arguments
t (float or np.ndarray, optional): Time point(s). Defaults to 1.
alpha (float, optional): Shape parameter. Defaults to 1.
beta (float, optional): Scale parameter. Defaults to 1.
Returns
np.ndarray: Array of boundary values, same shape as t
drift_functions
Define a collection of drift functions for the simulators in the package.
attend_drift(t=array([ 0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ,1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2. , 2.1,2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3. , 3.1, 3.2,3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4. , 4.1, 4.2, 4.3,4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 5. , 5.1, 5.2, 5.3, 5.4,5.5, 5.6, 5.7, 5.8, 5.9, 6. , 6.1, 6.2, 6.3, 6.4, 6.5,6.6, 6.7, 6.8, 6.9, 7. , 7.1, 7.2, 7.3, 7.4, 7.5, 7.6,7.7, 7.8, 7.9, 8. , 8.1, 8.2, 8.3, 8.4, 8.5, 8.6, 8.7,8.8, 8.9, 9. , 9.1, 9.2, 9.3, 9.4, 9.5, 9.6, 9.7, 9.8,9.9, 10. , 10.1, 10.2, 10.3, 10.4, 10.5, 10.6, 10.7, 10.8, 10.9,11. , 11.1, 11.2, 11.3, 11.4, 11.5, 11.6, 11.7, 11.8, 11.9, 12. ,12.1, 12.2, 12.3, 12.4, 12.5, 12.6, 12.7, 12.8, 12.9, 13. , 13.1,13.2, 13.3, 13.4, 13.5, 13.6, 13.7, 13.8, 13.9, 14. , 14.1, 14.2,14.3, 14.4, 14.5, 14.6, 14.7, 14.8, 14.9, 15. , 15.1, 15.2, 15.3,15.4, 15.5, 15.6, 15.7, 15.8, 15.9, 16. , 16.1, 16.2, 16.3, 16.4,16.5, 16.6, 16.7, 16.8, 16.9, 17. , 17.1, 17.2, 17.3, 17.4, 17.5,17.6, 17.7, 17.8, 17.9, 18. , 18.1, 18.2, 18.3, 18.4, 18.5, 18.6,18.7, 18.8, 18.9, 19. , 19.1, 19.2, 19.3, 19.4, 19.5, 19.6, 19.7,19.8, 19.9]), ptarget=-0.3, pouter=-0.3, pinner=0.3, r=0.5, sda=2)
Shrink spotlight model, which involves a time varying function dependent on a linearly decreasing standard deviation of attention.
Arguments
!!! t "np.ndarray"
Timepoints at which to evaluate the drift.
Usually np.arange() of some sort.
!!! pouter "float"
perceptual input for outer flankers
!!! pinner "float"
perceptual input for inner flankers
!!! ptarget "float"
perceptual input for target flanker
!!! r "float"
rate parameter for sda decrease
!!! sda "float"
width of attentional spotlight
Return
np.ndarray Drift evaluated at timepoints t
attend_drift_simple(t=array([ 0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ,1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2. , 2.1,2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3. , 3.1, 3.2,3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4. , 4.1, 4.2, 4.3,4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 5. , 5.1, 5.2, 5.3, 5.4,5.5, 5.6, 5.7, 5.8, 5.9, 6. , 6.1, 6.2, 6.3, 6.4, 6.5,6.6, 6.7, 6.8, 6.9, 7. , 7.1, 7.2, 7.3, 7.4, 7.5, 7.6,7.7, 7.8, 7.9, 8. , 8.1, 8.2, 8.3, 8.4, 8.5, 8.6, 8.7,8.8, 8.9, 9. , 9.1, 9.2, 9.3, 9.4, 9.5, 9.6, 9.7, 9.8,9.9, 10. , 10.1, 10.2, 10.3, 10.4, 10.5, 10.6, 10.7, 10.8, 10.9,11. , 11.1, 11.2, 11.3, 11.4, 11.5, 11.6, 11.7, 11.8, 11.9, 12. ,12.1, 12.2, 12.3, 12.4, 12.5, 12.6, 12.7, 12.8, 12.9, 13. , 13.1,13.2, 13.3, 13.4, 13.5, 13.6, 13.7, 13.8, 13.9, 14. , 14.1, 14.2,14.3, 14.4, 14.5, 14.6, 14.7, 14.8, 14.9, 15. , 15.1, 15.2, 15.3,15.4, 15.5, 15.6, 15.7, 15.8, 15.9, 16. , 16.1, 16.2, 16.3, 16.4,16.5, 16.6, 16.7, 16.8, 16.9, 17. , 17.1, 17.2, 17.3, 17.4, 17.5,17.6, 17.7, 17.8, 17.9, 18. , 18.1, 18.2, 18.3, 18.4, 18.5, 18.6,18.7, 18.8, 18.9, 19. , 19.1, 19.2, 19.3, 19.4, 19.5, 19.6, 19.7,19.8, 19.9]), ptarget=-0.3, pouter=-0.3, r=0.5, sda=2)
Drift function for shrinking spotlight model, which involves a time varying function dependent on a linearly decreasing standard deviation of attention.
Arguments
!!! t "np.ndarray"
Timepoints at which to evaluate the drift.
Usually np.arange() of some sort.
!!! pouter "float"
perceptual input for outer flankers
!!! ptarget "float"
perceptual input for target flanker
!!! r "float"
rate parameter for sda decrease
!!! sda "float"
width of attentional spotlight
Return
np.ndarray Drift evaluated at timepoints t
constant(t=array([ 0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ,1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2. , 2.1,2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3. , 3.1, 3.2,3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4. , 4.1, 4.2, 4.3,4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 5. , 5.1, 5.2, 5.3, 5.4,5.5, 5.6, 5.7, 5.8, 5.9, 6. , 6.1, 6.2, 6.3, 6.4, 6.5,6.6, 6.7, 6.8, 6.9, 7. , 7.1, 7.2, 7.3, 7.4, 7.5, 7.6,7.7, 7.8, 7.9, 8. , 8.1, 8.2, 8.3, 8.4, 8.5, 8.6, 8.7,8.8, 8.9, 9. , 9.1, 9.2, 9.3, 9.4, 9.5, 9.6, 9.7, 9.8,9.9, 10. , 10.1, 10.2, 10.3, 10.4, 10.5, 10.6, 10.7, 10.8, 10.9,11. , 11.1, 11.2, 11.3, 11.4, 11.5, 11.6, 11.7, 11.8, 11.9, 12. ,12.1, 12.2, 12.3, 12.4, 12.5, 12.6, 12.7, 12.8, 12.9, 13. , 13.1,13.2, 13.3, 13.4, 13.5, 13.6, 13.7, 13.8, 13.9, 14. , 14.1, 14.2,14.3, 14.4, 14.5, 14.6, 14.7, 14.8, 14.9, 15. , 15.1, 15.2, 15.3,15.4, 15.5, 15.6, 15.7, 15.8, 15.9, 16. , 16.1, 16.2, 16.3, 16.4,16.5, 16.6, 16.7, 16.8, 16.9, 17. , 17.1, 17.2, 17.3, 17.4, 17.5,17.6, 17.7, 17.8, 17.9, 18. , 18.1, 18.2, 18.3, 18.4, 18.5, 18.6,18.7, 18.8, 18.9, 19. , 19.1, 19.2, 19.3, 19.4, 19.5, 19.6, 19.7,19.8, 19.9]))
Constant drift function.
Arguments
!!! t "np.ndarray, optional"
Timepoints at which to evaluate the drift. Defaults to
np.arange(0, 20, 0.1).
Returns
np.ndarray: Array of drift values, same length as t
ds_conflict_drift(t=array([0.000e+00, 1.000e-03, 2.000e-03, ..., 9.997e+00, 9.998e+00,9.999e+00]), tinit=0, dinit=0, tslope=1, dslope=1, tfixedp=1, tcoh=1.5, dcoh=1.5)
This drift is inspired by a conflict task which involves a target and a distractor stimuli both presented simultaneously.
Two drift timecourses are linearly combined weighted by the coherence in the respective target and distractor stimuli. Each timecourse follows a dynamical system as described in the ds_support_analytic() function.
Arguments
!!! t "np.ndarray"
Timepoints at which to evaluate the drift.
Usually np.arange() of some sort.
!!! tinit "float"
Initial condition of target drift timecourse
!!! dinit "float"
Initial condition of distractor drift timecourse
!!! tslope "float"
Slope parameter for target drift timecourse
!!! dslope "float"
Slope parameter for distractor drift timecourse
!!! tfixedp "float"
Fixed point for target drift timecourse
!!! tcoh "float"
Coefficient for the target drift timecourse
!!! dcoh "float"
Coefficient for the distractor drift timecourse
Return
np.ndarray The full drift timecourse evaluated at the supplied timepoints t.
ds_support_analytic(t=array([0.000e+00, 1.000e-03, 2.000e-03, ..., 9.997e+00, 9.998e+00,9.999e+00]), init_p=0, fix_point=1, slope=2)
Solve DE.
DE is of the form: x' = slope*(fix_point - x), with initial condition init_p. The solution takes the form: (init_p - fix_point) * exp(-slope * t) + fix_point
Arguments
!!! t "np.ndarray"
Timepoints at which to evaluate the drift. Usually np.arange() of some sort.
!!! init_p "float"
Initial condition of dynamical system
!!! fix_point "float"
Fixed point of dynamical system
!!! slope "float"
Coefficient in exponent of the solution.
Return
np.ndarray The gamma drift evaluated at the supplied timepoints t.
gamma_drift(t=array([ 0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ,1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2. , 2.1,2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3. , 3.1, 3.2,3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4. , 4.1, 4.2, 4.3,4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 5. , 5.1, 5.2, 5.3, 5.4,5.5, 5.6, 5.7, 5.8, 5.9, 6. , 6.1, 6.2, 6.3, 6.4, 6.5,6.6, 6.7, 6.8, 6.9, 7. , 7.1, 7.2, 7.3, 7.4, 7.5, 7.6,7.7, 7.8, 7.9, 8. , 8.1, 8.2, 8.3, 8.4, 8.5, 8.6, 8.7,8.8, 8.9, 9. , 9.1, 9.2, 9.3, 9.4, 9.5, 9.6, 9.7, 9.8,9.9, 10. , 10.1, 10.2, 10.3, 10.4, 10.5, 10.6, 10.7, 10.8, 10.9,11. , 11.1, 11.2, 11.3, 11.4, 11.5, 11.6, 11.7, 11.8, 11.9, 12. ,12.1, 12.2, 12.3, 12.4, 12.5, 12.6, 12.7, 12.8, 12.9, 13. , 13.1,13.2, 13.3, 13.4, 13.5, 13.6, 13.7, 13.8, 13.9, 14. , 14.1, 14.2,14.3, 14.4, 14.5, 14.6, 14.7, 14.8, 14.9, 15. , 15.1, 15.2, 15.3,15.4, 15.5, 15.6, 15.7, 15.8, 15.9, 16. , 16.1, 16.2, 16.3, 16.4,16.5, 16.6, 16.7, 16.8, 16.9, 17. , 17.1, 17.2, 17.3, 17.4, 17.5,17.6, 17.7, 17.8, 17.9, 18. , 18.1, 18.2, 18.3, 18.4, 18.5, 18.6,18.7, 18.8, 18.9, 19. , 19.1, 19.2, 19.3, 19.4, 19.5, 19.6, 19.7,19.8, 19.9]), shape=2, scale=0.01, c=1.5)
Drift function that follows a scaled gamma distribution.
Arguments
!!! t "np.ndarray"
Timepoints at which to evaluate the drift.
Usually np.arange() of some sort.
!!! shape "float"
Shape parameter of the gamma distribution
!!! scale "float"
Scale parameter of the gamma distribution
!!! c "float"
Scalar parameter that scales the peak of
the gamma distribution.
(Note this function follows a gamma distribution
but does not integrate to 1)
Return
np.ndarray
The gamma drift evaluated at the supplied timepoints t.
simulator
This module defines the basic simulator function which is the main workshorse of the package. In addition some utility functions are provided that help with preprocessing the output of the simulator function.
bin_arbitrary_fptd(out=None, bin_dt=0.04, nbins=256, nchoices=2, choice_codes=[-1.0, 1.0], max_t=10.0)
Takes in simulator output and returns a histogram of bin counts Arguments
!!! out "np.ndarray"
Output of the 'simulator' function
bin_dt : float
If nbins is 0, this determines the desired bin size
which in turn automatically determines the resulting number of bins.
nbins : int
Number of bins to bin reaction time data into.
If supplied as 0, bin_dt instead determines the number of
bins automatically.
!!! nchoices "int <default=2>"
Number of choices allowed by the simulator.
choice_codes = list[float] <default=[-1.0, 1.0]>
Choice labels to be used.
!!! max_t "float"
Maximum RT to consider.
Returns
2d array (nbins, nchoices): A histogram of bin counts
bin_simulator_output(out=None, bin_dt=0.04, nbins=0, max_t=-1, freq_cnt=False)
Turns RT part of simulator output into bin-identifier by trial
Arguments
out : dict
Output of the 'simulator' function
bin_dt : float
If nbins is 0, this determines the desired
bin size which in turn automatically
determines the resulting number of bins.
nbins : int
Number of bins to bin reaction time data into.
If supplied as 0, bin_dt instead determines the number of
bins automatically.
max_t : float <default=-1>
Override the 'max_t' metadata as part of the simulator output.
Sometimes useful, but usually default will do the job.
freq_cnt : bool <default=False>
Decide whether to return proportions (default) or counts in bins.
Returns
A histogram of counts or proportions.
bin_simulator_output_pointwise(out=(array([0]), array([0])), bin_dt=0.04, nbins=0)
Turns RT part of simulator output into bin-identifier by trial
Arguments
!!! out "tuple"
Output of the 'simulator' function
!!! bin_dt "float"
If nbins is 0, this determines the desired
bin size which in turn automatically
determines the resulting number of bins.
!!! nbins "int"
Number of bins to bin reaction time data into.
If supplied as 0, bin_dt instead determines the
number of bins automatically.
Returns
2d array. The first columns collects bin-identifiers
by trial, the second column lists the corresponding choices.
make_boundary_dict(config, theta)
Create a dictionary containing boundary-related parameters and functions.
This function extracts boundary-related parameters from the input theta dictionary, based on the boundary configuration specified in the config. It also retrieves the appropriate boundary function and multiplicative flag from the boundary_config.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
dict |
A dictionary containing model configuration, including the boundary name. |
required |
theta |
dict |
A dictionary of parameter values, potentially including boundary-related parameters. |
required |
Returns:
Type | Description |
---|---|
dict |
A dictionary containing: - boundary_params (dict): Extracted boundary-related parameters. - boundary_fun (callable): The boundary function corresponding to the specified boundary name. - boundary_multiplicative (bool): Flag indicating if the boundary is multiplicative. |
make_drift_dict(config, theta)
Create a dictionary containing drift-related parameters and functions.
This function extracts drift-related parameters from the input theta dictionary, based on the drift configuration specified in the config. It also retrieves the appropriate drift function from the drift_config.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
dict |
A dictionary containing model configuration, including the drift name. |
required |
theta |
dict |
A dictionary of parameter values, potentially including drift-related parameters. |
required |
Returns:
Type | Description |
---|---|
dict |
A dictionary containing: - drift_fun (callable): The drift function corresponding to the specified drift name. - drift_params (dict): Extracted drift-related parameters. If no drift name is specified in config, returns an empty dictionary. |
simulator(theta, model='angle', n_samples=1000, delta_t=0.001, max_t=20, no_noise=False, bin_dim=None, bin_pointwise=False, sigma_noise=None, smooth_unif=True, return_option='full', random_state=None)
Basic data simulator for the models included in HDDM.
Arguments
theta : list, numpy.array, dict or pd.DataFrame
Parameters of the simulator. If 2d array, each row is treated as a 'trial'
and the function runs n_sample * n_trials simulations.
deadline : numpy.array <default=None>
If supplied, the simulator will run a deadline model. RTs will be returned
!!! model "str <default='angle'>"
Determines the model that will be simulated.
!!! n_samples "int <default=1000>"
Number of simulation runs for each row in the theta argument.
!!! delta_t "float"
Size fo timesteps in simulator (conceptually measured in seconds)
!!! max_t "float"
Maximum reaction the simulator can reach
!!! no_noise "bool <default=False>"
Turn noise of (useful for plotting purposes mostly)
!!! bin_dim "int | None <default=None>"
Number of bins to use (in case the simulator output is
supposed to come out as a count histogram)
!!! bin_pointwise "bool <default=False>"
Wheter or not to bin the output data pointwise.
If true the 'RT' part of the data is now specifies the
'bin-number' of a given trial instead of the 'RT' directly.
You need to specify bin_dim as some number for this to work.
!!! sigma_noise "float | None <default=None>"
Standard deviation of noise in the diffusion process. If None, defaults to 1.0 for most models
and 0.1 for LBA models. If no_noise is True, sigma_noise will be set to 0.0.
If 'sd' or 's' is passed via theta dictionary, sigma_noise must be None.
!!! smooth_unif "bool <default=True>"
Whether to add uniform random noise to RTs to smooth the distributions.
!!! return_option "str <default='full'>"
Determines what the function returns. Can be either
'full' or 'minimal'. If 'full' the function returns
a dictionary with keys 'rts', 'responses' and 'metadata', and
metadata contains the model parameters and some additional
information. 'metadata' is a simpler dictionary with less information
if 'minimal' is chosen.
!!! random_state "int | None <default=None>"
Integer passed to random_seed function in the simulator.
Can be used for reproducibility.
Return
dictionary where keys can be (rts, responses, metadata) or (rt-response histogram, metadata) or (rts binned pointwise, responses, metadata)
validate_ssm_parameters(model, theta)
Validate the parameters for Sequential Sampling Models (SSM).
This function checks the validity of parameters for different SSM models. It performs specific checks based on the model type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
str |
The name of the SSM model. |
required |
theta |
dict |
A dictionary containing the model parameters. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If any of the parameter validations fail. |
theta_processor
Define the AbstractThetaProcessor
and its concrete implementation.
SimpleThetaProcessor
for processing theta parameters based on model configurations.
Classes
- AbstractThetaProcessor: An abstract base class that defines the interface for processing theta parameters.
- SimpleThetaProcessor: A concrete implementation of
AbstractThetaProcessor
that processes theta parameters based on various model configurations.
The SimpleThetaProcessor
class includes methods to handle different models such as
single particle models, multi-particle models, LBA-based models, and various choice
models. It modifies the theta parameters according to the specified model configuration
and number of trials.
AbstractThetaProcessor (ABC)
Abstract base class for theta processors.
process_theta(self, theta, model_config, n_trials)
Abstract method to process theta parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
theta |
Dict[str, Any] |
Dictionary of theta parameters. |
required |
model_config |
Dict[str, Any] |
Dictionary of model configuration. |
required |
n_trials |
int |
Number of trials. |
required |
Returns
Dict[str, Any]: Processed theta parameters.
SimpleThetaProcessor (AbstractThetaProcessor)
Simple implementation of the AbstractThetaProcessor.
This class collects functions (for now very simple) that build the bridge between the model_config level specification of the model and the theta parameters that are used in the simulator.
process_theta(self, theta, model_config, n_trials)
Process theta parameters based on the model configuration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
theta |
Dict[str, Any] |
Dictionary of theta parameters. |
required |
model_config |
Dict[str, Any] |
Dictionary of model configuration. |
required |
n_trials |
int |
Number of trials. |
required |
Returns
Dict[str, Any]: Processed theta parameters.
config
special
config
Configuration dictionary for simulators.
Variables:
dict
Dictionary containing all the information about the models
dict
Dictionary containing the filters for the KDE simulations
dict
Dictionary containing information for data generator settings. Supposed to serve as a starting point and example, which the user then modifies to their needs.
boundary_config_to_function_params(boundary_config)
Convert boundary configuration to function parameters.
Parameters
dict
Dictionary containing the boundary configuration
Returns
dict Dictionary with adjusted key names so that they match function parameters names directly.
dataset_generators
special
lan_mlp
This module defines a data generator class for use with LANs. The class defined below can be used to generate training data compatible with the expectations of LANs.
data_generator
The data_generator() class is used to generate training data for various likelihood approximators.
Attributes
!!! generator_config "dict"
Configuation dictionary for the data generator.
(For an example load ssms.config.data_generator_config['lan'])
!!! model_config "dict"
Configuration dictionary for the model to be simulated.
(For an example load ssms.config.model_config['ddm'])
Methods
generate_data_training_uniform(save=False, verbose=True, cpn_only=False)
Generates training data for LANs.
get_simulations(theta=None, random_seed=None)
Generates simulations for a given parameter set.
_filter_simulations(simulations=None)
Filters simulations according to the criteria
specified in the generator_config.
_make_kde_data(simulations=None, theta=None)
Generates KDE data from simulations.
_mlp_get_processed_data_for_theta(random_seed_tuple)
Helper function for generating training data for MLPs.
_cpn_get_processed_data_for_theta(random_seed_tuple)
Helper function for generating training data for CPNs.
_get_rejected_parameter_setups(random_seed_tuple)
Helper function that collectes parameters sets which were rejected
by the filter used in the _filter_simulations() method.
_make_save_file_name(unique_tag=None)
Helper function for generating save file names.
_build_simulator()
Builds simulator function for LANs.
_get_ncpus()
Helper function for determining the number of
cpus to use for parallelization.
Returns
data_generator object
__init__(self, generator_config=None, model_config=None)
special
Initialize data generator class.
Arguments
dict
Configuration dictionary for the data generator. (For an example load ssms.config.data_generator_config['lan'])
dict
Configuration dictionary for the model to be simulated. (For an example load ssms.config.model_config['ddm'])
Raises
ValueError If no generator_config or model_config is specified.
Returns
data_generator object
generate_data_training_uniform(self, save=False, verbose=True, cpn_only=False)
Generates training data for LANs.
Arguments
!!! save "bool"
If True, the generated data is saved to disk.
!!! verbose "bool"
If True, progress is printed to the console.
!!! cpn_only "bool"
If True, only choice probabilities are computed.
This is useful for training CPNs.
Returns
!!! data "dict"
Dictionary containing the generated data.
generate_rejected_parameterizations(self, save=False)
Generates parameterizations that are rejected by the filter.
Arguments
!!! save "bool"
If True, the generated data is saved to disk.
Returns
!!! rejected_parameterization_list "np.array"
Array containing the rejected parameterizations.
get_simulations(self, theta=None, random_seed=None)
Generates simulations for a given parameter set.
parameter_transform_for_data_gen(self, theta)
Function to impose constraints on the parameters for data generation.
Arguments
!!! theta "dict"
Dictionary containing the parameters.
Returns
!!! theta "dict"
Dictionary containing the transformed parameters.
snpe
data_generator_snpe (data_generator)
Class for generating data for SNPE.
Attributes
dict
Configuration for data generation
dict
Configuration for model
Methods
generate_data_training_uniform(save=False) Generates data for training SNPE. _snpe_get_processed_data_for_theta(random_seed) Helper function for generating data for SNPE. _build_simulator() Builds simulator function for SNPE.
generate_data_training_uniform(self, save=False)
Generates training data for LANs.
Arguments
!!! save "bool"
If True, the generated data is saved to disk.
!!! verbose "bool"
If True, progress is printed to the console.
!!! cpn_only "bool"
If True, only choice probabilities are computed.
This is useful for training CPNs.
Returns
!!! data "dict"
Dictionary containing the generated data.
support_utils
special
kde_class
LogKDE
Class for generating kdes from (rt, choice) data. Works for any number of choices.
Attributes
!!! simulator_data "dict, default<None"
Dictionary of the type {'rts':[], 'choices':[], 'metadata':{}}.
Follows the format of simulator returns in this package.
!!! bandwidth_type "string"
type of bandwidth to use, default is 'silverman'
!!! auto_bandwidth "boolean"
whether to compute bandwidths automatically, default is True
Methods
compute_bandwidths(type='silverman')
Computes bandwidths for each choice from rt data.
generate_base_kdes(auto_bandwidth=True, bandwidth_type='silverman')
Generates kdes from rt data.
kde_eval(data=([], []), log_eval=True)
Evaluates kde log likelihood at chosen points.
kde_sample(n_samples=2000, use_empirical_choice_p=True, alternate_choice_p=0)
Samples from a given kde.
attach_data_from_simulator(simulator_data={'rts':[0, 2, 4], 'choices':[-1, 1, -1], 'metadata':{}}))
Helper function to transform ddm simulator output
to dataset suitable for the kde function class.
Returns:
Type | Description |
---|---|
_type_ |
description |
__init__(self, simulator_data, bandwidth_type='silverman', auto_bandwidth=True, displace_t=False)
special
Initialize LogKDE class.
!!! simulator_data "Dictionary containing simulation data with keys 'rts', 'choices', and 'metadata'."
Follows the format returned by simulator functions in this package.
!!! bandwidth_type "Type of bandwidth to use for KDE. Currently only 'silverman' is supported."
Defaults to 'silverman'.
!!! auto_bandwidth "Whether to automatically compute bandwidths based on the data."
If False, bandwidths must be set manually. Defaults to True.
!!! displace_t "Whether to shift RTs by the t parameter from metadata."
Only works if all trials have the same t value. Defaults to False.
AssertionError: If displace_t is True but metadata contains multiple t values.
attach_data_from_simulator(self, simulator_data=([0, 2, 4], [-1, 1, -1]), filter_rts=-999)
Helper function to transform ddm simulator output to dataset suitable for the kde function class.
tuple
Tuple of (rts, choices, simulator_info) as returned by simulator function.
float
Value to filter rts by, default is -999. -999 is the number returned by the simulators if we breach max_t or deadline.
compute_bandwidths(self, bandwidth_type='silverman')
Computes bandwidths for each choice from rt data.
string
Type of bandwidth to use, default is 'silverman' which follows silverman rule.
list
List of bandwidths for each choice.
generate_base_kdes(self, auto_bandwidth=True, bandwidth_type='silverman')
Generates kdes from rt data. We apply gaussian kernels to the log of the rts.
boolean
Whether to compute bandwidths automatically, default is True.
string
Type of bandwidth to use, default is 'silverman' which follows silverman rule.
list
List of kdes for each choice. (These get attached to the base_kdes attribute of the class, not returned)
kde_eval(self, data={}, log_eval=True, lb=-66.774, eps=0.0001, filter_rts=-999)
Evaluates kde log likelihood at chosen points.
dict
Dictionary with keys 'rts', and/or 'log_rts' and 'choices' to evaluate the kde at. If 'rts' is provided, 'log_rts' is ignored.
boolean
Whether to return log likelihood or likelihood, default is True.
float
Lower bound for log likelihoods, default is -66.774. (This is the log of 1e-29)
float
Epsilon value to use for lower bounds on rts.
float
Value to filter rts by, default is -999. -999 is the number returned by the simulators if we breach max_t or deadline.
array
Array of log likelihoods for each (rt, choice) pair.
kde_sample(self, n_samples=2000, use_empirical_choice_p=True, alternate_choice_p=0)
Samples from a given kde.
int
Number of samples to draw.
boolean
Whether to use empirical choice proportions, default is True. (Note 'empirical' here, refers to the originally attached datasets that served as the basis to generate the choice-wise kdes)
array
Array of choice proportions to use, default is 0. (Note 'alternate' here refers to 'alternative' to the 'empirical' choice proportions)
bandwidth_silverman(sample=[0, 0, 0], std_cutoff=0.001, std_proc='restrict', std_n_1=10)
Computes silverman bandwidth for an array of samples (rts in our context, but general).
array
Array of samples to compute bandwidth for.
float
Cutoff for std, default is 1e-3. (If sample-std is smaller than this, we either kill it or restrict it to this value)
string
How to deal with small stds, default is 'restrict'. (Options: 'kill', 'restrict')
float
Value to use if n = 1, default is 10. (Not clear if the default is sensible here)
float
Silverman bandwidth for the given sample. This is applied as the bandwidth parameter when generating gaussian-based kdes in the LogKDE class.
utils
This module provides utility functions for handling parameter dependencies and sampling parameters within specified constraints.
Functions
parse_bounds(bounds: Tuple[Any, Any]) -> Set[str] Parse the bounds of a parameter and extract any dependencies.
build_dependency_graph(param_dict: Dict[str, Tuple[Any, Any]]) -> Dict[str, Set[str]] Build a dependency graph based on parameter bounds.
topological_sort_util(node: str, visited: Set[str], stack: List[str], graph: Dict[str, Set[str]], temp_marks: Set[str]) -> None Helper function for performing a depth-first search in the topological sort.
topological_sort(graph: Dict[str, Set[str]]) -> List[str] Perform a topological sort on the dependency graph to determine the sampling order.
sample_parameters_from_constraints(param_dict: Dict[str, Tuple[Any, Any]], sample_size: int) -> Dict[str, np.ndarray] Sample parameters uniformly within specified bounds, respecting any dependencies.
build_dependency_graph(param_dict)
Build a dependency graph based on parameter bounds.
Parameters
param_dict (Dict[str, Tuple[Any, Any]]): A dictionary mapping parameter names
to their bounds.
Returns
Dict[str, Set[str]]: A dictionary representing the dependency graph where keys
are parameter names,
and values are sets of parameter names they depend on.
parse_bounds(bounds)
Parse the bounds of a parameter and extract any dependencies.
Parameters
bounds (Tuple[Any, Any]): A tuple containing the lower and upper bounds,
numeric or strings, indicating dependencies.
Returns
Set[str]: A set of parameter names that the bounds depend on.
sample_parameters_from_constraints(param_dict, sample_size)
Sample parameters uniformly within specified bounds, respecting any dependencies.
Parameters
param_dict (Dict[str, Tuple[Any, Any]]): Dictionary mapping parameter names to
their bounds.
sample_size (int): Number of samples to generate.
Returns
Dict[str, np.ndarray]: A dictionary mapping parameter names to arrays of sampled
values.
Raises
ValueError: If dependencies cannot be resolved due to missing parameters or
circular dependencies.
topological_sort(graph)
Perform a topological sort on the dependency graph to determine the sampling order.
Parameters
graph (Dict[str, Set[str]]): The dependency graph.
Returns
List[str]: A list of parameter names in the order they should be sampled.
Raises
ValueError: If a circular dependency is detected.
topological_sort_util(node, visited, stack, graph, temp_marks)
Helper function for performing a depth-first search in the topological sort.
Parameters
node (str): The current node being visited.
visited (Set[str]): Set of nodes that have been permanently marked
(fully processed).
stack (List[str]): List representing the ordering of nodes.
graph (Dict[str, Set[str]]): The dependency graph.
temp_marks (Set[str]): Set of nodes that have been temporarily marked (currently
being processed).
Raises
ValueError: If a circular dependency is detected.