dnscrypt-proxy/dnscrypt-proxy/plugin_forward.go

125 lines
3.0 KiB
Go

package main
import (
"fmt"
"math/rand"
"net"
"strings"
"github.com/jedisct1/dlog"
"github.com/miekg/dns"
)
type PluginForwardEntry struct {
domain string
servers []string
}
type PluginForward struct {
forwardMap []PluginForwardEntry
}
func (plugin *PluginForward) Name() string {
return "forward"
}
func (plugin *PluginForward) Description() string {
return "Route queries matching specific domains to a dedicated set of servers"
}
func (plugin *PluginForward) Init(proxy *Proxy) error {
dlog.Noticef("Loading the set of forwarding rules from [%s]", proxy.forwardFile)
lines, err := ReadTextFile(proxy.forwardFile)
if err != nil {
return err
}
for lineNo, line := range strings.Split(lines, "\n") {
line = TrimAndStripInlineComments(line)
if len(line) == 0 {
continue
}
domain, serversStr, ok := StringTwoFields(line)
if !ok {
return fmt.Errorf(
"Syntax error for a forwarding rule at line %d. Expected syntax: example.com 9.9.9.9,8.8.8.8",
1+lineNo,
)
}
domain = strings.ToLower(domain)
var servers []string
for _, server := range strings.Split(serversStr, ",") {
server = strings.TrimSpace(server)
server = strings.TrimPrefix(server, "[")
server = strings.TrimSuffix(server, "]")
if ip := net.ParseIP(server); ip != nil {
if ip.To4() != nil {
server = fmt.Sprintf("%s:%d", server, 53)
} else {
server = fmt.Sprintf("[%s]:%d", server, 53)
}
}
dlog.Infof("Forwarding [%s] to %s", domain, server)
servers = append(servers, server)
}
if len(servers) == 0 {
continue
}
plugin.forwardMap = append(plugin.forwardMap, PluginForwardEntry{
domain: domain,
servers: servers,
})
}
return nil
}
func (plugin *PluginForward) Drop() error {
return nil
}
func (plugin *PluginForward) Reload() error {
return nil
}
func (plugin *PluginForward) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
qName := pluginsState.qName
qNameLen := len(qName)
var servers []string
for _, candidate := range plugin.forwardMap {
candidateLen := len(candidate.domain)
if candidateLen > qNameLen {
continue
}
if (qName[qNameLen-candidateLen:] == candidate.domain &&
(candidateLen == qNameLen || (qName[qNameLen-candidateLen-1] == '.'))) ||
(candidate.domain == ".") {
servers = candidate.servers
break
}
}
if len(servers) == 0 {
return nil
}
server := servers[rand.Intn(len(servers))]
pluginsState.serverName = server
client := dns.Client{Net: pluginsState.serverProto, Timeout: pluginsState.timeout}
respMsg, _, err := client.Exchange(msg, server)
if err != nil {
return err
}
if respMsg.Truncated {
client.Net = "tcp"
respMsg, _, err = client.Exchange(msg, server)
if err != nil {
return err
}
}
if edns0 := respMsg.IsEdns0(); edns0 == nil || !edns0.Do() {
respMsg.AuthenticatedData = false
}
respMsg.Id = msg.Id
pluginsState.synthResponse = respMsg
pluginsState.action = PluginsActionSynth
pluginsState.returnCode = PluginsReturnCodeForward
return nil
}