L

Initializing Studio...

Docs

Getting Started

  • Introduction
  • Quick Start
  • Installation

Fine-tuning

  • LoRA & QLoRA
  • Full Fine-tuning

API & SDK

  • REST API
  • Python SDK

Deployment

  • Cloud Deployment
  • Security

Resources

  • FAQ
  • Changelog
Docs

Getting Started

  • Introduction
  • Quick Start
  • Installation

Fine-tuning

  • LoRA & QLoRA
  • Full Fine-tuning

API & SDK

  • REST API
  • Python SDK

Deployment

  • Cloud Deployment
  • Security

Resources

  • FAQ
  • Changelog

RLHF - Reinforcement Learning from Human Feedback

Master Reinforcement Learning from Human Feedback (RLHF) to align language models with human preferences, create safer AI systems, and improve response quality through human feedback integration.

👥

Human Alignment

Align language models with human preferences and values through iterative feedback and reinforcement learning.

🛡️

Safety & Control

Build safer AI systems by incorporating human judgment and reducing harmful or inappropriate outputs.

✨

Quality Improvement

Significantly improve response quality, relevance, and engagement through preference-based optimization.

🔄

Iterative Learning

Continuously improve model behavior through ongoing human feedback collection and model updates.

Understanding RLHF

Reinforcement Learning from Human Feedback (RLHF) is a powerful technique for aligning language models with human preferences and values. It's the key technology behind models like ChatGPT, Claude, and other successful conversational AI systems.

What is RLHF?

RLHF is a multi-stage training process that uses human feedback to teach language models to generate responses that are more helpful, harmless, and honest. Instead of only optimizing for next-token prediction, RLHF optimizes for human-preferred outcomes.

The Three-Stage RLHF Process:

Stage 1: Supervised Fine-Tuning (SFT)
- Start with a pre-trained language model
- Fine-tune on high-quality human demonstrations
- Creates an instruction-following base model
- Typically uses thousands of human-written examples

Stage 2: Reward Model Training
- Collect human preference data (comparisons between outputs)
- Train a reward model to predict human preferences
- The reward model learns to score outputs based on human values
- Uses techniques like Bradley-Terry model for preference learning

Stage 3: Reinforcement Learning
- Use the reward model to train the language model via RL
- Typically uses Proximal Policy Optimization (PPO)
- Balances reward maximization with staying close to the original model
- Includes KL-divergence penalty to prevent mode collapse

Why RLHF is Important:

Alignment with Human Values:
- Models learn what humans actually want, not just what they say
- Captures nuanced preferences that are hard to specify in rules
- Enables teaching complex behaviors like helpfulness and safety

Quality Improvement:
- Significantly improves response quality and relevance
- Reduces harmful, biased, or inappropriate outputs
- Creates more engaging and useful conversational AI

Safety and Control:
- Provides a framework for building safer AI systems
- Allows incorporation of human judgment and oversight
- Enables iterative improvement based on real-world feedback

Key Challenges:

Reward Hacking:
- Models may exploit reward model weaknesses
- Can lead to outputs that score high but aren't actually good
- Requires careful reward model design and regularization

Human Feedback Quality:
- Inconsistent human preferences
- Annotator bias and subjectivity
- Need for clear guidelines and quality control

Computational Cost:
- RL training is computationally expensive
- Requires multiple model copies during training
- Significantly more complex than supervised learning

