Saving and Loading State
During training, you'll need to save checkpoints for two primary purposes: sampling (to test your model) and resuming training (to continue from where you left off). The TrainingClient provides three methods to handle these use cases efficiently.
State Management Methods
save_weights_for_sampler()Saves model weights only (fast, smaller storage)save_state()Saves weights + optimizer state (for resuming training)load_state()Restores weights + optimizer state (full training resume)Checkpoint Naming and Paths
Both save_* functions require a name parameter—a string identifier for the checkpoint within your training run.
1# Sequential numbering
2training_client.save_state(name="0000")
3training_client.save_state(name="0001")
4
5# Step-based naming
6training_client.save_state(name="step_1000")
7training_client.save_state(name="step_2000")
8
9# Epoch-based naming
10training_client.save_state(name=f"epoch_{epoch}_step_{step}")
11
12# Descriptive naming
13training_client.save_state(name="finance_model_v1_final")Checkpoint Paths
The return value contains a path field—a fully-qualified identifier that persists across sessions:
bios://<model_id>/<checkpoint_name>Example: bios://usf-finance-abc123/step_5000
This path can be used later by a new ServiceClient or TrainingClient to restore the checkpoint.
Saving for Sampling
Use save_weights_for_sampler() when you want to test your model during training without preserving optimizer state. This is faster and requires less storage.
1import bios
2
3# Setup
4service_client = bios.ServiceClient()
5training_client = service_client.create_lora_training_client(
6    base_model="ultrasafe/usf-finance",
7    rank=32
8)
9
10# ... perform training steps ...
11
12# Save checkpoint for sampling
13sampling_path = training_client.save_weights_for_sampler(
14    name="checkpoint_1000"
15).result().path
16
17print(f"Saved weights at: {sampling_path}")
18
19# Create sampling client with that checkpoint
20sampling_client = service_client.create_sampling_client(
21    model_path=sampling_path
22)
23
24# Test the model
25from bios import types
26prompt = types.ModelInput.from_ints(
27    tokenizer.encode("Analyze this financial report...")
28)
29result = sampling_client.sample(
30    prompt=prompt,
31    sampling_params=types.SamplingParams(max_tokens=256, temperature=0.7)
32).result()
33
34print(tokenizer.decode(result.sequences[0].tokens))Shortcut Method
Combine saving and sampling client creation in one step:
1# Shortcut: save weights and get sampling client in one call
2sampling_client = training_client.save_weights_and_get_sampling_client(
3    name="checkpoint_1000"
4)
5
6# Immediately ready to sample
7result = sampling_client.sample(prompt, params).result()
8print(tokenizer.decode(result.sequences[0].tokens))✓ When to use save_weights_for_sampler()
- • Testing model quality during training (intermediate checkpoints)
- • Generating samples for evaluation or human review
- • Creating demonstration outputs
- • When you don't need to resume training from this point
Saving to Resume Training
Use save_state() and load_state() when you need to pause and continue training with full optimizer state preserved.
1import bios
2
3# Setup training client
4service_client = bios.ServiceClient()
5training_client = service_client.create_lora_training_client(
6    base_model="ultrasafe/usf-healthcare",
7    rank=16
8)
9
10# Train for some steps
11for step in range(1000):
12    training_client.forward_backward(batch, "cross_entropy")
13    training_client.optim_step()
14
15# Save complete training state (weights + optimizer)
16resume_path = training_client.save_state(
17    name="checkpoint_step_1000"
18).result().path
19
20print(f"Saved training state at: {resume_path}")
21
22# Later: Load the checkpoint to resume training
23training_client.load_state(resume_path)
24
25# Continue training from exactly where you left off
26for step in range(1000, 2000):
27    training_client.forward_backward(batch, "cross_entropy")
28    training_client.optim_step()What's Saved in save_state()?
- • Model Weights: All LoRA adapter parameters
- • Optimizer State: Adam momentum buffers, variance estimates
- • Learning Rate Schedule: Current step and scheduler state
- • Training Metadata: Configuration, hyperparameters, step count
✓ When to use save_state() and load_state()
- • Multi-stage training: Supervised learning followed by reinforcement learning
- • Hyperparameter adjustment: Changing learning rate or other settings mid-training
- • Recovery from interruptions: Hardware failures, timeouts, or manual stops
- • Exact optimizer state: Preserving momentum, variance, and learning rate schedules
- • Experiment forking: Branching from a checkpoint to try different configurations
Complete Checkpoint Management Example
Here's a comprehensive example showing both approaches:
1import bios
2from bios import types
3
4# Initialize training
5service_client = bios.ServiceClient()
6training_client = service_client.create_lora_training_client(
7    base_model="ultrasafe/usf-code",
8    rank=16
9)
10
11# Training loop with periodic checkpointing
12checkpoint_interval = 500
13sampling_interval = 1000
14
15for step in range(10000):
16    # Training step
17    training_client.forward_backward(batches[step], "cross_entropy")
18    training_client.optim_step()
19    
20    # Periodic full checkpoint for resumption
21    if step % checkpoint_interval == 0:
22        state_path = training_client.save_state(
23            name=f"step_{step}"
24        ).result().path
25        print(f"Step {step}: Saved state at {state_path}")
26    
27    # Periodic sampling checkpoint for evaluation
28    if step % sampling_interval == 0:
29        # Quick save for sampling (doesn't save optimizer state)
30        sampling_client = training_client.save_weights_and_get_sampling_client(
31            name=f"eval_{step}"
32        )
33        
34        # Test the model
35        test_prompt = types.ModelInput.from_ints(
36            tokenizer.encode("Implement a binary search in Python:")
37        )
38        result = sampling_client.sample(
39            prompt=test_prompt,
40            sampling_params=types.SamplingParams(max_tokens=512, temperature=0.3)
41        ).result()
42        
43        print(f"Step {step}: Generated code sample")
44        print(tokenizer.decode(result.sequences[0].tokens))
45
46print("Training complete!")
47
48# Final model save
49final_path = training_client.save_state(name="final_model").result().path
50print(f"Final model saved at: {final_path}")Resuming Training from Checkpoint
Load a previously saved checkpoint to continue training with exact optimizer state preservation:
1import bios
2
3# Create new service client (e.g., after restart)
4service_client = bios.ServiceClient()
5
6# Option 1: Load from path string
7checkpoint_path = "bios://usf-finance-abc123/step_5000"
8training_client = service_client.load_training_client(checkpoint_path)
9
10# Option 2: Create new client and load state
11training_client = service_client.create_lora_training_client(
12    base_model="ultrasafe/usf-finance"
13)
14training_client.load_state(checkpoint_path)
15
16# Continue training seamlessly
17for step in range(5000, 10000):
18    training_client.forward_backward(batch, "cross_entropy")
19    training_client.optim_step()
20    
21print("Training resumed and completed!")Important Notes
- • Loading a checkpoint restores the exact optimizer state (momentum, learning rate position)
- • Checkpoints are tied to the specific base model and LoRA configuration
- • You cannot mix checkpoints from different base models
- • Checkpoints persist in Bios storage until explicitly deleted
Multi-Stage Training Pipeline
A common pattern: supervised fine-tuning followed by RLHF. Use checkpoints to transition between stages:
1import bios
2from bios import types
3from bios.rlhf import PPOTrainer
4
5# Stage 1: Supervised Fine-Tuning
6service_client = bios.ServiceClient()
7sft_client = service_client.create_lora_training_client(
8    base_model="ultrasafe/usf-conversation",
9    rank=16
10)
11
12# SFT training loop
13for epoch in range(3):
14    for batch in sft_dataloader:
15        sft_client.forward_backward(batch, "cross_entropy")
16        sft_client.optim_step()
17
18# Save SFT checkpoint
19sft_path = sft_client.save_state(name="sft_complete").result().path
20print(f"SFT complete: {sft_path}")
21
22# Stage 2: RLHF Training
23# Load SFT checkpoint as starting point for RLHF
24rlhf_client = service_client.load_training_client(sft_path)
25
26# Configure for PPO training
27ppo_trainer = PPOTrainer.from_training_client(rlhf_client)
28
29# RLHF training loop
30for iteration in range(1000):
31    rollouts = ppo_trainer.collect_rollouts(prompts)
32    rewards = ppo_trainer.compute_rewards(rollouts)
33    ppo_trainer.ppo_update(rollouts, rewards)
34
35# Save final RLHF model
36final_path = rlhf_client.save_state(name="rlhf_final").result().path
37print(f"RLHF complete: {final_path}")Method Comparison
| Method | What's Saved | Storage Size | Speed | Use Case | 
|---|---|---|---|---|
| save_weights_for_sampler() | Weights only | Small | Fast | Testing/evaluation | 
| save_state() | Weights + optimizer state | Larger | Slower | Resume training | 
| load_state() | Restores everything | N/A | Fast | Continue from checkpoint | 
Advanced Checkpoint Patterns
Automatic Checkpoint Management
Implement automatic checkpointing with configurable intervals and retention:
1import bios
2from collections import deque
3
4class CheckpointManager:
5    def __init__(self, training_client, max_checkpoints=5):
6        self.training_client = training_client
7        self.checkpoints = deque(maxlen=max_checkpoints)
8    
9    def save_checkpoint(self, step: int):
10        """Save checkpoint and maintain max_checkpoints limit"""
11        path = self.training_client.save_state(
12            name=f"auto_checkpoint_{step}"
13        ).result().path
14        
15        self.checkpoints.append({
16            'step': step,
17            'path': path
18        })
19        
20        return path
21    
22    def get_latest_checkpoint(self):
23        """Get most recent checkpoint"""
24        return self.checkpoints[-1] if self.checkpoints else None
25
26# Usage
27manager = CheckpointManager(training_client, max_checkpoints=5)
28
29for step in range(10000):
30    training_client.forward_backward(batch, "cross_entropy")
31    training_client.optim_step()
32    
33    # Save every 500 steps (keeps only last 5)
34    if step % 500 == 0:
35        manager.save_checkpoint(step)
36
37# Resume from latest
38latest = manager.get_latest_checkpoint()
39if latest:
40    training_client.load_state(latest['path'])
41    print(f"Resumed from step {latest['step']}")Checkpoint Metadata
Include metadata with checkpoints for better tracking:
1# Save checkpoint with metadata
2checkpoint = training_client.save_state(
3    name=f"checkpoint_step_{step}",
4    metadata={
5        "step": step,
6        "epoch": current_epoch,
7        "loss": current_loss,
8        "learning_rate": current_lr,
9        "timestamp": datetime.now().isoformat(),
10        "dataset": "financial_corpus_v2",
11        "config": training_config
12    }
13).result()
14
15print(f"Checkpoint saved: {checkpoint.path}")
16print(f"Metadata: {checkpoint.metadata}")Checkpoint Best Practices
✓ Do
- • Use descriptive checkpoint names with step/epoch numbers
- • Save regularly during long training runs
- • Use save_weights_for_sampler()for frequent evaluation
- • Use save_state()for important milestones
- • Include metadata for tracking and debugging
- • Test checkpoint loading before long training runs
✗ Don't
- • Don't use save_state()too frequently (expensive)
- • Don't reuse checkpoint names (creates confusion)
- • Don't assume checkpoints work across different base models
- • Don't forget to handle checkpoint paths as persistent identifiers
- • Don't mix optimizer types when loading (must match)
Next Steps
Learn more about advanced training techniques and deployment: