Saving and Loading State

During training, you'll need to save checkpoints for two primary purposes: sampling (to test your model) and resuming training (to continue from where you left off). The TrainingClient provides three methods to handle these use cases efficiently.

State Management Methods

1.
save_weights_for_sampler()Saves model weights only (fast, smaller storage)
2.
save_state()Saves weights + optimizer state (for resuming training)
3.
load_state()Restores weights + optimizer state (full training resume)

Checkpoint Naming and Paths

Both save_* functions require a name parameter—a string identifier for the checkpoint within your training run.

Checkpoint Naming Examples
1# Sequential numbering
2training_client.save_state(name="0000")
3training_client.save_state(name="0001")
4
5# Step-based naming
6training_client.save_state(name="step_1000")
7training_client.save_state(name="step_2000")
8
9# Epoch-based naming
10training_client.save_state(name=f"epoch_{epoch}_step_{step}")
11
12# Descriptive naming
13training_client.save_state(name="finance_model_v1_final")

Checkpoint Paths

The return value contains a path field—a fully-qualified identifier that persists across sessions:

bios://<model_id>/<checkpoint_name>

Example: bios://usf-finance-abc123/step_5000

This path can be used later by a new ServiceClient or TrainingClient to restore the checkpoint.

Saving for Sampling

Use save_weights_for_sampler() when you want to test your model during training without preserving optimizer state. This is faster and requires less storage.

Save Weights for Sampling
1import bios
2
3# Setup
4service_client = bios.ServiceClient()
5training_client = service_client.create_lora_training_client(
6    base_model="ultrasafe/usf-finance",
7    rank=32
8)
9
10# ... perform training steps ...
11
12# Save checkpoint for sampling
13sampling_path = training_client.save_weights_for_sampler(
14    name="checkpoint_1000"
15).result().path
16
17print(f"Saved weights at: {sampling_path}")
18
19# Create sampling client with that checkpoint
20sampling_client = service_client.create_sampling_client(
21    model_path=sampling_path
22)
23
24# Test the model
25from bios import types
26prompt = types.ModelInput.from_ints(
27    tokenizer.encode("Analyze this financial report...")
28)
29result = sampling_client.sample(
30    prompt=prompt,
31    sampling_params=types.SamplingParams(max_tokens=256, temperature=0.7)
32).result()
33
34print(tokenizer.decode(result.sequences[0].tokens))

Shortcut Method

Combine saving and sampling client creation in one step:

Shortcut: Save and Sample
1# Shortcut: save weights and get sampling client in one call
2sampling_client = training_client.save_weights_and_get_sampling_client(
3    name="checkpoint_1000"
4)
5
6# Immediately ready to sample
7result = sampling_client.sample(prompt, params).result()
8print(tokenizer.decode(result.sequences[0].tokens))

✓ When to use save_weights_for_sampler()

  • • Testing model quality during training (intermediate checkpoints)
  • • Generating samples for evaluation or human review
  • • Creating demonstration outputs
  • • When you don't need to resume training from this point

Saving to Resume Training

Use save_state() and load_state() when you need to pause and continue training with full optimizer state preserved.

Save and Load Training State
1import bios
2
3# Setup training client
4service_client = bios.ServiceClient()
5training_client = service_client.create_lora_training_client(
6    base_model="ultrasafe/usf-healthcare",
7    rank=16
8)
9
10# Train for some steps
11for step in range(1000):
12    training_client.forward_backward(batch, "cross_entropy")
13    training_client.optim_step()
14
15# Save complete training state (weights + optimizer)
16resume_path = training_client.save_state(
17    name="checkpoint_step_1000"
18).result().path
19
20print(f"Saved training state at: {resume_path}")
21
22# Later: Load the checkpoint to resume training
23training_client.load_state(resume_path)
24
25# Continue training from exactly where you left off
26for step in range(1000, 2000):
27    training_client.forward_backward(batch, "cross_entropy")
28    training_client.optim_step()

What's Saved in save_state()?

  • Model Weights: All LoRA adapter parameters
  • Optimizer State: Adam momentum buffers, variance estimates
  • Learning Rate Schedule: Current step and scheduler state
  • Training Metadata: Configuration, hyperparameters, step count

✓ When to use save_state() and load_state()

  • Multi-stage training: Supervised learning followed by reinforcement learning
  • Hyperparameter adjustment: Changing learning rate or other settings mid-training
  • Recovery from interruptions: Hardware failures, timeouts, or manual stops
  • Exact optimizer state: Preserving momentum, variance, and learning rate schedules
  • Experiment forking: Branching from a checkpoint to try different configurations

Complete Checkpoint Management Example

Here's a comprehensive example showing both approaches:

