How should you test your machine learning project? A Beginner's Guide | By François Porcher | July, 2024

A friendly introduction to testing machine learning projects, using standard libraries such as Pytest and Pytest-cov

Read for 10 minutes

WhatsApp Group Join Now
Telegram Group Join Now
Instagram Group Join Now

12 hours ago

Code testing, photo by the author

Testing is an important component of software development, but in my experience, it is widely neglected in machine learning projects. Many people know they should test their code, but many don't know how to do it and actually do it.

The purpose of this guide is to introduce you to the essentials for testing different parts of the machine learning pipeline. We will focus on fine-tuning BERT for text classification on the IMDb dataset and using industry standard libraries such as pytest And pytest-cov For testing

I strongly recommend you follow the code on this github repository:

Here is a brief overview of the project.

bert-text-classification/
├── src/
│ ├── data_loader.py
│ ├── evaluation.py
│ ├── main.py
│ ├── trainer.py
│ └── utils.py
├── tests/
│ ├── conftest.py
│ ├── test_data_loader.py
│ ├── test_evaluation.py
│ ├── test_main.py
│ ├── test_trainer.py
│ └── test_utils.py
├── models/
│ └── imdb_bert_finetuned.pth
├── environment.yml
├── requirements.txt
├── README.md
└── setup.py

A common practice is to split the code into several parts:

  • src: Contains the main files we use to load datasets, train and evaluate models.
  • tests: It contains various python scripts. Most of the time, there is a test file for each script. I personally use the following convention: If the script you want to test is called XXX.py Then the corresponding test script is called. test_XXX.py And located in tests Folder

For example if you want to test evaluation.py file, I use test_evaluation.py File

Note: In the test folder, you may notice a conftest.py file This file is not testing the function properly, but it contains some configuration information about the test, in particular fixtureswhich we will explain a little later.

You can just read this article, but I strongly recommend you to clone the repository and start playing with the code, because we always learn better by being active. To do this, you need to clone the Github repository, create the environment and get the model.

# clone github repo
git clone https://github.com/FrancoisPorcher/awesome-ai-tutorials/tree/main

# enter corresponding folder
cd MLOps/how_to_test/

# create environment
conda env create -f environment.yml
conda activate how_to_test

You will also need a model to run the analyses. To reproduce my results, you can run the main file. Training should take 2 to 20 minutes (depending on whether you have CUDA, MPS, or CPU).

python src/main.py

If you don't want to fix BERT (but I strongly suggest you fix BERT yourself), you can take the stock version of BERT, and one to get 2 classes with the following command A linear layer can include:

from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained(
"bert-base-uncased", num_labels=2
)

Now you are all set!

Let's write some tests:

But first, a quick introduction to Pytest.

pytest The industry has a standard and mature testing framework that makes writing tests easy.

Something to be awesome with. pytest This is so you can test at different levels of granularity: a function, a script, or the entire project. Let's learn how to do 3 options.

What does the test look like?

A test is a function that tests the behavior of another function. The convention is to call if you want to test the function. foo you call your test function. test_foo .

We then define a number of tests, to check that the function we are testing is behaving as we want.

Let's use an example to illustrate the ideas:

I data_loader.py script We are using a very standard function called clean_text which removes capital letters and white spaces, defined as follows:

def clean_text(text: str) -> str:
"""
Clean the input text by converting it to lowercase and stripping whitespace.

Args:
text (str): The text to clean.

Returns:
str: The cleaned text.
"""
return text.lower().strip()

We want to make sure this function behaves well, so I test_data_loader.py We can write a function to the file. test_clean_text

from src.data_loader import clean_text

def test_clean_text():
# test capital letters
assert clean_text("HeLlo, WoRlD!") == "hello, world!"
# test spaces removed
assert clean_text(" Spaces ") == "spaces"
# test empty string
assert clean_text("") == ""

Note that we use the function assert If there is a claim here. TrueNothing happens, if it is False, AssertionError is raised.

Now let's call the test. Run the following command in your terminal.

pytest tests/test_data_loader.py::test_clean_text

