📢 转载信息
原文链接:https://machinelearningmastery.com/a-gentle-introduction-to-language-model-fine-tuning/
原文作者:Adrian Tam
在预训练之后,一个语言模型已经学习了关于人类语言的知识。您可以通过在附加数据上训练它来增强模型在特定领域的理解能力。当您提供特定的指令时,您也可以训练模型来执行特定的任务。这些在预训练之后的附加训练被称为微调(fine-tuning)。在本文中,您将学习如何微调语言模型。具体来说,您将学习:
- 微调的不同示例及其目标
- 如何将预训练脚本转换为执行微调
让我们开始吧!
语言模型微调的温和介绍
照片来源:Nick Night。保留部分权利。
概述
本文分为四个部分;它们是:
- 模型微调的原因
- 微调的数据集
- 微调过程
- 其他微调技术
模型微调的原因
一旦您训练好了解码器仅有的Transformer模型,您就拥有了一个文本生成器。您可以提供任何提示(prompt),模型将生成一些文本。它生成的内容取决于您拥有的模型。
让我们考虑一个非常简单的生成算法:
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
|
...
def apply_repetition_penalty(logits: Tensor, tokens: list[int], penalty: float) -> Tensor:
"""Apply repetition penalty to the logits."""
for tok in tokens:
if logits[tok] > 0:
logits[tok] /= penalty
else:
logits[tok] *= penalty
return logits
@torch.no_grad()
def generate(model, tokenizer, prompt, max_tokens=100, temperature=1.0, repetition_penalty=1.0,
repetition_penalty_range=10, top_k=50, device=None) -> str:
"""Generate text autoregressively from a prompt.
Args:
model: The trained LlamaForPretraining model
tokenizer: The tokenizer
prompt: Input text prompt
max_tokens: Maximum number of tokens to generate
temperature: Sampling temperature (higher = more random)
repetition_penalty: Penalty for repeating tokens
repetition_penalty_range: Number of previous tokens to consider for repetition penalty
top_k: Only sample from top k most likely tokens
device: Device the model is loaded on
Returns:
Generated text
"""
# Turn model to evaluation mode: Norm layer will work differently
model.eval()
# Get special token IDs
bot_id = tokenizer.token_to_id("[BOT]")
eot_id = tokenizer.token_to_id("[EOT]")
# Tokenize the prompt into integer tensor
prompt_tokens = [bot_id] + tokenizer.encode(" " + prompt).ids
input_ids = torch.tensor(prompt_tokens, dtype=torch.int64, device=device).unsqueeze(0)
# Recursively generate tokens
generated_tokens = []
for _step in range(max_tokens):
# Forward pass through model
logits = model(input_ids)
# Get logits for the last token
next_token_logits = logits[0, -1, :] / temperature
# Apply repetition penalty
if repetition_penalty != 1.0 and len(generated_tokens) > 0:
next_token_logits = apply_repetition_penalty(
next_token_logits,
generated_tokens[-repetition_penalty_range:],
repetition_penalty,
)
# Apply top-k filtering
if top_k > 0:
top_k_logits = torch.topk(next_token_logits, top_k)[0]
indices_to_remove = next_token_logits < top_k_logits[-1]
next_token_logits[indices_to_remove] = float("-inf")
# Sample from the filtered distribution
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Early stop if EOT token is generated
if next_token.item() == eot_id:
break
# Append the new token to input_ids for next iteration
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
generated_tokens.append(next_token.item())
# Decode all generated tokens
return tokenizer.decode(generated_tokens)
|
上面 generate() 函数是一个低效但简单的基于采样的文本生成方法。您的模型接收一个提示(prompt),并为下一个 token 生成一个 Logits 张量。它们被称为 logits,因为它们与下一个 token 的概率的对数成正比。模型处理的是 token。要生成一个 token,需要对 logits 进行几个步骤的处理:
- 使用温度参数(temperature parameter)缩放 logits。这会影响下一个 token 被选中的概率分布。
- 操作 logits。在上面的代码中,您应用了重复惩罚(repetition penalty)来惩罚已存在于生成 token 序列中的 token。您还应用了 Top-$k$ 过滤(top-$k$ filtering),将选择范围限制在最可能的 $k$ 个 token 中。
- 将 logits 转换为概率,然后使用多项式采样算法(multinomial sampling algorithm)来选择下一个 token。
您可以通过始终使用 torch.argmax() 来选择下一个 token,使这个过程更简单。这被称为贪婪解码(greedy decoding)。通常不推荐这样做,因为输出看起来不自然,并且不允许任何变化。
您可以尝试使用前一篇文章中训练的模型来使用它。下面是一个完整的代码,用于使用简单提示生成文本:
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
|
import dataclasses import tokenizers import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor
# Model architecture same as training script
@dataclasses.dataclass class LlamaConfig:
"""Define Llama model hyperparameters."""
vocab_size: int = 50000
max_position_embeddings: int = 2048
hidden_size: int = 768
intermediate_size: int = 4*4096
num_hidden_layers: int = 12
num_attention_heads: int = 12
num_key_value_heads: int = 3
class RotaryPositionEncoding(nn.Module):
"""Rotary position encoding."""
def __init__(self: self, dim: int, max_position_embeddings: int) -> None:
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
N = 10_000.0
inv_freq = 1.0 / (N ** (torch.arange(0, dim, 2) / dim)
inv_freq = torch.cat((inv_freq, inv_freq), dim=-1)
position = torch.arange(max_position_embeddings)
sinusoid_inp = torch.outer(position, inv_freq)
self.register_buffer("cos", sinusoid_inp.cos())
self.register_buffer("sin", sinusoid_inp.sin())
def forward(self: self, x: Tensor) -> Tensor:
batch_size, seq_len, num_heads, head_dim = x.shape
device = x.device
dtype = x.dtype
cos = self.cos.to(device, dtype)[:seq_len].view(1, seq_len, 1, -1)
sin = self.sin.to(device, dtype)[:seq_len].view(1, seq_len, 1, -1)
x1, x2 = x.chunk(2, dim=-1)
rotated = torch.cat((-x2, x1), dim=-1)
return (x * cos) + (rotated * sin)
class LlamaAttention(nn.Module):
"""Grouped-query attention with rotary embeddings."""
def __init__(self: self, config: LlamaConfig) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_kv_heads = config.num_key_value_heads
assert (self.head_dim * self.num_heads) == self.hidden_size
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
def forward(self: self, hidden_states: Tensor, rope: RotaryPositionEncoding) -> Tensor:
bs, seq_len, dim = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bs, seq_len, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim)
attn_output = F.scaled_dot_product_attention(
rope(query_states).transpose(1, 2),
rope(key_states).transpose(1, 2),
value_states.transpose(1, 2),
is_causal=True,
dropout_p=0.0,
enable_gqa=True,
)
attn_output = attn_output.transpose(1, 2).reshape(bs, seq_len, self.hidden_size)
return self.o_proj(attn_output)
class LlamaMLP(nn.Module):
"""Feed-forward network with SwiGLU activation."""
def __init__(self: self, config: LlamaConfig) -> None:
super().__init__()
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.act_fn = F.silu
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def forward(self: self, x: Tensor) -> Tensor:
gate = self.act_fn(self.gate_proj(x))
up = self.up_proj(x)
return self.down_proj(gate * up)
class LlamaDecoderLayer(nn.Module):
"""Single transformer layer for a Llama model."""
def __init__(self: self, config: LlamaConfig) -> None:
super().__init__()
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=1e-5)
self.self_attn = LlamaAttention(config)
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=1e-5)
self.mlp = LlamaMLP(config)
def forward(self: self, hidden_states: Tensor, rope: RotaryPositionEncoding) -> Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_outputs = self.self_attn(hidden_states, rope=rope)
hidden_states = attn_outputs + residual
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
return self.mlp(hidden_states) + residual
class LlamaModel(nn.Module):
"""The full Llama model without any pretraining heads."""
def __init__(self: self, config: LlamaConfig) -> None:
super().__init__()
self.rotary_emb = RotaryPositionEncoding(
config.hidden_size // config.num_attention_heads,
config.max_position_embeddings,
)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config)
for _ in range(config.num_hidden_layers)
])
self.norm = nn.RMSNorm(config.hidden_size, eps=1e-5)
def forward(self: self, input_ids: Tensor) -> Tensor:
hidden_states = self.embed_tokens(input_ids)
for layer in self.layers:
hidden_states = layer(hidden_states, rope=self.rotary_emb)
return self.norm(hidden_states)
class LlamaForPretraining(nn.Module):
def __init__(self: self, config: LlamaConfig) -> None:
super().__init__()
self.base_model = LlamaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def forward(self: self, input_ids: Tensor) -> Tensor:
hidden_states = self.base_model(input_ids)
return self.lm_head(hidden_states)
def apply_repetition_penalty(logits: Tensor, tokens: list[int], penalty: float)-> Tensor:
"""Apply repetition penalty to the logits."""
for tok in tokens:
if logits[tok] > 0:
logits[tok] /= penalty
else:
logits[tok] *= penalty
return logits
def generate(model, tokenizer, prompt, max_tokens=100, temperature=1.0, repetition_penalty=1.0, repetition_penalty_range=10, top_k=50, device=None)-> str:
"""Generate text autoregressively from a prompt.
Args:
model: The trained LlamaForPretraining model
tokenizer: The tokenizer
prompt: Input text prompt
max_tokens: Maximum number of tokens to generate
temperature: Sampling temperature (higher = more random)
repetition_penalty: Penalty for repeating tokens
repetition_penalty_range: Number of previous tokens to consider for repetition penalty
top_k: Only sample from top k most likely tokens
device: Device the model is loaded on
Returns:
Generated text
"""
# Turn model to evaluation mode: Norm layer will work differently
model.eval()
# Get special token IDs
bot_id = tokenizer.token_to_id("[BOT]")
eot_id = tokenizer.token_to_id("[EOT]")
# Tokenize the prompt into integer tensor
prompt_tokens = [bot_id] + tokenizer.encode(" " + prompt).ids
input_ids = torch.tensor([prompt_tokens], dtype=torch.int64, device=device)
# Recursively generate tokens
generated_tokens = []
for _step in range(max_tokens):
# Forward pass through model
logits = model(input_ids)
# Get logits for the last token
next_token_logits = logits[0, -1, :] / temperature
# Apply repetition penalty
if repetition_penalty != 1.0 and len(generated_tokens) > 0:
next_token_logits = apply_repetition_penalty(
next_token_logits,
generated_tokens[-repetition_penalty_range:],
repetition_penalty,
)
# Apply top-k filtering
if top_k > 0:
top_k_logits = torch.topk(next_token_logits, top_k)[0]
indices_to_remove = next_token_logits < top_k_logits[-1]
next_token_logits[indices_to_remove] = float("-inf")
# Sample from the filtered distribution
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Early stop if EOT token is generated
if next_token.item() == eot_id:
break
# Append the new token to input_ids for next iteration
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
generated_tokens.append(next_token.item())
# Decode all generated tokens
return tokenizer.decode(generated_tokens)
|
checkpoint = "llama_model_final.pth" # saved model checkpoint
tokenizer = "bpe_50K.json" # saved tokenizer
max_tokens = 100
temperature = 0.9
top_k = 50
penalty = 1.1
penalty_range = 10
# Load tokenizer and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = tokenizers.Tokenizer.from_file(tokenizer)
config = LlamaConfig()
model = LlamaForPretraining(config).to(device)
model.load_state_dict(torch.load(checkpoint, map_location=device))
prompt = "Once upon a time, there was"
response = generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
top_k=top_k,
repetition_penalty=penalty,
repetition_penalty_range=penalty_range,
device=device,
)
print(prompt)
print("-" * 20)
print(response)
🚀 想要体验更好更全面的AI调用?
欢迎使用青云聚合API,约为官网价格的十分之一,支持300+全球最新模型,以及全球各种生图生视频模型,无需翻墙高速稳定,文档丰富,小白也可以简单操作。
评论区