Distribution Shift:
- Reward model may not generalize to new situations
- Policy may discover out-of-distribution behaviors
- Requires ongoing monitoring and updates
python
1# RLHF implementation overview using TRL (Transformer Reinforcement Learning)
2import torch
3from transformers import (
4 AutoTokenizer, AutoModelForCausalLM,
5 AutoModelForSequenceClassification
6)
7from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
8from datasets import Dataset
9import numpy as np
10from typing import List, Dict
11
12class RLHFPipeline:
13 def __init__(self, model_name: str, reward_model_name: str = None):
14 self.model_name = model_name
15 self.reward_model_name = reward_model_name
16 self.tokenizer = None
17 self.sft_model = None
18 self.reward_model = None
19 self.ppo_model = None
20
21 def stage1_supervised_fine_tuning(self, sft_dataset: Dataset):
22 """Stage 1: Supervised Fine-Tuning on human demonstrations."""
23
24 print("Stage 1: Supervised Fine-Tuning")
25 print("=" * 40)
26
27 # Load tokenizer and model
28 self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
29 if self.tokenizer.pad_token is None:
30 self.tokenizer.pad_token = self.tokenizer.eos_token
31
32 self.sft_model = AutoModelForCausalLM.from_pretrained(
33 self.model_name,
34 torch_dtype=torch.bfloat16,
35 device_map="auto"
36 )
37
38 # Prepare SFT dataset
39 def format_sft_example(example):
40 """Format example for SFT training."""
41 prompt = f"Human: {example['prompt']}\n\nAssistant: "
42 response = example['chosen'] # Use human-preferred response
43 full_text = prompt + response + self.tokenizer.eos_token
44 return {"text": full_text}
45
46 formatted_dataset = sft_dataset.map(format_sft_example)
47
48 # SFT training (simplified - use full Trainer in practice)
49 print(f"SFT dataset size: {len(formatted_dataset)}")
50 print("SFT training would happen here...")
51 print("✓ Stage 1 completed: SFT model ready")
52
53 # Save SFT model
54 self.sft_model.save_pretrained("./sft_model")
55 self.tokenizer.save_pretrained("./sft_model")
56
57 def stage2_reward_model_training(self, preference_dataset: Dataset):
58 """Stage 2: Train reward model on human preferences."""
59
60 print("\nStage 2: Reward Model Training")
61 print("=" * 40)
62
63 # Load model for reward modeling (typically smaller than main model)
64 reward_model_name = self.reward_model_name or self.model_name
65
66 self.reward_model = AutoModelForSequenceClassification.from_pretrained(
67 reward_model_name,
68 num_labels=1, # Single scalar reward
69 torch_dtype=torch.bfloat16,
70 device_map="auto"
71 )
72
73 # Prepare preference dataset
74 def format_preference_example(example):
75 """Format preference comparison for reward model training."""
76 prompt = example['prompt']
77 chosen = example['chosen']
78 rejected = example['rejected']
79
80 # Create prompt + response pairs
81 chosen_text = f"Human: {prompt}\n\nAssistant: {chosen}"
82 rejected_text = f"Human: {prompt}\n\nAssistant: {rejected}"
83
84 return {
85 'chosen': chosen_text,
86 'rejected': rejected_text
87 }
88
89 formatted_preferences = preference_dataset.map(format_preference_example)
90
91 # Reward model training (simplified)
92 print(f"Preference dataset size: {len(formatted_preferences)}")
93 print("Reward model training would happen here...")
94 print("Training on preference pairs (chosen > rejected)")
95 print("✓ Stage 2 completed: Reward model ready")
96
97 # Save reward model
98 self.reward_model.save_pretrained("./reward_model")
99
100 def stage3_ppo_training(self, prompts_dataset: Dataset, ppo_config: PPOConfig = None):
101 """Stage 3: PPO training using reward model."""
102
103 print("\nStage 3: PPO Training")
104 print("=" * 40)
105
106 # Default PPO configuration
107 if ppo_config is None:
108 ppo_config = PPOConfig(
109 model_name="./sft_model",
110 learning_rate=1.41e-5,
111 batch_size=16,
112 mini_batch_size=4,
113 gradient_accumulation_steps=1,
114 optimize_cuda_cache=True,
115 early_stopping=False,
116 target_kl=0.1,
117 ppo_epochs=4,
118 seed=0,
119 init_kl_coef=0.2,
120 adap_kl_ctrl=True,
121 )
122
123 # Load model with value head for PPO
124 self.ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained(
125 "./sft_model",
126 torch_dtype=torch.bfloat16,
127 device_map="auto"
128 )
129
130 # Load reward model
131 reward_model = AutoModelForSequenceClassification.from_pretrained(
132 "./reward_model",
133 torch_dtype=torch.bfloat16,
134 device_map="auto"
135 )
136
137 # Create PPO trainer
138 ppo_trainer = PPOTrainer(
139 config=ppo_config,
140 model=self.ppo_model,
141 ref_model=None, # Will use model copy as reference
142 tokenizer=self.tokenizer,
143 )
144
145 # Prepare prompts for PPO training
146 prompts = [f"Human: {prompt}\n\nAssistant: " for prompt in prompts_dataset['prompt']]
147
148 print(f"PPO training on {len(prompts)} prompts")
149
150 # PPO training loop (simplified)
151 for epoch in range(3): # Limited epochs for demonstration
152 print(f"\nPPO Epoch {epoch + 1}")
153
154 for batch_idx in range(0, len(prompts), ppo_config.batch_size):
155 batch_prompts = prompts[batch_idx:batch_idx + ppo_config.batch_size]
156
157 # Generate responses
158 prompt_tensors = [
159 self.tokenizer.encode(prompt, return_tensors="pt")[0]
160 for prompt in batch_prompts
161 ]
162
163 # Generate responses from current policy
164 response_tensors = []
165 for prompt_tensor in prompt_tensors:
166 response = self.ppo_model.generate(
167 prompt_tensor.unsqueeze(0),
168 max_new_tokens=50,
169 do_sample=True,
170 temperature=0.7,
171 pad_token_id=self.tokenizer.eos_token_id
172 )
173 response_tensors.append(response[0])
174
175 # Calculate rewards using reward model
176 rewards = []
177 for prompt_tensor, response_tensor in zip(prompt_tensors, response_tensors):
178 # Combine prompt and response
179 full_text = self.tokenizer.decode(response_tensor, skip_special_tokens=True)
180
181 # Get reward score
182 inputs = self.tokenizer(full_text, return_tensors="pt", truncation=True, max_length=512)
183 inputs = {k: v.to(reward_model.device) for k, v in inputs.items()}
184
185 with torch.no_grad():
186 reward_score = reward_model(**inputs).logits[0, 0].item()
187
188 rewards.append(reward_score)
189
190 # Convert to tensors
191 rewards = [torch.tensor(r) for r in rewards]
192
193 # PPO training step
194 stats = ppo_trainer.step(prompt_tensors, response_tensors, rewards)
195
196 if batch_idx % 4 == 0: # Log every few batches
197 print(f" Batch {batch_idx//ppo_config.batch_size + 1}: "
198 f"Mean reward: {np.mean([r.item() for r in rewards]):.3f}")
199
200 print("✓ Stage 3 completed: RLHF training finished")
201
202 # Save final model
203 self.ppo_model.save_pretrained("./rlhf_model")
204
205 return ppo_trainer
206
207 def evaluate_rlhf_model(self, test_prompts: List[str]):
208 """Evaluate the RLHF-trained model."""
209
210 print("\nEvaluating RLHF Model")
211 print("=" * 40)
212
213 if self.ppo_model is None:
214 # Load the trained model
215 self.ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained("./rlhf_model")
216
217 self.ppo_model.eval()
218
219 for i, prompt in enumerate(test_prompts, 1):
220 formatted_prompt = f"Human: {prompt}\n\nAssistant: "
221
222 inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
223 inputs = {k: v.to(self.ppo_model.device) for k, v in inputs.items()}
224
225 # Generate response
226 with torch.no_grad():
227 outputs = self.ppo_model.generate(
228 **inputs,
229 max_new_tokens=150,
230 do_sample=True,
231 temperature=0.7,
232 top_p=0.9,
233 pad_token_id=self.tokenizer.eos_token_id
234 )
235
236 response = self.tokenizer.decode(
237 outputs[0][inputs['input_ids'].shape[1]:],
238 skip_special_tokens=True
239 )
240
241 print(f"\nTest {i}:")
242 print(f"Human: {prompt}")
243 print(f"Assistant: {response}")
244 print("-" * 40)
245
246# Sample data preparation functions
247def create_sft_dataset():
248 """Create sample SFT dataset with human demonstrations."""
249
250 sft_examples = [
251 {
252 "prompt": "Explain quantum computing in simple terms.",
253 "chosen": "Quantum computing uses quantum mechanical phenomena like superposition and entanglement to process information in ways that classical computers cannot. Unlike classical bits that are either 0 or 1, quantum bits (qubits) can exist in multiple states simultaneously, potentially allowing quantum computers to solve certain problems exponentially faster than classical computers."
254 },
255 {
256 "prompt": "How can I improve my sleep quality?",
257 "chosen": "Here are some evidence-based strategies to improve sleep quality: 1) Maintain a consistent sleep schedule, 2) Create a relaxing bedtime routine, 3) Ensure your bedroom is cool, dark, and quiet, 4) Avoid caffeine and screens before bedtime, 5) Get regular exercise during the day, and 6) Consider relaxation techniques like meditation or deep breathing."
258 },
259 {
260 "prompt": "Write a Python function to reverse a string.",
261 "chosen": "Here's a simple Python function to reverse a string:\n\ndef reverse_string(s):\n return s[::-1]\n\n# Example usage:\noriginal = 'hello'\nreversed_str = reverse_string(original)\nprint(reversed_str) # Output: 'olleh'\n\nThis uses Python's slice notation with a step of -1 to reverse the string efficiently."
262 }
263 ]
264
265 return Dataset.from_list(sft_examples)
266
267def create_preference_dataset():
268 """Create sample preference dataset for reward model training."""
269
270 preference_examples = [
271 {
272 "prompt": "What's the capital of France?",
273 "chosen": "The capital of France is Paris. It's a beautiful city known for its art, culture, cuisine, and iconic landmarks like the Eiffel Tower and Louvre Museum.",
274 "rejected": "Paris is the capital. It's in France and has some buildings and stuff."
275 },
276 {
277 "prompt": "How do I bake a chocolate cake?",
278 "chosen": "To bake a chocolate cake, you'll need: flour, sugar, cocoa powder, eggs, butter, baking powder, and milk. Mix dry ingredients, cream butter and sugar, add eggs, then alternate adding dry ingredients and milk. Bake at 350°F for 25-30 minutes. Let me know if you'd like a detailed recipe!",
279 "rejected": "Just mix some chocolate stuff together and put it in the oven until it looks done. Should work fine."
280 },
281 {
282 "prompt": "Is it safe to eat raw eggs?",
283 "chosen": "Eating raw eggs carries some risk of Salmonella infection, though the risk is relatively low (about 1 in 20,000 eggs). Pasteurized eggs are safer for raw consumption. If you're pregnant, elderly, or immunocompromised, it's best to avoid raw eggs. For recipes requiring raw eggs, consider pasteurized alternatives.",
284 "rejected": "Raw eggs are totally fine to eat, there's no risk at all. Eat as many as you want!"
285 }
286 ]
287
288 return Dataset.from_list(preference_examples)
289
290def create_prompts_dataset():
291 """Create sample prompts for PPO training."""
292
293 prompts = [
294 "Explain the importance of exercise.",
295 "What's the best way to learn a new language?",
296 "How does photosynthesis work?",
297 "Give me tips for public speaking.",
298 "What are the benefits of meditation?",
299 "How do I start investing in stocks?",
300 "Explain machine learning to a beginner.",
301 "What's the difference between weather and climate?",
302 ]
303
304 return Dataset.from_dict({"prompt": prompts})
305
306# Main execution example
307if __name__ == "__main__":
308 # Initialize RLHF pipeline
309 rlhf = RLHFPipeline(
310 model_name="microsoft/DialoGPT-medium", # Use smaller model for demo
311 reward_model_name="microsoft/DialoGPT-small"
312 )
313
314 # Create sample datasets
315 sft_data = create_sft_dataset()
316 preference_data = create_preference_dataset()
317 prompts_data = create_prompts_dataset()
318
319 # Run RLHF pipeline
320 print("Starting RLHF Pipeline")
321 print("=" * 50)
322
323 # Stage 1: SFT
324 rlhf.stage1_supervised_fine_tuning(sft_data)
325
326 # Stage 2: Reward Model
327 rlhf.stage2_reward_model_training(preference_data)
328
329 # Stage 3: PPO Training
330 ppo_config = PPOConfig(
331 model_name="./sft_model",
332 learning_rate=1.41e-5,
333 batch_size=4, # Small batch for demo
334 mini_batch_size=2,
335 ppo_epochs=2, # Fewer epochs for demo
336 target_kl=0.1,
337 )
338
339 rlhf.stage3_ppo_training(prompts_data, ppo_config)
340
341 # Evaluate final model
342 test_prompts = [
343 "What's the best way to stay healthy?",
344 "Explain artificial intelligence.",
345 "How do I write a good email?"
346 ]
347
348 rlhf.evaluate_rlhf_model(test_prompts)
349
350 print("\nRLHF Pipeline completed successfully!")

