package main import ( "net" "sync" "time" nft "github.com/google/nftables" "github.com/miekg/dns" ) type DnsAnswer struct { Qname string Qtype uint16 Ttl uint32 AddrLen uint32 Addr net.IP } func dnsRemap(qname string, qtype uint16, orig *dns.Msg) ([]PowerDnsAnswer, error) { interim := make([]DnsAnswer, 0, len(orig.Answer)) real_qnames := make([]string, 0) for _, rr := range orig.Answer { if rr.Header().Name != qname { continue } r := DnsAnswer{ Qname: qname, Qtype: rr.Header().Rrtype, Ttl: rr.Header().Ttl, } switch r.Qtype { case dns.TypeA: r.Addr = rr.(*dns.A).A r.AddrLen = net.IPv4len interim = append(interim, r) case dns.TypeAAAA: r.Addr = rr.(*dns.AAAA).AAAA r.AddrLen = net.IPv6len interim = append(interim, r) case dns.TypeCNAME: cname := rr.(*dns.CNAME) real_qnames = append(real_qnames, cname.Target) } } var wg sync.WaitGroup var mtx_interim sync.Mutex // reprocess answers due to CNAME for _, real_qname := range real_qnames { wg.Add(1) go func(real_name string) { defer wg.Done() found_qname := false for _, rr := range orig.Answer { if rr.Header().Name != real_name { continue } found_qname = true r := DnsAnswer{ Qname: qname, Qtype: rr.Header().Rrtype, Ttl: rr.Header().Ttl, } switch r.Qtype { case dns.TypeA: r.Addr = rr.(*dns.A).A r.AddrLen = net.IPv4len mtx_interim.Lock() interim = append(interim, r) mtx_interim.Unlock() case dns.TypeAAAA: r.Addr = rr.(*dns.AAAA).AAAA r.AddrLen = net.IPv6len mtx_interim.Lock() interim = append(interim, r) mtx_interim.Unlock() } } if found_qname { return } resp, err := dnsCustomResolve(real_name, dns.TypeANY) if err != nil { return } if resp == nil { return } for _, rr := range resp.Answer { if rr.Header().Name != real_name { continue } r := DnsAnswer{ Qname: qname, Qtype: rr.Header().Rrtype, Ttl: rr.Header().Ttl, } switch r.Qtype { case dns.TypeA: r.Addr = rr.(*dns.A).A r.AddrLen = net.IPv4len mtx_interim.Lock() interim = append(interim, r) mtx_interim.Unlock() case dns.TypeAAAA: r.Addr = rr.(*dns.AAAA).AAAA r.AddrLen = net.IPv6len mtx_interim.Lock() interim = append(interim, r) mtx_interim.Unlock() } } }(real_qname) } wg.Wait() result := make([]PowerDnsAnswer, 0, len(interim)) // nothing to do if len(interim) == 0 { return result, nil } unix_start := time.Unix(0, 0) nft_ipv4 := make([]nft.SetElement, 0, len(interim)) nft_ipv6 := make([]nft.SetElement, 0, len(interim)) // fill map elements for _, r := range interim { switch r.AddrLen { case net.IPv4len, net.IPv6len: break default: continue } var srcAddr net.IP = make([]byte, r.AddrLen) copy(srcAddr, r.Addr) var dstAddr net.IP switch r.AddrLen { case net.IPv4len: dstAddr = addrMap(r.Addr, nftCidrV4) case net.IPv6len: dstAddr = addrMap(r.Addr, nftCidrV6) } // HACK: clip ttl r.Ttl = dnsClipTtl(r.Ttl) // HACK: replace addr copy(r.Addr, dstAddr) elem := nft.SetElement{ Key: []byte(dstAddr), Val: []byte(srcAddr), // Timeout: time.Duration(r.Ttl), Timeout: time.Unix(int64(r.Ttl), 0).Sub(unix_start), } switch r.AddrLen { case net.IPv4len: nft_ipv4 = append(nft_ipv4, elem) case net.IPv6len: nft_ipv6 = append(nft_ipv6, elem) } } // perform nftables assignment if len(nft_ipv4) > 0 { nftDoWithTable(cfgNftTable, cfgNftTableFamily, func(c *nft.Conn, t *nft.Table) error { m, err := nftGetMapByName(c, t, cfgNftMapV4) if err != nil { return err } _ = c.SetDeleteElements(m, nft_ipv4) return nil }) nftDoWithTable(cfgNftTable, cfgNftTableFamily, func(c *nft.Conn, t *nft.Table) error { m, err := nftGetMapByName(c, t, cfgNftMapV4) if err != nil { return err } return c.SetAddElements(m, nft_ipv4) }) } if len(nft_ipv6) > 0 { nftDoWithTable(cfgNftTable, cfgNftTableFamily, func(c *nft.Conn, t *nft.Table) error { m, err := nftGetMapByName(c, t, cfgNftMapV6) if err != nil { return err } _ = c.SetDeleteElements(m, nft_ipv6) return nil }) nftDoWithTable(cfgNftTable, cfgNftTableFamily, func(c *nft.Conn, t *nft.Table) error { m, err := nftGetMapByName(c, t, cfgNftMapV6) if err != nil { return err } return c.SetAddElements(m, nft_ipv6) }) } for _, i := range interim { t, ok := dns.TypeToString[i.Qtype] if !ok { continue } r := PowerDnsAnswer{ Qname: i.Qname, Qtype: t, Ttl: i.Ttl, Content: i.Addr.String(), } result = append(result, r) } return result, nil }