Training with Checkpoints
1import bios
2from bios import types
3
4# Initialize training
5service_client = bios.ServiceClient()
6training_client = service_client.create_lora_training_client(
7    base_model="ultrasafe/usf-code",
8    rank=16
9)
10
11# Training loop with periodic checkpointing
12checkpoint_interval = 500
13sampling_interval = 1000
14
15for step in range(10000):
16    # Training step
17    training_client.forward_backward(batches[step], "cross_entropy")
18    training_client.optim_step()
19    
20    # Periodic full checkpoint for resumption
21    if step % checkpoint_interval == 0:
22        state_path = training_client.save_state(
23            name=f"step_{step}"
24        ).result().path
25        print(f"Step {step}: Saved state at {state_path}")
26    
27    # Periodic sampling checkpoint for evaluation
28    if step % sampling_interval == 0:
29        # Quick save for sampling (doesn't save optimizer state)
30        sampling_client = training_client.save_weights_and_get_sampling_client(
31            name=f"eval_{step}"
32        )
33        
34        # Test the model
35        test_prompt = types.ModelInput.from_ints(
36            tokenizer.encode("Implement a binary search in Python:")
37        )
38        result = sampling_client.sample(
39            prompt=test_prompt,
40            sampling_params=types.SamplingParams(max_tokens=512, temperature=0.3)
41        ).result()
42        
43        print(f"Step {step}: Generated code sample")
44        print(tokenizer.decode(result.sequences[0].tokens))
45
46print("Training complete!")
47
48# Final model save
49final_path = training_client.save_state(name="final_model").result().path
50print(f"Final model saved at: {final_path}")

Resuming Training from Checkpoint

Load a previously saved checkpoint to continue training with exact optimizer state preservation:

Resume from Checkpoint
1import bios
2
3# Create new service client (e.g., after restart)
4service_client = bios.ServiceClient()
5
6# Option 1: Load from path string
7checkpoint_path = "bios://usf-finance-abc123/step_5000"
8training_client = service_client.load_training_client(checkpoint_path)
9
10# Option 2: Create new client and load state
11training_client = service_client.create_lora_training_client(
12    base_model="ultrasafe/usf-finance"
13)
14training_client.load_state(checkpoint_path)
15
16# Continue training seamlessly
17for step in range(5000, 10000):
18    training_client.forward_backward(batch, "cross_entropy")
19    training_client.optim_step()
20    
21print("Training resumed and completed!")

Important Notes

  • • Loading a checkpoint restores the exact optimizer state (momentum, learning rate position)
  • • Checkpoints are tied to the specific base model and LoRA configuration
  • • You cannot mix checkpoints from different base models
  • • Checkpoints persist in Bios storage until explicitly deleted

Multi-Stage Training Pipeline

A common pattern: supervised fine-tuning followed by RLHF. Use checkpoints to transition between stages:

Multi-Stage Training
1import bios
2from bios import types
3from bios.rlhf import PPOTrainer
4
5# Stage 1: Supervised Fine-Tuning
6service_client = bios.ServiceClient()
7sft_client = service_client.create_lora_training_client(
8    base_model="ultrasafe/usf-conversation",
9    rank=16
10)
11
12# SFT training loop
13for epoch in range(3):
14    for batch in sft_dataloader:
15        sft_client.forward_backward(batch, "cross_entropy")
16        sft_client.optim_step()
17
18# Save SFT checkpoint
19sft_path = sft_client.save_state(name="sft_complete").result().path
20print(f"SFT complete: {sft_path}")
21
22# Stage 2: RLHF Training
23# Load SFT checkpoint as starting point for RLHF
24rlhf_client = service_client.load_training_client(sft_path)
25
26# Configure for PPO training
27ppo_trainer = PPOTrainer.from_training_client(rlhf_client)
28
29# RLHF training loop
30for iteration in range(1000):
31    rollouts = ppo_trainer.collect_rollouts(prompts)
32    rewards = ppo_trainer.compute_rewards(rollouts)
33    ppo_trainer.ppo_update(rollouts, rewards)
34
35# Save final RLHF model
36final_path = rlhf_client.save_state(name="rlhf_final").result().path
37print(f"RLHF complete: {final_path}")

Method Comparison

Downloading Weights

You can download trained model weights from any Bios checkpoint to use outside of the Bios platform—for example, with your own inference infrastructure or for deployment to production environments.

Download Checkpoint Archive
1import bios
2
3# Create service client
4sc = bios.ServiceClient()
5
6# Create REST client for downloads
7rc = sc.create_rest_client()
8
9# Download checkpoint archive
10checkpoint_path = "bios://<unique_id>/sampler_weights/final"
11future = rc.download_checkpoint_archive_from_bios_path(checkpoint_path)
12archive_data = future.result()
13
14# Save to local file
15with open("model-checkpoint.tar.gz", "wb") as f:
16    f.write(archive_data)
17
18print("✓ Checkpoint downloaded: model-checkpoint.tar.gz")

