Когда я впервые начал свою аспирантуру в 2021 году, я узнал о байесовских нейронных сетях и MCMC. Это был трудный процесс, и мне удалось написать рукопись на эту тему. Однако, поскольку я хочу исследовать другую сторону байесовских методов вывода, я встретил своего нынешнего научного руководителя, ученого, которого очень уважаю. Он дал мне сильное руководство по изучению вариационного байесовского подхода.

Все более сложные модели машинного обучения, которые включают большое количество параметров и предъявляют высокие требования к вычислительным ресурсам с точки зрения обучения, используются во многих статистических приложениях для решения проблем с большими данными.

Байесовские методы обеспечивают естественный и принципиальный инструмент для статистического вывода с использованием апостериорного распределения вероятностей.

Апостериорное распределение по параметрам θ модели обновляется с использованием априорного распределения вероятностей и функции правдоподобия после наблюдения за данными.

Поскольку апостериорное распределение часто имеет большую размерность и сильно невыпукло, задача вычисления и выборки из него может быть решена с помощью некоторых подходов, таких как цепь Маркова Монте-Карло (MCMC) и вариационный вывод.

Что такое вариационный байесовский анализ (VB)?

В классе методов вариационного вывода методы вариационного Байеса (VB) представляют собой семейство методов, которые аппроксимируют сложное и дорогостоящее апостериорное распределение, предоставляя аналитическое решение q(θ), которое принадлежит некоторому поддающемуся решению семейству распределений Q. Локально оптимизированное приближение q* в Q получается путем минимизации расхождения Кульбака-Лейбера (KL) между P(θ | данные) и q(θ).

VB в кодах

Чтобы реализовать VB для оценки параметров моделей, я использовал Python и JAX для ускорения кода. JAX — это NumPy на ЦП, ГП и ТПУ с отличной автоматической дифференциацией для высокопроизводительных исследований в области машинного обучения.

JAX требовал глубокого обучения, так как мне пришлось переписать большую часть кода, чтобы он был совместим с фреймворком. Самый сложный аспект JAX заключается в том, что, в отличие от массивов NumPy, массивы JAX всегда неизменяемы. Этот пример ниже взят из документации JAX:

# NumPy: mutable arrays
x = np.arange(10)
x[0] = 10
print(x)
[10  1  2  3  4  5  6  7  8  9]

Вы можете изменять элементы на месте в массиве Python. Однако выполнение этого в массивах JAX приведет к ошибкам.

# JAX: immutable arrays
x = jnp.arange(10)
x[0] = 10
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-7-6b90817377fe> in <module>()
      1 # JAX: immutable arrays
      2 x = jnp.arange(10)
----> 3 x[0] = 10

TypeError: '<class 'jax.interpreters.xla._DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?

Я представлю полный код Python для Gaussian VB. Код доказал свою работоспособность во многих примерах, от смоделированных до реальных.

Во-первых, мы обсудим, какие библиотеки мы используем для выполнения этого кода.

import jax.numpy as jnp
import numpy as np
from jax import grad, jit, vmap
from jax import random
from functools import partial

import time
import math
from tqdm import tqdm
import sklearn.mixture
import pandas as pd
import matplotlib.pyplot as plt

from jax.scipy.stats import multivariate_normal
from scipy.stats import multivariate_normal as ss_multivariate_normal
from sklearn.neighbors import KernelDensity
import optax

Во-вторых, мы выбираем, с какой проблемой мы хотим работать: регрессией или классификацией. В этом примере мы используем данные об абалоне, чтобы предсказать возраст морского ушка на основе физических измерений. Возраст морского ушка определяют, разрезая раковину через конус, окрашивая ее и подсчитывая количество колец под микроскопом — скучная и трудоемкая задача. Другие измерения, которые легче получить, используются для предсказания возраста. Дополнительную информацию о наборе данных можно найти по адресу https://archive.ics.uci.edu/dataset/1/abalone.

Затем мы определяем модель, используемую для прогнозирования целевой переменной. Выбранная модель представляет собой нейронную сеть прямого распространения с заданным количеством скрытых слоев. Ковариаты для входного слоя включают пол, длину, размер самой длинной скорлупы, диаметр, высоту, общий вес, вес очищенной скорлупы, вес внутренностей, вес скорлупы и кольца.

Мы используем Flax для формулировки нашей модели. Flax обеспечивает сквозной и гибкий пользовательский интерфейс для исследователей, использующих JAX с нейронными сетями.

from flax import linen as nn
from typing import Sequence, Tuple

class Model(nn.Module):
    features: Sequence[int] = (10,10,5,1)

    def setup(self):
        self.layers = [nn.Dense(feat) for feat in self.features]

    def __call__(self, inputs):
        x = inputs
        for i, lyr in enumerate(self.layers):
            x = lyr(x)
            if i != len(self.layers) - 1:
                x = nn.relu(x)
        return x

model = Model()

Для GVB необходимо будет сгладить структуру параметров, к которой мы привыкли в нейронных сетях. Следовательно, нам нужно сгенерировать некоторые начальные параметры из модели в виде «pytree», а затем «распутать» pytree.

# Initialize the key
init_rng = jax.random.split(rng, 1)
# Initialize the model
params = model.init(init_rng, X_train)
# Unravel the pytree
fp, unravel_fn = jax.flatten_util.ravel_pytree(params)

Для интереса посмотрим на количество коэффициентов модели.

num_coeffs = len(jax.flatten_util.ravel_pytree(params)[0])

Важной частью GVB является формула Байеса, которая позволяет нам вычислять апостериорное распределение, логарифм которого представляет собой сумму logpdf априорного распределения и вероятности. Мы применяем многомерную нормальную априорную оценку параметров.

def prior(theta):
    prior = multivariate_normal.logpdf(theta, mean = jnp.array([0] * num_coeffs), cov = jnp.identity(num_coeffs))
    return prior

Функция правдоподобия выглядит следующим образом.

def likelihood(theta, data_input, actual_data, tausq):
    params = unravel_fn(theta)
    # Obtain the logits and predictions of the model for the input data
    preds = model.apply(params, data_input)

    n = actual_data.shape[0]  # will change for multiple outputs (y.shape[0]*y.shape[1])
    log_lhood = -n/2 * jnp.log(2 * jnp.pi * tausq) - (1/(2*tausq)) * jnp.sum(jnp.square(actual_data - preds))
    return log_lhood

Основная идея GVB состоит в том, чтобы оптимизировать q*, чтобы свести к минимуму расхождение KL. q* принадлежит гауссову семейству; следовательно, его pdf представляет собой многомерное распределение Гаусса N(mu, sigma) со следующим logpdf.

def fun_log_q(theta, mu, l):
    log_q = multivariate_normal.logpdf(theta, mean = mu, cov= jnp.linalg.inv(l @ l.T))

Как видно из вышеизложенного, q* зависит от λ = (mu, L), где L — нижняя треугольная матрица сигмы.