WordPiece Step-by-Step Implementation#

In this lecture, we will walk you through the process of implementing WordPiece tokenization, a subword tokenization algorithm used in many state-of-the-art models like BERT, DistilBERT, and RoBERTa. We will use a dataset of financial news headlines for this purpose.

Dataset Preparation#

First, we need to load our dataset. We will use the ashraq/financial-news dataset from the Hugging Face Hub. We will use the headline column as our text data. We will randomly sample 1000 records from this dataset for our tokenization process. Here is how we can do this:

from datasets import load_dataset

dataset = load_dataset("ashraq/financial-news")
texts = dataset["train"].shuffle(seed=1234).select(range(1000))["headline"]
/home/yjlee/.cache/pypoetry/virtualenvs/lecture-_dERj_9R-py3.8/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Found cached dataset parquet (/home/yjlee/.cache/huggingface/datasets/ashraq___parquet/ashraq--financial-news-89d6ac597a40e29e/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
100%|██████████| 1/1 [00:00<00:00, 53.96it/s]
Loading cached shuffled indices for dataset at /home/yjlee/.cache/huggingface/datasets/ashraq___parquet/ashraq--financial-news-89d6ac597a40e29e/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-79cde8905e45a47f.arrow

Pre-tokenization#

Before we start with the WordPiece tokenization, we need to pre-tokenize our text. Pre-tokenization involves splitting the text into words. This is a necessary step because WordPiece operates on the word level. Here is a simple function to pre-tokenize the text:

import re


def pre_tokenize(text, lowercase=True):
    if lowercase:
        text = text.lower()
    text = re.sub(r"\s+", " ", text)
    return text.split(" ")

Vocabulary Initialization#

The next step is to initialize our vocabulary. In the case of WordPiece, we start with a vocabulary of individual characters. Here is how we can do this:

from collections import defaultdict


def initialize_vocab(texts, lowercase=True):
    vocab = defaultdict(int)
    for text in texts:
        words = pre_tokenize(text, lowercase)
        for word in words:
            vocab[word] += 1
    return vocab


word_freqs = initialize_vocab(texts)
print("Number of words: {}".format(len(word_freqs.keys())))
Number of words: 3636

The alphabet is the unique set composed of all the first letters of words, and all the other letters that appear in words prefixed by ##:

characters = []
for word in word_freqs.keys():
    if word[0] not in characters:
        characters.append(word[0])
    for letter in word[1:]:
        if f"##{letter}" not in characters:
            characters.append(f"##{letter}")

characters = sorted(characters)
print(characters)
['"', '#', '##!', '##"', '##$', '##%', '##&', "##'", '##)', '##+', '##,', '##-', '##.', '##/', '##0', '##1', '##2', '##3', '##4', '##5', '##6', '##7', '##8', '##9', '##:', '##;', '##?', '##]', '##a', '##b', '##c', '##d', '##e', '##f', '##g', '##h', '##i', '##j', '##k', '##l', '##m', '##n', '##o', '##p', '##q', '##r', '##s', '##t', '##u', '##v', '##w', '##x', '##y', '##z', '##|', '##®', '##é', '##–', '##—', '##…', '$', '&', "'", '(', ')', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', '[', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '|', '~', '–', '—', '€']

Adding Special Tokens#

BERT and other models that use WordPiece tokenization use special tokens like “[PAD]”, “[UNK]”, “[CLS]”, “[SEP]”, “[MASK]”. We need to add these tokens to our vocabulary:

vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] + characters.copy()
len(vocab)
118

Splitting Words into Characters#

Next, we split each word into characters. We also add a special prefix “##” to all characters except the first one in each word. This prefix is used to indicate that a character is not the start of a new word:

splits = {
    word: [c if i == 0 else f"##{c}" for i, c in enumerate(word)]
    for word in word_freqs.keys()
}

Computing Pair Scores#

Now, we compute the scores for all possible pairs of characters. The score of a pair is defined as the frequency of the pair divided by the product of the frequencies of the individual characters. The idea is to find the pair that occurs together more often than separately:

def compute_pair_scores(splits, word_freqs):
    letter_freqs = defaultdict(int)
    pair_freqs = defaultdict(int)
    for word, freq in word_freqs.items():
        split = splits[word]
        if len(split) == 1:
            letter_freqs[split[0]] += freq
            continue
        for i in range(len(split) - 1):
            pair = (split[i], split[i + 1])
            letter_freqs[split[i]] += freq
            pair_freqs[pair] += freq
        letter_freqs[split[-1]] += freq

    scores = {
        pair: freq / (letter_freqs[pair[0]] * letter_freqs[pair[1]])
        for pair, freq in pair_freqs.items()
    }
    return scores


pair_scores = compute_pair_scores(splits, word_freqs)
for i, key in enumerate(pair_scores.keys()):
    print(f"{key}: {pair_scores[key]}")
    if i >= 5:
        break
('t', '##r'): 5.0029720626114525e-05
('##r', '##u'): 1.623534759879209e-05
('##u', '##d'): 9.79063701800694e-06
('##d', '##e'): 4.1215004357264235e-05
('##e', '##a'): 2.057878735731085e-05
('##a', '##u'): 4.019809621816311e-06

The compute_pair_scores function calculates the scores for all possible pairs of characters in the vocabulary. The score of a pair is defined as the frequency of the pair divided by the product of the frequencies of the individual characters. This score is a measure of how often the pair occurs together compared to how often they occur separately.

Finding the Best Pair#

Next, we find the pair with the highest score. This pair is the best candidate for merging:

best_pair = ""
max_score = None
for pair, score in pair_scores.items():
    if max_score is None or max_score < score:
        best_pair = pair
        max_score = score

print(best_pair, max_score)
('~', '##$') 0.3333333333333333

Merging the Best Pair#

Once we have identified the best pair, we merge it and update our splits:

def merge_pair(a, b, splits, word_freqs):
    for word in word_freqs:
        split = splits[word]
        if len(split) == 1:
            continue
        i = 0
        while i < len(split) - 1:
            if split[i] == a and split[i + 1] == b:
                merge = a + b[2:] if b.startswith("##") else a + b
                split = split[:i] + [merge] + split[i + 2 :]
            else:
                i += 1
            splits[word] = split
    return splits

The merge_pair function merges the best pair in all the words in our vocabulary. If the second character of the pair starts with “##”, we remove the “##” before merging. After merging, we replace the pair in the word with the merged character.

Repeating the Process#

We repeat the process of computing pair scores, finding the best pair, and merging the best pair until we reach our desired vocabulary size:

vocab_size = 1000
while len(vocab) < vocab_size:
    scores = compute_pair_scores(splits, word_freqs)
    best_pair, max_score = "", None
    for pair, score in scores.items():
        if max_score is None or max_score < score:
            best_pair = pair
            max_score = score
    splits = merge_pair(*best_pair, splits, word_freqs)
    new_token = (
        best_pair[0] + best_pair[1][2:]
        if best_pair[1].startswith("##")
        else best_pair[0] + best_pair[1]
    )
    vocab.append(new_token)

print("First 10 tokens: {}".format(vocab[:10]))
print("Last 50 tokens: {}".format(vocab[-50:]))
First 10 tokens: ['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]', '"', '#', '##!', '##"', '##$']
Last 50 tokens: ['##tangibl', '##ntangibl', '##intangibl', "'intangibl", '##rangl', 'wrangl', 'dange', 'danger', '##eng', '##leng', '##lleng', '##alleng', 'challeng', 'challenge', 'challenge:', 'ange', 'anger', 'angers', "'intangible", 'wrangle', 'wrangler', 'change', 'changes', 'challenger', 'challenges', 'chi', 'chin', 'chip', 'chic', 'chil', 'chico', "chico's", 'china', "china's", 'china,', 'chipo', 'chipot', 'chipotl', 'chile', 'chipotle', "chipotle's", 'chipotle,', 'chine', 'chines', 'chinese', "##rsday's", "##ursday's", "thursday's", "##esday's", "##uesday's"]

Encoding Words#

Now that we have our WordPiece vocabulary, we can use it to encode words:

def encode_word(word):
    tokens = []
    while len(word) > 0:
        i = len(word)
        while i > 0 and word[:i] not in vocab:
            i -= 1
        if i == 0:
            return ["[UNK]"]
        tokens.append(word[:i])
        word = word[i:]
        if len(word) > 0:
            word = f"##{word}"
    return tokens


print(encode_word("company"))
print(encode_word("companies"))
print(encode_word("회사"))  # UNK
['c', '##o', '##m', '##p', '##a', '##n', '##y']
['c', '##o', '##m', '##p', '##a', '##n', '##i', '##e', '##s']
['[UNK]']

The encode_word function takes a word and encodes it using the WordPiece vocabulary. It starts from the end of the word and finds the longest substring that is in the vocabulary. If no such substring is found, it returns the unknown token “[UNK]”. Otherwise, it adds the substring to the list of tokens and repeats the process with the remaining part of the word.

Tokenizing Text#

Finally, we can use our WordPiece tokenizer to tokenize text:

def tokenize(text):
    words = pre_tokenize(text)
    encoded_words = [encode_word(word) for word in words]
    return sum(encoded_words, [])


tokenized_text = tokenize("Investment opportunities in the company")
print(tokenized_text)
['investment', 'opportunities', 'in', 'the', 'c', '##o', '##m', '##p', '##a', '##n', '##y']

The tokenize function takes a text, splits it into words using the pre_tokenize function, encodes each word using the encode_word function, and returns the list of encoded words.

That’s it! You have now implemented WordPiece tokenization from scratch. You can use this knowledge to understand how subword tokenization works in models like BERT, DistilBERT, and RoBERTa.

The alphabet is the unique set composed of all the first letters of words, and all the other letters that appear in words prefixed by ##:

characters = []
for word in word_freqs.keys():
    if word[0] not in characters:
        characters.append(word[0])
    for letter in word[1:]:
        if f"##{letter}" not in characters:
            characters.append(f"##{letter}")

characters.sort()

print(characters)
['##0', '##1', '##2', '##3', '##4', '##5', '##6', '##7', '##8', '##9', '##a', '##b', '##c', '##d', '##e', '##f', '##g', '##h', '##i', '##j', '##k', '##l', '##m', '##n', '##o', '##p', '##q', '##r', '##s', '##t', '##u', '##v', '##w', '##x', '##y', '##z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']

Add the special tokens used by the model at the beginning of that vocabulary. In the case of BERT, it’s the list [“[PAD]”, “[UNK]”, “[CLS]”, “[SEP]”, “[MASK]”]:

vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] + characters.copy()

Split each word, with all the letters that are not the first prefixed by ##:

splits = {
    word: [c if i == 0 else f"##{c}" for i, c in enumerate(word)]
    for word in word_freqs.keys()
}

A function to compute the score of each pair:

def compute_pair_scores(splits):
    letter_freqs = defaultdict(int)
    pair_freqs = defaultdict(int)
    for word, freq in word_freqs.items():
        split = splits[word]
        if len(split) == 1:
            letter_freqs[split[0]] += freq
            continue
        for i in range(len(split) - 1):
            pair = (split[i], split[i + 1])
            letter_freqs[split[i]] += freq
            pair_freqs[pair] += freq
        letter_freqs[split[-1]] += freq

    scores = {
        pair: freq / (letter_freqs[pair[0]] * letter_freqs[pair[1]])
        for pair, freq in pair_freqs.items()
    }
    return scores
pair_scores = compute_pair_scores(splits)
for i, key in enumerate(pair_scores.keys()):
    print(f"{key}: {pair_scores[key]}")
    if i >= 5:
        break
('i', '##n'): 3.065532997859911e-05
('##n', '##v'): 5.217525332521506e-06
('##v', '##e'): 2.2967510416892118e-05
('##e', '##s'): 6.2108678847586545e-06
('##s', '##t'): 7.931201514160114e-06
('##t', '##i'): 8.905730802064189e-06

Find the pair with the highest score:

best_pair = ""
max_score = None
for pair, score in pair_scores.items():
    if max_score is None or max_score < score:
        best_pair = pair
        max_score = score

print(best_pair, max_score)
('8', '##9') 0.0004752098843655948

So the first merge to learn is (8, ##9) -> 89. Add it to the vocabulary:

vocab.append("89")

To continue, we need to apply that merge in our splits dictionary. A function for this:

def merge_pair(a, b, splits):
    for word in word_freqs:
        split = splits[word]
        if len(split) == 1:
            continue
        i = 0
        while i < len(split) - 1:
            if split[i] == a and split[i + 1] == b:
                merge = a + b[2:] if b.startswith("##") else a + b
                split = split[:i] + [merge] + split[i + 2 :]
            else:
                i += 1
        splits[word] = split
    return splits

And we can have a look at the result of the first merge:

splits = merge_pair("8", "##9", splits)
splits["8920"]
['89', '##2', '##0']

Now we have everything we need to loop until we have learned all the merges we want. For example, we can loop until we have a vocabulary of size 1000:

vocab_size = 1000
while len(vocab) < vocab_size:
    scores = compute_pair_scores(splits)
    best_pair, max_score = "", None
    for pair, score in scores.items():
        if max_score is None or max_score < score:
            best_pair = pair
            max_score = score
    splits = merge_pair(*best_pair, splits)
    new_token = (
        best_pair[0] + best_pair[1][2:]
        if best_pair[1].startswith("##")
        else best_pair[0] + best_pair[1]
    )
    vocab.append(new_token)
print("First 10 tokens: {}".format(vocab[:10]))
print("Last 50 tokens: {}".format(vocab[-50:]))
First 10 tokens: ['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]', '##0', '##1', '##2', '##3', '##4']
Last 50 tokens: ['thompso', 'accompan', 'accompany', 'accompani', 'thompson', 'compa', 'compan', 'company', 'compani', 'compar', 'compac', 'compas', 'compass', 'compact', 'employm', '##mpl', '##mply', '##imply', 'simply', '##amp', 'camp', 'ramp', 'hamp', 'damp', '##vamp', '##ymp', 'lymp', 'lymph', '##lymp', 'olymp', 'olympi', 'olympic', 'lympho', 'lymphom', 'lymphoma', 'campa', 'campai', 'campaig', 'campaign', '##rump', 'trump', 'bump', 'gump', 'pump', 'pumps', '##lump', 'slump', '##sump', '##sumpt', '##sumpti']

Encoding is done by finding the biggest subword in the vocabulary that is in the word, and splitting on it. Iterating on the word until it is empty:

def encode_word(word):
    tokens = []
    while len(word) > 0:
        i = len(word)
        while i > 0 and word[:i] not in vocab:
            i -= 1
        if i == 0:
            return ["[UNK]"]
        tokens.append(word[:i])
        word = word[i:]
        if len(word) > 0:
            word = f"##{word}"
    return tokens
print(encode_word("company"))
print(encode_word("companies"))
print(encode_word("회사"))
['company']
['compani', '##e', '##s']
['[UNK]']

To tokenize a sentence, we can apply this function to each word:

def tokenize(text):
    words = pre_tokenize(text)
    encoded_words = [encode_word(word) for word in words]
    return sum(encoded_words, [])
tokenized_text = tokenize("Investment opportunities in the company")
print(tokenized_text)
['i', '##n', '##v', '##e', '##s', '##t', '##m', '##e', '##n', '##t', 'opportuniti', '##e', '##s', 'i', '##n', 'th', '##e', 'company']
  • Then, we need to initialize our vocabulary to something larger than the vocabulary size we want.

  • We have to include all the basic characters (otherwise we won’t be able to tokenize every word).

  • For the bigger substrings, we can use the most frequent substrings in the corpus.

character_freqs = defaultdict(int)
subwords_freqs = defaultdict(int)
for word, freq in word_freqs.items():
    for i in range(len(word)):
        character_freqs[word[i]] += freq
        # Loop through the subwords of length at least 2
        for j in range(i + 2, len(word) + 1):
            subwords_freqs[word[i:j]] += freq

# Sort subwords by frequency
sorted_subwords = sorted(subwords_freqs.items(), key=lambda x: x[1], reverse=True)
print(sorted_subwords[:10])
[('in', 6437), ('th', 6241), ('he', 4833), ('er', 4585), ('re', 4475), ('an', 4311), ('the', 3977), ('on', 3842), ('es', 3360), ('ar', 3269)]
len(sorted_subwords)
60023

We group the characters with the best subwords to arrive at an initial vocabulary of size 2000:

token_freqs = (
    list(character_freqs.items()) + sorted_subwords[: 2000 - len(character_freqs)]
)
token_freqs = {token: freq for token, freq in token_freqs}
len(token_freqs)
2000

Next, we compute the sum of all frequencies, to convert the frequencies into probabilities.

from math import log

total_sum = sum([freq for token, freq in token_freqs.items()])
model = {token: -log(freq / total_sum) for token, freq in token_freqs.items()}

The main function is the one that tokenizes words using the Viterbi algorithm.

def encode_word(word, model):
    best_segmentations = [{"start": 0, "score": 1}] + [
        {"start": None, "score": None} for _ in range(len(word))
    ]
    for start_idx in range(len(word)):
        # This should be properly filled by the previous steps of the loop
        best_score_at_start = best_segmentations[start_idx]["score"]
        for end_idx in range(start_idx + 1, len(word) + 1):
            token = word[start_idx:end_idx]
            if token in model and best_score_at_start is not None:
                score = model[token] + best_score_at_start
                # If we have found a better segmentation ending at end_idx, we update
                if (
                    best_segmentations[end_idx]["score"] is None
                    or best_segmentations[end_idx]["score"] > score
                ):
                    best_segmentations[end_idx] = {"start": start_idx, "score": score}

    segmentation = best_segmentations[-1]
    if segmentation["score"] is None:
        # We did not find a tokenization of the word -> unknown
        return ["<unk>"], None

    score = segmentation["score"]
    start = segmentation["start"]
    end = len(word)
    tokens = []
    while start != 0:
        tokens.insert(0, word[start:end])
        next_start = best_segmentations[start]["start"]
        end = start
        start = next_start
    tokens.insert(0, word[start:end])
    return tokens, score
print(encode_word("apple", model))
print(encode_word("investment", model))
(['app', 'le'], 16.199784937807312)
(['investment'], 9.955290111180942)

Compute the loss:

def compute_loss(model):
    loss = 0
    for word, freq in word_freqs.items():
        _, word_loss = encode_word(word, model)
        loss += freq * word_loss
    return loss


compute_loss(model)
802891.4150846584

Computing the scores for each token:

import copy


def compute_scores(model):
    scores = {}
    model_loss = compute_loss(model)
    for token, score in model.items():
        # We always keep tokens of length 1
        if len(token) == 1:
            continue
        model_without_token = copy.deepcopy(model)
        _ = model_without_token.pop(token)
        scores[token] = compute_loss(model_without_token) - model_loss
    return scores


scores = compute_scores(model)
print(scores["app"])
print(scores["le"])
print(scores["investment"])
print(scores["invest"])
print(scores["ment"])
102.53267826826777
239.44326648849528
326.00756032345816
113.19295595935546
435.93838484131265

Iterate until we have the desired vocabulary size:

percent_to_remove = 0.1
while len(model) > 1000:
    scores = compute_scores(model)
    sorted_scores = sorted(scores.items(), key=lambda x: x[1])
    # Remove percent_to_remove tokens with the lowest scores.
    for i in range(int(len(model) * percent_to_remove)):
        _ = token_freqs.pop(sorted_scores[i][0])

    total_sum = sum([freq for token, freq in token_freqs.items()])
    model = {token: -log(freq / total_sum) for token, freq in token_freqs.items()}

To tokenize a sentence, we can apply this function to each word:

def tokenize(text, model):
    words = pre_tokenize(text)
    encoded_words = [encode_word(word, model)[0] for word in words]
    return sum(encoded_words, [])
tokenized_text = tokenize("investment opportunities in the company", model)
print(tokenized_text)
['investment', 'o', 'pport', 'un', 'ities', 'in', 'the', 'company']