Start implementing a basic cache
This commit is contained in:
parent
da3e3e61b4
commit
77cdc1db78
|
@ -14,10 +14,11 @@ type Config struct {
|
|||
ServerNames []string `toml:"server_names"`
|
||||
ListenAddresses []string `toml:"listen_addresses"`
|
||||
Daemonize bool
|
||||
ForceTCP bool `toml:"force_tcp"`
|
||||
Timeout int `toml:"timeout_ms"`
|
||||
CertRefreshDelay int `toml:"cert_refresh_delay"`
|
||||
BlockIPv6 bool `toml:"block_ipv6"`
|
||||
ForceTCP bool `toml:"force_tcp"`
|
||||
Timeout int `toml:"timeout_ms"`
|
||||
CertRefreshDelay int `toml:"cert_refresh_delay"`
|
||||
BlockIPv6 bool `toml:"block_ipv6"`
|
||||
Cache bool
|
||||
ServersConfig map[string]ServerConfig `toml:"servers"`
|
||||
}
|
||||
|
||||
|
@ -58,6 +59,10 @@ func ConfigLoad(proxy *Proxy, config_file string) error {
|
|||
proxy.listenAddresses = config.ListenAddresses
|
||||
proxy.daemonize = config.Daemonize
|
||||
proxy.pluginBlockIPv6 = config.BlockIPv6
|
||||
proxy.cache = config.Cache
|
||||
proxy.negCacheMinTTL = 60
|
||||
proxy.minTTL = 60
|
||||
proxy.maxTTL = 86400
|
||||
if len(config.ServerNames) == 0 {
|
||||
for serverName := range config.ServersConfig {
|
||||
config.ServerNames = append(config.ServerNames, serverName)
|
||||
|
|
|
@ -48,6 +48,13 @@ cert_refresh_delay = 30
|
|||
block_ipv6 = false
|
||||
|
||||
|
||||
############## DNS Cache ##############
|
||||
|
||||
## Enable a basic DNS cache to reduce outgoing traffic
|
||||
|
||||
cache = true
|
||||
|
||||
|
||||
############## Servers ##############
|
||||
|
||||
## Static list of available servers
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
|
@ -30,3 +32,27 @@ func EmptyResponseFromMessage(srcMsg *dns.Msg) (*dns.Msg, error) {
|
|||
func HasTCFlag(packet []byte) bool {
|
||||
return packet[2]&2 == 2
|
||||
}
|
||||
|
||||
func NormalizeName(name *[]byte) {
|
||||
for i, c := range *name {
|
||||
if c >= 65 && c <= 90 {
|
||||
(*name)[i] = c + 32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getMinTTL(msg *dns.Msg, minTTL uint32, maxTTL uint32, negCacheMinTTL uint32) time.Duration {
|
||||
if len(msg.Answer) <= 0 {
|
||||
return time.Duration(negCacheMinTTL) * time.Second
|
||||
}
|
||||
ttl := uint32(maxTTL)
|
||||
for _, rr := range msg.Answer {
|
||||
if rr.Header().Ttl < ttl {
|
||||
ttl = rr.Header().Ttl
|
||||
}
|
||||
}
|
||||
if ttl < minTTL {
|
||||
ttl = minTTL
|
||||
}
|
||||
return time.Duration(ttl) * time.Second
|
||||
}
|
||||
|
|
|
@ -24,6 +24,10 @@ type Proxy struct {
|
|||
daemonize bool
|
||||
registeredServers []RegisteredServer
|
||||
pluginBlockIPv6 bool
|
||||
cache bool
|
||||
negCacheMinTTL uint32
|
||||
minTTL uint32
|
||||
maxTTL uint32
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
@ -191,6 +195,7 @@ func (proxy *Proxy) processIncomingQuery(serverInfo *ServerInfo, serverProto str
|
|||
serverInfo.noticeFailure(proxy)
|
||||
return
|
||||
}
|
||||
response, _ = pluginsState.ApplyResponsePlugins(response)
|
||||
}
|
||||
if clientAddr != nil {
|
||||
if len(response) > MaxDNSUDPPacketSize {
|
||||
|
|
|
@ -1,6 +1,13 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/sha512"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
|
@ -23,6 +30,7 @@ type PluginsState struct {
|
|||
queryPlugins *[]Plugin
|
||||
responsePlugins *[]Plugin
|
||||
synthResponse *dns.Msg
|
||||
dnssec bool
|
||||
}
|
||||
|
||||
type Plugin interface {
|
||||
|
@ -39,12 +47,25 @@ func NewPluginsState(proxy *Proxy, proto string) PluginsState {
|
|||
*queryPlugins = append(*queryPlugins, Plugin(new(PluginGetSetPayloadSize)))
|
||||
|
||||
responsePlugins := &[]Plugin{}
|
||||
if proxy.cache {
|
||||
*responsePlugins = append(*responsePlugins, Plugin(new(PluginCache)))
|
||||
}
|
||||
|
||||
return PluginsState{action: PluginsActionForward, maxPayloadSize: MaxDNSUDPPacketSize - ResponseOverhead,
|
||||
queryPlugins: queryPlugins, responsePlugins: responsePlugins, proto: proto}
|
||||
return PluginsState{
|
||||
action: PluginsActionForward,
|
||||
maxPayloadSize: MaxDNSUDPPacketSize - ResponseOverhead,
|
||||
queryPlugins: queryPlugins,
|
||||
responsePlugins: responsePlugins,
|
||||
proto: proto,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------- Query plugins ----------------
|
||||
|
||||
func (pluginsState *PluginsState) ApplyQueryPlugins(packet []byte) ([]byte, error) {
|
||||
if len(*pluginsState.queryPlugins) == 0 {
|
||||
return packet, nil
|
||||
}
|
||||
pluginsState.action = PluginsActionForward
|
||||
msg := dns.Msg{}
|
||||
if err := msg.Unpack(packet); err != nil {
|
||||
|
@ -86,6 +107,7 @@ func (plugin *PluginGetSetPayloadSize) Eval(pluginsState *PluginsState, msg *dns
|
|||
pluginsState.originalMaxPayloadSize = Min(int(opt.UDPSize())-ResponseOverhead, pluginsState.originalMaxPayloadSize)
|
||||
dnssec = opt.Do()
|
||||
}
|
||||
pluginsState.dnssec = dnssec
|
||||
pluginsState.maxPayloadSize = Min(MaxDNSUDPPacketSize-ResponseOverhead, Max(pluginsState.originalMaxPayloadSize, pluginsState.maxPayloadSize))
|
||||
if pluginsState.maxPayloadSize > 512 {
|
||||
extra2 := []dns.RR{}
|
||||
|
@ -129,3 +151,108 @@ func (plugin *PluginBlockIPv6) Eval(pluginsState *PluginsState, msg *dns.Msg) er
|
|||
pluginsState.action = PluginsActionSynth
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------------- Response plugins ----------------
|
||||
|
||||
func (pluginsState *PluginsState) ApplyResponsePlugins(packet []byte) ([]byte, error) {
|
||||
if len(*pluginsState.responsePlugins) == 0 {
|
||||
return packet, nil
|
||||
}
|
||||
pluginsState.action = PluginsActionForward
|
||||
msg := dns.Msg{}
|
||||
if err := msg.Unpack(packet); err != nil {
|
||||
return packet, err
|
||||
}
|
||||
for _, plugin := range *pluginsState.responsePlugins {
|
||||
if ret := plugin.Eval(pluginsState, &msg); ret != nil {
|
||||
pluginsState.action = PluginsActionDrop
|
||||
return packet, ret
|
||||
}
|
||||
if pluginsState.action != PluginsActionForward {
|
||||
break
|
||||
}
|
||||
}
|
||||
packet2, err := msg.PackBuffer(packet)
|
||||
if err != nil {
|
||||
return packet, err
|
||||
}
|
||||
return packet2, nil
|
||||
}
|
||||
|
||||
// -------- cache plugin --------
|
||||
|
||||
type CachedResponse struct {
|
||||
expiration time.Time
|
||||
msg dns.Msg
|
||||
}
|
||||
|
||||
type CachedResponses struct {
|
||||
sync.RWMutex
|
||||
cache map[[32]byte]CachedResponse
|
||||
}
|
||||
|
||||
var cachedResponses CachedResponses
|
||||
|
||||
type PluginCache struct {
|
||||
cachedResponses *CachedResponses
|
||||
}
|
||||
|
||||
func (plugin *PluginCache) Name() string {
|
||||
return "cache"
|
||||
}
|
||||
|
||||
func (plugin *PluginCache) Description() string {
|
||||
return "DNS cache."
|
||||
}
|
||||
|
||||
func (plugin *PluginCache) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
|
||||
plugin.cachedResponses = &cachedResponses
|
||||
|
||||
cacheKey, err := computeCacheKey(pluginsState, msg)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
ttl := getMinTTL(msg, 60, 86400, 60)
|
||||
cachedResponse := CachedResponse{
|
||||
expiration: time.Now().Add(ttl),
|
||||
msg: *msg,
|
||||
}
|
||||
plugin.cachedResponses.Lock()
|
||||
defer plugin.cachedResponses.Unlock()
|
||||
if plugin.cachedResponses.cache == nil {
|
||||
plugin.cachedResponses.cache = make(map[[32]byte]CachedResponse)
|
||||
}
|
||||
if len(plugin.cachedResponses.cache) > 1000 {
|
||||
z := byte(rand.Uint32())
|
||||
for k := range plugin.cachedResponses.cache {
|
||||
delete(plugin.cachedResponses.cache, k)
|
||||
if k[0] == z {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
plugin.cachedResponses.cache[cacheKey] = cachedResponse
|
||||
return nil
|
||||
}
|
||||
|
||||
func computeCacheKey(pluginsState *PluginsState, msg *dns.Msg) ([32]byte, error) {
|
||||
questions := msg.Question
|
||||
if len(questions) != 1 {
|
||||
return [32]byte{}, errors.New("No question present")
|
||||
}
|
||||
question := questions[0]
|
||||
h := sha512.New512_256()
|
||||
var tmp [5]byte
|
||||
binary.LittleEndian.PutUint16(tmp[0:2], question.Qtype)
|
||||
binary.LittleEndian.PutUint16(tmp[2:4], question.Qclass)
|
||||
if pluginsState.dnssec {
|
||||
tmp[4] = 1
|
||||
}
|
||||
h.Write(tmp[:])
|
||||
normalizedName := []byte(question.Name)
|
||||
NormalizeName(&normalizedName)
|
||||
h.Write(normalizedName)
|
||||
var sum [32]byte
|
||||
h.Sum(sum[:])
|
||||
return sum, nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue