dnscrypt-proxy/dnscrypt-proxy/oblivious_doh.go

192 lines
5.2 KiB
Go
Raw Normal View History

2021-03-30 11:03:25 +02:00
package main
import (
"crypto/subtle"
"encoding/binary"
"fmt"
2021-06-12 14:16:20 +02:00
"github.com/jedisct1/dlog"
2021-03-30 11:03:25 +02:00
hpkecompact "github.com/jedisct1/go-hpke-compact"
)
const (
odohVersion = uint16(0x0001)
odohTestVersion = uint16(0xff06)
maxODoHConfigs = 10
2021-03-30 11:03:25 +02:00
)
type ODoHTargetConfig struct {
2021-03-30 11:03:25 +02:00
suite *hpkecompact.Suite
keyID []byte
publicKey []byte
}
func encodeLengthValue(b []byte) []byte {
lengthBuffer := make([]byte, 2)
binary.BigEndian.PutUint16(lengthBuffer, uint16(len(b)))
return append(lengthBuffer, b...)
}
func parseODoHTargetConfig(config []byte) (ODoHTargetConfig, error) {
2021-03-30 11:03:25 +02:00
if len(config) < 8 {
return ODoHTargetConfig{}, fmt.Errorf("Malformed config")
2021-03-30 11:03:25 +02:00
}
kemID := binary.BigEndian.Uint16(config[0:2])
kdfID := binary.BigEndian.Uint16(config[2:4])
aeadID := binary.BigEndian.Uint16(config[4:6])
publicKeyLength := binary.BigEndian.Uint16(config[6:8])
publicKey := config[8:]
if len(publicKey) != int(publicKeyLength) {
return ODoHTargetConfig{}, fmt.Errorf("Malformed config")
2021-03-30 11:03:25 +02:00
}
suite, err := hpkecompact.NewSuite(hpkecompact.KemID(kemID), hpkecompact.KdfID(kdfID), hpkecompact.AeadID(aeadID))
if err != nil {
return ODoHTargetConfig{}, err
2021-03-30 11:03:25 +02:00
}
_, _, err = suite.NewClientContext(publicKey, []byte("odoh query"), nil)
if err != nil {
return ODoHTargetConfig{}, err
2021-03-30 11:03:25 +02:00
}
keyID, err := suite.Expand(suite.Extract(config, nil), []byte("odoh key id"), uint16(suite.Hash().Size()))
if err != nil {
return ODoHTargetConfig{}, err
2021-03-30 11:03:25 +02:00
}
return ODoHTargetConfig{
2021-03-30 11:03:25 +02:00
suite: suite,
publicKey: publicKey,
keyID: encodeLengthValue(keyID),
}, nil
}
func parseODoHTargetConfigs(configs []byte) ([]ODoHTargetConfig, error) {
if len(configs) <= 2 {
2021-06-12 14:48:02 +02:00
return nil, fmt.Errorf("Server didn't return any ODoH configurations")
}
2021-03-30 11:03:25 +02:00
length := binary.BigEndian.Uint16(configs)
if len(configs) != int(length)+2 {
return nil, fmt.Errorf("Malformed configs")
}
targets := make([]ODoHTargetConfig, 0)
2021-03-30 11:03:25 +02:00
offset := 2
for {
2021-06-06 00:14:56 +02:00
if offset+4 > len(configs) || len(targets) >= maxODoHConfigs {
break
2021-03-30 11:03:25 +02:00
}
configVersion := binary.BigEndian.Uint16(configs[offset : offset+2])
configLength := binary.BigEndian.Uint16(configs[offset+2 : offset+4])
if configVersion == odohVersion || configVersion == odohTestVersion {
2021-06-12 14:16:20 +02:00
if configVersion != odohVersion {
dlog.Debugf("Server still uses the legacy 0x%x ODoH version", configVersion)
}
2021-03-30 11:03:25 +02:00
target, err := parseODoHTargetConfig(configs[offset+4 : offset+4+int(configLength)])
if err == nil {
targets = append(targets, target)
}
}
offset = offset + int(configLength) + 4
}
2021-06-06 00:14:56 +02:00
return targets, nil
2021-03-30 11:03:25 +02:00
}
type ODoHQuery struct {
suite *hpkecompact.Suite
ctx hpkecompact.ClientContext
odohPlaintext []byte
odohMessage []byte
}
func (t ODoHTargetConfig) encryptQuery(query []byte) (ODoHQuery, error) {
2021-03-30 11:03:25 +02:00
clientCtx, encryptedSharedSecret, err := t.suite.NewClientContext(t.publicKey, []byte("odoh query"), nil)
if err != nil {
return ODoHQuery{}, err
}
odohPlaintext := make([]byte, 4+len(query))
binary.BigEndian.PutUint16(odohPlaintext[0:2], uint16(len(query)))
copy(odohPlaintext[2:], query)
aad := append([]byte{0x01}, t.keyID...)
ciphertext, err := clientCtx.EncryptToServer(odohPlaintext, aad)
2021-05-09 16:16:38 +02:00
if err != nil {
return ODoHQuery{}, err
}
2021-03-30 11:03:25 +02:00
encryptedMessage := encodeLengthValue(append(encryptedSharedSecret, ciphertext...))
odohMessage := append(append([]byte{0x01}, t.keyID...), encryptedMessage...)
return ODoHQuery{
suite: t.suite,
odohPlaintext: odohPlaintext,
odohMessage: odohMessage,
ctx: clientCtx,
}, nil
}
func (q ODoHQuery) decryptResponse(response []byte) ([]byte, error) {
if len(response) < 3 {
return nil, fmt.Errorf("Malformed response")
}
messageType := response[0]
if messageType != uint8(0x02) {
return nil, fmt.Errorf("Malformed response")
}
responseNonceLength := binary.BigEndian.Uint16(response[1:3])
if len(response) < 5+int(responseNonceLength) {
return nil, fmt.Errorf("Malformed response")
}
responseNonceEnc := response[1 : 3+responseNonceLength]
secret, err := q.ctx.Export([]byte("odoh response"), q.suite.KeyBytes)
if err != nil {
return nil, err
}
salt := append(q.odohPlaintext, responseNonceEnc...)
prk := q.suite.Extract(secret, salt)
key, err := q.suite.Expand(prk, []byte("odoh key"), q.suite.KeyBytes)
if err != nil {
return nil, err
}
nonce, err := q.suite.Expand(prk, []byte("odoh nonce"), q.suite.NonceBytes)
if err != nil {
return nil, err
}
2021-06-07 18:16:15 +02:00
cipher, err := q.suite.NewRawCipher(key)
2021-03-30 11:03:25 +02:00
if err != nil {
return nil, err
}
ctLength := binary.BigEndian.Uint16(response[3+int(responseNonceLength) : 5+int(responseNonceLength)])
if int(ctLength) != len(response[5+int(responseNonceLength):]) {
return nil, fmt.Errorf("Malformed response")
}
ct := response[5+int(responseNonceLength):]
aad := response[0 : 3+int(responseNonceLength)]
2021-06-07 18:16:15 +02:00
responsePlaintext, err := cipher.Open(nil, nonce, ct, aad)
2021-03-30 11:03:25 +02:00
if err != nil {
return nil, err
}
responseLength := binary.BigEndian.Uint16(responsePlaintext[0:2])
valid := 1
for i := 4 + int(responseLength); i < len(responsePlaintext); i++ {
valid &= subtle.ConstantTimeByteEq(response[i], 0x00)
2021-03-30 11:03:25 +02:00
}
if valid != 1 {
return nil, fmt.Errorf("Malformed response")
}
return responsePlaintext[2 : 2+int(responseLength)], nil
}