目 录CONTENT

文章目录

语言模型微调的温和介绍

Administrator
2026-01-21 / 0 评论 / 0 点赞 / 0 阅读 / 0 字

📢 转载信息

原文链接:https://machinelearningmastery.com/a-gentle-introduction-to-language-model-fine-tuning/

原文作者:Adrian Tam


在预训练之后,一个语言模型已经学习了关于人类语言的知识。你可以通过在额外的数据上对其进行训练,来增强模型在特定领域的理解能力。此外,当你提供特定的指令时,也可以训练模型来执行特定的任务。这些在预训练之后的额外训练被称为微调(fine-tuning)。在本文中,你将学习如何微调一个语言模型。具体来说,你将学习:

  • 微调的不同示例及其目标
  • 如何将预训练脚本转换为执行微调

让我们开始吧!

语言模型微调的温和介绍
图片来源:Nick Night。部分权利保留。

概述

本文分为四个部分:

  • 模型微调的原因
  • 微调数据集
  • 微调过程
  • 其他微调技术

模型微调的原因

一旦你训练好了你的仅解码器(decoder-only)Transformer模型,你就拥有了一个文本生成器。你可以提供任何提示(prompt),模型就会生成一些文本。它生成的内容取决于你使用的模型。

让我们考虑一个非常简单的生成算法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
...
def apply_repetition_penalty(logits: Tensor, tokens: list[int], penalty: float) -> Tensor:
    """Apply repetition penalty to the logits."""
    for tok in tokens:
        if logits[tok] > 0:
            logits[tok] /= penalty
        else:
            logits[tok] *= penalty
    return logits
 
 
@torch.no_grad()
def generate(model, tokenizer, prompt, max_tokens=100, temperature=1.0, repetition_penalty=1.0,
             repetition_penalty_range=10, top_k=50, device=None) -> str:
    """Generate text autoregressively from a prompt.
 
    Args:
        model: The trained LlamaForPretraining model
        tokenizer: The tokenizer
        prompt: Input text prompt
        max_tokens: Maximum number of tokens to generate
        temperature: Sampling temperature (higher = more random)
        repetition_penalty: Penalty for repeating tokens
        repetition_penalty_range: Number of previous tokens to consider for repetition penalty
        top_k: Only sample from top k most likely tokens
        device: Device the model is loaded on
 
    Returns:
        Generated text
    """
    # Turn model to evaluation mode: Norm layer will work differently
    model.eval()
 
    # Get special token IDs
    bot_id = tokenizer.token_to_id("[BOT]")
    eot_id = tokenizer.token_to_id("[EOT]")
 
    # Tokenize the prompt into integer tensor
    prompt_tokens = [bot_id] + tokenizer.encode(" " + prompt).ids
    input_ids = torch.tensor(prompt_tokens, dtype=torch.int64, device=device).unsqueeze(0)
 
    # Recursively generate tokens
    generated_tokens = []
    for _step in range(max_tokens):
        # Forward pass through model
        logits = model(input_ids)
 
        # Get logits for the last token
        next_token_logits = logits[0, -1, :] / temperature
 
        # Apply repetition penalty
        if repetition_penalty != 1.0 and len(generated_tokens) > 0:
            next_token_logits = apply_repetition_penalty(
                next_token_logits,
                generated_tokens[-repetition_penalty_range:],
                repetition_penalty,
            )
 
        # Apply top-k filtering
        if top_k > 0:
            top_k_logits = torch.topk(next_token_logits, top_k)[0]
            indices_to_remove = next_token_logits < top_k_logits[-1]
            next_token_logits[indices_to_remove] = float("-inf")
 
        # Sample from the filtered distribution
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
 
        # Early stop if EOT token is generated
        if next_token.item() == eot_id:
            break
 
        # Append the new token to input_ids for next iteration
        input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
        generated_tokens.append(next_token.item())
 
    # Decode all generated tokens
    return tokenizer.decode(generated_tokens)

