Functional table-driven tests in Go

Functional table-driven tests in Go
Photo by Christian Holzinger / Unsplash

There are numerous blog posts about table-driven tests in Go. In this blog post, I want to show a technique we have recently started using in our unit tests. I saw it the first time when my friend Matt Layher introduced it to me at work, and since then, I gradually started using it in places where it's suitable. This technique comes especially handy for large structs with multiple nested fields.

First, let me describe the problem with large structs (with multiple nested fields).

Problem

Table-driven tests are OK when the input data to a function consists of single variables and types. Here is an example from the Go wiki:

var flagtests = []struct {
	in  string
	out string
}{
	{"%a", "[%a]"},
	{"%-a", "[%-a]"},
	{"%+a", "[%+a]"},
	{"%#a", "[%#a]"},
	{"% a", "[% a]"},
	{"%0a", "[%0a]"},
	{"%1.2a", "[%1.2a]"},
	{"%-1.2a", "[%-1.2a]"},
	{"%+1.2a", "[%+1.2a]"},
	{"%-+1.2a", "[%+-1.2a]"},
	{"%-+1.2abc", "[%+-1.2a]bc"},
	{"%-1.2abc", "[%-1.2a]bc"},
}
func TestFlagParser(t *testing.T) {
	var flagprinter flagPrinter
	for _, tt := range flagtests {
		t.Run(tt.in, func(t *testing.T) {
			s := Sprintf(tt.in, &flagprinter)
			if s != tt.out {
				t.Errorf("got %q, want %q", s, tt.out)
			}
		})
	}
}

As you see, it's perfectly fine, and you can test multiple cases by adding a new item to the flagtests slice. It has a string input and a string output. This is a prime example of an excellent table-driven test.

However, if your input and output parameters are structs, the test case in the table might be more than a few lines of code. This leads to tables that are pretty hard to read, which defeats the first purpose of table-driven tests: readability. Let me show an example. Assume we have a validate() function that ensures that a Kubernetes Pod is valid:

func validate(pod *corev1.Pod) error {
	if pod.Name == "" {
		return errors.New("pod.Name is empty")
	}

	if pod.Namespace == "" {
		return errors.New("pod.Namespace is empty")
	}

	if _, ok := pod.Annotations["ready"]; !ok {
		return errors.New("pod.Annotations['ready'] key is not set")
	}

	if len(pod.Spec.Containers) == 0 {
		return errors.New("spec.Containers is empty")
	}

	for _, container := range pod.Spec.Containers {
		if container.Name == "" {
			return errors.New("container.Name is empty")
		}

		if container.Image == "" {
			return errors.New("container.Image is empty")
		}

		if len(container.Command) == 0 {
			return errors.New("container.Command is not set")
		}

		if len(container.Ports) == 0 {
			return errors.New("container.Ports is not set")
		}
	}

	return nil
}

The validate function checks where a given Pod is valid. For the sake of this post, we keep it simple and only check for a couple of things.

And let's write a table-driven test with a single test case. We'll add more:

package main

