WordPiece Step-by-Step Implementation#
In this lecture, we will walk you through the process of implementing WordPiece tokenization, a subword tokenization algorithm used in many state-of-the-art models like BERT, DistilBERT, and RoBERTa. We will use a dataset of financial news headlines for this purpose.
Dataset Preparation#
First, we need to load our dataset. We will use the ashraq/financial-news
dataset from the Hugging Face Hub. We will use the headline
column as our text data. We will randomly sample 1000 records from this dataset for our tokenization process. Here is how we can do this:
from datasets import load_dataset
dataset = load_dataset("ashraq/financial-news")
texts = dataset["train"].shuffle(seed=1234).select(range(1000))["headline"]
/home/yjlee/.cache/pypoetry/virtualenvs/lecture-_dERj_9R-py3.8/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Found cached dataset parquet (/home/yjlee/.cache/huggingface/datasets/ashraq___parquet/ashraq--financial-news-89d6ac597a40e29e/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
100%|██████████| 1/1 [00:00<00:00, 53.96it/s]
Loading cached shuffled indices for dataset at /home/yjlee/.cache/huggingface/datasets/ashraq___parquet/ashraq--financial-news-89d6ac597a40e29e/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-79cde8905e45a47f.arrow
Pre-tokenization#
Before we start with the WordPiece tokenization, we need to pre-tokenize our text. Pre-tokenization involves splitting the text into words. This is a necessary step because WordPiece operates on the word level. Here is a simple function to pre-tokenize the text:
import re
def pre_tokenize(text, lowercase=True):
if lowercase:
text = text.lower()
text = re.sub(r"\s+", " ", text)
return text.split(" ")
Vocabulary Initialization#
The next step is to initialize our vocabulary. In the case of WordPiece, we start with a vocabulary of individual characters. Here is how we can do this:
from collections import defaultdict
def initialize_vocab(texts, lowercase=True):
vocab = defaultdict(int)
for text in texts:
words = pre_tokenize(text, lowercase)
for word in words:
vocab[word] += 1
return vocab
word_freqs = initialize_vocab(texts)
print("Number of words: {}".format(len(word_freqs.keys())))
Number of words: 3636
The alphabet is the unique set composed of all the first letters of words, and all the other letters that appear in words prefixed by ##:
characters = []
for word in word_freqs.keys():
if word[0] not in characters:
characters.append(word[0])
for letter in word[1:]:
if f"##{letter}" not in characters:
characters.append(f"##{letter}")
characters = sorted(characters)
print(characters)
['"', '#', '##!', '##"', '##$', '##%', '##&', "##'", '##)', '##+', '##,', '##-', '##.', '##/', '##0', '##1', '##2', '##3', '##4', '##5', '##6', '##7', '##8', '##9', '##:', '##;', '##?', '##]', '##a', '##b', '##c', '##d', '##e', '##f', '##g', '##h', '##i', '##j', '##k', '##l', '##m', '##n', '##o', '##p', '##q', '##r', '##s', '##t', '##u', '##v', '##w', '##x', '##y', '##z', '##|', '##®', '##é', '##–', '##—', '##…', '$', '&', "'", '(', ')', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', '[', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '|', '~', '–', '—', '€']
Adding Special Tokens#
BERT and other models that use WordPiece tokenization use special tokens like “[PAD]”, “[UNK]”, “[CLS]”, “[SEP]”, “[MASK]”. We need to add these tokens to our vocabulary:
vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] + characters.copy()
len(vocab)
118
Splitting Words into Characters#
Next, we split each word into characters. We also add a special prefix “##” to all characters except the first one in each word. This prefix is used to indicate that a character is not the start of a new word:
splits = {
word: [c if i == 0 else f"##{c}" for i, c in enumerate(word)]
for word in word_freqs.keys()
}
Computing Pair Scores#
Now, we compute the scores for all possible pairs of characters. The score of a pair is defined as the frequency of the pair divided by the product of the frequencies of the individual characters. The idea is to find the pair that occurs together more often than separately:
def compute_pair_scores(splits, word_freqs):
letter_freqs = defaultdict(int)
pair_freqs = defaultdict(int)
for word, freq in word_freqs.items():
split = splits[word]
if len(split) == 1:
letter_freqs[split[0]] += freq
continue
for i in range(len(split) - 1):
pair = (split[i], split[i + 1])
letter_freqs[split[i]] += freq
pair_freqs[pair] += freq
letter_freqs[split[-1]] += freq
scores = {
pair: freq / (letter_freqs[pair[0]] * letter_freqs[pair[1]])
for pair, freq in pair_freqs.items()
}
return scores
pair_scores = compute_pair_scores(splits, word_freqs)
for i, key in enumerate(pair_scores.keys()):
print(f"{key}: {pair_scores[key]}")
if i >= 5:
break
('t', '##r'): 5.0029720626114525e-05
('##r', '##u'): 1.623534759879209e-05
('##u', '##d'): 9.79063701800694e-06
('##d', '##e'): 4.1215004357264235e-05
('##e', '##a'): 2.057878735731085e-05
('##a', '##u'): 4.019809621816311e-06
The compute_pair_scores
function calculates the scores for all possible pairs of characters in the vocabulary. The score of a pair is defined as the frequency of the pair divided by the product of the frequencies of the individual characters. This score is a measure of how often the pair occurs together compared to how often they occur separately.
Finding the Best Pair#
Next, we find the pair with the highest score. This pair is the best candidate for merging:
best_pair = ""
max_score = None
for pair, score in pair_scores.items():
if max_score is None or max_score < score:
best_pair = pair
max_score = score
print(best_pair, max_score)
('~', '##$') 0.3333333333333333
Merging the Best Pair#
Once we have identified the best pair, we merge it and update our splits:
def merge_pair(a, b, splits, word_freqs):
for word in word_freqs:
split = splits[word]
if len(split) == 1:
continue
i = 0
while i < len(split) - 1:
if split[i] == a and split[i + 1] == b:
merge = a + b[2:] if b.startswith("##") else a + b
split = split[:i] + [merge] + split[i + 2 :]
else:
i += 1
splits[word] = split
return splits
The merge_pair
function merges the best pair in all the words in our vocabulary. If the second character of the pair starts with “##”, we remove the “##” before merging. After merging, we replace the pair in the word with the merged character.
Repeating the Process#
We repeat the process of computing pair scores, finding the best pair, and merging the best pair until we reach our desired vocabulary size:
vocab_size = 1000
while len(vocab) < vocab_size:
scores = compute_pair_scores(splits, word_freqs)
best_pair, max_score = "", None
for pair, score in scores.items():
if max_score is None or max_score < score:
best_pair = pair
max_score = score
splits = merge_pair(*best_pair, splits, word_freqs)
new_token = (
best_pair[0] + best_pair[1][2:]
if best_pair[1].startswith("##")
else best_pair[0] + best_pair[1]
)
vocab.append(new_token)
print("First 10 tokens: {}".format(vocab[:10]))
print("Last 50 tokens: {}".format(vocab[-50:]))
First 10 tokens: ['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]', '"', '#', '##!', '##"', '##$']
Last 50 tokens: ['##tangibl', '##ntangibl', '##intangibl', "'intangibl", '##rangl', 'wrangl', 'dange', 'danger', '##eng', '##leng', '##lleng', '##alleng', 'challeng', 'challenge', 'challenge:', 'ange', 'anger', 'angers', "'intangible", 'wrangle', 'wrangler', 'change', 'changes', 'challenger', 'challenges', 'chi', 'chin', 'chip', 'chic', 'chil', 'chico', "chico's", 'china', "china's", 'china,', 'chipo', 'chipot', 'chipotl', 'chile', 'chipotle', "chipotle's", 'chipotle,', 'chine', 'chines', 'chinese', "##rsday's", "##ursday's", "thursday's", "##esday's", "##uesday's"]
Encoding Words#
Now that we have our WordPiece vocabulary, we can use it to encode words:
def encode_word(word):
tokens = []
while len(word) > 0:
i = len(word)
while i > 0 and word[:i] not in vocab:
i -= 1
if i == 0:
return ["[UNK]"]
tokens.append(word[:i])
word = word[i:]
if len(word) > 0:
word = f"##{word}"
return tokens
print(encode_word("company"))
print(encode_word("companies"))
print(encode_word("회사")) # UNK
['c', '##o', '##m', '##p', '##a', '##n', '##y']
['c', '##o', '##m', '##p', '##a', '##n', '##i', '##e', '##s']
['[UNK]']
The encode_word
function takes a word and encodes it using the WordPiece vocabulary. It starts from the end of the word and finds the longest substring that is in the vocabulary. If no such substring is found, it returns the unknown token “[UNK]”. Otherwise, it adds the substring to the list of tokens and repeats the process with the remaining part of the word.
Tokenizing Text#
Finally, we can use our WordPiece tokenizer to tokenize text:
def tokenize(text):
words = pre_tokenize(text)
encoded_words = [encode_word(word) for word in words]
return sum(encoded_words, [])
tokenized_text = tokenize("Investment opportunities in the company")
print(tokenized_text)
['investment', 'opportunities', 'in', 'the', 'c', '##o', '##m', '##p', '##a', '##n', '##y']
The tokenize
function takes a text, splits it into words using the pre_tokenize
function, encodes each word using the encode_word
function, and returns the list of encoded words.
That’s it! You have now implemented WordPiece tokenization from scratch. You can use this knowledge to understand how subword tokenization works in models like BERT, DistilBERT, and RoBERTa.
The alphabet is the unique set composed of all the first letters of words, and all the other letters that appear in words prefixed by ##:
characters = []
for word in word_freqs.keys():
if word[0] not in characters:
characters.append(word[0])
for letter in word[1:]:
if f"##{letter}" not in characters:
characters.append(f"##{letter}")
characters.sort()
print(characters)
['##0', '##1', '##2', '##3', '##4', '##5', '##6', '##7', '##8', '##9', '##a', '##b', '##c', '##d', '##e', '##f', '##g', '##h', '##i', '##j', '##k', '##l', '##m', '##n', '##o', '##p', '##q', '##r', '##s', '##t', '##u', '##v', '##w', '##x', '##y', '##z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
Add the special tokens used by the model at the beginning of that vocabulary. In the case of BERT, it’s the list [“[PAD]”, “[UNK]”, “[CLS]”, “[SEP]”, “[MASK]”]:
vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] + characters.copy()
Split each word, with all the letters that are not the first prefixed by ##:
splits = {
word: [c if i == 0 else f"##{c}" for i, c in enumerate(word)]
for word in word_freqs.keys()
}
A function to compute the score of each pair:
def compute_pair_scores(splits):
letter_freqs = defaultdict(int)
pair_freqs = defaultdict(int)
for word, freq in word_freqs.items():
split = splits[word]
if len(split) == 1:
letter_freqs[split[0]] += freq
continue
for i in range(len(split) - 1):
pair = (split[i], split[i + 1])
letter_freqs[split[i]] += freq
pair_freqs[pair] += freq
letter_freqs[split[-1]] += freq
scores = {
pair: freq / (letter_freqs[pair[0]] * letter_freqs[pair[1]])
for pair, freq in pair_freqs.items()
}
return scores
pair_scores = compute_pair_scores(splits)
for i, key in enumerate(pair_scores.keys()):
print(f"{key}: {pair_scores[key]}")
if i >= 5:
break
('i', '##n'): 3.065532997859911e-05
('##n', '##v'): 5.217525332521506e-06
('##v', '##e'): 2.2967510416892118e-05
('##e', '##s'): 6.2108678847586545e-06
('##s', '##t'): 7.931201514160114e-06
('##t', '##i'): 8.905730802064189e-06
Find the pair with the highest score:
best_pair = ""
max_score = None
for pair, score in pair_scores.items():
if max_score is None or max_score < score:
best_pair = pair
max_score = score
print(best_pair, max_score)
('8', '##9') 0.0004752098843655948
So the first merge to learn is (8
, ##9
) -> 89
. Add it to the vocabulary:
vocab.append("89")
To continue, we need to apply that merge in our splits dictionary. A function for this:
def merge_pair(a, b, splits):
for word in word_freqs:
split = splits[word]
if len(split) == 1:
continue
i = 0
while i < len(split) - 1:
if split[i] == a and split[i + 1] == b:
merge = a + b[2:] if b.startswith("##") else a + b
split = split[:i] + [merge] + split[i + 2 :]
else:
i += 1
splits[word] = split
return splits
And we can have a look at the result of the first merge:
splits = merge_pair("8", "##9", splits)
splits["8920"]
['89', '##2', '##0']
Now we have everything we need to loop until we have learned all the merges we want. For example, we can loop until we have a vocabulary of size 1000:
vocab_size = 1000
while len(vocab) < vocab_size:
scores = compute_pair_scores(splits)
best_pair, max_score = "", None
for pair, score in scores.items():
if max_score is None or max_score < score:
best_pair = pair
max_score = score
splits = merge_pair(*best_pair, splits)
new_token = (
best_pair[0] + best_pair[1][2:]
if best_pair[1].startswith("##")
else best_pair[0] + best_pair[1]
)
vocab.append(new_token)
print("First 10 tokens: {}".format(vocab[:10]))
print("Last 50 tokens: {}".format(vocab[-50:]))
First 10 tokens: ['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]', '##0', '##1', '##2', '##3', '##4']
Last 50 tokens: ['thompso', 'accompan', 'accompany', 'accompani', 'thompson', 'compa', 'compan', 'company', 'compani', 'compar', 'compac', 'compas', 'compass', 'compact', 'employm', '##mpl', '##mply', '##imply', 'simply', '##amp', 'camp', 'ramp', 'hamp', 'damp', '##vamp', '##ymp', 'lymp', 'lymph', '##lymp', 'olymp', 'olympi', 'olympic', 'lympho', 'lymphom', 'lymphoma', 'campa', 'campai', 'campaig', 'campaign', '##rump', 'trump', 'bump', 'gump', 'pump', 'pumps', '##lump', 'slump', '##sump', '##sumpt', '##sumpti']
Encoding is done by finding the biggest subword in the vocabulary that is in the word, and splitting on it. Iterating on the word until it is empty:
def encode_word(word):
tokens = []
while len(word) > 0:
i = len(word)
while i > 0 and word[:i] not in vocab:
i -= 1
if i == 0:
return ["[UNK]"]
tokens.append(word[:i])
word = word[i:]
if len(word) > 0:
word = f"##{word}"
return tokens
print(encode_word("company"))
print(encode_word("companies"))
print(encode_word("회사"))
['company']
['compani', '##e', '##s']
['[UNK]']
To tokenize a sentence, we can apply this function to each word:
def tokenize(text):
words = pre_tokenize(text)
encoded_words = [encode_word(word) for word in words]
return sum(encoded_words, [])
tokenized_text = tokenize("Investment opportunities in the company")
print(tokenized_text)
['i', '##n', '##v', '##e', '##s', '##t', '##m', '##e', '##n', '##t', 'opportuniti', '##e', '##s', 'i', '##n', 'th', '##e', 'company']
Then, we need to initialize our vocabulary to something larger than the vocabulary size we want.
We have to include all the basic characters (otherwise we won’t be able to tokenize every word).
For the bigger substrings, we can use the most frequent substrings in the corpus.
character_freqs = defaultdict(int)
subwords_freqs = defaultdict(int)
for word, freq in word_freqs.items():
for i in range(len(word)):
character_freqs[word[i]] += freq
# Loop through the subwords of length at least 2
for j in range(i + 2, len(word) + 1):
subwords_freqs[word[i:j]] += freq
# Sort subwords by frequency
sorted_subwords = sorted(subwords_freqs.items(), key=lambda x: x[1], reverse=True)
print(sorted_subwords[:10])
[('in', 6437), ('th', 6241), ('he', 4833), ('er', 4585), ('re', 4475), ('an', 4311), ('the', 3977), ('on', 3842), ('es', 3360), ('ar', 3269)]
len(sorted_subwords)
60023
We group the characters with the best subwords to arrive at an initial vocabulary of size 2000:
token_freqs = (
list(character_freqs.items()) + sorted_subwords[: 2000 - len(character_freqs)]
)
token_freqs = {token: freq for token, freq in token_freqs}
len(token_freqs)
2000
Next, we compute the sum of all frequencies, to convert the frequencies into probabilities.
from math import log
total_sum = sum([freq for token, freq in token_freqs.items()])
model = {token: -log(freq / total_sum) for token, freq in token_freqs.items()}
The main function is the one that tokenizes words using the Viterbi algorithm.
def encode_word(word, model):
best_segmentations = [{"start": 0, "score": 1}] + [
{"start": None, "score": None} for _ in range(len(word))
]
for start_idx in range(len(word)):
# This should be properly filled by the previous steps of the loop
best_score_at_start = best_segmentations[start_idx]["score"]
for end_idx in range(start_idx + 1, len(word) + 1):
token = word[start_idx:end_idx]
if token in model and best_score_at_start is not None:
score = model[token] + best_score_at_start
# If we have found a better segmentation ending at end_idx, we update
if (
best_segmentations[end_idx]["score"] is None
or best_segmentations[end_idx]["score"] > score
):
best_segmentations[end_idx] = {"start": start_idx, "score": score}
segmentation = best_segmentations[-1]
if segmentation["score"] is None:
# We did not find a tokenization of the word -> unknown
return ["<unk>"], None
score = segmentation["score"]
start = segmentation["start"]
end = len(word)
tokens = []
while start != 0:
tokens.insert(0, word[start:end])
next_start = best_segmentations[start]["start"]
end = start
start = next_start
tokens.insert(0, word[start:end])
return tokens, score
print(encode_word("apple", model))
print(encode_word("investment", model))
(['app', 'le'], 16.199784937807312)
(['investment'], 9.955290111180942)
Compute the loss:
def compute_loss(model):
loss = 0
for word, freq in word_freqs.items():
_, word_loss = encode_word(word, model)
loss += freq * word_loss
return loss
compute_loss(model)
802891.4150846584
Computing the scores for each token:
import copy
def compute_scores(model):
scores = {}
model_loss = compute_loss(model)
for token, score in model.items():
# We always keep tokens of length 1
if len(token) == 1:
continue
model_without_token = copy.deepcopy(model)
_ = model_without_token.pop(token)
scores[token] = compute_loss(model_without_token) - model_loss
return scores
scores = compute_scores(model)
print(scores["app"])
print(scores["le"])
print(scores["investment"])
print(scores["invest"])
print(scores["ment"])
102.53267826826777
239.44326648849528
326.00756032345816
113.19295595935546
435.93838484131265
Iterate until we have the desired vocabulary size:
percent_to_remove = 0.1
while len(model) > 1000:
scores = compute_scores(model)
sorted_scores = sorted(scores.items(), key=lambda x: x[1])
# Remove percent_to_remove tokens with the lowest scores.
for i in range(int(len(model) * percent_to_remove)):
_ = token_freqs.pop(sorted_scores[i][0])
total_sum = sum([freq for token, freq in token_freqs.items()])
model = {token: -log(freq / total_sum) for token, freq in token_freqs.items()}
To tokenize a sentence, we can apply this function to each word:
def tokenize(text, model):
words = pre_tokenize(text)
encoded_words = [encode_word(word, model)[0] for word in words]
return sum(encoded_words, [])
tokenized_text = tokenize("investment opportunities in the company", model)
print(tokenized_text)
['investment', 'o', 'pport', 'un', 'ities', 'in', 'the', 'company']