L

Initializing Studio...

Documentation

Getting Started

  • Introduction
  • Quick Start
  • Installation

Fine-tuning

  • LoRA & QLoRA
  • Full Fine-tuning

API & SDK

  • REST API
  • Python SDK

Deployment

  • Cloud Deployment
  • Security

Resources

  • FAQ
  • Changelog

Supervised Fine-Tuning (SFT)

Master Supervised Fine-Tuning (SFT) techniques to adapt open-source Large Language Models for specific tasks, domains, and use cases. Learn how to prepare data, optimize training, and achieve superior performance on targeted applications.

🎯

Task Specialization

Adapt general-purpose models to excel at specific tasks and domains with targeted training data.

📊

Data Efficiency

Achieve excellent results with thousands of examples rather than billions, making customization accessible.

🎛️

Behavioral Control

Shape model outputs to follow specific formats, styles, and behavioral patterns for your use case.

🏥

Domain Expertise

Incorporate specialized knowledge and terminology from medical, legal, financial, and other domains.

Understanding SFT

Supervised Fine-Tuning (SFT) is the process of adapting a pre-trained language model to perform specific tasks by training it on labeled task-specific data. Unlike unsupervised pre-training, SFT uses input-output pairs to teach the model desired behaviors.

Core Concepts:

What is SFT?
SFT takes a pre-trained language model and continues training it on a curated dataset of input-output examples. This process adapts the model's general language understanding to specific tasks, domains, or behaviors.

SFT vs. Pre-training:
●Pre-training: Learns general language patterns from massive unlabeled text

●SFT: Learns specific behaviors from labeled examples

●Data Volume: Pre-training uses trillions of tokens, SFT uses thousands to millions

●Objective: Pre-training optimizes next-token prediction, SFT optimizes task-specific performance


Types of SFT:

1. Instruction Tuning:
●Teaching models to follow instructions

●Format: Instruction → Response

●Examples: "Summarize this text" → Summary

●Results in general-purpose instruction-following models


2. Task-Specific Fine-tuning:
●Adapting models for specific tasks

●Examples: Question answering, sentiment analysis, code generation

●Highly optimized for single use cases


3. Domain Adaptation:
●Specializing models for specific domains

●Examples: Medical, legal, financial, scientific domains

●Incorporates domain-specific knowledge and terminology


4. Behavioral Alignment:
●Training models to exhibit desired behaviors

●Examples: Being helpful, harmless, and honest

●Often combined with reinforcement learning techniques


Key Benefits:
●Task Performance: Dramatically improves performance on target tasks

●Efficiency: Requires less data than training from scratch

●Customization: Allows tailoring to specific requirements

●Control: Better control over model outputs and behavior

●Domain Expertise: Incorporates specialized knowledge


When to Use SFT:
●Adapting general models to specific domains

●Improving performance on targeted tasks

●Teaching new formats or behaviors

●Incorporating proprietary or domain-specific data

