Finetuning XLS-R(Wav2Vec2) on OpenSLR Nepali ASR Dataset

30 minute read

Published:

This blog/tutorial on finetune XLS-R on OpenSLR’s Nepali ASR dataset is adopted from the Huggingface’s blog on “Fine-tuning XLS-R for Multi-Lingual ASR with 🤗 Transformers.” The existing XLS-R on Nepali was actually finetuned on OpenSLR’s Nepali Text to Speech, which contains voices from only one speaker and that too of high quality. Therefore, it is doubtful that this model would work when the utterances are made in the wild, as we would normally do. To avoid this problem, speech samples taken in real life with natural characteristics like noises, pauses should be used, and Large Nepali ASR training data set is the one dataset we have for Nepali language. However, computational resources unavailability is keeping me from completing the finetuning and hyperparameter optimization of this remarkable model on my mother tongue; so far I get 21% word error rate on the test split I created. Anyone who likes to contribute on finshing this train will be heartily welcomed. Let’s begin the implementation journey though.

Introduction to Wav2Vec2

Wav2Vec2 is a pretrained model for Automatic Speech Recognition (ASR) and was released in September 2020 by Alexei Baevski, Michael Auli, and Alex Conneau. Soon after the superior performance of Wav2Vec2 was demonstrated on one of the most popular English datasets for ASR, called LibriSpeech, Facebook AI presented a multi-lingual version of Wav2Vec2, called XLSR. XLSR stands for cross-lingual speech representations and refers to model’s ability to learn speech representations that are useful across multiple languages.

XLSR’s successor, simply called XLS-R (refering to the '’XLM-R for Speech’‘), was released in November 2021 by Arun Babu, Changhan Wang, Andros Tjandra, et al. XLS-R used almost half a million hours of audio data in 128 languages for self-supervised pre-training and comes in sizes ranging from 300 milion up to two billion parameters. You can find the pretrained checkpoints on the 🤗 Hub:

Similar to BERT’s masked language modeling objective, XLS-R learns contextualized speech representations by randomly masking feature vectors before passing them to a transformer network during self-supervised pre-training (i.e. diagram on the left below).

For fine-tuning, a single linear layer is added on top of the pre-trained network to train the model on labeled data of audio downstream tasks such as speech recognition, speech translation and audio classification (i.e. diagram on the right below).

wav2vec2_structure

XLS-R shows impressive improvements over previous state-of-the-art results on both speech recognition, speech translation and speaker/language identification, cf. with Table 3-6, Table 7-10, and Table 11-12 respectively of the official paper.

Notebook Setup

In this notebook, we will give an in-detail explanation of how XLS-R - more specifically the pre-trained checkpoint Wav2Vec2-XLS-R-300M - can be fine-tuned for ASR.

XLS-R is fine-tuned using Connectionist Temporal Classification (CTC), which is an algorithm that is used to train neural networks for sequence-to-sequence problems, such as ASR and handwriting recognition.

I highly recommend reading the well-written blog post Sequence Modeling with CTC (2017) by Awni Hannun.

First, let’s try to get a good GPU in our colab! With Google Colab’s free version it’s sadly becoming much harder to get access to a good GPU. With Google Colab Pro, however, one should easily get either a V100 or P100 GPU.

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)
Sun Oct 23 02:39:42 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   38C    P8    11W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

Before we start, let’s install datasets, transformers, and pytorch mutually compatible with each other. Also, we need the torchaudio to load audio files and jiwer to evaluate our fine-tuned model using the word error rate (WER) metric \({}^1\).

!pip --no-cache-dir install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
!pip --no-cache-dir install transformers==4.23.1
!pip --no-cache-dir install datasets==2.6.0
!pip --no-cache-dir install evaluate==0.3.0
!pip --no-cache-dir install jiwer

If you are training in environments like colab where storage space are limited or environments are temporary, it is strongly recommended to save your model checkpoints somewhere else. If we’re working with Huggingface’s transformer, why not the Huggingface Hub. The 🤗 Hub has integrated version control so you can be sure that no model checkpoint is getting lost during training.

To do so you have to store your authentication token from the Hugging Face website (sign up here if you haven’t already!)

from huggingface_hub import notebook_login

notebook_login()
Login successful
Your token has been saved to /root/.huggingface/token

Huggingface hub requires git lfs as it reuires to upload large models.

!apt install git-lfs

