177 lines
4.5 KiB
Go
Raw Normal View History

2024-09-14 09:12:10 +03:00
package main
import (
"flag"
"log"
"net"
"strings"
"time"
nft "github.com/google/nftables"
)
var (
cfgListen string
cfgTtlMin uint32
cfgTtlMax uint32
cfgTtlFuzzy bool
cfgResolverEndpoint string
cfgResolverProto string
cfgResolverTimeout time.Duration
cfgWithNft bool
cfgNftTable string
cfgNftTableFamily nft.TableFamily
cfgNftMapV4 string
cfgNftMapV6 string
cfgCidrV4 *net.IPNet
cfgCidrV6 *net.IPNet
cfgSoaNs string
cfgSoaMbox string
)
func init() {
var _cfgTtlMin, _cfgTtlMax uint
var _cfgNftTableFamily, _cfgCidrV4, _cfgCidrV6 string
var _cfgTtlFuzzy bool
flag.StringVar(&cfgListen, "listen", "127.0.0.1:8080", "listen on addr:port")
flag.UintVar(&_cfgTtlMin, "ttl-min", 120, "minimum TTL")
flag.UintVar(&_cfgTtlMax, "ttl-max", 1800, "maximum TTL")
flag.BoolVar(&_cfgTtlFuzzy, "ttl-fuzz", false, "fuzz TTL in DNS responses")
flag.StringVar(&cfgResolverEndpoint, "resolver-endpoint", "127.0.0.1:53", "dns resolver addr:port")
flag.StringVar(&cfgResolverProto, "resolver-proto", "", "dns resolver protocol ('udp' or '' for DNS over UDP, 'tcp' for DNS over TCP, 'tcp-tls' for DNS over TLS)")
flag.DurationVar(&cfgResolverTimeout, "resolver-timeout", 5*time.Second, "dns resolver timeout")
flag.StringVar(&_cfgCidrV4, "cidr-ipv4", "", "IPv4 CIDR mapping (e.g. 192.0.2.0/24)")
flag.StringVar(&_cfgCidrV6, "cidr-ipv6", "", "IPv6 CIDR mapping (e.g. 2001:db8::/64)")
flag.StringVar(&cfgNftTable, "nft-table", "", "nft table name (e.g. 'fw4'); leave empty to not bother with nft")
flag.StringVar(&_cfgNftTableFamily, "nft-table-family", "inet", "nft table family (e.g. 'inet')")
flag.StringVar(&cfgNftMapV4, "nft-map-ipv4", "", "nft IPv4:IPv4 map name")
flag.StringVar(&cfgNftMapV6, "nft-map-ipv6", "", "nft IPv6:IPv6 map name")
flag.StringVar(&cfgSoaNs, "soa-ns", "", "fake SOA name server in dotted form (e.g. 'example.org.')")
flag.StringVar(&cfgSoaMbox, "soa-mbox", "", "fake SOA mailbox in dotted form (e.g. 'dns.example.org.')")
flag.Parse()
if _cfgTtlMin > _cfgTtlMax {
log.Fatalf("invalid ttl range: %d-%d", cfgTtlMin, cfgTtlMax)
}
cfgTtlMin = flagClipTtl(_cfgTtlMin)
cfgTtlMax = flagClipTtl(_cfgTtlMax)
cfgResolverProto = flagResolverProtoMap(cfgResolverProto)
cfgResolverTimeout = flagClipResolverTimeout(cfgResolverTimeout)
cfgWithNft = (cfgNftTable != "")
if cfgWithNft {
cfgNftTableFamily = flagNftTableFamilyMap(_cfgNftTableFamily)
if (cfgNftMapV4 == "") && (cfgNftMapV6 == "") {
log.Fatalf("at least one nft map must be specified")
}
if (cfgNftMapV4 != "") && (_cfgCidrV4 == "") {
log.Fatalf("IPv4: nft map requires CIDR to be specified")
}
if (cfgNftMapV6 != "") && (_cfgCidrV6 == "") {
log.Fatalf("IPv6: nft map requires CIDR to be specified")
}
}
var net_err error
if _cfgCidrV4 != "" {
_, cfgCidrV4, net_err = net.ParseCIDR(_cfgCidrV4)
if net_err != nil {
log.Fatal(net_err)
}
}
if _cfgCidrV6 != "" {
_, cfgCidrV6, net_err = net.ParseCIDR(_cfgCidrV6)
if net_err != nil {
log.Fatal(net_err)
}
}
if (cfgSoaNs == "") || (cfgSoaMbox == "") {
log.Fatalf("both SOA NS and SOA MBOX must be specified")
}
// naive adjustments
if !strings.HasSuffix(cfgSoaNs, ".") {
cfgSoaNs = cfgSoaNs + "."
}
if !strings.HasSuffix(cfgSoaMbox, ".") {
cfgSoaMbox = cfgSoaMbox + "."
}
dnsTtlRange = cfgTtlMax - cfgTtlMin
cfgTtlFuzzy = (_cfgTtlFuzzy && (dnsTtlRange > 10))
}
const (
_ttlMin uint = 30
_ttlMax uint = 86400
_resolverTimeoutMin time.Duration = time.Millisecond
_resolverTimeoutMax time.Duration = 30 * time.Second
)
func flagClipTtl(v uint) uint32 {
if v < _ttlMin {
return uint32(_ttlMin)
}
if v > _ttlMax {
return uint32(_ttlMax)
}
return uint32(v)
}
func flagResolverProtoMap(flag string) string {
switch flag {
case "tcp", "tcp-tls":
return flag
case "udp", "":
return ""
}
log.Fatalf("invalid resolver proto: %s", flag)
// unreachable
return ""
}
func flagClipResolverTimeout(v time.Duration) time.Duration {
if v < _resolverTimeoutMin {
return _resolverTimeoutMin
}
if v > _resolverTimeoutMax {
return _resolverTimeoutMax
}
return v
}
var (
nftTableFamilyFromString = map[string]nft.TableFamily{
"inet": nft.TableFamilyINet,
"ip": nft.TableFamilyIPv4,
"ip6": nft.TableFamilyIPv6,
"arp": nft.TableFamilyARP,
"netdev": nft.TableFamilyNetdev,
"bridge": nft.TableFamilyBridge,
}
)
func flagNftTableFamilyMap(flag string) nft.TableFamily {
if v, ok := nftTableFamilyFromString[flag]; ok {
return v
}
log.Fatalf("invalid nft table family: %s", flag)
// unreachable
return nft.TableFamilyUnspecified
}