dnscrypt-proxy/dnscrypt-proxy/plugins.go

417 lines
9.6 KiB
Go
Raw Normal View History

package main
import (
2018-01-10 18:32:05 +01:00
"crypto/sha512"
"encoding/binary"
"errors"
"net"
2018-01-10 18:32:05 +01:00
"sync"
"time"
2018-01-10 19:02:43 +01:00
lru "github.com/hashicorp/golang-lru"
"github.com/miekg/dns"
)
type PluginsAction int
const (
PluginsActionNone = 0
PluginsActionForward = 1
2018-01-10 10:11:59 +01:00
PluginsActionDrop = 2
PluginsActionReject = 3
2018-01-10 17:23:20 +01:00
PluginsActionSynth = 4
)
type PluginsGlobals struct {
sync.RWMutex
queryPlugins *[]Plugin
responsePlugins *[]Plugin
}
var pluginsGlobals PluginsGlobals
type PluginsState struct {
sessionData map[string]interface{}
action PluginsAction
originalMaxPayloadSize int
maxPayloadSize int
clientProto string
clientAddr *net.Addr
2018-01-10 17:23:20 +01:00
synthResponse *dns.Msg
2018-01-10 18:32:05 +01:00
dnssec bool
cacheSize int
cacheNegTTL uint32
cacheMinTTL uint32
cacheMaxTTL uint32
2018-01-10 10:11:59 +01:00
}
func InitPluginsGlobals(pluginsGlobals *PluginsGlobals, proxy *Proxy) error {
2018-01-10 17:23:20 +01:00
queryPlugins := &[]Plugin{}
if proxy.pluginBlockIPv6 {
*queryPlugins = append(*queryPlugins, Plugin(new(PluginBlockIPv6)))
}
*queryPlugins = append(*queryPlugins, Plugin(new(PluginGetSetPayloadSize)))
2018-01-10 18:53:09 +01:00
if proxy.cache {
*queryPlugins = append(*queryPlugins, Plugin(new(PluginCache)))
}
2018-01-10 17:23:20 +01:00
2018-01-10 10:11:59 +01:00
responsePlugins := &[]Plugin{}
2018-01-10 18:32:05 +01:00
if proxy.cache {
2018-01-10 18:53:09 +01:00
*responsePlugins = append(*responsePlugins, Plugin(new(PluginCacheResponse)))
2018-01-10 18:32:05 +01:00
}
2018-01-10 17:23:20 +01:00
for _, plugin := range *queryPlugins {
if err := plugin.Init(proxy); err != nil {
return err
}
}
for _, plugin := range *responsePlugins {
if err := plugin.Init(proxy); err != nil {
return err
}
}
(*pluginsGlobals).queryPlugins = queryPlugins
(*pluginsGlobals).responsePlugins = responsePlugins
return nil
}
type Plugin interface {
Name() string
Description() string
Init(proxy *Proxy) error
Drop() error
Reload() error
Eval(pluginsState *PluginsState, msg *dns.Msg) error
}
func NewPluginsState(proxy *Proxy, clientProto string, clientAddr *net.Addr) PluginsState {
2018-01-10 18:32:05 +01:00
return PluginsState{
action: PluginsActionForward,
maxPayloadSize: MaxDNSUDPPacketSize - ResponseOverhead,
clientProto: clientProto,
clientAddr: clientAddr,
cacheSize: proxy.cacheSize,
cacheNegTTL: proxy.cacheNegTTL,
cacheMinTTL: proxy.cacheMinTTL,
cacheMaxTTL: proxy.cacheMaxTTL,
2018-01-10 18:32:05 +01:00
}
}
2018-01-10 18:32:05 +01:00
// ---------------- Query plugins ----------------
func (pluginsState *PluginsState) ApplyQueryPlugins(pluginsGlobals *PluginsGlobals, packet []byte) ([]byte, error) {
if len(*pluginsGlobals.queryPlugins) == 0 {
2018-01-10 18:32:05 +01:00
return packet, nil
}
2018-01-10 10:11:59 +01:00
pluginsState.action = PluginsActionForward
msg := dns.Msg{}
if err := msg.Unpack(packet); err != nil {
return packet, err
}
pluginsGlobals.RLock()
for _, plugin := range *pluginsGlobals.queryPlugins {
2018-01-10 10:11:59 +01:00
if ret := plugin.Eval(pluginsState, &msg); ret != nil {
pluginsGlobals.RUnlock()
2018-01-10 10:11:59 +01:00
pluginsState.action = PluginsActionDrop
return packet, ret
}
2018-01-10 17:23:20 +01:00
if pluginsState.action != PluginsActionForward {
break
}
}
pluginsGlobals.RUnlock()
packet2, err := msg.PackBuffer(packet)
if err != nil {
return packet, err
}
return packet2, nil
}
2018-01-10 17:23:20 +01:00
// -------- get_set_payload_size plugin --------
2018-01-10 10:11:59 +01:00
type PluginGetSetPayloadSize struct{}
func (plugin *PluginGetSetPayloadSize) Name() string {
return "get_set_payload_size"
}
func (plugin *PluginGetSetPayloadSize) Description() string {
return "Adjusts the maximum payload size advertised in queries sent to upstream servers."
}
func (plugin *PluginGetSetPayloadSize) Init(proxy *Proxy) error {
return nil
}
func (plugin *PluginGetSetPayloadSize) Drop() error {
return nil
}
func (plugin *PluginGetSetPayloadSize) Reload() error {
return nil
}
2018-01-10 10:11:59 +01:00
func (plugin *PluginGetSetPayloadSize) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
pluginsState.originalMaxPayloadSize = 512 - ResponseOverhead
opt := msg.IsEdns0()
dnssec := false
if opt != nil {
pluginsState.originalMaxPayloadSize = Min(int(opt.UDPSize())-ResponseOverhead, pluginsState.originalMaxPayloadSize)
dnssec = opt.Do()
}
2018-01-10 18:32:05 +01:00
pluginsState.dnssec = dnssec
pluginsState.maxPayloadSize = Min(MaxDNSUDPPacketSize-ResponseOverhead, Max(pluginsState.originalMaxPayloadSize, pluginsState.maxPayloadSize))
if pluginsState.maxPayloadSize > 512 {
extra2 := []dns.RR{}
for _, extra := range msg.Extra {
if extra.Header().Rrtype != dns.TypeOPT {
extra2 = append(extra2, extra)
}
}
msg.Extra = extra2
msg.SetEdns0(uint16(pluginsState.maxPayloadSize), dnssec)
}
return nil
}
2018-01-10 17:23:20 +01:00
// -------- block_ipv6 plugin --------
type PluginBlockIPv6 struct{}
func (plugin *PluginBlockIPv6) Name() string {
return "block_ipv6"
}
func (plugin *PluginBlockIPv6) Description() string {
return "Immediately return a synthetic response to AAAA queries"
}
func (plugin *PluginBlockIPv6) Init(proxy *Proxy) error {
return nil
}
func (plugin *PluginBlockIPv6) Drop() error {
return nil
}
func (plugin *PluginBlockIPv6) Reload() error {
return nil
}
2018-01-10 17:23:20 +01:00
func (plugin *PluginBlockIPv6) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
questions := msg.Question
if len(questions) != 1 {
return nil
}
question := questions[0]
if question.Qclass != dns.ClassINET || question.Qtype != dns.TypeAAAA {
return nil
}
synth, err := EmptyResponseFromMessage(msg)
if err != nil {
return err
}
pluginsState.synthResponse = synth
pluginsState.action = PluginsActionSynth
return nil
}
2018-01-10 18:32:05 +01:00
// -------- querylog plugin --------
type PluginQueryLog struct{}
func (plugin *PluginQueryLog) Name() string {
return "querylog"
}
func (plugin *PluginQueryLog) Description() string {
return "Log DNS queries"
}
func (plugin *PluginQueryLog) Init(proxy *Proxy) error {
return nil
}
func (plugin *PluginQueryLog) Drop() error {
return nil
}
func (plugin *PluginQueryLog) Reload() error {
return nil
}
func (plugin *PluginQueryLog) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
return nil
}
2018-01-10 18:32:05 +01:00
// ---------------- Response plugins ----------------
func (pluginsState *PluginsState) ApplyResponsePlugins(pluginsGlobals *PluginsGlobals, packet []byte) ([]byte, error) {
if len(*pluginsGlobals.responsePlugins) == 0 {
2018-01-10 18:32:05 +01:00
return packet, nil
}
pluginsState.action = PluginsActionForward
msg := dns.Msg{}
if err := msg.Unpack(packet); err != nil {
return packet, err
}
pluginsGlobals.RLock()
for _, plugin := range *pluginsGlobals.responsePlugins {
2018-01-10 18:32:05 +01:00
if ret := plugin.Eval(pluginsState, &msg); ret != nil {
pluginsGlobals.RUnlock()
2018-01-10 18:32:05 +01:00
pluginsState.action = PluginsActionDrop
return packet, ret
}
if pluginsState.action != PluginsActionForward {
break
}
}
pluginsGlobals.RUnlock()
2018-01-10 18:32:05 +01:00
packet2, err := msg.PackBuffer(packet)
if err != nil {
return packet, err
}
return packet2, nil
}
// -------- cache plugin --------
type CachedResponse struct {
expiration time.Time
msg dns.Msg
}
type CachedResponses struct {
sync.RWMutex
2018-01-10 19:02:43 +01:00
cache *lru.ARCCache
2018-01-10 18:32:05 +01:00
}
var cachedResponses CachedResponses
2018-01-10 18:53:09 +01:00
type PluginCacheResponse struct {
2018-01-10 18:32:05 +01:00
cachedResponses *CachedResponses
}
2018-01-10 18:53:09 +01:00
func (plugin *PluginCacheResponse) Name() string {
return "cache_response"
2018-01-10 18:32:05 +01:00
}
2018-01-10 18:53:09 +01:00
func (plugin *PluginCacheResponse) Description() string {
return "DNS cache (writer)."
2018-01-10 18:32:05 +01:00
}
func (plugin *PluginCacheResponse) Init(proxy *Proxy) error {
return nil
}
func (plugin *PluginCacheResponse) Drop() error {
return nil
}
func (plugin *PluginCacheResponse) Reload() error {
return nil
}
2018-01-10 18:53:09 +01:00
func (plugin *PluginCacheResponse) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
2018-01-10 18:32:05 +01:00
plugin.cachedResponses = &cachedResponses
if msg.Rcode == dns.RcodeServerFailure {
2018-01-10 18:53:09 +01:00
return nil
}
2018-01-10 18:32:05 +01:00
cacheKey, err := computeCacheKey(pluginsState, msg)
if err != nil {
return err
2018-01-10 18:32:05 +01:00
}
2018-01-10 22:47:29 +01:00
ttl := getMinTTL(msg, pluginsState.cacheMinTTL, pluginsState.cacheMaxTTL, pluginsState.cacheNegTTL)
2018-01-10 18:32:05 +01:00
cachedResponse := CachedResponse{
expiration: time.Now().Add(ttl),
msg: *msg,
}
plugin.cachedResponses.Lock()
defer plugin.cachedResponses.Unlock()
if plugin.cachedResponses.cache == nil {
2018-01-10 22:47:29 +01:00
plugin.cachedResponses.cache, err = lru.NewARC(pluginsState.cacheSize)
2018-01-10 19:02:43 +01:00
if err != nil {
return err
2018-01-10 18:32:05 +01:00
}
}
2018-01-10 19:02:43 +01:00
plugin.cachedResponses.cache.Add(cacheKey, cachedResponse)
2018-01-10 18:32:05 +01:00
return nil
}
2018-01-10 18:53:09 +01:00
type PluginCache struct {
cachedResponses *CachedResponses
}
func (plugin *PluginCache) Name() string {
return "cache"
}
func (plugin *PluginCache) Description() string {
return "DNS cache (reader)."
}
func (plugin *PluginCache) Init(proxy *Proxy) error {
return nil
}
func (plugin *PluginCache) Drop() error {
return nil
}
func (plugin *PluginCache) Reload() error {
return nil
}
2018-01-10 18:53:09 +01:00
func (plugin *PluginCache) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
plugin.cachedResponses = &cachedResponses
cacheKey, err := computeCacheKey(pluginsState, msg)
if err != nil {
return nil
}
plugin.cachedResponses.RLock()
defer plugin.cachedResponses.RUnlock()
if plugin.cachedResponses.cache == nil {
return nil
}
2018-01-10 19:02:43 +01:00
cached_any, ok := plugin.cachedResponses.cache.Get(cacheKey)
2018-01-10 18:53:09 +01:00
if !ok {
return nil
}
2018-01-10 19:02:43 +01:00
cached := cached_any.(CachedResponse)
2018-01-10 18:53:09 +01:00
if time.Now().After(cached.expiration) {
return nil
}
synth := cached.msg
2018-01-10 22:47:29 +01:00
synth.Id = msg.Id
synth.Response = true
synth.Compress = true
2018-01-10 18:53:09 +01:00
synth.Question = msg.Question
pluginsState.synthResponse = &synth
pluginsState.action = PluginsActionSynth
return nil
}
2018-01-10 18:32:05 +01:00
func computeCacheKey(pluginsState *PluginsState, msg *dns.Msg) ([32]byte, error) {
questions := msg.Question
if len(questions) != 1 {
return [32]byte{}, errors.New("No question present")
}
question := questions[0]
h := sha512.New512_256()
var tmp [5]byte
binary.LittleEndian.PutUint16(tmp[0:2], question.Qtype)
binary.LittleEndian.PutUint16(tmp[2:4], question.Qclass)
if pluginsState.dnssec {
tmp[4] = 1
}
h.Write(tmp[:])
normalizedName := []byte(question.Name)
NormalizeName(&normalizedName)
h.Write(normalizedName)
var sum [32]byte
2018-01-10 18:53:09 +01:00
h.Sum(sum[:0])
2018-01-10 18:32:05 +01:00
return sum, nil
}