\({}^1\) In the paper, the model was evaluated using the phoneme error rate (PER), but by far the most common metric in ASR is the word error rate (WER). To keep this notebook as general as possible we decided to evaluate the model using WER.

Prepare Data, Tokenizer, Feature Extractor

ASR models transcribe speech to text, which means that we both need a feature extractor that processes the speech signal to the model’s input format, e.g. a feature vector, and a tokenizer that processes the model’s output format to text.

In 🤗 Transformers, the XLS-R model is thus accompanied by both a tokenizer, called Wav2Vec2CTCTokenizer, and a feature extractor, called Wav2Vec2FeatureExtractor.

Let’s start by creating the tokenizer to decode the predicted output classes to the output transcription.

Create Wav2Vec2CTCTokenizer

A pre-trained XLS-R model maps the speech signal to a sequence of context representations as illustrated in the figure above. However, for speech recognition the model has to to map this sequence of context representations to its corresponding transcription which means that a linear layer has to be added on top of the transformer block (shown in yellow in the diagram above). This linear layer is used to classifies each context representation to a token class analogous how, e.g., after pretraining a linear layer is added on top of BERT’s embeddings for further classification - cf. with ‘BERT’ section of this blog post.

The output size of this layer corresponds to the number of tokens in the vocabulary, which does not depend onXLS-R’s pretraining task, but only on the labeled dataset used for fine-tuning. So in the first step, we will take a look at the chosen dataset of OpenSLR Nepali ASR and define a vocabulary based on the transcriptions.

Dataset Preparation

First, let’s go to OpenSLR official website for Nepali ASR This data set contains transcribed audio data for Nepali. The data set consists of zips containing flac(a file format for audio) files, and a TSV file. The file utt_spk_text.tsv contains a FileID, anonymized UserID and the transcription of audio in the file. The data set has been manually quality checked, but there might still be errors.

Since downloading, extracting and preprocessing takes a lot of works and time, I’ve uploaded the preprocessed dataset along with the original dataset in Huggingface Hub so that we can interact with the Huggingface’ Dataset API. In summary, following are the steps I’ve taken to prepare the dataset:

  1. Download the dataset from https://www.openslr.org/54/
  2. Extract the zip files containing flac audio files
  3. Load the audio files and apply following preprocessing function to each of them
     import torchaudio
     # The pretrained Wav2Vec2 model was trained on speeches with sample rate 16KHz
     SAMPLING_RATE = 16000
     def process_audio_file(orig_path, new_path):
         """Read and process file in `orig_path` and save it to `new_path`"""
         waveform, sampling_rate = torchaudio.load(orig_path)
         if sampling_rate != SAMPLING_RATE:
             # Resample to 16KHz if the audio originally has different sampling rate.
             waveform = torchaudio.functional.resample(waveform, sampling_rate, SAMPLING_RATE)
         #  Though the ASR models should be resilient to silences at the ends of audio,
         # the leading and trailing silences are removed using Voice Activity Detection(VAD)
         # implemented in torchaudio with default parameters to reduce the demands 
         # for computational resources
         waveform = torchaudio.functional.vad(waveform, sample_rate=SAMPLING_RATE)
         # save the processed audio files to new location
         torchaudio.save(new_path, waveform, sample_rate=SAMPLING_RATE)
    
  4. The processed audio files are agained zipped in similar fashion as in original OpenSLR Nepali ASR dataset.
  5. The zip files and TSV file containing transcript and audio path mappings are uploaded to “spktsagar/openslr-nepali-asr-cleaned/data”.
  6. The dataset loading script, which can be found here is developed and pushed to the same repo.

Now, we load the dataset with datasets api. Since the dataset is not split into train/val/test split, whole dataset will be downloaded and split into train and val set later. When I finetuned the Wav2Vec2 model on preprocessed, cleaned dataset, the model was not learning(WER was always 1 on validation/test set). Anyone who likes to debug the preprocessed dataset and fintuning on it is heartily welcomed.

from datasets import load_dataset

DATASET_TYPE = 'original'  # change to `original` or `cleaned` for downloading original or cleaned version of openslr dataset

dataset = load_dataset("spktsagar/openslr-nepali-asr-cleaned", name=DATASET_TYPE, split='train')
dataset
Dataset({
    features: ['utterance_id', 'speaker_id', 'utterance', 'transcription', 'num_frames'],
    num_rows: 157905
})

