diff --git a/dns-remap.go b/dns-remap.go index 5a65612..96731f3 100644 --- a/dns-remap.go +++ b/dns-remap.go @@ -18,117 +18,114 @@ type DnsAnswer struct { } func dnsRemap(qname string, qtype uint16, orig *dns.Msg) ([]PowerDnsAnswer, error) { + result := make([]PowerDnsAnswer, 0) interim := make([]DnsAnswer, 0, len(orig.Answer)) - real_qnames := make([]string, 0) - for _, rr := range orig.Answer { - if rr.Header().Name != qname { + needle := qname + name_seen := map[string]bool{ + needle: true, + } + cname_dive := true + for { + if !cname_dive { + break + } + cname_dive = false + + for _, rr := range orig.Answer { + if rr.Header().Name != needle { + continue + } + if rr.Header().Rrtype != dns.TypeCNAME { + continue + } + + cname_dive = true + cname := rr.(*dns.CNAME) + _, seen := name_seen[cname.Target] + if seen { + // CNAME loop? + return []PowerDnsAnswer{}, nil + } + + needle = cname.Target + name_seen[needle] = true + break + } + if cname_dive { continue } - r := DnsAnswer{ - Qname: qname, - Qtype: rr.Header().Rrtype, - Ttl: rr.Header().Ttl, - } - switch r.Qtype { - case dns.TypeA: - r.Addr = rr.(*dns.A).A - r.AddrLen = net.IPv4len + for _, rr := range orig.Answer { + if rr.Header().Name != needle { + continue + } + + t := rr.Header().Rrtype + switch t { + case dns.TypeA, dns.TypeAAAA: + // continue below + default: + continue + } + + r := DnsAnswer{ + Qname: qname, + Qtype: t, + Ttl: rr.Header().Ttl, + } + switch r.Qtype { + case dns.TypeA: + r.Addr = rr.(*dns.A).A + r.AddrLen = net.IPv4len + case dns.TypeAAAA: + r.Addr = rr.(*dns.AAAA).AAAA + r.AddrLen = net.IPv6len + } interim = append(interim, r) - case dns.TypeAAAA: - r.Addr = rr.(*dns.AAAA).AAAA - r.AddrLen = net.IPv6len - interim = append(interim, r) - case dns.TypeCNAME: - cname := rr.(*dns.CNAME) - real_qnames = append(real_qnames, cname.Target) } } - var wg sync.WaitGroup - var mtx_interim sync.Mutex - // reprocess answers due to CNAME - for _, real_qname := range real_qnames { - wg.Add(1) - go func(real_name string) { - defer wg.Done() + // fix missing A/AAAA records + if (len(interim) == 0) && ((needle != qname) || (qtype == dns.TypeANY)) { + var wg sync.WaitGroup + var a, aaaa []PowerDnsAnswer - found_qname := false - for _, rr := range orig.Answer { - if rr.Header().Name != real_name { - continue + if qtype != dns.TypeAAAA { + wg.Add(1) + go func() { + defer wg.Done() + x, _ := dnsApi_lookup_int(needle, dns.TypeA) + if x != nil { + a = x.([]PowerDnsAnswer) } - found_qname = true - - r := DnsAnswer{ - Qname: qname, - Qtype: rr.Header().Rrtype, - Ttl: rr.Header().Ttl, + }() + } + if qtype != dns.TypeA { + wg.Add(1) + go func() { + defer wg.Done() + x, _ := dnsApi_lookup_int(needle, dns.TypeAAAA) + if x != nil { + aaaa = x.([]PowerDnsAnswer) } - switch r.Qtype { - case dns.TypeA: - r.Addr = rr.(*dns.A).A - r.AddrLen = net.IPv4len + }() + } - mtx_interim.Lock() - interim = append(interim, r) - mtx_interim.Unlock() - case dns.TypeAAAA: - r.Addr = rr.(*dns.AAAA).AAAA - r.AddrLen = net.IPv6len + wg.Wait() - mtx_interim.Lock() - interim = append(interim, r) - mtx_interim.Unlock() - } - } - if found_qname { - return - } + if a != nil { + result = append(result, a...) + } + if aaaa != nil { + result = append(result, aaaa...) + } - resp, err := dnsCustomResolve(real_name, dns.TypeANY) - if err != nil { - return - } - if resp == nil { - return - } + // HACK: replace qname + for i := range result { + result[i].Qname = qname + } - for _, rr := range resp.Answer { - if rr.Header().Name != real_name { - continue - } - - r := DnsAnswer{ - Qname: qname, - Qtype: rr.Header().Rrtype, - Ttl: rr.Header().Ttl, - } - switch r.Qtype { - case dns.TypeA: - r.Addr = rr.(*dns.A).A - r.AddrLen = net.IPv4len - - mtx_interim.Lock() - interim = append(interim, r) - mtx_interim.Unlock() - case dns.TypeAAAA: - r.Addr = rr.(*dns.AAAA).AAAA - r.AddrLen = net.IPv6len - - mtx_interim.Lock() - interim = append(interim, r) - mtx_interim.Unlock() - } - } - }(real_qname) - } - wg.Wait() - - result := make([]PowerDnsAnswer, 0, len(interim)) - // nothing to do - if len(interim) == 0 { return result, nil } @@ -139,7 +136,6 @@ func dnsRemap(qname string, qtype uint16, orig *dns.Msg) ([]PowerDnsAnswer, erro for _, r := range interim { switch r.AddrLen { case net.IPv4len, net.IPv6len: - break default: continue }