Overwrite the server name only when we need to send an upstream query

This commit is contained in:
Frank Denis 2020-03-13 17:50:58 +01:00
parent c17637c026
commit 19647e03a6
5 changed files with 13 additions and 14 deletions

View File

@ -53,7 +53,7 @@ func (handler localDoHHandler) ServeHTTP(writer http.ResponseWriter, request *ht
writer.WriteHeader(400) writer.WriteHeader(400)
return return
} }
response := proxy.processIncomingQuery(proxy.serversInfo.getOne(), "local_doh", proxy.mainProto, packet, &xClientAddr, nil, start) response := proxy.processIncomingQuery("local_doh", proxy.mainProto, packet, &xClientAddr, nil, start)
if len(response) == 0 { if len(response) == 0 {
writer.WriteHeader(500) writer.WriteHeader(500)
return return

View File

@ -26,6 +26,9 @@ var blockedNames *BlockedNames
func (blockedNames *BlockedNames) check(pluginsState *PluginsState, qName string, aliasFor *string) (bool, error) { func (blockedNames *BlockedNames) check(pluginsState *PluginsState, qName string, aliasFor *string) (bool, error) {
reject, reason, xweeklyRanges := blockedNames.patternMatcher.Eval(qName) reject, reason, xweeklyRanges := blockedNames.patternMatcher.Eval(qName)
if aliasFor != nil {
reason = reason + " (alias for [" + *aliasFor + "])"
}
var weeklyRanges *WeeklyRanges var weeklyRanges *WeeklyRanges
if xweeklyRanges != nil { if xweeklyRanges != nil {
weeklyRanges = xweeklyRanges.(*WeeklyRanges) weeklyRanges = xweeklyRanges.(*WeeklyRanges)
@ -40,11 +43,6 @@ func (blockedNames *BlockedNames) check(pluginsState *PluginsState, qName string
} }
pluginsState.action = PluginsActionReject pluginsState.action = PluginsActionReject
pluginsState.returnCode = PluginsReturnCodeReject pluginsState.returnCode = PluginsReturnCodeReject
if aliasFor != nil {
reason = reason + " (alias for [" + *aliasFor + "])"
} else {
pluginsState.noServed = true
}
if blockedNames.logger != nil { if blockedNames.logger != nil {
var clientIPStr string var clientIPStr string
if pluginsState.clientProto == "udp" { if pluginsState.clientProto == "udp" {

View File

@ -63,7 +63,7 @@ func (plugin *PluginQueryLog) Eval(pluginsState *PluginsState, msg *dns.Msg) err
} }
qName := pluginsState.qName qName := pluginsState.qName
if pluginsState.cacheHit || pluginsState.noServed { if pluginsState.cacheHit {
pluginsState.serverName = "-" pluginsState.serverName = "-"
} else { } else {
switch pluginsState.returnCode { switch pluginsState.returnCode {

View File

@ -84,7 +84,6 @@ type PluginsState struct {
requestStart time.Time requestStart time.Time
requestEnd time.Time requestEnd time.Time
cacheHit bool cacheHit bool
noServed bool
returnCode PluginsReturnCode returnCode PluginsReturnCode
serverName string serverName string
} }
@ -238,14 +237,14 @@ func NewPluginsState(proxy *Proxy, clientProto string, clientAddr *net.Addr, sta
rejectTTL: proxy.rejectTTL, rejectTTL: proxy.rejectTTL,
questionMsg: nil, questionMsg: nil,
qName: "", qName: "",
serverName: "-",
requestStart: start, requestStart: start,
maxUnencryptedUDPSafePayloadSize: MaxDNSUDPSafePacketSize, maxUnencryptedUDPSafePayloadSize: MaxDNSUDPSafePacketSize,
sessionData: make(map[string]interface{}), sessionData: make(map[string]interface{}),
} }
} }
func (pluginsState *PluginsState) ApplyQueryPlugins(pluginsGlobals *PluginsGlobals, packet []byte, serverName string, needsEDNS0Padding bool) ([]byte, error) { func (pluginsState *PluginsState) ApplyQueryPlugins(pluginsGlobals *PluginsGlobals, packet []byte, needsEDNS0Padding bool) ([]byte, error) {
pluginsState.serverName = serverName
msg := dns.Msg{} msg := dns.Msg{}
if err := msg.Unpack(packet); err != nil { if err := msg.Unpack(packet); err != nil {
return packet, err return packet, err

View File

@ -269,7 +269,7 @@ func (proxy *Proxy) udpListener(clientPc *net.UDPConn) {
return return
} }
defer proxy.clientsCountDec() defer proxy.clientsCountDec()
proxy.processIncomingQuery(proxy.serversInfo.getOne(), "udp", proxy.mainProto, packet, &clientAddr, clientPc, start) proxy.processIncomingQuery("udp", proxy.mainProto, packet, &clientAddr, clientPc, start)
}() }()
} }
} }
@ -307,7 +307,7 @@ func (proxy *Proxy) tcpListener(acceptPc *net.TCPListener) {
return return
} }
clientAddr := clientPc.RemoteAddr() clientAddr := clientPc.RemoteAddr()
proxy.processIncomingQuery(proxy.serversInfo.getOne(), "tcp", "tcp", packet, &clientAddr, clientPc, start) proxy.processIncomingQuery("tcp", "tcp", packet, &clientAddr, clientPc, start)
}() }()
} }
} }
@ -438,18 +438,19 @@ func (proxy *Proxy) clientsCountDec() {
} }
} }
func (proxy *Proxy) processIncomingQuery(serverInfo *ServerInfo, clientProto string, serverProto string, query []byte, clientAddr *net.Addr, clientPc net.Conn, start time.Time) (response []byte) { func (proxy *Proxy) processIncomingQuery(clientProto string, serverProto string, query []byte, clientAddr *net.Addr, clientPc net.Conn, start time.Time) (response []byte) {
if len(query) < MinDNSPacketSize { if len(query) < MinDNSPacketSize {
return return
} }
pluginsState := NewPluginsState(proxy, clientProto, clientAddr, start) pluginsState := NewPluginsState(proxy, clientProto, clientAddr, start)
serverName := "-" serverName := "-"
needsEDNS0Padding := false needsEDNS0Padding := false
serverInfo := proxy.serversInfo.getOne()
if serverInfo != nil { if serverInfo != nil {
serverName = serverInfo.Name serverName = serverInfo.Name
needsEDNS0Padding = (serverInfo.Proto == stamps.StampProtoTypeDoH || serverInfo.Proto == stamps.StampProtoTypeTLS) needsEDNS0Padding = (serverInfo.Proto == stamps.StampProtoTypeDoH || serverInfo.Proto == stamps.StampProtoTypeTLS)
} }
query, _ = pluginsState.ApplyQueryPlugins(&proxy.pluginsGlobals, query, serverName, needsEDNS0Padding) query, _ = pluginsState.ApplyQueryPlugins(&proxy.pluginsGlobals, query, needsEDNS0Padding)
if len(query) < MinDNSPacketSize || len(query) > MaxDNSPacketSize { if len(query) < MinDNSPacketSize || len(query) > MaxDNSPacketSize {
return return
} }
@ -469,6 +470,7 @@ func (proxy *Proxy) processIncomingQuery(serverInfo *ServerInfo, clientProto str
} }
if len(response) == 0 && serverInfo != nil { if len(response) == 0 && serverInfo != nil {
var ttl *uint32 var ttl *uint32
pluginsState.serverName = serverName
if serverInfo.Proto == stamps.StampProtoTypeDNSCrypt { if serverInfo.Proto == stamps.StampProtoTypeDNSCrypt {
sharedKey, encryptedQuery, clientNonce, err := proxy.Encrypt(serverInfo, query, serverProto) sharedKey, encryptedQuery, clientNonce, err := proxy.Encrypt(serverInfo, query, serverProto)
if err != nil { if err != nil {