Reward Model Design

The reward model is the heart of RLHF, translating human preferences into scalar rewards that guide the language model's training. Designing effective reward models is crucial for successful RLHF.

Reward Model Architecture:

Base Model Selection:
- Often uses the same architecture as the policy model but smaller
- Can be initialized from the SFT model or trained from scratch
- Typical sizes: 350M-6B parameters for reward models

Output Layer Design:
- Single scalar output representing reward/preference score
- Linear layer with no activation (raw logits)
- Some approaches use multiple output heads for different criteria

Training Objective:

Bradley-Terry Model:
The most common approach uses the Bradley-Terry model for preference learning:

$$P(y_1 \succ y_2 | x) = \frac{\exp(r(x, y_1))}{\exp(r(x, y_1)) + \exp(r(x, y_2))}$$

Where:
- $x$ is the input prompt
- $y_1$ and $y_2$ are two responses
- $r(x, y)$ is the reward model's score
- $y_1 \succ y_2$ means $y_1$ is preferred over $y_2$

Loss Function:
$$L = -\log P(y_{chosen} \succ y_{rejected} | x)$$

Data Collection Strategies:

Human Preference Collection:
- Present pairs of model outputs to human annotators
- Ask "Which response is better?" with clear criteria
- Collect rankings rather than just binary preferences
- Include confidence scores from annotators

Quality Control:
- Multiple annotators per comparison
- Inter-annotator agreement metrics
- Clear annotation guidelines
- Regular calibration sessions

Active Learning:
- Select most informative comparisons
- Focus on areas where the model is uncertain
- Iteratively improve the reward model

Reward Model Training Best Practices:

Data Preprocessing:
- Balance positive and negative examples
- Handle ties and close comparisons appropriately
- Filter out low-confidence annotations
- Ensure diverse prompt coverage

Training Techniques:
- Use appropriate learning rates (typically lower than SFT)
- Implement dropout for regularization
- Monitor validation accuracy carefully
- Use early stopping to prevent overfitting

Evaluation Metrics:
- Accuracy on held-out preference data
- Correlation with human judgments
- Consistency across different prompt types
- Robustness to adversarial examples

Common Pitfalls and Solutions:

Reward Hacking:
- Problem: Model exploits reward model weaknesses
- Solution: Regularization, ensemble methods, adversarial training

Length Bias:
- Problem: Reward model prefers longer responses
- Solution: Length-normalized scoring, explicit length control

Distribution Shift:
- Problem: Reward model fails on new data distributions
- Solution: Continual learning, diverse training data, domain adaptation

