A Knowledge Distillation library for PyTorch.

compress large models into smaller ones by transferring learned representations.

Key Features

  • Minimal API : small surface area, easy to learn
  • Strategy : change distillation methods without changing much code
  • PyTorch : works with any nn.Module
  • Composable : combine multiple strategies with weighted sums into a custom one
  • HuggingFace & Accelerate : third party integrations

Installation

Requires Python 3.13+ and PyTorch 2.0+.

# using uv (recommended)
uv add rictr

# using pip
pip install rictr

# using huggingface support
pip install rictr[hf]

# using accelerate support
pip install rictr[accelerate]

Quick Start

Distill teacher into a student in a few lines as such:

import torch
from rictr import Distiller, SoftTarget

# your models (must accept **kwargs in forward)
teacher = LargeModel()
student = SmallModel()

# strategy defines what to distill
strategy = SoftTarget(temperature=4.0)

# use standard pytorch optimizer as per need
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)

# create distiller
distiller = Distiller(
    teacher=teacher,
    student=student,
    strategy=strategy,
    optimizer=optimizer,
)

# training loop
for batch in dataloader:
    loss = distiller.distill_step(batch=batch)
    print(f"Loss: {loss.item():.4f}")

Distiller

The Distiller runs the distillation loop: freezes teacher, runs forward passes, computes loss, updates student.

from rictr import Distiller

distiller = Distiller(
    teacher=teacher,          # frozen automatically
    student=student,          # trainable model
    strategy=strategy,        # on how to compute loss
    optimizer=optimizer,      # updates the student
    device=torch.device("cuda"),  # optional, defaults to cpu when not specified
)

What Distiller Does

  • freezes teacher parameters (requires_grad=False)
  • sets teacher to eval mode
  • moves batch tensors to device automatically
  • runs forward passes with/without gradients
  • calls strategy to compute loss
  • runs backward pass and optimizer step
  • moves models to specified device

What Distiller Does NOT Do

  • does not support training loops for now, you can use Trainer or define your own
  • data or batch loading
  • logging
  • learning rate scheduling

Trainer

For multiple epoch training with callbacks, use the Trainer class.

from rictr import Trainer

def log_callback(state, output):
    if state.step % 100 == 0:
        print(f"Step {state.step}: loss={output.loss:.4f}")

trainer = Trainer(distiller, callbacks=[log_callback])
epoch_losses = trainer.train(dataloader, epochs=10)

Training State

Tracks progress and is passed to callbacks.

from rictr import TrainingState

# available attributes:
state.step       # current step number (int)
state.epoch      # current epoch number (int)
state.best_loss  # best loss seen so far (float)
state.history    # list of StepOutput objects

Step Output

Returned from each training step.

from rictr import StepOutput

output.loss    # loss value (float)
output.step    # step number (int)
output.extras  # optional dict for custom data

Callbacks

Callbacks are functions that receive (state, output).

def checkpoint_callback(state, output):
    if state.step % 1000 == 0:
        torch.save(student.state_dict(), f"ckpt_{state.step}.pt")

def early_stop_callback(state, output):
    if output.loss < 0.01:
        raise StopIteration("Target reached")

trainer = Trainer(distiller, callbacks=[
    log_callback,
    checkpoint_callback,
])

Model Requirements

Models must meet two requirements to work with rictr.

1. Accept **kwargs in forward()

the distiller passes batch dicts as **kwargs. your model's forward() must accept extra arguments.

class MyModel(nn.Module):
    def forward(self, x, **kwargs):  # accept **kwargs
        return self.layers(x)

# or explicitly handle expected keys
class MyModel(nn.Module):
    def forward(self, input_ids, attention_mask=None, **kwargs):
        return self.transformer(input_ids, attention_mask)

2. Return Dict or Tensor

Model outputs are normalized to dicts. Supported formats:

# dict (recommended)
return {"logits": logits, "hidden_states": hidden}

# single tensor -> becomes {"logits": tensor}
return logits

# tuple/list -> first element becomes {"logits": outputs[0]}
return (logits, hidden_states)

# named tuple -> converted via _asdict()
return ModelOutput(logits=logits, hidden=hidden)

Batch Format

