feat(dns): re-introduce DNS_SERVER option

- force to set `DNS_UPSTREAM_RESOLVER_TYPE=plain` to avoid any confusion/security hole
- force to set `DNS_UPSTREAM_PLAIN_ADDRESSES` to addresses only with port 53
This commit is contained in:
Quentin McGaw
2026-05-05 21:15:28 +00:00
parent aab10f9d3f
commit 4ea2337668
5 changed files with 97 additions and 18 deletions
+1
View File
@@ -209,6 +209,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
HEALTH_SMALL_CHECK_TYPE=icmp \
HEALTH_RESTART_VPN=on \
# DNS
DNS_SERVER=on \
DNS_UPSTREAM_RESOLVER_TYPE=DoT \
# Note: DNS_UPSTREAM_RESOLVERS defaults to cloudflare in code if DNS_UPSTREAM_PLAIN_ADDRESSES is empty
DNS_UPSTREAM_RESOLVERS= \
@@ -14,10 +14,8 @@ func readObsolete(r *reader.Reader) (warnings []string) {
"DOT_VALIDATION_LOGLEVEL": "DOT_VALIDATION_LOGLEVEL is obsolete because DNSSEC validation is not implemented.",
"HEALTH_VPN_DURATION_INITIAL": "HEALTH_VPN_DURATION_INITIAL is obsolete",
"HEALTH_VPN_DURATION_ADDITION": "HEALTH_VPN_DURATION_ADDITION is obsolete",
"DNS_SERVER": "DNS_SERVER is obsolete because the forwarding server is always enabled.",
"DOT": "DOT is obsolete because the forwarding server is always enabled.",
"DNS_KEEP_NAMESERVER": "DNS_KEEP_NAMESERVER is obsolete because the forwarding server is always used and " +
"forwards local names to private DNS resolvers found in /etc/resolv.conf",
"DNS_KEEP_NAMESERVER": "DNS_KEEP_NAMESERVER is obsolete because you should use the built-in server which now " +
"forwards local names to private DNS resolvers found in /etc/resolv.conf at container start",
}
sortedKeys := maps.Keys(keyToMessage)
slices.Sort(sortedKeys)
+49 -2
View File
@@ -20,6 +20,9 @@ const (
// DNS contains settings to configure DNS.
type DNS struct {
// ServerEnabled indicates if the DNS server should be enabled.
// It defaults to true and cannot be nil in the internal state.
ServerEnabled *bool `json:"enabled"`
// UpstreamType can be [DNSUpstreamTypeDot], [DNSUpstreamTypeDoh]
// or [DNSUpstreamTypePlain]. It defaults to [DNSUpstreamTypeDot].
UpstreamType string `json:"upstream_type"`
@@ -52,6 +55,13 @@ func (d DNS) validate() (err error) {
return fmt.Errorf("DNS upstream type is not valid: %s", d.UpstreamType)
}
if !*d.ServerEnabled {
err = d.validateForServerOff()
if err != nil {
return err
}
}
const minUpdatePeriod = 30 * time.Second
if *d.UpdatePeriod != 0 && *d.UpdatePeriod < minUpdatePeriod {
return fmt.Errorf("update period is too short: %s must be bigger than %s",
@@ -90,8 +100,26 @@ func (d DNS) validate() (err error) {
return nil
}
func (d DNS) validateForServerOff() (err error) {
switch {
case d.UpstreamType != DNSUpstreamTypePlain:
return fmt.Errorf("upstream type %s must be %s if the built-in DNS server is disabled",
d.UpstreamType, DNSUpstreamTypePlain)
case len(d.UpstreamPlainAddresses) == 0:
return fmt.Errorf("if DNS is disabled, at least one upstream plain address must be set")
}
for _, addrPort := range d.UpstreamPlainAddresses {
const defaultDNSPort = 53
if addrPort.Port() != defaultDNSPort {
return fmt.Errorf("invalid DNS port in %s: must be %d", addrPort, defaultDNSPort)
}
}
return nil
}
func (d *DNS) Copy() (copied DNS) {
return DNS{
ServerEnabled: gosettings.CopyPointer(d.ServerEnabled),
UpstreamType: d.UpstreamType,
UpdatePeriod: gosettings.CopyPointer(d.UpdatePeriod),
Providers: gosettings.CopySlice(d.Providers),
@@ -106,6 +134,7 @@ func (d *DNS) Copy() (copied DNS) {
// settings object with any field set in the other
// settings.
func (d *DNS) overrideWith(other DNS) {
d.ServerEnabled = gosettings.OverrideWithPointer(d.ServerEnabled, other.ServerEnabled)
d.UpstreamType = gosettings.OverrideWithComparable(d.UpstreamType, other.UpstreamType)
d.UpdatePeriod = gosettings.OverrideWithPointer(d.UpdatePeriod, other.UpdatePeriod)
d.Providers = gosettings.OverrideWithSlice(d.Providers, other.Providers)
@@ -116,7 +145,12 @@ func (d *DNS) overrideWith(other DNS) {
}
func (d *DNS) setDefaults() {
d.UpstreamType = gosettings.DefaultComparable(d.UpstreamType, DNSUpstreamTypeDot)
d.ServerEnabled = gosettings.DefaultPointer(d.ServerEnabled, true)
defaultUpstreamType := DNSUpstreamTypeDot
if !*d.ServerEnabled {
defaultUpstreamType = DNSUpstreamTypePlain
}
d.UpstreamType = gosettings.DefaultComparable(d.UpstreamType, defaultUpstreamType)
const defaultUpdatePeriod = 24 * time.Hour
d.UpdatePeriod = gosettings.DefaultPointer(d.UpdatePeriod, defaultUpdatePeriod)
d.UpstreamPlainAddresses = gosettings.DefaultSlice(d.UpstreamPlainAddresses, []netip.AddrPort{})
@@ -139,6 +173,14 @@ func (d DNS) String() string {
func (d DNS) toLinesNode() (node *gotree.Node) {
node = gotree.New("DNS settings:")
if !*d.ServerEnabled {
plainServers := node.Append("Plain DNS servers to use directly:")
for _, addr := range d.UpstreamPlainAddresses {
plainServers.Append(addr.String())
}
return node
}
node.Appendf("Upstream resolver type: %s", d.UpstreamType)
upstreamResolvers := node.Append("Upstream resolvers:")
@@ -174,6 +216,11 @@ func (d DNS) toLinesNode() (node *gotree.Node) {
}
func (d *DNS) read(r *reader.Reader) (err error) {
d.ServerEnabled, err = r.BoolPtr("DNS_SERVER", reader.RetroKeys("DOT"))
if err != nil {
return err
}
d.UpstreamType = r.String("DNS_UPSTREAM_RESOLVER_TYPE")
d.UpdatePeriod, err = r.DurationPtr("DNS_UPDATE_PERIOD")
@@ -207,7 +254,7 @@ func (d *DNS) read(r *reader.Reader) (err error) {
}
func (d *DNS) readUpstreamPlainAddresses(r *reader.Reader) (err error) {
// If DNS_UPSTREAM_PLAIN_ADDRESSES is set, the user must also set DNS_UPSTREAM_TYPE=plain
// If DNS_UPSTREAM_PLAIN_ADDRESSES is set, the user must also set DNS_UPSTREAM_RESOLVER_TYPE=plain
// for these to be used. This is an added safety measure to reduce misunderstandings, and
// reduce odd settings overrides.
d.UpstreamPlainAddresses, err = r.CSVNetipAddrPorts("DNS_UPSTREAM_PLAIN_ADDRESSES")
+19 -9
View File
@@ -33,10 +33,23 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
for {
settings = l.GetSettings()
var err error
if *settings.ServerEnabled { //nolint:nestif
runError, err = l.setupServer(ctx, settings)
if err == nil {
l.logger.Infof("ready and using DNS server with %s upstream resolvers", settings.UpstreamType)
err = l.updateFiles(ctx, settings)
if err != nil {
l.logger.Warn("downloading block lists failed, skipping: " + err.Error())
}
break
}
} else {
err = l.usePlainServers(settings.UpstreamPlainAddresses)
if err == nil {
l.logger.Infof("ready and using plain DNS resolvers: %v", settings.UpstreamPlainAddresses)
break
}
}
l.signalOrSetStatus(constants.Crashed)
if ctx.Err() != nil {
@@ -46,12 +59,6 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
}
l.backoffTime = defaultBackoffTime
l.logger.Infof("ready and using DNS server with %s upstream resolvers", settings.UpstreamType)
err = l.updateFiles(ctx, settings)
if err != nil {
l.logger.Warn("downloading block lists failed, skipping: " + err.Error())
}
l.signalOrSetStatus(constants.Running)
l.userTrigger = false
@@ -74,13 +81,13 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo
for {
select {
case <-ctx.Done():
l.stopServer()
l.stopServerIfAny()
// TODO revert OS and Go nameserver when exiting
return true
case <-l.stop:
l.userTrigger = true
l.logger.Info("stopping")
l.stopServer()
l.stopServerIfAny()
l.stopped <- struct{}{}
case <-l.start:
l.userTrigger = true
@@ -94,7 +101,10 @@ func (l *Loop) runWait(ctx context.Context, runError <-chan error) (exitLoop boo
}
}
func (l *Loop) stopServer() {
func (l *Loop) stopServerIfAny() {
if l.server == nil {
return
}
stopErr := l.server.Stop()
if stopErr != nil {
l.logger.Error("stopping server: " + stopErr.Error())
+23
View File
@@ -3,6 +3,7 @@ package dns
import (
"context"
"fmt"
"net/netip"
"github.com/qdm12/dns/v2/pkg/middlewares/filter/update"
"github.com/qdm12/dns/v2/pkg/nameserver"
@@ -45,3 +46,25 @@ func (l *Loop) setupServer(ctx context.Context, settings settings.DNS) (runError
return runError, nil
}
func (l *Loop) usePlainServers(addrPorts []netip.AddrPort) (err error) {
nameserver.UseDNSInternally(nameserver.SettingsInternalDNS{
AddrPort: addrPorts[0],
})
addresses := make([]netip.Addr, len(addrPorts))
for i, addrPort := range addrPorts {
const defaultDNSPort = 53
if addrPort.Port() != defaultDNSPort {
return fmt.Errorf("invalid DNS port: %d, must be %d", addrPort.Port(), defaultDNSPort)
}
addresses[i] = addrPort.Addr()
}
err = nameserver.UseDNSSystemWide(nameserver.SettingsSystemDNS{
IPs: addresses,
ResolvPath: l.resolvConf,
})
if err != nil {
return fmt.Errorf("using DNS system wide: %w", err)
}
return nil
}