A Knowledge Distillation library for PyTorch.
compress large models into smaller ones by transferring learned representations.
Key Features
- Minimal API : small surface area, easy to learn
- Strategy : change distillation methods without changing much code
- PyTorch : works with any
nn.Module - Composable : combine multiple strategies with weighted sums into a custom one
- HuggingFace & Accelerate : third party integrations
Installation
Requires Python 3.13+ and PyTorch 2.0+.
# using uv (recommended)
uv add rictr
# using pip
pip install rictr
# using huggingface support
pip install rictr[hf]
# using accelerate support
pip install rictr[accelerate]
Quick Start
Distill teacher into a student in a few lines as such:
import torch
from rictr import Distiller, SoftTarget
# your models (must accept **kwargs in forward)
teacher = LargeModel()
student = SmallModel()
# strategy defines what to distill
strategy = SoftTarget(temperature=4.0)
# use standard pytorch optimizer as per need
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
# create distiller
distiller = Distiller(
teacher=teacher,
student=student,
strategy=strategy,
optimizer=optimizer,
)
# training loop
for batch in dataloader:
loss = distiller.distill_step(batch=batch)
print(f"Loss: {loss.item():.4f}")
Distiller
The Distiller runs the distillation loop: freezes teacher, runs forward passes, computes loss, updates student.
from rictr import Distiller
distiller = Distiller(
teacher=teacher, # frozen automatically
student=student, # trainable model
strategy=strategy, # on how to compute loss
optimizer=optimizer, # updates the student
device=torch.device("cuda"), # optional, defaults to cpu when not specified
)
What Distiller Does
- freezes teacher parameters (
requires_grad=False) - sets teacher to eval mode
- moves batch tensors to device automatically
- runs forward passes with/without gradients
- calls strategy to compute loss
- runs backward pass and optimizer step
- moves models to specified device
What Distiller Does NOT Do
- does not support training loops for now, you can use
Traineror define your own - data or batch loading
- logging
- learning rate scheduling
Trainer
For multiple epoch training with callbacks, use the Trainer class.
from rictr import Trainer
def log_callback(state, output):
if state.step % 100 == 0:
print(f"Step {state.step}: loss={output.loss:.4f}")
trainer = Trainer(distiller, callbacks=[log_callback])
epoch_losses = trainer.train(dataloader, epochs=10)
Training State
Tracks progress and is passed to callbacks.
from rictr import TrainingState
# available attributes:
state.step # current step number (int)
state.epoch # current epoch number (int)
state.best_loss # best loss seen so far (float)
state.history # list of StepOutput objects
Step Output
Returned from each training step.
from rictr import StepOutput
output.loss # loss value (float)
output.step # step number (int)
output.extras # optional dict for custom data
Callbacks
Callbacks are functions that receive (state, output).
def checkpoint_callback(state, output):
if state.step % 1000 == 0:
torch.save(student.state_dict(), f"ckpt_{state.step}.pt")
def early_stop_callback(state, output):
if output.loss < 0.01:
raise StopIteration("Target reached")
trainer = Trainer(distiller, callbacks=[
log_callback,
checkpoint_callback,
])
Model Requirements
Models must meet two requirements to work with rictr.
1. Accept **kwargs in forward()
the distiller passes batch dicts as **kwargs. your model's forward() must accept extra arguments.
class MyModel(nn.Module):
def forward(self, x, **kwargs): # accept **kwargs
return self.layers(x)
# or explicitly handle expected keys
class MyModel(nn.Module):
def forward(self, input_ids, attention_mask=None, **kwargs):
return self.transformer(input_ids, attention_mask)
2. Return Dict or Tensor
Model outputs are normalized to dicts. Supported formats:
# dict (recommended)
return {"logits": logits, "hidden_states": hidden}
# single tensor -> becomes {"logits": tensor}
return logits
# tuple/list -> first element becomes {"logits": outputs[0]}
return (logits, hidden_states)
# named tuple -> converted via _asdict()
return ModelOutput(logits=logits, hidden=hidden)
Batch Format
Batches must be dicts with tensor values. The labels key is used by SoftTarget when alpha is set.
# example collate function
def collate_fn(batch):
xs, ys = zip(*batch)
return {
"x": torch.stack(xs),
"labels": torch.stack(ys), # required if using alpha
}
Soft Target Strategy
Logit-based distillation from Hinton et al. (2015), matches softened probability distributions using temperature scaled KL divergence loss.
Pure Distillation
from rictr import SoftTarget
strategy = SoftTarget(temperature=4.0)
With Task Loss Blending
# loss = alpha * task_loss + (1 - alpha) * distill_loss
strategy = SoftTarget(
temperature=4.0,
alpha=0.5, # requires "labels" in batch
)
Parameters
| Parameter | Type | Description |
|---|---|---|
temperature | float | softmax temperature (must be > 0), higher = softer. typical: 2-20 |
alpha | float | None | task loss weight in [0, 1]. None = pure distillation |
task_loss | callable | loss function for labels. default: F.cross_entropy |
Temperature Guide
| Temperature | Effect |
|---|---|
| T = 1 | original (hard) distribution |
| T = 2-5 | moderate softening (good default) |
| T = 10-20 | very soft, reveals class relationships |
Composite Strategy
Combines multiple strategies with weighted sum.
from rictr import Composite, SoftTarget, HiddenStateDistillation
soft = SoftTarget(temperature=4.0)
hidden = HiddenStateDistillation(...)
# loss = 0.7 * soft_loss + 0.3 * hidden_loss
strategy = Composite([
(soft, 0.7),
(hidden, 0.3),
])
Custom Strategies
Implement the DistillationStrategy protocol.
from rictr import DistillationStrategy
import torch
class MyStrategy:
def __call__(
self,
*,
student_outputs: dict[str, torch.Tensor],
teacher_outputs: dict[str, torch.Tensor],
targets: dict[str, torch.Tensor] | None = None,
) -> torch.Tensor:
student_logits = student_outputs["logits"]
teacher_logits = teacher_outputs["logits"]
# your custom loss computation
loss = your_loss_function(student_logits, teacher_logits)
return loss # must be scalar tensor
Layer Mapping
LayerMap defines correspondences between teacher and student layers.
from rictr import LayerMap
layer_map = LayerMap(
pairs=[
("teacher.layer4", "student.layer2"), # (teacher, student)
("teacher.layer8", "student.layer4"),
],
)
# access layer lists
layer_map.teacher_layers # ["teacher.layer4", "teacher.layer8"]
layer_map.student_layers # ["student.layer2", "student.layer4"]
len(layer_map) # 2
With Projectors
When dimensions don't match, add projectors keyed by student layer name.
layer_map = LayerMap(
pairs=[("teacher.fc", "student.fc")],
projectors={
"student.fc": make_projector(256, 512), # 256 -> 512
},
)
Projectors
Transform student features to match teacher dimensions.
from rictr import make_projector
from rictr.alignment import LinearProjector, MLPProjector
# linear projection (default)
proj = make_projector(in_features=256, out_features=512)
# MLP with hidden layer
proj = make_projector(
in_features=256,
out_features=512,
hidden=384, # adds relu hidden layer
)
Include in Optimizer
Projectors have learnable parameters.
optimizer = torch.optim.Adam(
list(student.parameters()) + list(proj.parameters()),
lr=1e-3,
)
CNN Projectors
For convolutional features, use 1x1 convolutions.
proj = nn.Conv2d(
in_channels=32, # student channels
out_channels=128, # teacher channels
kernel_size=1,
)
Feature Extractor
Captures intermediate activations using forward hooks.
from rictr import FeatureExtractor
extractor = FeatureExtractor(
model=model,
layer_names=["encoder.layer2", "encoder.layer4"],
transform=None, # optional transform for captured features
)
# forward pass populates features
output = model(input)
# access captured features
features = extractor.features # {"encoder.layer2": tensor, ...}
# clear for next batch
extractor.clear()
# IMPORTANT: remove hooks when done
extractor.remove_hooks()
Layer Naming
Use dot notation to specify nested modules.
# find layer names
for name, module in model.named_modules():
print(name, type(module).__name__)
# examples
layer_names = [
"encoder", # direct child
"encoder.layers.0", # indexed
"encoder.layers.0.attention", # deeply nested
]
Loss Functions
Reusable loss primitives in rictr.losses.
| Function | Description |
|---|---|
kl_divergence | temperature scaled KL divergence |
mse_loss | mean squared error with optional L2 normalization |
cosine_loss | 1 - cosine similarity |
smooth_l1_loss | huber loss (less sensitive to outliers) |
attention_loss | attention transfer loss |
Usage
from rictr import kl_divergence, mse_loss, cosine_loss, attention_loss
from rictr.losses import smooth_l1_loss, compute_attention_map
# kl divergence with temperature
loss = kl_divergence(student_logits, teacher_logits, temperature=4.0)
# mean squared error with optional normalization
loss = mse_loss(student_feat, teacher_feat, normalize=True)
# cosine similarity loss
loss = cosine_loss(student_feat, teacher_feat)
# smooth l1 (huber)
loss = smooth_l1_loss(student_feat, teacher_feat, beta=1.0)
# attention transfer
s_attn = compute_attention_map(student_features) # [B,C,H,W] -> [B,H,W]
t_attn = compute_attention_map(teacher_features)
loss = attention_loss(s_attn, t_attn, normalize=True)
Metrics
Evaluation helpers in rictr.metrics.
from rictr.metrics import LossTracker, accuracy, top_k_accuracy, perplexity
# track loss during training
tracker = LossTracker()
for batch in dataloader:
loss = distiller.distill_step(batch=batch)
tracker.update(loss)
print(f"Mean: {tracker.mean():.4f}")
print(f"Min: {tracker.min():.4f}")
print(f"Last: {tracker.last():.4f}")
tracker.reset() # clear tracked values
# classification metrics
acc = accuracy(logits, labels) # top-1
acc5 = top_k_accuracy(logits, labels, k=5) # top-5
# language modeling
ppl = perplexity(logits, labels, ignore_index=-100)
Utilities
Helper functions for common operations.
Freezing Models
from rictr import freeze, unfreeze
from rictr.utils import freeze_except
freeze(teacher) # all params: requires_grad=False
unfreeze(model) # all params: requires_grad=True
# freeze all except specific layers
freeze_except(model, layer_names=["classifier", "head"])
Device Management
from rictr import auto_device, get_device
from rictr.utils import move_batch
device = auto_device() # cuda > mps > cpu
device = get_device(model) # get model's device
batch = move_batch(batch, device) # move all tensors
Model Inspection
from rictr import get_submodule
from rictr.utils import count_parameters
# count parameters
total = count_parameters(model)
trainable = count_parameters(model, trainable_only=True)
# access nested module by path
encoder = get_submodule(model, "transformer.encoder.layer.0")
Integrations
HuggingFace Transformers
from transformers import AutoModelForSequenceClassification
from rictr.integrations import HFOutputAdapter, hf_to_dict
# wrap hugging face model for rictr
hf_model = AutoModelForSequenceClassification.from_pretrained("bert-base")
teacher = HFOutputAdapter(hf_model)
# or convert outputs manually
outputs = hf_model(**inputs)
outputs_dict = hf_to_dict(outputs)
# {"logits": tensor, "hidden_states": tuple, ...}
Accelerate (Multi-GPU)
from accelerate import Accelerator
from rictr.integrations import AcceleratedDistiller
accelerator = Accelerator()
distiller = AcceleratedDistiller(
teacher=teacher,
student=student,
strategy=strategy,
optimizer=optimizer,
accelerator=accelerator,
)
dataloader = accelerator.prepare(dataloader)
for batch in dataloader:
loss = distiller.distill_step(batch=batch)
AcceleratedDistiller wraps accelerate for multi-GPU and mixed precision.
MLP Distillation Example
Simple classification with SoftTarget.
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from rictr import Distiller, Trainer, SoftTarget
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, num_classes):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, num_classes),
)
def forward(self, x, **kwargs):
return self.net(x)
# models
teacher = MLP(64, 256, 10) # large
student = MLP(64, 64, 10) # small
# setup
strategy = SoftTarget(temperature=4.0, alpha=0.5)
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
distiller = Distiller(
teacher=teacher, student=student,
strategy=strategy, optimizer=optimizer,
)
# your collate_fn should return {"x": tensor, "labels": tensor}
dataset = TensorDataset(torch.randn(1000, 64), torch.randint(0, 10, (1000,)))
dataloader = DataLoader(dataset, batch_size=32, collate_fn=your_collate_fn)
# train
trainer = Trainer(distiller)
losses = trainer.train(dataloader, epochs=5)
CNN Distillation Example
Vision model with feature matching.
from rictr import (
Distiller, Composite, SoftTarget,
HiddenStateDistillation, LayerMap,
)
import torch.nn as nn
# create 1x1 conv projector for channel alignment
projector = nn.Conv2d(32, 128, kernel_size=1)
layer_map = LayerMap(
pairs=[("conv2", "conv2")],
projectors={"conv2": projector},
)
# combine logit and feature distillation
soft = SoftTarget(temperature=4.0)
hidden = HiddenStateDistillation(
teacher=teacher, student=student,
layer_map=layer_map,
)
strategy = Composite([
(soft, 0.5),
(hidden, 0.5),
])
# include projector in optimizer
optimizer = torch.optim.Adam(
list(student.parameters()) + list(projector.parameters()),
lr=1e-3,
)
# train...
# cleanup
hidden.remove_hooks()
Transformer Distillation Example
Distilling transformer models with layer mapping.
from rictr import (
Distiller, Composite, SoftTarget,
HiddenStateDistillation, LayerMap, make_projector,
)
# teacher: 256 dim, student: 64 dim
proj_encoder = make_projector(64, 256)
layer_map = LayerMap(
pairs=[("encoder", "encoder")],
projectors={"encoder": proj_encoder},
)
soft = SoftTarget(temperature=4.0)
hidden = HiddenStateDistillation(
teacher=teacher, student=student,
layer_map=layer_map,
)
strategy = Composite([
(soft, 0.7),
(hidden, 0.3),
])
optimizer = torch.optim.AdamW(
list(student.parameters()) + list(proj_encoder.parameters()),
lr=5e-4,
weight_decay=0.01,
)
# train and cleanup ...
hidden.remove_hooks()
Architecture
rictr/
├── core/ # Distiller, Trainer, TrainingState, StepOutput
├── strategies/ # SoftTarget, HiddenStateDistillation, Composite
├── alignment/ # LayerMap, FeatureExtractor, projectors
├── losses/ # kl_divergence, mse_loss, attention_loss, etc.
├── metrics/ # LossTracker, accuracy, perplexity
├── utils/ # freeze, auto_device, count_parameters
└── integrations/ # HFOutputAdapter, AcceleratedDistiller
Public API Exports
from rictr import (
# core
Distiller, Trainer, TrainingState, StepOutput,
# strategies
DistillationStrategy, SoftTarget, HiddenStateDistillation, Composite,
# alignment
LayerMap, FeatureExtractor, make_projector, get_submodule,
# losses
kl_divergence, mse_loss, cosine_loss, attention_loss,
# utils
freeze, unfreeze, get_device, auto_device,
)