Batches must be dicts with tensor values. The labels key is used by SoftTarget when alpha is set.

# example collate function
def collate_fn(batch):
    xs, ys = zip(*batch)
    return {
        "x": torch.stack(xs),
        "labels": torch.stack(ys),  # required if using alpha
    }

Soft Target Strategy

Logit-based distillation from Hinton et al. (2015), matches softened probability distributions using temperature scaled KL divergence loss.

Pure Distillation

from rictr import SoftTarget

strategy = SoftTarget(temperature=4.0)

With Task Loss Blending

# loss = alpha * task_loss + (1 - alpha) * distill_loss
strategy = SoftTarget(
    temperature=4.0,
    alpha=0.5,  # requires "labels" in batch
)

Parameters

ParameterTypeDescription
temperaturefloatsoftmax temperature (must be > 0), higher = softer. typical: 2-20
alphafloat | Nonetask loss weight in [0, 1]. None = pure distillation
task_losscallableloss function for labels. default: F.cross_entropy

Temperature Guide

TemperatureEffect
T = 1original (hard) distribution
T = 2-5moderate softening (good default)
T = 10-20very soft, reveals class relationships

Hidden State Distillation

Feature based distillation from FitNets (Romero et al., 2014), matches intermediate layer activations between teacher and student.

from rictr import HiddenStateDistillation, LayerMap, make_projector

# define layer pairs: (teacher_layer, student_layer)
layer_map = LayerMap(
    pairs=[
        ("encoder.layer4", "encoder.layer2"),
    ],
    projectors={
        # project student dim (256) to teacher dim (512)
        "encoder.layer2": make_projector(256, 512),
    },
)

strategy = HiddenStateDistillation(
    teacher=teacher,
    student=student,
    layer_map=layer_map,
    loss_fn=F.mse_loss,  # default
)

# IMPORTANT: clean up hooks when done!
strategy.remove_hooks()

Parameters

ParameterTypeDescription
teachernn.Moduleteacher model
studentnn.Modulestudent model
layer_mapLayerMapdefines layer pairs and projectors
loss_fncallableloss for comparing features. default: F.mse_loss

Composite Strategy

Combines multiple strategies with weighted sum.

from rictr import Composite, SoftTarget, HiddenStateDistillation

soft = SoftTarget(temperature=4.0)
hidden = HiddenStateDistillation(...)

# loss = 0.7 * soft_loss + 0.3 * hidden_loss
strategy = Composite([
    (soft, 0.7),
    (hidden, 0.3),
])

Custom Strategies

Implement the DistillationStrategy protocol.

from rictr import DistillationStrategy
import torch

class MyStrategy:
    
    def __call__(
        self,
        *,
        student_outputs: dict[str, torch.Tensor],
        teacher_outputs: dict[str, torch.Tensor],
        targets: dict[str, torch.Tensor] | None = None,
    ) -> torch.Tensor:
        student_logits = student_outputs["logits"]
        teacher_logits = teacher_outputs["logits"]
        
        # your custom loss computation
        loss = your_loss_function(student_logits, teacher_logits)
        
        return loss  # must be scalar tensor

Layer Mapping

LayerMap defines correspondences between teacher and student layers.

from rictr import LayerMap

layer_map = LayerMap(
    pairs=[
        ("teacher.layer4", "student.layer2"),  # (teacher, student)
        ("teacher.layer8", "student.layer4"),
    ],
)

# access layer lists
layer_map.teacher_layers  # ["teacher.layer4", "teacher.layer8"]
layer_map.student_layers  # ["student.layer2", "student.layer4"]
len(layer_map)            # 2

With Projectors

When dimensions don't match, add projectors keyed by student layer name.

layer_map = LayerMap(
    pairs=[("teacher.fc", "student.fc")],
    projectors={
        "student.fc": make_projector(256, 512),  # 256 -> 512
    },
)

Projectors

Transform student features to match teacher dimensions.

from rictr import make_projector
from rictr.alignment import LinearProjector, MLPProjector

# linear projection (default)
proj = make_projector(in_features=256, out_features=512)

# MLP with hidden layer
proj = make_projector(
    in_features=256,
    out_features=512,
    hidden=384,  # adds relu hidden layer
)

Include in Optimizer

