feat(dns): re-introduce DNS_SERVER option

- force to set `DNS_UPSTREAM_RESOLVER_TYPE=plain` to avoid any confusion/security hole
- force to set `DNS_UPSTREAM_PLAIN_ADDRESSES` to addresses only with port 53
This commit is contained in:
Quentin McGaw
2026-05-05 21:15:28 +00:00
parent aab10f9d3f
commit 4ea2337668
5 changed files with 97 additions and 18 deletions
+1
View File
@@ -209,6 +209,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
HEALTH_SMALL_CHECK_TYPE=icmp \ HEALTH_SMALL_CHECK_TYPE=icmp \
HEALTH_RESTART_VPN=on \ HEALTH_RESTART_VPN=on \
# DNS # DNS
DNS_SERVER=on \
DNS_UPSTREAM_RESOLVER_TYPE=DoT \ DNS_UPSTREAM_RESOLVER_TYPE=DoT \
# Note: DNS_UPSTREAM_RESOLVERS defaults to cloudflare in code if DNS_UPSTREAM_PLAIN_ADDRESSES is empty # Note: DNS_UPSTREAM_RESOLVERS defaults to cloudflare in code if DNS_UPSTREAM_PLAIN_ADDRESSES is empty
DNS_UPSTREAM_RESOLVERS= \ DNS_UPSTREAM_RESOLVERS= \
@@ -14,10 +14,8 @@ func readObsolete(r *reader.Reader) (warnings []string) {
"DOT_VALIDATION_LOGLEVEL": "DOT_VALIDATION_LOGLEVEL is obsolete because DNSSEC validation is not implemented.", "DOT_VALIDATION_LOGLEVEL": "DOT_VALIDATION_LOGLEVEL is obsolete because DNSSEC validation is not implemented.",
"HEALTH_VPN_DURATION_INITIAL": "HEALTH_VPN_DURATION_INITIAL is obsolete", "HEALTH_VPN_DURATION_INITIAL": "HEALTH_VPN_DURATION_INITIAL is obsolete",
"HEALTH_VPN_DURATION_ADDITION": "HEALTH_VPN_DURATION_ADDITION is obsolete", "HEALTH_VPN_DURATION_ADDITION": "HEALTH_VPN_DURATION_ADDITION is obsolete",
"DNS_SERVER": "DNS_SERVER is obsolete because the forwarding server is always enabled.", "DNS_KEEP_NAMESERVER": "DNS_KEEP_NAMESERVER is obsolete because you should use the built-in server which now " +
"DOT": "DOT is obsolete because the forwarding server is always enabled.", "forwards local names to private DNS resolvers found in /etc/resolv.conf at container start",
"DNS_KEEP_NAMESERVER": "DNS_KEEP_NAMESERVER is obsolete because the forwarding server is always used and " +
"forwards local names to private DNS resolvers found in /etc/resolv.conf",
} }
sortedKeys := maps.Keys(keyToMessage) sortedKeys := maps.Keys(keyToMessage)
slices.Sort(sortedKeys) slices.Sort(sortedKeys)
+49 -2
View File
@@ -20,6 +20,9 @@ const (
// DNS contains settings to configure DNS. // DNS contains settings to configure DNS.
type DNS struct { type DNS struct {
// ServerEnabled indicates if the DNS server should be enabled.
// It defaults to true and cannot be nil in the internal state.
ServerEnabled *bool `json:"enabled"`
// UpstreamType can be [DNSUpstreamTypeDot], [DNSUpstreamTypeDoh] // UpstreamType can be [DNSUpstreamTypeDot], [DNSUpstreamTypeDoh]
// or [DNSUpstreamTypePlain]. It defaults to [DNSUpstreamTypeDot]. // or [DNSUpstreamTypePlain]. It defaults to [DNSUpstreamTypeDot].
UpstreamType string `json:"upstream_type"` UpstreamType string `json:"upstream_type"`
@@ -52,6 +55,13 @@ func (d DNS) validate() (err error) {
return fmt.Errorf("DNS upstream type is not valid: %s", d.UpstreamType) return fmt.Errorf("DNS upstream type is not valid: %s", d.UpstreamType)
} }
if !*d.ServerEnabled {
err = d.validateForServerOff()
if err != nil {
return err
}
}
const minUpdatePeriod = 30 * time.Second const minUpdatePeriod = 30 * time.Second
if *d.UpdatePeriod != 0 && *d.UpdatePeriod < minUpdatePeriod { if *d.UpdatePeriod != 0 && *d.UpdatePeriod < minUpdatePeriod {
return fmt.Errorf("update period is too short: %s must be bigger than %s", return fmt.Errorf("update period is too short: %s must be bigger than %s",
@@ -90,8 +100,26 @@ func (d DNS) validate() (err error) {
return nil return nil
} }
func (d DNS) validateForServerOff() (err error) {
switch {
case d.UpstreamType != DNSUpstreamTypePlain:
return fmt.Errorf("upstream type %s must be %s if the built-in DNS server is disabled",
d.UpstreamType, DNSUpstreamTypePlain)
case len(d.UpstreamPlainAddresses) == 0:
return fmt.Errorf("if DNS is disabled, at least one upstream plain address must be set")
}
for _, addrPort := range d.UpstreamPlainAddresses {
const defaultDNSPort = 53
if addrPort.Port() != defaultDNSPort {
return fmt.Errorf("invalid DNS port in %s: must be %d", addrPort, defaultDNSPort)
}
}
return nil
}
func (d *DNS) Copy() (copied DNS) { func (d *DNS) Copy() (copied DNS) {
return DNS{ return DNS{
ServerEnabled: gosettings.CopyPointer(d.ServerEnabled),
UpstreamType: d.UpstreamType, UpstreamType: d.UpstreamType,
UpdatePeriod: gosettings.CopyPointer(d.UpdatePeriod), UpdatePeriod: gosettings.CopyPointer(d.UpdatePeriod),
Providers: gosettings.CopySlice(d.Providers), Providers: gosettings.CopySlice(d.Providers),
@@ -106,6 +134,7 @@ func (d *DNS) Copy() (copied DNS) {
// settings object with any field set in the other // settings object with any field set in the other
// settings. // settings.
func (d *DNS) overrideWith(other DNS) { func (d *DNS) overrideWith(other DNS) {
d.ServerEnabled = gosettings.OverrideWithPointer(d.ServerEnabled, other.ServerEnabled)
d.UpstreamType = gosettings.OverrideWithComparable(d.UpstreamType, other.UpstreamType) d.UpstreamType = gosettings.OverrideWithComparable(d.UpstreamType, other.UpstreamType)
d.UpdatePeriod = gosettings.OverrideWithPointer(d.UpdatePeriod, other.UpdatePeriod) d.UpdatePeriod = gosettings.OverrideWithPointer(d.UpdatePeriod, other.UpdatePeriod)
d.Providers = gosettings.OverrideWithSlice(d.Providers, other.Providers) d.Providers = gosettings.OverrideWithSlice(d.Providers, other.Providers)
@@ -116,7 +145,12 @@ func (d *DNS) overrideWith(other DNS) {
} }
func (d *DNS) setDefaults() { func (d *DNS) setDefaults() {
d.UpstreamType = gosettings.DefaultComparable(d.UpstreamType, DNSUpstreamTypeDot) d.ServerEnabled = gosettings.DefaultPointer(d.ServerEnabled, true)
defaultUpstreamType := DNSUpstreamTypeDot
if !*d.ServerEnabled {
defaultUpstreamType = DNSUpstreamTypePlain
}
d.UpstreamType = gosettings.DefaultComparable(d.UpstreamType, defaultUpstreamType)
const defaultUpdatePeriod = 24 * time.Hour const defaultUpdatePeriod = 24 * time.Hour
d.UpdatePeriod = gosettings.DefaultPointer(d.UpdatePeriod, defaultUpdatePeriod) d.UpdatePeriod = gosettings.DefaultPointer(d.UpdatePeriod, defaultUpdatePeriod)
d.UpstreamPlainAddresses = gosettings.DefaultSlice(d.UpstreamPlainAddresses, []netip.AddrPort{}) d.UpstreamPlainAddresses = gosettings.DefaultSlice(d.UpstreamPlainAddresses, []netip.AddrPort{})
@@ -139,6 +173,14 @@ func (d DNS) String() string {
func (d DNS) toLinesNode() (node *gotree.Node) { func (d DNS) toLinesNode() (node *gotree.Node) {
node = gotree.New("DNS settings:") node = gotree.New("DNS settings:")
if !*d.ServerEnabled {
plainServers := node.Append("Plain DNS servers to use directly:")
for _, addr := range d.UpstreamPlainAddresses {
plainServers.Append(addr.String())
}
return node
}
node.Appendf("Upstream resolver type: %s", d.UpstreamType) node.Appendf("Upstream resolver type: %s", d.UpstreamType)
upstreamResolvers := node.Append("Upstream resolvers:") upstreamResolvers := node.Append("Upstream resolvers:")
@@ -174,6 +216,11 @@ func (d DNS) toLinesNode() (node *gotree.Node) {
} }
func (d *DNS) read(r *reader.Reader) (err error) { func (d *DNS) read(r *reader.Reader) (err error) {
d.ServerEnabled, err = r.BoolPtr("DNS_SERVER", reader.RetroKeys("DOT"))
if err != nil {
return err
}
d.UpstreamType = r.String("DNS_UPSTREAM_RESOLVER_TYPE") d.UpstreamType = r.String("DNS_UPSTREAM_RESOLVER_TYPE")
d.UpdatePeriod, err = r.DurationPtr("DNS_UPDATE_PERIOD") d.UpdatePeriod, err = r.DurationPtr("DNS_UPDATE_PERIOD")
@@ -207,7 +254,7 @@ func (d *DNS) read(r *reader.Reader) (err error) {
} }
func (d *DNS) readUpstreamPlainAddresses(r *reader.Reader) (err error) { func (d *DNS) readUpstreamPlainAddresses(r *reader.Reader) (err error) {
// If DNS_UPSTREAM_PLAIN_ADDRESSES is set, the user must also set DNS_UPSTREAM_TYPE=plain // If DNS_UPSTREAM_PLAIN_ADDRESSES is set, the user must also set DNS_UPSTREAM_RESOLVER_TYPE=plain
// for these to be used. This is an added safety measure to reduce misunderstandings, and // for these to be used. This is an added safety measure to reduce misunderstandings, and
// reduce odd settings overrides. // reduce odd settings overrides.
d.UpstreamPlainAddresses, err = r.CSVNetipAddrPorts("DNS_UPSTREAM_PLAIN_ADDRESSES") d.UpstreamPlainAddresses, err = r.CSVNetipAddrPorts("DNS_UPSTREAM_PLAIN_ADDRESSES")
+22 -12
View File
@@ -33,9 +33,22 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
for { for {
settings = l.GetSettings() settings = l.GetSettings()
var err error var err error
runError, err = l.setupServer(ctx, settings) if *settings.ServerEnabled { //nolint:nestif
if err == nil { runError, err = l.setupServer(ctx, settings)
break if err == nil {
l.logger.Infof("ready and using DNS server with %s upstream resolvers", settings.UpstreamType)
err = l.updateFiles(ctx, settings)
if err != nil {
l.logger.Warn("downloading block lists failed, skipping: " + err.Error())
}
break
}
} else {
err = l.usePlainServers(settings.UpstreamPlainAddresses)
if err == nil {
l.logger.Infof("ready and using plain DNS resolvers: %v", settings.UpstreamPlainAddresses)
break
}
} }
l.signalOrSetStatus(constants.Crashed) l.signalOrSetStatus(constants.Crashed)
@@ -46,12 +59,6 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
} }
l.backoffTime = defaultBackoffTime l.backoffTime = defaultBackoffTime
l.logger.Infof("ready and using DNS server with %s upstream resolvers", settings.UpstreamType)
err = l.updateFiles(ctx, settings)
if err != nil {
l.logger.Warn("downloading block lists failed, skipping: " + err.Error())
}
l.signalOrSetStatus(constants.Running) l.signalOrSetStatus(constants.Running)
l.userTrigger = false l.userTrigger = false
@@ -74,13 +81,13 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
l.stopServer() l.stopServerIfAny()
// TODO revert OS and Go nameserver when exiting // TODO revert OS and Go nameserver when exiting
return true return true
case <-l.stop: case <-l.stop:
l.userTrigger = true l.userTrigger = true
l.logger.Info("stopping") l.logger.Info("stopping")
l.stopServer() l.stopServerIfAny()
l.stopped <- struct{}{} l.stopped <- struct{}{}
case <-l.start: case <-l.start:
l.userTrigger = true l.userTrigger = true
@@ -94,7 +101,10 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo
} }
} }
func (l *Loop) stopServer() { func (l *Loop) stopServerIfAny() {
if l.server == nil {
return
}
stopErr := l.server.Stop() stopErr := l.server.Stop()
if stopErr != nil { if stopErr != nil {
l.logger.Error("stopping server: " + stopErr.Error()) l.logger.Error("stopping server: " + stopErr.Error())
+23
View File
@@ -3,6 +3,7 @@ package dns
import ( import (
"context" "context"
"fmt" "fmt"
"net/netip"
"github.com/qdm12/dns/v2/pkg/middlewares/filter/update" "github.com/qdm12/dns/v2/pkg/middlewares/filter/update"
"github.com/qdm12/dns/v2/pkg/nameserver" "github.com/qdm12/dns/v2/pkg/nameserver"
@@ -45,3 +46,25 @@ func (l *Loop) setupServer(ctx context.Context, settings settings.DNS) (runError
return runError, nil return runError, nil
} }
func (l *Loop) usePlainServers(addrPorts []netip.AddrPort) (err error) {
nameserver.UseDNSInternally(nameserver.SettingsInternalDNS{
AddrPort: addrPorts[0],
})
addresses := make([]netip.Addr, len(addrPorts))
for i, addrPort := range addrPorts {
const defaultDNSPort = 53
if addrPort.Port() != defaultDNSPort {
return fmt.Errorf("invalid DNS port: %d, must be %d", addrPort.Port(), defaultDNSPort)
}
addresses[i] = addrPort.Addr()
}
err = nameserver.UseDNSSystemWide(nameserver.SettingsSystemDNS{
IPs: addresses,
ResolvPath: l.resolvConf,
})
if err != nil {
return fmt.Errorf("using DNS system wide: %w", err)
}
return nil
}