目 录CONTENT

文章目录

机器学习数据增强的完整指南

Administrator
2026-01-16 / 0 评论 / 0 点赞 / 0 阅读 / 0 字

📢 转载信息

原文链接:https://machinelearningmastery.com/the-complete-guide-to-data-augmentation-for-machine-learning/

原文作者:Kanwal Mehreen


在本文中,您将学习使用数据增强来减少过拟合和提高图像、文本、音频和表格数据集泛化能力的实用、安全的方法。


我们将涵盖的主题包括:

  • 增强的工作原理以及何时有帮助。
  • 在线与离线增强策略。
  • 图像(TensorFlow/Keras)、文本(NLTK)、音频(librosa)和表格数据(NumPy/Pandas)的实操示例,以及数据泄露的关键陷阱。

好了,我们开始吧。

The Complete Guide to Data Augmentation for Machine Learning

机器学习数据增强的完整指南
图片作者提供

假设您已经构建了机器学习模型,运行了实验,并对结果感到困惑,不确定哪里出了问题。训练准确率看起来不错,甚至可能令人印象深刻,但当您检查验证准确率时……效果却不理想。您可以通过获取更多数据来解决这个问题。但这既耗时、昂贵,有时甚至是不可能的。


这并非是凭空捏造假数据。而是通过对已有的数据进行细微修改,同时不改变其含义或标签,来创建新的训练样本。您以多种形式向模型展示同一个概念。您是在教导模型什么重要,什么可以忽略。增强有助于模型泛化,而不是简单地记忆训练集。在本文中,您将了解数据增强在实践中如何工作以及何时使用它。具体来说,我们将涵盖:

  • 什么是数据增强以及它如何帮助减少过拟合
  • 离线数据增强与在线数据增强的区别
  • 如何使用TensorFlow对图像数据应用增强
  • 文本数据的简单安全增强技术
  • 音频和表格数据集的常见增强方法
  • 为什么增强过程中的数据泄露会悄无声息地破坏您的模型

离线与在线数据增强

增强可以在训练前或训练期间发生。离线增强仅扩展数据集一次并保存。在线增强则在每个周期(epoch)生成新的变体。深度学习管道通常偏爱在线增强,因为它能让模型接触到有效不受限制的变化,而无需增加存储空间。

图像数据增强

图像数据增强是最直观的入门点。一只狗即使经过轻微的旋转、缩放或在不同光照条件下观察,它仍然是狗。您的模型需要在训练期间看到这些变化。一些常见的图像增强技术包括:

  • 旋转 (Rotation)
  • 翻转 (Flipping)
  • 调整大小 (Resizing)
  • 裁剪 (Cropping)
  • 缩放 (Zooming)
  • 平移 (Shifting)
  • 剪切 (Shearing)
  • 亮度和对比度变化

这些变换不会改变标签——只改变外观。让我们使用TensorFlowKeras来演示一个简单的例子:

1. 导入库

1
2
3
4
5
6
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D, Dropout
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential

2. 加载 MNIST 数据集

1
2
3
4
5
6
7
8
9
10
11
12
13
(X_train, y_train), (X_test, y_test) = mnist.load_data()
 
# Normalize pixel values
X_train = X_train / 255.0
X_test = X_test / 255.0
 
# Reshape to (samples, height, width, channels)
X_train = X_train.reshape(-1, 28, 28, 1)
X_test = X_test.reshape(-1, 28, 28, 1)
 
# One-hot encode labels
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

输出:

1
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz

3. 定义用于增强的ImageDataGenerator

1
2
3
4
5
6
7
8
9
datagen = ImageDataGenerator(
   rotation_range=15,       # rotate images by ±15 degrees
   width_shift_range=0.1,   # 10% horizontal shift
   height_shift_range=0.1,      # 10% vertical shift
   zoom_range=0.1,          # zoom in/out by 10%
   shear_range=0.1,         # apply shear transformation
   horizontal_flip=False,   # not needed for digits
   fill_mode='nearest'      # fill missing pixels after transformations
)