import (
	"testing"

	corev1 "k8s.io/api/core/v1"
	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

func TestValidate(t *testing.T) {
	tests := []struct {
		name string
		pod  *corev1.Pod
		err  string
	}{
		{
			name: "valid pod",
			pod: &corev1.Pod{
				ObjectMeta: metav1.ObjectMeta{
					Namespace: "default",
					Name:      "pod-123",
					Annotations: map[string]string{
						"ready": "ensure that this annotation is set",
					},
				},
				Spec: corev1.PodSpec{
					Containers: []corev1.Container{
						{
							Name:  "some-container",
							Image: "fatih/foo:test",
							Command: []string{
								"./foo",
								"--port=8800",
							},
							Ports: []corev1.ContainerPort{
								{
									Name:          "http",
									ContainerPort: 8800,
									Protocol:      corev1.ProtocolTCP,
								},
							},
						},
					},
				},
			},
		},
	}

	for _, tt := range tests {
		tt := tt
		t.Run(tt.name, func(t *testing.T) {
			err := validate(tt.pod)
			// should it error?
			if tt.err != "" {
				if err == nil {
					t.Fatal("validate should error, but got non-nil error")
					return
				}

				if err.Error() != tt.err {
					t.Errorf("err msg want: %s got: %s", tt.err, err.Error())
				}

				return
			}

			// should not error
			if err != nil {
				t.Fatalf("validate error: %s", err)
			}
		})
	}
}

So as you see, it's already pretty long, even though we only have a single test case. Let's run the test:

$ go test -v
=== RUN   TestValidate
=== RUN   TestValidate/valid_pod
--- PASS: TestValidate (0.00s)
    --- PASS: TestValidate/valid_pod (0.00s)
PASS
ok      demo    0.771s

It's all good!

Now let's add another case where we want to check that the container.Image is missing, and that the validate function errors with a warning:

		{
			name: "invalid pod, image is not set",
			pod: &corev1.Pod{
				ObjectMeta: metav1.ObjectMeta{
					Namespace: "default",
					Name:      "pod-123",
					Annotations: map[string]string{
						"ready": "ensure that this annotation is set",
					},
				},
				Spec: corev1.PodSpec{
					Containers: []corev1.Container{
						{
							Name: "some-container",
							Command: []string{
								"./foo",
								"--port=8800",
							},
							Ports: []corev1.ContainerPort{
								{
									Name:          "http",
									ContainerPort: 8800,
									Protocol:      corev1.ProtocolTCP,
								},
							},
						},
					},
				},
			},
			err: "container.Image is empty",
		},

Let's rerun it:

$ go test -v
=== RUN   TestValidate
=== RUN   TestValidate/valid_pod
=== RUN   TestValidate/invalid_pod,_image_is_not_set
--- PASS: TestValidate (0.00s)
    --- PASS: TestValidate/valid_pod (0.00s)
    --- PASS: TestValidate/invalid_pod,_image_is_not_set (0.00s)
PASS
ok      demo    0.305s

Let's another case, this time, the Ports slice is not set, and we want to make sure validate() errors:

		{
			name: "invalid pod, ports is not set",
			pod: &corev1.Pod{
				ObjectMeta: metav1.ObjectMeta{
					Namespace: "default",
					Name:      "pod-123",
					Annotations: map[string]string{
						"ready": "ensure that this annotation is set",
					},
				},
				Spec: corev1.PodSpec{
					Containers: []corev1.Container{
						{
							Name:  "some-container",
							Image: "fatih/foo:test",
							Command: []string{
								"./foo",
								"--port=8800",
							},
						},
					},
				},
			},
			err: "container.Ports is not set",
		},
	}

Let's run our test:

$ go test -v
=== RUN   TestValidate
=== RUN   TestValidate/valid_pod
=== RUN   TestValidate/invalid_pod,_image_is_not_set
=== RUN   TestValidate/invalid_pod,_ports_is_not_set
--- PASS: TestValidate (0.00s)
    --- PASS: TestValidate/valid_pod (0.00s)
    --- PASS: TestValidate/invalid_pod,_image_is_not_set (0.00s)
    --- PASS: TestValidate/invalid_pod,_ports_is_not_set (0.00s)
PASS
ok      demo    0.298s

So far, all is good. However, now look at how big our table-driven test became. Because the file is too long, I will share it as a snippet. Please open it:

https://go.dev/play/p/ekc3cVvN__D

Compare this to our initial prime example, and you'll see how bad it is. There are many such tests, especially in the Kubernetes community when you have to mock large K8S resources; some examples I've found searching via GitHub are:

Another issue is each test case in the table consists of duplicates. Because even though the input parameters are a struct, most of the time, we only change a single field in the struct (or remove fields in our case).

This leads to copying/pasting whole test cases to only change a single field, which leads to repetition, poor readability, and maintainability. Adding a new edge case or refactoring the table becomes incredibly tedious.

How can we improve the situation?

Solution

another dimension
Photo by Rene Böhmer / Unsplash

One trick to make the table shorter and more readable is defining base values of the structs and using function types to modify the base value for each test case. An example code snippet explains it's better than words, so I'll show an example. First, let's move the Pod struct into the test.Run() closure. Let's also change the field that accepts the struct as a pointer:

