196 lines
4.3 KiB
Go
196 lines
4.3 KiB
Go
|
package main
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"log"
|
||
|
"net"
|
||
|
"runtime"
|
||
|
|
||
|
nft "github.com/google/nftables"
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
nftCidrV4 *net.IPNet
|
||
|
nftCidrV6 *net.IPNet
|
||
|
nftTableFamilyToString = map[nft.TableFamily]string{
|
||
|
nft.TableFamilyINet: "inet",
|
||
|
nft.TableFamilyIPv4: "ip",
|
||
|
nft.TableFamilyIPv6: "ip6",
|
||
|
nft.TableFamilyARP: "arp",
|
||
|
nft.TableFamilyNetdev: "netdev",
|
||
|
nft.TableFamilyBridge: "bridge",
|
||
|
}
|
||
|
)
|
||
|
|
||
|
func setupNftables() {
|
||
|
var net_err error
|
||
|
_, nftCidrV4, net_err = net.ParseCIDR(cfgNftCidrV4)
|
||
|
if net_err != nil {
|
||
|
log.Fatal(net_err)
|
||
|
}
|
||
|
_, nftCidrV6, net_err = net.ParseCIDR(cfgNftCidrV6)
|
||
|
if net_err != nil {
|
||
|
log.Fatal(net_err)
|
||
|
}
|
||
|
|
||
|
nftDoWithTable(cfgNftTable, cfgNftTableFamily, func(c *nft.Conn, t *nft.Table) error {
|
||
|
m4, nf_err := nftGetMapByName(c, t, cfgNftMapV4)
|
||
|
if nf_err != nil {
|
||
|
log.Fatal(nf_err)
|
||
|
}
|
||
|
if m4 == nil {
|
||
|
log.Fatalf("unable to find nft map %#v", cfgNftMapV4)
|
||
|
}
|
||
|
m6, nf_err := nftGetMapByName(c, t, cfgNftMapV6)
|
||
|
if nf_err != nil {
|
||
|
log.Fatal(nf_err)
|
||
|
}
|
||
|
if m6 == nil {
|
||
|
log.Fatalf("unable to find nft map %#v", cfgNftMapV6)
|
||
|
}
|
||
|
|
||
|
c.FlushSet(m4)
|
||
|
c.FlushSet(m6)
|
||
|
|
||
|
return nil
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func nftDo(actor func(*nft.Conn) error) error {
|
||
|
if actor == nil {
|
||
|
log.Fatal("nftDo(): actor is nil")
|
||
|
}
|
||
|
|
||
|
runtime.LockOSThread()
|
||
|
defer runtime.UnlockOSThread()
|
||
|
|
||
|
c, err := nft.New()
|
||
|
if err != nil {
|
||
|
log.Printf("nft.New() error: %v", err)
|
||
|
log.Panic(err)
|
||
|
}
|
||
|
if c == nil {
|
||
|
log.Fatal("nft.New() returned nil")
|
||
|
}
|
||
|
|
||
|
err = actor(c)
|
||
|
if err == nil {
|
||
|
err = c.Flush()
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func nftGetTableByName(nftConn *nft.Conn, tableName string, tableFamily nft.TableFamily) (*nft.Table, error) {
|
||
|
var err error
|
||
|
var all []*nft.Table
|
||
|
|
||
|
if tableFamily == nft.TableFamilyUnspecified {
|
||
|
all, err = nftConn.ListTables()
|
||
|
if err != nil {
|
||
|
log.Printf("nft.ListTables: %v", err)
|
||
|
return nil, err
|
||
|
}
|
||
|
if all == nil {
|
||
|
log.Fatal("nft.Conn.ListTables() returned nil")
|
||
|
}
|
||
|
} else {
|
||
|
all, err = nftConn.ListTablesOfFamily(tableFamily)
|
||
|
if err != nil {
|
||
|
log.Printf("nft.ListTablesOfFamily: %v", err)
|
||
|
return nil, err
|
||
|
}
|
||
|
if all == nil {
|
||
|
log.Fatal("nft.Conn.ListTablesOfFamily() returned nil")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var table *nft.Table
|
||
|
for i := range all {
|
||
|
if all[i].Name != tableName {
|
||
|
continue
|
||
|
}
|
||
|
table = new(nft.Table)
|
||
|
*table = *(all[i])
|
||
|
break
|
||
|
}
|
||
|
return table, nil
|
||
|
}
|
||
|
|
||
|
func nftDoWithTable(tableName string, tableFamily nft.TableFamily, actor func(*nft.Conn, *nft.Table) error) error {
|
||
|
if actor == nil {
|
||
|
log.Fatal("nftDoWithTable(): actor is nil")
|
||
|
}
|
||
|
|
||
|
return nftDo(func(c *nft.Conn) error {
|
||
|
t, err := nftGetTableByName(c, tableName, tableFamily)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if t == nil {
|
||
|
if tableFamily == nft.TableFamilyUnspecified {
|
||
|
return fmt.Errorf("unable to find nft table %#v", tableName)
|
||
|
}
|
||
|
|
||
|
family, ok := nftTableFamilyToString[tableFamily]
|
||
|
if ok {
|
||
|
return fmt.Errorf("unable to find nft table %#v (family %v)", tableName, family)
|
||
|
} else {
|
||
|
return fmt.Errorf("unable to find nft table %#v (family id = %v)", tableName, tableFamily)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return actor(c, t)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func nftDoWithSet(tableName string, tableFamily nft.TableFamily, setName string, actor func(*nft.Conn, *nft.Table, *nft.Set) error) error {
|
||
|
if actor == nil {
|
||
|
log.Fatal("nftDoWithSet(): actor is nil")
|
||
|
}
|
||
|
|
||
|
return nftDoWithTable(tableName, tableFamily, func(c *nft.Conn, t *nft.Table) error {
|
||
|
s, err := c.GetSetByName(t, setName)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if s == nil {
|
||
|
return fmt.Errorf("unable to find nft set %#v", setName)
|
||
|
}
|
||
|
|
||
|
return actor(c, t, s)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func nftGetMapByName(nftConn *nft.Conn, nftTable *nft.Table, mapName string) (*nft.Set, error) {
|
||
|
m, err := nftConn.GetSetByName(nftTable, mapName)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if m == nil {
|
||
|
return nil, nil
|
||
|
}
|
||
|
if !m.IsMap {
|
||
|
return nil, fmt.Errorf("nft set %#v is not map", mapName)
|
||
|
}
|
||
|
|
||
|
return m, nil
|
||
|
}
|
||
|
|
||
|
func nftDoWithMap(tableName string, tableFamily nft.TableFamily, mapName string, actor func(*nft.Conn, *nft.Table, *nft.Set) error) error {
|
||
|
if actor == nil {
|
||
|
log.Fatal("nftDoWithMap(): actor is nil")
|
||
|
}
|
||
|
|
||
|
return nftDoWithTable(tableName, tableFamily, func(c *nft.Conn, t *nft.Table) error {
|
||
|
m, err := nftGetMapByName(c, t, mapName)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if m == nil {
|
||
|
return fmt.Errorf("unable to find nft map %#v", mapName)
|
||
|
}
|
||
|
|
||
|
return actor(c, t, m)
|
||
|
})
|
||
|
}
|