Projectors have learnable parameters.

optimizer = torch.optim.Adam(
    list(student.parameters()) + list(proj.parameters()),
    lr=1e-3,
)

CNN Projectors

For convolutional features, use 1x1 convolutions.

proj = nn.Conv2d(
    in_channels=32,   # student channels
    out_channels=128, # teacher channels
    kernel_size=1,
)

Feature Extractor

Captures intermediate activations using forward hooks.

from rictr import FeatureExtractor

extractor = FeatureExtractor(
    model=model,
    layer_names=["encoder.layer2", "encoder.layer4"],
    transform=None,  # optional transform for captured features
)

# forward pass populates features
output = model(input)

# access captured features
features = extractor.features  # {"encoder.layer2": tensor, ...}

# clear for next batch
extractor.clear()

# IMPORTANT: remove hooks when done
extractor.remove_hooks()

Layer Naming

Use dot notation to specify nested modules.

# find layer names
for name, module in model.named_modules():
    print(name, type(module).__name__)

# examples
layer_names = [
    "encoder",                      # direct child
    "encoder.layers.0",             # indexed
    "encoder.layers.0.attention",   # deeply nested
]

Loss Functions

Reusable loss primitives in rictr.losses.

FunctionDescription
kl_divergencetemperature scaled KL divergence
mse_lossmean squared error with optional L2 normalization
cosine_loss1 - cosine similarity
smooth_l1_losshuber loss (less sensitive to outliers)
attention_lossattention transfer loss

Usage

from rictr import kl_divergence, mse_loss, cosine_loss, attention_loss
from rictr.losses import smooth_l1_loss, compute_attention_map

# kl divergence with temperature
loss = kl_divergence(student_logits, teacher_logits, temperature=4.0)

# mean squared error with optional normalization
loss = mse_loss(student_feat, teacher_feat, normalize=True)

# cosine similarity loss
loss = cosine_loss(student_feat, teacher_feat)

# smooth l1 (huber)
loss = smooth_l1_loss(student_feat, teacher_feat, beta=1.0)

# attention transfer
s_attn = compute_attention_map(student_features)  # [B,C,H,W] -> [B,H,W]
t_attn = compute_attention_map(teacher_features)
loss = attention_loss(s_attn, t_attn, normalize=True)

Metrics

Evaluation helpers in rictr.metrics.

from rictr.metrics import LossTracker, accuracy, top_k_accuracy, perplexity

# track loss during training
tracker = LossTracker()
for batch in dataloader:
    loss = distiller.distill_step(batch=batch)
    tracker.update(loss)

print(f"Mean: {tracker.mean():.4f}")
print(f"Min:  {tracker.min():.4f}")
print(f"Last: {tracker.last():.4f}")
tracker.reset()  # clear tracked values

# classification metrics
acc = accuracy(logits, labels)           # top-1
acc5 = top_k_accuracy(logits, labels, k=5)  # top-5

# language modeling
ppl = perplexity(logits, labels, ignore_index=-100)

Utilities

Helper functions for common operations.

Freezing Models

from rictr import freeze, unfreeze
from rictr.utils import freeze_except

freeze(teacher)  # all params: requires_grad=False
unfreeze(model)  # all params: requires_grad=True

# freeze all except specific layers
freeze_except(model, layer_names=["classifier", "head"])

Device Management

from rictr import auto_device, get_device
from rictr.utils import move_batch

device = auto_device()    # cuda > mps > cpu
device = get_device(model)  # get model's device

batch = move_batch(batch, device)  # move all tensors

Model Inspection

from rictr import get_submodule
from rictr.utils import count_parameters

# count parameters
total = count_parameters(model)
trainable = count_parameters(model, trainable_only=True)

# access nested module by path
encoder = get_submodule(model, "transformer.encoder.layer.0")

Integrations

HuggingFace Transformers

from transformers import AutoModelForSequenceClassification
from rictr.integrations import HFOutputAdapter, hf_to_dict

# wrap hugging face model for rictr 
hf_model = AutoModelForSequenceClassification.from_pretrained("bert-base")
teacher = HFOutputAdapter(hf_model)

# or convert outputs manually
outputs = hf_model(**inputs)
outputs_dict = hf_to_dict(outputs)
# {"logits": tensor, "hidden_states": tuple, ...}

