Batch normalization: нормализация по батчу шаг за шагом

Глубокая сеть учится тяжело, когда распределение входов каждого слоя плавает от батча к батчу: предыдущие слои меняют веса, и следующий слой каждый раз видит «новую» статистику. Batch normalization (нормализация по батчу) решает это просто - приводит активации внутри мини-батча к нулевому среднему и единичной дисперсии, а затем разрешает сети самой подобрать нужный масштаб и сдвиг. Ниже разберём все четыре шага прямого прохода, роль обучаемых параметров и тонкость с поведением на инференсе. Чтобы пощупать механику на числах - покрутите слайдеры калькулятора чуть ниже.
Что делает batch normalization
Batch normalization - это слой, который вставляют между линейным (или свёрточным) преобразованием и функцией активации. Он берёт активации одного признака по всему мини-батчу, центрирует и масштабирует их. Идея, предложенная Иоффе и Сегеди в 2015 году, в том, что стабильная статистика входов позволяет учиться с большими шагами и меньше зависеть от удачной инициализации весов.
Ключевое слово здесь - «по батчу»: среднее и дисперсия считаются не по одному примеру, а по всем примерам мини-батча для каждого нейрона отдельно. Поэтому размер батча напрямую влияет на качество оценки статистики.
Четыре шага прямого прохода
Пусть в мини-батче примеров, и мы смотрим на один нейрон с активациями . Слой делает ровно четыре операции:
Сначала считается среднее по батчу , затем дисперсия (со знаменателем , это смещённая оценка). На третьем шаге каждое значение центрируется и делится на стандартное отклонение - получаем с нулевым средним и единичной дисперсией. Четвёртый шаг - аффинное преобразование с обучаемыми и .

Зачем нужен eps
В знаменателе стоит не просто , а . Малая константа (обычно ) защищает от деления на ноль: если все активации в батче одинаковы, дисперсия равна нулю, и без нормировка взорвалась бы. Заодно слегка сглаживает численную нестабильность при очень маленькой дисперсии.
Значение eps стоит держать одинаковым на обучении и на инференсе. Разные eps дадут чуть разные выходы и могут незаметно ухудшить метрики при переносе модели.
Зачем обучаемые gamma и beta
После третьего шага активации жёстко зажаты в стандартное распределение. Но это не всегда полезно: например, перед сигмоидой нулевое среднее и единичная дисперсия загоняют вход в почти линейную область, и сеть теряет нелинейность. Параметры (масштаб) и (сдвиг) возвращают слою свободу.
Важное свойство: при и преобразование становится тождественным - слой может полностью «отменить» нормализацию, если она вредна. То есть BN ничего не отнимает у сети: в худшем случае она научится игнорировать слой. Эти параметры обучаются обычным обратным распространением ошибки наравне с весами.
Обучение против инференса
Это самая частая ловушка. На обучении и считаются по текущему мини-батчу. Но на инференсе примеры часто приходят по одному, и статистики батча просто нет - а даже если есть, ответ не должен зависеть от того, какие соседи случайно попали в батч.
Поэтому во время обучения слой накапливает скользящие средние:
На инференсе вместо статистики батча подставляются именно эти накопленные и - они фиксированы, и выход для каждого примера детерминирован. Здесь - momentum (в PyTorch по умолчанию 0.1). Забыть переключить модель в режим инференса (model.eval()) - классическая причина «модель отлично училась, но на тесте бред».
Обратный проход: как течёт градиент
BN - дифференцируемая операция, поэтому градиент проходит через неё насквозь, но с одной особенностью: и сами зависят от всех батча, значит производная по конкретному учитывает вклад этого примера и в среднее, и в дисперсию. Из-за этого градиент по входу не локален - он «перемешивает» примеры внутри батча.
Сами параметры обновляются просто. Градиенты по обучаемым и - это суммы по батчу:
Здесь видно, зачем BN усредняет именно по батчу: и прямой, и обратный проход связывают примеры вместе. Это же объясняет лёгкий регуляризующий эффект - для каждого примера нормировка зависит от случайных соседей по батчу, то есть в обучение подмешивается шум.

