Files
gluetun/internal/firewall/enable.go
T
Quentin McGaw ec24ffdfd8 hotfix(firewall): save and restore behavior fixed
- restore if IPv4 set all policies fails
- fix deadlock when using iptables custom rules
- fix setting ipv6 rules when running runMixedIptablesInstruction
2026-02-28 14:37:58 +00:00

201 lines
4.6 KiB
Go

package firewall
import (
"context"
"fmt"
"github.com/qdm12/gluetun/internal/netlink"
)
func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
if enabled == c.enabled {
if enabled {
c.logger.Info("already enabled")
} else {
c.logger.Info("already disabled")
}
return nil
}
if !enabled {
c.logger.Info("disabling...")
c.restore(ctx)
c.enabled = false
c.logger.Info("disabled successfully")
return nil
}
c.logger.Info("enabling...")
if err := c.enable(ctx); err != nil {
return fmt.Errorf("enabling firewall: %w", err)
}
c.enabled = true
c.logger.Info("enabled successfully")
return nil
}
func (c *Config) enable(ctx context.Context) (err error) {
c.restore, err = c.impl.SaveAndRestore(ctx)
if err != nil {
return fmt.Errorf("saving firewall rules: %w", err)
}
defer func() {
if err != nil {
c.restore(context.Background())
}
}()
if err = c.impl.SetIPv4AllPolicies(ctx, "DROP"); err != nil {
return err
}
if err = c.impl.SetIPv6AllPolicies(ctx, "DROP"); err != nil {
return err
}
// Loopback traffic
if err = c.impl.AcceptInputThroughInterface(ctx, "lo"); err != nil {
return err
}
const remove = false
if err = c.impl.AcceptOutputThroughInterface(ctx, "lo", remove); err != nil {
return err
}
if err = c.impl.AcceptEstablishedRelatedTraffic(ctx); err != nil {
return err
}
if err = c.allowVPNIP(ctx); err != nil {
return err
}
localInterfaces := make(map[string]struct{}, len(c.localNetworks))
for _, network := range c.localNetworks {
err = c.impl.AcceptOutputFromIPToSubnet(ctx,
network.InterfaceName, network.IP, network.IPNet, remove)
if err != nil {
return err
}
_, localInterfaceSeen := localInterfaces[network.InterfaceName]
if localInterfaceSeen {
continue
}
localInterfaces[network.InterfaceName] = struct{}{}
err = c.impl.AcceptIpv6MulticastOutput(ctx, network.InterfaceName)
if err != nil {
return fmt.Errorf("accepting IPv6 multicast output: %w", err)
}
}
if err = c.allowOutboundSubnets(ctx); err != nil {
return err
}
// Allows packets from any IP address to go through eth0 / local network
// to reach Gluetun.
for _, network := range c.localNetworks {
if err := c.impl.AcceptInputToSubnet(ctx, network.InterfaceName, network.IPNet); err != nil {
return err
}
}
if err = c.allowInputPorts(ctx); err != nil {
return err
}
err = c.redirectPorts(ctx)
if err != nil {
return fmt.Errorf("redirecting ports: %w", err)
}
if err := c.impl.RunUserPostRules(ctx, c.customRulesPath); err != nil {
return fmt.Errorf("running user defined post firewall rules: %w", err)
}
return nil
}
func (c *Config) allowVPNIP(ctx context.Context) (err error) {
if !c.vpnConnection.IP.IsValid() {
return nil
}
const remove = false
interfacesSeen := make(map[string]struct{}, len(c.defaultRoutes))
for _, defaultRoute := range c.defaultRoutes {
_, seen := interfacesSeen[defaultRoute.NetInterface]
if seen {
continue
}
interfacesSeen[defaultRoute.NetInterface] = struct{}{}
err = c.impl.AcceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove)
if err != nil {
return fmt.Errorf("accepting output traffic through VPN: %w", err)
}
}
return nil
}
func (c *Config) allowOutboundSubnets(ctx context.Context) (err error) {
for _, subnet := range c.outboundSubnets {
subnetIsIPv6 := subnet.Addr().Is6()
firewallUpdated := false
for _, defaultRoute := range c.defaultRoutes {
defaultRouteIsIPv6 := defaultRoute.Family == netlink.FamilyV6
ipFamilyMatch := subnetIsIPv6 == defaultRouteIsIPv6
if !ipFamilyMatch {
continue
}
firewallUpdated = true
const remove = false
err := c.impl.AcceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
defaultRoute.AssignedIP, subnet, remove)
if err != nil {
return err
}
}
if !firewallUpdated {
c.logIgnoredSubnetFamily(subnet)
}
}
return nil
}
func (c *Config) allowInputPorts(ctx context.Context) (err error) {
for port, netInterfaces := range c.allowedInputPorts {
for netInterface := range netInterfaces {
const remove = false
err = c.impl.AcceptInputToPort(ctx, netInterface, port, remove)
if err != nil {
return fmt.Errorf("accepting input port %d on interface %s: %w",
port, netInterface, err)
}
}
}
return nil
}
func (c *Config) redirectPorts(ctx context.Context) (err error) {
for _, portRedirection := range c.portRedirections {
const remove = false
err = c.impl.RedirectPort(ctx, portRedirection.interfaceName, portRedirection.sourcePort,
portRedirection.destinationPort, remove)
if err != nil {
return err
}
}
return nil
}