Add Init/Drop/Update methods to plugins

Eventually, we may want to provide a specific structure for plugin
initialization. Sending the whole Proxy structure doesn't scale well.
This commit is contained in:
Frank Denis 2018-01-15 23:07:41 +01:00
parent b945e23101
commit 3ffad7be44
2 changed files with 137 additions and 27 deletions

View File

@ -28,6 +28,7 @@ type Proxy struct {
cacheNegTTL uint32 cacheNegTTL uint32
cacheMinTTL uint32 cacheMinTTL uint32
cacheMaxTTL uint32 cacheMaxTTL uint32
pluginsGlobals PluginsGlobals
} }
func main() { func main() {
@ -37,6 +38,9 @@ func main() {
if err := ConfigLoad(&proxy, "dnscrypt-proxy.toml"); err != nil { if err := ConfigLoad(&proxy, "dnscrypt-proxy.toml"); err != nil {
dlog.Fatal(err) dlog.Fatal(err)
} }
if err := InitPluginsGlobals(&proxy.pluginsGlobals, &proxy); err != nil {
dlog.Fatal(err)
}
if proxy.daemonize { if proxy.daemonize {
Daemonize() Daemonize()
} }
@ -178,7 +182,7 @@ func (proxy *Proxy) processIncomingQuery(serverInfo *ServerInfo, clientProto str
return return
} }
pluginsState := NewPluginsState(proxy, clientProto, clientAddr) pluginsState := NewPluginsState(proxy, clientProto, clientAddr)
query, _ = pluginsState.ApplyQueryPlugins(query) query, _ = pluginsState.ApplyQueryPlugins(&proxy.pluginsGlobals, query)
var response []byte var response []byte
var err error var err error
if pluginsState.action != PluginsActionForward { if pluginsState.action != PluginsActionForward {
@ -204,7 +208,7 @@ func (proxy *Proxy) processIncomingQuery(serverInfo *ServerInfo, clientProto str
serverInfo.noticeFailure(proxy) serverInfo.noticeFailure(proxy)
return return
} }
response, _ = pluginsState.ApplyResponsePlugins(response) response, _ = pluginsState.ApplyResponsePlugins(&proxy.pluginsGlobals, response)
} }
if clientProto == "udp" { if clientProto == "udp" {
if len(response) > MaxDNSUDPPacketSize { if len(response) > MaxDNSUDPPacketSize {

View File

@ -22,6 +22,14 @@ const (
PluginsActionSynth = 4 PluginsActionSynth = 4
) )
type PluginsGlobals struct {
sync.RWMutex
queryPlugins *[]Plugin
responsePlugins *[]Plugin
}
var pluginsGlobals PluginsGlobals
type PluginsState struct { type PluginsState struct {
sessionData map[string]interface{} sessionData map[string]interface{}
action PluginsAction action PluginsAction
@ -29,8 +37,6 @@ type PluginsState struct {
maxPayloadSize int maxPayloadSize int
clientProto string clientProto string
clientAddr *net.Addr clientAddr *net.Addr
queryPlugins *[]Plugin
responsePlugins *[]Plugin
synthResponse *dns.Msg synthResponse *dns.Msg
dnssec bool dnssec bool
cacheSize int cacheSize int
@ -39,13 +45,7 @@ type PluginsState struct {
cacheMaxTTL uint32 cacheMaxTTL uint32
} }
type Plugin interface { func InitPluginsGlobals(pluginsGlobals *PluginsGlobals, proxy *Proxy) error {
Name() string
Description() string
Eval(pluginsState *PluginsState, msg *dns.Msg) error
}
func NewPluginsState(proxy *Proxy, clientProto string, clientAddr *net.Addr) PluginsState {
queryPlugins := &[]Plugin{} queryPlugins := &[]Plugin{}
if proxy.pluginBlockIPv6 { if proxy.pluginBlockIPv6 {
*queryPlugins = append(*queryPlugins, Plugin(new(PluginBlockIPv6))) *queryPlugins = append(*queryPlugins, Plugin(new(PluginBlockIPv6)))
@ -60,24 +60,48 @@ func NewPluginsState(proxy *Proxy, clientProto string, clientAddr *net.Addr) Plu
*responsePlugins = append(*responsePlugins, Plugin(new(PluginCacheResponse))) *responsePlugins = append(*responsePlugins, Plugin(new(PluginCacheResponse)))
} }
for _, plugin := range *queryPlugins {
if err := plugin.Init(proxy); err != nil {
return err
}
}
for _, plugin := range *responsePlugins {
if err := plugin.Init(proxy); err != nil {
return err
}
}
(*pluginsGlobals).queryPlugins = queryPlugins
(*pluginsGlobals).responsePlugins = responsePlugins
return nil
}
type Plugin interface {
Name() string
Description() string
Init(proxy *Proxy) error
Drop() error
Reload() error
Eval(pluginsState *PluginsState, msg *dns.Msg) error
}
func NewPluginsState(proxy *Proxy, clientProto string, clientAddr *net.Addr) PluginsState {
return PluginsState{ return PluginsState{
action: PluginsActionForward, action: PluginsActionForward,
maxPayloadSize: MaxDNSUDPPacketSize - ResponseOverhead, maxPayloadSize: MaxDNSUDPPacketSize - ResponseOverhead,
queryPlugins: queryPlugins, clientProto: clientProto,
responsePlugins: responsePlugins, clientAddr: clientAddr,
clientProto: clientProto, cacheSize: proxy.cacheSize,
clientAddr: clientAddr, cacheNegTTL: proxy.cacheNegTTL,
cacheSize: proxy.cacheSize, cacheMinTTL: proxy.cacheMinTTL,
cacheNegTTL: proxy.cacheNegTTL, cacheMaxTTL: proxy.cacheMaxTTL,
cacheMinTTL: proxy.cacheMinTTL,
cacheMaxTTL: proxy.cacheMaxTTL,
} }
} }
// ---------------- Query plugins ---------------- // ---------------- Query plugins ----------------
func (pluginsState *PluginsState) ApplyQueryPlugins(packet []byte) ([]byte, error) { func (pluginsState *PluginsState) ApplyQueryPlugins(pluginsGlobals *PluginsGlobals, packet []byte) ([]byte, error) {
if len(*pluginsState.queryPlugins) == 0 { if len(*pluginsGlobals.queryPlugins) == 0 {
return packet, nil return packet, nil
} }
pluginsState.action = PluginsActionForward pluginsState.action = PluginsActionForward
@ -85,8 +109,10 @@ func (pluginsState *PluginsState) ApplyQueryPlugins(packet []byte) ([]byte, erro
if err := msg.Unpack(packet); err != nil { if err := msg.Unpack(packet); err != nil {
return packet, err return packet, err
} }
for _, plugin := range *pluginsState.queryPlugins { pluginsGlobals.RLock()
for _, plugin := range *pluginsGlobals.queryPlugins {
if ret := plugin.Eval(pluginsState, &msg); ret != nil { if ret := plugin.Eval(pluginsState, &msg); ret != nil {
pluginsGlobals.RUnlock()
pluginsState.action = PluginsActionDrop pluginsState.action = PluginsActionDrop
return packet, ret return packet, ret
} }
@ -94,6 +120,7 @@ func (pluginsState *PluginsState) ApplyQueryPlugins(packet []byte) ([]byte, erro
break break
} }
} }
pluginsGlobals.RUnlock()
packet2, err := msg.PackBuffer(packet) packet2, err := msg.PackBuffer(packet)
if err != nil { if err != nil {
return packet, err return packet, err
@ -113,6 +140,18 @@ func (plugin *PluginGetSetPayloadSize) Description() string {
return "Adjusts the maximum payload size advertised in queries sent to upstream servers." return "Adjusts the maximum payload size advertised in queries sent to upstream servers."
} }
func (plugin *PluginGetSetPayloadSize) Init(proxy *Proxy) error {
return nil
}
func (plugin *PluginGetSetPayloadSize) Drop() error {
return nil
}
func (plugin *PluginGetSetPayloadSize) Reload() error {
return nil
}
func (plugin *PluginGetSetPayloadSize) Eval(pluginsState *PluginsState, msg *dns.Msg) error { func (plugin *PluginGetSetPayloadSize) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
pluginsState.originalMaxPayloadSize = 512 - ResponseOverhead pluginsState.originalMaxPayloadSize = 512 - ResponseOverhead
opt := msg.IsEdns0() opt := msg.IsEdns0()
@ -148,6 +187,18 @@ func (plugin *PluginBlockIPv6) Description() string {
return "Immediately return a synthetic response to AAAA queries" return "Immediately return a synthetic response to AAAA queries"
} }
func (plugin *PluginBlockIPv6) Init(proxy *Proxy) error {
return nil
}
func (plugin *PluginBlockIPv6) Drop() error {
return nil
}
func (plugin *PluginBlockIPv6) Reload() error {
return nil
}
func (plugin *PluginBlockIPv6) Eval(pluginsState *PluginsState, msg *dns.Msg) error { func (plugin *PluginBlockIPv6) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
questions := msg.Question questions := msg.Question
if len(questions) != 1 { if len(questions) != 1 {
@ -166,10 +217,38 @@ func (plugin *PluginBlockIPv6) Eval(pluginsState *PluginsState, msg *dns.Msg) er
return nil return nil
} }
// -------- querylog plugin --------
type PluginQueryLog struct{}
func (plugin *PluginQueryLog) Name() string {
return "querylog"
}
func (plugin *PluginQueryLog) Description() string {
return "Log DNS queries"
}
func (plugin *PluginQueryLog) Init(proxy *Proxy) error {
return nil
}
func (plugin *PluginQueryLog) Drop() error {
return nil
}
func (plugin *PluginQueryLog) Reload() error {
return nil
}
func (plugin *PluginQueryLog) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
return nil
}
// ---------------- Response plugins ---------------- // ---------------- Response plugins ----------------
func (pluginsState *PluginsState) ApplyResponsePlugins(packet []byte) ([]byte, error) { func (pluginsState *PluginsState) ApplyResponsePlugins(pluginsGlobals *PluginsGlobals, packet []byte) ([]byte, error) {
if len(*pluginsState.responsePlugins) == 0 { if len(*pluginsGlobals.responsePlugins) == 0 {
return packet, nil return packet, nil
} }
pluginsState.action = PluginsActionForward pluginsState.action = PluginsActionForward
@ -177,8 +256,10 @@ func (pluginsState *PluginsState) ApplyResponsePlugins(packet []byte) ([]byte, e
if err := msg.Unpack(packet); err != nil { if err := msg.Unpack(packet); err != nil {
return packet, err return packet, err
} }
for _, plugin := range *pluginsState.responsePlugins { pluginsGlobals.RLock()
for _, plugin := range *pluginsGlobals.responsePlugins {
if ret := plugin.Eval(pluginsState, &msg); ret != nil { if ret := plugin.Eval(pluginsState, &msg); ret != nil {
pluginsGlobals.RUnlock()
pluginsState.action = PluginsActionDrop pluginsState.action = PluginsActionDrop
return packet, ret return packet, ret
} }
@ -186,6 +267,7 @@ func (pluginsState *PluginsState) ApplyResponsePlugins(packet []byte) ([]byte, e
break break
} }
} }
pluginsGlobals.RUnlock()
packet2, err := msg.PackBuffer(packet) packet2, err := msg.PackBuffer(packet)
if err != nil { if err != nil {
return packet, err return packet, err
@ -219,6 +301,18 @@ func (plugin *PluginCacheResponse) Description() string {
return "DNS cache (writer)." return "DNS cache (writer)."
} }
func (plugin *PluginCacheResponse) Init(proxy *Proxy) error {
return nil
}
func (plugin *PluginCacheResponse) Drop() error {
return nil
}
func (plugin *PluginCacheResponse) Reload() error {
return nil
}
func (plugin *PluginCacheResponse) Eval(pluginsState *PluginsState, msg *dns.Msg) error { func (plugin *PluginCacheResponse) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
plugin.cachedResponses = &cachedResponses plugin.cachedResponses = &cachedResponses
if msg.Rcode == dns.RcodeServerFailure { if msg.Rcode == dns.RcodeServerFailure {
@ -257,6 +351,18 @@ func (plugin *PluginCache) Description() string {
return "DNS cache (reader)." return "DNS cache (reader)."
} }
func (plugin *PluginCache) Init(proxy *Proxy) error {
return nil
}
func (plugin *PluginCache) Drop() error {
return nil
}
func (plugin *PluginCache) Reload() error {
return nil
}
func (plugin *PluginCache) Eval(pluginsState *PluginsState, msg *dns.Msg) error { func (plugin *PluginCache) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
plugin.cachedResponses = &cachedResponses plugin.cachedResponses = &cachedResponses