Annotation Bias:
- Problem: Human preferences may be biased or inconsistent
- Solution: Diverse annotator pool, bias detection, demographic analysis
python
1# Comprehensive reward model implementation
2import torch
3import torch.nn as nn
4import torch.nn.functional as F
5from transformers import (
6 AutoTokenizer, AutoModelForSequenceClassification,
7 AutoConfig, Trainer, TrainingArguments
8)
9from datasets import Dataset
10from sklearn.metrics import accuracy_score
11import numpy as np
12from typing import Dict, List, Tuple, Optional
13
14class RewardModelTrainer:
15 def __init__(self, base_model_name: str, max_length: int = 512):
16 self.base_model_name = base_model_name
17 self.max_length = max_length
18 self.tokenizer = None
19 self.model = None
20
21 def setup_reward_model(self, dropout_rate: float = 0.1):
22 """Setup reward model architecture."""
23
24 # Load tokenizer
25 self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name)
26 if self.tokenizer.pad_token is None:
27 self.tokenizer.pad_token = self.tokenizer.eos_token
28
29 # Load config and modify for reward modeling
30 config = AutoConfig.from_pretrained(self.base_model_name)
31 config.num_labels = 1 # Single scalar reward
32 config.hidden_dropout_prob = dropout_rate
33 config.attention_probs_dropout_prob = dropout_rate
34
35 # Load model for sequence classification
36 self.model = AutoModelForSequenceClassification.from_pretrained(
37 self.base_model_name,
38 config=config,
39 torch_dtype=torch.bfloat16
40 )
41
42 # Modify the classifier head for reward modeling
43 self.model.classifier = nn.Sequential(
44 nn.Dropout(dropout_rate),
45 nn.Linear(config.hidden_size, config.hidden_size),
46 nn.Tanh(),
47 nn.Dropout(dropout_rate),
48 nn.Linear(config.hidden_size, 1) # Single reward score
49 )
50
51 print(f"Reward model setup completed:")
52 print(f"- Base model: {self.base_model_name}")
53 print(f"- Parameters: {self.model.num_parameters():,}")
54 print(f"- Dropout rate: {dropout_rate}")
55
56 def prepare_preference_dataset(self, preference_data: List[Dict]) -> Dataset:
57 """Prepare preference comparison dataset."""
58
59 def tokenize_pair(example):
60 """Tokenize chosen and rejected responses."""
61
62 prompt = example['prompt']
63 chosen = example['chosen']
64 rejected = example['rejected']
65
66 # Create full texts
67 chosen_text = f"{prompt}\n\n{chosen}"
68 rejected_text = f"{prompt}\n\n{rejected}"
69
70 # Tokenize both
71 chosen_tokens = self.tokenizer(
72 chosen_text,
73 truncation=True,
74 padding='max_length',
75 max_length=self.max_length,
76 return_tensors="pt"
77 )
78
79 rejected_tokens = self.tokenizer(
80 rejected_text,
81 truncation=True,
82 padding='max_length',
83 max_length=self.max_length,
84 return_tensors="pt"
85 )
86
87 return {
88 'chosen_input_ids': chosen_tokens['input_ids'].squeeze(),
89 'chosen_attention_mask': chosen_tokens['attention_mask'].squeeze(),
90 'rejected_input_ids': rejected_tokens['input_ids'].squeeze(),
91 'rejected_attention_mask': rejected_tokens['attention_mask'].squeeze(),
92 }
93
94 # Convert to dataset and tokenize
95 dataset = Dataset.from_list(preference_data)
96 tokenized_dataset = dataset.map(tokenize_pair, remove_columns=dataset.column_names)
97
98 print(f"Prepared preference dataset with {len(tokenized_dataset)} pairs")
99 return tokenized_dataset
100
101 def create_pairwise_dataset(self, tokenized_dataset: Dataset) -> Dataset:
102 """Create pairwise dataset for Bradley-Terry training."""
103
104 pairwise_examples = []
105
106 for example in tokenized_dataset:
107 # Chosen example (label = 1)
108 pairwise_examples.append({
109 'input_ids': example['chosen_input_ids'],
110 'attention_mask': example['chosen_attention_mask'],
111 'labels': torch.tensor(1.0) # Chosen is better
112 })
113
114 # Rejected example (label = 0)
115 pairwise_examples.append({
116 'input_ids': example['rejected_input_ids'],
117 'attention_mask': example['rejected_attention_mask'],
118 'labels': torch.tensor(0.0) # Rejected is worse
119 })
120
121 return Dataset.from_list(pairwise_examples)
122
123 def compute_pairwise_loss(self, chosen_rewards, rejected_rewards):
124 """Compute Bradley-Terry pairwise ranking loss."""
125
126 # Bradley-Terry loss: -log(sigmoid(chosen - rejected))
127 diff = chosen_rewards - rejected_rewards
128 loss = -F.logsigmoid(diff).mean()
129
130 return loss
131
132 def train_reward_model(self,
133 train_dataset: Dataset,
134 eval_dataset: Dataset = None,
135 output_dir: str = "./reward_model",
136 num_epochs: int = 3,
137 batch_size: int = 8,
138 learning_rate: float = 2e-5,
139 warmup_ratio: float = 0.1):
140 """Train the reward model on preference data."""
141
142 # Custom trainer for pairwise ranking
143 class RewardTrainer(Trainer):
144 def __init__(self, *args, **kwargs):
145 super().__init__(*args, **kwargs)
146 self.prediction_step_count = 0
147
148 def compute_loss(self, model, inputs, return_outputs=False):
149 """Compute pairwise ranking loss."""
150
151 # Split batch into chosen and rejected
152 batch_size = inputs['input_ids'].size(0) // 2
153
154 chosen_inputs = {
155 'input_ids': inputs['input_ids'][:batch_size],
156 'attention_mask': inputs['attention_mask'][:batch_size]
157 }
158
159 rejected_inputs = {
160 'input_ids': inputs['input_ids'][batch_size:],
161 'attention_mask': inputs['attention_mask'][batch_size:]
162 }
163
164 # Get reward scores
165 chosen_outputs = model(**chosen_inputs)
166 rejected_outputs = model(**rejected_inputs)
167
168 chosen_rewards = chosen_outputs.logits.squeeze(-1)
169 rejected_rewards = rejected_outputs.logits.squeeze(-1)
170
171 # Compute pairwise loss
172 loss = self.compute_pairwise_loss(chosen_rewards, rejected_rewards)
173
174 return (loss, {'chosen_rewards': chosen_rewards, 'rejected_rewards': rejected_rewards}) if return_outputs else loss
175
176 def compute_pairwise_loss(self, chosen_rewards, rejected_rewards):
177 """Bradley-Terry loss implementation."""
178 diff = chosen_rewards - rejected_rewards
179 return -F.logsigmoid(diff).mean()
180
181 def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
182 """Custom prediction step for evaluation."""
183
184 inputs = self._prepare_inputs(inputs)
185
186 with torch.no_grad():
187 outputs = model(**inputs)
188 rewards = outputs.logits.squeeze(-1)
189
190 # For evaluation, we compute accuracy of preference predictions
191 batch_size = rewards.size(0) // 2
192 chosen_rewards = rewards[:batch_size]
193 rejected_rewards = rewards[batch_size:]
194
195 # Accuracy: how often chosen > rejected
196 correct = (chosen_rewards > rejected_rewards).float()
197 accuracy = correct.mean()
198
199 loss = self.compute_pairwise_loss(chosen_rewards, rejected_rewards)
200
201 return (loss, accuracy, accuracy) # Return accuracy as both predictions and labels
202
203 # Prepare dataset for pairwise training
204 pairwise_train = self.create_pairwise_dataset(train_dataset)
205 pairwise_eval = self.create_pairwise_dataset(eval_dataset) if eval_dataset else None
206
207 # Training arguments
208 training_args = TrainingArguments(
209 output_dir=output_dir,
210 num_train_epochs=num_epochs,
211 per_device_train_batch_size=batch_size,
212 per_device_eval_batch_size=batch_size,
213 gradient_accumulation_steps=1,
214 learning_rate=learning_rate,
215 weight_decay=0.01,
216 warmup_ratio=warmup_ratio,
217 logging_steps=50,
218 save_steps=500,
219 eval_steps=500 if pairwise_eval else None,
220 evaluation_strategy="steps" if pairwise_eval else "no",
221 save_strategy="steps",
222 load_best_model_at_end=True if pairwise_eval else False,
223 metric_for_best_model="eval_loss" if pairwise_eval else None,
224 greater_is_better=False,
225 report_to="none",
226 bf16=True,
227 dataloader_pin_memory=False,
228 remove_unused_columns=False,
229 )
230
231 # Create trainer
232 trainer = RewardTrainer(
233 model=self.model,
234 args=training_args,
235 train_dataset=pairwise_train,
236 eval_dataset=pairwise_eval,
237 tokenizer=self.tokenizer,
238 )
239
240 # Train
241 print(f"Starting reward model training...")
242 print(f"Training examples: {len(pairwise_train)}")
243 if pairwise_eval:
244 print(f"Evaluation examples: {len(pairwise_eval)}")
245
246 trainer.train()
247
248 # Save model
249 trainer.save_model()
250 self.tokenizer.save_pretrained(output_dir)
251
252 print(f"Reward model training completed! Saved to {output_dir}")
253
254 return trainer
255
256 def evaluate_reward_model(self, test_data: List[Dict]) -> Dict:
257 """Evaluate reward model on test data."""
258
259 self.model.eval()
260 results = {
261 'accuracy': 0.0,
262 'mean_chosen_reward': 0.0,
263 'mean_rejected_reward': 0.0,
264 'reward_difference': 0.0
265 }
266
267 correct_predictions = 0
268 chosen_rewards = []
269 rejected_rewards = []
270
271 for example in test_data:
272 prompt = example['prompt']
273 chosen = example['chosen']
274 rejected = example['rejected']
275
276 # Score chosen response
277 chosen_text = f"{prompt}\n\n{chosen}"
278 chosen_inputs = self.tokenizer(
279 chosen_text,
280 return_tensors="pt",
281 truncation=True,
282 max_length=self.max_length
283 )
284 chosen_inputs = {k: v.to(self.model.device) for k, v in chosen_inputs.items()}
285
286 with torch.no_grad():
287 chosen_score = self.model(**chosen_inputs).logits[0, 0].item()
288
289 # Score rejected response
290 rejected_text = f"{prompt}\n\n{rejected}"
291 rejected_inputs = self.tokenizer(
292 rejected_text,
293 return_tensors="pt",
294 truncation=True,
295 max_length=self.max_length
296 )
297 rejected_inputs = {k: v.to(self.model.device) for k, v in rejected_inputs.items()}
298
299 with torch.no_grad():
300 rejected_score = self.model(**rejected_inputs).logits[0, 0].item()
301
302 # Check if model correctly prefers chosen over rejected
303 if chosen_score > rejected_score:
304 correct_predictions += 1
305
306 chosen_rewards.append(chosen_score)
307 rejected_rewards.append(rejected_score)
308
309 # Calculate metrics
310 results['accuracy'] = correct_predictions / len(test_data)
311 results['mean_chosen_reward'] = np.mean(chosen_rewards)
312 results['mean_rejected_reward'] = np.mean(rejected_rewards)
313 results['reward_difference'] = results['mean_chosen_reward'] - results['mean_rejected_reward']
314
315 return results
316
317 def get_reward_score(self, prompt: str, response: str) -> float:
318 """Get reward score for a prompt-response pair."""
319
320 self.model.eval()
321
322 text = f"{prompt}\n\n{response}"
323 inputs = self.tokenizer(
324 text,
325 return_tensors="pt",
326 truncation=True,
327 max_length=self.max_length
328 )
329 inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
330
331 with torch.no_grad():
332 score = self.model(**inputs).logits[0, 0].item()
333
334 return score
335
336# Example usage and testing
337def create_sample_preference_data():
338 """Create sample preference data for training."""
339
340 preference_data = [
341 {
342 "prompt": "Explain the concept of gravity.",
343 "chosen": "Gravity is a fundamental force of nature that causes objects with mass to attract each other. According to Einstein's theory of general relativity, gravity is not actually a force, but rather the curvature of spacetime caused by mass and energy. This curvature guides the motion of objects, making them appear to be attracted to each other.",
344 "rejected": "Gravity is when things fall down because they're heavy."
345 },
346 {
347 "prompt": "How do I cook pasta?",
348 "chosen": "To cook pasta: 1) Bring a large pot of salted water to boil, 2) Add pasta and stir occasionally, 3) Cook according to package directions (usually 8-12 minutes) until al dente, 4) Drain and serve immediately. The key is using plenty of water and not overcooking.",
349 "rejected": "Put pasta in water and heat it until it's soft. Should be fine."
350 },
351 {
352 "prompt": "What causes climate change?",
353 "chosen": "Climate change is primarily caused by increased concentrations of greenhouse gases in the atmosphere, mainly from human activities like burning fossil fuels, deforestation, and industrial processes. These gases trap heat from the sun, leading to global warming and associated climate impacts like sea level rise, extreme weather, and ecosystem disruption.",
354 "rejected": "The sun gets hotter sometimes and that changes the climate. It's natural."
355 }
356 ]
357
358 return preference_data
359
360if __name__ == "__main__":
361 # Initialize reward model trainer
362 trainer = RewardModelTrainer("microsoft/DialoGPT-small", max_length=256)
363
364 # Setup model
365 trainer.setup_reward_model(dropout_rate=0.1)
366
367 # Create sample data
368 preference_data = create_sample_preference_data()
369
370 # Prepare dataset
371 dataset = trainer.prepare_preference_dataset(preference_data)
372
373 # Split for training and evaluation
374 train_size = int(0.8 * len(dataset))
375 train_dataset = dataset.select(range(train_size))
376 eval_dataset = dataset.select(range(train_size, len(dataset)))
377
378 # Train reward model
379 reward_trainer = trainer.train_reward_model(
380 train_dataset=train_dataset,
381 eval_dataset=eval_dataset,
382 output_dir="./sample_reward_model",
383 num_epochs=2,
384 batch_size=2, # Small batch for demo
385 learning_rate=5e-5
386 )
387
388 # Evaluate model
389 results = trainer.evaluate_reward_model(preference_data)
390 print("\nReward Model Evaluation Results:")
391 for metric, value in results.items():
392 print(f"{metric}: {value:.4f}")
393
394 # Test individual scoring
395 print("\nTesting individual reward scoring:")
396 test_prompt = "What's the best way to learn programming?"
397 good_response = "Start with a beginner-friendly language like Python, practice regularly with small projects, and don't be afraid to make mistakes - they're part of learning!"
398 bad_response = "Just read some books about it."
399
400 good_score = trainer.get_reward_score(test_prompt, good_response)
401 bad_score = trainer.get_reward_score(test_prompt, bad_response)
402
403 print(f"Good response score: {good_score:.4f}")
404 print(f"Bad response score: {bad_score:.4f}")
405 print(f"Difference: {good_score - bad_score:.4f}")
406
407 print("\nReward model training completed!")