You can see in the info that the dataset contains following fields.

  • utterance_id
  • speaker_id
  • utterance
  • transcription
  • num_frames

For the description of them, please read the dataset card here.

Text Preprocessing

Although the transcription are fairly clean, some transcription contain characters other than Nepali. We will remove those data from our dataset.

import string

def check_english_chars(text):
    """Returns if this text contains any english characters"""
    return any([c in text for c in string.ascii_letters])

# Use dataset filter to remove examples with above function
dataset = dataset.filter(
    lambda ex: not check_english_chars(ex),
    input_columns=['transcription',],
    with_indices=False, batched=False, batch_size=0,
)
dataset
Dataset({
    features: ['utterance_id', 'speaker_id', 'utterance', 'transcription', 'num_frames'],
    num_rows: 157904
})

Let’s see the list of all the characters we have now in the dataset.

''.join(sorted(set([c for s in dataset['transcription'] for c in s])))
' !%.;?\\\xa0ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरऱलवशषसह़ािीुूृेैॉॊोौ्ॐ॑ॠ।०१२३४५६७८९॰\u200c\u200d\u200e\u200f“'

You can see there are some characters and symbols that we don’t use in Nepali. We will remove those from the transcription.

remove_chars = ['!', '%', '.', ';', '?', '\\', '।', '\xa0', '\u200c', '\u200d', '\u200e', '\u200f', '“']

def remove_special_characters(row):
    row['transcription'] = ''.join(
        [c for c in row['transcription'] if c not in remove_chars]
    ).strip()
    return row

dataset = dataset.map(remove_special_characters)
''.join(sorted(set([c for s in dataset['transcription'] for c in s])))
' ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरऱलवशषसह़ािीुूृेैॉॊोौ्ॐ॑ॠ०१२३४५६७८९॰'

In CTC, it is common to classify speech chunks into letters, so we will do the same here. Let’s extract all distinct letters of the training and test data and build our vocabulary from this set of letters.

We write a mapping function that concatenates all transcriptions into one long transcription and then transforms the string into a set of chars. It is important to pass the argument batched=True to the map(...) function so that the mapping function has access to all transcriptions at once.

def extract_all_chars(batch):
    all_text = " ".join(batch["transcription"])
    vocab = list(set(all_text))
    return {"vocab": [vocab]}

vocab_all = dataset.map(extract_all_chars, batched=True,
                        batch_size=-1, keep_in_memory=True,
                        remove_columns=dataset.column_names)
vocab_list = sorted(list(set(vocab_all["vocab"][0])))

Finally, we also add a padding token that corresponds to CTC’s “blank token”. The “blank token” is a core component of the CTC algorithm. For more information, please take a look at the “Alignment” section here.

UNK_TOKEN = '__UNK__'
PAD_TOKEN = '__PAD__'

vocab_list = [PAD_TOKEN, UNK_TOKEN, *vocab_list]

Now, we create an enumerated dictionary so that we have token to id mapping.

vocab_dict = {v: k for k, v in enumerate(vocab_list)}

# for printing vocab in single line
', '.join([f"{k}: {v}" for k, v in (vocab_dict.items())])
'__PAD__: 0, __UNK__: 1,  : 2, ँ: 3, ं: 4, ः: 5, अ: 6, आ: 7, इ: 8, ई: 9, उ: 10, ऊ: 11, ऋ: 12, ए: 13, ऐ: 14, ओ: 15, औ: 16, क: 17, ख: 18, ग: 19, घ: 20, ङ: 21, च: 22, छ: 23, ज: 24, झ: 25, ञ: 26, ट: 27, ठ: 28, ड: 29, ढ: 30, ण: 31, त: 32, थ: 33, द: 34, ध: 35, न: 36, प: 37, फ: 38, ब: 39, भ: 40, म: 41, य: 42, र: 43, ऱ: 44, ल: 45, व: 46, श: 47, ष: 48, स: 49, ह: 50, ़: 51, ा: 52, ि: 53, ी: 54, ु: 55, ू: 56, ृ: 57, े: 58, ै: 59, ॉ: 60, ॊ: 61, ो: 62, ौ: 63, ्: 64, ॐ: 65, ॑: 66, ॠ: 67, ०: 68, १: 69, २: 70, ३: 71, ४: 72, ५: 73, ६: 74, ७: 75, ८: 76, ९: 77, ॰: 78'