func TestValidate(t *testing.T) {
	tests := []struct {
		name string
		pod  func(pod *corev1.Pod)
		err  string
	}{
		{
			name: "valid pod",
		},
	}

	for _, tt := range tests {
		tt := tt
		t.Run(tt.name, func(t *testing.T) {
			pod := testPod()
			if tt.pod != nil {
				tt.pod(pod)
			}

			err := validate(pod)
			// should it error?
			if tt.err != "" {
				if err == nil {
					t.Fatal("validate should error, but got non-nil error")
					return
				}

				if err.Error() != tt.err {
					t.Errorf("err msg\nwant: %q\n got: %q", tt.err, err.Error())
				}

				return
			}

			if err != nil {
				t.Fatalf("validate error: %s", err)
			}
		})
	}
}

func testPod() *corev1.Pod {
	return &corev1.Pod{
		ObjectMeta: metav1.ObjectMeta{
			Namespace: "default",
			Name:      "pod-123",
			Annotations: map[string]string{
				"ready": "ensure that this annotation is set",
			},
		},
		Spec: corev1.PodSpec{
			Containers: []corev1.Container{
				{
					Name:  "some-container",
					Image: "fatih/foo:test",
					Command: []string{
						"./foo",
						"--port=8800",
					},
					Ports: []corev1.ContainerPort{
						{
							Name:          "http",
							ContainerPort: 8800,
							Protocol:      corev1.ProtocolTCP,
						},
					},
				},
			},
		},
	}
}

The most significant changes are; first, we changed the Pod from a struct type to a function type:

func TestValidate(t *testing.T) {
        tests := []struct {
                name string
-               pod  *corev1.Pod
+               pod  func(pod *corev1.Pod)
                err  string
        }{

The idea is that instead of defining a full-filled Pod struct, we'll assume it's already valid and only change the fields we're interested in. By default, the Pod is valid (testPod() is a helper function that returns a valid Pod value). Instead of passing the tt.pod value to validate(), we pass the Pod returned by testPod, but also modify it when tt.pod() is defined:

        for _, tt := range tests {
                tt := tt
                t.Run(tt.name, func(t *testing.T) {
-                       err := validate(tt.pod)
+                       pod := testPod()
+                       if tt.pod != nil {
+                               tt.pod(pod)
+                       }
+
+                       err := validate(pod)

Here you can see that it's the same test, however, with a significant change. We no longer define the struct with all its nested fields in each case. Instead, we define a function that we can use to modify only specific fields of an already defined struct.

This approach is powerful when you apply it to the remaining cases. Let's change our test for the remaining cases, where we check the container.Ports and container.Images fields:

func TestValidate(t *testing.T) {
	tests := []struct {
		name string
		pod  func(pod *corev1.Pod)
		err  string
	}{
		{
			name: "valid pod",
		},
		{
			name: "invalid pod, image is not set",
			pod: func(pod *corev1.Pod) {
				pod.Spec.Containers[0].Image = ""
			},
			err: "container.Image is empty",
		},
		{
			name: "invalid pod, ports is not set",
			pod: func(pod *corev1.Pod) {
				pod.Spec.Containers[0].Ports = nil
			},
			err: "container.Ports is not set",
		},
	}
...

Let's run the tests:

$ go test -v
=== RUN   TestValidate
=== RUN   TestValidate/valid_pod
=== RUN   TestValidate/invalid_pod,_image_is_not_set
=== RUN   TestValidate/invalid_pod,_ports_is_not_set
--- PASS: TestValidate (0.00s)
    --- PASS: TestValidate/valid_pod (0.00s)
    --- PASS: TestValidate/invalid_pod,_image_is_not_set (0.00s)
    --- PASS: TestValidate/invalid_pod,_ports_is_not_set (0.00s)
PASS
ok      demo    0.556s

Here is the final test code:

https://go.dev/play/p/Uzspa-PtHjd

Previously we had to copy/paste the whole struct and only modify the lines, but now, as you see, with just a few lines, we can achieve the same result. It also reads a lot better because you can see at a glance which fields you have modified for a particular test case.

This pattern is also very flexible. In our test, we assumed a Pod is, by default, valid. But you can also assume the opposite, where the Pod is not valid by default, and you change the fields, so it becomes valid. You can also use function types for the output rather than the input. In our example, our validate() function only returns an error type, hence it's not needed. But if you return a complex, large struct, you can also use a function type for the return type in the table.