mirror of
https://github.com/qdm12/gluetun.git
synced 2026-06-15 07:54:08 +02:00
wip
This commit is contained in:
@@ -0,0 +1,99 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/nftables"
|
||||
)
|
||||
|
||||
// SaveAndRestore saves the current nftables tree and returns a restore function that
|
||||
// can be called to restore the saved tree.
|
||||
func (f *Firewall) SaveAndRestore(_ context.Context) (restore func(context.Context), err error) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
conn, err := nftables.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating nftables connection: %w", err)
|
||||
}
|
||||
tables, err := saveTables(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("saving nftables state: %w", err)
|
||||
}
|
||||
return func(_ context.Context) {
|
||||
conn, err := nftables.New()
|
||||
if err != nil {
|
||||
f.logger.Warnf("creating nftables connection for restore: %s", err)
|
||||
return
|
||||
}
|
||||
err = restoreTables(conn, tables)
|
||||
if err != nil {
|
||||
f.logger.Warnf("restoring nftables state: %s", err)
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
type savedTable struct {
|
||||
table *nftables.Table
|
||||
chains []savedChain
|
||||
}
|
||||
|
||||
type savedChain struct {
|
||||
chain *nftables.Chain
|
||||
rules []*nftables.Rule
|
||||
}
|
||||
|
||||
func saveTables(conn *nftables.Conn) ([]savedTable, error) {
|
||||
tables, err := conn.ListTables()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
savedTables := make([]savedTable, len(tables))
|
||||
for i, table := range tables {
|
||||
savedTables[i].table = table
|
||||
|
||||
chains, err := conn.ListChains()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name != table.Name ||
|
||||
chain.Table.Family != table.Family {
|
||||
continue
|
||||
}
|
||||
rules, err := conn.GetRules(table, chain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting rules for chain %s in table %s: %w", chain.Name, table.Name, err)
|
||||
}
|
||||
savedChain := savedChain{chain: chain, rules: rules}
|
||||
savedTables[i].chains = append(savedTables[i].chains, savedChain)
|
||||
}
|
||||
}
|
||||
|
||||
return savedTables, nil
|
||||
}
|
||||
|
||||
func restoreTables(conn *nftables.Conn, savedTables []savedTable) error {
|
||||
conn.FlushRuleset()
|
||||
|
||||
for _, savedTable := range savedTables {
|
||||
table := conn.AddTable(savedTable.table)
|
||||
for _, savedChain := range savedTable.chains {
|
||||
// Make the [nftables.Chain.Table] points to the new [nftables.Table]
|
||||
// created in this connection.
|
||||
savedChain.chain.Table = table
|
||||
savedChain.chain = conn.AddChain(savedChain.chain)
|
||||
|
||||
for _, rule := range savedChain.rules {
|
||||
rule.Table = table
|
||||
rule.Chain = savedChain.chain
|
||||
conn.AddRule(rule)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return conn.Flush()
|
||||
}
|
||||
Reference in New Issue
Block a user