To make it clearer that " " has its own token class, we give it a more visible character |. In addition, we also add an “unknown” token so that the model can later deal with characters not encountered in Common Voice’s training set.

WORD_DELIMITER = '|'

vocab_dict[WORD_DELIMITER] = vocab_dict[" "]
del vocab_dict[" "]
len(vocab_dict)
79

Cool, now our vocabulary is complete and consists of 79 tokens, which means that the linear layer that we will add on top of the pretrained XLS-R checkpoint will have an output dimension of 79.

Let’s now save the vocabulary as a json file.

import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

In a final step, we use the json file to load the vocabulary into an instance of the Wav2Vec2CTCTokenizer class.

from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token=UNK_TOKEN, pad_token=PAD_TOKEN, word_delimiter_token=WORD_DELIMITER)
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.

Create Wav2Vec2FeatureExtractor

Speech is a continuous signal and to be treated by computers, it first has to be discretized, which is usually called sampling. The sampling rate hereby plays an important role in that it defines how many data points of the speech signal are measured per second. Therefore, sampling with a higher sampling rate results in a better approximation of the real speech signal but also necessitates more values per second.

A pretrained checkpoint expects its input data to have been sampled more or less from the same distribution as the data it was trained on. The same speech signals sampled at two different rates have a very different distribution, e.g., doubling the sampling rate results in data points being twice as long. Thus, before fine-tuning a pretrained checkpoint of an ASR model, it is crucial to verify that the sampling rate of the data that was used to pretrain the model matches the sampling rate of the dataset used to fine-tune the model.

XLS-R was pretrained on audio data of Babel, Multilingual LibriSpeech (MLS), Common Voice, VoxPopuli, and VoxLingua107 at a sampling rate of 16kHz. As stated earlier, the OpenSLR Nepali ASR dataset is already has a sampling rate of 16kHz.

# Define a global variable to store our sampling rate
SPEECH_SAMPLING_RATE = 16000

Long input sequences require a lot of memory. XLS-R is based on self-attention the memory requirement scales quadratically with the input length for long input sequences (cf. with this reddit post). In case this demo crashes with an “Out-of-memory” error for you, you might want to use the following code to filter all sequences that are longer than 5 seconds for training.

MAX_FRAMES = SPEECH_SAMPLING_RATE*5  # 5 sec

dataset = dataset.filter(
    lambda ex: ex < MAX_FRAMES,
    input_columns=['num_frames',],
    with_indices=False, batched=False, batch_size=0,
)

dataset
Dataset({
    features: ['utterance_id', 'speaker_id', 'utterance', 'transcription', 'num_frames'],
    num_rows: 143974
})

This seemed to have worked! Let’s listen to a couple of audio files to better understand the dataset and verify that the audio was correctly loaded.

Note: You can click the following cell a couple of times to listen to different speech samples.

import random
import IPython.display as ipd

sample_idx = random.randint(0, len(dataset))

print(dataset[sample_idx]['transcription'])
ipd.Audio(dataset[sample_idx]['utterance']["array"], autoplay=True, rate=SPEECH_SAMPLING_RATE)
सुवेदी पार्टीको जिल्ला

A Wav2Vec2FeatureExtractor object requires the following parameters to be instantiated:

  • feature_size: Speech models take a sequence of feature vectors as an input. While the length of this sequence obviously varies, the feature size should not. In the case of Wav2Vec2, the feature size is 1 because the model was trained on the raw speech signal \({}^2\).
  • sampling_rate: The sampling rate at which the model is trained on.
  • padding_value: For batched inference, shorter inputs need to be padded with a specific value
  • do_normalize: Whether the input should be zero-mean-unit-variance normalized or not. Usually, speech models perform better when normalizing the input
  • return_attention_mask: Whether the model should make use of an attention_mask for batched inference. In general, XLS-R models checkpoints should always use the attention_mask.
from transformers import Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=SPEECH_SAMPLING_RATE,
                                             padding_value=0.0, do_normalize=True,
                                             return_attention_mask=True)

Great, XLS-R’s feature extraction pipeline is thereby fully defined!

For improved user-friendliness, the feature extractor and tokenizer are wrapped into a single Wav2Vec2Processor class so that one only needs a model and processor object.

from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor(
    feature_extractor=feature_extractor,
    tokenizer=tokenizer
)

If one wants to re-use the just created processor including tokenizer and feature extractor with the fine-tuned model of this notebook, it is strongly advised to upload the processor to the 🤗 Hub. Let’s call the repo to which we will upload the files "wav2vec2-large-xls-r-300m-nepali-openslr":

