From 4fd54a4919177ec163171e0763900eb5690284c0 Mon Sep 17 00:00:00 2001 From: Frank Denis Date: Tue, 17 Dec 2019 09:38:53 +0100 Subject: [PATCH] Store the normalized qName in the plugin state We now enforce the fact that a query always include a question. It holds true for all practical use cases of dnscrypt-proxy. This avoids quite a lot of redundant code in plugins, and is faster. --- dnscrypt-proxy/dnsutils.go | 89 ++++++++++++++-------- dnscrypt-proxy/plugin_block_ip.go | 9 +-- dnscrypt-proxy/plugin_block_ipv6.go | 6 +- dnscrypt-proxy/plugin_block_name.go | 21 ++--- dnscrypt-proxy/plugin_block_undelegated.go | 7 +- dnscrypt-proxy/plugin_block_unqualified.go | 10 +-- dnscrypt-proxy/plugin_cache.go | 29 +++---- dnscrypt-proxy/plugin_cloak.go | 12 +-- dnscrypt-proxy/plugin_firefox.go | 10 +-- dnscrypt-proxy/plugin_forward.go | 12 +-- dnscrypt-proxy/plugin_nx_log.go | 8 +- dnscrypt-proxy/plugin_query_log.go | 8 +- dnscrypt-proxy/plugin_querymeta.go | 4 - dnscrypt-proxy/plugin_whitelist_name.go | 6 +- dnscrypt-proxy/plugins.go | 20 +++-- 15 files changed, 107 insertions(+), 144 deletions(-) diff --git a/dnscrypt-proxy/dnsutils.go b/dnscrypt-proxy/dnsutils.go index 3cc2b586..2daed8e4 100644 --- a/dnscrypt-proxy/dnsutils.go +++ b/dnscrypt-proxy/dnsutils.go @@ -2,9 +2,11 @@ package main import ( "encoding/binary" + "errors" "net" "strings" "time" + "unicode/utf8" "github.com/miekg/dns" ) @@ -42,36 +44,37 @@ func RefusedResponseFromMessage(srcMsg *dns.Msg, refusedCode bool, ipv4 net.IP, } else { dstMsg.Rcode = dns.RcodeSuccess questions := srcMsg.Question - if len(questions) > 0 { - question := questions[0] - sendHInfoResponse := true + if len(questions) == 0 { + return dstMsg + } + question := questions[0] + sendHInfoResponse := true - if ipv4 != nil && question.Qtype == dns.TypeA { - rr := new(dns.A) - rr.Hdr = dns.RR_Header{Name: question.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl} - rr.A = ipv4.To4() - if rr.A != nil { - dstMsg.Answer = []dns.RR{rr} - sendHInfoResponse = false - } - } else if ipv6 != nil && question.Qtype == dns.TypeAAAA { - rr := new(dns.AAAA) - rr.Hdr = dns.RR_Header{Name: question.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: ttl} - rr.AAAA = ipv6.To16() - if rr.AAAA != nil { - dstMsg.Answer = []dns.RR{rr} - sendHInfoResponse = false - } + if ipv4 != nil && question.Qtype == dns.TypeA { + rr := new(dns.A) + rr.Hdr = dns.RR_Header{Name: question.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl} + rr.A = ipv4.To4() + if rr.A != nil { + dstMsg.Answer = []dns.RR{rr} + sendHInfoResponse = false } + } else if ipv6 != nil && question.Qtype == dns.TypeAAAA { + rr := new(dns.AAAA) + rr.Hdr = dns.RR_Header{Name: question.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: ttl} + rr.AAAA = ipv6.To16() + if rr.AAAA != nil { + dstMsg.Answer = []dns.RR{rr} + sendHInfoResponse = false + } + } - if sendHInfoResponse { - hinfo := new(dns.HINFO) - hinfo.Hdr = dns.RR_Header{Name: question.Name, Rrtype: dns.TypeHINFO, - Class: dns.ClassINET, Ttl: 1} - hinfo.Cpu = "This query has been locally blocked" - hinfo.Os = "by dnscrypt-proxy" - dstMsg.Answer = []dns.RR{hinfo} - } + if sendHInfoResponse { + hinfo := new(dns.HINFO) + hinfo.Hdr = dns.RR_Header{Name: question.Name, Rrtype: dns.TypeHINFO, + Class: dns.ClassINET, Ttl: 1} + hinfo.Cpu = "This query has been locally blocked" + hinfo.Os = "by dnscrypt-proxy" + dstMsg.Answer = []dns.RR{hinfo} } } return dstMsg @@ -93,7 +96,7 @@ func Rcode(packet []byte) uint8 { return packet[3] & 0xf } -func NormalizeName(name *[]byte) { +func NormalizeRawQName(name *[]byte) { for i, c := range *name { if c >= 65 && c <= 90 { (*name)[i] = c + 32 @@ -101,11 +104,33 @@ func NormalizeName(name *[]byte) { } } -func StripTrailingDot(str string) string { - if len(str) > 1 && strings.HasSuffix(str, ".") { - str = str[:len(str)-1] +func NormalizeQName(str string) (string, error) { + if len(str) == 0 || str == "." { + return ".", nil } - return str + hasUpper := false + str = strings.TrimSuffix(str, ".") + strLen := len(str) + for i := 0; i < strLen; i++ { + c := str[i] + if c >= utf8.RuneSelf { + return str, errors.New("Query name is not an ASCII string") + } + hasUpper = hasUpper || ('A' <= c && c <= 'Z') + } + if !hasUpper { + return str, nil + } + var b strings.Builder + b.Grow(len(str)) + for i := 0; i < strLen; i++ { + c := str[i] + if 'A' <= c && c <= 'Z' { + c += 'a' - 'A' + } + b.WriteByte(c) + } + return b.String(), nil } func getMinTTL(msg *dns.Msg, minTTL uint32, maxTTL uint32, cacheNegMinTTL uint32, cacheNegMaxTTL uint32) time.Duration { diff --git a/dnscrypt-proxy/plugin_block_ip.go b/dnscrypt-proxy/plugin_block_ip.go index 3c4f0b32..317ea314 100644 --- a/dnscrypt-proxy/plugin_block_ip.go +++ b/dnscrypt-proxy/plugin_block_ip.go @@ -122,14 +122,7 @@ func (plugin *PluginBlockIP) Eval(pluginsState *PluginsState, msg *dns.Msg) erro pluginsState.action = PluginsActionReject pluginsState.returnCode = PluginsReturnCodeReject if plugin.logger != nil { - questions := msg.Question - if len(questions) != 1 { - return nil - } - qName := strings.ToLower(StripTrailingDot(questions[0].Name)) - if len(qName) < 2 { - return nil - } + qName := pluginsState.qName var clientIPStr string if pluginsState.clientProto == "udp" { clientIPStr = (*pluginsState.clientAddr).(*net.UDPAddr).IP.String() diff --git a/dnscrypt-proxy/plugin_block_ipv6.go b/dnscrypt-proxy/plugin_block_ipv6.go index 52fd63cd..6dbbd7e7 100644 --- a/dnscrypt-proxy/plugin_block_ipv6.go +++ b/dnscrypt-proxy/plugin_block_ipv6.go @@ -29,11 +29,7 @@ func (plugin *PluginBlockIPv6) Reload() error { } func (plugin *PluginBlockIPv6) Eval(pluginsState *PluginsState, msg *dns.Msg) error { - questions := msg.Question - if len(questions) != 1 { - return nil - } - question := questions[0] + question := msg.Question[0] if question.Qclass != dns.ClassINET || question.Qtype != dns.TypeAAAA { return nil } diff --git a/dnscrypt-proxy/plugin_block_name.go b/dnscrypt-proxy/plugin_block_name.go index cc9418f2..73e96211 100644 --- a/dnscrypt-proxy/plugin_block_name.go +++ b/dnscrypt-proxy/plugin_block_name.go @@ -25,10 +25,9 @@ const aliasesLimit = 8 var blockedNames *BlockedNames func (blockedNames *BlockedNames) check(pluginsState *PluginsState, qName string, aliasFor *string) (bool, error) { - qName = strings.ToLower(StripTrailingDot(qName)) reject, reason, xweeklyRanges := blockedNames.patternMatcher.Eval(qName) if aliasFor != nil { - reason = reason + " (alias for [" + StripTrailingDot(*aliasFor) + "])" + reason = reason + " (alias for [" + *aliasFor + "])" } var weeklyRanges *WeeklyRanges if xweeklyRanges != nil { @@ -144,11 +143,7 @@ func (plugin *PluginBlockName) Eval(pluginsState *PluginsState, msg *dns.Msg) er if blockedNames == nil || pluginsState.sessionData["whitelisted"] != nil { return nil } - questions := msg.Question - if len(questions) != 1 { - return nil - } - _, err := blockedNames.check(pluginsState, questions[0].Name, nil) + _, err := blockedNames.check(pluginsState, pluginsState.qName, nil) return err } @@ -181,11 +176,7 @@ func (plugin *PluginBlockNameResponse) Eval(pluginsState *PluginsState, msg *dns if blockedNames == nil || pluginsState.sessionData["whitelisted"] != nil { return nil } - questions := msg.Question - if len(questions) != 1 { - return nil - } - aliasFor := questions[0].Name + aliasFor := pluginsState.qName aliasesLeft := aliasesLimit answers := msg.Answer for _, answer := range answers { @@ -193,7 +184,11 @@ func (plugin *PluginBlockNameResponse) Eval(pluginsState *PluginsState, msg *dns if header.Class != dns.ClassINET || header.Rrtype != dns.TypeCNAME { continue } - if blocked, err := blockedNames.check(pluginsState, answer.(*dns.CNAME).Target, &aliasFor); blocked || err != nil { + target, err := NormalizeQName(answer.(*dns.CNAME).Target) + if err != nil { + return err + } + if blocked, err := blockedNames.check(pluginsState, target, &aliasFor); blocked || err != nil { return err } aliasesLeft-- diff --git a/dnscrypt-proxy/plugin_block_undelegated.go b/dnscrypt-proxy/plugin_block_undelegated.go index a1b0eada..0fd6b601 100644 --- a/dnscrypt-proxy/plugin_block_undelegated.go +++ b/dnscrypt-proxy/plugin_block_undelegated.go @@ -3,7 +3,6 @@ package main import ( "github.com/k-sone/critbitgo" "github.com/miekg/dns" - "strings" ) var undelegatedSet = []string{ @@ -179,11 +178,7 @@ func (plugin *PluginBlockUndelegated) Reload() error { } func (plugin *PluginBlockUndelegated) Eval(pluginsState *PluginsState, msg *dns.Msg) error { - questions := msg.Question - if len(questions) != 1 { - return nil - } - revQname := strings.ToLower(StringReverse(questions[0].Name)) + revQname := StringReverse(pluginsState.qName) match, _, found := plugin.suffixes.LongestPrefix([]byte(revQname)) if !found { return nil diff --git a/dnscrypt-proxy/plugin_block_unqualified.go b/dnscrypt-proxy/plugin_block_unqualified.go index 553cdd68..125a54bb 100644 --- a/dnscrypt-proxy/plugin_block_unqualified.go +++ b/dnscrypt-proxy/plugin_block_unqualified.go @@ -30,17 +30,11 @@ func (plugin *PluginBlockUnqualified) Reload() error { } func (plugin *PluginBlockUnqualified) Eval(pluginsState *PluginsState, msg *dns.Msg) error { - questions := msg.Question - if len(questions) != 1 { - return nil - } - question := questions[0] + question := msg.Question[0] if question.Qclass != dns.ClassINET || (question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA) { return nil } - qName := question.Name - idx := strings.IndexByte(qName, '.') - if idx == -1 || (idx == 0 || idx+1 != len(qName)) { + if strings.IndexByte(pluginsState.qName, '.') >= 0 { return nil } synth := EmptyResponseFromMessage(msg) diff --git a/dnscrypt-proxy/plugin_cache.go b/dnscrypt-proxy/plugin_cache.go index 2d4ef5a1..4e342033 100644 --- a/dnscrypt-proxy/plugin_cache.go +++ b/dnscrypt-proxy/plugin_cache.go @@ -3,7 +3,6 @@ package main import ( "crypto/sha512" "encoding/binary" - "errors" "sync" "time" @@ -23,12 +22,8 @@ type CachedResponses struct { var cachedResponses CachedResponses -func computeCacheKey(pluginsState *PluginsState, msg *dns.Msg) ([32]byte, error) { - questions := msg.Question - if len(questions) != 1 { - return [32]byte{}, errors.New("No question present") - } - question := questions[0] +func computeCacheKey(pluginsState *PluginsState, msg *dns.Msg) [32]byte { + question := msg.Question[0] h := sha512.New512_256() var tmp [5]byte binary.LittleEndian.PutUint16(tmp[0:2], question.Qtype) @@ -37,12 +32,13 @@ func computeCacheKey(pluginsState *PluginsState, msg *dns.Msg) ([32]byte, error) tmp[4] = 1 } h.Write(tmp[:]) - normalizedName := []byte(question.Name) - NormalizeName(&normalizedName) - h.Write(normalizedName) + normalizedRawQName := []byte(question.Name) + NormalizeRawQName(&normalizedRawQName) + h.Write(normalizedRawQName) var sum [32]byte h.Sum(sum[:0]) - return sum, nil + + return sum } // --- @@ -71,10 +67,7 @@ func (plugin *PluginCache) Reload() error { } func (plugin *PluginCache) Eval(pluginsState *PluginsState, msg *dns.Msg) error { - cacheKey, err := computeCacheKey(pluginsState, msg) - if err != nil { - return nil - } + cacheKey := computeCacheKey(pluginsState, msg) cachedResponses.RLock() defer cachedResponses.RUnlock() if cachedResponses.cache == nil { @@ -134,10 +127,7 @@ func (plugin *PluginCacheResponse) Eval(pluginsState *PluginsState, msg *dns.Msg if msg.Truncated { return nil } - cacheKey, err := computeCacheKey(pluginsState, msg) - if err != nil { - return err - } + cacheKey := computeCacheKey(pluginsState, msg) ttl := getMinTTL(msg, pluginsState.cacheMinTTL, pluginsState.cacheMaxTTL, pluginsState.cacheNegMinTTL, pluginsState.cacheNegMaxTTL) cachedResponse := CachedResponse{ expiration: time.Now().Add(ttl), @@ -145,6 +135,7 @@ func (plugin *PluginCacheResponse) Eval(pluginsState *PluginsState, msg *dns.Msg } cachedResponses.Lock() if cachedResponses.cache == nil { + var err error cachedResponses.cache, err = lru.NewARC(pluginsState.cacheSize) if err != nil { cachedResponses.Unlock() diff --git a/dnscrypt-proxy/plugin_cloak.go b/dnscrypt-proxy/plugin_cloak.go index ad595ce0..3bad64f9 100644 --- a/dnscrypt-proxy/plugin_cloak.go +++ b/dnscrypt-proxy/plugin_cloak.go @@ -100,21 +100,13 @@ func (plugin *PluginCloak) Reload() error { } func (plugin *PluginCloak) Eval(pluginsState *PluginsState, msg *dns.Msg) error { - questions := msg.Question - if len(questions) != 1 { - return nil - } - question := questions[0] + question := msg.Question[0] if question.Qclass != dns.ClassINET || (question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA) { return nil } - qName := strings.ToLower(StripTrailingDot(questions[0].Name)) - if len(qName) < 2 { - return nil - } now := time.Now() plugin.RLock() - _, _, xcloakedName := plugin.patternMatcher.Eval(qName) + _, _, xcloakedName := plugin.patternMatcher.Eval(pluginsState.qName) if xcloakedName == nil { plugin.RUnlock() return nil diff --git a/dnscrypt-proxy/plugin_firefox.go b/dnscrypt-proxy/plugin_firefox.go index c3c0dcf5..4ec7bfb0 100644 --- a/dnscrypt-proxy/plugin_firefox.go +++ b/dnscrypt-proxy/plugin_firefox.go @@ -34,16 +34,12 @@ func (plugin *PluginFirefox) Reload() error { } func (plugin *PluginFirefox) Eval(pluginsState *PluginsState, msg *dns.Msg) error { - questions := msg.Question - if len(questions) != 1 { - return nil - } - question := questions[0] + question := msg.Question[0] if question.Qclass != dns.ClassINET || (question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA) { return nil } - qName := strings.ToLower(question.Name) - if qName != "use-application-dns.net." && !strings.HasSuffix(qName, ".use-application-dns.net.") { + qName := pluginsState.qName + if qName != "use-application-dns.net" && !strings.HasSuffix(qName, ".use-application-dns.net") { return nil } synth := EmptyResponseFromMessage(msg) diff --git a/dnscrypt-proxy/plugin_forward.go b/dnscrypt-proxy/plugin_forward.go index 5abd13bd..1cc1c4ee 100644 --- a/dnscrypt-proxy/plugin_forward.go +++ b/dnscrypt-proxy/plugin_forward.go @@ -71,19 +71,15 @@ func (plugin *PluginForward) Reload() error { } func (plugin *PluginForward) Eval(pluginsState *PluginsState, msg *dns.Msg) error { - questions := msg.Question - if len(questions) != 1 { - return nil - } - question := strings.ToLower(StripTrailingDot(questions[0].Name)) - questionLen := len(question) + qName := pluginsState.qName + qNameLen := len(qName) var servers []string for _, candidate := range plugin.forwardMap { candidateLen := len(candidate.domain) - if candidateLen > questionLen { + if candidateLen > qNameLen { continue } - if question[questionLen-candidateLen:] == candidate.domain && (candidateLen == questionLen || (question[questionLen-candidateLen-1] == '.')) { + if qName[qNameLen-candidateLen:] == candidate.domain && (candidateLen == qNameLen || (qName[qNameLen-candidateLen-1] == '.')) { servers = candidate.servers break } diff --git a/dnscrypt-proxy/plugin_nx_log.go b/dnscrypt-proxy/plugin_nx_log.go index f252d3b0..b194be02 100644 --- a/dnscrypt-proxy/plugin_nx_log.go +++ b/dnscrypt-proxy/plugin_nx_log.go @@ -43,11 +43,7 @@ func (plugin *PluginNxLog) Eval(pluginsState *PluginsState, msg *dns.Msg) error if msg.Rcode != dns.RcodeNameError { return nil } - questions := msg.Question - if len(questions) == 0 { - return nil - } - question := questions[0] + question := msg.Question[0] qType, ok := dns.TypeToString[question.Qtype] if !ok { qType = string(qType) @@ -58,7 +54,7 @@ func (plugin *PluginNxLog) Eval(pluginsState *PluginsState, msg *dns.Msg) error } else { clientIPStr = (*pluginsState.clientAddr).(*net.TCPAddr).IP.String() } - qName := StripTrailingDot(question.Name) + qName := pluginsState.qName var line string if plugin.format == "tsv" { diff --git a/dnscrypt-proxy/plugin_query_log.go b/dnscrypt-proxy/plugin_query_log.go index f5d603d9..d94b4b94 100644 --- a/dnscrypt-proxy/plugin_query_log.go +++ b/dnscrypt-proxy/plugin_query_log.go @@ -43,11 +43,7 @@ func (plugin *PluginQueryLog) Reload() error { } func (plugin *PluginQueryLog) Eval(pluginsState *PluginsState, msg *dns.Msg) error { - questions := msg.Question - if len(questions) == 0 { - return nil - } - question := questions[0] + question := msg.Question[0] qType, ok := dns.TypeToString[question.Qtype] if !ok { qType = string(qType) @@ -65,7 +61,7 @@ func (plugin *PluginQueryLog) Eval(pluginsState *PluginsState, msg *dns.Msg) err } else { clientIPStr = (*pluginsState.clientAddr).(*net.TCPAddr).IP.String() } - qName := StripTrailingDot(question.Name) + qName := pluginsState.qName if pluginsState.cacheHit { pluginsState.serverName = "-" diff --git a/dnscrypt-proxy/plugin_querymeta.go b/dnscrypt-proxy/plugin_querymeta.go index 68398f31..9dd8da2f 100644 --- a/dnscrypt-proxy/plugin_querymeta.go +++ b/dnscrypt-proxy/plugin_querymeta.go @@ -34,10 +34,6 @@ func (plugin *PluginQueryMeta) Reload() error { } func (plugin *PluginQueryMeta) Eval(pluginsState *PluginsState, msg *dns.Msg) error { - questions := msg.Question - if len(questions) == 0 { - return nil - } msg.Extra = []dns.RR{plugin.queryMetaRR} return nil } diff --git a/dnscrypt-proxy/plugin_whitelist_name.go b/dnscrypt-proxy/plugin_whitelist_name.go index 5afc28d8..6c4a44f7 100644 --- a/dnscrypt-proxy/plugin_whitelist_name.go +++ b/dnscrypt-proxy/plugin_whitelist_name.go @@ -82,11 +82,7 @@ func (plugin *PluginWhitelistName) Reload() error { } func (plugin *PluginWhitelistName) Eval(pluginsState *PluginsState, msg *dns.Msg) error { - questions := msg.Question - if len(questions) != 1 { - return nil - } - qName := strings.ToLower(StripTrailingDot(questions[0].Name)) + qName := pluginsState.qName whitelist, reason, xweeklyRanges := plugin.patternMatcher.Eval(qName) var weeklyRanges *WeeklyRanges if xweeklyRanges != nil { diff --git a/dnscrypt-proxy/plugins.go b/dnscrypt-proxy/plugins.go index 1b8dd43d..adce6fc8 100644 --- a/dnscrypt-proxy/plugins.go +++ b/dnscrypt-proxy/plugins.go @@ -80,6 +80,7 @@ type PluginsState struct { cacheMaxTTL uint32 rejectTTL uint32 questionMsg *dns.Msg + qName string requestStart time.Time requestEnd time.Time cacheHit bool @@ -235,25 +236,31 @@ func NewPluginsState(proxy *Proxy, clientProto string, clientAddr *net.Addr, sta cacheMaxTTL: proxy.cacheMaxTTL, rejectTTL: proxy.rejectTTL, questionMsg: nil, + qName: "", requestStart: start, maxUnencryptedUDPSafePayloadSize: MaxDNSUDPSafePacketSize, } } func (pluginsState *PluginsState) ApplyQueryPlugins(pluginsGlobals *PluginsGlobals, packet []byte, serverName string) ([]byte, error) { - if len(*pluginsGlobals.queryPlugins) == 0 && len(*pluginsGlobals.loggingPlugins) == 0 { - return packet, nil - } pluginsState.serverName = serverName pluginsState.action = PluginsActionForward msg := dns.Msg{} if err := msg.Unpack(packet); err != nil { return packet, err } - if len(msg.Question) > 1 { + if len(msg.Question) != 1 { return packet, errors.New("Unexpected number of questions") } + qName, err := NormalizeQName(msg.Question[0].Name) + if err != nil { + return packet, err + } + pluginsState.qName = qName pluginsState.questionMsg = &msg + if len(*pluginsGlobals.queryPlugins) == 0 && len(*pluginsGlobals.loggingPlugins) == 0 { + return packet, nil + } pluginsGlobals.RLock() defer pluginsGlobals.RUnlock() for _, plugin := range *pluginsGlobals.queryPlugins { @@ -307,7 +314,6 @@ func (pluginsState *PluginsState) ApplyResponsePlugins(pluginsGlobals *PluginsGl } if pluginsState.action == PluginsActionReject { synth := RefusedResponseFromMessage(&msg, pluginsGlobals.refusedCodeInResponses, pluginsGlobals.respondWithIPv4, pluginsGlobals.respondWithIPv6, pluginsState.rejectTTL) - dlog.Infof("Blocking [%s]", synth.Question[0].Name) pluginsState.synthResponse = synth } if pluginsState.action != PluginsActionForward { @@ -330,8 +336,8 @@ func (pluginsState *PluginsState) ApplyLoggingPlugins(pluginsGlobals *PluginsGlo } pluginsState.requestEnd = time.Now() questionMsg := pluginsState.questionMsg - if questionMsg == nil || len(questionMsg.Question) > 1 { - return errors.New("Unexpected number of questions") + if questionMsg == nil { + return errors.New("Question not found") } pluginsGlobals.RLock() defer pluginsGlobals.RUnlock()