Skip to content


Restricting what packages are allowed to import in Golang using ASTs

I’m working on refactoring a monolithic end-to-end test suite for my company. It’s grown organically over the past 2 years as many, many new engineers and teams have onboarded. As you may know, “grown organically” is a euphemism for “it’s a hot mess”.

The question I wanted to answer today was “how do I restrict what packages my golang code can import?”

The solution I whipped up is a “linter” that runs as part of CI and fails if my import rules aren’t met. But what does it look like for real?

After a bunch of refactoring, we have test directories roughly corresponding to individual teams, and a set of library packages with stable interfaces that are intended to be used by those teams for testing. Now we need to make sure that each teams’ tests depend _only_ on the stable library packages and not random helpers in the test packages of other teams.

Say this is our repo layout:

$ tree .
.
├── pkg
└── test
    ├── connect
    │   ├── connector_sink_test.go
    │   ├── ...
    │   └── utils.go
    ├── kafka
    │   ├── byok
    │   │   ├── aws_test.go
    │   │   └── gcp_test.go
    │   ├── kafka_suite
    │   │   ├── ...
    │   │   └── storage_test.go
    │   └── test_helpers.go
    └── test_helpers.go

With this layout, we’d expect that

  • connector_sink_test.go can import from test/connect/utils.go or test/test_helpers.go
  • connector_sink_test.go can NOT import from test/kafka/test_helpers.go
  • aws_test.go can import from anything in test/kafka/byok, test/kafka/test_helpers.go, or test/test_helpers.go
  • aws_test.go can NOT import from anything from test/kafka/kafka_suite or test/connect

Roughly speaking, we can break this down into like 3 phases.

1. Parse all the golang packages in the test directories (recursively) into ASTs
2. For each golang file, define the list of allowed import directories
3. Iterate through the ASTs and evaluate what we import versus what we allow

Here’s the heart of the approach. We have a list of ast.Packages and we want to validate that they only import paths that are in the parent directory hierarchy, rooted at testDir. Since this uses local file paths for parsing ASTs, we also pass our goSrcDir for translating between import paths and disk locations (using $GOPATH/src).

func validateAllowedImports(pkgs []*ast.Package, testDir string, goSrcDir string) []error {
  testImport := strings.TrimPrefix(testDir, goSrcDir)
 
  var errors []error
  for _, pkg := range pkgs {
    for filename, file := range pkg.Files {
      allowedDirs := directoriesBetween(filename, testDir)
      allowedImports := make(map[string]struct{})
      for _, path := range allowedDirs {
        importLine := strings.TrimPrefix(path, goSrcDir)
        allowedImports[importLine] = struct{}{}
      }
      for _, imp := range file.Imports {
        path := imp.Path.Value[1 : len(imp.Path.Value)-1] // trim " from start and end
        // if we're importing from another test directory
        if strings.HasPrefix(path, testImport) {
          // make sure that its one of our allowed imports from a parent package
          if _, allowed := allowedImports[path]; !allowed {
            filename := strings.TrimPrefix(filename, goSrcDir)
            errors = append(errors, &IllegalImportError{ImportPath: path, Filepath: filename})
          }
        }
      }
    }
  }
 
  return errors
}

This should be pretty straightforward to follow. There’s only a couple weird things in here.

1. The import line from the AST includes quotes, like "github.com/codyaray/break" so we trim them off
2. We use a map[string]struct{} as a poor man’s “set” data structure in golang for easy “set contains” operations (the if _, allowed := allowedImports[path]; !allowed line).
3. The “magic” that must be in the directoriesBetween(filename, testDir) function

Ok, ok. There’s not really much magic here. We each start with the filename and walk up the file path directory by directory, until we hit the root directory. That set of directories is what we allow tests to import from.

func directoriesBetween(filename string, rootDir string) []string {
  d := filepath.Dir(filename)
  allowedImports := []string{d}
  for d != rootDir {
    d = filepath.Dir(d)
    allowedImports = append(allowedImports, d)
  }
  return allowedImports
}

