Getting Started#

We’re going to demonstrate how to estimate recruitment curves using a hierarchical Bayesian model and the rectified-logistic function.

A Simple Example#

Tip

If you have trouble running the commands in this tutorial, please copy the command and its output, then open an issue on our GitHub repository. We’ll do our best to help you!

Begin by reading the mock_data.csv file:

import pandas as pd

url = "https://raw.githubusercontent.com/hbmep/hbmep/refs/heads/docs-data/data/mock_data.csv"
df = pd.read_csv(url)
df shape:		(245, 4)
df columns:		TMSIntensity, participant, PKPK_ECR, PKPK_FCR
All participants:	P1, P2, P3

First 5 rows:

   TMSIntensity participant  PKPK_ECR  PKPK_FCR
0         43.79          P1     0.197     0.048
1         55.00          P1     0.224     0.068
2         41.00          P1     0.112     0.110
3         43.00          P1     0.149     0.058
4         14.00          P1     0.014     0.011

This dataset contains TMS responses (peak-to-peak amplitude, in mV) for three participants (P1, P2, P3), recorded from two muscles (ECR and FCR). The column TMSIntensity represents stimulation intensity in percent maximum stimulator output (0–100% MSO).

Build the model#

Next, we initialize a standard hierarchical Bayesian model. This step typically consists of assigning the model’s attributes to the appropriate dataframe columns, setting the sampling parameters, and choosing the recruitment curve function.

from hbmep.model.standard import HB

model = HB()

# Point to the respective columns in dataframe
model.intensity = "TMSIntensity"
model.features = ["participant"]
model.response = ["PKPK_ECR", "PKPK_FCR"]

# Specify the sampling parameters
model.mcmc_params = {
    "num_chains": 4,
    "thinning": 1,
    "num_warmup": 1000,
    "num_samples": 1000,
}

# Set the function
model._model = model.rectified_logistic

Alternatively, these settings can be specified in an hbmep.toml configuration file without having to change the code directly. See Working with TOML configuration file for details.

Running the model#

Before fitting the model, we can visualize the dataset. Since the plot is saved as a PDF, we need to specify an output path.

import os

current_directory = os.getcwd()
output_path = os.path.join(current_directory, "dataset.pdf")

# Plot dataset and save it as a PDF
model.plot(df, output_path=output_path)
../_images/e4e40676a07bee3111019d09358405edfb6da9553c3eedbc88fa20ac16580151.png

The plot shows rows as participants and columns as muscles. The x-axis is TMS intensity (% MSO), and the y-axis is MEP peak-to-peak amplitude (mV).

Next, we process the dataframe by encoding categorical feature columns. This returns the same dataframe with encoded values, plus an encoder dictionary for mapping back to original labels.

# Process the dataframe
df, encoder = model.load(df)
Encoded participants:	0, 1, 2
Participant mapping:	0 -> P1, 1 -> P2, 2 -> P3

Now we run the model to estimate curves.

# Run
mcmc, posterior = model.run(df=df)

# Check convergence diagnostics
summary_df = model.summary(posterior)
print(summary_df.to_string())
Hide code cell output
WARNING:2025-09-10 13:37:55,646:jax._src.xla_bridge:864: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Compiling.. :   0%|          | 0/2000 [00:00<?, ?it/s]





Running chain 0:   0%|          | 0/2000 [00:02<?, ?it/s]





Running chain 0:   5%|▌         | 100/2000 [00:05<01:07, 28.04it/s]




Running chain 0:  10%|█         | 200/2000 [00:06<00:36, 49.70it/s]


Running chain 0:  15%|█▌        | 300/2000 [00:07<00:24, 70.24it/s]


Running chain 0:  20%|██        | 400/2000 [00:07<00:17, 93.04it/s]





Running chain 0:  25%|██▌       | 500/2000 [00:08<00:13, 107.38it/s]



Running chain 0:  30%|███       | 600/2000 [00:09<00:11, 125.25it/s]


Running chain 0:  35%|███▌      | 700/2000 [00:09<00:08, 146.29it/s]

Running chain 0:  40%|████      | 800/2000 [00:09<00:07, 166.56it/s]
Running chain 0:  45%|████▌     | 900/2000 [00:10<00:06, 182.06it/s]


Running chain 0:  50%|█████     | 1000/2000 [00:10<00:05, 192.61it/s]


Running chain 0:  55%|█████▌    | 1100/2000 [00:11<00:04, 193.64it/s]


Running chain 0:  60%|██████    | 1200/2000 [00:11<00:04, 186.25it/s]


Running chain 0:  65%|██████▌   | 1300/2000 [00:12<00:03, 183.70it/s]


Running chain 0:  70%|███████   | 1400/2000 [00:12<00:03, 185.69it/s]


Running chain 0:  75%|███████▌  | 1500/2000 [00:13<00:02, 178.36it/s]


Running chain 0:  80%|████████  | 1600/2000 [00:14<00:02, 176.14it/s]


Running chain 0:  85%|████████▌ | 1700/2000 [00:14<00:01, 178.43it/s]


