package main import ( "fmt" "log" "runtime" nft "github.com/google/nftables" ) var ( nftTableFamilyToString = map[nft.TableFamily]string{ nft.TableFamilyINet: "inet", nft.TableFamilyIPv4: "ip", nft.TableFamilyIPv6: "ip6", nft.TableFamilyARP: "arp", nft.TableFamilyNetdev: "netdev", nft.TableFamilyBridge: "bridge", } ) func setupNftables() { if !cfgWithNft { return } nftDoWithTable(cfgNftTable, cfgNftTableFamily, func(c *nft.Conn, t *nft.Table) error { var m4, m6 *nft.Set var nf_err error if cfgNftMapV4 != "" { m4, nf_err = nftGetMapByName(c, t, cfgNftMapV4) if nf_err != nil { log.Fatal(nf_err) } if m4 == nil { log.Fatalf("unable to find nft map %#v", cfgNftMapV4) } } if cfgNftMapV6 != "" { m6, nf_err = nftGetMapByName(c, t, cfgNftMapV6) if nf_err != nil { log.Fatal(nf_err) } if m6 == nil { log.Fatalf("unable to find nft map %#v", cfgNftMapV6) } } if m4 != nil { c.FlushSet(m4) } if m6 != nil { c.FlushSet(m6) } return nil }) } func nftDo(actor func(*nft.Conn) error) error { if actor == nil { log.Fatal("nftDo(): actor is nil") } runtime.LockOSThread() defer runtime.UnlockOSThread() c, err := nft.New() if err != nil { log.Printf("nft.New() error: %v", err) log.Panic(err) } if c == nil { log.Fatal("nft.New() returned nil") } err = actor(c) if err == nil { err = c.Flush() } return err } func nftGetTableByName(nftConn *nft.Conn, tableName string, tableFamily nft.TableFamily) (*nft.Table, error) { var err error var all []*nft.Table if tableFamily == nft.TableFamilyUnspecified { all, err = nftConn.ListTables() if err != nil { log.Printf("nft.ListTables: %v", err) return nil, err } if all == nil { log.Fatal("nft.Conn.ListTables() returned nil") } } else { all, err = nftConn.ListTablesOfFamily(tableFamily) if err != nil { log.Printf("nft.ListTablesOfFamily: %v", err) return nil, err } if all == nil { log.Fatal("nft.Conn.ListTablesOfFamily() returned nil") } } var table *nft.Table for i := range all { if all[i].Name != tableName { continue } table = new(nft.Table) *table = *(all[i]) break } return table, nil } func nftDoWithTable(tableName string, tableFamily nft.TableFamily, actor func(*nft.Conn, *nft.Table) error) error { if actor == nil { log.Fatal("nftDoWithTable(): actor is nil") } return nftDo(func(c *nft.Conn) error { t, err := nftGetTableByName(c, tableName, tableFamily) if err != nil { return err } if t == nil { if tableFamily == nft.TableFamilyUnspecified { return fmt.Errorf("unable to find nft table %#v", tableName) } family, ok := nftTableFamilyToString[tableFamily] if ok { return fmt.Errorf("unable to find nft table %#v (family %v)", tableName, family) } else { return fmt.Errorf("unable to find nft table %#v (family id = %v)", tableName, tableFamily) } } return actor(c, t) }) } func nftDoWithSet(tableName string, tableFamily nft.TableFamily, setName string, actor func(*nft.Conn, *nft.Table, *nft.Set) error) error { if actor == nil { log.Fatal("nftDoWithSet(): actor is nil") } return nftDoWithTable(tableName, tableFamily, func(c *nft.Conn, t *nft.Table) error { s, err := c.GetSetByName(t, setName) if err != nil { return err } if s == nil { return fmt.Errorf("unable to find nft set %#v", setName) } return actor(c, t, s) }) } func nftGetMapByName(nftConn *nft.Conn, nftTable *nft.Table, mapName string) (*nft.Set, error) { m, err := nftConn.GetSetByName(nftTable, mapName) if err != nil { return nil, err } if m == nil { return nil, nil } if !m.IsMap { return nil, fmt.Errorf("nft set %#v is not map", mapName) } return m, nil } func nftDoWithMap(tableName string, tableFamily nft.TableFamily, mapName string, actor func(*nft.Conn, *nft.Table, *nft.Set) error) error { if actor == nil { log.Fatal("nftDoWithMap(): actor is nil") } return nftDoWithTable(tableName, tableFamily, func(c *nft.Conn, t *nft.Table) error { m, err := nftGetMapByName(c, t, mapName) if err != nil { return err } if m == nil { return fmt.Errorf("unable to find nft map %#v", mapName) } return actor(c, t, m) }) }