This terminal command means you're using pytest to run tests, specifically test_data_loader.py located in the script. tests folder, and you just want to run a test that is test_clean_text .

If the test passes, you should get:

Pytest test pass, photo by author

What happens when the test is not passed?

For the sake of this example, let's imagine that I edit. test_clean_text function for:

def clean_text(text: str) -> str:
# return text.lower().strip()
return text.lower()

Now the function doesn't remove spaces and is going to fail the tests. This is what we get when we run the test again:

Example of failed test, photo by author

This time we know why the test failed. Great!

Why would we even want to test a function?

Well, testing can take a long time. For a small project like this, reviewing the entire IMDb dataset can already take several minutes. Sometimes we want to test only one behavior without retesting the entire codebase each time.

Now let's move on to the next level of granularity: testing the script.

How to test the entire script?

Now we complicate things data_loader.py script and add a tokenize_text function, which takes as input a stringor a list of stringand outputs a tokenized version of the input.

# src/data_loader.py
import torch
from transformers import BertTokenizer

def clean_text(text: str) -> str:
"""
Clean the input text by converting it to lowercase and stripping whitespace.

Args:
text (str): The text to clean.

Returns:
str: The cleaned text.
"""
return text.lower().strip()

def tokenize_text(
text: str, tokenizer: BertTokenizer, max_length: int
) -> Dict[str, torch.Tensor]:
"""
Tokenize a single text using the BERT tokenizer.

Args:
text (str): The text to tokenize.
tokenizer (BertTokenizer): The tokenizer to use.
max_length (int): The maximum length of the tokenized sequence.

Returns:
Dict[str, torch.Tensor]: A dictionary containing the tokenized data.
"""
return tokenizer(
text,
padding="max_length",
truncation=True,
max_length=max_length,
return_tensors="pt",
)

Just so you can better understand what this function does, let's try with an example:

from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
txt = ["Hello, @! World! qwefqwef"]
tokenize_text(txt, tokenizer=tokenizer, max_length=16)

This will produce the following result:

