BPE Step-by-Step Implementation#
In this lecture, we will walk through the implementation of Byte Pair Encoding (BPE), a popular subword tokenization method. We will use a dataset of financial news headlines for this demonstration.
Dataset Preparation#
First, we need to load our dataset. We will use the ashraq/financial-news
dataset from the Hugging Face Hub. We will randomly sample 1000 records from this dataset for our demonstration.
from datasets import load_dataset
dataset = load_dataset("ashraq/financial-news")
texts = dataset["train"].shuffle(seed=1234).select(range(1000))["headline"]
BPE Implementation#
Now, let’s dive into the implementation of BPE.
We start by initializing our vocabulary. We will format each word by separating its characters with spaces and appending a special end-of-word token </w>
import re, collections
def format_word(text, space_token="▁"):
return " ".join(list(text)) + " " + space_token
def initialize_vocab(texts, lowercase=True):
vocab = {}
for text in texts:
if lowercase:
text = text.lower()
text = re.sub(r"\s+", " ", text)
all_words = text.split()
for word in all_words:
word = format_word(word)
vocab[word] = vocab.get(word, 0) + 1
return vocab
vocab = initialize_vocab(texts)
print(f"Number of words: {len(vocab)}")
Number of words: 3636
Token Extraction#
Next, we extract all unique tokens from our vocabulary and count their frequencies.
def get_tokens_from_vocab(vocab):
tokens = collections.defaultdict(int)
vocab_tokenization = {}
for word, freq in vocab.items():
word_tokens = word.split()
for token in word_tokens:
tokens[token] += freq
vocab_tokenization["".join(word_tokens)] = word_tokens
return tokens, vocab_tokenization
tokens, vocab_tokenization = get_tokens_from_vocab(vocab)
print(f"Number of tokens: {len(tokens)}")
Number of tokens: 64
Bigram Counts#
We then count the frequency of each bigram (pair of consecutive tokens) in our vocabulary.
def get_bigram_counts(vocab):
pairs = {}
for word, count in vocab.items():
symbols = word.split()
for i in range(len(symbols) - 1):
pair = (symbols[i], symbols[i + 1])
pairs[pair] = pairs.get(pair, 0) + count
return pairs
Merge Operations#
The core of BPE is a series of merge operations. In each operation, we find the most frequent bigram and merge it into a single token. We repeat this process for a specified number of iterations.
def merge_vocab(pair, vocab_in):
vocab_out = {}
bigram = re.escape(" ".join(pair))
p = re.compile(r"(?<!\S)" + bigram + r"(?!\S)")
bytepair = "".join(pair)
for word in vocab_in:
w_out = p.sub(bytepair, word)
vocab_out[w_out] = vocab_in[word]
return vocab_out, (bigram, bytepair)
def find_merges(vocab, tokens, num_merges, indices_to_print=[0, 1, 2]):
merges = []
for i in range(num_merges):
pairs = get_bigram_counts(vocab)
best_pair = max(pairs, key=pairs.get)
best_count = pairs[best_pair]
vocab, (bigram, bytepair) = merge_vocab(best_pair, vocab)
merges.append((r"(?<!\S)" + bigram + r"(?!\S)", bytepair))
tokens, vocab_tokenization = get_tokens_from_vocab(vocab)
if i in indices_to_print:
print(f"Merge {i}: {best_pair} with count {best_count}")
print("All tokens: {}".format(tokens.keys()))
print("Number of tokens: {}".format(len(tokens.keys())))
return vocab, tokens, merges, vocab_tokenization
num_merges = 1000
indices_to_print = [0, 1, 2, num_merges - 1]
vocab, tokens, merges, vocab_tokenization = find_merges(
vocab, tokens, num_merges, indices_to_print
Encoding and Decoding#
Decoding is straightforward. We simply concatenate all the tokens together and remove the stop token </w>
. For example, if the encoded sequence is [the</w>
, high
, est</w>
, moun
, tain</w>
], the decoded sequence is the highest mountain
Encoding is a bit more complex. For a given sentence, we need to find the longest token in our vocabulary that is a subword of each word in the sentence. If no such token exists, we replace the word with an unknown token </u>
. This process is computationally expensive.
def measure_token_length(token, space_token="▁"):
space_token_len = len(space_token)
if token[-space_token_len:] == space_token:
return len(token) - space_token_len + 1
return len(token)
def encode_word(string, sorted_tokens, unknown_token="</u>"):
if string == "":
return []
sorted_tokens = sorted_tokens.copy()
if sorted_tokens == []:
return [unknown_token]
string_tokens = []
for i in range(len(sorted_tokens)):
token = sorted_tokens[i]
token_reg = re.escape(token.replace(".", "[.]"))
matched_positions = [
(m.start(0), m.end(0)) for m in re.finditer(token_reg, string)
if len(matched_positions) == 0:
substring_end_positions = [
matched_position[0] for matched_position in matched_positions
substring_start_position = 0
for substring_end_position in substring_end_positions:
substring = string[substring_start_position:substring_end_position]
string_tokens += encode_word(
sorted_tokens=sorted_tokens[i + 1 :],
string_tokens += [token]
substring_start_position = substring_end_position + len(token)
remaining_substring = string[substring_start_position:]
string_tokens += encode_word(
sorted_tokens=sorted_tokens[i + 1 :],
return string_tokens
We can now use this function to encode a given word.
def print_tokenization(word_given, sorted_tokens, vocab_tokenization):
print("Tokenizing word: {}...".format(word_given))
if word_given in vocab_tokenization:
print("Tokenization of the known word:")
print("Tokenization treating the known word as unknown:")
string=word_given, sorted_tokens=sorted_tokens, unknown_token="</u>"
print("Tokenizating of the unknown word:")
string=word_given, sorted_tokens=sorted_tokens, unknown_token="</u>"
# Sort tokens by length in descending order
sorted_tokens = sorted(tokens.keys(), key=len, reverse=True)
word_given_known = "investors▁"
print_tokenization(word_given_known, sorted_tokens, vocab_tokenization)
word_given_unknown = "dogecoin▁"
print_tokenization(word_given_unknown, sorted_tokens, vocab_tokenization)
Tokenizing word: investors▁...
Tokenization of the known word:
Tokenization treating the known word as unknown:
Tokenizing word: dogecoin▁...
Tokenizating of the unknown word:
['do', 'ge', 'co', 'in▁']
Finally, we can use our encoding function to tokenize an entire sentence.
def tokenize(text, space_token="▁"):
text = re.sub("\s+", " ", text.lower())
words = [word + space_token for word in text.split(" ")]
encoded_words = [
encode_word(word, sorted_tokens, unknown_token="</u>") for word in words
return sum(encoded_words, [])
tokenized_text = tokenize("Investment opportunities in the company")
['investment▁', 'op', 'port', 'un', 'ities▁', 'in▁', 'the▁', 'company▁']
That’s it! You have now implemented BPE from scratch. This should give you a good understanding of how subword tokenization works in practice.