346 lines
7.0 KiB
Go
Raw Normal View History

2024-09-14 09:12:10 +03:00
package main
import (
"log"
"net"
"sync"
"time"
nft "github.com/google/nftables"
"github.com/miekg/dns"
)
const (
dnsCnameWalkLimit = 16
nftTimeoutSkew = time.Second * 1
)
type DnsAnswer struct {
Qname string
Qtype uint16
Ttl uint32
AddrLen uint32
Addr net.IP
}
func dnsRemap(qname string, qtype uint16, resp *dns.Msg) ([]PowerDnsAnswer, error) {
r_main := make([]dns.RR, 0)
r_extra := make([]dns.RR, 0)
answ := make([]dns.RR, 0, len(resp.Answer)+len(resp.Extra))
answ = append(answ, resp.Answer...)
answ = append(answ, resp.Extra...)
last_qname := qname
needle := qname
name_seen := make(map[string]bool)
name_seen[needle] = true
for range dnsCnameWalkLimit {
i_cnames := make([]int, 0, len(answ))
i_addrs := make([]int, 0, len(answ))
for i := range answ {
switch answ[i].Header().Rrtype {
case dns.TypeCNAME:
i_cnames = append(i_cnames, i)
case dns.TypeA, dns.TypeAAAA:
i_addrs = append(i_addrs, i)
}
}
if len(i_cnames) > 0 {
seen := false
cname_dive := true
for {
if !cname_dive {
break
}
cname_dive = false
for _, i := range i_cnames {
if answ[i].Header().Name != needle {
continue
}
cname := answ[i].(*dns.CNAME)
needle = cname.Target
_, seen = name_seen[needle]
if seen {
// CNAME loop?
log.Printf("CNAME loop: %v -> %v", qname, needle)
return []PowerDnsAnswer{}, nil
}
name_seen[needle] = true
cname_dive = true
break
}
}
}
found := false
for _, i := range i_addrs {
if answ[i].Header().Name != needle {
continue
}
found = true
r_main = append(r_main, answ[i])
}
if found {
for i := range answ {
if !dnsIsAllowedExtraQtype(answ[i].Header().Rrtype) {
continue
}
// if answ[i].Header().Name != needle {
// continue
// }
r_extra = append(r_extra, answ[i])
}
break
}
// trim
answ = answ[:0]
if (needle == last_qname) && (qtype == dns.TypeANY) {
answ = dnsRemapAnyFallback(needle)
continue
}
last_qname = needle
resp, err := dnsCustomResolve(needle, qtype)
if err != nil {
break
}
if resp == nil {
break
}
if len(resp.Answer) != 0 {
for i := range resp.Answer {
answ = append(answ, dns.Copy(resp.Answer[i]))
}
for i := range resp.Extra {
answ = append(answ, dns.Copy(resp.Extra[i]))
}
resp = nil
continue
}
if qtype != dns.TypeANY {
break
}
answ = dnsRemapAnyFallback(needle)
}
if len(r_main) == 0 {
if qname == needle {
log.Printf("not resolved fully %v/%v", qname, dns.TypeToString[qtype])
} else {
log.Printf("not resolved fully %v/%v (stuck at %v)", qname, dns.TypeToString[qtype], needle)
}
return []PowerDnsAnswer{}, nil
}
interim := make([]DnsAnswer, 0, len(r_main))
for i := range r_main {
t := r_main[i].Header().Rrtype
r := DnsAnswer{
Qname: qname,
Qtype: t,
Ttl: r_main[i].Header().Ttl,
}
switch t {
case dns.TypeA:
r.AddrLen = net.IPv4len
r.Addr = make([]byte, net.IPv4len)
copy(r.Addr, r_main[i].(*dns.A).A)
case dns.TypeAAAA:
r.AddrLen = net.IPv6len
r.Addr = make([]byte, net.IPv6len)
copy(r.Addr, r_main[i].(*dns.AAAA).AAAA)
}
interim = append(interim, r)
}
// unify/adjust TTL
var ttl uint32
if cfgTtlFuzzy {
ttl = dnsFuzzClipTtl()
} else {
ttl = cfgTtlMax
for i := range interim {
if ttl > interim[i].Ttl {
ttl = interim[i].Ttl
}
}
ttl = dnsClipTtl(ttl)
}
unix_start := time.Unix(0, 0)
nft_ipv4 := make([]nft.SetElement, 0, len(interim))
nft_ipv6 := make([]nft.SetElement, 0, len(interim))
// remap addresses in answers and prepare nftables maps
for i := range interim {
addrlen := interim[i].AddrLen
var srcAddr net.IP = make([]byte, addrlen)
copy(srcAddr, interim[i].Addr)
var cidr *net.IPNet = nil
switch addrlen {
case net.IPv4len:
if cfgCidrV4 != nil {
cidr = cfgCidrV4
}
case net.IPv6len:
if cfgCidrV6 != nil {
cidr = cfgCidrV6
}
}
if cidr == nil {
// no need to remap or add to nftables
continue
}
dstAddr, nft_ttl := addrMapGet(srcAddr, cidr, ttl)
// HACK: replace addr
copy(interim[i].Addr, dstAddr)
if !cfgWithNft {
continue
}
elem := nft.SetElement{
Key: []byte(dstAddr),
Val: []byte(srcAddr),
// Timeout: time.Duration(nft_ttl),
Timeout: time.Unix(int64(nft_ttl), 0).Add(nftTimeoutSkew).Sub(unix_start).Round(time.Millisecond),
}
switch addrlen {
case net.IPv4len:
nft_ipv4 = append(nft_ipv4, elem)
case net.IPv6len:
nft_ipv6 = append(nft_ipv6, elem)
}
}
// perform nftables assignment
if (len(nft_ipv4) > 0) && (cfgNftMapV4 != "") {
nftDoWithMap(cfgNftTable, cfgNftTableFamily, cfgNftMapV4, func(c *nft.Conn, t *nft.Table, m *nft.Set) error {
_ = c.SetDeleteElements(m, nft_ipv4)
return nil
})
nftDoWithMap(cfgNftTable, cfgNftTableFamily, cfgNftMapV4, func(c *nft.Conn, t *nft.Table, m *nft.Set) error {
return c.SetAddElements(m, nft_ipv4)
})
}
if (len(nft_ipv6) > 0) && (cfgNftMapV6 != "") {
nftDoWithMap(cfgNftTable, cfgNftTableFamily, cfgNftMapV6, func(c *nft.Conn, t *nft.Table, m *nft.Set) error {
_ = c.SetDeleteElements(m, nft_ipv6)
return nil
})
nftDoWithMap(cfgNftTable, cfgNftTableFamily, cfgNftMapV6, func(c *nft.Conn, t *nft.Table, m *nft.Set) error {
return c.SetAddElements(m, nft_ipv6)
})
}
result := make([]PowerDnsAnswer, 0)
for i := range interim {
r := PowerDnsAnswer{
Qname: interim[i].Qname,
Qtype: dns.TypeToString[interim[i].Qtype],
Ttl: ttl,
Content: interim[i].Addr.String(),
}
result = append(result, r)
}
// extra records (if any)
for i := range r_extra {
r, err := dnsRrToPowerDnsAnswer(r_extra[i])
if err != nil {
log.Printf("%v", err)
continue
}
result = append(result, r)
}
// adjust results
for i := range result {
if result[i].Qname == needle {
result[i].Qname = qname
}
}
return result, nil
}
func dnsRemapAnyFallback(qname string) []dns.RR {
var wg sync.WaitGroup
var r_a, r_aaaa []dns.RR
wg.Add(1)
go func() {
defer wg.Done()
resp, err := dnsCustomResolve(qname, dns.TypeA)
if err != nil {
return
}
if resp == nil {
return
}
if len(resp.Answer) == 0 {
return
}
r_a = make([]dns.RR, 0, len(resp.Answer))
for i := range resp.Answer {
r_a = append(r_a, dns.Copy(resp.Answer[i]))
}
for i := range resp.Extra {
r_a = append(r_a, dns.Copy(resp.Extra[i]))
}
}()
wg.Add(1)
go func() {
defer wg.Done()
resp, err := dnsCustomResolve(qname, dns.TypeAAAA)
if err != nil {
return
}
if resp == nil {
return
}
if len(resp.Answer) == 0 {
return
}
r_aaaa = make([]dns.RR, 0, len(resp.Answer))
for i := range resp.Answer {
r_aaaa = append(r_aaaa, dns.Copy(resp.Answer[i]))
}
for i := range resp.Extra {
r_aaaa = append(r_aaaa, dns.Copy(resp.Extra[i]))
}
}()
wg.Wait()
answ := make([]dns.RR, 0, len(r_a)+len(r_aaaa))
// TODO: very naive (no unique record checks)
if r_a != nil {
answ = append(answ, r_a...)
}
if r_aaaa != nil {
answ = append(answ, r_aaaa...)
}
return answ
}