Store the normalized qName in the plugin state

We now enforce the fact that a query always include a question.
It holds true for all practical use cases of dnscrypt-proxy.

This avoids quite a lot of redundant code in plugins, and is faster.
This commit is contained in:
Frank Denis 2019-12-17 09:38:53 +01:00
parent ee24bf0421
commit 4fd54a4919
15 changed files with 107 additions and 144 deletions

View File

@ -2,9 +2,11 @@ package main
import (
"encoding/binary"
"errors"
"net"
"strings"
"time"
"unicode/utf8"
"github.com/miekg/dns"
)
@ -42,36 +44,37 @@ func RefusedResponseFromMessage(srcMsg *dns.Msg, refusedCode bool, ipv4 net.IP,
} else {
dstMsg.Rcode = dns.RcodeSuccess
questions := srcMsg.Question
if len(questions) > 0 {
question := questions[0]
sendHInfoResponse := true
if len(questions) == 0 {
return dstMsg
}
question := questions[0]
sendHInfoResponse := true
if ipv4 != nil && question.Qtype == dns.TypeA {
rr := new(dns.A)
rr.Hdr = dns.RR_Header{Name: question.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl}
rr.A = ipv4.To4()
if rr.A != nil {
dstMsg.Answer = []dns.RR{rr}
sendHInfoResponse = false
}
} else if ipv6 != nil && question.Qtype == dns.TypeAAAA {
rr := new(dns.AAAA)
rr.Hdr = dns.RR_Header{Name: question.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: ttl}
rr.AAAA = ipv6.To16()
if rr.AAAA != nil {
dstMsg.Answer = []dns.RR{rr}
sendHInfoResponse = false
}
if ipv4 != nil && question.Qtype == dns.TypeA {
rr := new(dns.A)
rr.Hdr = dns.RR_Header{Name: question.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl}
rr.A = ipv4.To4()
if rr.A != nil {
dstMsg.Answer = []dns.RR{rr}
sendHInfoResponse = false
}
} else if ipv6 != nil && question.Qtype == dns.TypeAAAA {
rr := new(dns.AAAA)
rr.Hdr = dns.RR_Header{Name: question.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: ttl}
rr.AAAA = ipv6.To16()
if rr.AAAA != nil {
dstMsg.Answer = []dns.RR{rr}
sendHInfoResponse = false
}
}
if sendHInfoResponse {
hinfo := new(dns.HINFO)
hinfo.Hdr = dns.RR_Header{Name: question.Name, Rrtype: dns.TypeHINFO,
Class: dns.ClassINET, Ttl: 1}
hinfo.Cpu = "This query has been locally blocked"
hinfo.Os = "by dnscrypt-proxy"
dstMsg.Answer = []dns.RR{hinfo}
}
if sendHInfoResponse {
hinfo := new(dns.HINFO)
hinfo.Hdr = dns.RR_Header{Name: question.Name, Rrtype: dns.TypeHINFO,
Class: dns.ClassINET, Ttl: 1}
hinfo.Cpu = "This query has been locally blocked"
hinfo.Os = "by dnscrypt-proxy"
dstMsg.Answer = []dns.RR{hinfo}
}
}
return dstMsg
@ -93,7 +96,7 @@ func Rcode(packet []byte) uint8 {
return packet[3] & 0xf
}
func NormalizeName(name *[]byte) {
func NormalizeRawQName(name *[]byte) {
for i, c := range *name {
if c >= 65 && c <= 90 {
(*name)[i] = c + 32
@ -101,11 +104,33 @@ func NormalizeName(name *[]byte) {
}
}
func StripTrailingDot(str string) string {
if len(str) > 1 && strings.HasSuffix(str, ".") {
str = str[:len(str)-1]
func NormalizeQName(str string) (string, error) {
if len(str) == 0 || str == "." {
return ".", nil
}
return str
hasUpper := false
str = strings.TrimSuffix(str, ".")
strLen := len(str)
for i := 0; i < strLen; i++ {
c := str[i]
if c >= utf8.RuneSelf {
return str, errors.New("Query name is not an ASCII string")
}
hasUpper = hasUpper || ('A' <= c && c <= 'Z')
}
if !hasUpper {
return str, nil
}
var b strings.Builder
b.Grow(len(str))
for i := 0; i < strLen; i++ {
c := str[i]
if 'A' <= c && c <= 'Z' {
c += 'a' - 'A'
}
b.WriteByte(c)
}
return b.String(), nil
}
func getMinTTL(msg *dns.Msg, minTTL uint32, maxTTL uint32, cacheNegMinTTL uint32, cacheNegMaxTTL uint32) time.Duration {

View File

@ -122,14 +122,7 @@ func (plugin *PluginBlockIP) Eval(pluginsState *PluginsState, msg *dns.Msg) erro
pluginsState.action = PluginsActionReject
pluginsState.returnCode = PluginsReturnCodeReject
if plugin.logger != nil {
questions := msg.Question
if len(questions) != 1 {
return nil
}
qName := strings.ToLower(StripTrailingDot(questions[0].Name))
if len(qName) < 2 {
return nil
}
qName := pluginsState.qName
var clientIPStr string
if pluginsState.clientProto == "udp" {
clientIPStr = (*pluginsState.clientAddr).(*net.UDPAddr).IP.String()

View File

@ -29,11 +29,7 @@ func (plugin *PluginBlockIPv6) Reload() error {
}
func (plugin *PluginBlockIPv6) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
questions := msg.Question
if len(questions) != 1 {
return nil
}
question := questions[0]
question := msg.Question[0]
if question.Qclass != dns.ClassINET || question.Qtype != dns.TypeAAAA {
return nil
}

View File

@ -25,10 +25,9 @@ const aliasesLimit = 8
var blockedNames *BlockedNames
func (blockedNames *BlockedNames) check(pluginsState *PluginsState, qName string, aliasFor *string) (bool, error) {
qName = strings.ToLower(StripTrailingDot(qName))
reject, reason, xweeklyRanges := blockedNames.patternMatcher.Eval(qName)
if aliasFor != nil {
reason = reason + " (alias for [" + StripTrailingDot(*aliasFor) + "])"
reason = reason + " (alias for [" + *aliasFor + "])"
}
var weeklyRanges *WeeklyRanges
if xweeklyRanges != nil {
@ -144,11 +143,7 @@ func (plugin *PluginBlockName) Eval(pluginsState *PluginsState, msg *dns.Msg) er
if blockedNames == nil || pluginsState.sessionData["whitelisted"] != nil {
return nil
}
questions := msg.Question
if len(questions) != 1 {
return nil
}
_, err := blockedNames.check(pluginsState, questions[0].Name, nil)
_, err := blockedNames.check(pluginsState, pluginsState.qName, nil)
return err
}
@ -181,11 +176,7 @@ func (plugin *PluginBlockNameResponse) Eval(pluginsState *PluginsState, msg *dns
if blockedNames == nil || pluginsState.sessionData["whitelisted"] != nil {
return nil
}
questions := msg.Question
if len(questions) != 1 {
return nil
}
aliasFor := questions[0].Name
aliasFor := pluginsState.qName
aliasesLeft := aliasesLimit
answers := msg.Answer
for _, answer := range answers {
@ -193,7 +184,11 @@ func (plugin *PluginBlockNameResponse) Eval(pluginsState *PluginsState, msg *dns
if header.Class != dns.ClassINET || header.Rrtype != dns.TypeCNAME {
continue
}
if blocked, err := blockedNames.check(pluginsState, answer.(*dns.CNAME).Target, &aliasFor); blocked || err != nil {
target, err := NormalizeQName(answer.(*dns.CNAME).Target)
if err != nil {
return err
}
if blocked, err := blockedNames.check(pluginsState, target, &aliasFor); blocked || err != nil {
return err
}
aliasesLeft--

View File

@ -3,7 +3,6 @@ package main
import (
"github.com/k-sone/critbitgo"
"github.com/miekg/dns"
"strings"
)
var undelegatedSet = []string{
@ -179,11 +178,7 @@ func (plugin *PluginBlockUndelegated) Reload() error {
}
func (plugin *PluginBlockUndelegated) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
questions := msg.Question
if len(questions) != 1 {
return nil
}
revQname := strings.ToLower(StringReverse(questions[0].Name))
revQname := StringReverse(pluginsState.qName)
match, _, found := plugin.suffixes.LongestPrefix([]byte(revQname))
if !found {
return nil

View File

@ -30,17 +30,11 @@ func (plugin *PluginBlockUnqualified) Reload() error {
}
func (plugin *PluginBlockUnqualified) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
questions := msg.Question
if len(questions) != 1 {
return nil
}
question := questions[0]
question := msg.Question[0]
if question.Qclass != dns.ClassINET || (question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA) {
return nil
}
qName := question.Name
idx := strings.IndexByte(qName, '.')
if idx == -1 || (idx == 0 || idx+1 != len(qName)) {
if strings.IndexByte(pluginsState.qName, '.') >= 0 {
return nil
}
synth := EmptyResponseFromMessage(msg)

View File

@ -3,7 +3,6 @@ package main
import (
"crypto/sha512"
"encoding/binary"
"errors"
"sync"
"time"
@ -23,12 +22,8 @@ type CachedResponses struct {
var cachedResponses CachedResponses
func computeCacheKey(pluginsState *PluginsState, msg *dns.Msg) ([32]byte, error) {
questions := msg.Question
if len(questions) != 1 {
return [32]byte{}, errors.New("No question present")
}
question := questions[0]
func computeCacheKey(pluginsState *PluginsState, msg *dns.Msg) [32]byte {
question := msg.Question[0]
h := sha512.New512_256()
var tmp [5]byte
binary.LittleEndian.PutUint16(tmp[0:2], question.Qtype)
@ -37,12 +32,13 @@ func computeCacheKey(pluginsState *PluginsState, msg *dns.Msg) ([32]byte, error)
tmp[4] = 1
}
h.Write(tmp[:])
normalizedName := []byte(question.Name)
NormalizeName(&normalizedName)
h.Write(normalizedName)
normalizedRawQName := []byte(question.Name)
NormalizeRawQName(&normalizedRawQName)
h.Write(normalizedRawQName)
var sum [32]byte
h.Sum(sum[:0])
return sum, nil
return sum
}
// ---
@ -71,10 +67,7 @@ func (plugin *PluginCache) Reload() error {
}
func (plugin *PluginCache) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
cacheKey, err := computeCacheKey(pluginsState, msg)
if err != nil {
return nil
}
cacheKey := computeCacheKey(pluginsState, msg)
cachedResponses.RLock()
defer cachedResponses.RUnlock()
if cachedResponses.cache == nil {
@ -134,10 +127,7 @@ func (plugin *PluginCacheResponse) Eval(pluginsState *PluginsState, msg *dns.Msg
if msg.Truncated {
return nil
}
cacheKey, err := computeCacheKey(pluginsState, msg)
if err != nil {
return err
}
cacheKey := computeCacheKey(pluginsState, msg)
ttl := getMinTTL(msg, pluginsState.cacheMinTTL, pluginsState.cacheMaxTTL, pluginsState.cacheNegMinTTL, pluginsState.cacheNegMaxTTL)
cachedResponse := CachedResponse{
expiration: time.Now().Add(ttl),
@ -145,6 +135,7 @@ func (plugin *PluginCacheResponse) Eval(pluginsState *PluginsState, msg *dns.Msg
}
cachedResponses.Lock()
if cachedResponses.cache == nil {
var err error
cachedResponses.cache, err = lru.NewARC(pluginsState.cacheSize)
if err != nil {
cachedResponses.Unlock()

View File

@ -100,21 +100,13 @@ func (plugin *PluginCloak) Reload() error {
}
func (plugin *PluginCloak) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
questions := msg.Question
if len(questions) != 1 {
return nil
}
question := questions[0]
question := msg.Question[0]
if question.Qclass != dns.ClassINET || (question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA) {
return nil
}
qName := strings.ToLower(StripTrailingDot(questions[0].Name))
if len(qName) < 2 {
return nil
}
now := time.Now()
plugin.RLock()
_, _, xcloakedName := plugin.patternMatcher.Eval(qName)
_, _, xcloakedName := plugin.patternMatcher.Eval(pluginsState.qName)
if xcloakedName == nil {
plugin.RUnlock()
return nil

View File

@ -34,16 +34,12 @@ func (plugin *PluginFirefox) Reload() error {
}
func (plugin *PluginFirefox) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
questions := msg.Question
if len(questions) != 1 {
return nil
}
question := questions[0]
question := msg.Question[0]
if question.Qclass != dns.ClassINET || (question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA) {
return nil
}
qName := strings.ToLower(question.Name)
if qName != "use-application-dns.net." && !strings.HasSuffix(qName, ".use-application-dns.net.") {
qName := pluginsState.qName
if qName != "use-application-dns.net" && !strings.HasSuffix(qName, ".use-application-dns.net") {
return nil
}
synth := EmptyResponseFromMessage(msg)

View File

@ -71,19 +71,15 @@ func (plugin *PluginForward) Reload() error {
}
func (plugin *PluginForward) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
questions := msg.Question
if len(questions) != 1 {
return nil
}
question := strings.ToLower(StripTrailingDot(questions[0].Name))
questionLen := len(question)
qName := pluginsState.qName
qNameLen := len(qName)
var servers []string
for _, candidate := range plugin.forwardMap {
candidateLen := len(candidate.domain)
if candidateLen > questionLen {
if candidateLen > qNameLen {
continue
}
if question[questionLen-candidateLen:] == candidate.domain && (candidateLen == questionLen || (question[questionLen-candidateLen-1] == '.')) {
if qName[qNameLen-candidateLen:] == candidate.domain && (candidateLen == qNameLen || (qName[qNameLen-candidateLen-1] == '.')) {
servers = candidate.servers
break
}

View File

@ -43,11 +43,7 @@ func (plugin *PluginNxLog) Eval(pluginsState *PluginsState, msg *dns.Msg) error
if msg.Rcode != dns.RcodeNameError {
return nil
}
questions := msg.Question
if len(questions) == 0 {
return nil
}
question := questions[0]
question := msg.Question[0]
qType, ok := dns.TypeToString[question.Qtype]
if !ok {
qType = string(qType)
@ -58,7 +54,7 @@ func (plugin *PluginNxLog) Eval(pluginsState *PluginsState, msg *dns.Msg) error
} else {
clientIPStr = (*pluginsState.clientAddr).(*net.TCPAddr).IP.String()
}
qName := StripTrailingDot(question.Name)
qName := pluginsState.qName
var line string
if plugin.format == "tsv" {

View File

@ -43,11 +43,7 @@ func (plugin *PluginQueryLog) Reload() error {
}
func (plugin *PluginQueryLog) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
questions := msg.Question
if len(questions) == 0 {
return nil
}
question := questions[0]
question := msg.Question[0]
qType, ok := dns.TypeToString[question.Qtype]
if !ok {
qType = string(qType)
@ -65,7 +61,7 @@ func (plugin *PluginQueryLog) Eval(pluginsState *PluginsState, msg *dns.Msg) err
} else {
clientIPStr = (*pluginsState.clientAddr).(*net.TCPAddr).IP.String()
}
qName := StripTrailingDot(question.Name)
qName := pluginsState.qName
if pluginsState.cacheHit {
pluginsState.serverName = "-"

View File

@ -34,10 +34,6 @@ func (plugin *PluginQueryMeta) Reload() error {
}
func (plugin *PluginQueryMeta) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
questions := msg.Question
if len(questions) == 0 {
return nil
}
msg.Extra = []dns.RR{plugin.queryMetaRR}
return nil
}

View File

@ -82,11 +82,7 @@ func (plugin *PluginWhitelistName) Reload() error {
}
func (plugin *PluginWhitelistName) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
questions := msg.Question
if len(questions) != 1 {
return nil
}
qName := strings.ToLower(StripTrailingDot(questions[0].Name))
qName := pluginsState.qName
whitelist, reason, xweeklyRanges := plugin.patternMatcher.Eval(qName)
var weeklyRanges *WeeklyRanges
if xweeklyRanges != nil {

View File

@ -80,6 +80,7 @@ type PluginsState struct {
cacheMaxTTL uint32
rejectTTL uint32
questionMsg *dns.Msg
qName string
requestStart time.Time
requestEnd time.Time
cacheHit bool
@ -235,25 +236,31 @@ func NewPluginsState(proxy *Proxy, clientProto string, clientAddr *net.Addr, sta
cacheMaxTTL: proxy.cacheMaxTTL,
rejectTTL: proxy.rejectTTL,
questionMsg: nil,
qName: "",
requestStart: start,
maxUnencryptedUDPSafePayloadSize: MaxDNSUDPSafePacketSize,
}
}
func (pluginsState *PluginsState) ApplyQueryPlugins(pluginsGlobals *PluginsGlobals, packet []byte, serverName string) ([]byte, error) {
if len(*pluginsGlobals.queryPlugins) == 0 && len(*pluginsGlobals.loggingPlugins) == 0 {
return packet, nil
}
pluginsState.serverName = serverName
pluginsState.action = PluginsActionForward
msg := dns.Msg{}
if err := msg.Unpack(packet); err != nil {
return packet, err
}
if len(msg.Question) > 1 {
if len(msg.Question) != 1 {
return packet, errors.New("Unexpected number of questions")
}
qName, err := NormalizeQName(msg.Question[0].Name)
if err != nil {
return packet, err
}
pluginsState.qName = qName
pluginsState.questionMsg = &msg
if len(*pluginsGlobals.queryPlugins) == 0 && len(*pluginsGlobals.loggingPlugins) == 0 {
return packet, nil
}
pluginsGlobals.RLock()
defer pluginsGlobals.RUnlock()
for _, plugin := range *pluginsGlobals.queryPlugins {
@ -307,7 +314,6 @@ func (pluginsState *PluginsState) ApplyResponsePlugins(pluginsGlobals *PluginsGl
}
if pluginsState.action == PluginsActionReject {
synth := RefusedResponseFromMessage(&msg, pluginsGlobals.refusedCodeInResponses, pluginsGlobals.respondWithIPv4, pluginsGlobals.respondWithIPv6, pluginsState.rejectTTL)
dlog.Infof("Blocking [%s]", synth.Question[0].Name)
pluginsState.synthResponse = synth
}
if pluginsState.action != PluginsActionForward {
@ -330,8 +336,8 @@ func (pluginsState *PluginsState) ApplyLoggingPlugins(pluginsGlobals *PluginsGlo
}
pluginsState.requestEnd = time.Now()
questionMsg := pluginsState.questionMsg
if questionMsg == nil || len(questionMsg.Question) > 1 {
return errors.New("Unexpected number of questions")
if questionMsg == nil {
return errors.New("Question not found")
}
pluginsGlobals.RLock()
defer pluginsGlobals.RUnlock()