diff --git a/dnscrypt-proxy/sources.go b/dnscrypt-proxy/sources.go index bd4c04ff..110defba 100644 --- a/dnscrypt-proxy/sources.go +++ b/dnscrypt-proxy/sources.go @@ -31,12 +31,13 @@ const ( ) type Source struct { - urls []string - prefetch []*URLToPrefetch - format SourceFormat - in []byte - minisignKey *minisign.PublicKey - cacheFile string + urls []string + prefetch []*URLToPrefetch + format SourceFormat + in []byte + minisignKey *minisign.PublicKey + cacheFile string + cacheTTL, prefetchDelay time.Duration } func (source *Source) checkSignature(bin, sig []byte) (err error) { @@ -50,26 +51,23 @@ func (source *Source) checkSignature(bin, sig []byte) (err error) { // timeNow can be replaced by tests to provide a static value var timeNow = time.Now -func fetchFromCache(cacheFile string, refreshDelay time.Duration) (bin, sig []byte, delayTillNextUpdate time.Duration, err error) { - delayTillNextUpdate = time.Duration(0) - if refreshDelay < DefaultPrefetchDelay { - refreshDelay = DefaultPrefetchDelay - } +func (source *Source) fetchFromCache() (bin, sig []byte, delayTillNextUpdate time.Duration, err error) { + delayTillNextUpdate = 0 var fi os.FileInfo - if fi, err = os.Stat(cacheFile); err != nil { + if fi, err = os.Stat(source.cacheFile); err != nil { return } - if bin, err = ioutil.ReadFile(cacheFile); err != nil { + if bin, err = ioutil.ReadFile(source.cacheFile); err != nil { return } - if sig, err = ioutil.ReadFile(cacheFile + ".minisig"); err != nil { + if sig, err = ioutil.ReadFile(source.cacheFile + ".minisig"); err != nil { return } - if elapsed := timeNow().Sub(fi.ModTime()); elapsed < refreshDelay { - dlog.Debugf("Cache file [%s] is still fresh", cacheFile) - delayTillNextUpdate = DefaultPrefetchDelay - elapsed + if elapsed := timeNow().Sub(fi.ModTime()); elapsed < source.cacheTTL { + dlog.Debugf("Cache file [%s] is still fresh", source.cacheFile) + delayTillNextUpdate = source.prefetchDelay - elapsed } else { - dlog.Debugf("Cache file [%s] needs to be refreshed", cacheFile) + dlog.Debugf("Cache file [%s] needs to be refreshed", source.cacheFile) } return } @@ -83,13 +81,13 @@ func fetchFromURL(xTransport *XTransport, u *url.URL) (bin []byte, err error) { return } -func fetchWithCache(xTransport *XTransport, urlStr string, cacheFile string, refreshDelay time.Duration) (bin, sig []byte, delayTillNextUpdate time.Duration, err error) { - if bin, sig, delayTillNextUpdate, err = fetchFromCache(cacheFile, refreshDelay); err != nil { +func (source *Source) fetchWithCache(xTransport *XTransport, urlStr string) (bin, sig []byte, delayTillNextUpdate time.Duration, err error) { + if bin, sig, delayTillNextUpdate, err = source.fetchFromCache(); err != nil { if len(urlStr) == 0 { - dlog.Errorf("Cache file [%s] not present and no URL given to retrieve it", cacheFile) + dlog.Errorf("Cache file [%s] not present and no URL given to retrieve it", source.cacheFile) return } - dlog.Debugf("Cache file [%s] not present", cacheFile) + dlog.Debugf("Cache file [%s] not present", source.cacheFile) } if err == nil && delayTillNextUpdate > 0 { dlog.Debugf("Delay till next update: %v", delayTillNextUpdate) @@ -111,17 +109,17 @@ func fetchWithCache(xTransport *XTransport, urlStr string, cacheFile string, ref if sig, err = fetchFromURL(xTransport, sigURL); err != nil { return } - if err = AtomicFileWrite(cacheFile, bin); err != nil { - if absPath, err2 := filepath.Abs(cacheFile); err2 == nil { + if err = AtomicFileWrite(source.cacheFile, bin); err != nil { + if absPath, err2 := filepath.Abs(source.cacheFile); err2 == nil { dlog.Warnf("%s: %s", absPath, err) } } - if err = AtomicFileWrite(cacheFile+".minisig", sig); err != nil { - if absPath, err2 := filepath.Abs(cacheFile + ".minisig"); err2 == nil { + if err = AtomicFileWrite(source.cacheFile+".minisig", sig); err != nil { + if absPath, err2 := filepath.Abs(source.cacheFile + ".minisig"); err2 == nil { dlog.Warnf("%s: %s", absPath, err) } } - delayTillNextUpdate = DefaultPrefetchDelay + delayTillNextUpdate = source.prefetchDelay return } @@ -135,7 +133,10 @@ type URLToPrefetch struct { } func NewSource(xTransport *XTransport, urls []string, minisignKeyStr string, cacheFile string, formatStr string, refreshDelay time.Duration) (source *Source, err error) { - source = &Source{urls: urls, cacheFile: cacheFile} + if refreshDelay < DefaultPrefetchDelay { + refreshDelay = DefaultPrefetchDelay + } + source = &Source{urls: urls, cacheFile: cacheFile, cacheTTL: refreshDelay, prefetchDelay: DefaultPrefetchDelay} if formatStr == "v2" { source.format = SourceFormatV2 } else { @@ -153,11 +154,11 @@ func NewSource(xTransport *XTransport, urls []string, minisignKeyStr string, cac var delayTillNextUpdate time.Duration var preloadURL string if len(urls) <= 0 { - bin, sig, delayTillNextUpdate, err = fetchWithCache(xTransport, "", cacheFile, refreshDelay) + bin, sig, delayTillNextUpdate, err = source.fetchWithCache(xTransport, "") } else { preloadURL = urls[0] for _, url := range urls { - bin, sig, delayTillNextUpdate, err = fetchWithCache(xTransport, url, cacheFile, refreshDelay) + bin, sig, delayTillNextUpdate, err = source.fetchWithCache(xTransport, url) if err == nil { preloadURL = url break @@ -188,7 +189,7 @@ func PrefetchSources(xTransport *XTransport, sources []*Source) time.Duration { for _, urlToPrefetch := range source.prefetch { if now.After(urlToPrefetch.when) { dlog.Debugf("Prefetching [%s]", urlToPrefetch.url) - _, _, delay, err := fetchWithCache(xTransport, urlToPrefetch.url, source.cacheFile, DefaultPrefetchDelay) + _, _, delay, err := source.fetchWithCache(xTransport, urlToPrefetch.url) if err != nil { dlog.Debugf("Prefetching [%s] failed: %v", urlToPrefetch.url, err) continue diff --git a/dnscrypt-proxy/sources_test.go b/dnscrypt-proxy/sources_test.go index 1423c8a7..72800b9d 100644 --- a/dnscrypt-proxy/sources_test.go +++ b/dnscrypt-proxy/sources_test.go @@ -308,7 +308,8 @@ func setupSourceTestCase(t *testing.T, d *SourceTestData, i int, cachePath: filepath.Join(d.tempDir, id), refresh: d.timeNow, } - e.Source = &Source{urls: []string{}, prefetch: []*URLToPrefetch{}, format: SourceFormatV2, minisignKey: d.key, cacheFile: e.cachePath} + e.Source = &Source{urls: []string{}, prefetch: []*URLToPrefetch{}, format: SourceFormatV2, minisignKey: d.key, + cacheFile: e.cachePath, cacheTTL: DefaultPrefetchDelay * 3, prefetchDelay: DefaultPrefetchDelay} if cacheTest != nil { prepSourceTestCache(t, d, e, d.sources[i], *cacheTest) i = (i + 1) % len(d.sources) // make the cached and downloaded fixtures different @@ -338,9 +339,9 @@ func TestNewSource(t *testing.T) { e *SourceTestExpect }{ {"old format", d.keyStr, "v1", DefaultPrefetchDelay * 3, &SourceTestExpect{ - Source: &Source{}, err: "Unsupported source format"}}, + Source: &Source{cacheTTL: DefaultPrefetchDelay * 3, prefetchDelay: DefaultPrefetchDelay}, err: "Unsupported source format"}}, {"invalid public key", "", "v2", DefaultPrefetchDelay * 3, &SourceTestExpect{ - Source: &Source{}, err: "Invalid encoded public key"}}, + Source: &Source{cacheTTL: DefaultPrefetchDelay * 3, prefetchDelay: DefaultPrefetchDelay}, err: "Invalid encoded public key"}}, } { t.Run(tt.name, func(t *testing.T) { got, err := NewSource(d.xTransport, tt.e.Source.urls, tt.key, tt.e.cachePath, tt.v, tt.refresh)