4. 构建一个简单的CNN模型

1
2
3
4
5
6
7
8
9
10
11
12
model = Sequential([
   Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
   MaxPooling2D((2, 2)),
   Conv2D(64, (3, 3), activation='relu'),
   MaxPooling2D((2, 2)),
   Flatten(),
   Dropout(0.3),
   Dense(64, activation='relu'),
   Dense(10, activation='softmax')
])
 
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

5. 训练模型

1
2
3
4
5
6
7
8
9
batch_size = 64
epochs = 5
 
history = model.fit(
   datagen.flow(X_train, y_train, batch_size=batch_size, shuffle=True),
   steps_per_epoch=len(X_train)//batch_size,
   epochs=epochs,
   validation_data=(X_test, y_test)
)

输出:

Output of training

6. 可视化增强后的图像

1
2
3
4
5
6
7
8
9
10
11
import matplotlib.pyplot as plt
 
# Visualize five augmented variants of the first training sample
plt.figure(figsize=(10, 2))
for i, batch in enumerate(datagen.flow(X_train[:1], batch_size=1)):
   plt.subplot(1, 5, i + 1)
   plt.imshow(batch[0].reshape(28, 28), cmap='gray')
   plt.axis('off')
   if i == 4:
       break
plt.show()

输出:

Output of augmentation

文本数据增强

文本更加微妙。您不能随意替换单词而不考虑其含义。但是,小的、受控的更改可以帮助模型泛化。这是一个使用同义词替换(结合NLTK)的简单示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import nltk
from nltk.corpus import wordnet
import random
 
nltk.download("wordnet")
nltk.download("omw-1.4")
 
def synonym_replacement(sentence):
    words = sentence.split()
    if not words:
        return sentence
    idx = random.randint(0, len(words) - 1)
    synsets = wordnet.synsets(words[idx])
    if synsets and synsets[0].lemmas():
        replacement = synsets[0].lemmas()[0].name().replace("_", " ")
        words[idx] = replacement
    return " ".join(words)
 
text = "The movie was really good"
print(synonym_replacement(text))

输出:

1
2
[nltk_data] Downloading package wordnet to /root/nltk_data...
The movie was truly good

含义相同。新的训练样本。在实践中,像nlpaug这样的库或反向翻译API常被用于获得更可靠的结果。

音频数据增强

音频数据也从增强中受益良多。一些常见的音频增强技术包括:

  • 添加背景噪音
  • 时间拉伸 (Time stretching)
  • 音高偏移 (Pitch shifting)
  • 音量缩放 (Volume scaling)

最简单和最常用的音频增强之一是添加背景噪音和时间拉伸。这有助于语音和声音模型在嘈杂的现实环境中表现得更好。让我们通过一个简单的例子(使用librosa)来理解:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import librosa
import numpy as np
 
# Load built-in trumpet audio from librosa
audio_path = librosa.ex("trumpet")
audio, sr = librosa.load(audio_path, sr=None)
 
# Add background noise
noise = np.random.randn(len(audio))
audio_noisy = audio + 0.005 * noise
 
# Time stretching
audio_stretched = librosa.effects.time_stretch(audio, rate=1.1)
 
print("Sample rate:", sr)
print("Original length:", len(audio))
print("Noisy length:", len(audio_noisy))
print("Stretched length:", len(audio_stretched))

输出:

1
2
3
4
5
Downloading file 'sorohanro_-_solo-trumpet-06.ogg' from 'https://librosa.org/data/audio/sorohanro_-_solo-trumpet-06.ogg' to '/root/.cache/librosa'.
Sample rate: 22050
Original length: 117601... [内容被截断]



🚀 想要体验更好更全面的AI调用?

欢迎使用青云聚合API,约为官网价格的十分之一,支持300+全球最新模型,以及全球各种生图生视频模型,无需翻墙高速稳定,文档丰富,小白也可以简单操作。

0

评论区