Accelerate (Multi-GPU)

from accelerate import Accelerator
from rictr.integrations import AcceleratedDistiller

accelerator = Accelerator()

distiller = AcceleratedDistiller(
    teacher=teacher,
    student=student,
    strategy=strategy,
    optimizer=optimizer,
    accelerator=accelerator,
)

dataloader = accelerator.prepare(dataloader)
for batch in dataloader:
    loss = distiller.distill_step(batch=batch)

AcceleratedDistiller wraps accelerate for multi-GPU and mixed precision.

MLP Distillation Example

Simple classification with SoftTarget.

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from rictr import Distiller, Trainer, SoftTarget

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes),
        )

    def forward(self, x, **kwargs):
        return self.net(x)

# models
teacher = MLP(64, 256, 10)  # large
student = MLP(64, 64, 10)   # small

# setup
strategy = SoftTarget(temperature=4.0, alpha=0.5)
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
distiller = Distiller(
    teacher=teacher, student=student,
    strategy=strategy, optimizer=optimizer,
)

# your collate_fn should return {"x": tensor, "labels": tensor}
dataset = TensorDataset(torch.randn(1000, 64), torch.randint(0, 10, (1000,)))
dataloader = DataLoader(dataset, batch_size=32, collate_fn=your_collate_fn)

# train
trainer = Trainer(distiller)
losses = trainer.train(dataloader, epochs=5)

CNN Distillation Example

Vision model with feature matching.

from rictr import (
    Distiller, Composite, SoftTarget,
    HiddenStateDistillation, LayerMap,
)
import torch.nn as nn

# create 1x1 conv projector for channel alignment
projector = nn.Conv2d(32, 128, kernel_size=1)

layer_map = LayerMap(
    pairs=[("conv2", "conv2")],
    projectors={"conv2": projector},
)

# combine logit and feature distillation
soft = SoftTarget(temperature=4.0)
hidden = HiddenStateDistillation(
    teacher=teacher, student=student,
    layer_map=layer_map,
)

strategy = Composite([
    (soft, 0.5),
    (hidden, 0.5),
])

# include projector in optimizer
optimizer = torch.optim.Adam(
    list(student.parameters()) + list(projector.parameters()),
    lr=1e-3,
)

# train...

# cleanup
hidden.remove_hooks()

Transformer Distillation Example

Distilling transformer models with layer mapping.

from rictr import (
    Distiller, Composite, SoftTarget,
    HiddenStateDistillation, LayerMap, make_projector,
)

# teacher: 256 dim, student: 64 dim
proj_encoder = make_projector(64, 256)

layer_map = LayerMap(
    pairs=[("encoder", "encoder")],
    projectors={"encoder": proj_encoder},
)

soft = SoftTarget(temperature=4.0)
hidden = HiddenStateDistillation(
    teacher=teacher, student=student,
    layer_map=layer_map,
)

strategy = Composite([
    (soft, 0.7),
    (hidden, 0.3),
])

optimizer = torch.optim.AdamW(
    list(student.parameters()) + list(proj_encoder.parameters()),
    lr=5e-4,
    weight_decay=0.01,
)

# train and cleanup ...
hidden.remove_hooks()

Architecture

rictr/
├── core/           # Distiller, Trainer, TrainingState, StepOutput
├── strategies/     # SoftTarget, HiddenStateDistillation, Composite
├── alignment/      # LayerMap, FeatureExtractor, projectors
├── losses/         # kl_divergence, mse_loss, attention_loss, etc.
├── metrics/        # LossTracker, accuracy, perplexity
├── utils/          # freeze, auto_device, count_parameters
└── integrations/   # HFOutputAdapter, AcceleratedDistiller

Public API Exports

from rictr import (
    # core
    Distiller, Trainer, TrainingState, StepOutput,
    # strategies
    DistillationStrategy, SoftTarget, HiddenStateDistillation, Composite,
    # alignment
    LayerMap, FeatureExtractor, make_projector, get_submodule,
    # losses
    kl_divergence, mse_loss, cosine_loss, attention_loss,
    # utils
    freeze, unfreeze, get_device, auto_device,
)