The rest is just boring boilerplate code, but for the sake of completeness, here it is in full:

package main
 
import (
	"fmt"
	"go/ast"
	"go/parser"
	"go/token"
	"os"
	"path/filepath"
	"strings"
)
 
func main() {
	rootDir, err := os.Getwd()
	if err != nil {
		panic(err)
	}
	testDir := filepath.Join(rootDir, "./test")
 
	dirs, err := listDirectoriesRecursive(testDir)
	if err != nil {
		panic(err)
	}
 
	pkgs, err := parsePackages(dirs)
	if err != nil {
		panic(err)
	}
 
	goPath := os.Getenv("GOPATH")
	if goPath == "" {
		panic(fmt.Errorf("$GOPATH unset"))
	}
	goSrcDir := fmt.Sprintf("%s/src/", goPath)
 
	errs := validateAllowedImports(pkgs, testDir, goSrcDir)
	for _, err := range errs {
		fmt.Println(err)
	}
	if len(errs) > 0 {
		os.Exit(1)
	}
}
 
func validateAllowedImports(pkgs []*ast.Package, testDir string, goSrcDir string) []error {
	testImport := strings.TrimPrefix(testDir, goSrcDir)
 
	var errors []error
	for _, pkg := range pkgs {
		for filename, file := range pkg.Files {
			allowedDirs := directoriesBetween(filename, testDir)
			allowedImports := make(map[string]struct{})
			for _, path := range allowedDirs {
				importLine := strings.TrimPrefix(path, goSrcDir)
				allowedImports[importLine] = struct{}{}
			}
			for _, imp := range file.Imports {
				path := imp.Path.Value[1 : len(imp.Path.Value)-1] // trim " from start and end
				// if we're importing from another test directory
				if strings.HasPrefix(path, testImport) {
					// make sure that its one of our allowed imports from a parent package
					if _, allowed := allowedImports[path]; !allowed {
						filename := strings.TrimPrefix(filename, goSrcDir)
						errors = append(errors, &IllegalImportError{ImportPath: path, Filepath: filename})
					}
				}
			}
		}
	}
 
	return errors
}
 
type IllegalImportError struct {
	ImportPath string
	Filepath   string
}
 
func (e *IllegalImportError) Error() string {
	return fmt.Sprintf("Illegal import of %s from %s", e.ImportPath, e.Filepath)
}
 
// Returns a list of the directories between filename and rootDir
//
// e.g., given filename=test/kafka/byok/aws_test.go and rootDir=test,
// this will return test/kafka/byok, test/kafka, and test
func directoriesBetween(filename string, rootDir string) []string {
	d := filepath.Dir(filename)
	allowedImports := []string{d}
	for d != rootDir {
		d = filepath.Dir(d)
		allowedImports = append(allowedImports, d)
	}
	return allowedImports
}
 
func parsePackages(dirs []string) ([]*ast.Package, error) {
	var packages []*ast.Package
	fset := token.NewFileSet()
	for _, dir := range dirs {
		pkgs, err := parser.ParseDir(fset, dir, nil, 0)
		if err != nil {
			return nil, err
		}
		for _, pkg := range pkgs {
			packages = append(packages, pkg)
		}
	}
	return packages, nil
}
 
func listDirectoriesRecursive(rootDir string) ([]string, error) {
	dirs := []string{rootDir}
	err := filepath.Walk(rootDir,
		func(path string, info os.FileInfo, err error) error {
			if err != nil {
				return err
			}
			if info.IsDir() {
				dirs = append(dirs, path)
			}
			return nil
		})
	if err != nil {
		return nil, err
	}
	return dirs, nil
}

Posted in Tech Reference, Tutorials.


0 Responses

Stay in touch with the conversation, subscribe to the RSS feed for comments on this post.



Some HTML is OK

or, reply to this post via trackback.

 



Log in here!