📢 转载信息
原文链接:https://machinelearningmastery.com/a-gentle-introduction-to-language-model-fine-tuning/
原文作者:Adrian Tam
在经过预训练后,语言模型已经学习了关于人类语言的知识。您可以通过在附加数据上训练它来增强模型在特定领域的理解能力。当您提供特定指令时,您也可以训练模型以执行特定任务。这些在预训练之后的附加训练被称为微调(fine-tuning)。在本文中,您将学习如何微调语言模型。具体来说,您将学习:
- 微调的不同示例及其目标
- 如何将预训练脚本转换为执行微调
让我们开始吧!
语言模型微调的温和介绍
图片来源:Nick Night。部分权利保留。
概述
本文分为四个部分;它们是:
- 模型微调的原因
- 微调数据集
- 微调流程
- 其他微调技术
模型微调的原因
一旦您训练了仅解码器的 Transformer 模型,您就得到了一个文本生成器。您可以提供任何提示(prompt),模型将生成一些文本。它生成的内容取决于您拥有的模型。
让我们考虑一个非常简单的生成算法:
... def apply_repetition_penalty(logits: Tensor, tokens: list[int], penalty: float) -> Tensor: """Apply repetition penalty to the logits.""" for tok in tokens: if logits[tok] > 0: logits[tok] /= penalty else: logits[tok] *= penalty return logits @torch.no_grad() def generate(model, tokenizer, prompt, max_tokens=100, temperature=1.0, repetition_penalty=1.0, repetition_penalty_range=10, top_k=50, device=None) -> str: """Generate text autoregressively from a prompt. Args: model: The trained LlamaForPretraining model tokenizer: The tokenizer prompt: Input text prompt max_tokens: Maximum number of tokens to generate temperature: Sampling temperature (higher = more random) repetition_penalty: Penalty for repeating tokens repetition_penalty_range: Number of previous tokens to consider for repetition penalty top_k: Only sample from top k most likely tokens device: Device the model is loaded on Returns: Generated text """ # Turn model to evaluation mode: Norm layer will work differently model.eval() # Get special token IDs bot_id = tokenizer.token_to_id("[BOT]") eot_id = tokenizer.token_to_id("[EOT]") # Tokenize the prompt into integer tensor prompt_tokens = [bot_id] + tokenizer.encode(" " + prompt).ids input_ids = torch.tensor(prompt_tokens, dtype=torch.int64, device=device).unsqueeze(0) # Recursively generate tokens generated_tokens = [] for _step in range(max_tokens): # Forward pass through model logits = model(input_ids) # Get logits for the last token next_token_logits = logits[0, -1, :] / temperature # Apply repetition penalty if repetition_penalty != 1.0 and len(generated_tokens) > 0: next_token_logits = apply_repetition_penalty( next_token_logits, generated_tokens[-repetition_penalty_range:], repetition_penalty, ) # Apply top-k filtering if top_k > 0: top_k_logits = torch.topk(next_token_logits, top_k)[0] indices_to_remove = next_token_logits < top_k_logits[-1] next_token_logits[indices_to_remove] = float("-inf") # Sample from the filtered distribution probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # Early stop if EOT token is generated if next_token.item() == eot_id: break # Append the new token to input_ids for next iteration input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1) generated_tokens.append(next_token.item()) # Decode all generated tokens return tokenizer.decode(generated_tokens)
上面的 generate() 函数是一种效率较低但简单的基于采样的文本生成方法。您的模型接收一个提示,并为下一个令牌生成一个 logits 张量。它们被称为 logits 是因为它们与下一个令牌的概率的对数成正比。该模型在令牌上工作。要生成一个令牌,需要对 logits 执行几个步骤:
- 使用温度参数缩放 logits。这会扭曲下一个令牌的选择概率。
- 操作 logits。在上面,您应用了重复惩罚,对已存在于令牌生成序列中的令牌施加惩罚。您还应用了 top-k 过滤,将选择限制在最有可能的 $k$ 个令牌中。
- 将 logits 转换为概率,然后使用多项式采样算法来选择下一个令牌。
您可以通过始终使用 torch.argmax() 来选择下一个令牌,使这变得更简单。这被称为贪婪解码(greedy decoding)。通常不推荐这样做,因为输出看起来不自然,并且不允许任何变化。
您可以尝试在前一篇文章中训练的模型上使用它。下面是一个完整的代码,用于使用简单提示生成文本:
import dataclasses import tokenizers import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor # Model architecture same as training script @dataclasses.dataclass class LlamaConfig: """Define Llama model hyperparameters.""" vocab_size: int = 50000 max_position_embeddings: int = 2048 hidden_size: int = 768 intermediate_size: int = 4*768 num_hidden_layers: int = 12 num_attention_heads: int = 12 num_key_value_heads: int = 3 class RotaryPositionEncoding(nn.Module): """Rotary position encoding.""" def __init__(self, dim: int, max_position_embeddings: int) -> None: super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings N = 10_000.0 inv_freq = 1.0 / (N ** (torch.arange(0, dim, 2) / dim)) inv_freq = torch.cat((inv_freq, inv_freq), dim=-1) position = torch.arange(max_position_embeddings) sinusoid_inp = torch.outer(position, inv_freq) self.register_buffer("cos", sinusoid_inp.cos()) self.register_buffer("sin", sinusoid_inp.sin()) def forward(self, x: Tensor) -> Tensor: batch_size, seq_len, num_heads, head_dim = x.shape device = x.device dtype = x.dtype cos = self.cos.to(device, dtype)[:seq_len].view(1, seq_len, 1, -1) sin = self.sin.to(device, dtype)[:seq_len].view(1, seq_len, 1, -1) x1, x2 = x.chunk(2, dim=-1) rotated = torch.cat((-x2, x1), dim=-1) return (x * cos) + (rotated * sin) class LlamaAttention(nn.Module): """Grouped-query attention with rotary embeddings.""" def __init__(self, config: LlamaConfig) -> None: super().__init__() self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_kv_heads = config.num_key_value_heads assert (self.head_dim * self.num_heads) == self.hidden_size self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding) -> Tensor: bs, seq_len, dim = hidden_states.size() query_states = self.q_proj(hidden_states).view(bs, seq_len, self.num_heads, self.head_dim) key_states = self.k_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim) value_states = self.v_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim) attn_output = F.scaled_dot_product_attention( rope(query_states).transpose(1, 2), rope(key_states).transpose(1, 2), value_states.transpose(1, 2), is_causal=True, dropout_p=0.0, enable_gqa=True, ) attn_output = attn_output.transpose(1, 2).reshape(bs, seq_len, self.hidden_size) return self.o_proj(attn_output) class LlamaMLP(nn.Module): """Feed-forward network with SwiGLU activation.""" def __init__(self, config: LlamaConfig) -> None: super().__init__() self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.act_fn = F.silu self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) def forward(self, x: Tensor) -> Tensor: gate = self.act_fn(self.gate_proj(x)) up = self.up_proj(x) return self.down_proj(gate * up) class LlamaDecoderLayer(nn.Module): """Single transformer layer for a Llama model.""" def __init__(self, config: LlamaConfig) -> None: super().__init__() self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=1e-5) self.self_attn = LlamaAttention(config) self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=1e-5) self.mlp = LlamaMLP(config) def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding) -> Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) attn_outputs = self.self_attn(hidden_states, rope=rope) hidden_states = attn_outputs + residual residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) return self.mlp(hidden_states) + residual class LlamaModel(nn.Module): """The full Llama model without any pretraining heads.""" def __init__(self, config: LlamaConfig) -> None: super().__init__() self.rotary_emb = RotaryPositionEncoding( config.hidden_size // config.num_attention_heads, config.max_position_embeddings, ) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers) ]) self.norm = nn.RMSNorm(config.hidden_size, eps=1e-5) def forward(self, input_ids: Tensor) -> Tensor: hidden_states = self.embed_tokens(input_ids) for layer in self.layers: hidden_states = layer(hidden_states, rope=self.rotary_emb) return self.norm(hidden_states) class LlamaForPretraining(nn.Module): def __init__(self, config: LlamaConfig) -> None: super().__init__() self.base_model = LlamaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def forward(self, input_ids: Tensor) -> Tensor: hidden_states = self.base_model(input_ids) return self.lm_head(hidden_states) def apply_repetition_penalty(logits: Tensor, tokens: list[int], penalty: float) -> Tensor: """Apply repetition penalty to the logits.""" for tok in tokens: if logits[tok] > 0: logits[tok] /= penalty else: logits[tok] *= penalty return logits @torch.no_grad() def generate(model, tokenizer, prompt, max_tokens=100, temperature=1.0, repetition_penalty=1.0, repetition_penalty_range=10, top_k=50, device=None) -> str: """Generate text autoregressively from a prompt. Args: model: The trained LlamaForPretraining model tokenizer: The tokenizer prompt: Input text prompt max_tokens: Maximum number of tokens to generate temperature: Sampling temperature (higher = more random) repetition_penalty: Penalty for repeating tokens repetition_penalty_range: Number of previous tokens to consider for repetition penalty top_k: Only sample from top k most likely tokens device: Device the model is loaded on Returns: Generated text """ # Turn model to evaluation mode: Norm layer will work differently model.eval() # Get special token IDs bot_id = tokenizer.token_to_id("[BOT]") eot_id = tokenizer.token_to_id("[EOT]") # Tokenize the prompt into integer tensor prompt_tokens = [bot_id] + tokenizer.encode(" " + prompt).ids input_ids = torch.tensor([prompt_tokens], dtype=torch.int64, device=device) # Recursively generate tokens generated_tokens = [] for _step in range(max_tokens): # Forward pass through model logits = model(input_ids) # Get logits for the last token next_token_logits = logits[0, -1, :] / temperature # Apply repetition penalty if repetition_penalty != 1.0 and len(generated_tokens) > 0: next_token_logits = apply_repetition_penalty( next_token_logits, generated_tokens[-repetition_penalty_range:], repetition_penalty, ) # Apply top-k filtering if top_k > 0: top_k_logits = torch.topk(next_token_logits, top_k)[0] indices_to_remove = next_token_logits < top_k_logits[-1] next_token_logits[indices_to_remove] = float("-inf") # Sample from the filtered distribution probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # Early stop if EOT token is generated if next_token.item() == eot_id: break # Append the new token to input_ids for next iteration input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1) generated_tokens.append(next_token.item()) # Decode all generated tokens return tokenizer.decode(generated_tokens) checkpoint = "llama_model_final.pth" # saved model checkpoint tokenizer = "bpe_50K.json" # saved tokenizer max_tokens = 100 temperature = 0.9 top_k = 50 penalty = 1.1 penalty_range = 10 # Load tokenizer and model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = tokenizers.Tokenizer.from_file(tokenizer) config = LlamaConfig() model = LlamaForPretraining(config).to(device) model.load_state_dict(torch.load(checkpoint, map_location=device)) prompt = "Once upon a time, there was" response = generate( model=model, tokenizer=tokenizer, prompt=prompt, max_tokens=max_tokens, temperature=temperature, top_k=top_k, repetition_penalty=penalty, repetition_penalty_range=penalty_range, device=device, ) print(prompt) print("-" * 20) print(response)
🚀 想要体验更好更全面的AI调用?
欢迎使用青云聚合API,约为官网价格的十分之一,支持300+全球最新模型,以及全球各种生图生视频模型,无需翻墙高速稳定,文档丰富,小白也可以简单操作。
评论区