BN в свёрточных сетях
В полносвязном слое нормализация идёт по каждому нейрону отдельно. В свёрточной сети логика та же, но усреднение учитывает пространственную структуру. Для тензора активаций размера среднее и дисперсия считаются по осям , , - то есть по всем примерам и всем пространственным позициям, но отдельно для каждого канала .
В итоге обучается по одному и одному на канал, а не на каждый пиксель. Это согласуется с идеей свёртки: один фильтр применяется ко всей карте, значит и нормировать его выход естественно одинаково по всему полю.
Куда ставить слой: до или после активации
Исходная статья Иоффе и Сегеди располагала BN между линейным преобразованием и нелинейностью: сначала , потом BN, потом ReLU. Логика - нормализовать именно вход нелинейности, чтобы он попадал в её «рабочую» зону, а не насыщался.
Со временем появился альтернативный порядок «активация, затем нормализация», и в разных архитектурах побеждали разные варианты - однозначного ответа нет, это вопрос эксперимента. Зато есть устойчивое практическое правило: если перед BN стоит линейный слой, его смещение можно убрать. Параметр в BN всё равно задаёт сдвиг, поэтому отдельное смещение избыточно и просто тратит параметры. Во многих фреймворках при связке «линейный слой плюс BN» bias отключают явно.
Что даёт нормализация по батчу
Практические эффекты, ради которых BN стал почти стандартом:
- Большая скорость обучения. Стабильная статистика входов позволяет увеличить learning rate без расходимости.
- Меньшая чувствительность к инициализации. Сеть прощает не идеально подобранные начальные веса.
- Лёгкая регуляризация. Шум от того, что статистика считается по случайному батчу, работает как слабый dropout - иногда BN позволяет уменьшить явную регуляризацию.
- Сглаживание ландшафта потерь. Более поздние работы показали, что главный вклад BN - не столько борьба с internal covariate shift, сколько то, что он делает поверхность функции потерь более гладкой и предсказуемой.
Частые ошибки
- Маленький батч. При или оценка и слишком шумная, и BN скорее вредит. Для таких случаев используют layer norm или group norm.
- Забытый
eval(). На инференсе модель продолжает считать статистику по батчу - ответы плавают и зависят от состава батча. - BN перед линейным слоем без активации. Два линейных преобразования подряд схлопываются в одно; смысл нормализации между ними теряется. BN ставят перед нелинейностью.
- Утечка статистики теста в обучение. Если считать по объединённым train и test данным - это утечка, метрики окажутся завышенными.
- Двойная регуляризация. BN уже даёт лёгкий шумовой эффект; агрессивный dropout поверх него иногда мешает сходимости.
FAQ
Чем batch normalization отличается от layer normalization? BN усредняет по примерам батча для каждого признака, поэтому зависит от размера батча. Layer normalization усредняет по признакам внутри одного примера и от батча не зависит - поэтому её предпочитают в трансформерах и при батче размера 1.
Можно ли применять BN к входным данным вместо обычной нормировки? Технически да, но обычно вход нормируют один раз статически (вычитают среднее по обучающей выборке). BN полезнее во внутренних слоях, где статистика активаций меняется во время обучения.
Почему дисперсия делится на m, а не на m минус 1? В формуле BN используется смещённая оценка дисперсии (знаменатель ). Это сознательный выбор авторов: на больших батчах разница ничтожна, а формула проще и согласована со скользящими статистиками.
Коротко
Batch normalization приводит активации мини-батча к нулевому среднему и единичной дисперсии за четыре шага: среднее , дисперсия , нормировка и аффинный сдвиг . Обучаемые и возвращают сети свободу масштаба и при нужных значениях делают слой тождественным. На инференсе вместо статистики батча используются накопленные за обучение скользящие среднее и дисперсия - поэтому критично переключать модель в режим оценки. BN ускоряет обучение, ослабляет зависимость от инициализации и сглаживает ландшафт потерь.
Читайте также

Алгоритм обратного распространения ошибки: как учится сеть
Backpropagation простыми словами: как обратное распространение ошибки считает градиенты по цепному правилу, обновляет веса нейросети и при чём тут исчезающий градиент. С формулами и разбором.

Self-attention механизм: как токен смотрит на контекст
Self-attention механизм простыми словами: почему Q, K, V берутся из одной последовательности, как слой собирает контекст для каждого токена, зачем позиционное кодирование и где ошибаются.

Механизм внимания attention: формула и примеры
Механизм внимания attention: как работает scaled dot-product attention, формула softmax(QK^T/sqrt(d_k))V, матрица весов, multi-head и где чаще всего ошибаются.