Archive Contents

The downloaded archive (model-checkpoint.tar.gz) contains:

  • LoRA adapter weights in SafeTensors or PyTorch format
  • adapter_config.json with LoRA configuration (rank, alpha, target modules)
  • tokenizer files for the base model
  • metadata.json with training information and base model reference

Extract and Use Weights

After downloading, extract the archive and load the weights with standard ML frameworks:

Extract and Load Weights
1import tarfile
2from peft import PeftModel
3from transformers import AutoModelForCausalLM, AutoTokenizer
4
5# Extract the archive
6with tarfile.open("model-checkpoint.tar.gz", "r:gz") as tar:
7    tar.extractall("./downloaded_model")
8
9# Load base model
10base_model = AutoModelForCausalLM.from_pretrained(
11    "ultrasafe/usf-finance-base"  # Base model reference
12)
13
14# Load LoRA adapter
15model = PeftModel.from_pretrained(
16    base_model,
17    "./downloaded_model/adapter"
18)
19
20# Load tokenizer
21tokenizer = AutoTokenizer.from_pretrained("./downloaded_model/tokenizer")
22
23# Use the model for inference
24inputs = tokenizer("Analyze Q3 earnings for...", return_tensors="pt")
25outputs = model.generate(**inputs, max_new_tokens=256)
26print(tokenizer.decode(outputs[0]))

Deploy to Production

Use downloaded weights with your preferred inference provider:

Local Deployment

  • • Load with HuggingFace Transformers + PEFT
  • • Deploy with vLLM or TGI for production inference
  • • Integrate with LangChain or other frameworks

Cloud Deployment

  • • Upload to your cloud storage (S3, GCS, Azure)
  • • Deploy on managed inference platforms
  • • Use with serverless functions or containers
MethodWhat's SavedStorage SizeSpeedUse Case
save_weights_for_sampler()Weights onlySmallFastTesting/evaluation
save_state()Weights + optimizer stateLargerSlowerResume training
load_state()Restores everythingN/AFastContinue from checkpoint

Advanced Checkpoint Patterns

Automatic Checkpoint Management

Implement automatic checkpointing with configurable intervals and retention:

Automated Checkpointing
1import bios
2from collections import deque
3
4class CheckpointManager:
5    def __init__(self, training_client, max_checkpoints=5):
6        self.training_client = training_client
7        self.checkpoints = deque(maxlen=max_checkpoints)
8    
9    def save_checkpoint(self, step: int):
10        """Save checkpoint and maintain max_checkpoints limit"""
11        path = self.training_client.save_state(
12            name=f"auto_checkpoint_{step}"
13        ).result().path
14        
15        self.checkpoints.append({
16            'step': step,
17            'path': path
18        })
19        
20        return path
21    
22    def get_latest_checkpoint(self):
23        """Get most recent checkpoint"""
24        return self.checkpoints[-1] if self.checkpoints else None
25
26# Usage
27manager = CheckpointManager(training_client, max_checkpoints=5)
28
29for step in range(10000):
30    training_client.forward_backward(batch, "cross_entropy")
31    training_client.optim_step()
32    
33    # Save every 500 steps (keeps only last 5)
34    if step % 500 == 0:
35        manager.save_checkpoint(step)
36
37# Resume from latest
38latest = manager.get_latest_checkpoint()
39if latest:
40    training_client.load_state(latest['path'])
41    print(f"Resumed from step {latest['step']}")

Checkpoint Metadata

Include metadata with checkpoints for better tracking:

Checkpoint with Metadata
1# Save checkpoint with metadata
2checkpoint = training_client.save_state(
3    name=f"checkpoint_step_{step}",
4    metadata={
5        "step": step,
6        "epoch": current_epoch,
7        "loss": current_loss,
8        "learning_rate": current_lr,
9        "timestamp": datetime.now().isoformat(),
10        "dataset": "financial_corpus_v2",
11        "config": training_config
12    }
13).result()
14
15print(f"Checkpoint saved: {checkpoint.path}")
16print(f"Metadata: {checkpoint.metadata}")

Checkpoint Best Practices

✓ Do

  • • Use descriptive checkpoint names with step/epoch numbers
  • • Save regularly during long training runs
  • • Use save_weights_for_sampler() for frequent evaluation
  • • Use save_state() for important milestones
  • • Include metadata for tracking and debugging
  • • Test checkpoint loading before long training runs

✗ Don't

  • • Don't use save_state() too frequently (expensive)
  • • Don't reuse checkpoint names (creates confusion)
  • • Don't assume checkpoints work across different base models
  • • Don't forget to handle checkpoint paths as persistent identifiers
  • • Don't mix optimizer types when loading (must match)