dnscrypt-proxy/dnscrypt-proxy/sources_test.go

492 lines
16 KiB
Go

package main
import (
"bytes"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
"github.com/hectane/go-acl"
"github.com/jedisct1/dlog"
"github.com/jedisct1/go-minisign"
"github.com/powerman/check"
)
type SourceFixture struct {
suffix string
content []byte
length string // HTTP Content-Length header
perms os.FileMode
mtime time.Time
}
type SourceTestState uint8
const (
TestStateCorrect SourceTestState = iota // valid files
TestStateExpired // modification time of files set in distant past (cache only)
TestStatePartial // incomplete files
TestStatePartialSig // incomplete .minisig
TestStateMissing // non-existent files
TestStateMissingSig // non-existent .minisig
TestStateReadErr // I/O error on reading files (download only)
TestStateReadSigErr // I/O error on reading .minisig (download only)
TestStateOpenErr // I/O error on opening files
TestStateOpenSigErr // I/O error on opening .minisig
TestStatePathErr // unparsable path to files (download only)
)
type SourceTestData struct {
n int // subtest counter
xTransport *XTransport
key *minisign.PublicKey
keyStr, tempDir string
sources []string
fixtures map[SourceTestState]map[string]SourceFixture
timeNow, timeOld, timeUpd time.Time
server *httptest.Server
reqActual, reqExpect map[string]uint
cacheTests map[string]SourceTestState
downloadTests map[string][]SourceTestState
}
type SourceTestExpect struct {
success bool
err, cachePath string
cache []SourceFixture
mtime time.Time
urls []string
Source *Source
delay time.Duration
prefix string
}
func readFixture(t *testing.T, name string) []byte {
bin, err := os.ReadFile(filepath.Join("testdata", name))
if err != nil {
t.Fatalf("Unable to read test fixture %s: %v", name, err)
}
return bin
}
func writeSourceCache(t *testing.T, e *SourceTestExpect) {
for _, f := range e.cache {
if f.content == nil {
continue
}
path := e.cachePath + f.suffix
perms := f.perms
if perms == 0 {
perms = 0o644
}
if err := os.WriteFile(path, f.content, perms); err != nil {
t.Fatalf("Unable to write cache file %s: %v", path, err)
}
if err := acl.Chmod(path, perms); err != nil {
t.Fatalf("Unable to set permissions on cache file %s: %v", path, err)
}
if f.suffix != "" {
continue
}
mtime := f.mtime
if f.mtime.IsZero() {
mtime = e.mtime
}
if err := os.Chtimes(path, mtime, mtime); err != nil {
t.Fatalf("Unable to touch cache file %s to %v: %v", path, f.mtime, err)
}
}
}
func checkSourceCache(c *check.C, e *SourceTestExpect) {
for _, f := range e.cache {
path := e.cachePath + f.suffix
_ = acl.Chmod(path, 0o644) // don't worry if this fails, reading it will catch the same problem
got, err := os.ReadFile(path)
c.DeepEqual(got, f.content, "Unexpected content for cache file '%s', err %v", path, err)
if f.suffix != "" {
continue
}
if fi, err := os.Stat(path); err == nil { // again, if this failed it was already caught above
mtime := f.mtime
if f.mtime.IsZero() {
mtime = e.mtime
}
c.EQ(fi.ModTime(), mtime, "Unexpected timestamp for cache file '%s'", path)
}
}
}
func loadSnakeoil(t *testing.T, d *SourceTestData) {
key, err := minisign.NewPublicKeyFromFile(filepath.Join("testdata", "snakeoil.pub"))
if err != nil {
t.Fatalf("Unable to load snakeoil key: %v", err)
}
d.keyStr = string(bytes.SplitN(readFixture(t, "snakeoil.pub"), []byte("\n"), 2)[1])
d.key = &key
}
func loadTestSourceNames(t *testing.T, d *SourceTestData) {
files, err := os.ReadDir(filepath.Join("testdata", "sources"))
if err != nil {
t.Fatalf("Unable to load list of test sources: %v", err)
}
for _, file := range files {
if !file.IsDir() && strings.HasSuffix(file.Name(), ".minisig") {
d.sources = append(d.sources, strings.TrimSuffix(file.Name(), ".minisig"))
}
}
}
func generateFixtureState(_ *testing.T, d *SourceTestData, suffix, file string, state SourceTestState) {
if _, ok := d.fixtures[state]; !ok {
d.fixtures[state] = map[string]SourceFixture{}
}
if suffix != ".minisig" {
switch state {
case TestStatePartialSig, TestStateMissingSig, TestStateReadSigErr, TestStateOpenSigErr:
d.fixtures[state][file] = d.fixtures[TestStateCorrect][file]
return
}
}
f := SourceFixture{suffix: suffix}
switch state {
case TestStateExpired:
f.content, f.mtime = d.fixtures[TestStateCorrect][file].content, d.timeOld
case TestStatePartial, TestStatePartialSig:
f.content = d.fixtures[TestStateCorrect][file].content[:1]
case TestStateReadErr, TestStateReadSigErr:
f.content, f.length = []byte{}, "1"
case TestStateOpenErr, TestStateOpenSigErr:
f.content, f.perms = d.fixtures[TestStateCorrect][file].content[:1], 0o200
}
d.fixtures[state][file] = f
}
func loadFixtures(t *testing.T, d *SourceTestData) {
d.fixtures = map[SourceTestState]map[string]SourceFixture{TestStateCorrect: {}}
for _, source := range d.sources {
for _, suffix := range [...]string{"", ".minisig"} {
file := source + suffix
d.fixtures[TestStateCorrect][file] = SourceFixture{
suffix: suffix,
content: readFixture(t, filepath.Join("sources", file)),
}
for _, state := range [...]SourceTestState{
TestStateExpired,
TestStatePartial,
TestStateReadErr,
TestStateOpenErr,
TestStatePartialSig,
TestStateMissingSig,
TestStateReadSigErr,
TestStateOpenSigErr,
} {
generateFixtureState(t, d, suffix, file, state)
}
}
}
}
func makeTempDir(t *testing.T, d *SourceTestData) {
name, err := os.MkdirTemp("", "sources_test.go."+t.Name())
if err != nil {
t.Fatalf("Unable to create temporary directory: %v", err)
}
d.tempDir = name
}
func makeTestServer(t *testing.T, d *SourceTestData) {
d.reqActual, d.reqExpect = map[string]uint{}, map[string]uint{}
d.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var data []byte = nil
d.reqActual[r.URL.Path]++
pathParts := strings.SplitN(strings.TrimPrefix(r.URL.Path, "/"), "/", 2)
state, _ := strconv.ParseUint(pathParts[0], 10, 8)
if fixture, ok := d.fixtures[SourceTestState(state)][pathParts[1]]; ok {
if len(fixture.length) > 0 {
w.Header().Set("Content-Length", fixture.length) // client will return unexpected EOF
}
data = fixture.content
}
if data != nil {
if _, err := w.Write(data); err != nil {
t.Logf("Error writing HTTP response for request [%s]: %v", r.URL.Path, err)
}
} else {
w.WriteHeader(http.StatusNotFound)
}
}))
}
func checkTestServer(c *check.C, d *SourceTestData) {
c.DeepEqual(d.reqActual, d.reqExpect, "Unexpected HTTP request log")
d.reqActual, d.reqExpect = map[string]uint{}, map[string]uint{}
}
func setupSourceTest(t *testing.T) (func(), *SourceTestData) {
d := &SourceTestData{n: -1, xTransport: NewXTransport()}
d.cacheTests = map[string]SourceTestState{ // determines cache files written to disk before each call
"correct": TestStateCorrect,
"expired": TestStateExpired,
"partial": TestStatePartial,
"partial-sig": TestStatePartialSig,
"missing": TestStateMissing,
"missing-sig": TestStateMissingSig,
"open-err": TestStateOpenErr,
"open-sig-err": TestStateOpenSigErr,
}
d.downloadTests = map[string][]SourceTestState{ // determines the list of URLs passed in each call and how they will respond
"correct": {TestStateCorrect},
"partial": {TestStatePartial},
"partial-sig": {TestStatePartialSig},
"missing": {TestStateMissing},
"missing-sig": {TestStateMissingSig},
"read-err": {TestStateReadErr},
"read-sig-err": {TestStateReadSigErr},
"open-err": {TestStateOpenErr},
"open-sig-err": {TestStateOpenSigErr},
"path-err": {TestStatePathErr},
"partial,correct": {TestStatePartial, TestStateCorrect},
"partial-sig,correct": {TestStatePartialSig, TestStateCorrect},
"missing,correct": {TestStateMissing, TestStateCorrect},
"missing-sig,correct": {TestStateMissingSig, TestStateCorrect},
"read-err,correct": {TestStateReadErr, TestStateCorrect},
"read-sig-err,correct": {TestStateReadSigErr, TestStateCorrect},
"open-err,correct": {TestStateOpenErr, TestStateCorrect},
"open-sig-err,correct": {TestStateOpenSigErr, TestStateCorrect},
"path-err,correct": {TestStatePathErr, TestStateCorrect},
"no-urls": {},
}
d.xTransport.rebuildTransport()
d.timeNow = time.Now().AddDate(0, 0, 0).Truncate(time.Second)
d.timeOld = d.timeNow.Add(DefaultPrefetchDelay * -4)
d.timeUpd = d.timeNow.Add(DefaultPrefetchDelay)
timeNow = func() time.Time { return d.timeNow } // originally defined in sources.go, replaced during testing to ensure consistent results
makeTempDir(t, d)
makeTestServer(t, d)
loadSnakeoil(t, d)
loadTestSourceNames(t, d)
loadFixtures(t, d)
return func() {
os.RemoveAll(d.tempDir)
d.server.Close()
}, d
}
func prepSourceTestCache(t *testing.T, d *SourceTestData, e *SourceTestExpect, source string, state SourceTestState) {
e.cache = []SourceFixture{d.fixtures[state][source], d.fixtures[state][source+".minisig"]}
switch state {
case TestStateCorrect:
e.Source.bin, e.success = e.cache[0].content, true
case TestStateExpired:
e.Source.bin = e.cache[0].content
case TestStatePartial, TestStatePartialSig:
e.err = "signature"
case TestStateMissing, TestStateMissingSig, TestStateOpenErr, TestStateOpenSigErr:
e.err = "open"
}
writeSourceCache(t, e)
}
func prepSourceTestDownload(
_ *testing.T,
d *SourceTestData,
e *SourceTestExpect,
source string,
downloadTest []SourceTestState,
) {
if len(downloadTest) == 0 {
return
}
for _, state := range downloadTest {
path := "/" + strconv.FormatUint(uint64(state), 10) + "/" + source
serverURL := d.server.URL
switch state {
case TestStateMissing, TestStateMissingSig:
e.err = "404 Not Found"
case TestStatePartial, TestStatePartialSig:
e.err = "signature"
case TestStateReadErr, TestStateReadSigErr:
e.err = "unexpected EOF"
case TestStateOpenErr, TestStateOpenSigErr:
if u, err := url.Parse(serverURL + path); err == nil {
host, port := ExtractHostAndPort(u.Host, -1)
u.Host = fmt.Sprintf(
"%s:%d",
host,
port|0x10000,
) // high numeric port is parsed but then fails to connect
serverURL = u.String()
}
e.err = "invalid port"
case TestStatePathErr:
path = "..." + path // non-numeric port fails URL parsing
}
if u, err := url.Parse(serverURL + path); err == nil {
e.Source.urls = append(e.Source.urls, u)
}
e.urls = append(e.urls, serverURL+path)
if e.success {
continue
}
switch state {
case TestStateCorrect:
e.cache = []SourceFixture{d.fixtures[state][source], d.fixtures[state][source+".minisig"]}
e.Source.bin, e.success = e.cache[0].content, true
fallthrough
case TestStateMissingSig, TestStatePartial, TestStatePartialSig, TestStateReadSigErr:
d.reqExpect[path+".minisig"]++
fallthrough
case TestStateMissing, TestStateReadErr:
d.reqExpect[path]++
}
}
if e.success {
e.err = ""
e.delay = DefaultPrefetchDelay
} else {
e.delay = MinimumPrefetchInterval
}
if len(e.Source.urls) > 0 {
e.Source.refresh = d.timeNow.Add(e.delay)
} else {
e.success = false
}
}
func setupSourceTestCase(t *testing.T, d *SourceTestData, i int,
cacheTest *SourceTestState, downloadTest []SourceTestState,
) (id string, e *SourceTestExpect) {
id = strconv.Itoa(d.n) + "-" + strconv.Itoa(i)
e = &SourceTestExpect{
cachePath: filepath.Join(d.tempDir, id),
mtime: d.timeNow,
}
e.Source = &Source{
name: id, urls: []*url.URL{}, format: SourceFormatV2, minisignKey: d.key,
cacheFile: e.cachePath, cacheTTL: DefaultPrefetchDelay * 3, prefetchDelay: DefaultPrefetchDelay,
}
if cacheTest != nil {
prepSourceTestCache(t, d, e, d.sources[i], *cacheTest)
i = (i + 1) % len(d.sources) // make the cached and downloaded fixtures different
}
prepSourceTestDownload(t, d, e, d.sources[i], downloadTest)
return
}
func TestNewSource(t *testing.T) {
if testing.Verbose() {
dlog.SetLogLevel(dlog.SeverityDebug)
dlog.UseSyslog(false)
}
teardown, d := setupSourceTest(t)
defer teardown()
checkResult := func(t *testing.T, e *SourceTestExpect, got *Source, err error) {
c := check.T(t)
if len(e.err) > 0 {
c.Match(err, e.err, "Unexpected error")
} else {
c.Nil(err, "Unexpected error")
}
c.DeepEqual(got, e.Source, "Unexpected return")
checkTestServer(c, d)
checkSourceCache(c, e)
}
d.n++
for _, tt := range []struct {
v, key string
refreshDelay time.Duration
e *SourceTestExpect
}{
{"", "", 0, &SourceTestExpect{err: " ", Source: &Source{name: "short refresh delay", urls: []*url.URL{}, cacheTTL: DefaultPrefetchDelay, prefetchDelay: DefaultPrefetchDelay, prefix: ""}}},
{"v1", d.keyStr, DefaultPrefetchDelay * 2, &SourceTestExpect{err: "Unsupported source format", Source: &Source{name: "old format", urls: []*url.URL{}, cacheTTL: DefaultPrefetchDelay * 2, prefetchDelay: DefaultPrefetchDelay}}},
{"v2", "", DefaultPrefetchDelay * 3, &SourceTestExpect{err: "Invalid encoded public key", Source: &Source{name: "invalid public key", urls: []*url.URL{}, cacheTTL: DefaultPrefetchDelay * 3, prefetchDelay: DefaultPrefetchDelay}}},
} {
t.Run(tt.e.Source.name, func(t *testing.T) {
got, err := NewSource(
tt.e.Source.name,
d.xTransport,
tt.e.urls,
tt.key,
tt.e.cachePath,
tt.v,
tt.refreshDelay,
tt.e.prefix,
)
checkResult(t, tt.e, got, err)
})
}
for cacheTestName, cacheTest := range d.cacheTests {
for downloadTestName, downloadTest := range d.downloadTests {
d.n++
for i := range d.sources {
id, e := setupSourceTestCase(t, d, i, &cacheTest, downloadTest)
t.Run("cache "+cacheTestName+", download "+downloadTestName+"/"+id, func(t *testing.T) {
got, err := NewSource(
id,
d.xTransport,
e.urls,
d.keyStr,
e.cachePath,
"v2",
DefaultPrefetchDelay*3,
"",
)
checkResult(t, e, got, err)
})
}
}
}
}
func TestPrefetchSources(t *testing.T) {
if testing.Verbose() {
dlog.SetLogLevel(dlog.SeverityDebug)
dlog.UseSyslog(false)
}
teardown, d := setupSourceTest(t)
defer teardown()
checkResult := func(t *testing.T, expects []*SourceTestExpect, got time.Duration) {
c := check.T(t)
expectDelay := MinimumPrefetchInterval
for _, e := range expects {
if e.delay >= MinimumPrefetchInterval && (expectDelay == MinimumPrefetchInterval || expectDelay > e.delay) {
expectDelay = e.delay
}
}
c.InDelta(got, expectDelay, time.Second, "Unexpected return")
checkTestServer(c, d)
for _, e := range expects {
checkSourceCache(c, e)
}
}
timeNow = func() time.Time { return d.timeUpd } // since the fixtures are prepared using real now, make the tested code think it's the future
for downloadTestName, downloadTest := range d.downloadTests {
d.n++
sources := []*Source{}
expects := []*SourceTestExpect{}
for i := range d.sources {
_, e := setupSourceTestCase(t, d, i, nil, downloadTest)
e.mtime = d.timeUpd
s := &Source{}
*s = *e.Source
s.bin = nil
sources = append(sources, s)
expects = append(expects, e)
}
t.Run("download "+downloadTestName, func(t *testing.T) {
got := PrefetchSources(d.xTransport, sources)
checkResult(t, expects, got)
})
}
}
func TestMain(m *testing.M) { check.TestMain(m) }