PyanNet Model Training Speaker Diarization¶
This process highlights the steps taken for Model Training on the CallHome Dataset. For this particular dataset we used the English version of the CallHome Dataset. The Model Training Architecture, Loss Functions, Optimisation Techniques, Data Augmentation and Metrics Used.
Segmentation Model Configuration Explained¶
Overview¶
Model Architecture¶
- SegmentationModel: This is a wrapper for the PyanNet segmentation model used for speaker diarization tasks. Inherits from Pretrained model to be compatible with the HF Trainer. Can be used to train segmentation models to be used for the "SpeakerDiarisation Task" in pyannote.
Forward
forward
: Forward pass function of the Pretrained Model.
Parameters:
waveforms(torch.tensor)
: A tensor containing audio data to be processed by the model and ensures the waveforms parameter is a PyTorch tensor.
labels
: Ground truth labels for Training. Defaults to None.
nb_speakers
: Number of speakers. Defaults to None
Returns: A dictionary with loss(if predicted) and predictions.
Setup loss function
setup_loss_func
: Sets up the loss function especially when using the powerset classes. ie self.specifications.powerset=True
Segmentation Loss Function
segmentation_loss
: Defines the permutation-invariant segmentation loss. Computes the loss using either nll_loss
(negative log likelihood) for powerset
or binary_cross_entropy
Parameters:
permutated_prediction
: Prediction after permutation. Type: torch.Tensor
target
: Ground truth labels. Type: torch.Tensor
weight
: Type: Optional[torch.Tensor]
Returns: Permutation-invariant segmentation loss. torch.Tensor
To pyannote
to_pyannote_model
: Converts the current model to a pyannote segmentation model for use in pyannote pipelines
class SegmentationModel(PreTrainedModel):
config_class = SegmentationModelConfig
def __init__(
self,
config=SegmentationModelConfig(),
):
super().__init__(config)
self.model = PyanNet_nn(sincnet={"stride": 10})
self.weigh_by_cardinality = config.weigh_by_cardinality
self.max_speakers_per_frame = config.max_speakers_per_frame
self.chunk_duration = config.chunk_duration
self.min_duration = config.min_duration
self.warm_up = config.warm_up
self.max_speakers_per_chunk = config.max_speakers_per_chunk
self.specifications = Specifications(
problem=Problem.MULTI_LABEL_CLASSIFICATION
if self.max_speakers_per_frame is None
else Problem.MONO_LABEL_CLASSIFICATION,
resolution=Resolution.FRAME,
duration=self.chunk_duration,
min_duration=self.min_duration,
warm_up=self.warm_up,
classes=[f"speaker#{i+1}" for i in range(self.max_speakers_per_chunk)],
powerset_max_classes=self.max_speakers_per_frame,
permutation_invariant=True,
)
self.model.specifications = self.specifications
self.model.build()
self.setup_loss_func()
def forward(self, waveforms, labels=None, nb_speakers=None):
prediction = self.model(waveforms.unsqueeze(1))
batch_size, num_frames, _ = prediction.shape
if labels is not None:
weight = torch.ones(batch_size, num_frames, 1, device=waveforms.device)
warm_up_left = round(self.specifications.warm_up[0] / self.specifications.duration * num_frames)
weight[:, :warm_up_left] = 0.0
warm_up_right = round(self.specifications.warm_up[1] / self.specifications.duration * num_frames)
weight[:, num_frames - warm_up_right :] = 0.0
if self.specifications.powerset:
multilabel = self.model.powerset.to_multilabel(prediction)
permutated_target, _ = permutate(multilabel, labels)
permutated_target_powerset = self.model.powerset.to_powerset(permutated_target.float())
loss = self.segmentation_loss(prediction, permutated_target_powerset, weight=weight)
else:
permutated_prediction, _ = permutate(labels, prediction)
loss = self.segmentation_loss(permutated_prediction, labels, weight=weight)
return {"loss": loss, "logits": prediction}
return {"logits": prediction}
def setup_loss_func(self):
if self.specifications.powerset:
self.model.powerset = Powerset(
len(self.specifications.classes),
self.specifications.powerset_max_classes,
)
def segmentation_loss(
self,
permutated_prediction: torch.Tensor,
target: torch.Tensor,
weight: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.specifications.powerset:
# `clamp_min` is needed to set non-speech weight to 1.
class_weight = torch.clamp_min(self.model.powerset.cardinality, 1.0) if self.weigh_by_cardinality else None
seg_loss = nll_loss(
permutated_prediction,
torch.argmax(target, dim=-1),
class_weight=class_weight,
weight=weight,
)
else:
seg_loss = binary_cross_entropy(permutated_prediction, target.float(), weight=weight)
return seg_loss
@classmethod
def from_pyannote_model(cls, pretrained):
# Initialize model:
specifications = copy.deepcopy(pretrained.specifications)
# Copy pretrained model hyperparameters:
chunk_duration = specifications.duration
max_speakers_per_frame = specifications.powerset_max_classes
weigh_by_cardinality = False
min_duration = specifications.min_duration
warm_up = specifications.warm_up
max_speakers_per_chunk = len(specifications.classes)
config = SegmentationModelConfig(
chunk_duration=chunk_duration,
max_speakers_per_frame=max_speakers_per_frame,
weigh_by_cardinality=weigh_by_cardinality,
min_duration=min_duration,
warm_up=warm_up,
max_speakers_per_chunk=max_speakers_per_chunk,
)
model = cls(config)
# Copy pretrained model weights:
model.model.hparams = copy.deepcopy(pretrained.hparams)
model.model.sincnet = copy.deepcopy(pretrained.sincnet)
model.model.sincnet.load_state_dict(pretrained.sincnet.state_dict())
model.model.lstm = copy.deepcopy(pretrained.lstm)
model.model.lstm.load_state_dict(pretrained.lstm.state_dict())
model.model.linear = copy.deepcopy(pretrained.linear)
model.model.linear.load_state_dict(pretrained.linear.state_dict())
model.model.classifier = copy.deepcopy(pretrained.classifier)
model.model.classifier.load_state_dict(pretrained.classifier.state_dict())
model.model.activation = copy.deepcopy(pretrained.activation)
model.model.activation.load_state_dict(pretrained.activation.state_dict())
return model
def to_pyannote_model(self):
seg_model = PyanNet(sincnet={"stride": 10})
seg_model.hparams.update(self.model.hparams)
seg_model.sincnet = copy.deepcopy(self.model.sincnet)
seg_model.sincnet.load_state_dict(self.model.sincnet.state_dict())
seg_model.lstm = copy.deepcopy(self.model.lstm)
seg_model.lstm.load_state_dict(self.model.lstm.state_dict())
seg_model.linear = copy.deepcopy(self.model.linear)
seg_model.linear.load_state_dict(self.model.linear.state_dict())
seg_model.classifier = copy.deepcopy(self.model.classifier)
seg_model.classifier.load_state_dict(self.model.classifier.state_dict())
seg_model.activation = copy.deepcopy(self.model.activation)
seg_model.activation.load_state_dict(self.model.activation.state_dict())
seg_model.specifications = self.specifications
return seg_model
Segmentation Model Configuration
SegmentationModelConfig
Configuration class for the segmentation model, specifying various parameters like chunk duration, maximum speakers per frame, etc.- Configuration parameters like chunk duration, number of speakers per chunk/frame, minimum duration, warm-up period, etc.
class SegmentationModelConfig(PretrainedConfig):
model_type = "pyannet"
def __init__(
self,
chunk_duration=10,
max_speakers_per_frame=2,
max_speakers_per_chunk=3,
min_duration=None,
warm_up=(0.0, 0.0),
weigh_by_cardinality=False,
**kwargs,
):
super().__init__(**kwargs)
self.chunk_duration = chunk_duration
self.max_speakers_per_frame = max_speakers_per_frame
self.max_speakers_per_chunk = max_speakers_per_chunk
self.min_duration = min_duration
self.warm_up = warm_up
self.weigh_by_cardinality = weigh_by_cardinality
# For now, the model handles only 16000 Hz sampling rate
self.sample_rate = 16000
Loss Functions¶
Binary Cross-Entropy¶
- Used when the model does not use the powerset approach.
- Computes the binary cross-entropy loss between the predicted and actual speaker activity.
Negative Log-Likelihood (NLL) Loss¶
- Used when the model uses the powerset approach.
- Computes the NLL loss considering class weights if specified.
Optimization Techniques¶
Batch Size¶
- This refers to the number of samples that you feed into your model at each iteration of the training process. This can be adjusted accordingly to optimise the performance of your model
Learning Rate¶
- This is an optimization tunning parameter that determines the step-size at each iteration while moving towards a minimum loss function
Training Epochs¶
- An epoch refers to a complete pass through the entire training dataset. A model is exposed to all the training examples and updates its parametrs basd on the patterns it learns. In our case, we try and iterate and test with 5, 10 and 20 epochs and find that the Diarisation Error Rate remains constant at "'der': 0.23994926057695026"
Warm-up¶
- The warm-up period allows the model to adjust at the beginning of each chunk, ensuring the central part of the chunk is more accurate.
- The warm-up is applied to both the left and right parts of each chunk.
Permutation-Invariant Training¶
- This technique permutes predictions and targets to find the optimal alignment, ensuring the loss computation is invariant to the order of speakers.
Data Augmentation Methods¶
- For our case this is done using the the DataCollator class. This class is responsible for collecting data and ensuring that the target labels are dynamically padded.
- Pads the target labels to ensure they have the same shape.
- Pads with zeros if the number of speakers in a chunk is less than the maximum number of speakers per chunk
Preprocessing Steps¶
- Preprocessing steps like random overlap and fixed overlap during chunking can be considered a form of augmentation as they provide varied inputs to the model.
Preprocess
class used to handle these preprocessing steps is not detailed here, but it's responsible for preparing the input data.
class Preprocess:
def __init__(
self,
config,
):
self.chunk_duration = config.chunk_duration
self.max_speakers_per_frame = config.max_speakers_per_frame
self.max_speakers_per_chunk = config.max_speakers_per_chunk
self.min_duration = config.min_duration
self.warm_up = config.warm_up
self.sample_rate = config.sample_rate
self.model = SegmentationModel(config).to_pyannote_model()
# Get the number of frames associated to a chunk:
_, self.num_frames_per_chunk, _ = self.model(
torch.rand((1, int(self.chunk_duration * self.sample_rate)))
).shape
def get_labels_in_file(self, file):
file_labels = []
for i in range(len(file["speakers"][0])):
if file["speakers"][0][i] not in file_labels:
file_labels.append(file["speakers"][0][i])
return file_labels
def get_segments_in_file(self, file, labels):
file_annotations = []
for i in range(len(file["timestamps_start"][0])):
start_segment = file["timestamps_start"][0][i]
end_segment = file["timestamps_end"][0][i]
label = labels.index(file["speakers"][0][i])
file_annotations.append((start_segment, end_segment, label))
dtype = [("start", "<f4"), ("end", "<f4"), ("labels", "i1")]
annotations = np.array(file_annotations, dtype)
return annotations
def get_chunk(self, file, start_time):
sample_rate = file["audio"][0]["sampling_rate"]
assert sample_rate == self.sample_rate
end_time = start_time + self.chunk_duration
start_frame = math.floor(start_time * sample_rate)
num_frames_waveform = math.floor(self.chunk_duration * sample_rate)
end_frame = start_frame + num_frames_waveform
waveform = file["audio"][0]["array"][start_frame:end_frame]
labels = self.get_labels_in_file(file)
file_segments = self.get_segments_in_file(file, labels)
chunk_segments = file_segments[(file_segments["start"] < end_time) & (file_segments["end"] > start_time)]
# compute frame resolution:
# resolution = self.chunk_duration / self.num_frames_per_chunk
# discretize chunk annotations at model output resolution
step = self.model.receptive_field.step
half = 0.5 * self.model.receptive_field.duration
# discretize chunk annotations at model output resolution
start = np.maximum(chunk_segments["start"], start_time) - start_time - half
start_idx = np.maximum(0, np.round(start / step)).astype(int)
# start_idx = np.floor(start / resolution).astype(int)
end = np.minimum(chunk_segments["end"], end_time) - start_time - half
end_idx = np.round(end / step).astype(int)
# end_idx = np.ceil(end / resolution).astype(int)
# get list and number of labels for current scope
labels = list(np.unique(chunk_segments["labels"]))
num_labels = len(labels)
# initial frame-level targets
y = np.zeros((self.num_frames_per_chunk, num_labels), dtype=np.uint8)
# map labels to indices
mapping = {label: idx for idx, label in enumerate(labels)}
for start, end, label in zip(start_idx, end_idx, chunk_segments["labels"]):
mapped_label = mapping[label]
y[start : end + 1, mapped_label] = 1
return waveform, y, labels
def get_start_positions(self, file, overlap, random=False):
sample_rate = file["audio"][0]["sampling_rate"]
assert sample_rate == self.sample_rate
file_duration = len(file["audio"][0]["array"]) / sample_rate
start_positions = np.arange(0, file_duration - self.chunk_duration, self.chunk_duration * (1 - overlap))
if random:
nb_samples = int(file_duration / self.chunk_duration)
start_positions = np.random.uniform(0, file_duration, nb_samples)
return start_positions
def __call__(self, file, random=False, overlap=0.0):
new_batch = {"waveforms": [], "labels": [], "nb_speakers": []}
if random:
start_positions = self.get_start_positions(file, overlap, random=True)
else:
start_positions = self.get_start_positions(file, overlap)
for start_time in start_positions:
waveform, target, label = self.get_chunk(file, start_time)
new_batch["waveforms"].append(waveform)
new_batch["labels"].append(target)
new_batch["nb_speakers"].append(label)
return new_batch
Metrics and Trainer¶
- Initializes the Metrics class for evaluation.
- Configures the Trainer with the model, training arguments, datasets, data collator, and metrics.
- For the metrics we have the Diarisation Error Rate(DER), FalseAlarm Rate, MissedDetectionRate and the SpeakerConfusionRate with the implementation in the metrics class below.
import numpy as np
import torch
from pyannote.audio.torchmetrics import (DiarizationErrorRate, FalseAlarmRate,
MissedDetectionRate,
SpeakerConfusionRate)
from pyannote.audio.utils.powerset import Powerset
class Metrics:
"""Metric class used by the HF trainer to compute speaker diarization metrics."""
def __init__(self, specifications) -> None:
"""init method
Args:
specifications (_type_): specifications attribute from a SegmentationModel.
"""
self.powerset = specifications.powerset
self.classes = specifications.classes
self.powerset_max_classes = specifications.powerset_max_classes
self.model_powerset = Powerset(
len(self.classes),
self.powerset_max_classes,
)
self.metrics = {
"der": DiarizationErrorRate(0.5),
"confusion": SpeakerConfusionRate(0.5),
"missed_detection": MissedDetectionRate(0.5),
"false_alarm": FalseAlarmRate(0.5),
}
def __call__(self, eval_pred):
logits, labels = eval_pred
if self.powerset:
predictions = self.model_powerset.to_multilabel(torch.tensor(logits))
else:
predictions = torch.tensor(logits)
labels = torch.tensor(labels)
predictions = torch.transpose(predictions, 1, 2)
labels = torch.transpose(labels, 1, 2)
metrics = {"der": 0, "false_alarm": 0, "missed_detection": 0, "confusion": 0}
metrics["der"] += self.metrics["der"](predictions, labels).cpu().numpy()
metrics["false_alarm"] += self.metrics["false_alarm"](predictions, labels).cpu().numpy()
metrics["missed_detection"] += self.metrics["missed_detection"](predictions, labels).cpu().numpy()
metrics["confusion"] += self.metrics["confusion"](predictions, labels).cpu().numpy()
return metrics
class DataCollator:
"""Data collator that will dynamically pad the target labels to have max_speakers_per_chunk"""
def __init__(self, max_speakers_per_chunk) -> None:
self.max_speakers_per_chunk = max_speakers_per_chunk
def __call__(self, features):
"""_summary_
Args:
features (_type_): _description_
Returns:
_type_: _description_
"""
batch = {}
speakers = [f["nb_speakers"] for f in features]
labels = [f["labels"] for f in features]
batch["labels"] = self.pad_targets(labels, speakers)
batch["waveforms"] = torch.stack([f["waveforms"] for f in features])
return batch
def pad_targets(self, labels, speakers):
"""
labels:
speakers:
Returns:
_type_:
Collated target tensor of shape (num_frames, self.max_speakers_per_chunk)
If one chunk has more than max_speakers_per_chunk speakers, we keep
the max_speakers_per_chunk most talkative ones. If it has less, we pad with
zeros (artificial inactive speakers).
"""
targets = []
for i in range(len(labels)):
label = speakers[i]
target = labels[i].numpy()
num_speakers = len(label)
if num_speakers > self.max_speakers_per_chunk:
indices = np.argsort(-np.sum(target, axis=0), axis=0)
target = target[:, indices[: self.max_speakers_per_chunk]]
elif num_speakers < self.max_speakers_per_chunk:
target = np.pad(
target,
((0, 0), (0, self.max_speakers_per_chunk - num_speakers)),
mode="constant",
)
targets.append(target)
return torch.from_numpy(np.stack(targets))
Training Script¶
- The script train_segmentation.py can be used to pre-process a diarization dataset and subsequently fine-tune the pyannote segmentation model. In the following example, we fine-tuned the segmentation model on the English subset of the CallHome dataset, a conversational dataset between native speakers:
!python3 train_segmentation.py \
--dataset_name=diarizers-community/callhome \
--dataset_config_name=eng \
--split_on_subset=data \
--model_name_or_path=pyannote/segmentation-3.0 \
--output_dir=./speaker-segmentation-fine-tuned-callhome-eng \
--do_train \
--do_eval \
--learning_rate=1e-3 \
--num_train_epochs=20 \
--lr_scheduler_type=cosine \
--per_device_train_batch_size=32 \
--per_device_eval_batch_size=32 \
--evaluation_strategy=epoch \
--save_strategy=epoch \
--preprocessing_num_workers=2 \
--dataloader_num_workers=2 \
--logging_steps=100 \
--load_best_model_at_end \
--push_to_hub
Evaluation Script¶
The script test_segmentation.pycan be used to evaluate a fine-tuned model on a diarization dataset. In the following example, we evaluate the fine-tuned model from the previous step on the test split of the CallHome English dataset:
!python3 test_segmentation.py \
--dataset_name=diarizers-community/callhome \
--dataset_config_name=eng \
--split_on_subset=data \
--test_split_name=test \
--model_name_or_path=diarizers-community/speaker-segmentation-fine-tuned-callhome-eng \
--preprocessing_num_workers=2 \
--evaluate_with_pipeline
Sample Output
Inference with Pyannote¶
- The fine-tuned segmentation model can easily be loaded into the pyannote speaker diarization pipeline for inference. To do so, we load the pre-trained speaker diarization pipeline, and subsequently override the segmentation model with our fine-tuned checkpoint:
from diarizers import SegmentationModel
from pyannote.audio import Pipeline
from datasets import load_dataset
import torch
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
# load the pre-trained pyannote pipeline
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
pipeline.to(device)
# replace the segmentation model with your fine-tuned one
model = SegmentationModel().from_pretrained("diarizers-community/speaker-segmentation-fine-tuned-callhome-jpn")
model = model.to_pyannote_model()
pipeline._segmentation.model = model.to(device)
# load dataset example
dataset = load_dataset("diarizers-community/callhome", "jpn", split="data")
sample = dataset[0]["audio"]
# pre-process inputs
sample["waveform"] = torch.from_numpy(sample.pop("array")[None, :]).to(device, dtype=model.dtype)
sample["sample_rate"] = sample.pop("sampling_rate")
# perform inference
diarization = pipeline(sample)
# dump the diarization output to disk using RTTM format
with open("audio.rttm", "w") as rttm:
diarization.write_rttm(rttm)