repo_name = "wav2vec2-large-xls-r-300m-nepali-openslr"

and upload the tokenizer to the 🤗 Hub.

processor.push_to_hub(repo_name)

Great, you can see the just created repository under https://huggingface.co/<your-username>/wav2vec2-large-xls-r-300m-nepali-openslr

Preprocess Data

So far, we have not looked at the actual values of the speech signal but just the transcription. In addition to transcription, our datasets include more column names utterance_id, speaker_id, utterance, and num_frames. In utterance there are two fields: array and path. path states the absolute path of the audio file, and array is the numpy array of the same audio file. Let’s take a look.

dataset[45]
{'utterance_id': 'a176fcb0d8',
 'speaker_id': '6a6d1',
 'utterance': {'path': '/root/.cache/huggingface/datasets/downloads/extracted/6bc81d0b078b9cc240efb7e2885d7c845ea238b71125a16d73b19c06621b39f2/asr_nepali/data/a1/a176fcb0d8.flac',
  'array': array([ 0.00030518,  0.00039673,  0.00039673, ..., -0.00057983,
         -0.00036621, -0.00027466], dtype=float32),
  'sampling_rate': 16000},
 'transcription': '० पनि एक',
 'num_frames': 36800}

We will convert Huggingface’s Dataset to PyTorch dataset, so that audio files are loaded lazily as we are restricted by space availablity and memory size.

import torch

class NepaliASRProcessedDataset(torch.utils.data.Dataset):
    """Takes HF dataset and processor, and process the audio files
    and transcription with the processor only when items are requested
    """
    def __init__(
        self,
        dataset,
        processor,
    ):
        self.dataset = dataset
        self.processor = processor
    
    def __len__(self):
        """Length of dataset"""
        return len(self.dataset)
    
    def __getitem__(self, idx):
        """Return processed data at `idx` index."""
        example = self.dataset[idx]
        
        # Return dict
        return_dict = {}

        # first, process the audio with Wav2Vec2 feature extractor
        return_dict['input_values'] = self.processor(
            audio=example['utterance']['array'],
            sampling_rate=example['utterance']['sampling_rate'],
            return_attention_mask=False,  # will be calculated during batching
        )['input_values'][0]
        # add the length of extracted features of audio
        return_dict['input_length'] = len(return_dict['input_values'])

        # second, process the transcription with Wav2Vec2 tokenizer
        return_dict['labels'] = self.processor(
            text=example['transcription'],
            return_attention_mask=False,  # will be calculated during batching
        )['input_ids']
        return return_dict

Train/Test Split

Since our dataset has no separate split for training and evaluation, we will create one manually. We will split the dataset into 15% test and 85% train set

test_size = 0.15
dataset = dataset.sort('utterance_id')
split_dict = dataset.train_test_split(test_size=test_size, seed=42)
train_dataset, test_dataset = split_dict['train'], split_dict['test']
train_dataset, test_dataset
(Dataset({
     features: ['utterance_id', 'speaker_id', 'utterance', 'transcription', 'num_frames'],
     num_rows: 122377
 }), Dataset({
     features: ['utterance_id', 'speaker_id', 'utterance', 'transcription', 'num_frames'],
     num_rows: 21597
 }))

Convert the Huggingface’s train/test dataset to Pytorch train/test data

train_dataset = NepaliASRProcessedDataset(train_dataset, processor)
test_dataset = NepaliASRProcessedDataset(test_dataset, processor)

Training

The data is processed so that we are ready to start setting up the training pipeline. We will make use of 🤗’s Trainer for which we essentially need to do the following:

  • Define a data collator. In contrast to most NLP models, XLS-R has a much larger input length than output length. E.g., a sample of input length 50000 has an output length of no more than 100. Given the large input sizes, it is much more efficient to pad the training batches dynamically meaning that all training samples should only be padded to the longest sample in their batch and not the overall longest sample. Therefore, fine-tuning XLS-R requires a special padding data collator, which we will define below

  • Evaluation metric. During training, the model should be evaluated on the word error rate. We should define a compute_metrics function accordingly

  • Load a pretrained checkpoint. We need to load a pretrained checkpoint and configure it correctly for training.

  • Define the training configuration.

After having fine-tuned the model, we will correctly evaluate it on the test data and verify that it has indeed learned to correctly transcribe speech.