{'input_ids': tensor([[ 101, 7592, 1010, 1030,  999, 2088,  999, 1053, 8545, 2546, 4160, 8545,2546,  102,    0,    0]]), 
'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])}
  • max_length: is the maximum length of a sequence. In this case we chose 16, but we can see that the sequence is of length 14, so we can see that the 2 last tokens are pads.
  • input_ids: Each token is converted to its corresponding id, which are the worlds that are part of the vocabulary. Note: The token is a 101 token. CLSand token_id is 102 tokens. SEP. These 2 tokens mark the beginning and end of a sentence. Read all the documents you need carefully for more details.
  • token_type_ids: It is not very important. If you feed 2 sequences as input, you will have 1 values ​​for the second sentence.
  • attention_mask: This tells the model what tokens it needs to participate in the self-attention mechanism. Since the sentence is bold, the focus method does not need to attend to the last 2 tokens, so there are 0s.

Now we write ours. test_tokenize_text function that will check that tokenize_text The function behaves appropriately:

def test_tokenize_text():
"""
Test the tokenize_text function to ensure it correctly tokenizes text using BERT tokenizer.
"""
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Example input texts
txt = ["Hello, @! World!",
"Spaces "]

# Tokenize the text
max_length = 128
res = tokenize_text(text=txt, tokenizer=tokenizer, max_length=max_length)

# let's test that the output is a dictionary and that the keys are correct
assert all(key in res for key in ["input_ids", "token_type_ids", "attention_mask"]), "Missing keys in the output dictionary."

# let's check the dimensions of the output tensors
assert res["input_ids"].shape[0] == len(txt), "Incorrect number of input_ids."
assert res['input_ids'].shape[1] == max_length, "Incorrect number of tokens."

# let's check that all the associated tensors are pytorch tensors
assert all(isinstance(res[key], torch.Tensor) for key in res), "Not all values are PyTorch tensors."

Run the full test for now test_data_loader.py file, which now contains 2 functions:

  • test_tokenize_text
  • test_clean_text

You can run the full test using this command from the terminal.

pytest tests/test_data_loader.py

And you should get this result:

Successful test for test_data_loader.py script, photo by author

Congratulations! Now you know how to test an entire script. Let's move on to the final level, examining the entire codebase.

How to test the entire codebase?

Continuing with the same logic, we can write other tests for each script, and your structure should be similar:

├── tests/
│ ├── conftest.py
│ ├── test_data_loader.py
│ ├── test_evaluation.py
│ ├── test_main.py
│ ├── test_trainer.py
│ └── test_utils.py

Now notice that in all these test functions, some variables are constant. For example tokenizer We use the same in all scripts. Pytest There is a good way to deal with it. Fixtures.

A fixture is a way to set up some context or condition before running a test and cleanup afterwards. They provide a mechanism for managing test dependencies and injecting reusable code into tests.

Defined using fixtures. @pytest.fixture The decorator

A good example of a tokenizer fixture we can use. For that, let's add it toconftest.py file located in tests Folder:

import pytest
from transformers import BertTokenizer

@pytest.fixture()
def bert_tokenizer():
"""Fixture to initialize the BERT tokenizer."""
return BertTokenizer.from_pretrained("bert-base-uncased")

And now me test_data_loader.py file, we can call fixtures. bert_tokenizer In the argument of test_tokenize_text.

def test_tokenize_text(bert_tokenizer):
"""
Test the tokenize_text function to ensure it correctly tokenizes text using BERT tokenizer.
"""
tokenizer = bert_tokenizer

# Example input texts
txt = ["Hello, @! World!",
"Spaces "]

# Tokenize the text
max_length = 128
res = tokenize_text(text=txt, tokenizer=tokenizer, max_length=max_length)

# let's test that the output is a dictionary and that the keys are correct
assert all(key in res for key in ["input_ids", "token_type_ids", "attention_mask"]), "Missing keys in the output dictionary."

# let's check the dimensions of the output tensors
assert res["input_ids"].shape[0] == len(txt), "Incorrect number of input_ids."
assert res['input_ids'].shape[1] == max_length, "Incorrect number of tokens."

# let's check that all the associated tensors are pytorch tensors
assert all(isinstance(res[key], torch.Tensor) for key in res), "Not all values are PyTorch tensors."

Fixtures are a very powerful and versatile tool. If you want to know more about them, the official dock is your go-to resource. But at least now, you have the tools to cover most ML testing.

Let's run the entire codebase from the terminal with the following command:

pytest tests

And you should get the following message:

Testing the entire codebase with Pytest, image by author

Congratulations!

In previous sections we have learned how to test code. In large projects, it is important to measure. coverage of your tests. In other words, how much of your code has been tested.

pytest-cov There is a plugin for pytest which generates test coverage reports.

That being said, don't be fooled by the coverage percentage. Just because you have 100% coverage doesn't mean your code is bug-free. It's just a tool for you to identify which parts of your code need further testing.

You can run the following command to generate a coverage report from the terminal.

pytest --cov=src --cov-report=html tests/

And you should get this:

Coverage with pytest-cov, image by author

Let's see how to read it:

  1. Statements: Total number of executable statements in the code. It enumerates all lines of code that can be executed, including conditionals, loops, and function calls.
  2. missing: This indicates the number of statements that were not executed during the test run. These are lines of code that were not included in any of the tests.
  3. Coverage: Percentage of total statements executed during the test. It is calculated by dividing the number of executed statements by the total number of statements.
  4. emitted: This refers to lines of code that are explicitly excluded from coverage measurements. This is useful for ignoring code that is not relevant to test coverage, such as debugging statements.

We can see that the coverage for main.py The file is 0%, this is normal, we did not write test_main.py File

We can also see that only 19 percent. evaluation The code is being tested, and that gives us an idea of ​​where we should focus first.

Congratulations, you've done it!

Thanks for reading! Before you go:

For more great tutorials, check out my compilation of AI tutorials on Github

YYou should receive my articles in your inbox. Subscribe here.

If you want to access premium articles on Medium, all you need is a subscription for just $5 a month. If you sign up. With my linkyou support me with a portion of your fee at no additional cost.

WhatsApp Group Join Now
Telegram Group Join Now
Instagram Group Join Now

Leave a Comment