initial commit
This commit is contained in:
345
dns-remap.go
Normal file
345
dns-remap.go
Normal file
@@ -0,0 +1,345 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
nft "github.com/google/nftables"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const (
|
||||
dnsCnameWalkLimit = 16
|
||||
nftTimeoutSkew = time.Second * 1
|
||||
)
|
||||
|
||||
type DnsAnswer struct {
|
||||
Qname string
|
||||
Qtype uint16
|
||||
Ttl uint32
|
||||
AddrLen uint32
|
||||
Addr net.IP
|
||||
}
|
||||
|
||||
func dnsRemap(qname string, qtype uint16, resp *dns.Msg) ([]PowerDnsAnswer, error) {
|
||||
r_main := make([]dns.RR, 0)
|
||||
r_extra := make([]dns.RR, 0)
|
||||
|
||||
answ := make([]dns.RR, 0, len(resp.Answer)+len(resp.Extra))
|
||||
answ = append(answ, resp.Answer...)
|
||||
answ = append(answ, resp.Extra...)
|
||||
|
||||
last_qname := qname
|
||||
needle := qname
|
||||
name_seen := make(map[string]bool)
|
||||
name_seen[needle] = true
|
||||
|
||||
for range dnsCnameWalkLimit {
|
||||
i_cnames := make([]int, 0, len(answ))
|
||||
i_addrs := make([]int, 0, len(answ))
|
||||
for i := range answ {
|
||||
switch answ[i].Header().Rrtype {
|
||||
case dns.TypeCNAME:
|
||||
i_cnames = append(i_cnames, i)
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
i_addrs = append(i_addrs, i)
|
||||
}
|
||||
}
|
||||
|
||||
if len(i_cnames) > 0 {
|
||||
seen := false
|
||||
cname_dive := true
|
||||
for {
|
||||
if !cname_dive {
|
||||
break
|
||||
}
|
||||
cname_dive = false
|
||||
|
||||
for _, i := range i_cnames {
|
||||
if answ[i].Header().Name != needle {
|
||||
continue
|
||||
}
|
||||
|
||||
cname := answ[i].(*dns.CNAME)
|
||||
needle = cname.Target
|
||||
_, seen = name_seen[needle]
|
||||
if seen {
|
||||
// CNAME loop?
|
||||
log.Printf("CNAME loop: %v -> %v", qname, needle)
|
||||
return []PowerDnsAnswer{}, nil
|
||||
}
|
||||
name_seen[needle] = true
|
||||
cname_dive = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, i := range i_addrs {
|
||||
if answ[i].Header().Name != needle {
|
||||
continue
|
||||
}
|
||||
found = true
|
||||
r_main = append(r_main, answ[i])
|
||||
}
|
||||
if found {
|
||||
for i := range answ {
|
||||
if !dnsIsAllowedExtraQtype(answ[i].Header().Rrtype) {
|
||||
continue
|
||||
}
|
||||
// if answ[i].Header().Name != needle {
|
||||
// continue
|
||||
// }
|
||||
|
||||
r_extra = append(r_extra, answ[i])
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// trim
|
||||
answ = answ[:0]
|
||||
|
||||
if (needle == last_qname) && (qtype == dns.TypeANY) {
|
||||
answ = dnsRemapAnyFallback(needle)
|
||||
continue
|
||||
}
|
||||
|
||||
last_qname = needle
|
||||
resp, err := dnsCustomResolve(needle, qtype)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if resp == nil {
|
||||
break
|
||||
}
|
||||
|
||||
if len(resp.Answer) != 0 {
|
||||
for i := range resp.Answer {
|
||||
answ = append(answ, dns.Copy(resp.Answer[i]))
|
||||
}
|
||||
for i := range resp.Extra {
|
||||
answ = append(answ, dns.Copy(resp.Extra[i]))
|
||||
}
|
||||
resp = nil
|
||||
continue
|
||||
}
|
||||
|
||||
if qtype != dns.TypeANY {
|
||||
break
|
||||
}
|
||||
|
||||
answ = dnsRemapAnyFallback(needle)
|
||||
}
|
||||
|
||||
if len(r_main) == 0 {
|
||||
if qname == needle {
|
||||
log.Printf("not resolved fully %v/%v", qname, dns.TypeToString[qtype])
|
||||
} else {
|
||||
log.Printf("not resolved fully %v/%v (stuck at %v)", qname, dns.TypeToString[qtype], needle)
|
||||
}
|
||||
return []PowerDnsAnswer{}, nil
|
||||
}
|
||||
|
||||
interim := make([]DnsAnswer, 0, len(r_main))
|
||||
for i := range r_main {
|
||||
t := r_main[i].Header().Rrtype
|
||||
r := DnsAnswer{
|
||||
Qname: qname,
|
||||
Qtype: t,
|
||||
Ttl: r_main[i].Header().Ttl,
|
||||
}
|
||||
|
||||
switch t {
|
||||
case dns.TypeA:
|
||||
r.AddrLen = net.IPv4len
|
||||
r.Addr = make([]byte, net.IPv4len)
|
||||
copy(r.Addr, r_main[i].(*dns.A).A)
|
||||
case dns.TypeAAAA:
|
||||
r.AddrLen = net.IPv6len
|
||||
r.Addr = make([]byte, net.IPv6len)
|
||||
copy(r.Addr, r_main[i].(*dns.AAAA).AAAA)
|
||||
}
|
||||
|
||||
interim = append(interim, r)
|
||||
}
|
||||
|
||||
// unify/adjust TTL
|
||||
var ttl uint32
|
||||
if cfgTtlFuzzy {
|
||||
ttl = dnsFuzzClipTtl()
|
||||
} else {
|
||||
ttl = cfgTtlMax
|
||||
for i := range interim {
|
||||
if ttl > interim[i].Ttl {
|
||||
ttl = interim[i].Ttl
|
||||
}
|
||||
}
|
||||
ttl = dnsClipTtl(ttl)
|
||||
}
|
||||
|
||||
unix_start := time.Unix(0, 0)
|
||||
nft_ipv4 := make([]nft.SetElement, 0, len(interim))
|
||||
nft_ipv6 := make([]nft.SetElement, 0, len(interim))
|
||||
|
||||
// remap addresses in answers and prepare nftables maps
|
||||
for i := range interim {
|
||||
addrlen := interim[i].AddrLen
|
||||
var srcAddr net.IP = make([]byte, addrlen)
|
||||
copy(srcAddr, interim[i].Addr)
|
||||
|
||||
var cidr *net.IPNet = nil
|
||||
switch addrlen {
|
||||
case net.IPv4len:
|
||||
if cfgCidrV4 != nil {
|
||||
cidr = cfgCidrV4
|
||||
}
|
||||
case net.IPv6len:
|
||||
if cfgCidrV6 != nil {
|
||||
cidr = cfgCidrV6
|
||||
}
|
||||
}
|
||||
if cidr == nil {
|
||||
// no need to remap or add to nftables
|
||||
continue
|
||||
}
|
||||
|
||||
dstAddr, nft_ttl := addrMapGet(srcAddr, cidr, ttl)
|
||||
// HACK: replace addr
|
||||
copy(interim[i].Addr, dstAddr)
|
||||
|
||||
if !cfgWithNft {
|
||||
continue
|
||||
}
|
||||
|
||||
elem := nft.SetElement{
|
||||
Key: []byte(dstAddr),
|
||||
Val: []byte(srcAddr),
|
||||
// Timeout: time.Duration(nft_ttl),
|
||||
Timeout: time.Unix(int64(nft_ttl), 0).Add(nftTimeoutSkew).Sub(unix_start).Round(time.Millisecond),
|
||||
}
|
||||
|
||||
switch addrlen {
|
||||
case net.IPv4len:
|
||||
nft_ipv4 = append(nft_ipv4, elem)
|
||||
case net.IPv6len:
|
||||
nft_ipv6 = append(nft_ipv6, elem)
|
||||
}
|
||||
}
|
||||
|
||||
// perform nftables assignment
|
||||
if (len(nft_ipv4) > 0) && (cfgNftMapV4 != "") {
|
||||
nftDoWithMap(cfgNftTable, cfgNftTableFamily, cfgNftMapV4, func(c *nft.Conn, t *nft.Table, m *nft.Set) error {
|
||||
_ = c.SetDeleteElements(m, nft_ipv4)
|
||||
return nil
|
||||
})
|
||||
nftDoWithMap(cfgNftTable, cfgNftTableFamily, cfgNftMapV4, func(c *nft.Conn, t *nft.Table, m *nft.Set) error {
|
||||
return c.SetAddElements(m, nft_ipv4)
|
||||
})
|
||||
}
|
||||
if (len(nft_ipv6) > 0) && (cfgNftMapV6 != "") {
|
||||
nftDoWithMap(cfgNftTable, cfgNftTableFamily, cfgNftMapV6, func(c *nft.Conn, t *nft.Table, m *nft.Set) error {
|
||||
_ = c.SetDeleteElements(m, nft_ipv6)
|
||||
return nil
|
||||
})
|
||||
nftDoWithMap(cfgNftTable, cfgNftTableFamily, cfgNftMapV6, func(c *nft.Conn, t *nft.Table, m *nft.Set) error {
|
||||
return c.SetAddElements(m, nft_ipv6)
|
||||
})
|
||||
}
|
||||
|
||||
result := make([]PowerDnsAnswer, 0)
|
||||
for i := range interim {
|
||||
r := PowerDnsAnswer{
|
||||
Qname: interim[i].Qname,
|
||||
Qtype: dns.TypeToString[interim[i].Qtype],
|
||||
Ttl: ttl,
|
||||
Content: interim[i].Addr.String(),
|
||||
}
|
||||
result = append(result, r)
|
||||
}
|
||||
|
||||
// extra records (if any)
|
||||
for i := range r_extra {
|
||||
r, err := dnsRrToPowerDnsAnswer(r_extra[i])
|
||||
if err != nil {
|
||||
log.Printf("%v", err)
|
||||
continue
|
||||
}
|
||||
result = append(result, r)
|
||||
}
|
||||
|
||||
// adjust results
|
||||
for i := range result {
|
||||
if result[i].Qname == needle {
|
||||
result[i].Qname = qname
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func dnsRemapAnyFallback(qname string) []dns.RR {
|
||||
var wg sync.WaitGroup
|
||||
var r_a, r_aaaa []dns.RR
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
resp, err := dnsCustomResolve(qname, dns.TypeA)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
if len(resp.Answer) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
r_a = make([]dns.RR, 0, len(resp.Answer))
|
||||
for i := range resp.Answer {
|
||||
r_a = append(r_a, dns.Copy(resp.Answer[i]))
|
||||
}
|
||||
for i := range resp.Extra {
|
||||
r_a = append(r_a, dns.Copy(resp.Extra[i]))
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
resp, err := dnsCustomResolve(qname, dns.TypeAAAA)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
if len(resp.Answer) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
r_aaaa = make([]dns.RR, 0, len(resp.Answer))
|
||||
for i := range resp.Answer {
|
||||
r_aaaa = append(r_aaaa, dns.Copy(resp.Answer[i]))
|
||||
}
|
||||
for i := range resp.Extra {
|
||||
r_aaaa = append(r_aaaa, dns.Copy(resp.Extra[i]))
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
answ := make([]dns.RR, 0, len(r_a)+len(r_aaaa))
|
||||
// TODO: very naive (no unique record checks)
|
||||
if r_a != nil {
|
||||
answ = append(answ, r_a...)
|
||||
}
|
||||
if r_aaaa != nil {
|
||||
answ = append(answ, r_aaaa...)
|
||||
}
|
||||
|
||||
return answ
|
||||
}
|
Reference in New Issue
Block a user