Set-up Trainer

Let’s start by defining the data collator. The code for the data collator was copied from this example.

Without going into too many details, in contrast to the common data collators, this data collator treats the input_values and labels differently and thus applies to separate padding functions on them. This is necessary because in speech input and output are of different modalities meaning that they should not be treated by the same padding function. Analogous to the common data collators, the padding tokens in the labels with -100 so that those tokens are not taken into account when computing the loss.

import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union


LARGE_NEG = -100

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths,
        # and need different padding methods
        batch = {}
        input_features = [{"input_values": feature["input_values"]} for feature in features if 'input_values' in feature]
        label_features = [{"input_ids": feature["labels"]} for feature in features if 'labels' in feature]

        if input_features:
            batch.update(self.processor.pad(
                input_features,
                padding=self.padding,
                return_tensors="pt",
            ))
        if label_features:
            labels_batch = self.processor.tokenizer.pad(
                label_features,
                padding=self.padding,
                return_tensors="pt",
            )

            # replace padding with large negative number to ignore loss correctly
            labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), LARGE_NEG)

            batch["labels"] = labels

        return batch
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

Next, the evaluation metric is defined. As mentioned earlier, the predominant metric in ASR is the word error rate (WER), hence we will use it in this notebook as well.

import evaluate
import numpy as np

wer_metric = evaluate.load("wer")

The model will return a sequence of logit vectors: \(\mathbf{y}_1, \ldots, \mathbf{y}_m\) with \(\mathbf{y}_1 = f_{\theta}(x_1, \ldots, x_n)[0]\) and \(n >> m\).

A logit vector \(\mathbf{y}_1\) contains the log-odds for each word in the vocabulary we defined earlier, thus \(\text{len}(\mathbf{y}_i) =\) config.vocab_size. We are interested in the most likely prediction of the model and thus take the argmax(...) of the logits. Also, we transform the encoded labels back to the original string by replacing LARGE_NEG with the pad_token_id and decoding the ids while making sure that consecutive tokens are not grouped to the same token in CTC style \({}^1\).

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == LARGE_NEG] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

Now, we can load the pretrained checkpoint of Wav2Vec2-XLS-R-300M. The tokenizer’s pad_token_id must be to define the model’s pad_token_id or in the case of Wav2Vec2ForCTC also CTC’s blank token \({}^2\). To save GPU memory, we enable PyTorch’s gradient checkpointing and also set the loss reduction to “mean”.

Because the dataset is quite large (~100h of data) and because ASR dataset is quite noisy, fine-tuning Facebook’s wav2vec2-xls-r-300m checkpoint seems to require some hyper-parameter tuning. Therefore, one had to play around a bit with different values for dropout, SpecAugment’s masking dropout rate, layer dropout, and the learning rate until training seemed to be stable enough.

Note: Since I was not able to run the hyperparamter optimization on colab, I’m not sure if the current set of hyperparameters are the best set of parameters. Feel free to adapt those parameters and let me know. I’ve used the default ones in wav2vec2 models.

from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-xls-r-300m", 
    attention_dropout=0.1,
    hidden_dropout=0.1,
    feat_proj_dropout=0.0,
    mask_time_prob=0.075,
    layerdrop=0.1,
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
)

The first component of XLS-R consists of a stack of CNN layers that are used to extract acoustically meaningful - but contextually independent - features from the raw speech signal. This part of the model has already been sufficiently trained during pretraining and as stated in the paper does not need to be fine-tuned anymore. Thus, we can set the requires_grad to False for all parameters of the feature extraction part.

model.freeze_feature_encoder()

In a final step, we define all parameters related to training. To give more explanation on some of the parameters:

  • group_by_length makes training more efficient by grouping training samples of similar input length into one batch. This can significantly speed up training time by heavily reducing the overall number of useless padding tokens that are passed through the model
  • learning_rate and weight_decay were heuristically tuned until fine-tuning has become stable. Note that those parameters strongly depend on the dataset used and might be suboptimal for other speech datasets.

For more explanations on other parameters, one can take a look at the docs.

During training, a checkpoint will be uploaded asynchronously to the hub every 400 training steps. It allows you to also play around with the demo widget even while your model is still training.

Note: If one does not want to upload the model checkpoints to the hub, simply set push_to_hub=False.

