import logging
import numpy as np
import jax
import jax.numpy as jnp
import numpyro as pyro
import numpyro.distributions as dist
from hbmep import functional as F
from hbmep.model import BaseModel
from hbmep.util import site
logger = logging.getLogger(__name__)
EPS = 1e-3
[docs]
class StandardHB(BaseModel):
"""
Standard hierarchical Bayesian model.
Features and responses are modeled as conditionally independent given
their model parameters. In particular, each feature and response combination
has its own set of curve and likelihood parameters, which are partially
pooled using shared hyperpriors.
This class implements models with three curve functions: rectified-logistic,
logistic5, and logistic4. Observations are modeled using a gamma likelihood.
Notes
-----
- The curve function can be selected as:
`model._model = model.rectified_logistic` (default),
`model._model = model.logistic5`, or
`model._model = model.logistic4`.
- The priors are specified for TMS data with
intensity in 0-100% MSO and response in millivolts (mV). To change the
prior specification, inherit from this class and override the
`_sample_priors` method.
- Setting `self.use_mixture = True` uses a mixture likelihood with an
additional outlier component.
"""
def __init__(self, *args, **kw):
super(StandardHB, self).__init__(*args, **kw)
self.use_mixture = False
def _model(self, *args, **kw):
return self.rectified_logistic(*args, **kw)
def _sample_priors(
self,
*,
num_features,
include_v: bool,
):
# Hyper-priors
a_loc = pyro.sample(
site.a.loc, dist.TruncatedNormal(50., 50., low=0)
)
a_scale = pyro.sample(site.a.scale, dist.HalfNormal(50.))
b_scale = pyro.sample(site.b.scale, dist.HalfNormal(5.))
g_scale = pyro.sample(site.g.scale, dist.HalfNormal(.1))
h_scale = pyro.sample(site.h.scale, dist.HalfNormal(5.))
if include_v:
v_scale = pyro.sample(site.v.scale, dist.HalfNormal(5.))
c1_scale = pyro.sample(site.c1.scale, dist.HalfNormal(5.))
c2_scale = pyro.sample(site.c2.scale, dist.HalfNormal(.5))
params = {}
with pyro.plate(site.num_response, self.num_response):
with pyro.plate_stack(
site.num_features, num_features, rightmost_dim=-2
):
params[site.a] = pyro.sample(
site.a, dist.TruncatedNormal(a_loc, a_scale, low=0)
)
b_raw = pyro.sample(site.b.raw, dist.HalfNormal(1))
params[site.b] = pyro.deterministic(site.b, b_scale * b_raw)
g_raw = pyro.sample(site.g.raw, dist.HalfNormal(1))
params[site.g] = pyro.deterministic(site.g, g_scale * g_raw)
h_raw = pyro.sample(site.h.raw, dist.HalfNormal(1))
params[site.h] = pyro.deterministic(site.h, h_scale * h_raw)
if include_v:
v_raw = pyro.sample(site.v.raw, dist.HalfNormal(1))
params[site.v] = pyro.deterministic(site.v, v_scale * v_raw)
c1_raw = pyro.sample(site.c1.raw, dist.HalfNormal(1))
params[site.c1] = pyro.deterministic(site.c1, c1_scale * c1_raw)
c2_raw = pyro.sample(site.c2.raw, dist.HalfNormal(1))
params[site.c2] = pyro.deterministic(site.c2, c2_scale * c2_raw)
return params
def _observe(
self,
*,
mu,
response,
mask_obs,
num_data,
params,
features,
):
c1 = params[site.c1]
c2 = params[site.c2]
g = params.get(site.g)
h = params.get(site.h)
if self.use_mixture:
if response is None:
q = 0.
else:
q = pyro.sample(site.outlier_prob, dist.Uniform(0., 0.01))
with pyro.handlers.mask(mask=mask_obs):
with pyro.plate(site.num_response, self.num_response):
with pyro.plate(site.num_data, num_data):
mu = pyro.deterministic(site.mu, mu)
alpha, beta = self.gamma_likelihood(
mu, c1[*features.T], c2[*features.T]
)
if self.use_mixture:
mixing_distribution = dist.Categorical(
probs=jnp.stack([1 - q, q], axis=-1)
)
component_distributions = [
dist.Gamma(concentration=alpha, rate=beta),
dist.HalfNormal(
scale=(g[*features.T] + h[*features.T])
),
]
dist_ = dist.MixtureGeneral(
mixing_distribution=mixing_distribution,
component_distributions=component_distributions,
)
else:
dist_ = dist.Gamma(concentration=alpha, rate=beta)
y_ = pyro.sample(site.obs, dist_, obs=response)
if self.use_mixture:
log_probs = dist_.component_log_probs(y_)
pyro.deterministic(
"p",
log_probs - jax.nn.logsumexp(
log_probs, axis=-1, keepdims=True
),
)
def rectified_logistic(self, intensity, features, response=None, **kw):
num_data = intensity.shape[0]
num_features = np.max(features, axis=0) + 1
mask_obs = True
if response is not None:
mask_obs = np.invert(np.isnan(response))
params = self._sample_priors(
num_features=num_features,
include_v=True,
)
mu = F.rectified_logistic(
intensity,
params[site.a][*features.T],
params[site.b][*features.T],
params[site.g][*features.T],
params[site.h][*features.T],
params[site.v][*features.T],
EPS,
)
self._observe(
mu=mu,
response=response,
mask_obs=mask_obs,
num_data=num_data,
params=params,
features=features,
)
def logistic5(self, intensity, features, response=None, **kw):
num_data = intensity.shape[0]
num_features = np.max(features, axis=0) + 1
mask_obs = True
if response is not None:
mask_obs = np.invert(np.isnan(response))
params = self._sample_priors(
num_features=num_features,
include_v=True,
)
mu = F.logistic5(
intensity,
params[site.a][*features.T],
params[site.b][*features.T],
params[site.g][*features.T],
params[site.h][*features.T],
params[site.v][*features.T],
)
self._observe(
mu=mu,
response=response,
mask_obs=mask_obs,
num_data=num_data,
params=params,
features=features,
)
def logistic4(self, intensity, features, response=None, **kw):
num_data = intensity.shape[0]
num_features = np.max(features, axis=0) + 1
mask_obs = True
if response is not None:
mask_obs = np.invert(np.isnan(response))
params = self._sample_priors(
num_features=num_features,
include_v=False,
)
mu = F.logistic4(
intensity,
params[site.a][*features.T],
params[site.b][*features.T],
params[site.g][*features.T],
params[site.h][*features.T],
)
self._observe(
mu=mu,
response=response,
mask_obs=mask_obs,
num_data=num_data,
params=params,
features=features,
)