DM_POST_UPDATE

This function performs conjugate Bayesian updating for categorical data. A Dirichlet prior combined with multinomial counts yields a Dirichlet posterior with category-wise hyperparameters incremented by the observed counts.

If prior hyperparameters are \boldsymbol{\alpha} and observed counts are \mathbf{n}, then:

\alpha_i' = \alpha_i + n_i

The function returns both the posterior hyperparameters and posterior predictive means \alpha_i' / \sum_j \alpha_j'.

Excel Usage

=DM_POST_UPDATE(alpha_prior, counts)
  • alpha_prior (list[list], required): Prior Dirichlet hyperparameters as a 2D range of positive values.
  • counts (list[list], required): Observed category counts as a 2D range of nonnegative integer values.

Returns (list[list]): 2D array with posterior hyperparameters in the first row and posterior predictive means in the second row.

Example 1: Three-category posterior update with moderate counts

Inputs:

alpha_prior counts
1 1 1 5 2 3

Excel formula:

=DM_POST_UPDATE({1,1,1}, {5,2,3})

Expected output:

Result
6 3 4
0.461538 0.230769 0.307692
Example 2: Informative prior combined with larger sample counts

Inputs:

alpha_prior counts
10 4 6 12 8 5

Excel formula:

=DM_POST_UPDATE({10,4,6}, {12,8,5})

Expected output:

Result
22 12 11
0.488889 0.266667 0.244444
Example 3: Matrix-shaped prior and counts are flattened consistently

Inputs:

alpha_prior counts
2 3 1 0
4 5 6 2

Excel formula:

=DM_POST_UPDATE({2,3;4,5}, {1,0;6,2})

Expected output:

Result
3 3 10 7
0.130435 0.130435 0.434783 0.304348
Example 4: Sparse counts preserve influence of prior

Inputs:

alpha_prior counts
0.8 1.5 2 0.7 0 1 0 2

Excel formula:

=DM_POST_UPDATE({0.8,1.5,2,0.7}, {0,1,0,2})

Expected output:

Result
0.8 2.5 2 2.7
0.1 0.3125 0.25 0.3375

Python Code

import numpy as np

def dm_post_update(alpha_prior, counts):
    """
    Update Dirichlet posterior parameters from prior hyperparameters and observed counts.

    See: https://en.wikipedia.org/wiki/Dirichlet_distribution#Conjugate_to_categorical_or_multinomial

    This example function is provided as-is without any representation of accuracy.

    Args:
        alpha_prior (list[list]): Prior Dirichlet hyperparameters as a 2D range of positive values.
        counts (list[list]): Observed category counts as a 2D range of nonnegative integer values.

    Returns:
        list[list]: 2D array with posterior hyperparameters in the first row and posterior predictive means in the second row.
    """
    try:
        def to2d(v):
            return [[v]] if not isinstance(v, list) else v

        def flatten_numeric(mat):
            if not isinstance(mat, list) or not all(isinstance(row, list) for row in mat):
                return None
            flat = []
            for row in mat:
                for val in row:
                    try:
                        flat.append(float(val))
                    except (TypeError, ValueError):
                        continue
            return flat

        alpha_prior = to2d(alpha_prior)
        counts = to2d(counts)

        alpha_flat = flatten_numeric(alpha_prior)
        counts_flat = flatten_numeric(counts)

        if alpha_flat is None or counts_flat is None:
            return "Error: alpha_prior and counts must be 2D lists"
        if len(alpha_flat) < 2:
            return "Error: alpha_prior must contain at least two positive values"
        if len(alpha_flat) != len(counts_flat):
            return "Error: alpha_prior and counts must have the same number of elements"
        if any(a <= 0 for a in alpha_flat):
            return "Error: alpha_prior values must be positive"

        clean_counts = []
        for c in counts_flat:
            c_int = int(round(c))
            if abs(c - c_int) > 1e-9 or c_int < 0:
                return "Error: counts must contain nonnegative integers"
            clean_counts.append(c_int)

        posterior = (np.asarray(alpha_flat, dtype=float) + np.asarray(clean_counts, dtype=float)).tolist()
        total_post = float(sum(posterior))
        means = [p / total_post for p in posterior]

        return [posterior, means]
    except Exception as e:
        return f"Error: {str(e)}"

Online Calculator

Prior Dirichlet hyperparameters as a 2D range of positive values.
Observed category counts as a 2D range of nonnegative integer values.