2024-09-14 09:12:10 +03:00

198 lines
4.2 KiB
Go

package main
import (
"fmt"
"log"
"runtime"
nft "github.com/google/nftables"
)
var (
nftTableFamilyToString = map[nft.TableFamily]string{
nft.TableFamilyINet: "inet",
nft.TableFamilyIPv4: "ip",
nft.TableFamilyIPv6: "ip6",
nft.TableFamilyARP: "arp",
nft.TableFamilyNetdev: "netdev",
nft.TableFamilyBridge: "bridge",
}
)
func setupNftables() {
if !cfgWithNft {
return
}
nftDoWithTable(cfgNftTable, cfgNftTableFamily, func(c *nft.Conn, t *nft.Table) error {
var m4, m6 *nft.Set
var nf_err error
if cfgNftMapV4 != "" {
m4, nf_err = nftGetMapByName(c, t, cfgNftMapV4)
if nf_err != nil {
log.Fatal(nf_err)
}
if m4 == nil {
log.Fatalf("unable to find nft map %#v", cfgNftMapV4)
}
}
if cfgNftMapV6 != "" {
m6, nf_err = nftGetMapByName(c, t, cfgNftMapV6)
if nf_err != nil {
log.Fatal(nf_err)
}
if m6 == nil {
log.Fatalf("unable to find nft map %#v", cfgNftMapV6)
}
}
if m4 != nil {
c.FlushSet(m4)
}
if m6 != nil {
c.FlushSet(m6)
}
return nil
})
}
func nftDo(actor func(*nft.Conn) error) error {
if actor == nil {
log.Fatal("nftDo(): actor is nil")
}
runtime.LockOSThread()
defer runtime.UnlockOSThread()
c, err := nft.New()
if err != nil {
log.Printf("nft.New() error: %v", err)
log.Panic(err)
}
if c == nil {
log.Fatal("nft.New() returned nil")
}
err = actor(c)
if err == nil {
err = c.Flush()
}
return err
}
func nftGetTableByName(nftConn *nft.Conn, tableName string, tableFamily nft.TableFamily) (*nft.Table, error) {
var err error
var all []*nft.Table
if tableFamily == nft.TableFamilyUnspecified {
all, err = nftConn.ListTables()
if err != nil {
log.Printf("nft.ListTables: %v", err)
return nil, err
}
if all == nil {
log.Fatal("nft.Conn.ListTables() returned nil")
}
} else {
all, err = nftConn.ListTablesOfFamily(tableFamily)
if err != nil {
log.Printf("nft.ListTablesOfFamily: %v", err)
return nil, err
}
if all == nil {
log.Fatal("nft.Conn.ListTablesOfFamily() returned nil")
}
}
var table *nft.Table
for i := range all {
if all[i].Name != tableName {
continue
}
table = new(nft.Table)
*table = *(all[i])
break
}
return table, nil
}
func nftDoWithTable(tableName string, tableFamily nft.TableFamily, actor func(*nft.Conn, *nft.Table) error) error {
if actor == nil {
log.Fatal("nftDoWithTable(): actor is nil")
}
return nftDo(func(c *nft.Conn) error {
t, err := nftGetTableByName(c, tableName, tableFamily)
if err != nil {
return err
}
if t == nil {
if tableFamily == nft.TableFamilyUnspecified {
return fmt.Errorf("unable to find nft table %#v", tableName)
}
family, ok := nftTableFamilyToString[tableFamily]
if ok {
return fmt.Errorf("unable to find nft table %#v (family %v)", tableName, family)
} else {
return fmt.Errorf("unable to find nft table %#v (family id = %v)", tableName, tableFamily)
}
}
return actor(c, t)
})
}
func nftDoWithSet(tableName string, tableFamily nft.TableFamily, setName string, actor func(*nft.Conn, *nft.Table, *nft.Set) error) error {
if actor == nil {
log.Fatal("nftDoWithSet(): actor is nil")
}
return nftDoWithTable(tableName, tableFamily, func(c *nft.Conn, t *nft.Table) error {
s, err := c.GetSetByName(t, setName)
if err != nil {
return err
}
if s == nil {
return fmt.Errorf("unable to find nft set %#v", setName)
}
return actor(c, t, s)
})
}
func nftGetMapByName(nftConn *nft.Conn, nftTable *nft.Table, mapName string) (*nft.Set, error) {
m, err := nftConn.GetSetByName(nftTable, mapName)
if err != nil {
return nil, err
}
if m == nil {
return nil, nil
}
if !m.IsMap {
return nil, fmt.Errorf("nft set %#v is not map", mapName)
}
return m, nil
}
func nftDoWithMap(tableName string, tableFamily nft.TableFamily, mapName string, actor func(*nft.Conn, *nft.Table, *nft.Set) error) error {
if actor == nil {
log.Fatal("nftDoWithMap(): actor is nil")
}
return nftDoWithTable(tableName, tableFamily, func(c *nft.Conn, t *nft.Table) error {
m, err := nftGetMapByName(c, t, mapName)
if err != nil {
return err
}
if m == nil {
return fmt.Errorf("unable to find nft map %#v", mapName)
}
return actor(c, t, m)
})
}