Unit tests & Testing Practices
Last updated on 2024-12-19 | Edit this page
Overview
Questions
- What to do about complex functions & tests?
- What are some testing best practices for testing?
- How far should I go with testing?
- How do I add tests to an existing project?
Objectives
- Be able to write effective unit tests for more complex functions
- Understand the AAA pattern for structuring tests
- Understand the benefits of test driven development
- Know how to handle randomness in tests
But what about complicated functions?
Some of the functions that you write will be more complex, resulting in tests that are very complex and hard to debug if they fail. Take this function as an example:
PYTHON
def process_data(data: list, maximum_value: float):
# Remove negative values
data_negative_removed = []
for i in range(len(data)):
if data[i] >= 0:
data_negative_removed.append(data[i])
# Remove values above the maximum value
data_maximum_removed = []
for i in range(len(data_negative_removed)):
if data_negative_removed[i] <= maximum_value:
data_maximum_removed.append(data_negative_removed[i])
# Calculate the mean
mean = sum(data_maximum_removed) / len(data_maximum_removed)
# Calculate the standard deviation
variance = sum([(x - mean) ** 2 for x in data_maximum_removed]) / len(data_maximum_removed)
std_dev = variance ** 0.5
return mean, std_dev
A test for this function might look like this:
PYTHON
def test_process_data():
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
maximum_value = 5
mean, std_dev = process_data(data, maximum_value)
assert mean == 3
assert std_dev == 1.5811388300841898
This test is very complex and hard to debug if it fails. Imagine if the calculation of the mean broke - the test would fail but it would not tell us what part of the function was broken, requiring us to check each function manually to find the bug. Not very efficient!
Unit Testing
The process of unit testing is a fundamental part of software development. It is where you test individual units or components of a software instead of multiple things at once. For example, if you were adding tests to a car, you would want to test the wheels, the engine, the brakes, etc. separately to make sure they all work as expected before testing that the car could drive to the shops. The goal with unit testing is to validate that each unit of the software performs as designed. A unit is the smallest testable part of your code. A unit test usually has one or a few inputs and usually a single output.
The above function could usefully be broken down into smaller functions, each of which could be tested separately. This would make the tests easier to write and maintain.
PYTHON
def remove_negative_values(data: list):
data_negatives_removed = []
for i in range(len(data)):
if data[i] >= 0:
data_negatives_removed.append(data[i])
return data
def remove_values_above_maximum(data: list, maximum_value: float):
data_maximum_removed = []
for i in range(len(data)):
if data[i] <= maximum_value:
data_maximum_removed.append(data[i])
return data
def calculate_mean(data: list):
return sum(data) / len(data)
def calculate_std_dev(data: list):
mean = calculate_mean(data)
variance = sum([(x - mean) ** 2 for x in data]) / len(data)
return variance ** 0.5
def process_data(data: list, maximum_value: float):
# Remove negative values
data = remove_negative_values(data)
# Remove values above the maximum value
data = remove_values_above_maximum(data, maximum_value)
# Calculate the mean
mean = calculate_mean(data)
# Calculate the standard deviation
std_dev = calculate_std_dev(data)
return mean, std_dev
Now we can write tests for each of these functions separately:
PYTHON
def test_remove_negative_values():
data = [1, -2, 3, -4, 5, -6, 7, -8, 9, -10]
assert remove_negative_values(data) == [1, 3, 5, 7, 9]
def test_remove_values_above_maximum():
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
maximum_value = 5
assert remove_values_above_maximum(data, maximum_value) == [1, 2, 3, 4, 5]
def test_calculate_mean():
data = [1, 2, 3, 4, 5]
assert calculate_mean(data) == 3
def test_calculate_std_dev():
data = [1, 2, 3, 4, 5]
assert calculate_std_dev(data) == 1.5811388300841898
def test_process_data():
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
maximum_value = 5
mean, std_dev = process_data(data, maximum_value)
assert mean == 3
assert std_dev == 1.5811388300841898
These tests are much easier to read and understand, and if one of them fails, it is much easier to see which part of the function is broken. This is the principle of unit testing: breaking down complex functions into smaller, testable units.
AAA pattern
When writing tests, it is a good idea to follow the AAA pattern:
- Arrange: Set up the data and the conditions for the test
- Act: Perform the action that you are testing
- Assert: Check that the result of the action is what you expect
It is a standard pattern in unit testing and is used in many testing frameworks. This makes your tests easier to read and understand for both yourself and others reading your code.
Test Driven Development (TDD)
Test Driven Development (TDD) is a software development process that focuses on writing tests before writing the code. This can have several benefits:
- It forces you to think about the requirements of the code before you write it, this is especially useful in research.
- It can help you to write cleaner, more modular code by breaking down complex functions into smaller, testable units.
- It can help you to catch bugs early in the development process.
Without the test driven development process, you might write the code first and then try to write tests for it afterwards. This can lead to tests that are hard to write and maintain, and can result in bugs that are hard to find and fix.
The TDD process usually follows these steps:
- Write a failing test
- Write the minimum amount of code to make the test pass
- Refactor the code to make it clean and maintainable
Here is an example of the TDD process:
- Write a failing test
PYTHON
def test_calculate_mean():
# Arrange
data = [1, 2, 3, 4, 5]
# Act
mean = calculate_mean(data)
# Assert
assert mean == 3.5
- Write the minimum amount of code to make the test pass
PYTHON
def calculate_mean(data: list):
total = 0
for i in range(len(data)):
total += data[i]
mean = total / len(data)
return mean
- Refactor the code to make it clean and maintainable
This process can help you to write clean, maintainable code that is easy to test and debug.
Of course, in research, sometimes you might not know exactly what the requirements of the code are before you write it. In this case, you can still use the TDD process, but you might need to iterate on the tests and the code as you learn more about the problem you are trying to solve.
Randomness in tests
Some functions use randomness, which you might assume means we cannot write tests for them. However using random seeds, we can make this randomness deterministic and write tests for these functions.
PYTHON
import random
def random_number():
return random.randint(1, 10)
def test_random_number():
random.seed(0)
assert random_number() == 1
assert random_number() == 2
assert random_number() == 3
Random seeds work by setting the initial state of the random number generator. This means that if you set the seed to the same value, you will get the same sequence of random numbers each time you run the function.
Challenge: Write your own unit tests
Take this complex function, break it down and write unit tests for it.
- Create a new directory called
statistics
in your project directory - Create a new file called
stats.py
in thestatistics
directory - Write the following function in
stats.py
:
PYTHON
import random
def randomly_sample_and_filter_participants(
participants: list,
sample_size: int,
min_age: int,
max_age: int,
min_height: int,
max_height: int
):
"""Participants is a list of tuples, containing the age and height of each participant
participants = [
{age: 25, height: 180},
{age: 30, height: 170},
{age: 35, height: 160},
]
"""
# Get the indexes to sample
indexes = random.sample(range(len(participants)), sample_size)
# Get the sampled participants
sampled_participants = []
for i in indexes:
sampled_participants.append(participants[i])
# Remove participants that are outside the age range
sampled_participants_age_filtered = []
for participant in sampled_participants:
if participant['age'] >= min_age and participant['age'] <= max_age:
sampled_participants_age_filtered.append(participant)
# Remove participants that are outside the height range
sampled_participants_height_filtered = []
for participant in sampled_participants_age_filtered:
if participant['height'] >= min_height and participant['height'] <= max_height:
sampled_participants_height_filtered.append(participant)
return sampled_participants_height_filtered
- Create a new file called
test_stats.py
in thestatistics
directory - Write unit tests for the
randomly_sample_and_filter_participants
function intest_stats.py
The function can be broken down into smaller functions, each of which can be tested separately:
PYTHON
import random
def sample_participants(
participants: list,
sample_size: int
):
indexes = random.sample(range(len(participants)), sample_size)
sampled_participants = []
for i in indexes:
sampled_participants.append(participants[i])
return sampled_participants
def filter_participants_by_age(
participants: list,
min_age: int,
max_age: int
):
filtered_participants = []
for participant in participants:
if participant['age'] >= min_age and participant['age'] <= max_age:
filtered_participants.append(participant)
return filtered_participants
def filter_participants_by_height(
participants: list,
min_height: int,
max_height: int
):
filtered_participants = []
for participant in participants:
if participant['height'] >= min_height and participant['height'] <= max_height:
filtered_participants.append(participant)
return filtered_participants
def randomly_sample_and_filter_participants(
participants: list,
sample_size: int,
min_age: int,
max_age: int,
min_height: int,
max_height: int
):
sampled_participants = sample_participants(participants, sample_size)
age_filtered_participants = filter_participants_by_age(sampled_participants, min_age, max_age)
height_filtered_participants = filter_participants_by_height(age_filtered_participants, min_height, max_height)
return height_filtered_participants
Now we can write tests for each of these functions separately, remembering to set the random seed to make the randomness deterministic:
PYTHON
import random
def test_sample_participants():
# set random seed
random.seed(0)
participants = [
{'age': 25, 'height': 180},
{'age': 30, 'height': 170},
{'age': 35, 'height': 160},
]
sample_size = 2
sampled_participants = sample_participants(participants, sample_size)
expected = [{'age': 30, 'height': 170}, {'age': 35, 'height': 160}]
assert sampled_participants == expected
def test_filter_participants_by_age():
participants = [
{'age': 25, 'height': 180},
{'age': 30, 'height': 170},
{'age': 35, 'height': 160},
]
min_age = 30
max_age = 35
filtered_participants = filter_participants_by_age(participants, min_age, max_age)
expected = [{'age': 30, 'height': 170}, {'age': 35, 'height': 160}]
assert filtered_participants == expected
def test_filter_participants_by_height():
participants = [
{'age': 25, 'height': 180},
{'age': 30, 'height': 170},
{'age': 35, 'height': 160},
]
min_height = 160
max_height = 170
filtered_participants = filter_participants_by_height(participants, min_height, max_height)
expected = [{'age': 30, 'height': 170}, {'age': 35, 'height': 160}]
assert filtered_participants == expected
def test_randomly_sample_and_filter_participants():
# set random seed
random.seed(0)
participants = [
{"age": 25, "height": 180},
{"age": 30, "height": 170},
{"age": 35, "height": 160},
{"age": 38, "height": 165},
{"age": 40, "height": 190},
{"age": 45, "height": 200},
]
sample_size = 5
min_age = 28
max_age = 42
min_height = 159
max_height = 172
filtered_participants = randomly_sample_and_filter_participants(
participants, sample_size, min_age, max_age, min_height, max_height
)
expected = [{"age": 38, "height": 165}, {"age": 30, "height": 170}, {"age": 35, "height": 160}]
assert filtered_participants == expected
These tests are much easier to read and understand, and if one of them fails, it is much easier to see which part of the function is broken.
Adding tests to an existing project
You may have an existing project that does not have any tests yet. Adding tests to an existing project can be a daunting task and it can be hard to know where to start.
In general, it’s a good idea to start by adding regression tests to your most important functions. Regression tests are tests that simply check that the output of a function doesn’t change when you make changes to the code. They don’t check the individual components of the functions like unit testing does.
For example if you had a long processing pipeline that returns a single number, 23 when provided a certain set of inputs, you could write a regression test that checks that the output is still 23 when you make changes to the code.
After adding regression tests, you can start adding unit tests to the individual functions in your code, starting with the more commonly used / likely to break functions such as ones that handle data processing or input/output.
Should we aim for 100% test coverage?
Although tests add reliability to your code, it’s not always practicable to spend so much development time writing tests. When time is limited, it’s often better to only write tests for the most critical parts of the code as opposed to rigorously testing every function.
You should discuss with your team how much of the code you think should be tested, and what the most critical parts of the code are in order to prioritize your time.
Key Points
- Complex functions can be broken down into smaller, testable units.
- Testing each unit separately is called unit testing.
- The AAA pattern is a good way to structure your tests.
- Test driven development can help you to write clean, maintainable code.
- Randomness in tests can be made deterministic using random seeds.
- Adding tests to an existing project can be done incrementally, starting with regression tests.