上面代码中的 generate() 函数是一种效率不高但简单的基于采样的文本生成方法。你的模型接收一个提示,然后生成一个代表下一个 token 的 logits 张量。之所以称为 logits,是因为它们与下一个 token 概率的对数成比例。该模型处理的是 token。要生成一个 token,需要对 logits 执行几个步骤:

  1. 使用 temperature 参数对 logits 进行缩放。这会影响下一个 token 被选中的概率分布。
  2. 操作 logits。在上述代码中,你应用了重复惩罚(repetition penalty)来惩罚已存在于生成 token 序列中的 token。你还应用了 top-k 过滤,将选择范围限制在最有可能的 $k$ 个 token 之中。
  3. 将 logits 转换为概率,然后使用多项式采样算法(multinomial sampling algorithm)来选择下一个 token。

你可以通过始终使用 torch.argmax() 来选择下一个 token,从而简化此过程。这被称为贪婪解码(greedy decoding)。通常不推荐这种方法,因为它生成的输出看起来不自然,且不允许任何变化。

你可以尝试使用上一篇文章中训练的模型进行测试。下面是一个完整的代码,用于使用简单提示生成文本:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
...
def apply_repetition_penalty(logits: Tensor, tokens: list[int], penalty: float) -> Tensor:
    """Apply repetition penalty to the logits."""
    for tok in tokens:
        if logits[tok] > 0:
            logits[tok] /= penalty
        else:
            logits[tok] *= penalty
    return logits
 
 
@torch.no_grad()
def generate(model, tokenizer, prompt, max_tokens=100, temperature=1.0, repetition_penalty=1.0,
             repetition_penalty_range=10, top_k=50, device=None) -> str:
    """Generate text autoregressively from a prompt.
 
    Args:
        model: The trained LlamaForPretraining model
        tokenizer: The tokenizer
        prompt: Input text prompt
        max_tokens: Maximum number of tokens to generate
        temperature: Sampling temperature (higher = more random)
        repetition_penalty: Penalty for repeating tokens
        repetition_penalty_range: Number of previous tokens to consider for repetition penalty
        top_k: Only sample from top k most likely tokens
        device: Device the model is loaded on
 
    Returns:
        Generated text
    """
    # Turn model to evaluation mode: Norm layer will work differently
    model.eval()
 
    # Get special token IDs
    bot_id = tokenizer.token_to_id("[BOT]")
    eot_id = tokenizer.token_to_id("[EOT]")
 
    # Tokenize the prompt into integer tensor
    prompt_tokens = [bot_id] + tokenizer.encode(" " + prompt).ids
    input_ids = torch.tensor(prompt_tokens, dtype=torch.int64, device=device).unsqueeze(0)
 
    # Recursively generate tokens
    generated_tokens = []
    for _step in range(max_tokens):
        # Forward pass through model
        logits = model(input_ids)
 
        # Get logits for the last token
        next_token_logits = logits[0, -1, :] / temperature
 
        # Apply repetition penalty
        if repetition_penalty != 1.0 and len(generated_tokens) > 0:
            next_token_logits = apply_repetition_penalty(
                next_token_logits,
                generated_tokens[-repetition_penalty_range:],
                repetition_penalty,
            )
 
        # Apply top-k filtering
        if top_k > 0:
            top_k_logits = torch.topk(next_token_logits, top_k)[0]
            indices_to_remove = next_token_logits < top_k_logits[-1]
            next_token_logits[indices_to_remove] = float("-inf")
 
        # Sample from the filtered distribution
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
 
        # Early stop if EOT token is generated
        if next_token.item() == eot_id:
            break
 
        # Append the new token to input_ids for next iteration
        input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
        generated_tokens.append(next_token.item())
 
    # Decode all generated tokens
    return tokenizer.decode(generated_tokens)



🚀 想要体验更好更全面的AI调用?

欢迎使用青云聚合API,约为官网价格的十分之一,支持300+全球最新模型,以及全球各种生图生视频模型,无需翻墙高速稳定,文档丰富,小白也可以简单操作。

0

评论区