package main import ( "encoding/binary" "errors" "net" "strings" "time" "unicode/utf8" "github.com/jedisct1/dlog" "github.com/miekg/dns" ) func EmptyResponseFromMessage(srcMsg *dns.Msg) *dns.Msg { dstMsg := dns.Msg{MsgHdr: srcMsg.MsgHdr, Compress: true} dstMsg.Question = srcMsg.Question dstMsg.Response = true if srcMsg.RecursionDesired { dstMsg.RecursionAvailable = true } dstMsg.RecursionDesired = false dstMsg.CheckingDisabled = false dstMsg.AuthenticatedData = false if edns0 := srcMsg.IsEdns0(); edns0 != nil { dstMsg.SetEdns0(edns0.UDPSize(), edns0.Do()) } return &dstMsg } func TruncatedResponse(packet []byte) ([]byte, error) { srcMsg := dns.Msg{} if err := srcMsg.Unpack(packet); err != nil { return nil, err } dstMsg := EmptyResponseFromMessage(&srcMsg) dstMsg.Truncated = true return dstMsg.Pack() } func RefusedResponseFromMessage(srcMsg *dns.Msg, refusedCode bool, ipv4 net.IP, ipv6 net.IP, ttl uint32) *dns.Msg { dstMsg := EmptyResponseFromMessage(srcMsg) ede := new(dns.EDNS0_EDE) if edns0 := dstMsg.IsEdns0(); edns0 != nil { edns0.Option = append(edns0.Option, ede) } ede.InfoCode = dns.ExtendedErrorCodeFiltered if refusedCode { dstMsg.Rcode = dns.RcodeRefused } else { dstMsg.Rcode = dns.RcodeSuccess questions := srcMsg.Question if len(questions) == 0 { return dstMsg } question := questions[0] sendHInfoResponse := true if ipv4 != nil && question.Qtype == dns.TypeA { rr := new(dns.A) rr.Hdr = dns.RR_Header{Name: question.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl} rr.A = ipv4.To4() if rr.A != nil { dstMsg.Answer = []dns.RR{rr} sendHInfoResponse = false ede.InfoCode = dns.ExtendedErrorCodeForgedAnswer } } else if ipv6 != nil && question.Qtype == dns.TypeAAAA { rr := new(dns.AAAA) rr.Hdr = dns.RR_Header{Name: question.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: ttl} rr.AAAA = ipv6.To16() if rr.AAAA != nil { dstMsg.Answer = []dns.RR{rr} sendHInfoResponse = false ede.InfoCode = dns.ExtendedErrorCodeForgedAnswer } } if sendHInfoResponse { hinfo := new(dns.HINFO) hinfo.Hdr = dns.RR_Header{ Name: question.Name, Rrtype: dns.TypeHINFO, Class: dns.ClassINET, Ttl: ttl, } hinfo.Cpu = "This query has been locally blocked" hinfo.Os = "by dnscrypt-proxy" dstMsg.Answer = []dns.RR{hinfo} } else { ede.ExtraText = "This query has been locally blocked by dnscrypt-proxy" } } return dstMsg } func HasTCFlag(packet []byte) bool { return packet[2]&2 == 2 } func TransactionID(packet []byte) uint16 { return binary.BigEndian.Uint16(packet[0:2]) } func SetTransactionID(packet []byte, tid uint16) { binary.BigEndian.PutUint16(packet[0:2], tid) } func Rcode(packet []byte) uint8 { return packet[3] & 0xf } func NormalizeRawQName(name *[]byte) { for i, c := range *name { if c >= 65 && c <= 90 { (*name)[i] = c + 32 } } } func NormalizeQName(str string) (string, error) { if len(str) == 0 || str == "." { return ".", nil } hasUpper := false str = strings.TrimSuffix(str, ".") strLen := len(str) for i := 0; i < strLen; i++ { c := str[i] if c >= utf8.RuneSelf { return str, errors.New("Query name is not an ASCII string") } hasUpper = hasUpper || ('A' <= c && c <= 'Z') } if !hasUpper { return str, nil } var b strings.Builder b.Grow(len(str)) for i := 0; i < strLen; i++ { c := str[i] if 'A' <= c && c <= 'Z' { c += 'a' - 'A' } b.WriteByte(c) } return b.String(), nil } func getMinTTL(msg *dns.Msg, minTTL uint32, maxTTL uint32, cacheNegMinTTL uint32, cacheNegMaxTTL uint32) time.Duration { if (msg.Rcode != dns.RcodeSuccess && msg.Rcode != dns.RcodeNameError) || (len(msg.Answer) <= 0 && len(msg.Ns) <= 0) { return time.Duration(cacheNegMinTTL) * time.Second } var ttl uint32 if msg.Rcode == dns.RcodeSuccess { ttl = uint32(maxTTL) } else { ttl = uint32(cacheNegMaxTTL) } if len(msg.Answer) > 0 { for _, rr := range msg.Answer { if rr.Header().Ttl < ttl { ttl = rr.Header().Ttl } } } else { for _, rr := range msg.Ns { if rr.Header().Ttl < ttl { ttl = rr.Header().Ttl } } } if msg.Rcode == dns.RcodeSuccess { if ttl < minTTL { ttl = minTTL } } else { if ttl < cacheNegMinTTL { ttl = cacheNegMinTTL } } return time.Duration(ttl) * time.Second } func setMaxTTL(msg *dns.Msg, ttl uint32) { for _, rr := range msg.Answer { if ttl < rr.Header().Ttl { rr.Header().Ttl = ttl } } for _, rr := range msg.Ns { if ttl < rr.Header().Ttl { rr.Header().Ttl = ttl } } for _, rr := range msg.Extra { header := rr.Header() if header.Rrtype == dns.TypeOPT { continue } if ttl < rr.Header().Ttl { rr.Header().Ttl = ttl } } } func updateTTL(msg *dns.Msg, expiration time.Time) { until := time.Until(expiration) ttl := uint32(0) if until > 0 { ttl = uint32(until / time.Second) if until-time.Duration(ttl)*time.Second >= time.Second/2 { ttl += 1 } } for _, rr := range msg.Answer { rr.Header().Ttl = ttl } for _, rr := range msg.Ns { rr.Header().Ttl = ttl } for _, rr := range msg.Extra { if rr.Header().Rrtype != dns.TypeOPT { rr.Header().Ttl = ttl } } } func hasEDNS0Padding(packet []byte) (bool, error) { msg := dns.Msg{} if err := msg.Unpack(packet); err != nil { return false, err } if edns0 := msg.IsEdns0(); edns0 != nil { for _, option := range edns0.Option { if option.Option() == dns.EDNS0PADDING { return true, nil } } } return false, nil } func addEDNS0PaddingIfNoneFound(msg *dns.Msg, unpaddedPacket []byte, paddingLen int) ([]byte, error) { edns0 := msg.IsEdns0() if edns0 == nil { msg.SetEdns0(uint16(MaxDNSPacketSize), false) edns0 = msg.IsEdns0() if edns0 == nil { return unpaddedPacket, nil } } for _, option := range edns0.Option { if option.Option() == dns.EDNS0PADDING { return unpaddedPacket, nil } } ext := new(dns.EDNS0_PADDING) padding := make([]byte, paddingLen) for i := range padding { padding[i] = 'X' } ext.Padding = padding[:paddingLen] edns0.Option = append(edns0.Option, ext) return msg.Pack() } func removeEDNS0Options(msg *dns.Msg) bool { edns0 := msg.IsEdns0() if edns0 == nil { return false } edns0.Option = []dns.EDNS0{} return true } func dddToByte(s []byte) byte { return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0')) } func PackTXTRR(s string) []byte { bs := make([]byte, len(s)) msg := make([]byte, 0) copy(bs, s) for i := 0; i < len(bs); i++ { if bs[i] == '\\' { i++ if i == len(bs) { break } if i+2 < len(bs) && isDigit(bs[i]) && isDigit(bs[i+1]) && isDigit(bs[i+2]) { msg = append(msg, dddToByte(bs[i:])) i += 2 } else if bs[i] == 't' { msg = append(msg, '\t') } else if bs[i] == 'r' { msg = append(msg, '\r') } else if bs[i] == 'n' { msg = append(msg, '\n') } else { msg = append(msg, bs[i]) } } else { msg = append(msg, bs[i]) } } return msg } type DNSExchangeResponse struct { response *dns.Msg rtt time.Duration priority int fragmentsBlocked bool err error } func DNSExchange( proxy *Proxy, proto string, query *dns.Msg, serverAddress string, relay *DNSCryptRelay, serverName *string, tryFragmentsSupport bool, ) (*dns.Msg, time.Duration, bool, error) { for { cancelChannel := make(chan struct{}) maxTries := 3 channel := make(chan DNSExchangeResponse, 2*maxTries) var err error options := 0 for tries := 0; tries < maxTries; tries++ { if tryFragmentsSupport { queryCopy := query.Copy() queryCopy.Id += uint16(options) go func(query *dns.Msg, delay time.Duration) { time.Sleep(delay) option := DNSExchangeResponse{err: errors.New("Canceled")} select { case <-cancelChannel: default: option = _dnsExchange(proxy, proto, query, serverAddress, relay, 1500) } option.fragmentsBlocked = false option.priority = 0 channel <- option }(queryCopy, time.Duration(200*tries)*time.Millisecond) options++ } queryCopy := query.Copy() queryCopy.Id += uint16(options) go func(query *dns.Msg, delay time.Duration) { time.Sleep(delay) option := DNSExchangeResponse{err: errors.New("Canceled")} select { case <-cancelChannel: default: option = _dnsExchange(proxy, proto, query, serverAddress, relay, 480) } option.fragmentsBlocked = true option.priority = 1 channel <- option }(queryCopy, time.Duration(250*tries)*time.Millisecond) options++ } var bestOption *DNSExchangeResponse for i := 0; i < options; i++ { if dnsExchangeResponse := <-channel; dnsExchangeResponse.err == nil { if bestOption == nil || dnsExchangeResponse.priority < bestOption.priority || (dnsExchangeResponse.priority == bestOption.priority && dnsExchangeResponse.rtt < bestOption.rtt) { bestOption = &dnsExchangeResponse if bestOption.priority == 0 { close(cancelChannel) break } } } else { err = dnsExchangeResponse.err } } if bestOption != nil { if bestOption.fragmentsBlocked { dlog.Debugf("[%v] public key retrieval succeeded but server is blocking fragments", *serverName) } else { dlog.Debugf("[%v] public key retrieval succeeded", *serverName) } return bestOption.response, bestOption.rtt, bestOption.fragmentsBlocked, nil } if relay == nil || !proxy.anonDirectCertFallback { if err == nil { err = errors.New("Unable to reach the server") } return nil, 0, false, err } dlog.Infof( "Unable to get the public key for [%v] via relay [%v], retrying over a direct connection", *serverName, relay.RelayUDPAddr.IP, ) relay = nil } } func _dnsExchange( proxy *Proxy, proto string, query *dns.Msg, serverAddress string, relay *DNSCryptRelay, paddedLen int, ) DNSExchangeResponse { var packet []byte var rtt time.Duration if proto == "udp" { qNameLen, padding := len(query.Question[0].Name), 0 if qNameLen < paddedLen { padding = paddedLen - qNameLen } if padding > 0 { opt := new(dns.OPT) opt.Hdr.Name = "." ext := new(dns.EDNS0_PADDING) ext.Padding = make([]byte, padding) opt.Option = append(opt.Option, ext) query.Extra = []dns.RR{opt} } binQuery, err := query.Pack() if err != nil { return DNSExchangeResponse{err: err} } udpAddr, err := net.ResolveUDPAddr("udp", serverAddress) if err != nil { return DNSExchangeResponse{err: err} } upstreamAddr := udpAddr if relay != nil { proxy.prepareForRelay(udpAddr.IP, udpAddr.Port, &binQuery) upstreamAddr = relay.RelayUDPAddr } now := time.Now() pc, err := net.DialUDP("udp", nil, upstreamAddr) if err != nil { return DNSExchangeResponse{err: err} } defer pc.Close() if err := pc.SetDeadline(time.Now().Add(proxy.timeout)); err != nil { return DNSExchangeResponse{err: err} } if _, err := pc.Write(binQuery); err != nil { return DNSExchangeResponse{err: err} } packet = make([]byte, MaxDNSPacketSize) length, err := pc.Read(packet) if err != nil { return DNSExchangeResponse{err: err} } rtt = time.Since(now) packet = packet[:length] } else { binQuery, err := query.Pack() if err != nil { return DNSExchangeResponse{err: err} } tcpAddr, err := net.ResolveTCPAddr("tcp", serverAddress) if err != nil { return DNSExchangeResponse{err: err} } upstreamAddr := tcpAddr if relay != nil { proxy.prepareForRelay(tcpAddr.IP, tcpAddr.Port, &binQuery) upstreamAddr = relay.RelayTCPAddr } now := time.Now() var pc net.Conn proxyDialer := proxy.xTransport.proxyDialer if proxyDialer == nil { pc, err = net.DialTCP("tcp", nil, upstreamAddr) } else { pc, err = (*proxyDialer).Dial("tcp", tcpAddr.String()) } if err != nil { return DNSExchangeResponse{err: err} } defer pc.Close() if err := pc.SetDeadline(time.Now().Add(proxy.timeout)); err != nil { return DNSExchangeResponse{err: err} } binQuery, err = PrefixWithSize(binQuery) if err != nil { return DNSExchangeResponse{err: err} } if _, err := pc.Write(binQuery); err != nil { return DNSExchangeResponse{err: err} } packet, err = ReadPrefixed(&pc) if err != nil { return DNSExchangeResponse{err: err} } rtt = time.Since(now) } msg := dns.Msg{} if err := msg.Unpack(packet); err != nil { return DNSExchangeResponse{err: err} } return DNSExchangeResponse{response: &msg, rtt: rtt, err: nil} }