package main import ( "net" "sync" "time" nft "github.com/google/nftables" "github.com/miekg/dns" ) type DnsAnswer struct { Qname string Qtype uint16 Ttl uint32 AddrLen uint32 Addr net.IP } func dnsRemap(qname string, qtype uint16, orig *dns.Msg) ([]PowerDnsAnswer, error) { result := make([]PowerDnsAnswer, 0) interim := make([]DnsAnswer, 0, len(orig.Answer)) needle := qname name_seen := map[string]bool{ needle: true, } cname_dive := true for { if !cname_dive { break } cname_dive = false for _, rr := range orig.Answer { if rr.Header().Name != needle { continue } if rr.Header().Rrtype != dns.TypeCNAME { continue } cname_dive = true cname := rr.(*dns.CNAME) _, seen := name_seen[cname.Target] if seen { // CNAME loop? return []PowerDnsAnswer{}, nil } needle = cname.Target name_seen[needle] = true break } if cname_dive { continue } for _, rr := range orig.Answer { if rr.Header().Name != needle { continue } t := rr.Header().Rrtype switch t { case dns.TypeA, dns.TypeAAAA: // continue below default: continue } r := DnsAnswer{ Qname: qname, Qtype: t, Ttl: dnsClipTtl(rr.Header().Ttl), } switch r.Qtype { case dns.TypeA: r.Addr = rr.(*dns.A).A r.AddrLen = net.IPv4len case dns.TypeAAAA: r.Addr = rr.(*dns.AAAA).AAAA r.AddrLen = net.IPv6len } interim = append(interim, r) } } // fix missing A/AAAA records if (len(interim) == 0) && ((needle != qname) || (qtype == dns.TypeANY)) { var wg sync.WaitGroup var a, aaaa []PowerDnsAnswer if qtype != dns.TypeAAAA { wg.Add(1) go func() { defer wg.Done() x, _ := dnsApi_lookup_int(needle, dns.TypeA) if x != nil { a = x.([]PowerDnsAnswer) } }() } if qtype != dns.TypeA { wg.Add(1) go func() { defer wg.Done() x, _ := dnsApi_lookup_int(needle, dns.TypeAAAA) if x != nil { aaaa = x.([]PowerDnsAnswer) } }() } wg.Wait() if a != nil { result = append(result, a...) } if aaaa != nil { result = append(result, aaaa...) } // HACK: replace qname for i := range result { result[i].Qname = qname } return result, nil } unix_start := time.Unix(0, 0) nft_ipv4 := make([]nft.SetElement, 0, len(interim)) nft_ipv6 := make([]nft.SetElement, 0, len(interim)) // fill map elements for _, r := range interim { switch r.AddrLen { case net.IPv4len, net.IPv6len: default: continue } // HACK: clip ttl r.Ttl = dnsClipTtl(r.Ttl) var srcAddr net.IP = make([]byte, r.AddrLen) copy(srcAddr, r.Addr) var dstAddr net.IP switch r.AddrLen { case net.IPv4len: dstAddr = addrMapGet(r.Addr, nftCidrV4, r.Ttl) case net.IPv6len: dstAddr = addrMapGet(r.Addr, nftCidrV6, r.Ttl) } // HACK: replace addr copy(r.Addr, dstAddr) elem := nft.SetElement{ Key: []byte(dstAddr), Val: []byte(srcAddr), // Timeout: time.Duration(r.Ttl), Timeout: time.Unix(int64(r.Ttl), 0).Sub(unix_start), } switch r.AddrLen { case net.IPv4len: nft_ipv4 = append(nft_ipv4, elem) case net.IPv6len: nft_ipv6 = append(nft_ipv6, elem) } } // perform nftables assignment if len(nft_ipv4) > 0 { nftDoWithTable(cfgNftTable, cfgNftTableFamily, func(c *nft.Conn, t *nft.Table) error { m, err := nftGetMapByName(c, t, cfgNftMapV4) if err != nil { return err } _ = c.SetDeleteElements(m, nft_ipv4) return nil }) nftDoWithTable(cfgNftTable, cfgNftTableFamily, func(c *nft.Conn, t *nft.Table) error { m, err := nftGetMapByName(c, t, cfgNftMapV4) if err != nil { return err } return c.SetAddElements(m, nft_ipv4) }) } if len(nft_ipv6) > 0 { nftDoWithTable(cfgNftTable, cfgNftTableFamily, func(c *nft.Conn, t *nft.Table) error { m, err := nftGetMapByName(c, t, cfgNftMapV6) if err != nil { return err } _ = c.SetDeleteElements(m, nft_ipv6) return nil }) nftDoWithTable(cfgNftTable, cfgNftTableFamily, func(c *nft.Conn, t *nft.Table) error { m, err := nftGetMapByName(c, t, cfgNftMapV6) if err != nil { return err } return c.SetAddElements(m, nft_ipv6) }) } for _, i := range interim { t, ok := dns.TypeToString[i.Qtype] if !ok { continue } r := PowerDnsAnswer{ Qname: i.Qname, Qtype: t, Ttl: i.Ttl, Content: i.Addr.String(), } result = append(result, r) } return result, nil }