L
Initializing Studio...
Master Supervised Fine-Tuning (SFT) techniques to adapt open-source Large Language Models for specific tasks, domains, and use cases. Learn how to prepare data, optimize training, and achieve superior performance on targeted applications.
Adapt general-purpose models to excel at specific tasks and domains with targeted training data.
Achieve excellent results with thousands of examples rather than billions, making customization accessible.
Shape model outputs to follow specific formats, styles, and behavioral patterns for your use case.
Incorporate specialized knowledge and terminology from medical, legal, financial, and other domains.
1# Complete SFT implementation with Hugging Face Transformers2import torch3from transformers import (4 AutoTokenizer,5 AutoModelForCausalLM,6 TrainingArguments,7 Trainer,8 DataCollatorForLanguageModeling,9 get_linear_schedule_with_warmup10)11from datasets import Dataset, load_dataset12import json13from typing import Dict, List14import numpy as np1516class SupervisedFineTuner:17 def __init__(self, model_name: str, max_length: int = 2048):18 self.model_name = model_name19 self.max_length = max_length20 self.tokenizer = None21 self.model = None2223 def setup_model_and_tokenizer(self):24 """Initialize model and tokenizer for SFT."""2526 # Load tokenizer27 self.tokenizer = AutoTokenizer.from_pretrained(28 self.model_name,29 trust_remote_code=True,30 use_fast=True31 )3233 # Set special tokens34 if self.tokenizer.pad_token is None:35 self.tokenizer.pad_token = self.tokenizer.eos_token3637 # Load model38 self.model = AutoModelForCausalLM.from_pretrained(39 self.model_name,40 torch_dtype=torch.bfloat16,41 device_map="auto",42 trust_remote_code=True,43 use_cache=False # Disable cache for training44 )4546 # Enable gradient checkpointing for memory efficiency47 self.model.gradient_checkpointing_enable()4849 print(f"Model loaded: {self.model_name}")50 print(f"Vocabulary size: {len(self.tokenizer)}")51 print(f"Model parameters: {self.model.num_parameters():,}")5253 def prepare_instruction_dataset(self, data: List[Dict]) -> Dataset:54 """Prepare instruction-following dataset for SFT."""5556 def format_instruction_example(example):57 """Format a single instruction example."""5859 instruction = example.get('instruction', '')60 input_text = example.get('input', '')61 output = example.get('output', '')6263 # Create formatted prompt64 if input_text:65 prompt = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n"66 else:67 prompt = f"### Instruction:\n{instruction}\n\n### Response:\n"6869 # Full text for training70 full_text = prompt + output + self.tokenizer.eos_token7172 return {73 'text': full_text,74 'prompt': prompt,75 'response': output76 }7778 # Format all examples79 formatted_data = [format_instruction_example(item) for item in data]8081 return Dataset.from_list(formatted_data)8283 def prepare_conversation_dataset(self, data: List[Dict]) -> Dataset:84 """Prepare conversational dataset for SFT."""8586 def format_conversation(example):87 """Format a conversation example."""8889 conversation = example.get('conversation', [])90 formatted_text = ""9192 for turn in conversation:93 role = turn.get('role', 'user')94 content = turn.get('content', '')9596 if role == 'user':97 formatted_text += f"Human: {content}\n\n"98 elif role == 'assistant':99 formatted_text += f"Assistant: {content}\n\n"100101 # Add EOS token102 formatted_text += self.tokenizer.eos_token103104 return {'text': formatted_text}105106 # Format all conversations107 formatted_data = [format_conversation(item) for item in data]108109 return Dataset.from_list(formatted_data)110111 def tokenize_dataset(self, dataset: Dataset) -> Dataset:112 """Tokenize dataset for training."""113114 def tokenize_function(examples):115 # Tokenize texts116 tokenized = self.tokenizer(117 examples["text"],118 truncation=True,119 padding=False,120 max_length=self.max_length,121 return_overflowing_tokens=False,122 )123124 # For causal LM, labels are the same as input_ids125 tokenized["labels"] = tokenized["input_ids"].copy()126127 return tokenized128129 # Apply tokenization130 tokenized_dataset = dataset.map(131 tokenize_function,132 batched=True,133 remove_columns=dataset.column_names,134 desc="Tokenizing dataset"135 )136137 # Filter out examples that are too long138 original_size = len(tokenized_dataset)139 tokenized_dataset = tokenized_dataset.filter(140 lambda x: len(x["input_ids"]) <= self.max_length141 )142 final_size = len(tokenized_dataset)143144 print(f"Dataset size: {original_size} -> {final_size} examples")145146 return tokenized_dataset147148 def create_data_collator(self):149 """Create data collator for training."""150151 return DataCollatorForLanguageModeling(152 tokenizer=self.tokenizer,153 mlm=False, # Not masked language modeling154 pad_to_multiple_of=8, # For efficiency on modern GPUs155 )156157 def train(158 self,159 train_dataset: Dataset,160 eval_dataset: Dataset = None,161 output_dir: str = "./sft_results",162 num_epochs: int = 3,163 batch_size: int = 4,164 learning_rate: float = 5e-5,165 warmup_ratio: float = 0.03,166 save_steps: int = 500,167 logging_steps: int = 10,168 eval_steps: int = 500,169 ):170 """Train the model with SFT."""171172 # Calculate total training steps173 total_steps = (len(train_dataset) // batch_size) * num_epochs174 warmup_steps = int(total_steps * warmup_ratio)175176 # Training arguments177 training_args = TrainingArguments(178 output_dir=output_dir,179 num_train_epochs=num_epochs,180 per_device_train_batch_size=batch_size,181 per_device_eval_batch_size=batch_size,182 gradient_accumulation_steps=1,183 learning_rate=learning_rate,184 weight_decay=0.01,185 adam_beta1=0.9,186 adam_beta2=0.999,187 adam_epsilon=1e-8,188 max_grad_norm=1.0,189 warmup_steps=warmup_steps,190 lr_scheduler_type="linear",191 logging_steps=logging_steps,192 save_steps=save_steps,193 eval_steps=eval_steps if eval_dataset else None,194 evaluation_strategy="steps" if eval_dataset else "no",195 save_strategy="steps",196 load_best_model_at_end=True if eval_dataset else False,197 metric_for_best_model="eval_loss" if eval_dataset else None,198 greater_is_better=False,199 report_to="none", # Disable wandb/tensorboard200 dataloader_pin_memory=False,201 gradient_checkpointing=True,202 bf16=True, # Use bfloat16 for stability203 remove_unused_columns=False,204 push_to_hub=False,205 )206207 # Create trainer208 trainer = Trainer(209 model=self.model,210 args=training_args,211 train_dataset=train_dataset,212 eval_dataset=eval_dataset,213 data_collator=self.create_data_collator(),214 tokenizer=self.tokenizer,215 )216217 # Add custom callbacks for monitoring218 class TrainingCallback:219 def on_step_end(self, trainer, logs):220 if trainer.state.global_step % 100 == 0:221 # Print memory usage222 if torch.cuda.is_available():223 memory_allocated = torch.cuda.memory_allocated() / 1024**3224 memory_reserved = torch.cuda.memory_reserved() / 1024**3225 print(f"Step {trainer.state.global_step}: "226 f"Memory Allocated: {memory_allocated:.2f}GB, "227 f"Reserved: {memory_reserved:.2f}GB")228229 trainer.add_callback(TrainingCallback())230231 # Start training232 print(f"Starting SFT training...")233 print(f"Training examples: {len(train_dataset)}")234 if eval_dataset:235 print(f"Evaluation examples: {len(eval_dataset)}")236 print(f"Total training steps: {total_steps}")237 print(f"Warmup steps: {warmup_steps}")238239 trainer.train()240241 # Save final model242 trainer.save_model()243 self.tokenizer.save_pretrained(output_dir)244245 print(f"Training completed! Model saved to {output_dir}")246247 return trainer248249 def evaluate_model(self, test_prompts: List[str], max_new_tokens: int = 200):250 """Evaluate the fine-tuned model on test prompts."""251252 print("\nEvaluating fine-tuned model:")253 print("=" * 60)254255 self.model.eval()256257 for i, prompt in enumerate(test_prompts, 1):258 print(f"\nTest {i}:")259 print(f"Prompt: {prompt}")260261 # Tokenize input262 inputs = self.tokenizer(prompt, return_tensors="pt")263 inputs = {k: v.to(self.model.device) for k, v in inputs.items()}264265 # Generate response266 with torch.no_grad():267 outputs = self.model.generate(268 **inputs,269 max_new_tokens=max_new_tokens,270 do_sample=True,271 temperature=0.7,272 top_p=0.9,273 pad_token_id=self.tokenizer.eos_token_id,274 eos_token_id=self.tokenizer.eos_token_id,275 )276277 # Decode response278 response = self.tokenizer.decode(279 outputs[0][inputs['input_ids'].shape[1]:],280 skip_special_tokens=True281 )282283 print(f"Response: {response}")284 print("-" * 40)285286# Example usage and data preparation287def prepare_sample_instruction_data():288 """Prepare sample instruction-following data."""289290 sample_data = [291 {292 "instruction": "Explain the concept of machine learning in simple terms.",293 "input": "",294 "output": "Machine learning is a type of artificial intelligence where computers learn to make predictions or decisions by analyzing patterns in data, rather than being explicitly programmed for each task. It's like teaching a computer to recognize patterns the same way humans learn from experience."295 },296 {297 "instruction": "Write a Python function to calculate the factorial of a number.",298 "input": "",299 "output": "def factorial(n):\n if n == 0 or n == 1:\n return 1\n else:\n return n * factorial(n - 1)\n\n# Example usage:\n# print(factorial(5)) # Output: 120"300 },301 {302 "instruction": "Summarize the following text.",303 "input": "Artificial intelligence (AI) is intelligence demonstrated by machines, in contrast to the natural intelligence displayed by humans and animals. Leading AI textbooks define the field as the study of intelligent agents: any device that perceives its environment and takes actions that maximize its chance of successfully achieving its goals.",304 "output": "AI refers to machine intelligence that enables devices to perceive their environment and take goal-oriented actions, distinguishing it from natural intelligence found in humans and animals."305 },306 {307 "instruction": "Translate the following English text to French.",308 "input": "Hello, how are you today?",309 "output": "Bonjour, comment allez-vous aujourd'hui ?"310 },311 {312 "instruction": "Generate a creative story beginning with the given sentence.",313 "input": "The old lighthouse stood alone on the rocky cliff.",314 "output": "The old lighthouse stood alone on the rocky cliff, its weathered walls holding secrets of countless storms. Sarah climbed the spiral staircase, each step echoing with memories of the lighthouse keeper who had vanished mysteriously fifty years ago. At the top, she discovered a hidden journal that would change everything she thought she knew about her grandfather's disappearance."315 }316 ]317318 return sample_data319320# Main training script321if __name__ == "__main__":322 # Initialize fine-tuner323 model_name = "mistralai/Mistral-7B-v0.1" # or "meta-llama/Llama-2-7b-hf"324 fine_tuner = SupervisedFineTuner(model_name, max_length=2048)325326 # Setup model and tokenizer327 fine_tuner.setup_model_and_tokenizer()328329 # Prepare training data330 print("Preparing training data...")331 sample_data = prepare_sample_instruction_data()332333 # Create dataset334 dataset = fine_tuner.prepare_instruction_dataset(sample_data)335 tokenized_dataset = fine_tuner.tokenize_dataset(dataset)336337 # Split into train/eval (80/20)338 train_size = int(0.8 * len(tokenized_dataset))339 eval_size = len(tokenized_dataset) - train_size340341 train_dataset = tokenized_dataset.select(range(train_size))342 eval_dataset = tokenized_dataset.select(range(train_size, train_size + eval_size))343344 print(f"Training examples: {len(train_dataset)}")345 print(f"Evaluation examples: {len(eval_dataset)}")346347 # Start training348 trainer = fine_tuner.train(349 train_dataset=train_dataset,350 eval_dataset=eval_dataset,351 output_dir="./sft_mistral_7b",352 num_epochs=3,353 batch_size=2, # Adjust based on your GPU memory354 learning_rate=5e-5,355 warmup_ratio=0.03,356 save_steps=100,357 logging_steps=10,358 eval_steps=50,359 )360361 # Test the fine-tuned model362 test_prompts = [363 "### Instruction:\nExplain quantum computing in simple terms.\n\n### Response:\n",364 "### Instruction:\nWrite a Python function to find the largest number in a list.\n\n### Response:\n",365 "### Instruction:\nWhat are the benefits of renewable energy?\n\n### Response:\n"366 ]367368 fine_tuner.evaluate_model(test_prompts, max_new_tokens=150)369370 print("\nSFT training completed successfully!")
1# Comprehensive data preparation utilities for SFT2import json3import re4import random5from typing import List, Dict, Tuple6from collections import Counter7import pandas as pd8from datasets import Dataset, load_dataset9import hashlib1011class SFTDataProcessor:12 def __init__(self):13 self.processed_data = []14 self.stats = {}1516 def load_alpaca_format(self, file_path: str) -> List[Dict]:17 """Load data in Alpaca format."""1819 with open(file_path, 'r', encoding='utf-8') as f:20 data = json.load(f)2122 # Validate format23 required_keys = ['instruction', 'output']24 valid_data = []2526 for item in data:27 if all(key in item for key in required_keys):28 valid_data.append({29 'instruction': item['instruction'].strip(),30 'input': item.get('input', '').strip(),31 'output': item['output'].strip()32 })3334 print(f"Loaded {len(valid_data)}/{len(data)} valid examples")35 return valid_data3637 def load_conversational_format(self, file_path: str) -> List[Dict]:38 """Load conversational data (ChatML-like format)."""3940 with open(file_path, 'r', encoding='utf-8') as f:41 data = json.load(f)4243 processed_conversations = []4445 for conversation in data:46 if 'messages' in conversation:47 messages = conversation['messages']48 formatted_conversation = []4950 for message in messages:51 if 'role' in message and 'content' in message:52 formatted_conversation.append({53 'role': message['role'],54 'content': message['content'].strip()55 })5657 if len(formatted_conversation) >= 2: # At least one exchange58 processed_conversations.append({59 'conversation': formatted_conversation60 })6162 print(f"Loaded {len(processed_conversations)} conversations")63 return processed_conversations6465 def deduplicate_data(self, data: List[Dict], method: str = 'exact') -> List[Dict]:66 """Remove duplicate examples from dataset."""6768 if method == 'exact':69 # Exact string matching70 seen = set()71 deduplicated = []7273 for item in data:74 # Create hash of instruction + input + output75 content = item['instruction'] + item.get('input', '') + item['output']76 content_hash = hashlib.md5(content.encode()).hexdigest()7778 if content_hash not in seen:79 seen.add(content_hash)80 deduplicated.append(item)8182 elif method == 'fuzzy':83 # Fuzzy matching based on similarity84 from difflib import SequenceMatcher8586 deduplicated = []87 threshold = 0.98889 for item in data:90 is_duplicate = False91 content = item['instruction'] + ' ' + item.get('input', '') + ' ' + item['output']9293 for existing in deduplicated:94 existing_content = existing['instruction'] + ' ' + existing.get('input', '') + ' ' + existing['output']95 similarity = SequenceMatcher(None, content, existing_content).ratio()9697 if similarity > threshold:98 is_duplicate = True99 break100101 if not is_duplicate:102 deduplicated.append(item)103104 print(f"Deduplication: {len(data)} -> {len(deduplicated)} examples")105 return deduplicated106107 def filter_by_quality(self, data: List[Dict]) -> List[Dict]:108 """Filter data based on quality criteria."""109110 filtered_data = []111112 for item in data:113 instruction = item['instruction']114 output = item['output']115116 # Quality checks117 checks = [118 len(instruction.strip()) >= 10, # Minimum instruction length119 len(output.strip()) >= 5, # Minimum output length120 len(output.split()) <= 500, # Maximum output length121 not self._contains_placeholder(instruction, output),122 not self._contains_inappropriate_content(instruction, output),123 self._is_coherent_response(instruction, output)124 ]125126 if all(checks):127 filtered_data.append(item)128129 print(f"Quality filtering: {len(data)} -> {len(filtered_data)} examples")130 return filtered_data131132 def _contains_placeholder(self, instruction: str, output: str) -> bool:133 """Check if text contains placeholder content."""134 placeholders = ['[PLACEHOLDER]', 'TODO', 'FIXME', '...', 'Lorem ipsum']135 text = (instruction + ' ' + output).lower()136 return any(placeholder.lower() in text for placeholder in placeholders)137138 def _contains_inappropriate_content(self, instruction: str, output: str) -> bool:139 """Basic check for inappropriate content."""140 # Simple keyword-based filtering (expand as needed)141 inappropriate_keywords = ['hate', 'violence', 'explicit'] # Simplified list142 text = (instruction + ' ' + output).lower()143 return any(keyword in text for keyword in inappropriate_keywords)144145 def _is_coherent_response(self, instruction: str, output: str) -> bool:146 """Check if the response is coherent with the instruction."""147 # Simple heuristics (can be improved with more sophisticated methods)148149 # Check if output is not just repeating the instruction150 if instruction.lower() in output.lower() and len(output) < len(instruction) * 1.5:151 return False152153 # Check for minimum complexity154 if len(output.split()) < 3:155 return False156157 return True158159 def augment_data(self, data: List[Dict], augmentation_factor: float = 0.2) -> List[Dict]:160 """Augment dataset with variations."""161162 augmented_data = data.copy()163 num_to_augment = int(len(data) * augmentation_factor)164165 # Simple paraphrasing (in practice, use more sophisticated methods)166 paraphrasing_patterns = [167 (r"Explain (.+)", r"Describe \1"),168 (r"What is (.+)?", r"Can you explain \1?"),169 (r"How do I (.+)?", r"What's the way to \1?"),170 (r"Write (.+)", r"Create \1"),171 ]172173 for _ in range(num_to_augment):174 original = random.choice(data)175 augmented = original.copy()176177 # Try to paraphrase the instruction178 for pattern, replacement in paraphrasing_patterns:179 if re.search(pattern, augmented['instruction'], re.IGNORECASE):180 augmented['instruction'] = re.sub(181 pattern, replacement, augmented['instruction'], flags=re.IGNORECASE182 )183 break184185 augmented_data.append(augmented)186187 print(f"Data augmentation: {len(data)} -> {len(augmented_data)} examples")188 return augmented_data189190 def analyze_dataset(self, data: List[Dict]) -> Dict:191 """Analyze dataset characteristics."""192193 analysis = {194 'total_examples': len(data),195 'avg_instruction_length': 0,196 'avg_output_length': 0,197 'instruction_length_distribution': [],198 'output_length_distribution': [],199 'common_instruction_patterns': [],200 }201202 instruction_lengths = []203 output_lengths = []204 instruction_starts = []205206 for item in data:207 inst_len = len(item['instruction'].split())208 out_len = len(item['output'].split())209210 instruction_lengths.append(inst_len)211 output_lengths.append(out_len)212213 # Extract instruction patterns214 first_words = ' '.join(item['instruction'].split()[:3]).lower()215 instruction_starts.append(first_words)216217 analysis['avg_instruction_length'] = sum(instruction_lengths) / len(instruction_lengths)218 analysis['avg_output_length'] = sum(output_lengths) / len(output_lengths)219220 # Length distributions221 analysis['instruction_length_distribution'] = {222 'min': min(instruction_lengths),223 'max': max(instruction_lengths),224 'median': sorted(instruction_lengths)[len(instruction_lengths)//2]225 }226227 analysis['output_length_distribution'] = {228 'min': min(output_lengths),229 'max': max(output_lengths),230 'median': sorted(output_lengths)[len(output_lengths)//2]231 }232233 # Common patterns234 pattern_counts = Counter(instruction_starts)235 analysis['common_instruction_patterns'] = pattern_counts.most_common(10)236237 return analysis238239 def create_balanced_dataset(self, data: List[Dict], categories: List[str] = None) -> List[Dict]:240 """Create a balanced dataset across different categories."""241242 if categories is None:243 # Auto-detect categories based on instruction patterns244 categories = self._auto_detect_categories(data)245246 # Categorize examples247 categorized_data = {cat: [] for cat in categories}248 uncategorized = []249250 for item in data:251 instruction = item['instruction'].lower()252 categorized = False253254 for category in categories:255 if category.lower() in instruction:256 categorized_data[category].append(item)257 categorized = True258 break259260 if not categorized:261 uncategorized.append(item)262263 # Balance categories264 min_category_size = min(len(examples) for examples in categorized_data.values() if examples)265 balanced_data = []266267 for category, examples in categorized_data.items():268 if examples:269 # Sample from each category270 selected = random.sample(examples, min(len(examples), min_category_size))271 balanced_data.extend(selected)272273 # Add some uncategorized examples274 if uncategorized:275 additional_size = len(balanced_data) // 4 # 25% uncategorized276 selected_uncategorized = random.sample(277 uncategorized,278 min(len(uncategorized), additional_size)279 )280 balanced_data.extend(selected_uncategorized)281282 print(f"Balanced dataset: {len(data)} -> {len(balanced_data)} examples")283 return balanced_data284285 def _auto_detect_categories(self, data: List[Dict]) -> List[str]:286 """Auto-detect common categories in the dataset."""287288 # Common instruction types289 patterns = [290 'explain', 'describe', 'write', 'create', 'generate',291 'translate', 'summarize', 'analyze', 'compare', 'define'292 ]293294 detected_categories = []295296 for pattern in patterns:297 count = sum(1 for item in data if pattern in item['instruction'].lower())298 if count >= 5: # Minimum threshold299 detected_categories.append(pattern)300301 return detected_categories[:10] # Limit to top 10 categories302303 def export_processed_data(self, data: List[Dict], output_path: str, format: str = 'json'):304 """Export processed data in specified format."""305306 if format == 'json':307 with open(output_path, 'w', encoding='utf-8') as f:308 json.dump(data, f, indent=2, ensure_ascii=False)309310 elif format == 'jsonl':311 with open(output_path, 'w', encoding='utf-8') as f:312 for item in data:313 f.write(json.dumps(item, ensure_ascii=False) + '\n')314315 elif format == 'csv':316 df = pd.DataFrame(data)317 df.to_csv(output_path, index=False)318319 print(f"Data exported to {output_path} in {format} format")320321# Example usage322if __name__ == "__main__":323 processor = SFTDataProcessor()324325 # Sample data for demonstration326 sample_data = [327 {328 "instruction": "Explain machine learning",329 "input": "",330 "output": "Machine learning is a subset of AI that enables computers to learn from data."331 },332 {333 "instruction": "Write a Python function to add two numbers",334 "input": "",335 "output": "def add(a, b):\n return a + b"336 },337 # Add more examples...338 ]339340 print("Original dataset analysis:")341 analysis = processor.analyze_dataset(sample_data)342 for key, value in analysis.items():343 if key != 'common_instruction_patterns':344 print(f"{key}: {value}")345346 # Process the data347 print("\nProcessing data...")348349 # Deduplicate350 deduplicated = processor.deduplicate_data(sample_data)351352 # Filter by quality353 filtered = processor.filter_by_quality(deduplicated)354355 # Augment data356 augmented = processor.augment_data(filtered, augmentation_factor=0.3)357358 # Create balanced dataset359 balanced = processor.create_balanced_dataset(augmented)360361 print("\nFinal dataset analysis:")362 final_analysis = processor.analyze_dataset(balanced)363 for key, value in final_analysis.items():364 if key != 'common_instruction_patterns':365 print(f"{key}: {value}")366367 # Export processed data368 processor.export_processed_data(balanced, "processed_sft_data.json")369370 print("\nData processing completed!")
1# Advanced SFT techniques implementation2import torch3import torch.nn as nn4from transformers import (5 AutoTokenizer, AutoModelForCausalLM, TrainingArguments,6 Trainer, get_linear_schedule_with_warmup7)8from torch.optim import AdamW9import numpy as np10from typing import Dict, List, Optional11import wandb12from torch.cuda.amp import GradScaler, autocast1314class AdvancedSFTTrainer:15 def __init__(self, model_name: str, use_deepspeed: bool = False):16 self.model_name = model_name17 self.use_deepspeed = use_deepspeed18 self.model = None19 self.tokenizer = None20 self.scaler = GradScaler() if torch.cuda.is_available() else None2122 def setup_model_with_optimizations(self,23 gradient_checkpointing: bool = True,24 mixed_precision: bool = True):25 """Setup model with memory and training optimizations."""2627 # Load tokenizer28 self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)29 if self.tokenizer.pad_token is None:30 self.tokenizer.pad_token = self.tokenizer.eos_token3132 # Load model with optimizations33 self.model = AutoModelForCausalLM.from_pretrained(34 self.model_name,35 torch_dtype=torch.bfloat16 if mixed_precision else torch.float32,36 device_map="auto" if not self.use_deepspeed else None,37 use_cache=False, # Disable for training38 )3940 if gradient_checkpointing:41 self.model.gradient_checkpointing_enable()4243 print(f"Model loaded with optimizations:")44 print(f"- Gradient checkpointing: {gradient_checkpointing}")45 print(f"- Mixed precision: {mixed_precision}")46 print(f"- DeepSpeed: {self.use_deepspeed}")4748 def create_custom_optimizer(self,49 learning_rate: float = 5e-5,50 weight_decay: float = 0.01,51 use_layer_wise_lr: bool = False) -> AdamW:52 """Create optimized AdamW optimizer with optional layer-wise learning rates."""5354 if use_layer_wise_lr:55 # Different learning rates for different layers56 parameter_groups = []5758 # Embedding layers - lower LR59 embedding_params = []60 for name, param in self.model.named_parameters():61 if 'embed' in name or 'wte' in name or 'wpe' in name:62 embedding_params.append(param)6364 if embedding_params:65 parameter_groups.append({66 'params': embedding_params,67 'lr': learning_rate * 0.1, # 10x lower68 'weight_decay': weight_decay69 })7071 # Output layers - higher LR72 output_params = []73 for name, param in self.model.named_parameters():74 if 'lm_head' in name or 'output' in name:75 output_params.append(param)7677 if output_params:78 parameter_groups.append({79 'params': output_params,80 'lr': learning_rate * 2.0, # 2x higher81 'weight_decay': weight_decay82 })8384 # All other parameters - standard LR85 other_params = []86 embedding_names = {id(p) for p in embedding_params}87 output_names = {id(p) for p in output_params}8889 for param in self.model.parameters():90 if id(param) not in embedding_names and id(param) not in output_names:91 other_params.append(param)9293 if other_params:94 parameter_groups.append({95 'params': other_params,96 'lr': learning_rate,97 'weight_decay': weight_decay98 })99100 optimizer = AdamW(parameter_groups, betas=(0.9, 0.999), eps=1e-8)101102 else:103 # Standard optimizer104 optimizer = AdamW(105 self.model.parameters(),106 lr=learning_rate,107 weight_decay=weight_decay,108 betas=(0.9, 0.999),109 eps=1e-8110 )111112 return optimizer113114 def create_curriculum_dataset(self, dataset, difficulty_metric: str = 'length'):115 """Create curriculum learning dataset ordered by difficulty."""116117 def calculate_difficulty(example):118 if difficulty_metric == 'length':119 return len(example['input_ids'])120 elif difficulty_metric == 'vocab_complexity':121 # Simple vocabulary complexity metric122 unique_tokens = len(set(example['input_ids']))123 total_tokens = len(example['input_ids'])124 return unique_tokens / total_tokens125 else:126 return 0.5 # Default neutral difficulty127128 # Calculate difficulty scores129 difficulties = [calculate_difficulty(example) for example in dataset]130131 # Sort by difficulty (easy to hard)132 sorted_indices = sorted(range(len(dataset)), key=lambda i: difficulties[i])133134 # Create curriculum dataset135 curriculum_dataset = dataset.select(sorted_indices)136137 print(f"Created curriculum dataset with {len(curriculum_dataset)} examples")138 return curriculum_dataset139140 def train_with_advanced_techniques(self,141 train_dataset,142 eval_dataset=None,143 output_dir="./advanced_sft",144 num_epochs=3,145 batch_size=4,146 learning_rate=5e-5,147 use_curriculum=True,148 use_label_smoothing=True,149 label_smoothing_factor=0.1,150 use_cosine_schedule=True,151 warmup_ratio=0.03):152 """Train with advanced techniques."""153154 # Setup curriculum learning155 if use_curriculum:156 train_dataset = self.create_curriculum_dataset(train_dataset)157158 # Calculate training steps159 total_steps = (len(train_dataset) // batch_size) * num_epochs160 warmup_steps = int(total_steps * warmup_ratio)161162 # Create custom optimizer163 optimizer = self.create_custom_optimizer(164 learning_rate=learning_rate,165 use_layer_wise_lr=True166 )167168 # Create learning rate scheduler169 if use_cosine_schedule:170 scheduler = get_linear_schedule_with_warmup(171 optimizer,172 num_warmup_steps=warmup_steps,173 num_training_steps=total_steps174 )175 else:176 scheduler = None177178 # Custom loss function with label smoothing179 class LabelSmoothingLoss(nn.Module):180 def __init__(self, smoothing=0.1, vocab_size=None):181 super().__init__()182 self.smoothing = smoothing183 self.vocab_size = vocab_size or len(self.tokenizer)184185 def forward(self, pred, target):186 # Reshape predictions and targets187 pred = pred.view(-1, pred.size(-1))188 target = target.view(-1)189190 # Create smoothed targets191 confidence = 1.0 - self.smoothing192 smooth_value = self.smoothing / (self.vocab_size - 1)193194 # One-hot encode targets195 one_hot = torch.zeros_like(pred).scatter(1, target.unsqueeze(1), confidence)196 one_hot += smooth_value197198 # Compute cross entropy with smoothed labels199 log_probs = torch.log_softmax(pred, dim=1)200 loss = -torch.sum(one_hot * log_probs, dim=1)201202 # Mask padding tokens203 mask = (target != -100).float()204 loss = loss * mask205206 return loss.sum() / mask.sum()207208 # Custom trainer with advanced features209 class AdvancedTrainer(Trainer):210 def __init__(self, *args, label_smoothing_loss=None, **kwargs):211 super().__init__(*args, **kwargs)212 self.label_smoothing_loss = label_smoothing_loss213 self.training_step = 0214215 def compute_loss(self, model, inputs, return_outputs=False):216 labels = inputs.get("labels")217 outputs = model(**inputs)218 logits = outputs.get("logits")219220 if self.label_smoothing_loss and labels is not None:221 loss = self.label_smoothing_loss(logits, labels)222 else:223 loss = outputs.loss224225 return (loss, outputs) if return_outputs else loss226227 def training_step(self, model, inputs):228 """Custom training step with mixed precision."""229 model.train()230 inputs = self._prepare_inputs(inputs)231232 if self.use_amp:233 with autocast():234 loss = self.compute_loss(model, inputs)235 else:236 loss = self.compute_loss(model, inputs)237238 if self.args.n_gpu > 1:239 loss = loss.mean()240241 if self.args.gradient_accumulation_steps > 1:242 loss = loss / self.args.gradient_accumulation_steps243244 if self.use_amp:245 self.scaler.scale(loss).backward()246 else:247 loss.backward()248249 self.training_step += 1250251 return loss.detach()252253 def optimizer_step(self, optimizer):254 """Custom optimizer step with gradient clipping."""255 if self.use_amp:256 self.scaler.unscale_(optimizer)257 torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)258 self.scaler.step(optimizer)259 self.scaler.update()260 else:261 torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)262 optimizer.step()263264 optimizer.zero_grad()265266 # Training arguments267 training_args = TrainingArguments(268 output_dir=output_dir,269 num_train_epochs=num_epochs,270 per_device_train_batch_size=batch_size,271 per_device_eval_batch_size=batch_size,272 gradient_accumulation_steps=1,273 learning_rate=learning_rate,274 weight_decay=0.01,275 max_grad_norm=1.0,276 warmup_steps=warmup_steps,277 logging_steps=10,278 save_steps=500,279 eval_steps=500 if eval_dataset else None,280 evaluation_strategy="steps" if eval_dataset else "no",281 save_strategy="steps",282 load_best_model_at_end=True if eval_dataset else False,283 metric_for_best_model="eval_loss" if eval_dataset else None,284 greater_is_better=False,285 report_to="wandb",286 run_name="advanced_sft",287 bf16=True,288 dataloader_pin_memory=False,289 gradient_checkpointing=True,290 remove_unused_columns=False,291 )292293 # Create loss function294 loss_fn = LabelSmoothingLoss(295 smoothing=label_smoothing_factor,296 vocab_size=len(self.tokenizer)297 ) if use_label_smoothing else None298299 # Create trainer300 trainer = AdvancedTrainer(301 model=self.model,302 args=training_args,303 train_dataset=train_dataset,304 eval_dataset=eval_dataset,305 optimizers=(optimizer, scheduler),306 label_smoothing_loss=loss_fn,307 )308309 # Initialize wandb310 wandb.init(311 project="advanced-sft",312 config={313 "model_name": self.model_name,314 "num_epochs": num_epochs,315 "batch_size": batch_size,316 "learning_rate": learning_rate,317 "use_curriculum": use_curriculum,318 "use_label_smoothing": use_label_smoothing,319 "label_smoothing_factor": label_smoothing_factor,320 }321 )322323 # Training324 print("Starting advanced SFT training...")325 trainer.train()326327 # Save model328 trainer.save_model()329 self.tokenizer.save_pretrained(output_dir)330331 wandb.finish()332 print(f"Advanced SFT completed! Model saved to {output_dir}")333334 return trainer335336 def evaluate_with_multiple_metrics(self, test_dataset, metrics=['perplexity', 'bleu']):337 """Evaluate model with multiple metrics."""338339 from sklearn.metrics import accuracy_score340 import sacrebleu341342 self.model.eval()343 results = {}344345 if 'perplexity' in metrics:346 # Calculate perplexity347 total_loss = 0348 total_tokens = 0349350 for example in test_dataset:351 inputs = {k: torch.tensor(v).unsqueeze(0).to(self.model.device)352 for k, v in example.items() if k in ['input_ids', 'attention_mask']}353354 with torch.no_grad():355 outputs = self.model(**inputs, labels=inputs['input_ids'])356 loss = outputs.loss.item()357 num_tokens = inputs['input_ids'].numel()358359 total_loss += loss * num_tokens360 total_tokens += num_tokens361362 perplexity = torch.exp(torch.tensor(total_loss / total_tokens))363 results['perplexity'] = perplexity.item()364365 if 'bleu' in metrics:366 # Calculate BLEU score (simplified example)367 references = []368 predictions = []369370 for example in test_dataset[:100]: # Sample for efficiency371 # This is a simplified example - adapt based on your data format372 input_ids = example['input_ids'][:50] # First 50 tokens as input373 target_ids = example['input_ids'][50:] # Rest as target374375 inputs = torch.tensor(input_ids).unsqueeze(0).to(self.model.device)376377 with torch.no_grad():378 outputs = self.model.generate(379 inputs,380 max_new_tokens=len(target_ids),381 do_sample=False,382 pad_token_id=self.tokenizer.eos_token_id383 )384385 pred_text = self.tokenizer.decode(outputs[0][len(input_ids):], skip_special_tokens=True)386 ref_text = self.tokenizer.decode(target_ids, skip_special_tokens=True)387388 predictions.append(pred_text)389 references.append([ref_text]) # BLEU expects list of references390391 if predictions and references:392 bleu_score = sacrebleu.corpus_bleu(predictions, references)393 results['bleu'] = bleu_score.score394395 return results396397# Example usage398if __name__ == "__main__":399 # Initialize advanced trainer400 trainer = AdvancedSFTTrainer("mistralai/Mistral-7B-v0.1")401402 # Setup model with optimizations403 trainer.setup_model_with_optimizations(404 gradient_checkpointing=True,405 mixed_precision=True406 )407408 # Prepare sample dataset (replace with your actual data)409 from datasets import Dataset410411 sample_data = [412 {"input_ids": [1, 2, 3, 4, 5] * 100, "attention_mask": [1] * 500},413 {"input_ids": [6, 7, 8, 9, 10] * 80, "attention_mask": [1] * 400},414 # Add more examples...415 ]416417 train_dataset = Dataset.from_list(sample_data)418 eval_dataset = Dataset.from_list(sample_data[:2]) # Small eval set419420 # Train with advanced techniques421 trainer.train_with_advanced_techniques(422 train_dataset=train_dataset,423 eval_dataset=eval_dataset,424 output_dir="./advanced_sft_model",425 num_epochs=2,426 batch_size=2,427 learning_rate=5e-5,428 use_curriculum=True,429 use_label_smoothing=True,430 label_smoothing_factor=0.1,431 use_cosine_schedule=True,432 )433434 # Evaluate with multiple metrics435 results = trainer.evaluate_with_multiple_metrics(eval_dataset)436 print("Evaluation results:", results)