346 lines
7.0 KiB
Go
346 lines
7.0 KiB
Go
package main
|
|
|
|
import (
|
|
"log"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
nft "github.com/google/nftables"
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
const (
|
|
dnsCnameWalkLimit = 16
|
|
nftTimeoutSkew = time.Second * 1
|
|
)
|
|
|
|
type DnsAnswer struct {
|
|
Qname string
|
|
Qtype uint16
|
|
Ttl uint32
|
|
AddrLen uint32
|
|
Addr net.IP
|
|
}
|
|
|
|
func dnsRemap(qname string, qtype uint16, resp *dns.Msg) ([]PowerDnsAnswer, error) {
|
|
r_main := make([]dns.RR, 0)
|
|
r_extra := make([]dns.RR, 0)
|
|
|
|
answ := make([]dns.RR, 0, len(resp.Answer)+len(resp.Extra))
|
|
answ = append(answ, resp.Answer...)
|
|
answ = append(answ, resp.Extra...)
|
|
|
|
last_qname := qname
|
|
needle := qname
|
|
name_seen := make(map[string]bool)
|
|
name_seen[needle] = true
|
|
|
|
for range dnsCnameWalkLimit {
|
|
i_cnames := make([]int, 0, len(answ))
|
|
i_addrs := make([]int, 0, len(answ))
|
|
for i := range answ {
|
|
switch answ[i].Header().Rrtype {
|
|
case dns.TypeCNAME:
|
|
i_cnames = append(i_cnames, i)
|
|
case dns.TypeA, dns.TypeAAAA:
|
|
i_addrs = append(i_addrs, i)
|
|
}
|
|
}
|
|
|
|
if len(i_cnames) > 0 {
|
|
seen := false
|
|
cname_dive := true
|
|
for {
|
|
if !cname_dive {
|
|
break
|
|
}
|
|
cname_dive = false
|
|
|
|
for _, i := range i_cnames {
|
|
if answ[i].Header().Name != needle {
|
|
continue
|
|
}
|
|
|
|
cname := answ[i].(*dns.CNAME)
|
|
needle = cname.Target
|
|
_, seen = name_seen[needle]
|
|
if seen {
|
|
// CNAME loop?
|
|
log.Printf("CNAME loop: %v -> %v", qname, needle)
|
|
return []PowerDnsAnswer{}, nil
|
|
}
|
|
name_seen[needle] = true
|
|
cname_dive = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
found := false
|
|
for _, i := range i_addrs {
|
|
if answ[i].Header().Name != needle {
|
|
continue
|
|
}
|
|
found = true
|
|
r_main = append(r_main, answ[i])
|
|
}
|
|
if found {
|
|
for i := range answ {
|
|
if !dnsIsAllowedExtraQtype(answ[i].Header().Rrtype) {
|
|
continue
|
|
}
|
|
// if answ[i].Header().Name != needle {
|
|
// continue
|
|
// }
|
|
|
|
r_extra = append(r_extra, answ[i])
|
|
}
|
|
break
|
|
}
|
|
|
|
// trim
|
|
answ = answ[:0]
|
|
|
|
if (needle == last_qname) && (qtype == dns.TypeANY) {
|
|
answ = dnsRemapAnyFallback(needle)
|
|
continue
|
|
}
|
|
|
|
last_qname = needle
|
|
resp, err := dnsCustomResolve(needle, qtype)
|
|
if err != nil {
|
|
break
|
|
}
|
|
if resp == nil {
|
|
break
|
|
}
|
|
|
|
if len(resp.Answer) != 0 {
|
|
for i := range resp.Answer {
|
|
answ = append(answ, dns.Copy(resp.Answer[i]))
|
|
}
|
|
for i := range resp.Extra {
|
|
answ = append(answ, dns.Copy(resp.Extra[i]))
|
|
}
|
|
resp = nil
|
|
continue
|
|
}
|
|
|
|
if qtype != dns.TypeANY {
|
|
break
|
|
}
|
|
|
|
answ = dnsRemapAnyFallback(needle)
|
|
}
|
|
|
|
if len(r_main) == 0 {
|
|
if qname == needle {
|
|
log.Printf("not resolved fully %v/%v", qname, dns.TypeToString[qtype])
|
|
} else {
|
|
log.Printf("not resolved fully %v/%v (stuck at %v)", qname, dns.TypeToString[qtype], needle)
|
|
}
|
|
return []PowerDnsAnswer{}, nil
|
|
}
|
|
|
|
interim := make([]DnsAnswer, 0, len(r_main))
|
|
for i := range r_main {
|
|
t := r_main[i].Header().Rrtype
|
|
r := DnsAnswer{
|
|
Qname: qname,
|
|
Qtype: t,
|
|
Ttl: r_main[i].Header().Ttl,
|
|
}
|
|
|
|
switch t {
|
|
case dns.TypeA:
|
|
r.AddrLen = net.IPv4len
|
|
r.Addr = make([]byte, net.IPv4len)
|
|
copy(r.Addr, r_main[i].(*dns.A).A)
|
|
case dns.TypeAAAA:
|
|
r.AddrLen = net.IPv6len
|
|
r.Addr = make([]byte, net.IPv6len)
|
|
copy(r.Addr, r_main[i].(*dns.AAAA).AAAA)
|
|
}
|
|
|
|
interim = append(interim, r)
|
|
}
|
|
|
|
// unify/adjust TTL
|
|
var ttl uint32
|
|
if cfgTtlFuzzy {
|
|
ttl = dnsFuzzClipTtl()
|
|
} else {
|
|
ttl = cfgTtlMax
|
|
for i := range interim {
|
|
if ttl > interim[i].Ttl {
|
|
ttl = interim[i].Ttl
|
|
}
|
|
}
|
|
ttl = dnsClipTtl(ttl)
|
|
}
|
|
|
|
unix_start := time.Unix(0, 0)
|
|
nft_ipv4 := make([]nft.SetElement, 0, len(interim))
|
|
nft_ipv6 := make([]nft.SetElement, 0, len(interim))
|
|
|
|
// remap addresses in answers and prepare nftables maps
|
|
for i := range interim {
|
|
addrlen := interim[i].AddrLen
|
|
var srcAddr net.IP = make([]byte, addrlen)
|
|
copy(srcAddr, interim[i].Addr)
|
|
|
|
var cidr *net.IPNet = nil
|
|
switch addrlen {
|
|
case net.IPv4len:
|
|
if cfgCidrV4 != nil {
|
|
cidr = cfgCidrV4
|
|
}
|
|
case net.IPv6len:
|
|
if cfgCidrV6 != nil {
|
|
cidr = cfgCidrV6
|
|
}
|
|
}
|
|
if cidr == nil {
|
|
// no need to remap or add to nftables
|
|
continue
|
|
}
|
|
|
|
dstAddr, nft_ttl := addrMapGet(srcAddr, cidr, ttl)
|
|
// HACK: replace addr
|
|
copy(interim[i].Addr, dstAddr)
|
|
|
|
if !cfgWithNft {
|
|
continue
|
|
}
|
|
|
|
elem := nft.SetElement{
|
|
Key: []byte(dstAddr),
|
|
Val: []byte(srcAddr),
|
|
// Timeout: time.Duration(nft_ttl),
|
|
Timeout: time.Unix(int64(nft_ttl), 0).Add(nftTimeoutSkew).Sub(unix_start).Round(time.Millisecond),
|
|
}
|
|
|
|
switch addrlen {
|
|
case net.IPv4len:
|
|
nft_ipv4 = append(nft_ipv4, elem)
|
|
case net.IPv6len:
|
|
nft_ipv6 = append(nft_ipv6, elem)
|
|
}
|
|
}
|
|
|
|
// perform nftables assignment
|
|
if (len(nft_ipv4) > 0) && (cfgNftMapV4 != "") {
|
|
nftDoWithMap(cfgNftTable, cfgNftTableFamily, cfgNftMapV4, func(c *nft.Conn, t *nft.Table, m *nft.Set) error {
|
|
_ = c.SetDeleteElements(m, nft_ipv4)
|
|
return nil
|
|
})
|
|
nftDoWithMap(cfgNftTable, cfgNftTableFamily, cfgNftMapV4, func(c *nft.Conn, t *nft.Table, m *nft.Set) error {
|
|
return c.SetAddElements(m, nft_ipv4)
|
|
})
|
|
}
|
|
if (len(nft_ipv6) > 0) && (cfgNftMapV6 != "") {
|
|
nftDoWithMap(cfgNftTable, cfgNftTableFamily, cfgNftMapV6, func(c *nft.Conn, t *nft.Table, m *nft.Set) error {
|
|
_ = c.SetDeleteElements(m, nft_ipv6)
|
|
return nil
|
|
})
|
|
nftDoWithMap(cfgNftTable, cfgNftTableFamily, cfgNftMapV6, func(c *nft.Conn, t *nft.Table, m *nft.Set) error {
|
|
return c.SetAddElements(m, nft_ipv6)
|
|
})
|
|
}
|
|
|
|
result := make([]PowerDnsAnswer, 0)
|
|
for i := range interim {
|
|
r := PowerDnsAnswer{
|
|
Qname: interim[i].Qname,
|
|
Qtype: dns.TypeToString[interim[i].Qtype],
|
|
Ttl: ttl,
|
|
Content: interim[i].Addr.String(),
|
|
}
|
|
result = append(result, r)
|
|
}
|
|
|
|
// extra records (if any)
|
|
for i := range r_extra {
|
|
r, err := dnsRrToPowerDnsAnswer(r_extra[i])
|
|
if err != nil {
|
|
log.Printf("%v", err)
|
|
continue
|
|
}
|
|
result = append(result, r)
|
|
}
|
|
|
|
// adjust results
|
|
for i := range result {
|
|
if result[i].Qname == needle {
|
|
result[i].Qname = qname
|
|
}
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func dnsRemapAnyFallback(qname string) []dns.RR {
|
|
var wg sync.WaitGroup
|
|
var r_a, r_aaaa []dns.RR
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
resp, err := dnsCustomResolve(qname, dns.TypeA)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if resp == nil {
|
|
return
|
|
}
|
|
if len(resp.Answer) == 0 {
|
|
return
|
|
}
|
|
|
|
r_a = make([]dns.RR, 0, len(resp.Answer))
|
|
for i := range resp.Answer {
|
|
r_a = append(r_a, dns.Copy(resp.Answer[i]))
|
|
}
|
|
for i := range resp.Extra {
|
|
r_a = append(r_a, dns.Copy(resp.Extra[i]))
|
|
}
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
resp, err := dnsCustomResolve(qname, dns.TypeAAAA)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if resp == nil {
|
|
return
|
|
}
|
|
if len(resp.Answer) == 0 {
|
|
return
|
|
}
|
|
|
|
r_aaaa = make([]dns.RR, 0, len(resp.Answer))
|
|
for i := range resp.Answer {
|
|
r_aaaa = append(r_aaaa, dns.Copy(resp.Answer[i]))
|
|
}
|
|
for i := range resp.Extra {
|
|
r_aaaa = append(r_aaaa, dns.Copy(resp.Extra[i]))
|
|
}
|
|
}()
|
|
|
|
wg.Wait()
|
|
|
|
answ := make([]dns.RR, 0, len(r_a)+len(r_aaaa))
|
|
// TODO: very naive (no unique record checks)
|
|
if r_a != nil {
|
|
answ = append(answ, r_a...)
|
|
}
|
|
if r_aaaa != nil {
|
|
answ = append(answ, r_aaaa...)
|
|
}
|
|
|
|
return answ
|
|
}
|