Store the normalized qName in the plugin state
We now enforce the fact that a query always include a question. It holds true for all practical use cases of dnscrypt-proxy. This avoids quite a lot of redundant code in plugins, and is faster.
This commit is contained in:
parent
ee24bf0421
commit
4fd54a4919
|
@ -2,9 +2,11 @@ package main
|
|||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
@ -42,7 +44,9 @@ func RefusedResponseFromMessage(srcMsg *dns.Msg, refusedCode bool, ipv4 net.IP,
|
|||
} else {
|
||||
dstMsg.Rcode = dns.RcodeSuccess
|
||||
questions := srcMsg.Question
|
||||
if len(questions) > 0 {
|
||||
if len(questions) == 0 {
|
||||
return dstMsg
|
||||
}
|
||||
question := questions[0]
|
||||
sendHInfoResponse := true
|
||||
|
||||
|
@ -73,7 +77,6 @@ func RefusedResponseFromMessage(srcMsg *dns.Msg, refusedCode bool, ipv4 net.IP,
|
|||
dstMsg.Answer = []dns.RR{hinfo}
|
||||
}
|
||||
}
|
||||
}
|
||||
return dstMsg
|
||||
}
|
||||
|
||||
|
@ -93,7 +96,7 @@ func Rcode(packet []byte) uint8 {
|
|||
return packet[3] & 0xf
|
||||
}
|
||||
|
||||
func NormalizeName(name *[]byte) {
|
||||
func NormalizeRawQName(name *[]byte) {
|
||||
for i, c := range *name {
|
||||
if c >= 65 && c <= 90 {
|
||||
(*name)[i] = c + 32
|
||||
|
@ -101,11 +104,33 @@ func NormalizeName(name *[]byte) {
|
|||
}
|
||||
}
|
||||
|
||||
func StripTrailingDot(str string) string {
|
||||
if len(str) > 1 && strings.HasSuffix(str, ".") {
|
||||
str = str[:len(str)-1]
|
||||
func NormalizeQName(str string) (string, error) {
|
||||
if len(str) == 0 || str == "." {
|
||||
return ".", nil
|
||||
}
|
||||
return str
|
||||
hasUpper := false
|
||||
str = strings.TrimSuffix(str, ".")
|
||||
strLen := len(str)
|
||||
for i := 0; i < strLen; i++ {
|
||||
c := str[i]
|
||||
if c >= utf8.RuneSelf {
|
||||
return str, errors.New("Query name is not an ASCII string")
|
||||
}
|
||||
hasUpper = hasUpper || ('A' <= c && c <= 'Z')
|
||||
}
|
||||
if !hasUpper {
|
||||
return str, nil
|
||||
}
|
||||
var b strings.Builder
|
||||
b.Grow(len(str))
|
||||
for i := 0; i < strLen; i++ {
|
||||
c := str[i]
|
||||
if 'A' <= c && c <= 'Z' {
|
||||
c += 'a' - 'A'
|
||||
}
|
||||
b.WriteByte(c)
|
||||
}
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
func getMinTTL(msg *dns.Msg, minTTL uint32, maxTTL uint32, cacheNegMinTTL uint32, cacheNegMaxTTL uint32) time.Duration {
|
||||
|
|
|
@ -122,14 +122,7 @@ func (plugin *PluginBlockIP) Eval(pluginsState *PluginsState, msg *dns.Msg) erro
|
|||
pluginsState.action = PluginsActionReject
|
||||
pluginsState.returnCode = PluginsReturnCodeReject
|
||||
if plugin.logger != nil {
|
||||
questions := msg.Question
|
||||
if len(questions) != 1 {
|
||||
return nil
|
||||
}
|
||||
qName := strings.ToLower(StripTrailingDot(questions[0].Name))
|
||||
if len(qName) < 2 {
|
||||
return nil
|
||||
}
|
||||
qName := pluginsState.qName
|
||||
var clientIPStr string
|
||||
if pluginsState.clientProto == "udp" {
|
||||
clientIPStr = (*pluginsState.clientAddr).(*net.UDPAddr).IP.String()
|
||||
|
|
|
@ -29,11 +29,7 @@ func (plugin *PluginBlockIPv6) Reload() error {
|
|||
}
|
||||
|
||||
func (plugin *PluginBlockIPv6) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
|
||||
questions := msg.Question
|
||||
if len(questions) != 1 {
|
||||
return nil
|
||||
}
|
||||
question := questions[0]
|
||||
question := msg.Question[0]
|
||||
if question.Qclass != dns.ClassINET || question.Qtype != dns.TypeAAAA {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -25,10 +25,9 @@ const aliasesLimit = 8
|
|||
var blockedNames *BlockedNames
|
||||
|
||||
func (blockedNames *BlockedNames) check(pluginsState *PluginsState, qName string, aliasFor *string) (bool, error) {
|
||||
qName = strings.ToLower(StripTrailingDot(qName))
|
||||
reject, reason, xweeklyRanges := blockedNames.patternMatcher.Eval(qName)
|
||||
if aliasFor != nil {
|
||||
reason = reason + " (alias for [" + StripTrailingDot(*aliasFor) + "])"
|
||||
reason = reason + " (alias for [" + *aliasFor + "])"
|
||||
}
|
||||
var weeklyRanges *WeeklyRanges
|
||||
if xweeklyRanges != nil {
|
||||
|
@ -144,11 +143,7 @@ func (plugin *PluginBlockName) Eval(pluginsState *PluginsState, msg *dns.Msg) er
|
|||
if blockedNames == nil || pluginsState.sessionData["whitelisted"] != nil {
|
||||
return nil
|
||||
}
|
||||
questions := msg.Question
|
||||
if len(questions) != 1 {
|
||||
return nil
|
||||
}
|
||||
_, err := blockedNames.check(pluginsState, questions[0].Name, nil)
|
||||
_, err := blockedNames.check(pluginsState, pluginsState.qName, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -181,11 +176,7 @@ func (plugin *PluginBlockNameResponse) Eval(pluginsState *PluginsState, msg *dns
|
|||
if blockedNames == nil || pluginsState.sessionData["whitelisted"] != nil {
|
||||
return nil
|
||||
}
|
||||
questions := msg.Question
|
||||
if len(questions) != 1 {
|
||||
return nil
|
||||
}
|
||||
aliasFor := questions[0].Name
|
||||
aliasFor := pluginsState.qName
|
||||
aliasesLeft := aliasesLimit
|
||||
answers := msg.Answer
|
||||
for _, answer := range answers {
|
||||
|
@ -193,7 +184,11 @@ func (plugin *PluginBlockNameResponse) Eval(pluginsState *PluginsState, msg *dns
|
|||
if header.Class != dns.ClassINET || header.Rrtype != dns.TypeCNAME {
|
||||
continue
|
||||
}
|
||||
if blocked, err := blockedNames.check(pluginsState, answer.(*dns.CNAME).Target, &aliasFor); blocked || err != nil {
|
||||
target, err := NormalizeQName(answer.(*dns.CNAME).Target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if blocked, err := blockedNames.check(pluginsState, target, &aliasFor); blocked || err != nil {
|
||||
return err
|
||||
}
|
||||
aliasesLeft--
|
||||
|
|
|
@ -3,7 +3,6 @@ package main
|
|||
import (
|
||||
"github.com/k-sone/critbitgo"
|
||||
"github.com/miekg/dns"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var undelegatedSet = []string{
|
||||
|
@ -179,11 +178,7 @@ func (plugin *PluginBlockUndelegated) Reload() error {
|
|||
}
|
||||
|
||||
func (plugin *PluginBlockUndelegated) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
|
||||
questions := msg.Question
|
||||
if len(questions) != 1 {
|
||||
return nil
|
||||
}
|
||||
revQname := strings.ToLower(StringReverse(questions[0].Name))
|
||||
revQname := StringReverse(pluginsState.qName)
|
||||
match, _, found := plugin.suffixes.LongestPrefix([]byte(revQname))
|
||||
if !found {
|
||||
return nil
|
||||
|
|
|
@ -30,17 +30,11 @@ func (plugin *PluginBlockUnqualified) Reload() error {
|
|||
}
|
||||
|
||||
func (plugin *PluginBlockUnqualified) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
|
||||
questions := msg.Question
|
||||
if len(questions) != 1 {
|
||||
return nil
|
||||
}
|
||||
question := questions[0]
|
||||
question := msg.Question[0]
|
||||
if question.Qclass != dns.ClassINET || (question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA) {
|
||||
return nil
|
||||
}
|
||||
qName := question.Name
|
||||
idx := strings.IndexByte(qName, '.')
|
||||
if idx == -1 || (idx == 0 || idx+1 != len(qName)) {
|
||||
if strings.IndexByte(pluginsState.qName, '.') >= 0 {
|
||||
return nil
|
||||
}
|
||||
synth := EmptyResponseFromMessage(msg)
|
||||
|
|
|
@ -3,7 +3,6 @@ package main
|
|||
import (
|
||||
"crypto/sha512"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -23,12 +22,8 @@ type CachedResponses struct {
|
|||
|
||||
var cachedResponses CachedResponses
|
||||
|
||||
func computeCacheKey(pluginsState *PluginsState, msg *dns.Msg) ([32]byte, error) {
|
||||
questions := msg.Question
|
||||
if len(questions) != 1 {
|
||||
return [32]byte{}, errors.New("No question present")
|
||||
}
|
||||
question := questions[0]
|
||||
func computeCacheKey(pluginsState *PluginsState, msg *dns.Msg) [32]byte {
|
||||
question := msg.Question[0]
|
||||
h := sha512.New512_256()
|
||||
var tmp [5]byte
|
||||
binary.LittleEndian.PutUint16(tmp[0:2], question.Qtype)
|
||||
|
@ -37,12 +32,13 @@ func computeCacheKey(pluginsState *PluginsState, msg *dns.Msg) ([32]byte, error)
|
|||
tmp[4] = 1
|
||||
}
|
||||
h.Write(tmp[:])
|
||||
normalizedName := []byte(question.Name)
|
||||
NormalizeName(&normalizedName)
|
||||
h.Write(normalizedName)
|
||||
normalizedRawQName := []byte(question.Name)
|
||||
NormalizeRawQName(&normalizedRawQName)
|
||||
h.Write(normalizedRawQName)
|
||||
var sum [32]byte
|
||||
h.Sum(sum[:0])
|
||||
return sum, nil
|
||||
|
||||
return sum
|
||||
}
|
||||
|
||||
// ---
|
||||
|
@ -71,10 +67,7 @@ func (plugin *PluginCache) Reload() error {
|
|||
}
|
||||
|
||||
func (plugin *PluginCache) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
|
||||
cacheKey, err := computeCacheKey(pluginsState, msg)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
cacheKey := computeCacheKey(pluginsState, msg)
|
||||
cachedResponses.RLock()
|
||||
defer cachedResponses.RUnlock()
|
||||
if cachedResponses.cache == nil {
|
||||
|
@ -134,10 +127,7 @@ func (plugin *PluginCacheResponse) Eval(pluginsState *PluginsState, msg *dns.Msg
|
|||
if msg.Truncated {
|
||||
return nil
|
||||
}
|
||||
cacheKey, err := computeCacheKey(pluginsState, msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cacheKey := computeCacheKey(pluginsState, msg)
|
||||
ttl := getMinTTL(msg, pluginsState.cacheMinTTL, pluginsState.cacheMaxTTL, pluginsState.cacheNegMinTTL, pluginsState.cacheNegMaxTTL)
|
||||
cachedResponse := CachedResponse{
|
||||
expiration: time.Now().Add(ttl),
|
||||
|
@ -145,6 +135,7 @@ func (plugin *PluginCacheResponse) Eval(pluginsState *PluginsState, msg *dns.Msg
|
|||
}
|
||||
cachedResponses.Lock()
|
||||
if cachedResponses.cache == nil {
|
||||
var err error
|
||||
cachedResponses.cache, err = lru.NewARC(pluginsState.cacheSize)
|
||||
if err != nil {
|
||||
cachedResponses.Unlock()
|
||||
|
|
|
@ -100,21 +100,13 @@ func (plugin *PluginCloak) Reload() error {
|
|||
}
|
||||
|
||||
func (plugin *PluginCloak) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
|
||||
questions := msg.Question
|
||||
if len(questions) != 1 {
|
||||
return nil
|
||||
}
|
||||
question := questions[0]
|
||||
question := msg.Question[0]
|
||||
if question.Qclass != dns.ClassINET || (question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA) {
|
||||
return nil
|
||||
}
|
||||
qName := strings.ToLower(StripTrailingDot(questions[0].Name))
|
||||
if len(qName) < 2 {
|
||||
return nil
|
||||
}
|
||||
now := time.Now()
|
||||
plugin.RLock()
|
||||
_, _, xcloakedName := plugin.patternMatcher.Eval(qName)
|
||||
_, _, xcloakedName := plugin.patternMatcher.Eval(pluginsState.qName)
|
||||
if xcloakedName == nil {
|
||||
plugin.RUnlock()
|
||||
return nil
|
||||
|
|
|
@ -34,16 +34,12 @@ func (plugin *PluginFirefox) Reload() error {
|
|||
}
|
||||
|
||||
func (plugin *PluginFirefox) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
|
||||
questions := msg.Question
|
||||
if len(questions) != 1 {
|
||||
return nil
|
||||
}
|
||||
question := questions[0]
|
||||
question := msg.Question[0]
|
||||
if question.Qclass != dns.ClassINET || (question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA) {
|
||||
return nil
|
||||
}
|
||||
qName := strings.ToLower(question.Name)
|
||||
if qName != "use-application-dns.net." && !strings.HasSuffix(qName, ".use-application-dns.net.") {
|
||||
qName := pluginsState.qName
|
||||
if qName != "use-application-dns.net" && !strings.HasSuffix(qName, ".use-application-dns.net") {
|
||||
return nil
|
||||
}
|
||||
synth := EmptyResponseFromMessage(msg)
|
||||
|
|
|
@ -71,19 +71,15 @@ func (plugin *PluginForward) Reload() error {
|
|||
}
|
||||
|
||||
func (plugin *PluginForward) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
|
||||
questions := msg.Question
|
||||
if len(questions) != 1 {
|
||||
return nil
|
||||
}
|
||||
question := strings.ToLower(StripTrailingDot(questions[0].Name))
|
||||
questionLen := len(question)
|
||||
qName := pluginsState.qName
|
||||
qNameLen := len(qName)
|
||||
var servers []string
|
||||
for _, candidate := range plugin.forwardMap {
|
||||
candidateLen := len(candidate.domain)
|
||||
if candidateLen > questionLen {
|
||||
if candidateLen > qNameLen {
|
||||
continue
|
||||
}
|
||||
if question[questionLen-candidateLen:] == candidate.domain && (candidateLen == questionLen || (question[questionLen-candidateLen-1] == '.')) {
|
||||
if qName[qNameLen-candidateLen:] == candidate.domain && (candidateLen == qNameLen || (qName[qNameLen-candidateLen-1] == '.')) {
|
||||
servers = candidate.servers
|
||||
break
|
||||
}
|
||||
|
|
|
@ -43,11 +43,7 @@ func (plugin *PluginNxLog) Eval(pluginsState *PluginsState, msg *dns.Msg) error
|
|||
if msg.Rcode != dns.RcodeNameError {
|
||||
return nil
|
||||
}
|
||||
questions := msg.Question
|
||||
if len(questions) == 0 {
|
||||
return nil
|
||||
}
|
||||
question := questions[0]
|
||||
question := msg.Question[0]
|
||||
qType, ok := dns.TypeToString[question.Qtype]
|
||||
if !ok {
|
||||
qType = string(qType)
|
||||
|
@ -58,7 +54,7 @@ func (plugin *PluginNxLog) Eval(pluginsState *PluginsState, msg *dns.Msg) error
|
|||
} else {
|
||||
clientIPStr = (*pluginsState.clientAddr).(*net.TCPAddr).IP.String()
|
||||
}
|
||||
qName := StripTrailingDot(question.Name)
|
||||
qName := pluginsState.qName
|
||||
|
||||
var line string
|
||||
if plugin.format == "tsv" {
|
||||
|
|
|
@ -43,11 +43,7 @@ func (plugin *PluginQueryLog) Reload() error {
|
|||
}
|
||||
|
||||
func (plugin *PluginQueryLog) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
|
||||
questions := msg.Question
|
||||
if len(questions) == 0 {
|
||||
return nil
|
||||
}
|
||||
question := questions[0]
|
||||
question := msg.Question[0]
|
||||
qType, ok := dns.TypeToString[question.Qtype]
|
||||
if !ok {
|
||||
qType = string(qType)
|
||||
|
@ -65,7 +61,7 @@ func (plugin *PluginQueryLog) Eval(pluginsState *PluginsState, msg *dns.Msg) err
|
|||
} else {
|
||||
clientIPStr = (*pluginsState.clientAddr).(*net.TCPAddr).IP.String()
|
||||
}
|
||||
qName := StripTrailingDot(question.Name)
|
||||
qName := pluginsState.qName
|
||||
|
||||
if pluginsState.cacheHit {
|
||||
pluginsState.serverName = "-"
|
||||
|
|
|
@ -34,10 +34,6 @@ func (plugin *PluginQueryMeta) Reload() error {
|
|||
}
|
||||
|
||||
func (plugin *PluginQueryMeta) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
|
||||
questions := msg.Question
|
||||
if len(questions) == 0 {
|
||||
return nil
|
||||
}
|
||||
msg.Extra = []dns.RR{plugin.queryMetaRR}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -82,11 +82,7 @@ func (plugin *PluginWhitelistName) Reload() error {
|
|||
}
|
||||
|
||||
func (plugin *PluginWhitelistName) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
|
||||
questions := msg.Question
|
||||
if len(questions) != 1 {
|
||||
return nil
|
||||
}
|
||||
qName := strings.ToLower(StripTrailingDot(questions[0].Name))
|
||||
qName := pluginsState.qName
|
||||
whitelist, reason, xweeklyRanges := plugin.patternMatcher.Eval(qName)
|
||||
var weeklyRanges *WeeklyRanges
|
||||
if xweeklyRanges != nil {
|
||||
|
|
|
@ -80,6 +80,7 @@ type PluginsState struct {
|
|||
cacheMaxTTL uint32
|
||||
rejectTTL uint32
|
||||
questionMsg *dns.Msg
|
||||
qName string
|
||||
requestStart time.Time
|
||||
requestEnd time.Time
|
||||
cacheHit bool
|
||||
|
@ -235,25 +236,31 @@ func NewPluginsState(proxy *Proxy, clientProto string, clientAddr *net.Addr, sta
|
|||
cacheMaxTTL: proxy.cacheMaxTTL,
|
||||
rejectTTL: proxy.rejectTTL,
|
||||
questionMsg: nil,
|
||||
qName: "",
|
||||
requestStart: start,
|
||||
maxUnencryptedUDPSafePayloadSize: MaxDNSUDPSafePacketSize,
|
||||
}
|
||||
}
|
||||
|
||||
func (pluginsState *PluginsState) ApplyQueryPlugins(pluginsGlobals *PluginsGlobals, packet []byte, serverName string) ([]byte, error) {
|
||||
if len(*pluginsGlobals.queryPlugins) == 0 && len(*pluginsGlobals.loggingPlugins) == 0 {
|
||||
return packet, nil
|
||||
}
|
||||
pluginsState.serverName = serverName
|
||||
pluginsState.action = PluginsActionForward
|
||||
msg := dns.Msg{}
|
||||
if err := msg.Unpack(packet); err != nil {
|
||||
return packet, err
|
||||
}
|
||||
if len(msg.Question) > 1 {
|
||||
if len(msg.Question) != 1 {
|
||||
return packet, errors.New("Unexpected number of questions")
|
||||
}
|
||||
qName, err := NormalizeQName(msg.Question[0].Name)
|
||||
if err != nil {
|
||||
return packet, err
|
||||
}
|
||||
pluginsState.qName = qName
|
||||
pluginsState.questionMsg = &msg
|
||||
if len(*pluginsGlobals.queryPlugins) == 0 && len(*pluginsGlobals.loggingPlugins) == 0 {
|
||||
return packet, nil
|
||||
}
|
||||
pluginsGlobals.RLock()
|
||||
defer pluginsGlobals.RUnlock()
|
||||
for _, plugin := range *pluginsGlobals.queryPlugins {
|
||||
|
@ -307,7 +314,6 @@ func (pluginsState *PluginsState) ApplyResponsePlugins(pluginsGlobals *PluginsGl
|
|||
}
|
||||
if pluginsState.action == PluginsActionReject {
|
||||
synth := RefusedResponseFromMessage(&msg, pluginsGlobals.refusedCodeInResponses, pluginsGlobals.respondWithIPv4, pluginsGlobals.respondWithIPv6, pluginsState.rejectTTL)
|
||||
dlog.Infof("Blocking [%s]", synth.Question[0].Name)
|
||||
pluginsState.synthResponse = synth
|
||||
}
|
||||
if pluginsState.action != PluginsActionForward {
|
||||
|
@ -330,8 +336,8 @@ func (pluginsState *PluginsState) ApplyLoggingPlugins(pluginsGlobals *PluginsGlo
|
|||
}
|
||||
pluginsState.requestEnd = time.Now()
|
||||
questionMsg := pluginsState.questionMsg
|
||||
if questionMsg == nil || len(questionMsg.Question) > 1 {
|
||||
return errors.New("Unexpected number of questions")
|
||||
if questionMsg == nil {
|
||||
return errors.New("Question not found")
|
||||
}
|
||||
pluginsGlobals.RLock()
|
||||
defer pluginsGlobals.RUnlock()
|
||||
|
|
Loading…
Reference in New Issue