From 0faa8ec86a28efafb475a08066ff13b39fcaab66 Mon Sep 17 00:00:00 2001 From: Konstantin Demin Date: Thu, 6 Jun 2024 07:44:26 +0300 Subject: [PATCH] rework address mapping --- addr-map.go | 165 ++++++++++++++++++++++++++++++------- cfg.go | 4 +- dns-remap.go | 9 +- example-conf/nftables.conf | 4 +- main.go | 1 + 5 files changed, 145 insertions(+), 38 deletions(-) diff --git a/addr-map.go b/addr-map.go index 603d81e..762ed28 100644 --- a/addr-map.go +++ b/addr-map.go @@ -1,43 +1,148 @@ package main import ( + "crypto/rand" "encoding/binary" "log" + "math" "net" + "sync" + "time" "github.com/cespare/xxhash/v2" ) -// TODO: write more convenient version -func addrMap(srcIp net.IP, dstCidr *net.IPNet) net.IP { - addrlen := len(srcIp) - if addrlen != len(dstCidr.IP) { - log.Fatalf("addrMap(): src/dst size mismatch: %v vs %v", len(srcIp), len(dstCidr.IP)) - } +const ( + addrMapHouseKeepInterval = time.Second / 2 +) - var addr net.IP = make([]byte, addrlen) +var ( + addr4, addr6 sync.Map +) - hsum := xxhash.Sum64(srcIp) - bsum := binary.NativeEndian.AppendUint64([]byte{}, hsum) - blen := len(bsum) - if addrlen > blen { - // extend bsum - tmp := append(make([]byte, addrlen-blen), bsum...) - bsum = tmp - } - if addrlen < blen { - // trim bsum - bsum = bsum[:addrlen] - } - - // apply inverted mask to bsum and sum with addr - for i := range addrlen / 4 { - a := binary.NativeEndian.Uint32(dstCidr.IP[i*4:]) - b := binary.NativeEndian.Uint32(bsum[i*4:]) - m := binary.NativeEndian.Uint32(dstCidr.Mask[i*4:]) - a += (b & ^m) - binary.NativeEndian.PutUint32(addr[i*4:], a) - } - - return addr +type AddrMap struct { + SrcAddr net.IP + DstAddr net.IP + Ttl uint32 + Created time.Time +} + +func (a *AddrMap) GetTtl() int32 { + x := math.Trunc(time.Since(a.Created).Round(addrMapHouseKeepInterval).Seconds()) + return int32(a.Ttl) - int32(x) +} + +func setupAddrMapHousekeeping() { + go func() { + for { + time.Sleep(addrMapHouseKeepInterval) + addr4.Range(func(key, value any) bool { + a, ok := value.(AddrMap) + if ok { + if a.GetTtl() > 1 { + return true + } + } + // delete if value is bogus or if ttl is less than second + addr4.Delete(key) + return true + }) + } + }() + go func() { + for { + time.Sleep(addrMapHouseKeepInterval) + addr6.Range(func(key, value any) bool { + a, ok := value.(AddrMap) + if ok { + if a.GetTtl() > 1 { + return true + } + } + // delete if value is bogus or if ttl is less than second + addr6.Delete(key) + return true + }) + } + }() +} + +func addrMapGet(srcIp net.IP, dstCidr *net.IPNet, ttl uint32) net.IP { + addrlen := len(srcIp) + switch addrlen { + case net.IPv4len, net.IPv6len: + default: + log.Fatalf("addrMapGet(): src size mismatch: %v", addrlen) + } + if addrlen != len(dstCidr.IP) { + log.Fatalf("addrMapGet(): src/dst size mismatch: %v vs %v", addrlen, len(dstCidr.IP)) + } + if addrlen != len(dstCidr.Mask) { + log.Fatalf("addrMapGet(): src/dst size mismatch: %v vs %v", addrlen, len(dstCidr.IP)) + } + + var curr AddrMap + curr.SrcAddr = make([]byte, addrlen) + curr.DstAddr = make([]byte, addrlen) + copy(curr.DstAddr, srcIp) + curr.Ttl = ttl + + for { + _, err := rand.Read(curr.SrcAddr) + if err != nil { + log.Fatalf("rand.Read(): error %v", err) + } + + // adjust random bytes to dstCidr + for i := range addrlen / 4 { + a := binary.NativeEndian.Uint32(dstCidr.IP[i*4:]) + b := binary.NativeEndian.Uint32(curr.SrcAddr[i*4:]) + m := binary.NativeEndian.Uint32(dstCidr.Mask[i*4:]) + a += (b & ^m) + binary.NativeEndian.PutUint32(curr.SrcAddr[i*4:], a) + } + hsum := xxhash.Sum64(curr.SrcAddr) + + curr.Created = time.Now() + + var xprev any + var loaded bool + switch addrlen { + case net.IPv4len: + xprev, loaded = addr4.LoadOrStore(hsum, curr) + case net.IPv6len: + xprev, loaded = addr6.LoadOrStore(hsum, curr) + } + if !loaded { + // early return + return curr.SrcAddr + } + + prev, ok := xprev.(AddrMap) + if !ok { + log.Fatalf("addrMapGet(): wrong value type from sync.Map") + } + + if !net.IP.Equal(curr.SrcAddr, prev.SrcAddr) { + // generate next random address + continue + } + if !net.IP.Equal(curr.DstAddr, prev.DstAddr) { + // generate next random address + continue + } + + if prev.GetTtl() < int32(curr.Ttl) { + switch addrlen { + case net.IPv4len: + addr4.Store(hsum, curr) + case net.IPv6len: + addr6.Store(hsum, curr) + } + } + + break + } + + return curr.SrcAddr } diff --git a/cfg.go b/cfg.go index d75a96e..078f487 100644 --- a/cfg.go +++ b/cfg.go @@ -20,8 +20,8 @@ const ( cfgNftTableFamily = nft.TableFamilyINet cfgNftMapV4 = "tele4" cfgNftMapV6 = "tele6" - cfgNftCidrV4 = "251.0.0.0/8" - cfgNftCidrV6 = "2001:db8:11::/48" + cfgNftCidrV4 = "198.18.0.0/15" + cfgNftCidrV6 = "2001:db8:11::/80" cfgSoaNs = "gw.vpn." cfgSoaMbox = "dns.gw.vpn." diff --git a/dns-remap.go b/dns-remap.go index 96731f3..7cd20b8 100644 --- a/dns-remap.go +++ b/dns-remap.go @@ -140,18 +140,19 @@ func dnsRemap(qname string, qtype uint16, orig *dns.Msg) ([]PowerDnsAnswer, erro continue } + // HACK: clip ttl + r.Ttl = dnsClipTtl(r.Ttl) + var srcAddr net.IP = make([]byte, r.AddrLen) copy(srcAddr, r.Addr) var dstAddr net.IP switch r.AddrLen { case net.IPv4len: - dstAddr = addrMap(r.Addr, nftCidrV4) + dstAddr = addrMapGet(r.Addr, nftCidrV4, r.Ttl) case net.IPv6len: - dstAddr = addrMap(r.Addr, nftCidrV6) + dstAddr = addrMapGet(r.Addr, nftCidrV6, r.Ttl) } - // HACK: clip ttl - r.Ttl = dnsClipTtl(r.Ttl) // HACK: replace addr copy(r.Addr, dstAddr) diff --git a/example-conf/nftables.conf b/example-conf/nftables.conf index 5ccb60e..2d4c74b 100644 --- a/example-conf/nftables.conf +++ b/example-conf/nftables.conf @@ -1,7 +1,7 @@ #!/usr/sbin/nft -f -define n_tele4 = 251.0.0.0/8 -define n_tele6 = 2001:db8:11::/48 +define n_tele4 = 198.18.0.0/15 +define n_tele6 = 2001:db8:11::/80 table inet uni { diff --git a/main.go b/main.go index 9ecf8ed..2ec8903 100644 --- a/main.go +++ b/main.go @@ -25,6 +25,7 @@ func main() { setupNftables() r := setupGin() + setupAddrMapHousekeeping() log.Printf("%s: ready\n", userAgent) if err := r.Run(cfgListen); err != nil {