Start implementing a basic cache

This commit is contained in:
Frank Denis 2018-01-10 18:32:05 +01:00
parent da3e3e61b4
commit 77cdc1db78
5 changed files with 176 additions and 6 deletions

View File

@ -18,6 +18,7 @@ type Config struct {
Timeout int `toml:"timeout_ms"`
CertRefreshDelay int `toml:"cert_refresh_delay"`
BlockIPv6 bool `toml:"block_ipv6"`
Cache bool
ServersConfig map[string]ServerConfig `toml:"servers"`
}
@ -58,6 +59,10 @@ func ConfigLoad(proxy *Proxy, config_file string) error {
proxy.listenAddresses = config.ListenAddresses
proxy.daemonize = config.Daemonize
proxy.pluginBlockIPv6 = config.BlockIPv6
proxy.cache = config.Cache
proxy.negCacheMinTTL = 60
proxy.minTTL = 60
proxy.maxTTL = 86400
if len(config.ServerNames) == 0 {
for serverName := range config.ServersConfig {
config.ServerNames = append(config.ServerNames, serverName)

View File

@ -48,6 +48,13 @@ cert_refresh_delay = 30
block_ipv6 = false
############## DNS Cache ##############
## Enable a basic DNS cache to reduce outgoing traffic
cache = true
############## Servers ##############
## Static list of available servers

View File

@ -1,6 +1,8 @@
package main
import (
"time"
"github.com/miekg/dns"
)
@ -30,3 +32,27 @@ func EmptyResponseFromMessage(srcMsg *dns.Msg) (*dns.Msg, error) {
func HasTCFlag(packet []byte) bool {
return packet[2]&2 == 2
}
func NormalizeName(name *[]byte) {
for i, c := range *name {
if c >= 65 && c <= 90 {
(*name)[i] = c + 32
}
}
}
func getMinTTL(msg *dns.Msg, minTTL uint32, maxTTL uint32, negCacheMinTTL uint32) time.Duration {
if len(msg.Answer) <= 0 {
return time.Duration(negCacheMinTTL) * time.Second
}
ttl := uint32(maxTTL)
for _, rr := range msg.Answer {
if rr.Header().Ttl < ttl {
ttl = rr.Header().Ttl
}
}
if ttl < minTTL {
ttl = minTTL
}
return time.Duration(ttl) * time.Second
}

View File

@ -24,6 +24,10 @@ type Proxy struct {
daemonize bool
registeredServers []RegisteredServer
pluginBlockIPv6 bool
cache bool
negCacheMinTTL uint32
minTTL uint32
maxTTL uint32
}
func main() {
@ -191,6 +195,7 @@ func (proxy *Proxy) processIncomingQuery(serverInfo *ServerInfo, serverProto str
serverInfo.noticeFailure(proxy)
return
}
response, _ = pluginsState.ApplyResponsePlugins(response)
}
if clientAddr != nil {
if len(response) > MaxDNSUDPPacketSize {

View File

@ -1,6 +1,13 @@
package main
import (
"crypto/sha512"
"encoding/binary"
"errors"
"math/rand"
"sync"
"time"
"github.com/miekg/dns"
)
@ -23,6 +30,7 @@ type PluginsState struct {
queryPlugins *[]Plugin
responsePlugins *[]Plugin
synthResponse *dns.Msg
dnssec bool
}
type Plugin interface {
@ -39,12 +47,25 @@ func NewPluginsState(proxy *Proxy, proto string) PluginsState {
*queryPlugins = append(*queryPlugins, Plugin(new(PluginGetSetPayloadSize)))
responsePlugins := &[]Plugin{}
if proxy.cache {
*responsePlugins = append(*responsePlugins, Plugin(new(PluginCache)))
}
return PluginsState{action: PluginsActionForward, maxPayloadSize: MaxDNSUDPPacketSize - ResponseOverhead,
queryPlugins: queryPlugins, responsePlugins: responsePlugins, proto: proto}
return PluginsState{
action: PluginsActionForward,
maxPayloadSize: MaxDNSUDPPacketSize - ResponseOverhead,
queryPlugins: queryPlugins,
responsePlugins: responsePlugins,
proto: proto,
}
}
// ---------------- Query plugins ----------------
func (pluginsState *PluginsState) ApplyQueryPlugins(packet []byte) ([]byte, error) {
if len(*pluginsState.queryPlugins) == 0 {
return packet, nil
}
pluginsState.action = PluginsActionForward
msg := dns.Msg{}
if err := msg.Unpack(packet); err != nil {
@ -86,6 +107,7 @@ func (plugin *PluginGetSetPayloadSize) Eval(pluginsState *PluginsState, msg *dns
pluginsState.originalMaxPayloadSize = Min(int(opt.UDPSize())-ResponseOverhead, pluginsState.originalMaxPayloadSize)
dnssec = opt.Do()
}
pluginsState.dnssec = dnssec
pluginsState.maxPayloadSize = Min(MaxDNSUDPPacketSize-ResponseOverhead, Max(pluginsState.originalMaxPayloadSize, pluginsState.maxPayloadSize))
if pluginsState.maxPayloadSize > 512 {
extra2 := []dns.RR{}
@ -129,3 +151,108 @@ func (plugin *PluginBlockIPv6) Eval(pluginsState *PluginsState, msg *dns.Msg) er
pluginsState.action = PluginsActionSynth
return nil
}
// ---------------- Response plugins ----------------
func (pluginsState *PluginsState) ApplyResponsePlugins(packet []byte) ([]byte, error) {
if len(*pluginsState.responsePlugins) == 0 {
return packet, nil
}
pluginsState.action = PluginsActionForward
msg := dns.Msg{}
if err := msg.Unpack(packet); err != nil {
return packet, err
}
for _, plugin := range *pluginsState.responsePlugins {
if ret := plugin.Eval(pluginsState, &msg); ret != nil {
pluginsState.action = PluginsActionDrop
return packet, ret
}
if pluginsState.action != PluginsActionForward {
break
}
}
packet2, err := msg.PackBuffer(packet)
if err != nil {
return packet, err
}
return packet2, nil
}
// -------- cache plugin --------
type CachedResponse struct {
expiration time.Time
msg dns.Msg
}
type CachedResponses struct {
sync.RWMutex
cache map[[32]byte]CachedResponse
}
var cachedResponses CachedResponses
type PluginCache struct {
cachedResponses *CachedResponses
}
func (plugin *PluginCache) Name() string {
return "cache"
}
func (plugin *PluginCache) Description() string {
return "DNS cache."
}
func (plugin *PluginCache) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
plugin.cachedResponses = &cachedResponses
cacheKey, err := computeCacheKey(pluginsState, msg)
if err != nil {
return nil
}
ttl := getMinTTL(msg, 60, 86400, 60)
cachedResponse := CachedResponse{
expiration: time.Now().Add(ttl),
msg: *msg,
}
plugin.cachedResponses.Lock()
defer plugin.cachedResponses.Unlock()
if plugin.cachedResponses.cache == nil {
plugin.cachedResponses.cache = make(map[[32]byte]CachedResponse)
}
if len(plugin.cachedResponses.cache) > 1000 {
z := byte(rand.Uint32())
for k := range plugin.cachedResponses.cache {
delete(plugin.cachedResponses.cache, k)
if k[0] == z {
break
}
}
}
plugin.cachedResponses.cache[cacheKey] = cachedResponse
return nil
}
func computeCacheKey(pluginsState *PluginsState, msg *dns.Msg) ([32]byte, error) {
questions := msg.Question
if len(questions) != 1 {
return [32]byte{}, errors.New("No question present")
}
question := questions[0]
h := sha512.New512_256()
var tmp [5]byte
binary.LittleEndian.PutUint16(tmp[0:2], question.Qtype)
binary.LittleEndian.PutUint16(tmp[2:4], question.Qclass)
if pluginsState.dnssec {
tmp[4] = 1
}
h.Write(tmp[:])
normalizedName := []byte(question.Name)
NormalizeName(&normalizedName)
h.Write(normalizedName)
var sum [32]byte
h.Sum(sum[:])
return sum, nil
}