DM_DIRICHLET_SUM
This function summarizes a Dirichlet distribution for Bayesian categorical modeling. It computes the Dirichlet mean and variance for each category and evaluates both the density and log-density at a supplied probability vector on the simplex.
For concentration parameters \boldsymbol{\alpha}=(\alpha_1,\ldots,\alpha_K), the mean and variance of each category probability are:
\mathbb{E}[\theta_i]=\frac{\alpha_i}{\alpha_0}, \qquad \mathrm{Var}(\theta_i)=\frac{\alpha_i(\alpha_0-\alpha_i)}{\alpha_0^2(\alpha_0+1)}
where \alpha_0=\sum_{i=1}^K \alpha_i. The input probability vector must satisfy \sum_i x_i=1 with each x_i\in(0,1).
Excel Usage
=DM_DIRICHLET_SUM(alpha, x)
alpha(list[list], required): Dirichlet concentration parameters as a 2D range of positive values.x(list[list], required): Category probability vector as a 2D range with entries in (0, 1) that sum to 1.
Returns (list[list]): 2D array with mean row, variance row, and a row containing density and log-density.
Example 1: Symmetric three-category Dirichlet summary
Inputs:
| alpha | x | ||||
|---|---|---|---|---|---|
| 2 | 2 | 2 | 0.3 | 0.4 | 0.3 |
Excel formula:
=DM_DIRICHLET_SUM({2,2,2}, {0.3,0.4,0.3})
Expected output:
| Result | ||
|---|---|---|
| 0.333333 | 0.333333 | 0.333333 |
| 0.031746 | 0.031746 | 0.031746 |
| 4.32 | 1.46326 |
Example 2: Skewed four-category concentration with valid simplex point
Inputs:
| alpha | x | ||||||
|---|---|---|---|---|---|---|---|
| 5 | 2 | 3 | 4 | 0.4 | 0.1 | 0.2 | 0.3 |
Excel formula:
=DM_DIRICHLET_SUM({5,2,3,4}, {0.4,0.1,0.2,0.3})
Expected output:
| Result | |||
|---|---|---|---|
| 0.357143 | 0.142857 | 0.214286 | 0.285714 |
| 0.0153061 | 0.00816327 | 0.0112245 | 0.0136054 |
| 59.7794 | 4.09066 |
Example 3: Matrix-shaped input ranges are flattened correctly
Inputs:
| alpha | x | ||
|---|---|---|---|
| 3 | 1 | 0.25 | 0.1 |
| 2 | 4 | 0.15 | 0.5 |
Excel formula:
=DM_DIRICHLET_SUM({3,1;2,4}, {0.25,0.1;0.15,0.5})
Expected output:
| Result | |||
|---|---|---|---|
| 0.3 | 0.1 | 0.2 | 0.4 |
| 0.0190909 | 0.00818182 | 0.0145455 | 0.0218182 |
| 35.4375 | 3.56777 |
Example 4: Concentrated prior with dominant category probability
Inputs:
| alpha | x | ||||
|---|---|---|---|---|---|
| 10 | 2 | 1 | 0.75 | 0.15 | 0.1 |
Excel formula:
=DM_DIRICHLET_SUM({10,2,1}, {0.75,0.15,0.1})
Expected output:
| Result | ||
|---|---|---|
| 0.769231 | 0.153846 | 0.0769231 |
| 0.0126796 | 0.00929839 | 0.00507185 |
| 14.8668 | 2.69913 |
Python Code
import numpy as np
from scipy.stats import dirichlet as scipy_dirichlet
def dm_dirichlet_sum(alpha, x):
"""
Compute Dirichlet density and moments for a category-probability vector.
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.dirichlet.html
This example function is provided as-is without any representation of accuracy.
Args:
alpha (list[list]): Dirichlet concentration parameters as a 2D range of positive values.
x (list[list]): Category probability vector as a 2D range with entries in (0, 1) that sum to 1.
Returns:
list[list]: 2D array with mean row, variance row, and a row containing density and log-density.
"""
try:
def to2d(v):
return [[v]] if not isinstance(v, list) else v
def flatten_numeric(mat):
flat = []
for row in mat:
if not isinstance(row, list):
return None
for val in row:
try:
flat.append(float(val))
except (TypeError, ValueError):
continue
return flat
alpha = to2d(alpha)
x = to2d(x)
alpha_flat = flatten_numeric(alpha)
x_flat = flatten_numeric(x)
if alpha_flat is None or x_flat is None:
return "Error: alpha and x must be 2D lists"
if len(alpha_flat) < 2:
return "Error: alpha must contain at least two positive values"
if len(alpha_flat) != len(x_flat):
return "Error: alpha and x must have the same number of elements"
if any(a <= 0 for a in alpha_flat):
return "Error: alpha values must be positive"
if any((xi <= 0) or (xi >= 1) for xi in x_flat):
return "Error: x values must be strictly between 0 and 1"
x_sum = float(sum(x_flat))
if abs(x_sum - 1.0) > 1e-8:
return "Error: x values must sum to 1"
alpha_arr = np.asarray(alpha_flat, dtype=float)
x_arr = np.asarray(x_flat, dtype=float)
mean = scipy_dirichlet.mean(alpha_arr).tolist()
var = scipy_dirichlet.var(alpha_arr).tolist()
pdf_value = float(scipy_dirichlet.pdf(x_arr, alpha_arr))
logpdf_value = float(scipy_dirichlet.logpdf(x_arr, alpha_arr))
width = len(alpha_flat)
summary_row = [pdf_value, logpdf_value] + [""] * max(0, width - 2)
return [mean, var, summary_row]
except Exception as e:
return f"Error: {str(e)}"