from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir=repo_name,
  group_by_length=True,
  per_device_train_batch_size=16,
  gradient_accumulation_steps=2,
  evaluation_strategy="steps",
  num_train_epochs=10,
  gradient_checkpointing=True,
  fp16=True,
  save_steps=800,
  eval_steps=800,
  logging_steps=800,
  learning_rate=3e-4,
  warmup_steps=500,
  save_total_limit=2,
  push_to_hub=True,
  hub_strategy='checkpoint',
  resume_from_checkpoint='last-checkpoint',
)

Now, all instances can be passed to Trainer and we are ready to start training!

from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=processor.feature_extractor,
)

\({}^1\) To allow models to become independent of the speaker rate, in CTC, consecutive tokens that are identical are simply grouped as a single token. However, the encoded labels should not be grouped when decoding since they don’t correspond to the predicted tokens of the model, which is why the group_tokens=False parameter has to be passed. If we wouldn’t pass this parameter a word like "hello" would incorrectly be encoded, and decoded as "helo".

\({}^2\) The blank token allows the model to predict a word, such as "hello" by forcing it to insert the blank token between the two l’s. A CTC-conform prediction of "hello" of our model would be [PAD] [PAD] "h" "e" "e" "l" "l" [PAD] "l" "o" "o" [PAD].

Training

Training will take multiple hours depending on the GPU allocated to this notebook.

In case you want to use this google colab to fine-tune your model, you should make sure that your training doesn’t stop due to inactivity. A simple hack to prevent this is to paste the following code into the console of this tab (right mouse click -> inspect -> Console tab and insert code).

function ConnectButton(){
    console.log("Connect pushed"); 
    document.querySelector("#top-toolbar > colab-connect-button").shadowRoot.querySelector("#connect").click() 
}
setInterval(ConnectButton,60000);

Depending on what GPU was allocated to your google colab it might be possible that you are seeing an "out-of-memory" error here. In this case, it’s probably best to reduce per_device_train_batch_size to 8 or even less and increase gradient_accumulation.

trainer.train(
    resume_from_checkpoint=True,  # Set to false if you want to start from the beginning
)

If the training loss and validation WER go down nicely, You can now upload the result of the training to the 🤗 Hub, just execute this

trainer.push_to_hub()

You can now share this model with all your friends, family, favorite pets: they can all load it with the identifier “your-username/the-name-you-picked” so for instance:

from transformers import AutoModelForCTC, Wav2Vec2Processor

model = AutoModelForCTC.from_pretrained("spktsagar/wav2vec2-large-xls-r-300m-nepali-openslr")
processor = Wav2Vec2Processor.from_pretrained("spktsagar/wav2vec2-large-xls-r-300m-nepali-openslr")

For more examples of how XLS-R can be fine-tuned, please take a look at the official speech recognition examples.

Evaluation

As a final check, let’s load the model and verify that it indeed has learned to transcribe Nepali speech.

Let’s first load the pretrained checkpoint.

model = Wav2Vec2ForCTC.from_pretrained(repo_name).to("cuda")
processor = Wav2Vec2Processor.from_pretrained(repo_name)

Now, we will just take all examples of the test set, run it through the model and take the argmax(...) of the logits to retrieve the predicted token ids. Those token ids will be decoded to retrieve transcriptions.

# only take 5 examples from 
pred = trainer.predict(
    torch.utils.data.Subset(
        test_dataset,
        random.sample(list(range(len(test_dataset))), 5)
    )
)
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis=-1)

pred.label_ids[pred.label_ids == LARGE_NEG] = processor.tokenizer.pad_token_id

pred_str = processor.batch_decode(pred_ids)
# we do not want to group tokens when computing the metrics
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

Now, let’s see first few reference transcriptions and predicted transcriptions.

list(zip(label_str, pred_str))[:5]
[('उपस्थित गराए', 'उपस्थित गराए'),
 ('प्रतिशतले वृद्धि', 'प्रतिशदले वृद्धि'),
 ('उनीहरू जहाँ जुन', 'उनीहरू जाँ जुन'),
 ('टेम्प्लेटहरू पनि', 'टेम्प्लेटहरू पनि'),
 ('रूपमा खाइन्छ', 'रूपमा खाइन्छ')]

Alright! The transcription can definitely be recognized from our prediction, but it is not perfect yet. Training the model a bit longer, spending more time on the data preprocessing, and especially using a language model for decoding would certainly improve the model’s overall performance.

You can play with and use the model I trained from here. Thank you!!