L
Initializing Studio...
Master Reinforcement Learning from Human Feedback (RLHF) to align language models with human preferences, create safer AI systems, and improve response quality through human feedback integration.
Align language models with human preferences and values through iterative feedback and reinforcement learning.
Build safer AI systems by incorporating human judgment and reducing harmful or inappropriate outputs.
Significantly improve response quality, relevance, and engagement through preference-based optimization.
Continuously improve model behavior through ongoing human feedback collection and model updates.
1# RLHF implementation overview using TRL (Transformer Reinforcement Learning)2import torch3from transformers import (4 AutoTokenizer, AutoModelForCausalLM,5 AutoModelForSequenceClassification6)7from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead8from datasets import Dataset9import numpy as np10from typing import List, Dict1112class RLHFPipeline:13 def __init__(self, model_name: str, reward_model_name: str = None):14 self.model_name = model_name15 self.reward_model_name = reward_model_name16 self.tokenizer = None17 self.sft_model = None18 self.reward_model = None19 self.ppo_model = None2021 def stage1_supervised_fine_tuning(self, sft_dataset: Dataset):22 """Stage 1: Supervised Fine-Tuning on human demonstrations."""2324 print("Stage 1: Supervised Fine-Tuning")25 print("=" * 40)2627 # Load tokenizer and model28 self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)29 if self.tokenizer.pad_token is None:30 self.tokenizer.pad_token = self.tokenizer.eos_token3132 self.sft_model = AutoModelForCausalLM.from_pretrained(33 self.model_name,34 torch_dtype=torch.bfloat16,35 device_map="auto"36 )3738 # Prepare SFT dataset39 def format_sft_example(example):40 """Format example for SFT training."""41 prompt = f"Human: {example['prompt']}\n\nAssistant: "42 response = example['chosen'] # Use human-preferred response43 full_text = prompt + response + self.tokenizer.eos_token44 return {"text": full_text}4546 formatted_dataset = sft_dataset.map(format_sft_example)4748 # SFT training (simplified - use full Trainer in practice)49 print(f"SFT dataset size: {len(formatted_dataset)}")50 print("SFT training would happen here...")51 print("✓ Stage 1 completed: SFT model ready")5253 # Save SFT model54 self.sft_model.save_pretrained("./sft_model")55 self.tokenizer.save_pretrained("./sft_model")5657 def stage2_reward_model_training(self, preference_dataset: Dataset):58 """Stage 2: Train reward model on human preferences."""5960 print("\nStage 2: Reward Model Training")61 print("=" * 40)6263 # Load model for reward modeling (typically smaller than main model)64 reward_model_name = self.reward_model_name or self.model_name6566 self.reward_model = AutoModelForSequenceClassification.from_pretrained(67 reward_model_name,68 num_labels=1, # Single scalar reward69 torch_dtype=torch.bfloat16,70 device_map="auto"71 )7273 # Prepare preference dataset74 def format_preference_example(example):75 """Format preference comparison for reward model training."""76 prompt = example['prompt']77 chosen = example['chosen']78 rejected = example['rejected']7980 # Create prompt + response pairs81 chosen_text = f"Human: {prompt}\n\nAssistant: {chosen}"82 rejected_text = f"Human: {prompt}\n\nAssistant: {rejected}"8384 return {85 'chosen': chosen_text,86 'rejected': rejected_text87 }8889 formatted_preferences = preference_dataset.map(format_preference_example)9091 # Reward model training (simplified)92 print(f"Preference dataset size: {len(formatted_preferences)}")93 print("Reward model training would happen here...")94 print("Training on preference pairs (chosen > rejected)")95 print("✓ Stage 2 completed: Reward model ready")9697 # Save reward model98 self.reward_model.save_pretrained("./reward_model")99100 def stage3_ppo_training(self, prompts_dataset: Dataset, ppo_config: PPOConfig = None):101 """Stage 3: PPO training using reward model."""102103 print("\nStage 3: PPO Training")104 print("=" * 40)105106 # Default PPO configuration107 if ppo_config is None:108 ppo_config = PPOConfig(109 model_name="./sft_model",110 learning_rate=1.41e-5,111 batch_size=16,112 mini_batch_size=4,113 gradient_accumulation_steps=1,114 optimize_cuda_cache=True,115 early_stopping=False,116 target_kl=0.1,117 ppo_epochs=4,118 seed=0,119 init_kl_coef=0.2,120 adap_kl_ctrl=True,121 )122123 # Load model with value head for PPO124 self.ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained(125 "./sft_model",126 torch_dtype=torch.bfloat16,127 device_map="auto"128 )129130 # Load reward model131 reward_model = AutoModelForSequenceClassification.from_pretrained(132 "./reward_model",133 torch_dtype=torch.bfloat16,134 device_map="auto"135 )136137 # Create PPO trainer138 ppo_trainer = PPOTrainer(139 config=ppo_config,140 model=self.ppo_model,141 ref_model=None, # Will use model copy as reference142 tokenizer=self.tokenizer,143 )144145 # Prepare prompts for PPO training146 prompts = [f"Human: {prompt}\n\nAssistant: " for prompt in prompts_dataset['prompt']]147148 print(f"PPO training on {len(prompts)} prompts")149150 # PPO training loop (simplified)151 for epoch in range(3): # Limited epochs for demonstration152 print(f"\nPPO Epoch {epoch + 1}")153154 for batch_idx in range(0, len(prompts), ppo_config.batch_size):155 batch_prompts = prompts[batch_idx:batch_idx + ppo_config.batch_size]156157 # Generate responses158 prompt_tensors = [159 self.tokenizer.encode(prompt, return_tensors="pt")[0]160 for prompt in batch_prompts161 ]162163 # Generate responses from current policy164 response_tensors = []165 for prompt_tensor in prompt_tensors:166 response = self.ppo_model.generate(167 prompt_tensor.unsqueeze(0),168 max_new_tokens=50,169 do_sample=True,170 temperature=0.7,171 pad_token_id=self.tokenizer.eos_token_id172 )173 response_tensors.append(response[0])174175 # Calculate rewards using reward model176 rewards = []177 for prompt_tensor, response_tensor in zip(prompt_tensors, response_tensors):178 # Combine prompt and response179 full_text = self.tokenizer.decode(response_tensor, skip_special_tokens=True)180181 # Get reward score182 inputs = self.tokenizer(full_text, return_tensors="pt", truncation=True, max_length=512)183 inputs = {k: v.to(reward_model.device) for k, v in inputs.items()}184185 with torch.no_grad():186 reward_score = reward_model(**inputs).logits[0, 0].item()187188 rewards.append(reward_score)189190 # Convert to tensors191 rewards = [torch.tensor(r) for r in rewards]192193 # PPO training step194 stats = ppo_trainer.step(prompt_tensors, response_tensors, rewards)195196 if batch_idx % 4 == 0: # Log every few batches197 print(f" Batch {batch_idx//ppo_config.batch_size + 1}: "198 f"Mean reward: {np.mean([r.item() for r in rewards]):.3f}")199200 print("✓ Stage 3 completed: RLHF training finished")201202 # Save final model203 self.ppo_model.save_pretrained("./rlhf_model")204205 return ppo_trainer206207 def evaluate_rlhf_model(self, test_prompts: List[str]):208 """Evaluate the RLHF-trained model."""209210 print("\nEvaluating RLHF Model")211 print("=" * 40)212213 if self.ppo_model is None:214 # Load the trained model215 self.ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained("./rlhf_model")216217 self.ppo_model.eval()218219 for i, prompt in enumerate(test_prompts, 1):220 formatted_prompt = f"Human: {prompt}\n\nAssistant: "221222 inputs = self.tokenizer(formatted_prompt, return_tensors="pt")223 inputs = {k: v.to(self.ppo_model.device) for k, v in inputs.items()}224225 # Generate response226 with torch.no_grad():227 outputs = self.ppo_model.generate(228 **inputs,229 max_new_tokens=150,230 do_sample=True,231 temperature=0.7,232 top_p=0.9,233 pad_token_id=self.tokenizer.eos_token_id234 )235236 response = self.tokenizer.decode(237 outputs[0][inputs['input_ids'].shape[1]:],238 skip_special_tokens=True239 )240241 print(f"\nTest {i}:")242 print(f"Human: {prompt}")243 print(f"Assistant: {response}")244 print("-" * 40)245246# Sample data preparation functions247def create_sft_dataset():248 """Create sample SFT dataset with human demonstrations."""249250 sft_examples = [251 {252 "prompt": "Explain quantum computing in simple terms.",253 "chosen": "Quantum computing uses quantum mechanical phenomena like superposition and entanglement to process information in ways that classical computers cannot. Unlike classical bits that are either 0 or 1, quantum bits (qubits) can exist in multiple states simultaneously, potentially allowing quantum computers to solve certain problems exponentially faster than classical computers."254 },255 {256 "prompt": "How can I improve my sleep quality?",257 "chosen": "Here are some evidence-based strategies to improve sleep quality: 1) Maintain a consistent sleep schedule, 2) Create a relaxing bedtime routine, 3) Ensure your bedroom is cool, dark, and quiet, 4) Avoid caffeine and screens before bedtime, 5) Get regular exercise during the day, and 6) Consider relaxation techniques like meditation or deep breathing."258 },259 {260 "prompt": "Write a Python function to reverse a string.",261 "chosen": "Here's a simple Python function to reverse a string:\n\ndef reverse_string(s):\n return s[::-1]\n\n# Example usage:\noriginal = 'hello'\nreversed_str = reverse_string(original)\nprint(reversed_str) # Output: 'olleh'\n\nThis uses Python's slice notation with a step of -1 to reverse the string efficiently."262 }263 ]264265 return Dataset.from_list(sft_examples)266267def create_preference_dataset():268 """Create sample preference dataset for reward model training."""269270 preference_examples = [271 {272 "prompt": "What's the capital of France?",273 "chosen": "The capital of France is Paris. It's a beautiful city known for its art, culture, cuisine, and iconic landmarks like the Eiffel Tower and Louvre Museum.",274 "rejected": "Paris is the capital. It's in France and has some buildings and stuff."275 },276 {277 "prompt": "How do I bake a chocolate cake?",278 "chosen": "To bake a chocolate cake, you'll need: flour, sugar, cocoa powder, eggs, butter, baking powder, and milk. Mix dry ingredients, cream butter and sugar, add eggs, then alternate adding dry ingredients and milk. Bake at 350°F for 25-30 minutes. Let me know if you'd like a detailed recipe!",279 "rejected": "Just mix some chocolate stuff together and put it in the oven until it looks done. Should work fine."280 },281 {282 "prompt": "Is it safe to eat raw eggs?",283 "chosen": "Eating raw eggs carries some risk of Salmonella infection, though the risk is relatively low (about 1 in 20,000 eggs). Pasteurized eggs are safer for raw consumption. If you're pregnant, elderly, or immunocompromised, it's best to avoid raw eggs. For recipes requiring raw eggs, consider pasteurized alternatives.",284 "rejected": "Raw eggs are totally fine to eat, there's no risk at all. Eat as many as you want!"285 }286 ]287288 return Dataset.from_list(preference_examples)289290def create_prompts_dataset():291 """Create sample prompts for PPO training."""292293 prompts = [294 "Explain the importance of exercise.",295 "What's the best way to learn a new language?",296 "How does photosynthesis work?",297 "Give me tips for public speaking.",298 "What are the benefits of meditation?",299 "How do I start investing in stocks?",300 "Explain machine learning to a beginner.",301 "What's the difference between weather and climate?",302 ]303304 return Dataset.from_dict({"prompt": prompts})305306# Main execution example307if __name__ == "__main__":308 # Initialize RLHF pipeline309 rlhf = RLHFPipeline(310 model_name="microsoft/DialoGPT-medium", # Use smaller model for demo311 reward_model_name="microsoft/DialoGPT-small"312 )313314 # Create sample datasets315 sft_data = create_sft_dataset()316 preference_data = create_preference_dataset()317 prompts_data = create_prompts_dataset()318319 # Run RLHF pipeline320 print("Starting RLHF Pipeline")321 print("=" * 50)322323 # Stage 1: SFT324 rlhf.stage1_supervised_fine_tuning(sft_data)325326 # Stage 2: Reward Model327 rlhf.stage2_reward_model_training(preference_data)328329 # Stage 3: PPO Training330 ppo_config = PPOConfig(331 model_name="./sft_model",332 learning_rate=1.41e-5,333 batch_size=4, # Small batch for demo334 mini_batch_size=2,335 ppo_epochs=2, # Fewer epochs for demo336 target_kl=0.1,337 )338339 rlhf.stage3_ppo_training(prompts_data, ppo_config)340341 # Evaluate final model342 test_prompts = [343 "What's the best way to stay healthy?",344 "Explain artificial intelligence.",345 "How do I write a good email?"346 ]347348 rlhf.evaluate_rlhf_model(test_prompts)349350 print("\nRLHF Pipeline completed successfully!")
1# Comprehensive reward model implementation2import torch3import torch.nn as nn4import torch.nn.functional as F5from transformers import (6 AutoTokenizer, AutoModelForSequenceClassification,7 AutoConfig, Trainer, TrainingArguments8)9from datasets import Dataset10from sklearn.metrics import accuracy_score11import numpy as np12from typing import Dict, List, Tuple, Optional1314class RewardModelTrainer:15 def __init__(self, base_model_name: str, max_length: int = 512):16 self.base_model_name = base_model_name17 self.max_length = max_length18 self.tokenizer = None19 self.model = None2021 def setup_reward_model(self, dropout_rate: float = 0.1):22 """Setup reward model architecture."""2324 # Load tokenizer25 self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name)26 if self.tokenizer.pad_token is None:27 self.tokenizer.pad_token = self.tokenizer.eos_token2829 # Load config and modify for reward modeling30 config = AutoConfig.from_pretrained(self.base_model_name)31 config.num_labels = 1 # Single scalar reward32 config.hidden_dropout_prob = dropout_rate33 config.attention_probs_dropout_prob = dropout_rate3435 # Load model for sequence classification36 self.model = AutoModelForSequenceClassification.from_pretrained(37 self.base_model_name,38 config=config,39 torch_dtype=torch.bfloat1640 )4142 # Modify the classifier head for reward modeling43 self.model.classifier = nn.Sequential(44 nn.Dropout(dropout_rate),45 nn.Linear(config.hidden_size, config.hidden_size),46 nn.Tanh(),47 nn.Dropout(dropout_rate),48 nn.Linear(config.hidden_size, 1) # Single reward score49 )5051 print(f"Reward model setup completed:")52 print(f"- Base model: {self.base_model_name}")53 print(f"- Parameters: {self.model.num_parameters():,}")54 print(f"- Dropout rate: {dropout_rate}")5556 def prepare_preference_dataset(self, preference_data: List[Dict]) -> Dataset:57 """Prepare preference comparison dataset."""5859 def tokenize_pair(example):60 """Tokenize chosen and rejected responses."""6162 prompt = example['prompt']63 chosen = example['chosen']64 rejected = example['rejected']6566 # Create full texts67 chosen_text = f"{prompt}\n\n{chosen}"68 rejected_text = f"{prompt}\n\n{rejected}"6970 # Tokenize both71 chosen_tokens = self.tokenizer(72 chosen_text,73 truncation=True,74 padding='max_length',75 max_length=self.max_length,76 return_tensors="pt"77 )7879 rejected_tokens = self.tokenizer(80 rejected_text,81 truncation=True,82 padding='max_length',83 max_length=self.max_length,84 return_tensors="pt"85 )8687 return {88 'chosen_input_ids': chosen_tokens['input_ids'].squeeze(),89 'chosen_attention_mask': chosen_tokens['attention_mask'].squeeze(),90 'rejected_input_ids': rejected_tokens['input_ids'].squeeze(),91 'rejected_attention_mask': rejected_tokens['attention_mask'].squeeze(),92 }9394 # Convert to dataset and tokenize95 dataset = Dataset.from_list(preference_data)96 tokenized_dataset = dataset.map(tokenize_pair, remove_columns=dataset.column_names)9798 print(f"Prepared preference dataset with {len(tokenized_dataset)} pairs")99 return tokenized_dataset100101 def create_pairwise_dataset(self, tokenized_dataset: Dataset) -> Dataset:102 """Create pairwise dataset for Bradley-Terry training."""103104 pairwise_examples = []105106 for example in tokenized_dataset:107 # Chosen example (label = 1)108 pairwise_examples.append({109 'input_ids': example['chosen_input_ids'],110 'attention_mask': example['chosen_attention_mask'],111 'labels': torch.tensor(1.0) # Chosen is better112 })113114 # Rejected example (label = 0)115 pairwise_examples.append({116 'input_ids': example['rejected_input_ids'],117 'attention_mask': example['rejected_attention_mask'],118 'labels': torch.tensor(0.0) # Rejected is worse119 })120121 return Dataset.from_list(pairwise_examples)122123 def compute_pairwise_loss(self, chosen_rewards, rejected_rewards):124 """Compute Bradley-Terry pairwise ranking loss."""125126 # Bradley-Terry loss: -log(sigmoid(chosen - rejected))127 diff = chosen_rewards - rejected_rewards128 loss = -F.logsigmoid(diff).mean()129130 return loss131132 def train_reward_model(self,133 train_dataset: Dataset,134 eval_dataset: Dataset = None,135 output_dir: str = "./reward_model",136 num_epochs: int = 3,137 batch_size: int = 8,138 learning_rate: float = 2e-5,139 warmup_ratio: float = 0.1):140 """Train the reward model on preference data."""141142 # Custom trainer for pairwise ranking143 class RewardTrainer(Trainer):144 def __init__(self, *args, **kwargs):145 super().__init__(*args, **kwargs)146 self.prediction_step_count = 0147148 def compute_loss(self, model, inputs, return_outputs=False):149 """Compute pairwise ranking loss."""150151 # Split batch into chosen and rejected152 batch_size = inputs['input_ids'].size(0) // 2153154 chosen_inputs = {155 'input_ids': inputs['input_ids'][:batch_size],156 'attention_mask': inputs['attention_mask'][:batch_size]157 }158159 rejected_inputs = {160 'input_ids': inputs['input_ids'][batch_size:],161 'attention_mask': inputs['attention_mask'][batch_size:]162 }163164 # Get reward scores165 chosen_outputs = model(**chosen_inputs)166 rejected_outputs = model(**rejected_inputs)167168 chosen_rewards = chosen_outputs.logits.squeeze(-1)169 rejected_rewards = rejected_outputs.logits.squeeze(-1)170171 # Compute pairwise loss172 loss = self.compute_pairwise_loss(chosen_rewards, rejected_rewards)173174 return (loss, {'chosen_rewards': chosen_rewards, 'rejected_rewards': rejected_rewards}) if return_outputs else loss175176 def compute_pairwise_loss(self, chosen_rewards, rejected_rewards):177 """Bradley-Terry loss implementation."""178 diff = chosen_rewards - rejected_rewards179 return -F.logsigmoid(diff).mean()180181 def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):182 """Custom prediction step for evaluation."""183184 inputs = self._prepare_inputs(inputs)185186 with torch.no_grad():187 outputs = model(**inputs)188 rewards = outputs.logits.squeeze(-1)189190 # For evaluation, we compute accuracy of preference predictions191 batch_size = rewards.size(0) // 2192 chosen_rewards = rewards[:batch_size]193 rejected_rewards = rewards[batch_size:]194195 # Accuracy: how often chosen > rejected196 correct = (chosen_rewards > rejected_rewards).float()197 accuracy = correct.mean()198199 loss = self.compute_pairwise_loss(chosen_rewards, rejected_rewards)200201 return (loss, accuracy, accuracy) # Return accuracy as both predictions and labels202203 # Prepare dataset for pairwise training204 pairwise_train = self.create_pairwise_dataset(train_dataset)205 pairwise_eval = self.create_pairwise_dataset(eval_dataset) if eval_dataset else None206207 # Training arguments208 training_args = TrainingArguments(209 output_dir=output_dir,210 num_train_epochs=num_epochs,211 per_device_train_batch_size=batch_size,212 per_device_eval_batch_size=batch_size,213 gradient_accumulation_steps=1,214 learning_rate=learning_rate,215 weight_decay=0.01,216 warmup_ratio=warmup_ratio,217 logging_steps=50,218 save_steps=500,219 eval_steps=500 if pairwise_eval else None,220 evaluation_strategy="steps" if pairwise_eval else "no",221 save_strategy="steps",222 load_best_model_at_end=True if pairwise_eval else False,223 metric_for_best_model="eval_loss" if pairwise_eval else None,224 greater_is_better=False,225 report_to="none",226 bf16=True,227 dataloader_pin_memory=False,228 remove_unused_columns=False,229 )230231 # Create trainer232 trainer = RewardTrainer(233 model=self.model,234 args=training_args,235 train_dataset=pairwise_train,236 eval_dataset=pairwise_eval,237 tokenizer=self.tokenizer,238 )239240 # Train241 print(f"Starting reward model training...")242 print(f"Training examples: {len(pairwise_train)}")243 if pairwise_eval:244 print(f"Evaluation examples: {len(pairwise_eval)}")245246 trainer.train()247248 # Save model249 trainer.save_model()250 self.tokenizer.save_pretrained(output_dir)251252 print(f"Reward model training completed! Saved to {output_dir}")253254 return trainer255256 def evaluate_reward_model(self, test_data: List[Dict]) -> Dict:257 """Evaluate reward model on test data."""258259 self.model.eval()260 results = {261 'accuracy': 0.0,262 'mean_chosen_reward': 0.0,263 'mean_rejected_reward': 0.0,264 'reward_difference': 0.0265 }266267 correct_predictions = 0268 chosen_rewards = []269 rejected_rewards = []270271 for example in test_data:272 prompt = example['prompt']273 chosen = example['chosen']274 rejected = example['rejected']275276 # Score chosen response277 chosen_text = f"{prompt}\n\n{chosen}"278 chosen_inputs = self.tokenizer(279 chosen_text,280 return_tensors="pt",281 truncation=True,282 max_length=self.max_length283 )284 chosen_inputs = {k: v.to(self.model.device) for k, v in chosen_inputs.items()}285286 with torch.no_grad():287 chosen_score = self.model(**chosen_inputs).logits[0, 0].item()288289 # Score rejected response290 rejected_text = f"{prompt}\n\n{rejected}"291 rejected_inputs = self.tokenizer(292 rejected_text,293 return_tensors="pt",294 truncation=True,295 max_length=self.max_length296 )297 rejected_inputs = {k: v.to(self.model.device) for k, v in rejected_inputs.items()}298299 with torch.no_grad():300 rejected_score = self.model(**rejected_inputs).logits[0, 0].item()301302 # Check if model correctly prefers chosen over rejected303 if chosen_score > rejected_score:304 correct_predictions += 1305306 chosen_rewards.append(chosen_score)307 rejected_rewards.append(rejected_score)308309 # Calculate metrics310 results['accuracy'] = correct_predictions / len(test_data)311 results['mean_chosen_reward'] = np.mean(chosen_rewards)312 results['mean_rejected_reward'] = np.mean(rejected_rewards)313 results['reward_difference'] = results['mean_chosen_reward'] - results['mean_rejected_reward']314315 return results316317 def get_reward_score(self, prompt: str, response: str) -> float:318 """Get reward score for a prompt-response pair."""319320 self.model.eval()321322 text = f"{prompt}\n\n{response}"323 inputs = self.tokenizer(324 text,325 return_tensors="pt",326 truncation=True,327 max_length=self.max_length328 )329 inputs = {k: v.to(self.model.device) for k, v in inputs.items()}330331 with torch.no_grad():332 score = self.model(**inputs).logits[0, 0].item()333334 return score335336# Example usage and testing337def create_sample_preference_data():338 """Create sample preference data for training."""339340 preference_data = [341 {342 "prompt": "Explain the concept of gravity.",343 "chosen": "Gravity is a fundamental force of nature that causes objects with mass to attract each other. According to Einstein's theory of general relativity, gravity is not actually a force, but rather the curvature of spacetime caused by mass and energy. This curvature guides the motion of objects, making them appear to be attracted to each other.",344 "rejected": "Gravity is when things fall down because they're heavy."345 },346 {347 "prompt": "How do I cook pasta?",348 "chosen": "To cook pasta: 1) Bring a large pot of salted water to boil, 2) Add pasta and stir occasionally, 3) Cook according to package directions (usually 8-12 minutes) until al dente, 4) Drain and serve immediately. The key is using plenty of water and not overcooking.",349 "rejected": "Put pasta in water and heat it until it's soft. Should be fine."350 },351 {352 "prompt": "What causes climate change?",353 "chosen": "Climate change is primarily caused by increased concentrations of greenhouse gases in the atmosphere, mainly from human activities like burning fossil fuels, deforestation, and industrial processes. These gases trap heat from the sun, leading to global warming and associated climate impacts like sea level rise, extreme weather, and ecosystem disruption.",354 "rejected": "The sun gets hotter sometimes and that changes the climate. It's natural."355 }356 ]357358 return preference_data359360if __name__ == "__main__":361 # Initialize reward model trainer362 trainer = RewardModelTrainer("microsoft/DialoGPT-small", max_length=256)363364 # Setup model365 trainer.setup_reward_model(dropout_rate=0.1)366367 # Create sample data368 preference_data = create_sample_preference_data()369370 # Prepare dataset371 dataset = trainer.prepare_preference_dataset(preference_data)372373 # Split for training and evaluation374 train_size = int(0.8 * len(dataset))375 train_dataset = dataset.select(range(train_size))376 eval_dataset = dataset.select(range(train_size, len(dataset)))377378 # Train reward model379 reward_trainer = trainer.train_reward_model(380 train_dataset=train_dataset,381 eval_dataset=eval_dataset,382 output_dir="./sample_reward_model",383 num_epochs=2,384 batch_size=2, # Small batch for demo385 learning_rate=5e-5386 )387388 # Evaluate model389 results = trainer.evaluate_reward_model(preference_data)390 print("\nReward Model Evaluation Results:")391 for metric, value in results.items():392 print(f"{metric}: {value:.4f}")393394 # Test individual scoring395 print("\nTesting individual reward scoring:")396 test_prompt = "What's the best way to learn programming?"397 good_response = "Start with a beginner-friendly language like Python, practice regularly with small projects, and don't be afraid to make mistakes - they're part of learning!"398 bad_response = "Just read some books about it."399400 good_score = trainer.get_reward_score(test_prompt, good_response)401 bad_score = trainer.get_reward_score(test_prompt, bad_response)402403 print(f"Good response score: {good_score:.4f}")404 print(f"Bad response score: {bad_score:.4f}")405 print(f"Difference: {good_score - bad_score:.4f}")406407 print("\nReward model training completed!")
1# Comprehensive PPO implementation for RLHF2import torch3import torch.nn as nn4import torch.nn.functional as F5from torch.optim import AdamW6from transformers import AutoTokenizer, AutoModelForCausalLM7import numpy as np8from typing import Dict, List, Tuple, Optional9from dataclasses import dataclass10import wandb1112@dataclass13class PPOConfig:14 """Configuration for PPO training."""15 model_name: str = "gpt2"16 learning_rate: float = 1.41e-517 batch_size: int = 6418 mini_batch_size: int = 1619 gradient_accumulation_steps: int = 120 ppo_epochs: int = 421 max_grad_norm: float = 1.022 clip_range: float = 0.223 clip_range_vf: Optional[float] = None24 vf_coef: float = 0.125 target_kl: float = 0.126 init_kl_coef: float = 0.227 adap_kl_ctrl: bool = True28 gamma: float = 1.029 lam: float = 0.9530 use_score_scaling: bool = False31 use_score_norm: bool = False32 score_clip: Optional[float] = None3334class PPOTrainer:35 def __init__(self,36 config: PPOConfig,37 model: nn.Module,38 ref_model: nn.Module,39 reward_model: nn.Module,40 tokenizer):4142 self.config = config43 self.model = model # Policy model44 self.ref_model = ref_model # Reference model (frozen)45 self.reward_model = reward_model46 self.tokenizer = tokenizer4748 # Freeze reference model49 for param in self.ref_model.parameters():50 param.requires_grad = False5152 # Setup optimizer53 self.optimizer = AdamW(54 self.model.parameters(),55 lr=config.learning_rate,56 eps=1e-8,57 weight_decay=0.0158 )5960 # KL controller for adaptive penalty61 self.kl_ctl = AdaptiveKLController(config.init_kl_coef, config.target_kl)6263 # Training statistics64 self.stats = {65 'policy_loss': [],66 'value_loss': [],67 'total_loss': [],68 'kl_divergence': [],69 'rewards': [],70 'advantages': [],71 'approx_kl': [],72 }7374 def generate_responses(self,75 prompts: List[str],76 max_new_tokens: int = 50,77 temperature: float = 0.7,78 top_p: float = 0.9) -> Tuple[List[str], torch.Tensor, torch.Tensor]:79 """Generate responses from current policy."""8081 self.model.eval()8283 all_responses = []84 all_response_tensors = []85 all_log_probs = []8687 for prompt in prompts:88 # Tokenize prompt89 prompt_tokens = self.tokenizer.encode(prompt, return_tensors="pt")90 prompt_tokens = prompt_tokens.to(self.model.device)9192 # Generate response93 with torch.no_grad():94 response_tokens = self.model.generate(95 prompt_tokens,96 max_new_tokens=max_new_tokens,97 do_sample=True,98 temperature=temperature,99 top_p=top_p,100 pad_token_id=self.tokenizer.eos_token_id,101 return_dict_in_generate=True,102 output_scores=True103 )104105 # Extract generated tokens (without prompt)106 generated_tokens = response_tokens.sequences[0][prompt_tokens.shape[1]:]107108 # Calculate log probabilities109 log_probs = []110 for i, token_id in enumerate(generated_tokens):111 if i < len(response_tokens.scores):112 scores = response_tokens.scores[i][0] # [vocab_size]113 log_prob = F.log_softmax(scores, dim=-1)[token_id].item()114 log_probs.append(log_prob)115116 # Decode response117 response_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)118119 all_responses.append(response_text)120 all_response_tensors.append(generated_tokens)121 all_log_probs.append(torch.tensor(log_probs))122123 return all_responses, all_response_tensors, all_log_probs124125 def compute_rewards(self, prompts: List[str], responses: List[str]) -> List[float]:126 """Compute rewards using reward model."""127128 self.reward_model.eval()129 rewards = []130131 for prompt, response in zip(prompts, responses):132 # Create full text133 full_text = f"{prompt}\n\n{response}"134135 # Tokenize and get reward136 inputs = self.tokenizer(137 full_text,138 return_tensors="pt",139 truncation=True,140 max_length=512141 )142 inputs = {k: v.to(self.reward_model.device) for k, v in inputs.items()}143144 with torch.no_grad():145 reward = self.reward_model(**inputs).logits[0, 0].item()146147 rewards.append(reward)148149 return rewards150151 def compute_advantages(self,152 rewards: torch.Tensor,153 values: torch.Tensor,154 masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:155 """Compute GAE (Generalized Advantage Estimation) advantages."""156157 # Add terminal value of 0158 values = torch.cat([values, torch.zeros(1).to(values.device)])159160 advantages = torch.zeros_like(rewards)161 last_gae_lam = 0162163 # Compute advantages using GAE164 for t in reversed(range(len(rewards))):165 if t == len(rewards) - 1:166 next_non_terminal = 0167 next_values = 0168 else:169 next_non_terminal = masks[t + 1]170 next_values = values[t + 1]171172 delta = rewards[t] + self.config.gamma * next_values * next_non_terminal - values[t]173 advantages[t] = last_gae_lam = delta + self.config.gamma * self.config.lam * next_non_terminal * last_gae_lam174175 # Compute returns176 returns = advantages + values[:-1]177178 return advantages, returns179180 def compute_policy_loss(self,181 log_probs: torch.Tensor,182 old_log_probs: torch.Tensor,183 advantages: torch.Tensor,184 masks: torch.Tensor) -> torch.Tensor:185 """Compute clipped PPO policy loss."""186187 # Compute probability ratio188 log_ratio = log_probs - old_log_probs189 ratio = torch.exp(log_ratio)190191 # Compute clipped surrogate loss192 surr1 = ratio * advantages193 surr2 = torch.clamp(ratio, 1 - self.config.clip_range, 1 + self.config.clip_range) * advantages194195 policy_loss = -torch.min(surr1, surr2)196197 # Apply mask and average198 policy_loss = (policy_loss * masks).sum() / masks.sum()199200 return policy_loss201202 def compute_value_loss(self,203 values: torch.Tensor,204 old_values: torch.Tensor,205 returns: torch.Tensor,206 masks: torch.Tensor) -> torch.Tensor:207 """Compute value function loss."""208209 if self.config.clip_range_vf is not None:210 # Clipped value loss211 values_clipped = old_values + torch.clamp(212 values - old_values,213 -self.config.clip_range_vf,214 self.config.clip_range_vf215 )216217 vf_loss1 = (values - returns) ** 2218 vf_loss2 = (values_clipped - returns) ** 2219 vf_loss = torch.max(vf_loss1, vf_loss2)220 else:221 # Standard MSE loss222 vf_loss = (values - returns) ** 2223224 # Apply mask and average225 vf_loss = (vf_loss * masks).sum() / masks.sum()226227 return vf_loss228229 def compute_kl_penalty(self,230 log_probs: torch.Tensor,231 ref_log_probs: torch.Tensor,232 masks: torch.Tensor) -> torch.Tensor:233 """Compute KL divergence penalty."""234235 kl_div = ref_log_probs - log_probs236 kl_penalty = (kl_div * masks).sum() / masks.sum()237238 return kl_penalty239240 def train_step(self, batch_data: Dict) -> Dict:241 """Perform one PPO training step."""242243 self.model.train()244245 # Extract batch data246 prompts = batch_data['prompts']247 responses = batch_data['responses']248 old_log_probs = batch_data['log_probs']249 rewards = batch_data['rewards']250251 # Convert to tensors252 rewards = torch.tensor(rewards, dtype=torch.float32, device=self.model.device)253254 # Generate current policy outputs255 current_responses, response_tensors, current_log_probs = self.generate_responses(prompts)256257 # Compute values (simplified - in practice, use separate value head)258 values = torch.zeros_like(rewards) # Placeholder259260 # Compute advantages261 masks = torch.ones_like(rewards) # Simplified masking262 advantages, returns = self.compute_advantages(rewards, values, masks)263264 # Normalize advantages265 advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)266267 # PPO training loop268 total_policy_loss = 0269 total_value_loss = 0270 total_kl_penalty = 0271272 for ppo_epoch in range(self.config.ppo_epochs):273 # Get current log probs (simplified)274 curr_log_probs = torch.stack([lp.mean() for lp in current_log_probs])275 old_lp = torch.stack([lp.mean() for lp in old_log_probs])276277 # Compute losses278 policy_loss = self.compute_policy_loss(curr_log_probs, old_lp, advantages, masks)279 value_loss = self.compute_value_loss(values, values, returns, masks) # Simplified280281 # Get reference model log probs282 ref_log_probs = self.get_ref_log_probs(prompts, responses)283 kl_penalty = self.compute_kl_penalty(curr_log_probs, ref_log_probs, masks)284285 # Total loss286 total_loss = (287 policy_loss +288 self.config.vf_coef * value_loss +289 self.kl_ctl.value * kl_penalty290 )291292 # Backward pass293 self.optimizer.zero_grad()294 total_loss.backward()295296 # Gradient clipping297 torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)298299 self.optimizer.step()300301 # Accumulate losses302 total_policy_loss += policy_loss.item()303 total_value_loss += value_loss.item()304 total_kl_penalty += kl_penalty.item()305306 # Update KL controller307 mean_kl = total_kl_penalty / self.config.ppo_epochs308 self.kl_ctl.update(mean_kl, batch_data['batch_size'])309310 # Record statistics311 stats = {312 'policy_loss': total_policy_loss / self.config.ppo_epochs,313 'value_loss': total_value_loss / self.config.ppo_epochs,314 'kl_penalty': mean_kl,315 'kl_coef': self.kl_ctl.value,316 'mean_reward': rewards.mean().item(),317 'mean_advantage': advantages.mean().item(),318 }319320 return stats321322 def get_ref_log_probs(self, prompts: List[str], responses: List[str]) -> torch.Tensor:323 """Get log probabilities from reference model."""324325 self.ref_model.eval()326 ref_log_probs = []327328 with torch.no_grad():329 for prompt, response in zip(prompts, responses):330 # Simplified calculation331 ref_log_prob = torch.tensor(0.0) # Placeholder332 ref_log_probs.append(ref_log_prob)333334 return torch.stack(ref_log_probs)335336 def train(self, prompts: List[str], num_steps: int = 1000):337 """Main training loop."""338339 print(f"Starting PPO training for {num_steps} steps...")340341 for step in range(num_steps):342 # Sample batch of prompts343 batch_prompts = np.random.choice(prompts, size=self.config.batch_size, replace=True).tolist()344345 # Generate responses346 responses, response_tensors, log_probs = self.generate_responses(batch_prompts)347348 # Compute rewards349 rewards = self.compute_rewards(batch_prompts, responses)350351 # Prepare batch data352 batch_data = {353 'prompts': batch_prompts,354 'responses': responses,355 'log_probs': log_probs,356 'rewards': rewards,357 'batch_size': len(batch_prompts)358 }359360 # Training step361 stats = self.train_step(batch_data)362363 # Log statistics364 if step % 10 == 0:365 print(f"Step {step}:")366 for key, value in stats.items():367 print(f" {key}: {value:.4f}")368 print()369370 # Record stats371 for key, value in stats.items():372 if key in self.stats:373 self.stats[key].append(value)374375 print("PPO training completed!")376377class AdaptiveKLController:378 """Adaptive KL divergence controller."""379380 def __init__(self, init_kl_coef: float, target_kl: float):381 self.value = init_kl_coef382 self.target = target_kl383384 def update(self, current_kl: float, n_steps: int):385 """Update KL coefficient based on current KL divergence."""386387 if current_kl < self.target / 1.5:388 # KL too low, decrease penalty389 self.value *= 0.98390 elif current_kl > self.target * 1.5:391 # KL too high, increase penalty392 self.value *= 1.02393394 # Clamp to reasonable range395 self.value = max(0.01, min(2.0, self.value))396397# Example usage398if __name__ == "__main__":399 # Initialize models400 model_name = "microsoft/DialoGPT-small"401402 tokenizer = AutoTokenizer.from_pretrained(model_name)403 if tokenizer.pad_token is None:404 tokenizer.pad_token = tokenizer.eos_token405406 # Policy model (trainable)407 policy_model = AutoModelForCausalLM.from_pretrained(model_name)408409 # Reference model (frozen copy)410 ref_model = AutoModelForCausalLM.from_pretrained(model_name)411412 # Reward model (placeholder - use actual trained reward model)413 reward_model = AutoModelForCausalLM.from_pretrained(model_name)414415 # PPO configuration416 config = PPOConfig(417 model_name=model_name,418 learning_rate=1.41e-5,419 batch_size=8, # Small for demo420 mini_batch_size=4,421 ppo_epochs=2,422 target_kl=0.1,423 )424425 # Initialize PPO trainer426 ppo_trainer = PPOTrainer(427 config=config,428 model=policy_model,429 ref_model=ref_model,430 reward_model=reward_model,431 tokenizer=tokenizer432 )433434 # Sample prompts for training435 prompts = [436 "Human: What's the best way to learn programming?\n\nAssistant:",437 "Human: Explain climate change in simple terms.\n\nAssistant:",438 "Human: How do I make a good first impression?\n\nAssistant:",439 "Human: What are the benefits of exercise?\n\nAssistant:",440 ]441442 # Train with PPO443 ppo_trainer.train(prompts, num_steps=50) # Short training for demo444445 # Save trained model446 policy_model.save_pretrained("./ppo_trained_model")447 tokenizer.save_pretrained("./ppo_trained_model")448449 print("PPO training completed and model saved!")