PPO Implementation

Proximal Policy Optimization (PPO) is the most commonly used reinforcement learning algorithm for RLHF. It provides a stable and efficient way to update the language model policy using rewards from the reward model.

PPO Algorithm Overview:

Core Idea:
PPO constrains policy updates to prevent large changes that could destabilize training. It uses a clipped objective function that limits how much the policy can change in a single update.

Key Components:

1. Policy Network:
- The language model being trained
- Generates responses given prompts
- Updated based on reward signals

2. Value Network:
- Estimates expected future rewards
- Often shares parameters with policy network
- Used for variance reduction in policy gradients

3. Reference Model:
- Copy of the original policy (frozen)
- Used to compute KL divergence penalty
- Prevents the policy from deviating too much

4. Reward Model:
- Trained in previous stage
- Provides scalar rewards for generated responses
- Guides the policy optimization

PPO Objective Function:

The PPO objective combines several terms:

$$L^{PPO}(\theta) = \mathbb{E}[\min(r_t(\theta)A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon)A_t)]$$

Where:
- $r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}$ is the probability ratio
- $A_t$ is the advantage estimate
- $\epsilon$ is the clipping parameter (typically 0.2)
- The clip function prevents large policy updates

Additional Terms:
- KL Penalty: $\beta \cdot D_{KL}[\pi_\theta || \pi_{ref}]$ to stay close to reference
- Value Loss: $(V_\theta(s_t) - R_t)^2$ for value function training
- Entropy Bonus: $\alpha \cdot H[\pi_\theta]$ to encourage exploration

