diff --git a/dnscrypt-proxy/sources.go b/dnscrypt-proxy/sources.go index 8be99adf..7d9c8534 100644 --- a/dnscrypt-proxy/sources.go +++ b/dnscrypt-proxy/sources.go @@ -32,7 +32,7 @@ const ( type Source struct { name string - urls []string + urls []*url.URL format SourceFormat in []byte minisignKey *minisign.PublicKey @@ -96,6 +96,16 @@ func (source *Source) writeToCache(bin, sig []byte) (err error) { return } +func (source *Source) parseURLs(urls []string) { + for _, urlStr := range urls { + if srcURL, err := url.Parse(urlStr); err != nil { + dlog.Warnf("Source [%s] failed to parse URL [%s]", source.name, urlStr) + } else { + source.urls = append(source.urls, srcURL) + } + } +} + func fetchFromURL(xTransport *XTransport, u *url.URL) (bin []byte, err error) { var resp *http.Response if resp, _, err = xTransport.Get(u, "", DefaultTimeout); err == nil { @@ -108,7 +118,7 @@ func fetchFromURL(xTransport *XTransport, u *url.URL) (bin []byte, err error) { func (source *Source) fetchWithCache(xTransport *XTransport, now time.Time) (delay time.Duration, err error) { if delay, err = source.fetchFromCache(now); err != nil { if len(source.urls) == 0 { - dlog.Errorf("Source [%s] cache file [%s] not present and no URL given", source.name, source.cacheFile) + dlog.Errorf("Source [%s] cache file [%s] not present and no valid URL", source.name, source.cacheFile) return } dlog.Debugf("Source [%s] cache file [%s] not present", source.name, source.cacheFile) @@ -123,13 +133,8 @@ func (source *Source) fetchWithCache(xTransport *XTransport, now time.Time) (del } delay = MinimumPrefetchInterval var bin, sig []byte - for _, urlStr := range source.urls { - dlog.Infof("Source [%s] loading from URL [%s]", source.name, urlStr) - var srcURL *url.URL - if srcURL, err = url.Parse(urlStr); err != nil { - dlog.Debugf("Source [%s] failed to parse URL [%s]", source.name, urlStr) - continue - } + for _, srcURL := range source.urls { + dlog.Infof("Source [%s] loading from URL [%s]", source.name, srcURL) sigURL := &url.URL{} *sigURL = *srcURL // deep copy to avoid parsing twice sigURL.Path += ".minisig" @@ -144,7 +149,7 @@ func (source *Source) fetchWithCache(xTransport *XTransport, now time.Time) (del if err = source.checkSignature(bin, sig); err == nil { break // valid signature } // above err check inverted to make use of implicit continue - dlog.Debugf("Source [%s] failed signature check using URL [%s]", source.name, urlStr) + dlog.Debugf("Source [%s] failed signature check using URL [%s]", source.name, srcURL) } if err != nil { return @@ -160,7 +165,7 @@ func NewSource(name string, xTransport *XTransport, urls []string, minisignKeySt if refreshDelay < DefaultPrefetchDelay { refreshDelay = DefaultPrefetchDelay } - source = &Source{name: name, urls: urls, cacheFile: cacheFile, cacheTTL: refreshDelay, prefetchDelay: DefaultPrefetchDelay} + source = &Source{name: name, urls: []*url.URL{}, cacheFile: cacheFile, cacheTTL: refreshDelay, prefetchDelay: DefaultPrefetchDelay} if formatStr == "v2" { source.format = SourceFormatV2 } else { @@ -171,10 +176,10 @@ func NewSource(name string, xTransport *XTransport, urls []string, minisignKeySt } else { return source, err } - if _, err = source.fetchWithCache(xTransport, timeNow()); err != nil { - return + source.parseURLs(urls) + if _, err = source.fetchWithCache(xTransport, timeNow()); err == nil { + dlog.Noticef("Source [%s] loaded", name) } - dlog.Noticef("Source [%s] loaded", name) return } diff --git a/dnscrypt-proxy/sources_test.go b/dnscrypt-proxy/sources_test.go index e57893a7..4c43d25c 100644 --- a/dnscrypt-proxy/sources_test.go +++ b/dnscrypt-proxy/sources_test.go @@ -5,6 +5,7 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "net/url" "os" "path/filepath" "strconv" @@ -57,6 +58,7 @@ type SourceTestExpect struct { success bool err, cachePath string cache []SourceFixture + urls []string Source *Source delay time.Duration } @@ -204,16 +206,16 @@ func setupSourceTest(t *testing.T) (func(), *SourceTestData) { "open-sig-err": TestStateOpenSigErr, } d.downloadTests = map[string][]SourceTestState{ // determines the list of URLs passed in each call and how they will respond - "correct": {TestStateCorrect}, - "partial": {TestStatePartial}, - "partial-sig": {TestStatePartialSig}, - "missing": {TestStateMissing}, - "missing-sig": {TestStateMissingSig}, - "read-err": {TestStateReadErr}, - "read-sig-err": {TestStateReadSigErr}, - "open-err": {TestStateOpenErr}, - "open-sig-err": {TestStateOpenSigErr}, - //"path-err": {TestStatePathErr}, // TODO: invalid URLs should not be included in the prefetch list + "correct": {TestStateCorrect}, + "partial": {TestStatePartial}, + "partial-sig": {TestStatePartialSig}, + "missing": {TestStateMissing}, + "missing-sig": {TestStateMissingSig}, + "read-err": {TestStateReadErr}, + "read-sig-err": {TestStateReadSigErr}, + "open-err": {TestStateOpenErr}, + "open-sig-err": {TestStateOpenSigErr}, + "path-err": {TestStatePathErr}, "partial,correct": {TestStatePartial, TestStateCorrect}, "partial-sig,correct": {TestStatePartialSig, TestStateCorrect}, "missing,correct": {TestStateMissing, TestStateCorrect}, @@ -222,8 +224,8 @@ func setupSourceTest(t *testing.T) (func(), *SourceTestData) { "read-sig-err,correct": {TestStateReadSigErr, TestStateCorrect}, "open-err,correct": {TestStateOpenErr, TestStateCorrect}, "open-sig-err,correct": {TestStateOpenSigErr, TestStateCorrect}, - //"path-err,correct": {TestStatePathErr, TestStateCorrect}, // TODO: invalid URLs should not be included in the prefetch list - "no-urls": {}, + "path-err,correct": {TestStatePathErr, TestStateCorrect}, + "no-urls": {}, } d.xTransport.rebuildTransport() d.timeNow = time.Now().AddDate(0, 0, 0) @@ -287,9 +289,11 @@ func prepSourceTestDownload(t *testing.T, d *SourceTestData, e *SourceTestExpect e.err = "invalid port" case TestStatePathErr: path = "..." + path // non-numeric port fails URL parsing - e.err = "parse" } - e.Source.urls = append(e.Source.urls, d.server.URL+path) + if u, err := url.Parse(d.server.URL + path); err == nil { + e.Source.urls = append(e.Source.urls, u) + } + e.urls = append(e.urls, d.server.URL+path) } if e.success { e.err = "" @@ -297,7 +301,11 @@ func prepSourceTestDownload(t *testing.T, d *SourceTestData, e *SourceTestExpect } else { e.delay = MinimumPrefetchInterval } - e.Source.refresh = d.timeNow.Add(e.delay) + if len(e.Source.urls) > 0 { + e.Source.refresh = d.timeNow.Add(e.delay) + } else { + e.success = false + } } } @@ -307,7 +315,7 @@ func setupSourceTestCase(t *testing.T, d *SourceTestData, i int, e = &SourceTestExpect{ cachePath: filepath.Join(d.tempDir, id), } - e.Source = &Source{name: id, urls: []string{}, format: SourceFormatV2, minisignKey: d.key, + e.Source = &Source{name: id, urls: []*url.URL{}, format: SourceFormatV2, minisignKey: d.key, cacheFile: e.cachePath, cacheTTL: DefaultPrefetchDelay * 3, prefetchDelay: DefaultPrefetchDelay} if cacheTest != nil { prepSourceTestCache(t, d, e, d.sources[i], *cacheTest) @@ -333,17 +341,14 @@ func TestNewSource(t *testing.T) { } d.n++ for _, tt := range []struct { - v, key string - refresh time.Duration - e *SourceTestExpect + v, key string + e *SourceTestExpect }{ - {"v1", d.keyStr, DefaultPrefetchDelay * 2, &SourceTestExpect{err: "Unsupported source format", - Source: &Source{name: "old format", cacheTTL: DefaultPrefetchDelay * 2, prefetchDelay: DefaultPrefetchDelay}}}, - {"v2", "", DefaultPrefetchDelay * 3, &SourceTestExpect{err: "Invalid encoded public key", - Source: &Source{name: "invalid public key", cacheTTL: DefaultPrefetchDelay * 3, prefetchDelay: DefaultPrefetchDelay}}}, + {"v1", d.keyStr, &SourceTestExpect{err: "Unsupported source format", Source: &Source{name: "old format", urls: []*url.URL{}, cacheTTL: DefaultPrefetchDelay * 2, prefetchDelay: DefaultPrefetchDelay}}}, + {"v2", "", &SourceTestExpect{err: "Invalid encoded public key", Source: &Source{name: "invalid public key", urls: []*url.URL{}, cacheTTL: DefaultPrefetchDelay * 3, prefetchDelay: DefaultPrefetchDelay}}}, } { t.Run(tt.e.Source.name, func(t *testing.T) { - got, err := NewSource(tt.e.Source.name, d.xTransport, tt.e.Source.urls, tt.key, tt.e.cachePath, tt.v, tt.refresh) + got, err := NewSource(tt.e.Source.name, d.xTransport, tt.e.urls, tt.key, tt.e.cachePath, tt.v, tt.e.Source.cacheTTL) checkResult(t, tt.e, got, err) }) } @@ -353,7 +358,7 @@ func TestNewSource(t *testing.T) { for i := range d.sources { id, e := setupSourceTestCase(t, d, i, &cacheTest, downloadTest) t.Run("cache "+cacheTestName+", download "+downloadTestName+"/"+id, func(t *testing.T) { - got, err := NewSource(id, d.xTransport, e.Source.urls, d.keyStr, e.cachePath, "v2", DefaultPrefetchDelay*3) + got, err := NewSource(id, d.xTransport, e.urls, d.keyStr, e.cachePath, "v2", DefaultPrefetchDelay*3) checkResult(t, e, got, err) }) }