Running chain 0:  90%|█████████ | 1800/2000 [00:15<00:01, 179.44it/s]

Running chain 3: 100%|██████████| 2000/2000 [00:15<00:00, 130.38it/s]

Running chain 2: 100%|██████████| 2000/2000 [00:15<00:00, 129.55it/s]
Running chain 1: 100%|██████████| 2000/2000 [00:15<00:00, 128.48it/s]
Running chain 0: 100%|██████████| 2000/2000 [00:16<00:00, 122.60it/s]
                mean     sd  hdi_2.5%  hdi_97.5%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
a_loc         35.418  5.576    24.891     47.705      0.157    0.299    2321.0     869.0    1.0
a_scale       11.170  7.101     3.772     25.594      0.231    0.376    1611.0    1203.0    1.0
b_scale        0.189  0.115     0.061      0.381      0.003    0.013    1563.0    1594.0    1.0
g_scale        0.013  0.005     0.006      0.024      0.000    0.000    1047.0     975.0    1.0
h_scale        0.282  0.129     0.120      0.524      0.004    0.007    1054.0    1250.0    1.0
v_scale        4.547  2.913     0.246     10.230      0.050    0.040    2554.0    1966.0    1.0
c₁_scale       3.008  2.633     0.075      8.490      0.051    0.045    2077.0    2303.0    1.0
c₂_scale       0.280  0.102     0.125      0.493      0.003    0.003    1194.0    1728.0    1.0
a[0, 0]       31.900  0.463    30.731     32.637      0.016    0.021    1815.0     754.0    1.0
a[0, 1]       31.609  0.538    30.537     32.577      0.011    0.010    2845.0    2277.0    1.0
a[1, 0]       45.521  0.870    43.650     46.427      0.034    0.069    1385.0     650.0    1.0
a[1, 1]       44.842  0.815    43.216     46.156      0.017    0.022    2870.0    2747.0    1.0
a[2, 0]       30.066  0.431    29.251     30.719      0.010    0.022    2563.0    1600.0    1.0
a[2, 1]       31.175  0.673    29.987     32.482      0.012    0.020    3614.0    2677.0    1.0
b_raw[0, 0]    1.118  0.479     0.330      2.112      0.011    0.007    1767.0    1753.0    1.0
b_raw[0, 1]    1.372  0.550     0.419      2.458      0.010    0.009    2899.0    2533.0    1.0
b_raw[1, 0]    0.607  0.356     0.078      1.335      0.010    0.009    1399.0    2050.0    1.0
b_raw[1, 1]    0.495  0.312     0.040      1.110      0.006    0.006    1912.0    2230.0    1.0
b_raw[2, 0]    0.570  0.305     0.101      1.193      0.007    0.007    1773.0    2025.0    1.0
b_raw[2, 1]    0.321  0.215     0.024      0.743      0.005    0.005    1879.0    2145.0    1.0
g_raw[0, 0]    1.187  0.389     0.453      1.915      0.012    0.007    1048.0     983.0    1.0
g_raw[0, 1]    1.127  0.372     0.457      1.859      0.011    0.007    1038.0    1049.0    1.0
g_raw[1, 0]    0.746  0.246     0.303      1.226      0.007    0.005    1056.0     957.0    1.0
g_raw[1, 1]    0.640  0.215     0.254      1.063      0.006    0.004    1092.0    1102.0    1.0
g_raw[2, 0]    0.569  0.202     0.202      0.955      0.006    0.003    1123.0    1105.0    1.0
g_raw[2, 1]    0.692  0.243     0.274      1.198      0.007    0.004    1128.0    1148.0    1.0
h_raw[0, 0]    0.981  0.359     0.372      1.716      0.011    0.006    1053.0    1225.0    1.0
h_raw[0, 1]    0.665  0.244     0.244      1.165      0.007    0.004    1056.0    1246.0    1.0
h_raw[1, 0]    0.769  0.304     0.251      1.373      0.009    0.006    1076.0    1308.0    1.0
h_raw[1, 1]    0.649  0.313     0.170      1.260      0.007    0.008    1423.0    1490.0    1.0
h_raw[2, 0]    0.941  0.369     0.295      1.657      0.010    0.006    1191.0    1458.0    1.0
h_raw[2, 1]    0.991  0.464     0.234      1.933      0.010    0.008    1952.0    2093.0    1.0
v_raw[0, 0]    0.975  0.597     0.077      2.158      0.008    0.010    3906.0    2311.0    1.0
v_raw[0, 1]    0.936  0.606     0.039      2.113      0.008    0.009    3781.0    2116.0    1.0
v_raw[1, 0]    0.646  0.598     0.000      1.819      0.014    0.009    1059.0     727.0    1.0
v_raw[1, 1]    0.755  0.603     0.001      1.946      0.010    0.009    2017.0    1373.0    1.0
v_raw[2, 0]    0.786  0.608     0.001      1.977      0.010    0.009    2269.0    1195.0    1.0
v_raw[2, 1]    0.780  0.602     0.001      1.962      0.009    0.009    2906.0    2006.0    1.0
c₁_raw[0, 0]   0.932  0.592     0.045      2.082      0.009    0.009    3525.0    2086.0    1.0
c₁_raw[0, 1]   0.927  0.598     0.031      2.074      0.008    0.009    3902.0    2410.0    1.0
c₁_raw[1, 0]   0.704  0.621     0.002      1.893      0.011    0.010    1953.0    1767.0    1.0
c₁_raw[1, 1]   0.822  0.581     0.006      1.919      0.009    0.009    3242.0    2062.0    1.0
c₁_raw[2, 0]   0.318  0.490     0.002      1.418      0.011    0.011    1553.0    3387.0    1.0
c₁_raw[2, 1]   0.799  0.600     0.009      2.003      0.008    0.009    3473.0    2772.0    1.0
c₂_raw[0, 0]   0.300  0.112     0.116      0.526      0.003    0.002    1364.0    1744.0    1.0
c₂_raw[0, 1]   0.503  0.187     0.186      0.869      0.005    0.004    1330.0    1477.0    1.0
c₂_raw[1, 0]   0.335  0.131     0.115      0.577      0.004    0.003    1304.0    2088.0    1.0
c₂_raw[1, 1]   0.917  0.332     0.341      1.555      0.009    0.006    1383.0    1902.0    1.0
c₂_raw[2, 0]   1.239  0.464     0.462      2.167      0.010    0.007    1947.0    2264.0    1.0
c₂_raw[2, 1]   1.351  0.469     0.523      2.266      0.012    0.008    1368.0    1682.0    1.0