PPO Training Process:

1. Data Collection:
- Generate responses using current policy
- Compute rewards using reward model
- Calculate advantages using value estimates

2. Policy Update:
- Multiple epochs of gradient updates
- Use collected data batch repeatedly
- Clip gradients to prevent instability

3. KL Monitoring:
- Monitor KL divergence from reference model
- Adjust KL penalty coefficient if needed
- Early stopping if KL gets too large

Implementation Challenges:

Memory Management:
- Need multiple model copies (policy, value, reference)
- Large sequence lengths for language modeling
- Gradient accumulation across sequences

Reward Engineering:
- Reward hacking and exploitation
- Balancing different reward components
- Handling sparse or delayed rewards

Training Stability:
- Careful hyperparameter tuning
- Gradient clipping and normalization
- Learning rate scheduling

Hyperparameter Guidelines:

Learning Rate:
- Typically 1e-5 to 5e-5 for language models
- Lower than supervised learning
- May need learning rate decay

Clipping Parameter (ε):
- Standard value: 0.2
- Lower values (0.1) for more conservative updates
- Higher values (0.3) for faster learning

KL Penalty (β):
- Start with 0.1-0.2
- Adaptive adjustment based on KL divergence
- Higher values for more conservative training

Batch Size:
- Larger batches (64-256) for stability
- Limited by memory constraints
- May use gradient accumulation
python
1# Comprehensive PPO implementation for RLHF
2import torch
3import torch.nn as nn
4import torch.nn.functional as F
5from torch.optim import AdamW
6from transformers import AutoTokenizer, AutoModelForCausalLM
7import numpy as np
8from typing import Dict, List, Tuple, Optional
9from dataclasses import dataclass
10import wandb
11
12@dataclass
13class PPOConfig:
14 """Configuration for PPO training."""
15 model_name: str = "gpt2"
16 learning_rate: float = 1.41e-5
17 batch_size: int = 64
18 mini_batch_size: int = 16
19 gradient_accumulation_steps: int = 1
20 ppo_epochs: int = 4
21 max_grad_norm: float = 1.0
22 clip_range: float = 0.2
23 clip_range_vf: Optional[float] = None
24 vf_coef: float = 0.1
25 target_kl: float = 0.1
26 init_kl_coef: float = 0.2
27 adap_kl_ctrl: bool = True
28 gamma: float = 1.0
29 lam: float = 0.95
30 use_score_scaling: bool = False
31 use_score_norm: bool = False
32 score_clip: Optional[float] = None
33
34class PPOTrainer:
35 def __init__(self,
36 config: PPOConfig,
37 model: nn.Module,
38 ref_model: nn.Module,
39 reward_model: nn.Module,
40 tokenizer):
41
42 self.config = config
43 self.model = model # Policy model
44 self.ref_model = ref_model # Reference model (frozen)
45 self.reward_model = reward_model
46 self.tokenizer = tokenizer
47
48 # Freeze reference model
49 for param in self.ref_model.parameters():
50 param.requires_grad = False
51
52 # Setup optimizer
53 self.optimizer = AdamW(
54 self.model.parameters(),
55 lr=config.learning_rate,
56 eps=1e-8,
57 weight_decay=0.01
58 )
59
60 # KL controller for adaptive penalty
61 self.kl_ctl = AdaptiveKLController(config.init_kl_coef, config.target_kl)
62
63 # Training statistics
64 self.stats = {
65 'policy_loss': [],
66 'value_loss': [],
67 'total_loss': [],
68 'kl_divergence': [],
69 'rewards': [],
70 'advantages': [],
71 'approx_kl': [],
72 }
73
74 def generate_responses(self,
75 prompts: List[str],
76 max_new_tokens: int = 50,
77 temperature: float = 0.7,
78 top_p: float = 0.9) -> Tuple[List[str], torch.Tensor, torch.Tensor]:
79 """Generate responses from current policy."""
80
81 self.model.eval()
82
83 all_responses = []
84 all_response_tensors = []
85 all_log_probs = []
86
87 for prompt in prompts:
88 # Tokenize prompt
89 prompt_tokens = self.tokenizer.encode(prompt, return_tensors="pt")
90 prompt_tokens = prompt_tokens.to(self.model.device)
91
92 # Generate response
93 with torch.no_grad():
94 response_tokens = self.model.generate(
95 prompt_tokens,
96 max_new_tokens=max_new_tokens,
97 do_sample=True,
98 temperature=temperature,
99 top_p=top_p,
100 pad_token_id=self.tokenizer.eos_token_id,
101 return_dict_in_generate=True,
102 output_scores=True
103 )
104
105 # Extract generated tokens (without prompt)
106 generated_tokens = response_tokens.sequences[0][prompt_tokens.shape[1]:]
107
108 # Calculate log probabilities
109 log_probs = []
110 for i, token_id in enumerate(generated_tokens):
111 if i < len(response_tokens.scores):
112 scores = response_tokens.scores[i][0] # [vocab_size]
113 log_prob = F.log_softmax(scores, dim=-1)[token_id].item()
114 log_probs.append(log_prob)
115
116 # Decode response
117 response_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
118
119 all_responses.append(response_text)
120 all_response_tensors.append(generated_tokens)
121 all_log_probs.append(torch.tensor(log_probs))
122
123 return all_responses, all_response_tensors, all_log_probs
124
125 def compute_rewards(self, prompts: List[str], responses: List[str]) -> List[float]:
126 """Compute rewards using reward model."""
127
128 self.reward_model.eval()
129 rewards = []
130
131 for prompt, response in zip(prompts, responses):
132 # Create full text
133 full_text = f"{prompt}\n\n{response}"
134
135 # Tokenize and get reward
136 inputs = self.tokenizer(
137 full_text,
138 return_tensors="pt",
139 truncation=True,
140 max_length=512
141 )
142 inputs = {k: v.to(self.reward_model.device) for k, v in inputs.items()}
143
144 with torch.no_grad():
145 reward = self.reward_model(**inputs).logits[0, 0].item()
146
147 rewards.append(reward)
148
149 return rewards
150
151 def compute_advantages(self,
152 rewards: torch.Tensor,
153 values: torch.Tensor,
154 masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
155 """Compute GAE (Generalized Advantage Estimation) advantages."""
156
157 # Add terminal value of 0
158 values = torch.cat([values, torch.zeros(1).to(values.device)])
159
160 advantages = torch.zeros_like(rewards)
161 last_gae_lam = 0
162
163 # Compute advantages using GAE
164 for t in reversed(range(len(rewards))):
165 if t == len(rewards) - 1:
166 next_non_terminal = 0
167 next_values = 0
168 else:
169 next_non_terminal = masks[t + 1]
170 next_values = values[t + 1]
171
172 delta = rewards[t] + self.config.gamma * next_values * next_non_terminal - values[t]
173 advantages[t] = last_gae_lam = delta + self.config.gamma * self.config.lam * next_non_terminal * last_gae_lam
174
175 # Compute returns
176 returns = advantages + values[:-1]
177
178 return advantages, returns
179
180 def compute_policy_loss(self,
181 log_probs: torch.Tensor,
182 old_log_probs: torch.Tensor,
183 advantages: torch.Tensor,
184 masks: torch.Tensor) -> torch.Tensor:
185 """Compute clipped PPO policy loss."""
186
187 # Compute probability ratio
188 log_ratio = log_probs - old_log_probs
189 ratio = torch.exp(log_ratio)
190
191 # Compute clipped surrogate loss
192 surr1 = ratio * advantages
193 surr2 = torch.clamp(ratio, 1 - self.config.clip_range, 1 + self.config.clip_range) * advantages
194
195 policy_loss = -torch.min(surr1, surr2)
196
197 # Apply mask and average
198 policy_loss = (policy_loss * masks).sum() / masks.sum()
199
200 return policy_loss
201
202 def compute_value_loss(self,
203 values: torch.Tensor,
204 old_values: torch.Tensor,
205 returns: torch.Tensor,
206 masks: torch.Tensor) -> torch.Tensor:
207 """Compute value function loss."""
208
209 if self.config.clip_range_vf is not None:
210 # Clipped value loss
211 values_clipped = old_values + torch.clamp(
212 values - old_values,
213 -self.config.clip_range_vf,
214 self.config.clip_range_vf
215 )
216
217 vf_loss1 = (values - returns) ** 2
218 vf_loss2 = (values_clipped - returns) ** 2
219 vf_loss = torch.max(vf_loss1, vf_loss2)
220 else:
221 # Standard MSE loss
222 vf_loss = (values - returns) ** 2
223
224 # Apply mask and average
225 vf_loss = (vf_loss * masks).sum() / masks.sum()
226
227 return vf_loss
228
229 def compute_kl_penalty(self,
230 log_probs: torch.Tensor,
231 ref_log_probs: torch.Tensor,
232 masks: torch.Tensor) -> torch.Tensor:
233 """Compute KL divergence penalty."""
234
235 kl_div = ref_log_probs - log_probs
236 kl_penalty = (kl_div * masks).sum() / masks.sum()
237
238 return kl_penalty
239
240 def train_step(self, batch_data: Dict) -> Dict:
241 """Perform one PPO training step."""
242
243 self.model.train()
244
245 # Extract batch data
246 prompts = batch_data['prompts']
247 responses = batch_data['responses']
248 old_log_probs = batch_data['log_probs']
249 rewards = batch_data['rewards']
250
251 # Convert to tensors
252 rewards = torch.tensor(rewards, dtype=torch.float32, device=self.model.device)
253
254 # Generate current policy outputs
255 current_responses, response_tensors, current_log_probs = self.generate_responses(prompts)
256
257 # Compute values (simplified - in practice, use separate value head)
258 values = torch.zeros_like(rewards) # Placeholder
259
260 # Compute advantages
261 masks = torch.ones_like(rewards) # Simplified masking
262 advantages, returns = self.compute_advantages(rewards, values, masks)
263
264 # Normalize advantages
265 advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
266
267 # PPO training loop
268 total_policy_loss = 0
269 total_value_loss = 0
270 total_kl_penalty = 0
271
272 for ppo_epoch in range(self.config.ppo_epochs):
273 # Get current log probs (simplified)
274 curr_log_probs = torch.stack([lp.mean() for lp in current_log_probs])
275 old_lp = torch.stack([lp.mean() for lp in old_log_probs])
276
277 # Compute losses
278 policy_loss = self.compute_policy_loss(curr_log_probs, old_lp, advantages, masks)
279 value_loss = self.compute_value_loss(values, values, returns, masks) # Simplified
280
281 # Get reference model log probs
282 ref_log_probs = self.get_ref_log_probs(prompts, responses)
283 kl_penalty = self.compute_kl_penalty(curr_log_probs, ref_log_probs, masks)
284
285 # Total loss
286 total_loss = (
287 policy_loss +
288 self.config.vf_coef * value_loss +
289 self.kl_ctl.value * kl_penalty
290 )
291
292 # Backward pass
293 self.optimizer.zero_grad()
294 total_loss.backward()
295
296 # Gradient clipping
297 torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
298
299 self.optimizer.step()
300
301 # Accumulate losses
302 total_policy_loss += policy_loss.item()
303 total_value_loss += value_loss.item()
304 total_kl_penalty += kl_penalty.item()
305
306 # Update KL controller
307 mean_kl = total_kl_penalty / self.config.ppo_epochs
308 self.kl_ctl.update(mean_kl, batch_data['batch_size'])
309
310 # Record statistics
311 stats = {
312 'policy_loss': total_policy_loss / self.config.ppo_epochs,
313 'value_loss': total_value_loss / self.config.ppo_epochs,
314 'kl_penalty': mean_kl,
315 'kl_coef': self.kl_ctl.value,
316 'mean_reward': rewards.mean().item(),
317 'mean_advantage': advantages.mean().item(),
318 }
319
320 return stats
321
322 def get_ref_log_probs(self, prompts: List[str], responses: List[str]) -> torch.Tensor:
323 """Get log probabilities from reference model."""
324
325 self.ref_model.eval()
326 ref_log_probs = []
327
328 with torch.no_grad():
329 for prompt, response in zip(prompts, responses):
330 # Simplified calculation
331 ref_log_prob = torch.tensor(0.0) # Placeholder
332 ref_log_probs.append(ref_log_prob)
333
334 return torch.stack(ref_log_probs)
335
336 def train(self, prompts: List[str], num_steps: int = 1000):
337 """Main training loop."""
338
339 print(f"Starting PPO training for {num_steps} steps...")
340
341 for step in range(num_steps):
342 # Sample batch of prompts
343 batch_prompts = np.random.choice(prompts, size=self.config.batch_size, replace=True).tolist()
344
345 # Generate responses
346 responses, response_tensors, log_probs = self.generate_responses(batch_prompts)
347
348 # Compute rewards
349 rewards = self.compute_rewards(batch_prompts, responses)
350
351 # Prepare batch data
352 batch_data = {
353 'prompts': batch_prompts,
354 'responses': responses,
355 'log_probs': log_probs,
356 'rewards': rewards,
357 'batch_size': len(batch_prompts)
358 }
359
360 # Training step
361 stats = self.train_step(batch_data)
362
363 # Log statistics
364 if step % 10 == 0:
365 print(f"Step {step}:")
366 for key, value in stats.items():
367 print(f" {key}: {value:.4f}")
368 print()
369
370 # Record stats
371 for key, value in stats.items():
372 if key in self.stats:
373 self.stats[key].append(value)
374
375 print("PPO training completed!")
376
377class AdaptiveKLController:
378 """Adaptive KL divergence controller."""
379
380 def __init__(self, init_kl_coef: float, target_kl: float):
381 self.value = init_kl_coef
382 self.target = target_kl
383
384 def update(self, current_kl: float, n_steps: int):
385 """Update KL coefficient based on current KL divergence."""
386
387 if current_kl < self.target / 1.5:
388 # KL too low, decrease penalty
389 self.value *= 0.98
390 elif current_kl > self.target * 1.5:
391 # KL too high, increase penalty
392 self.value *= 1.02
393
394 # Clamp to reasonable range
395 self.value = max(0.01, min(2.0, self.value))
396
397# Example usage
398if __name__ == "__main__":
399 # Initialize models
400 model_name = "microsoft/DialoGPT-small"
401
402 tokenizer = AutoTokenizer.from_pretrained(model_name)
403 if tokenizer.pad_token is None:
404 tokenizer.pad_token = tokenizer.eos_token
405
406 # Policy model (trainable)
407 policy_model = AutoModelForCausalLM.from_pretrained(model_name)
408
409 # Reference model (frozen copy)
410 ref_model = AutoModelForCausalLM.from_pretrained(model_name)
411
412 # Reward model (placeholder - use actual trained reward model)
413 reward_model = AutoModelForCausalLM.from_pretrained(model_name)
414
415 # PPO configuration
416 config = PPOConfig(
417 model_name=model_name,
418 learning_rate=1.41e-5,
419 batch_size=8, # Small for demo
420 mini_batch_size=4,
421 ppo_epochs=2,
422 target_kl=0.1,
423 )
424
425 # Initialize PPO trainer
426 ppo_trainer = PPOTrainer(
427 config=config,
428 model=policy_model,
429 ref_model=ref_model,
430 reward_model=reward_model,
431 tokenizer=tokenizer
432 )
433
434 # Sample prompts for training
435 prompts = [
436 "Human: What's the best way to learn programming?\n\nAssistant:",
437 "Human: Explain climate change in simple terms.\n\nAssistant:",
438 "Human: How do I make a good first impression?\n\nAssistant:",
439 "Human: What are the benefits of exercise?\n\nAssistant:",
440 ]
441
442 # Train with PPO
443 ppo_trainer.train(prompts, num_steps=50) # Short training for demo
444
445 # Save trained model
446 policy_model.save_pretrained("./ppo_trained_model")
447 tokenizer.save_pretrained("./ppo_trained_model")
448
449 print("PPO training completed and model saved!")

On this page

Understanding RLHFReward Model DesignPPO Implementation