●Creating specialized AI assistants
1# Complete SFT implementation with Hugging Face Transformers
2import torch
3from transformers import (
4 AutoTokenizer,
5 AutoModelForCausalLM,
6 TrainingArguments,
7 Trainer,
8 DataCollatorForLanguageModeling,
9 get_linear_schedule_with_warmup
10)
11from datasets import Dataset, load_dataset
12import json
13from typing import Dict, List
14import numpy as np
15
16class SupervisedFineTuner:
17 def __init__(self, model_name: str, max_length: int = 2048):
18 self.model_name = model_name
19 self.max_length = max_length
20 self.tokenizer = None
21 self.model = None
22
23 def setup_model_and_tokenizer(self):
24 """Initialize model and tokenizer for SFT."""
25
26 # Load tokenizer
27 self.tokenizer = AutoTokenizer.from_pretrained(
28 self.model_name,
29 trust_remote_code=True,
30 use_fast=True
31 )
32
33 # Set special tokens
34 if self.tokenizer.pad_token is None:
35 self.tokenizer.pad_token = self.tokenizer.eos_token
36
37 # Load model
38 self.model = AutoModelForCausalLM.from_pretrained(
39 self.model_name,
40 torch_dtype=torch.bfloat16,
41 device_map="auto",
42 trust_remote_code=True,
43 use_cache=False # Disable cache for training
44 )
45
46 # Enable gradient checkpointing for memory efficiency
47 self.model.gradient_checkpointing_enable()
48
49 print(f"Model loaded: {self.model_name}")
50 print(f"Vocabulary size: {len(self.tokenizer)}")
51 print(f"Model parameters: {self.model.num_parameters():,}")
52
53 def prepare_instruction_dataset(self, data: List[Dict]) -> Dataset:
54 """Prepare instruction-following dataset for SFT."""
55
56 def format_instruction_example(example):
57 """Format a single instruction example."""
58
59 instruction = example.get('instruction', '')
60 input_text = example.get('input', '')
61 output = example.get('output', '')
62
63 # Create formatted prompt
64 if input_text:
65 prompt = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n"
66 else:
67 prompt = f"### Instruction:\n{instruction}\n\n### Response:\n"
68
69 # Full text for training
70 full_text = prompt + output + self.tokenizer.eos_token
71
72 return {
73 'text': full_text,
74 'prompt': prompt,
75 'response': output
76 }
77
78 # Format all examples
79 formatted_data = [format_instruction_example(item) for item in data]
80
81 return Dataset.from_list(formatted_data)
82
83 def prepare_conversation_dataset(self, data: List[Dict]) -> Dataset:
84 """Prepare conversational dataset for SFT."""
85
86 def format_conversation(example):
87 """Format a conversation example."""
88
89 conversation = example.get('conversation', [])
90 formatted_text = ""
91
92 for turn in conversation:
93 role = turn.get('role', 'user')
94 content = turn.get('content', '')
95
96 if role == 'user':
97 formatted_text += f"Human: {content}\n\n"
98 elif role == 'assistant':
99 formatted_text += f"Assistant: {content}\n\n"
100
101 # Add EOS token
102 formatted_text += self.tokenizer.eos_token
103
104 return {'text': formatted_text}
105
106 # Format all conversations
107 formatted_data = [format_conversation(item) for item in data]
108
109 return Dataset.from_list(formatted_data)
110
111 def tokenize_dataset(self, dataset: Dataset) -> Dataset:
112 """Tokenize dataset for training."""
113
114 def tokenize_function(examples):
115 # Tokenize texts
116 tokenized = self.tokenizer(
117 examples["text"],
118 truncation=True,
119 padding=False,
120 max_length=self.max_length,
121 return_overflowing_tokens=False,
122 )
123
124 # For causal LM, labels are the same as input_ids
125 tokenized["labels"] = tokenized["input_ids"].copy()
126
127 return tokenized
128
129 # Apply tokenization
130 tokenized_dataset = dataset.map(
131 tokenize_function,
132 batched=True,
133 remove_columns=dataset.column_names,
134 desc="Tokenizing dataset"
135 )
136
137 # Filter out examples that are too long
138 original_size = len(tokenized_dataset)
139 tokenized_dataset = tokenized_dataset.filter(
140 lambda x: len(x["input_ids"]) <= self.max_length
141 )
142 final_size = len(tokenized_dataset)
143
144 print(f"Dataset size: {original_size} -> {final_size} examples")
145
146 return tokenized_dataset
147
148 def create_data_collator(self):
149 """Create data collator for training."""
150
151 return DataCollatorForLanguageModeling(
152 tokenizer=self.tokenizer,
153 mlm=False, # Not masked language modeling
154 pad_to_multiple_of=8, # For efficiency on modern GPUs
155 )
156
157 def train(
158 self,
159 train_dataset: Dataset,
160 eval_dataset: Dataset = None,
161 output_dir: str = "./sft_results",
162 num_epochs: int = 3,
163 batch_size: int = 4,
164 learning_rate: float = 5e-5,
165 warmup_ratio: float = 0.03,
166 save_steps: int = 500,
167 logging_steps: int = 10,
168 eval_steps: int = 500,
169 ):
170 """Train the model with SFT."""
171
172 # Calculate total training steps
173 total_steps = (len(train_dataset) // batch_size) * num_epochs
174 warmup_steps = int(total_steps * warmup_ratio)
175
176 # Training arguments
177 training_args = TrainingArguments(
178 output_dir=output_dir,
179 num_train_epochs=num_epochs,
180 per_device_train_batch_size=batch_size,
181 per_device_eval_batch_size=batch_size,
182 gradient_accumulation_steps=1,
183 learning_rate=learning_rate,
184 weight_decay=0.01,
185 adam_beta1=0.9,
186 adam_beta2=0.999,
187 adam_epsilon=1e-8,
188 max_grad_norm=1.0,
189 warmup_steps=warmup_steps,
190 lr_scheduler_type="linear",
191 logging_steps=logging_steps,
192 save_steps=save_steps,
193 eval_steps=eval_steps if eval_dataset else None,
194 evaluation_strategy="steps" if eval_dataset else "no",
195 save_strategy="steps",
196 load_best_model_at_end=True if eval_dataset else False,
197 metric_for_best_model="eval_loss" if eval_dataset else None,
198 greater_is_better=False,
199 report_to="none", # Disable wandb/tensorboard
200 dataloader_pin_memory=False,
201 gradient_checkpointing=True,
202 bf16=True, # Use bfloat16 for stability
203 remove_unused_columns=False,
204 push_to_hub=False,
205 )
206
207 # Create trainer
208 trainer = Trainer(
209 model=self.model,
210 args=training_args,
211 train_dataset=train_dataset,
212 eval_dataset=eval_dataset,
213 data_collator=self.create_data_collator(),
214 tokenizer=self.tokenizer,
215 )
216
217 # Add custom callbacks for monitoring
218 class TrainingCallback:
219 def on_step_end(self, trainer, logs):
220 if trainer.state.global_step % 100 == 0:
221 # Print memory usage
222 if torch.cuda.is_available():
223 memory_allocated = torch.cuda.memory_allocated() / 1024**3
224 memory_reserved = torch.cuda.memory_reserved() / 1024**3
225 print(f"Step {trainer.state.global_step}: "
226 f"Memory Allocated: {memory_allocated:.2f}GB, "
227 f"Reserved: {memory_reserved:.2f}GB")
228
229 trainer.add_callback(TrainingCallback())
230
231 # Start training
232 print(f"Starting SFT training...")
233 print(f"Training examples: {len(train_dataset)}")
234 if eval_dataset:
235 print(f"Evaluation examples: {len(eval_dataset)}")
236 print(f"Total training steps: {total_steps}")
237 print(f"Warmup steps: {warmup_steps}")
238
239 trainer.train()
240
241 # Save final model
242 trainer.save_model()
243 self.tokenizer.save_pretrained(output_dir)
244
245 print(f"Training completed! Model saved to {output_dir}")
246
247 return trainer
248
249 def evaluate_model(self, test_prompts: List[str], max_new_tokens: int = 200):
250 """Evaluate the fine-tuned model on test prompts."""
251
252 print("\nEvaluating fine-tuned model:")
253 print("=" * 60)
254
255 self.model.eval()
256
257 for i, prompt in enumerate(test_prompts, 1):
258 print(f"\nTest {i}:")
259 print(f"Prompt: {prompt}")
260
261 # Tokenize input
262 inputs = self.tokenizer(prompt, return_tensors="pt")
263 inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
264
265 # Generate response
266 with torch.no_grad():
267 outputs = self.model.generate(
268 **inputs,
269 max_new_tokens=max_new_tokens,
270 do_sample=True,
271 temperature=0.7,
272 top_p=0.9,
273 pad_token_id=self.tokenizer.eos_token_id,
274 eos_token_id=self.tokenizer.eos_token_id,
275 )
276
277 # Decode response
278 response = self.tokenizer.decode(
279 outputs[0][inputs['input_ids'].shape[1]:],
280 skip_special_tokens=True
281 )
282
283 print(f"Response: {response}")
284 print("-" * 40)
285
286# Example usage and data preparation
287def prepare_sample_instruction_data():
288 """Prepare sample instruction-following data."""
289
290 sample_data = [
291 {
292 "instruction": "Explain the concept of machine learning in simple terms.",
293 "input": "",
294 "output": "Machine learning is a type of artificial intelligence where computers learn to make predictions or decisions by analyzing patterns in data, rather than being explicitly programmed for each task. It's like teaching a computer to recognize patterns the same way humans learn from experience."
295 },
296 {
297 "instruction": "Write a Python function to calculate the factorial of a number.",
298 "input": "",
299 "output": "def factorial(n):\n if n == 0 or n == 1:\n return 1\n else:\n return n * factorial(n - 1)\n\n# Example usage:\n# print(factorial(5)) # Output: 120"
300 },
301 {
302 "instruction": "Summarize the following text.",
303 "input": "Artificial intelligence (AI) is intelligence demonstrated by machines, in contrast to the natural intelligence displayed by humans and animals. Leading AI textbooks define the field as the study of intelligent agents: any device that perceives its environment and takes actions that maximize its chance of successfully achieving its goals.",
304 "output": "AI refers to machine intelligence that enables devices to perceive their environment and take goal-oriented actions, distinguishing it from natural intelligence found in humans and animals."
305 },
306 {
307 "instruction": "Translate the following English text to French.",
308 "input": "Hello, how are you today?",
309 "output": "Bonjour, comment allez-vous aujourd'hui ?"
310 },
311 {
312 "instruction": "Generate a creative story beginning with the given sentence.",
313 "input": "The old lighthouse stood alone on the rocky cliff.",
314 "output": "The old lighthouse stood alone on the rocky cliff, its weathered walls holding secrets of countless storms. Sarah climbed the spiral staircase, each step echoing with memories of the lighthouse keeper who had vanished mysteriously fifty years ago. At the top, she discovered a hidden journal that would change everything she thought she knew about her grandfather's disappearance."
315 }
316 ]
317
318 return sample_data
319
320# Main training script
321if __name__ == "__main__":
322 # Initialize fine-tuner
323 model_name = "mistralai/Mistral-7B-v0.1" # or "meta-llama/Llama-2-7b-hf"
324 fine_tuner = SupervisedFineTuner(model_name, max_length=2048)
325
326 # Setup model and tokenizer
327 fine_tuner.setup_model_and_tokenizer()
328
329 # Prepare training data
330 print("Preparing training data...")
331 sample_data = prepare_sample_instruction_data()
332
333 # Create dataset
334 dataset = fine_tuner.prepare_instruction_dataset(sample_data)
335 tokenized_dataset = fine_tuner.tokenize_dataset(dataset)
336
337 # Split into train/eval (80/20)
338 train_size = int(0.8 * len(tokenized_dataset))
339 eval_size = len(tokenized_dataset) - train_size
340
341 train_dataset = tokenized_dataset.select(range(train_size))
342 eval_dataset = tokenized_dataset.select(range(train_size, train_size + eval_size))
343
344 print(f"Training examples: {len(train_dataset)}")
345 print(f"Evaluation examples: {len(eval_dataset)}")
346
347 # Start training
348 trainer = fine_tuner.train(
349 train_dataset=train_dataset,
350 eval_dataset=eval_dataset,
351 output_dir="./sft_mistral_7b",
352 num_epochs=3,
353 batch_size=2, # Adjust based on your GPU memory
354 learning_rate=5e-5,
355 warmup_ratio=0.03,
356 save_steps=100,
357 logging_steps=10,
358 eval_steps=50,
359 )
360
361 # Test the fine-tuned model
362 test_prompts = [
363 "### Instruction:\nExplain quantum computing in simple terms.\n\n### Response:\n",
364 "### Instruction:\nWrite a Python function to find the largest number in a list.\n\n### Response:\n",
365 "### Instruction:\nWhat are the benefits of renewable energy?\n\n### Response:\n"
366 ]
367
368 fine_tuner.evaluate_model(test_prompts, max_new_tokens=150)
369
370 print("\nSFT training completed successfully!")

Data Preparation for SFT

The quality and format of your training data is crucial for successful SFT. Here's how to prepare different types of datasets:

1. Instruction-Following Datasets:

Alpaca Format (Recommended):
●Structure: Instruction + Optional Input + Response

●Use Case: General instruction following

●Benefits: Standardized format, widely supported


ChatML Format:
●Structure: Role-based conversation format

●Use Case: Multi-turn conversations

●Benefits: Natural conversation flow


2. Domain-Specific Datasets:

Medical Domain:
●Data Sources: Medical literature, Q&A pairs, case studies

●Format: Question-answer or case-diagnosis pairs

●Considerations: Accuracy, regulatory compliance, bias reduction


Legal Domain:
●Data Sources: Legal documents, case law, regulations

●Format: Legal query-response or document analysis

●Considerations: Jurisdiction accuracy, ethical guidelines


Code Generation:
●Data Sources: GitHub repositories, coding competitions, documentation

●Format: Natural language description to code

●Considerations: Code quality, security, best practices


3. Data Quality Guidelines:

High-Quality Characteristics:
●Accuracy: Factually correct information

●Relevance: Aligned with target use case

●Diversity: Covers various scenarios and edge cases

●Consistency: Uniform formatting and style

●Completeness: Complete responses without truncation


Data Preprocessing Steps:
1. Deduplication: Remove duplicate or near-duplicate examples
2. Quality Filtering: Remove low-quality, inappropriate, or biased content
3. Length Filtering: Remove examples that are too short or too long
4. Format Validation: Ensure consistent formatting
5. Language Detection: Filter for target language(s)
6. Content Filtering: Remove harmful, toxic, or inappropriate content

4. Data Augmentation Techniques:

Paraphrasing:
●Rewrite instructions and responses in different ways

●Increases dataset diversity

●Helps model generalize better


Back-Translation:
●Translate to another language and back

●Creates natural variations

●Useful for multilingual applications


Synthetic Data Generation:
●Use existing LLMs to generate training examples

●Helpful for expanding small datasets

●Requires careful quality control


5. Dataset Size Recommendations:

Task Complexity vs. Dataset Size:
●Simple Tasks: 500-2,000 examples

●Moderate Tasks: 2,000-10,000 examples

●Complex Tasks: 10,000-50,000+ examples

●Domain Adaptation: 1,000-5,000 high-quality examples


Quality vs. Quantity:
●Prefer 1,000 high-quality examples over 10,000 low-quality ones

●Focus on covering diverse scenarios

●Ensure balanced representation of different use cases
1# Comprehensive data preparation utilities for SFT
2import json
3import re
4import random
5from typing import List, Dict, Tuple
6from collections import Counter
7import pandas as pd
8from datasets import Dataset, load_dataset
9import hashlib
10
11class SFTDataProcessor:
12 def __init__(self):
13 self.processed_data = []
14 self.stats = {}
15
16 def load_alpaca_format(self, file_path: str) -> List[Dict]:
17 """Load data in Alpaca format."""
18
19 with open(file_path, 'r', encoding='utf-8') as f:
20 data = json.load(f)
21
22 # Validate format
23 required_keys = ['instruction', 'output']
24 valid_data = []
25
26 for item in data:
27 if all(key in item for key in required_keys):
28 valid_data.append({
29 'instruction': item['instruction'].strip(),
30 'input': item.get('input', '').strip(),
31 'output': item['output'].strip()
32 })
33
34 print(f"Loaded {len(valid_data)}/{len(data)} valid examples")
35 return valid_data
36
37 def load_conversational_format(self, file_path: str) -> List[Dict]:
38 """Load conversational data (ChatML-like format)."""
39
40 with open(file_path, 'r', encoding='utf-8') as f:
41 data = json.load(f)
42
43 processed_conversations = []
44
45 for conversation in data:
46 if 'messages' in conversation:
47 messages = conversation['messages']
48 formatted_conversation = []
49
50 for message in messages:
51 if 'role' in message and 'content' in message:
52 formatted_conversation.append({
53 'role': message['role'],
54 'content': message['content'].strip()
55 })
56
57 if len(formatted_conversation) >= 2: # At least one exchange
58 processed_conversations.append({
59 'conversation': formatted_conversation
60 })
61
62 print(f"Loaded {len(processed_conversations)} conversations")
63 return processed_conversations
64
65 def deduplicate_data(self, data: List[Dict], method: str = 'exact') -> List[Dict]:
66 """Remove duplicate examples from dataset."""
67
68 if method == 'exact':
69 # Exact string matching
70 seen = set()
71 deduplicated = []
72
73 for item in data:
74 # Create hash of instruction + input + output
75 content = item['instruction'] + item.get('input', '') + item['output']
76 content_hash = hashlib.md5(content.encode()).hexdigest()
77
78 if content_hash not in seen:
79 seen.add(content_hash)
80 deduplicated.append(item)
81
82 elif method == 'fuzzy':
83 # Fuzzy matching based on similarity
84 from difflib import SequenceMatcher
85
86 deduplicated = []
87 threshold = 0.9
88
89 for item in data:
90 is_duplicate = False
91 content = item['instruction'] + ' ' + item.get('input', '') + ' ' + item['output']
92
93 for existing in deduplicated:
94 existing_content = existing['instruction'] + ' ' + existing.get('input', '') + ' ' + existing['output']
95 similarity = SequenceMatcher(None, content, existing_content).ratio()
96
97 if similarity > threshold:
98 is_duplicate = True
99 break
100
101 if not is_duplicate:
102 deduplicated.append(item)
103
104 print(f"Deduplication: {len(data)} -> {len(deduplicated)} examples")
105 return deduplicated
106
107 def filter_by_quality(self, data: List[Dict]) -> List[Dict]:
108 """Filter data based on quality criteria."""
109
110 filtered_data = []
111
112 for item in data:
113 instruction = item['instruction']
114 output = item['output']
115
116 # Quality checks
117 checks = [
118 len(instruction.strip()) >= 10, # Minimum instruction length
119 len(output.strip()) >= 5, # Minimum output length
120 len(output.split()) <= 500, # Maximum output length
121 not self._contains_placeholder(instruction, output),
122 not self._contains_inappropriate_content(instruction, output),
123 self._is_coherent_response(instruction, output)
124 ]
125
126 if all(checks):
127 filtered_data.append(item)
128
129 print(f"Quality filtering: {len(data)} -> {len(filtered_data)} examples")
130 return filtered_data
131
132 def _contains_placeholder(self, instruction: str, output: str) -> bool:
133 """Check if text contains placeholder content."""
134 placeholders = ['[PLACEHOLDER]', 'TODO', 'FIXME', '...', 'Lorem ipsum']
135 text = (instruction + ' ' + output).lower()
136 return any(placeholder.lower() in text for placeholder in placeholders)
137
138 def _contains_inappropriate_content(self, instruction: str, output: str) -> bool:
139 """Basic check for inappropriate content."""
140 # Simple keyword-based filtering (expand as needed)
141 inappropriate_keywords = ['hate', 'violence', 'explicit'] # Simplified list
142 text = (instruction + ' ' + output).lower()
143 return any(keyword in text for keyword in inappropriate_keywords)
144
145 def _is_coherent_response(self, instruction: str, output: str) -> bool:
146 """Check if the response is coherent with the instruction."""
147 # Simple heuristics (can be improved with more sophisticated methods)
148
149 # Check if output is not just repeating the instruction
150 if instruction.lower() in output.lower() and len(output) < len(instruction) * 1.5:
151 return False
152
153 # Check for minimum complexity
154 if len(output.split()) < 3:
155 return False
156
157 return True
158
159 def augment_data(self, data: List[Dict], augmentation_factor: float = 0.2) -> List[Dict]:
160 """Augment dataset with variations."""
161
162 augmented_data = data.copy()
163 num_to_augment = int(len(data) * augmentation_factor)
164
165 # Simple paraphrasing (in practice, use more sophisticated methods)
166 paraphrasing_patterns = [
167 (r"Explain (.+)", r"Describe \1"),
168 (r"What is (.+)?", r"Can you explain \1?"),
169 (r"How do I (.+)?", r"What's the way to \1?"),
170 (r"Write (.+)", r"Create \1"),
171 ]
172
173 for _ in range(num_to_augment):
174 original = random.choice(data)
175 augmented = original.copy()
176
177 # Try to paraphrase the instruction
178 for pattern, replacement in paraphrasing_patterns:
179 if re.search(pattern, augmented['instruction'], re.IGNORECASE):
180 augmented['instruction'] = re.sub(
181 pattern, replacement, augmented['instruction'], flags=re.IGNORECASE
182 )
183 break
184
185 augmented_data.append(augmented)
186
187 print(f"Data augmentation: {len(data)} -> {len(augmented_data)} examples")
188 return augmented_data
189
190 def analyze_dataset(self, data: List[Dict]) -> Dict:
191 """Analyze dataset characteristics."""
192
193 analysis = {
194 'total_examples': len(data),
195 'avg_instruction_length': 0,
196 'avg_output_length': 0,
197 'instruction_length_distribution': [],
198 'output_length_distribution': [],
199 'common_instruction_patterns': [],
200 }
201
202 instruction_lengths = []
203 output_lengths = []
204 instruction_starts = []
205
206 for item in data:
207 inst_len = len(item['instruction'].split())
208 out_len = len(item['output'].split())
209
210 instruction_lengths.append(inst_len)
211 output_lengths.append(out_len)
212
213 # Extract instruction patterns
214 first_words = ' '.join(item['instruction'].split()[:3]).lower()
215 instruction_starts.append(first_words)
216
217 analysis['avg_instruction_length'] = sum(instruction_lengths) / len(instruction_lengths)
218 analysis['avg_output_length'] = sum(output_lengths) / len(output_lengths)
219
220 # Length distributions
221 analysis['instruction_length_distribution'] = {
222 'min': min(instruction_lengths),
223 'max': max(instruction_lengths),
224 'median': sorted(instruction_lengths)[len(instruction_lengths)//2]
225 }
226
227 analysis['output_length_distribution'] = {
228 'min': min(output_lengths),
229 'max': max(output_lengths),
230 'median': sorted(output_lengths)[len(output_lengths)//2]
231 }
232
233 # Common patterns
234 pattern_counts = Counter(instruction_starts)
235 analysis['common_instruction_patterns'] = pattern_counts.most_common(10)
236
237 return analysis
238
239 def create_balanced_dataset(self, data: List[Dict], categories: List[str] = None) -> List[Dict]:
240 """Create a balanced dataset across different categories."""
241
242 if categories is None:
243 # Auto-detect categories based on instruction patterns
244 categories = self._auto_detect_categories(data)
245
246 # Categorize examples
247 categorized_data = {cat: [] for cat in categories}
248 uncategorized = []
249
250 for item in data:
251 instruction = item['instruction'].lower()
252 categorized = False
253
254 for category in categories:
255 if category.lower() in instruction:
256 categorized_data[category].append(item)
257 categorized = True
258 break
259
260 if not categorized:
261 uncategorized.append(item)
262
263 # Balance categories
264 min_category_size = min(len(examples) for examples in categorized_data.values() if examples)
265 balanced_data = []
266
267 for category, examples in categorized_data.items():
268 if examples:
269 # Sample from each category
270 selected = random.sample(examples, min(len(examples), min_category_size))
271 balanced_data.extend(selected)
272
273 # Add some uncategorized examples
274 if uncategorized:
275 additional_size = len(balanced_data) // 4 # 25% uncategorized
276 selected_uncategorized = random.sample(
277 uncategorized,
278 min(len(uncategorized), additional_size)
279 )
280 balanced_data.extend(selected_uncategorized)
281
282 print(f"Balanced dataset: {len(data)} -> {len(balanced_data)} examples")
283 return balanced_data
284
285 def _auto_detect_categories(self, data: List[Dict]) -> List[str]:
286 """Auto-detect common categories in the dataset."""
287
288 # Common instruction types
289 patterns = [
290 'explain', 'describe', 'write', 'create', 'generate',
291 'translate', 'summarize', 'analyze', 'compare', 'define'
292 ]
293
294 detected_categories = []
295
296 for pattern in patterns:
297 count = sum(1 for item in data if pattern in item['instruction'].lower())
298 if count >= 5: # Minimum threshold
299 detected_categories.append(pattern)
300
301 return detected_categories[:10] # Limit to top 10 categories
302
303 def export_processed_data(self, data: List[Dict], output_path: str, format: str = 'json'):
304 """Export processed data in specified format."""
305
306 if format == 'json':
307 with open(output_path, 'w', encoding='utf-8') as f:
308 json.dump(data, f, indent=2, ensure_ascii=False)
309
310 elif format == 'jsonl':
311 with open(output_path, 'w', encoding='utf-8') as f:
312 for item in data:
313 f.write(json.dumps(item, ensure_ascii=False) + '\n')
314
315 elif format == 'csv':
316 df = pd.DataFrame(data)
317 df.to_csv(output_path, index=False)
318
319 print(f"Data exported to {output_path} in {format} format")
320
321# Example usage
322if __name__ == "__main__":
323 processor = SFTDataProcessor()
324
325 # Sample data for demonstration
326 sample_data = [
327 {
328 "instruction": "Explain machine learning",
329 "input": "",
330 "output": "Machine learning is a subset of AI that enables computers to learn from data."
331 },
332 {
333 "instruction": "Write a Python function to add two numbers",
334 "input": "",
335 "output": "def add(a, b):\n return a + b"
336 },
337 # Add more examples...
338 ]
339
340 print("Original dataset analysis:")
341 analysis = processor.analyze_dataset(sample_data)
342 for key, value in analysis.items():
343 if key != 'common_instruction_patterns':
344 print(f"{key}: {value}")
345
346 # Process the data
347 print("\nProcessing data...")
348
349 # Deduplicate
350 deduplicated = processor.deduplicate_data(sample_data)
351
352 # Filter by quality
353 filtered = processor.filter_by_quality(deduplicated)
354
355 # Augment data
356 augmented = processor.augment_data(filtered, augmentation_factor=0.3)
357
358 # Create balanced dataset
359 balanced = processor.create_balanced_dataset(augmented)
360
361 print("\nFinal dataset analysis:")
362 final_analysis = processor.analyze_dataset(balanced)
363 for key, value in final_analysis.items():
364 if key != 'common_instruction_patterns':
365 print(f"{key}: {value}")
366
367 # Export processed data
368 processor.export_processed_data(balanced, "processed_sft_data.json")
369
370 print("\nData processing completed!")

Advanced SFT Techniques

Beyond basic supervised fine-tuning, several advanced techniques can improve training efficiency and model performance:

1. Gradient Checkpointing and Memory Optimization:

Gradient Checkpointing:
●Trades computation for memory

●Enables training larger models on limited hardware

●Typically reduces memory usage by 50-70%


Mixed Precision Training:
●Uses both 16-bit and 32-bit floating point numbers

●Speeds up training while maintaining stability

●Automatic Mixed Precision (AMP) handles this automatically


DeepSpeed Integration:
●ZeRO optimizer for distributed training

●Significantly reduces memory usage

●Enables training models that wouldn't fit on single GPU


2. Learning Rate Scheduling:

Warmup Strategies:
●Linear warmup: Gradually increase learning rate

●Cosine warmup: Smooth increase with cosine function

●Prevents early training instability


Decay Strategies:
●Linear decay: Steady decrease over time

●Cosine decay: Smooth decrease following cosine curve

●Step decay: Discrete steps down at intervals


Adaptive Learning Rates:
●Different rates for different parameter groups

●Higher rates for task-specific layers

●Lower rates for pre-trained parameters


3. Regularization Techniques:

Dropout:
●Apply to attention layers and feedforward networks

●Helps prevent overfitting on small datasets

●Typical values: 0.1-0.3


Weight Decay:
●L2 regularization to prevent large weights

●Helps with generalization

●Typical values: 0.01-0.1


Label Smoothing:
●Softens hard targets

●Improves calibration and generalization

●Typical epsilon: 0.1


4. Data Efficiency Techniques:

Curriculum Learning:
●Start with easier examples

●Gradually increase difficulty

●Can improve convergence and final performance


Active Learning:
●Iteratively select most informative examples

●Maximizes learning from limited data

●Particularly useful for domain-specific applications


Few-Shot Learning:
●Learn from very few examples per task

●Useful when data is scarce

●Can be combined with meta-learning approaches


5. Multi-Task and Transfer Learning:

Multi-Task Fine-Tuning:
●Train on multiple related tasks simultaneously

●Shared representations benefit all tasks

●Requires careful task balancing


Sequential Fine-Tuning:
●Fine-tune on related tasks first

●Then fine-tune on target task

●Can improve performance on low-resource tasks


Domain Adaptation:
●Gradual adaptation from source to target domain

●Useful when target domain data is limited

●Can use domain adversarial training


6. Evaluation and Monitoring:

Perplexity Tracking:
●Monitor language modeling performance

●Lower perplexity generally indicates better fit

●Watch for overfitting patterns


Task-Specific Metrics:
●ROUGE for summarization

●BLEU for translation

●Exact match for QA

●Custom metrics for domain tasks


Validation Strategies:
●Hold-out validation set

●Cross-validation for small datasets

●Temporal splits for time-sensitive data


7. Hyperparameter Optimization:

Grid Search:
●Systematic exploration of hyperparameter space

●Good for small number of parameters

●Computationally expensive


Random Search:
●Random sampling of hyperparameter combinations

●Often more efficient than grid search

●Good for high-dimensional spaces


Bayesian Optimization:
●Uses previous results to guide search

●More sample-efficient than random/grid search

●Tools: Optuna, Hyperopt, Ray Tune


8. Model Distillation:

Knowledge Distillation:
●Train smaller student model to mimic larger teacher

●Maintains much of the performance with less compute

●Useful for deployment constraints


Progressive Distillation:
●Gradually reduce model size through multiple stages

●Can achieve better size/performance trade-offs

●Particularly effective for transformer models
1# Advanced SFT techniques implementation
2import torch
3import torch.nn as nn
4from transformers import (
5 AutoTokenizer, AutoModelForCausalLM, TrainingArguments,
6 Trainer, get_linear_schedule_with_warmup
7)
8from torch.optim import AdamW
9import numpy as np
10from typing import Dict, List, Optional
11import wandb
12from torch.cuda.amp import GradScaler, autocast
13
14class AdvancedSFTTrainer:
15 def __init__(self, model_name: str, use_deepspeed: bool = False):
16 self.model_name = model_name
17 self.use_deepspeed = use_deepspeed
18 self.model = None
19 self.tokenizer = None
20 self.scaler = GradScaler() if torch.cuda.is_available() else None
21
22 def setup_model_with_optimizations(self,
23 gradient_checkpointing: bool = True,
24 mixed_precision: bool = True):
25 """Setup model with memory and training optimizations."""
26
27 # Load tokenizer
28 self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
29 if self.tokenizer.pad_token is None:
30 self.tokenizer.pad_token = self.tokenizer.eos_token
31
32 # Load model with optimizations
33 self.model = AutoModelForCausalLM.from_pretrained(
34 self.model_name,
35 torch_dtype=torch.bfloat16 if mixed_precision else torch.float32,
36 device_map="auto" if not self.use_deepspeed else None,
37 use_cache=False, # Disable for training
38 )
39
40 if gradient_checkpointing:
41 self.model.gradient_checkpointing_enable()
42
43 print(f"Model loaded with optimizations:")
44 print(f"- Gradient checkpointing: {gradient_checkpointing}")
45 print(f"- Mixed precision: {mixed_precision}")
46 print(f"- DeepSpeed: {self.use_deepspeed}")
47
48 def create_custom_optimizer(self,
49 learning_rate: float = 5e-5,
50 weight_decay: float = 0.01,
51 use_layer_wise_lr: bool = False) -> AdamW:
52 """Create optimized AdamW optimizer with optional layer-wise learning rates."""
53
54 if use_layer_wise_lr:
55 # Different learning rates for different layers
56 parameter_groups = []
57
58 # Embedding layers - lower LR
59 embedding_params = []
60 for name, param in self.model.named_parameters():
61 if 'embed' in name or 'wte' in name or 'wpe' in name:
62 embedding_params.append(param)
63
64 if embedding_params:
65 parameter_groups.append({
66 'params': embedding_params,
67 'lr': learning_rate * 0.1, # 10x lower
68 'weight_decay': weight_decay
69 })
70
71 # Output layers - higher LR
72 output_params = []
73 for name, param in self.model.named_parameters():
74 if 'lm_head' in name or 'output' in name:
75 output_params.append(param)
76
77 if output_params:
78 parameter_groups.append({
79 'params': output_params,
80 'lr': learning_rate * 2.0, # 2x higher
81 'weight_decay': weight_decay
82 })
83
84 # All other parameters - standard LR
85 other_params = []
86 embedding_names = {id(p) for p in embedding_params}
87 output_names = {id(p) for p in output_params}
88
89 for param in self.model.parameters():
90 if id(param) not in embedding_names and id(param) not in output_names:
91 other_params.append(param)
92
93 if other_params:
94 parameter_groups.append({
95 'params': other_params,
96 'lr': learning_rate,
97 'weight_decay': weight_decay
98 })
99
100 optimizer = AdamW(parameter_groups, betas=(0.9, 0.999), eps=1e-8)
101
102 else:
103 # Standard optimizer
104 optimizer = AdamW(
105 self.model.parameters(),
106 lr=learning_rate,
107 weight_decay=weight_decay,
108 betas=(0.9, 0.999),
109 eps=1e-8
110 )
111
112 return optimizer
113
114 def create_curriculum_dataset(self, dataset, difficulty_metric: str = 'length'):
115 """Create curriculum learning dataset ordered by difficulty."""
116
117 def calculate_difficulty(example):
118 if difficulty_metric == 'length':
119 return len(example['input_ids'])
120 elif difficulty_metric == 'vocab_complexity':
121 # Simple vocabulary complexity metric
122 unique_tokens = len(set(example['input_ids']))
123 total_tokens = len(example['input_ids'])
124 return unique_tokens / total_tokens
125 else:
126 return 0.5 # Default neutral difficulty
127
128 # Calculate difficulty scores
129 difficulties = [calculate_difficulty(example) for example in dataset]
130
131 # Sort by difficulty (easy to hard)
132 sorted_indices = sorted(range(len(dataset)), key=lambda i: difficulties[i])
133
134 # Create curriculum dataset
135 curriculum_dataset = dataset.select(sorted_indices)
136
137 print(f"Created curriculum dataset with {len(curriculum_dataset)} examples")
138 return curriculum_dataset
139
140 def train_with_advanced_techniques(self,
141 train_dataset,
142 eval_dataset=None,
143 output_dir="./advanced_sft",
144 num_epochs=3,
145 batch_size=4,
146 learning_rate=5e-5,
147 use_curriculum=True,
148 use_label_smoothing=True,
149 label_smoothing_factor=0.1,
150 use_cosine_schedule=True,
151 warmup_ratio=0.03):
152 """Train with advanced techniques."""
153
154 # Setup curriculum learning
155 if use_curriculum:
156 train_dataset = self.create_curriculum_dataset(train_dataset)
157
158 # Calculate training steps
159 total_steps = (len(train_dataset) // batch_size) * num_epochs
160 warmup_steps = int(total_steps * warmup_ratio)
161
162 # Create custom optimizer
163 optimizer = self.create_custom_optimizer(
164 learning_rate=learning_rate,
165 use_layer_wise_lr=True
166 )
167
168 # Create learning rate scheduler
169 if use_cosine_schedule:
170 scheduler = get_linear_schedule_with_warmup(
171 optimizer,
172 num_warmup_steps=warmup_steps,
173 num_training_steps=total_steps
174 )
175 else:
176 scheduler = None
177
178 # Custom loss function with label smoothing
179 class LabelSmoothingLoss(nn.Module):
180 def __init__(self, smoothing=0.1, vocab_size=None):
181 super().__init__()
182 self.smoothing = smoothing
183 self.vocab_size = vocab_size or len(self.tokenizer)
184
185 def forward(self, pred, target):
186 # Reshape predictions and targets
187 pred = pred.view(-1, pred.size(-1))
188 target = target.view(-1)
189
190 # Create smoothed targets
191 confidence = 1.0 - self.smoothing
192 smooth_value = self.smoothing / (self.vocab_size - 1)
193
194 # One-hot encode targets
195 one_hot = torch.zeros_like(pred).scatter(1, target.unsqueeze(1), confidence)
196 one_hot += smooth_value
197
198 # Compute cross entropy with smoothed labels
199 log_probs = torch.log_softmax(pred, dim=1)
200 loss = -torch.sum(one_hot * log_probs, dim=1)
201
202 # Mask padding tokens
203 mask = (target != -100).float()
204 loss = loss * mask
205
206 return loss.sum() / mask.sum()
207
208 # Custom trainer with advanced features
209 class AdvancedTrainer(Trainer):
210 def __init__(self, *args, label_smoothing_loss=None, **kwargs):
211 super().__init__(*args, **kwargs)
212 self.label_smoothing_loss = label_smoothing_loss
213 self.training_step = 0
214
215 def compute_loss(self, model, inputs, return_outputs=False):
216 labels = inputs.get("labels")
217 outputs = model(**inputs)
218 logits = outputs.get("logits")
219
220 if self.label_smoothing_loss and labels is not None:
221 loss = self.label_smoothing_loss(logits, labels)
222 else:
223 loss = outputs.loss
224
225 return (loss, outputs) if return_outputs else loss
226
227 def training_step(self, model, inputs):
228 """Custom training step with mixed precision."""
229 model.train()
230 inputs = self._prepare_inputs(inputs)
231
232 if self.use_amp:
233 with autocast():
234 loss = self.compute_loss(model, inputs)
235 else:
236 loss = self.compute_loss(model, inputs)
237
238 if self.args.n_gpu > 1:
239 loss = loss.mean()
240
241 if self.args.gradient_accumulation_steps > 1:
242 loss = loss / self.args.gradient_accumulation_steps
243
244 if self.use_amp:
245 self.scaler.scale(loss).backward()
246 else:
247 loss.backward()
248
249 self.training_step += 1
250
251 return loss.detach()
252
253 def optimizer_step(self, optimizer):
254 """Custom optimizer step with gradient clipping."""
255 if self.use_amp:
256 self.scaler.unscale_(optimizer)
257 torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
258 self.scaler.step(optimizer)
259 self.scaler.update()
260 else:
261 torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
262 optimizer.step()
263
264 optimizer.zero_grad()
265
266 # Training arguments
267 training_args = TrainingArguments(
268 output_dir=output_dir,
269 num_train_epochs=num_epochs,
270 per_device_train_batch_size=batch_size,
271 per_device_eval_batch_size=batch_size,
272 gradient_accumulation_steps=1,
273 learning_rate=learning_rate,
274 weight_decay=0.01,
275 max_grad_norm=1.0,
276 warmup_steps=warmup_steps,
277 logging_steps=10,
278 save_steps=500,
279 eval_steps=500 if eval_dataset else None,
280 evaluation_strategy="steps" if eval_dataset else "no",
281 save_strategy="steps",
282 load_best_model_at_end=True if eval_dataset else False,
283 metric_for_best_model="eval_loss" if eval_dataset else None,
284 greater_is_better=False,
285 report_to="wandb",
286 run_name="advanced_sft",
287 bf16=True,
288 dataloader_pin_memory=False,
289 gradient_checkpointing=True,
290 remove_unused_columns=False,
291 )
292
293 # Create loss function
294 loss_fn = LabelSmoothingLoss(
295 smoothing=label_smoothing_factor,
296 vocab_size=len(self.tokenizer)
297 ) if use_label_smoothing else None
298
299 # Create trainer
300 trainer = AdvancedTrainer(
301 model=self.model,
302 args=training_args,
303 train_dataset=train_dataset,
304 eval_dataset=eval_dataset,
305 optimizers=(optimizer, scheduler),
306 label_smoothing_loss=loss_fn,
307 )
308
309 # Initialize wandb
310 wandb.init(
311 project="advanced-sft",
312 config={
313 "model_name": self.model_name,
314 "num_epochs": num_epochs,
315 "batch_size": batch_size,
316 "learning_rate": learning_rate,
317 "use_curriculum": use_curriculum,
318 "use_label_smoothing": use_label_smoothing,
319 "label_smoothing_factor": label_smoothing_factor,
320 }
321 )
322
323 # Training
324 print("Starting advanced SFT training...")
325 trainer.train()
326
327 # Save model
328 trainer.save_model()
329 self.tokenizer.save_pretrained(output_dir)
330
331 wandb.finish()
332 print(f"Advanced SFT completed! Model saved to {output_dir}")
333
334 return trainer
335
336 def evaluate_with_multiple_metrics(self, test_dataset, metrics=['perplexity', 'bleu']):
337 """Evaluate model with multiple metrics."""
338
339 from sklearn.metrics import accuracy_score
340 import sacrebleu
341
342 self.model.eval()
343 results = {}
344
345 if 'perplexity' in metrics:
346 # Calculate perplexity
347 total_loss = 0
348 total_tokens = 0
349
350 for example in test_dataset:
351 inputs = {k: torch.tensor(v).unsqueeze(0).to(self.model.device)
352 for k, v in example.items() if k in ['input_ids', 'attention_mask']}
353
354 with torch.no_grad():
355 outputs = self.model(**inputs, labels=inputs['input_ids'])
356 loss = outputs.loss.item()
357 num_tokens = inputs['input_ids'].numel()
358
359 total_loss += loss * num_tokens
360 total_tokens += num_tokens
361
362 perplexity = torch.exp(torch.tensor(total_loss / total_tokens))
363 results['perplexity'] = perplexity.item()
364
365 if 'bleu' in metrics:
366 # Calculate BLEU score (simplified example)
367 references = []
368 predictions = []
369
370 for example in test_dataset[:100]: # Sample for efficiency
371 # This is a simplified example - adapt based on your data format
372 input_ids = example['input_ids'][:50] # First 50 tokens as input
373 target_ids = example['input_ids'][50:] # Rest as target
374
375 inputs = torch.tensor(input_ids).unsqueeze(0).to(self.model.device)
376
377 with torch.no_grad():
378 outputs = self.model.generate(
379 inputs,
380 max_new_tokens=len(target_ids),
381 do_sample=False,
382 pad_token_id=self.tokenizer.eos_token_id
383 )
384
385 pred_text = self.tokenizer.decode(outputs[0][len(input_ids):], skip_special_tokens=True)
386 ref_text = self.tokenizer.decode(target_ids, skip_special_tokens=True)
387
388 predictions.append(pred_text)
389 references.append([ref_text]) # BLEU expects list of references
390
391 if predictions and references:
392 bleu_score = sacrebleu.corpus_bleu(predictions, references)
393 results['bleu'] = bleu_score.score
394
395 return results
396
397# Example usage
398if __name__ == "__main__":
399 # Initialize advanced trainer
400 trainer = AdvancedSFTTrainer("mistralai/Mistral-7B-v0.1")
401
402 # Setup model with optimizations
403 trainer.setup_model_with_optimizations(
404 gradient_checkpointing=True,
405 mixed_precision=True
406 )
407
408 # Prepare sample dataset (replace with your actual data)
409 from datasets import Dataset
410
411 sample_data = [
412 {"input_ids": [1, 2, 3, 4, 5] * 100, "attention_mask": [1] * 500},
413 {"input_ids": [6, 7, 8, 9, 10] * 80, "attention_mask": [1] * 400},
414 # Add more examples...
415 ]
416
417 train_dataset = Dataset.from_list(sample_data)
418 eval_dataset = Dataset.from_list(sample_data[:2]) # Small eval set
419
420 # Train with advanced techniques
421 trainer.train_with_advanced_techniques(
422 train_dataset=train_dataset,
423 eval_dataset=eval_dataset,
424 output_dir="./advanced_sft_model",
425 num_epochs=2,
426 batch_size=2,
427 learning_rate=5e-5,
428 use_curriculum=True,
429 use_label_smoothing=True,
430 label_smoothing_factor=0.1,
431 use_cosine_schedule=True,
432 )
433
434 # Evaluate with multiple metrics
435 results = trainer.evaluate_with_multiple_metrics(eval_dataset)
436 print("Evaluation results:", results)

On this page

Understanding SFTData Preparation for SFTAdvanced SFT Techniques