I’m working on refactoring a monolithic end-to-end test suite for my company. It’s grown organically over the past 2 years as many, many new engineers and teams have onboarded. As you may know, “grown organically” is a euphemism for “it’s a hot mess”.
The question I wanted to answer today was “how do I restrict what packages my golang code can import?”
The solution I whipped up is a “linter” that runs as part of CI and fails if my import rules aren’t met. But what does it look like for real?
After a bunch of refactoring, we have test directories roughly corresponding to individual teams, and a set of library packages with stable interfaces that are intended to be used by those teams for testing. Now we need to make sure that each teams’ tests depend _only_ on the stable library packages and not random helpers in the test packages of other teams.
Say this is our repo layout:
$ tree . . ├── pkg └── test ├── connect │ ├── connector_sink_test.go │ ├── ... │ └── utils.go ├── kafka │ ├── byok │ │ ├── aws_test.go │ │ └── gcp_test.go │ ├── kafka_suite │ │ ├── ... │ │ └── storage_test.go │ └── test_helpers.go └── test_helpers.go |
With this layout, we’d expect that
connector_sink_test.go
can import fromtest/connect/utils.go
ortest/test_helpers.go
connector_sink_test.go
can NOT import fromtest/kafka/test_helpers.go
aws_test.go
can import from anything intest/kafka/byok
,test/kafka/test_helpers.go
, ortest/test_helpers.go
aws_test.go
can NOT import from anything fromtest/kafka/kafka_suite
ortest/connect
Roughly speaking, we can break this down into like 3 phases.
1. Parse all the golang packages in the test directories (recursively) into ASTs
2. For each golang file, define the list of allowed import directories
3. Iterate through the ASTs and evaluate what we import versus what we allow
Here’s the heart of the approach. We have a list of ast.Package
s and we want to validate that they only import paths that are in the parent directory hierarchy, rooted at testDir
. Since this uses local file paths for parsing ASTs, we also pass our goSrcDir
for translating between import paths and disk locations (using $GOPATH/src
).
func validateAllowedImports(pkgs []*ast.Package, testDir string, goSrcDir string) []error { testImport := strings.TrimPrefix(testDir, goSrcDir) var errors []error for _, pkg := range pkgs { for filename, file := range pkg.Files { allowedDirs := directoriesBetween(filename, testDir) allowedImports := make(map[string]struct{}) for _, path := range allowedDirs { importLine := strings.TrimPrefix(path, goSrcDir) allowedImports[importLine] = struct{}{} } for _, imp := range file.Imports { path := imp.Path.Value[1 : len(imp.Path.Value)-1] // trim " from start and end // if we're importing from another test directory if strings.HasPrefix(path, testImport) { // make sure that its one of our allowed imports from a parent package if _, allowed := allowedImports[path]; !allowed { filename := strings.TrimPrefix(filename, goSrcDir) errors = append(errors, &IllegalImportError{ImportPath: path, Filepath: filename}) } } } } } return errors } |
This should be pretty straightforward to follow. There’s only a couple weird things in here.
1. The import line from the AST includes quotes, like "github.com/codyaray/break"
so we trim them off
2. We use a map[string]struct{}
as a poor man’s “set” data structure in golang for easy “set contains” operations (the if _, allowed := allowedImports[path]; !allowed
line).
3. The “magic” that must be in the directoriesBetween(filename, testDir)
function
Ok, ok. There’s not really much magic here. We each start with the filename and walk up the file path directory by directory, until we hit the root directory. That set of directories is what we allow tests to import from.
func directoriesBetween(filename string, rootDir string) []string { d := filepath.Dir(filename) allowedImports := []string{d} for d != rootDir { d = filepath.Dir(d) allowedImports = append(allowedImports, d) } return allowedImports } |
The rest is just boring boilerplate code, but for the sake of completeness, here it is in full:
package main import ( "fmt" "go/ast" "go/parser" "go/token" "os" "path/filepath" "strings" ) func main() { rootDir, err := os.Getwd() if err != nil { panic(err) } testDir := filepath.Join(rootDir, "./test") dirs, err := listDirectoriesRecursive(testDir) if err != nil { panic(err) } pkgs, err := parsePackages(dirs) if err != nil { panic(err) } goPath := os.Getenv("GOPATH") if goPath == "" { panic(fmt.Errorf("$GOPATH unset")) } goSrcDir := fmt.Sprintf("%s/src/", goPath) errs := validateAllowedImports(pkgs, testDir, goSrcDir) for _, err := range errs { fmt.Println(err) } if len(errs) > 0 { os.Exit(1) } } func validateAllowedImports(pkgs []*ast.Package, testDir string, goSrcDir string) []error { testImport := strings.TrimPrefix(testDir, goSrcDir) var errors []error for _, pkg := range pkgs { for filename, file := range pkg.Files { allowedDirs := directoriesBetween(filename, testDir) allowedImports := make(map[string]struct{}) for _, path := range allowedDirs { importLine := strings.TrimPrefix(path, goSrcDir) allowedImports[importLine] = struct{}{} } for _, imp := range file.Imports { path := imp.Path.Value[1 : len(imp.Path.Value)-1] // trim " from start and end // if we're importing from another test directory if strings.HasPrefix(path, testImport) { // make sure that its one of our allowed imports from a parent package if _, allowed := allowedImports[path]; !allowed { filename := strings.TrimPrefix(filename, goSrcDir) errors = append(errors, &IllegalImportError{ImportPath: path, Filepath: filename}) } } } } } return errors } type IllegalImportError struct { ImportPath string Filepath string } func (e *IllegalImportError) Error() string { return fmt.Sprintf("Illegal import of %s from %s", e.ImportPath, e.Filepath) } // Returns a list of the directories between filename and rootDir // // e.g., given filename=test/kafka/byok/aws_test.go and rootDir=test, // this will return test/kafka/byok, test/kafka, and test func directoriesBetween(filename string, rootDir string) []string { d := filepath.Dir(filename) allowedImports := []string{d} for d != rootDir { d = filepath.Dir(d) allowedImports = append(allowedImports, d) } return allowedImports } func parsePackages(dirs []string) ([]*ast.Package, error) { var packages []*ast.Package fset := token.NewFileSet() for _, dir := range dirs { pkgs, err := parser.ParseDir(fset, dir, nil, 0) if err != nil { return nil, err } for _, pkg := range pkgs { packages = append(packages, pkg) } } return packages, nil } func listDirectoriesRecursive(rootDir string) ([]string, error) { dirs := []string{rootDir} err := filepath.Walk(rootDir, func(path string, info os.FileInfo, err error) error { if err != nil { return err } if info.IsDir() { dirs = append(dirs, path) } return nil }) if err != nil { return nil, err } return dirs, nil } |
0 Responses
Stay in touch with the conversation, subscribe to the RSS feed for comments on this post.