A friendly introduction to testing machine learning projects, using standard libraries such as Pytest and Pytest-cov
Testing is an important component of software development, but in my experience, it is widely neglected in machine learning projects. Many people know they should test their code, but many don't know how to do it and actually do it.
The purpose of this guide is to introduce you to the essentials for testing different parts of the machine learning pipeline. We will focus on fine-tuning BERT for text classification on the IMDb dataset and using industry standard libraries such as pytest
And pytest-cov
For testing
I strongly recommend you follow the code on this github repository:
Here is a brief overview of the project.
bert-text-classification/
├── src/
│ ├── data_loader.py
│ ├── evaluation.py
│ ├── main.py
│ ├── trainer.py
│ └── utils.py
├── tests/
│ ├── conftest.py
│ ├── test_data_loader.py
│ ├── test_evaluation.py
│ ├── test_main.py
│ ├── test_trainer.py
│ └── test_utils.py
├── models/
│ └── imdb_bert_finetuned.pth
├── environment.yml
├── requirements.txt
├── README.md
└── setup.py
A common practice is to split the code into several parts:
src:
Contains the main files we use to load datasets, train and evaluate models.tests:
It contains various python scripts. Most of the time, there is a test file for each script. I personally use the following convention: If the script you want to test is calledXXX.py
Then the corresponding test script is called.test_XXX.py
And located intests
Folder
For example if you want to test evaluation.py
file, I use test_evaluation.py
File
Note: In the test folder, you may notice a conftest.py
file This file is not testing the function properly, but it contains some configuration information about the test, in particular fixtures
which we will explain a little later.
You can just read this article, but I strongly recommend you to clone the repository and start playing with the code, because we always learn better by being active. To do this, you need to clone the Github repository, create the environment and get the model.
# clone github repo
git clone https://github.com/FrancoisPorcher/awesome-ai-tutorials/tree/main# enter corresponding folder
cd MLOps/how_to_test/
# create environment
conda env create -f environment.yml
conda activate how_to_test
You will also need a model to run the analyses. To reproduce my results, you can run the main file. Training should take 2 to 20 minutes (depending on whether you have CUDA, MPS, or CPU).
python src/main.py
If you don't want to fix BERT (but I strongly suggest you fix BERT yourself), you can take the stock version of BERT, and one to get 2 classes with the following command A linear layer can include:
from transformers import BertForSequenceClassificationmodel = BertForSequenceClassification.from_pretrained(
"bert-base-uncased", num_labels=2
)
Now you are all set!
Let's write some tests:
But first, a quick introduction to Pytest.
pytest
The industry has a standard and mature testing framework that makes writing tests easy.
Something to be awesome with. pytest
This is so you can test at different levels of granularity: a function, a script, or the entire project. Let's learn how to do 3 options.
What does the test look like?
A test is a function that tests the behavior of another function. The convention is to call if you want to test the function. foo
you call your test function. test_foo
.
We then define a number of tests, to check that the function we are testing is behaving as we want.
Let's use an example to illustrate the ideas:
I data_loader.py
script We are using a very standard function called clean_text
which removes capital letters and white spaces, defined as follows:
def clean_text(text: str) -> str:
"""
Clean the input text by converting it to lowercase and stripping whitespace.Args:
text (str): The text to clean.
Returns:
str: The cleaned text.
"""
return text.lower().strip()
We want to make sure this function behaves well, so I test_data_loader.py
We can write a function to the file. test_clean_text
from src.data_loader import clean_textdef test_clean_text():
# test capital letters
assert clean_text("HeLlo, WoRlD!") == "hello, world!"
# test spaces removed
assert clean_text(" Spaces ") == "spaces"
# test empty string
assert clean_text("") == ""
Note that we use the function assert
If there is a claim here. True
Nothing happens, if it is False
, AssertionError
is raised.
Now let's call the test. Run the following command in your terminal.
pytest tests/test_data_loader.py::test_clean_text
This terminal command means you're using pytest to run tests, specifically test_data_loader.py
located in the script. tests
folder, and you just want to run a test that is test_clean_text
.
If the test passes, you should get:
What happens when the test is not passed?
For the sake of this example, let's imagine that I edit. test_clean_text
function for:
def clean_text(text: str) -> str:
# return text.lower().strip()
return text.lower()
Now the function doesn't remove spaces and is going to fail the tests. This is what we get when we run the test again:
This time we know why the test failed. Great!
Why would we even want to test a function?
Well, testing can take a long time. For a small project like this, reviewing the entire IMDb dataset can already take several minutes. Sometimes we want to test only one behavior without retesting the entire codebase each time.
Now let's move on to the next level of granularity: testing the script.
How to test the entire script?
Now we complicate things data_loader.py
script and add a tokenize_text
function, which takes as input a string
or a list of string
and outputs a tokenized version of the input.
# src/data_loader.py
import torch
from transformers import BertTokenizerdef clean_text(text: str) -> str:
"""
Clean the input text by converting it to lowercase and stripping whitespace.
Args:
text (str): The text to clean.
Returns:
str: The cleaned text.
"""
return text.lower().strip()
def tokenize_text(
text: str, tokenizer: BertTokenizer, max_length: int
) -> Dict[str, torch.Tensor]:
"""
Tokenize a single text using the BERT tokenizer.
Args:
text (str): The text to tokenize.
tokenizer (BertTokenizer): The tokenizer to use.
max_length (int): The maximum length of the tokenized sequence.
Returns:
Dict[str, torch.Tensor]: A dictionary containing the tokenized data.
"""
return tokenizer(
text,
padding="max_length",
truncation=True,
max_length=max_length,
return_tensors="pt",
)
Just so you can better understand what this function does, let's try with an example:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
txt = ["Hello, @! World! qwefqwef"]
tokenize_text(txt, tokenizer=tokenizer, max_length=16)
This will produce the following result:
{'input_ids': tensor([[ 101, 7592, 1010, 1030, 999, 2088, 999, 1053, 8545, 2546, 4160, 8545,2546, 102, 0, 0]]),
'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])}
max_length
: is the maximum length of a sequence. In this case we chose 16, but we can see that the sequence is of length 14, so we can see that the 2 last tokens are pads.input_ids
: Each token is converted to its corresponding id, which are the worlds that are part of the vocabulary. Note: The token is a 101 token.CLS
and token_id is 102 tokens.SEP
. These 2 tokens mark the beginning and end of a sentence. Read all the documents you need carefully for more details.token_type_ids
: It is not very important. If you feed 2 sequences as input, you will have 1 values for the second sentence.attention_mask
: This tells the model what tokens it needs to participate in the self-attention mechanism. Since the sentence is bold, the focus method does not need to attend to the last 2 tokens, so there are 0s.
Now we write ours. test_tokenize_text
function that will check that tokenize_text
The function behaves appropriately:
def test_tokenize_text():
"""
Test the tokenize_text function to ensure it correctly tokenizes text using BERT tokenizer.
"""
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")# Example input texts
txt = ["Hello, @! World!",
"Spaces "]
# Tokenize the text
max_length = 128
res = tokenize_text(text=txt, tokenizer=tokenizer, max_length=max_length)
# let's test that the output is a dictionary and that the keys are correct
assert all(key in res for key in ["input_ids", "token_type_ids", "attention_mask"]), "Missing keys in the output dictionary."
# let's check the dimensions of the output tensors
assert res["input_ids"].shape[0] == len(txt), "Incorrect number of input_ids."
assert res['input_ids'].shape[1] == max_length, "Incorrect number of tokens."
# let's check that all the associated tensors are pytorch tensors
assert all(isinstance(res[key], torch.Tensor) for key in res), "Not all values are PyTorch tensors."
Run the full test for now test_data_loader
.py file, which now contains 2 functions:
test_tokenize_text
test_clean_text
You can run the full test using this command from the terminal.
pytest tests/test_data_loader.py
And you should get this result:
Congratulations! Now you know how to test an entire script. Let's move on to the final level, examining the entire codebase.
How to test the entire codebase?
Continuing with the same logic, we can write other tests for each script, and your structure should be similar:
├── tests/
│ ├── conftest.py
│ ├── test_data_loader.py
│ ├── test_evaluation.py
│ ├── test_main.py
│ ├── test_trainer.py
│ └── test_utils.py
Now notice that in all these test functions, some variables are constant. For example tokenizer
We use the same in all scripts. Pytest
There is a good way to deal with it. Fixtures.
A fixture is a way to set up some context or condition before running a test and cleanup afterwards. They provide a mechanism for managing test dependencies and injecting reusable code into tests.
Defined using fixtures. @pytest.fixture
The decorator
A good example of a tokenizer fixture we can use. For that, let's add it toconftest.py
file located in tests
Folder:
import pytest
from transformers import BertTokenizer@pytest.fixture()
def bert_tokenizer():
"""Fixture to initialize the BERT tokenizer."""
return BertTokenizer.from_pretrained("bert-base-uncased")
And now me test_data_loader.py
file, we can call fixtures. bert_tokenizer
In the argument of test_tokenize_text.
def test_tokenize_text(bert_tokenizer):
"""
Test the tokenize_text function to ensure it correctly tokenizes text using BERT tokenizer.
"""
tokenizer = bert_tokenizer# Example input texts
txt = ["Hello, @! World!",
"Spaces "]
# Tokenize the text
max_length = 128
res = tokenize_text(text=txt, tokenizer=tokenizer, max_length=max_length)
# let's test that the output is a dictionary and that the keys are correct
assert all(key in res for key in ["input_ids", "token_type_ids", "attention_mask"]), "Missing keys in the output dictionary."
# let's check the dimensions of the output tensors
assert res["input_ids"].shape[0] == len(txt), "Incorrect number of input_ids."
assert res['input_ids'].shape[1] == max_length, "Incorrect number of tokens."
# let's check that all the associated tensors are pytorch tensors
assert all(isinstance(res[key], torch.Tensor) for key in res), "Not all values are PyTorch tensors."
Fixtures are a very powerful and versatile tool. If you want to know more about them, the official dock is your go-to resource. But at least now, you have the tools to cover most ML testing.
Let's run the entire codebase from the terminal with the following command:
pytest tests
And you should get the following message:
Congratulations!
In previous sections we have learned how to test code. In large projects, it is important to measure. coverage
of your tests. In other words, how much of your code has been tested.
pytest-cov
There is a plugin for pytest
which generates test coverage reports.
That being said, don't be fooled by the coverage percentage. Just because you have 100% coverage doesn't mean your code is bug-free. It's just a tool for you to identify which parts of your code need further testing.
You can run the following command to generate a coverage report from the terminal.
pytest --cov=src --cov-report=html tests/
And you should get this:
Let's see how to read it:
- Statements: Total number of executable statements in the code. It enumerates all lines of code that can be executed, including conditionals, loops, and function calls.
- missing: This indicates the number of statements that were not executed during the test run. These are lines of code that were not included in any of the tests.
- Coverage: Percentage of total statements executed during the test. It is calculated by dividing the number of executed statements by the total number of statements.
- emitted: This refers to lines of code that are explicitly excluded from coverage measurements. This is useful for ignoring code that is not relevant to test coverage, such as debugging statements.
We can see that the coverage for main.py
The file is 0%, this is normal, we did not write test_main.py
File
We can also see that only 19 percent. evaluation
The code is being tested, and that gives us an idea of where we should focus first.
Congratulations, you've done it!
Thanks for reading! Before you go:
For more great tutorials, check out my compilation of AI tutorials on Github
YYou should receive my articles in your inbox. Subscribe here.
If you want to access premium articles on Medium, all you need is a subscription for just $5 a month. If you sign up. With my linkyou support me with a portion of your fee at no additional cost.