Visualizing the curves#

Before plotting the curves, we have to generate predictions using the posterior.

# Create prediction dataframe
prediction_df = model.make_prediction_dataset(df=df, num_points=100)

# Use the model to predict on the prediction dataframe
predictive = model.predict(df=prediction_df, posterior=posterior)

This returns the posterior predictive distribution. We can use it to plot the estimated curves. Again, we specify the path where the generated PDF will be stored.

output_path = os.path.join(current_directory, "curves.pdf")

# Plot recruitment curves
model.plot_curves(
    df=df,
    prediction_df=prediction_df,
    predictive=predictive,
    posterior=posterior,
    encoder=encoder,
    output_path=output_path
)
../_images/2af04b2465d7ad99c8dbd44febda894254a5e4efae7699ffc65c7af0a4a1d8bf.png

In each panel above, the top plot shows the estimated curve overlaid on data, and the bottom plot shows the posterior distribution of the threshold parameter.

Mixture extension#

The curves look good overall, except for participant P1 and muscle FCR, where the growth rate seems to be biased by a few data points. This can be addressed with a mixture model. To enable it, we can add the following line at the end of the model-building code:

# Enable mixture model
model.use_mixture = True

Accessing parameters#

Each participant, muscle combination is assigned a tuple index, which can be used to access the curve parameters, which are stored in the posterior dictionary as NumPy arrays. Here we show how to access the threshold parameter.

from hbmep.util import site

# Threshold parameter
a = posterior[site.a]
Shape of a:	(4000, 3, 2)

First dimension corresponds to the number of samples:		4000
Second dimension corresponds to the number of participants:	3
Last dimension corresponds to the number of muscles:		2

By default, we have 4000 posterior samples (4 chains, 1000 samples each). We can set more chains or samples by updating the model-building code:

# Use 10 chains, 2000 samples each, for a total of 20,000 samples
model.mcmc_params = {
    "num_chains": 10,
    "thinning": 1,
    "num_warmup": 2000,
    "num_samples": 2000,
}

The other curve parameters can be accessed similarly using their keys.

print(f"{site.b} controlds the growth rate")
print(f"{site.g} is the offset")
print(f"({site.g} + {site.h}) is the saturation")
b controlds the growth rate
g is the offset
(g + h) is the saturation

Saving the model#

We can save the model, posterior samples, and other objects using pickle for later analysis.

import pickle

# Save the model
output_path = os.path.join(current_directory, "model.pkl")
with open(output_path, "wb") as f:
    pickle.dump((model,), f)
print(f"Saved model to {output_path}")

# Save the dataframe, encoder, and posterior samples
output_path = os.path.join(current_directory, "inference.pkl")
with open(output_path, "wb") as f:
    pickle.dump((df, encoder, posterior,), f)
print(f"Saved samples to {output_path}")

# Save the MCMC object
output_path = os.path.join(model.build_dir, "mcmc.pkl")
with open(output_path, "wb") as f:
    pickle.dump((mcmc,), f)
print(f"Saved MCMC object to {output_path}")

Using other functions#

Alternatively, we can use other functions to estimate the recruitment curves. The following choicej are available:

  • logistic-4, also known as the Boltzmann sigmoid, is the most common function used to estimate recruitment curves

  • logistic-5, is a more generalized version of logistic-4

  • rectified-linear

If estimating threshold is not important, we recommend using logistic-5 over logistic-4, which has a much better predictive performance.

For example, to use logistic-5 function, we need to update the model-building code and point to it, and rest of the tutorial remains the same.

# Set the function to logistic-5
model._model = model.logistic5