目 录CONTENT

文章目录

语言模型微调的温和介绍

Administrator
2026-01-07 / 0 评论 / 0 点赞 / 0 阅读 / 0 字

📢 转载信息

原文链接:https://machinelearningmastery.com/a-gentle-introduction-to-language-model-fine-tuning/

原文作者:Adrian Tam


在预训练之后,语言模型已经学习了关于人类语言的知识。您可以通过在附加数据上对其进行训练,来增强模型在特定领域的理解能力。此外,当您提供特定的指令时,您也可以训练模型来执行特定的任务。这些在预训练之后的附加训练被称为微调(fine-tuning)。在本文中,您将学习如何微调语言模型。具体来说,您将学习:

  • 不同微调的示例及其目标是什么
  • 如何将预训练脚本转换为执行微调

让我们开始吧!

语言模型微调的温和介绍
图片来源:Nick Night。保留部分权利。

概览

本文分为四个部分;它们是:

  • 微调模型的理由
  • 用于微调的数据集
  • 微调过程
  • 其他微调技术

微调模型的理由

一旦您训练了您的仅解码器(decoder-only)的Transformer模型,您就拥有了一个文本生成器。您可以提供任何提示(prompt),模型将生成一些文本。它生成的内容取决于您拥有的模型。

让我们考虑一个非常简单的生成算法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
...
def apply_repetition_penalty(logits: Tensor, tokens: list[int], penalty: float) -> Tensor:
    """Apply repetition penalty to the logits."""
    for tok in tokens:
        if logits[tok] > 0:
            logits[tok] /= penalty
        else:
            logits[tok] *= penalty
    return logits
 
 
@torch.no_grad()
def generate(model, tokenizer, prompt, max_tokens=100, temperature=1.0, repetition_penalty=1.0,
             repetition_penalty_range=10, top_k=50, device=None) -> str:
    """Generate text autoregressively from a prompt.
 
    Args:
        model: The trained LlamaForPretraining model
        tokenizer: The tokenizer
        prompt: Input text prompt
        max_tokens: Maximum number of tokens to generate
        temperature: Sampling temperature (higher = more random)
        repetition_penalty: Penalty for repeating tokens
        repetition_penalty_range: Number of previous tokens to consider for repetition penalty
        top_k: Only sample from top k most likely tokens
        device: Device the model is loaded on
 
    Returns:
        Generated text
    """
    # Turn model to evaluation mode: Norm layer will work differently
    model.eval()
 
    # Get special token IDs
    bot_id = tokenizer.token_to_id("[BOT]")
    eot_id = tokenizer.token_to_id("[EOT]")
 
    # Tokenize the prompt into integer tensor
    prompt_tokens = [bot_id] + tokenizer.encode(" " + prompt).ids
    input_ids = torch.tensor(prompt_tokens, dtype=torch.int64, device=device).unsqueeze(0)
 
    # Recursively generate tokens
    generated_tokens = []
    for _step in range(max_tokens):
        # Forward pass through model
        logits = model(input_ids)
 
        # Get logits for the last token
        next_token_logits = logits[0, -1, :] / temperature
 
        # Apply repetition penalty
        if repetition_penalty != 1.0 and len(generated_tokens) > 0:
            next_token_logits = apply_repetition_penalty(
                next_token_logits,
                generated_tokens[-repetition_penalty_range:],
                repetition_penalty,
            )
 
        # Apply top-k filtering
        if top_k > 0:
            top_k_logits = torch.topk(next_token_logits, top_k)[0]
            indices_to_remove = next_token_logits < top_k_logits[-1]
            next_token_logits[indices_to_remove] = float("-inf")
 
        # Sample from the filtered distribution
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
 
        # Early stop if EOT token is generated
        if next_token.item() == eot_id:
            break
 
        # Append the new token to input_ids for next iteration
        input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
        generated_tokens.append(next_token.item())
 
    # Decode all generated tokens
    return tokenizer.decode(generated_tokens)

上面函数 generate() 是一种效率低下但简单的基于采样的文本生成方法。您的模型接收一个提示,并为下一个 token 生成一个 logits(对数几率)张量。之所以称为 logits,是因为它们与下一个 token 概率的对数成正比。模型处理的是 tokens。要生成一个 token,需要对 logits 执行几个步骤:

  1. 使用 temperature(温度)参数来缩放 logits。这会影响下一个 token 选择的概率分布。
  2. 操作 logits。在上述代码中,您应用了重复惩罚(repetition penalty)来惩罚已存在于生成的 token 序列中的 token。您还应用了 top-k 过滤,将选择限制在最有可能的 $k$ 个 token 中。
  3. 将 logits 转换为概率,然后使用多项式采样算法(multinomial sampling algorithm)来选择下一个 token。

您可以通过始终使用 torch.argmax() 来选择下一个 token,使这个过程更简单。这被称为贪婪解码(greedy decoding)。它通常不被推荐,因为它产生的输出看起来不自然,并且不允许任何变化。

您可以尝试使用前一篇文章中训练的模型进行测试。下面是一个使用简单提示生成文本的完整代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
13... [内容被截断]
0

评论区