hotfix(dns): fix behavior for DNS_UPSTREAM_PLAIN_ADDRESSES

This commit is contained in:
Quentin McGaw
2026-03-21 23:37:25 +00:00
parent 8a2e8bda0f
commit 72af17cc91
2 changed files with 43 additions and 54 deletions
+26 -37
View File
@@ -4,7 +4,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/netip" "net/netip"
"slices"
"time" "time"
"github.com/qdm12/dns/v2/pkg/provider" "github.com/qdm12/dns/v2/pkg/provider"
@@ -31,8 +30,8 @@ type DNS struct {
// the internal state. // the internal state.
UpdatePeriod *time.Duration UpdatePeriod *time.Duration
// Providers is a list of DNS providers. // Providers is a list of DNS providers.
// It defaults to either ["cloudflare"] or [] if the // It defaults to ["cloudflare"] and is ignored if the UpstreamType is
// UpstreamPlainAddresses field is set. // [DNSUpstreamTypePlain] and the UpstreamPlainAddresses field is set.
Providers []string `json:"providers"` Providers []string `json:"providers"`
// Caching is true if the server should cache // Caching is true if the server should cache
// DNS responses. // DNS responses.
@@ -45,9 +44,7 @@ type DNS struct {
// UpstreamPlainAddresses are the upstream plaintext DNS resolver // UpstreamPlainAddresses are the upstream plaintext DNS resolver
// addresses to use by the built-in DNS server forwarder. // addresses to use by the built-in DNS server forwarder.
// Note, if the upstream type is [dnsUpstreamTypePlain] and this field is set, // Note, if the upstream type is [dnsUpstreamTypePlain] and this field is set,
// the Providers field will default to the empty slice. If the Providers field // the Providers field is ignored.
// is set by the user, then the content of this field will be merged together
// with the plain addresses of the providers set in the Providers field.
UpstreamPlainAddresses []netip.AddrPort UpstreamPlainAddresses []netip.AddrPort
} }
@@ -69,33 +66,27 @@ func (d DNS) validate() (err error) {
ErrDNSUpdatePeriodTooShort, *d.UpdatePeriod, minUpdatePeriod) ErrDNSUpdatePeriodTooShort, *d.UpdatePeriod, minUpdatePeriod)
} }
providers := provider.NewProviders()
selectedHasPlainIPv4, selectedHasPlainIPv6 := false, false
for _, providerName := range d.Providers {
provider, err := providers.Get(providerName)
if err != nil {
return err
}
if !selectedHasPlainIPv4 && len(provider.Plain.IPv4) > 0 {
selectedHasPlainIPv4 = true
}
if !selectedHasPlainIPv6 && len(provider.Plain.IPv6) > 0 {
selectedHasPlainIPv6 = true
}
}
if d.UpstreamType == DNSUpstreamTypePlain { if d.UpstreamType == DNSUpstreamTypePlain {
if *d.IPv6 && !selectedHasPlainIPv6 && selectedHasPlainIPv4, selectedHasPlainIPv6 := false, false
!slices.ContainsFunc(d.UpstreamPlainAddresses, func(addrPort netip.AddrPort) bool { for _, addrPort := range d.UpstreamPlainAddresses {
return addrPort.Addr().Is6() if !selectedHasPlainIPv4 && addrPort.Addr().Is4() {
}) { selectedHasPlainIPv4 = true
}
if !selectedHasPlainIPv6 && addrPort.Addr().Is6() {
selectedHasPlainIPv6 = true
}
if selectedHasPlainIPv4 && selectedHasPlainIPv6 {
break
}
}
switch {
case *d.IPv6 && !selectedHasPlainIPv6:
return fmt.Errorf("%w: in %d addresses", ErrDNSUpstreamPlainNoIPv6, len(d.UpstreamPlainAddresses)) return fmt.Errorf("%w: in %d addresses", ErrDNSUpstreamPlainNoIPv6, len(d.UpstreamPlainAddresses))
} else if !selectedHasPlainIPv4 && !slices.ContainsFunc(d.UpstreamPlainAddresses, func(addrPort netip.AddrPort) bool { case !*d.IPv6 && !selectedHasPlainIPv4:
return addrPort.Addr().Is4()
}) {
return fmt.Errorf("%w: in %d addresses", ErrDNSUpstreamPlainNoIPv4, len(d.UpstreamPlainAddresses)) return fmt.Errorf("%w: in %d addresses", ErrDNSUpstreamPlainNoIPv4, len(d.UpstreamPlainAddresses))
} }
} }
// Note: all DNS built in providers have both IPv4 and IPv6 addresses for all modes
err = d.Blacklist.validate() err = d.Blacklist.validate()
if err != nil { if err != nil {
@@ -113,7 +104,7 @@ func (d *DNS) Copy() (copied DNS) {
Caching: gosettings.CopyPointer(d.Caching), Caching: gosettings.CopyPointer(d.Caching),
IPv6: gosettings.CopyPointer(d.IPv6), IPv6: gosettings.CopyPointer(d.IPv6),
Blacklist: d.Blacklist.copy(), Blacklist: d.Blacklist.copy(),
UpstreamPlainAddresses: d.UpstreamPlainAddresses, UpstreamPlainAddresses: gosettings.CopySlice(d.UpstreamPlainAddresses),
} }
} }
@@ -135,11 +126,7 @@ func (d *DNS) setDefaults() {
const defaultUpdatePeriod = 24 * time.Hour const defaultUpdatePeriod = 24 * time.Hour
d.UpdatePeriod = gosettings.DefaultPointer(d.UpdatePeriod, defaultUpdatePeriod) d.UpdatePeriod = gosettings.DefaultPointer(d.UpdatePeriod, defaultUpdatePeriod)
d.UpstreamPlainAddresses = gosettings.DefaultSlice(d.UpstreamPlainAddresses, []netip.AddrPort{}) d.UpstreamPlainAddresses = gosettings.DefaultSlice(d.UpstreamPlainAddresses, []netip.AddrPort{})
defaultProviders := defaultDNSProviders() d.Providers = gosettings.DefaultSlice(d.Providers, defaultDNSProviders())
if d.UpstreamType == DNSUpstreamTypePlain && len(d.UpstreamPlainAddresses) == 0 {
defaultProviders = []string{}
}
d.Providers = gosettings.DefaultSlice(d.Providers, defaultProviders)
d.Caching = gosettings.DefaultPointer(d.Caching, true) d.Caching = gosettings.DefaultPointer(d.Caching, true)
d.IPv6 = gosettings.DefaultPointer(d.IPv6, false) d.IPv6 = gosettings.DefaultPointer(d.IPv6, false)
d.Blacklist.setDefaults() d.Blacklist.setDefaults()
@@ -208,6 +195,9 @@ func (d DNS) toLinesNode() (node *gotree.Node) {
} }
} else { } else {
node.Appendf("Upstream plain addresses: ignored because upstream type is not plain") node.Appendf("Upstream plain addresses: ignored because upstream type is not plain")
for _, provider := range d.Providers {
upstreamResolvers.Append(provider)
}
} }
} else { } else {
for _, provider := range d.Providers { for _, provider := range d.Providers {
@@ -273,8 +263,8 @@ func (d *DNS) readUpstreamPlainAddresses(r *reader.Reader) (err error) {
// Retro-compatibility - remove in v4 // Retro-compatibility - remove in v4
// If DNS_ADDRESS is set to a non-localhost address, append it to the other // If DNS_ADDRESS is set to a non-localhost address, append it to the other
// upstream plain addresses, assuming port 53, and force the upstream type to plain AND // upstream plain addresses, assuming port 53, and force the upstream type to plain
// clear any user picked providers, to maintain retro-compatibility behavior. // to maintain retro-compatibility behavior.
serverAddress, err := r.NetipAddr("DNS_ADDRESS", serverAddress, err := r.NetipAddr("DNS_ADDRESS",
reader.RetroKeys("DNS_PLAINTEXT_ADDRESS"), reader.RetroKeys("DNS_PLAINTEXT_ADDRESS"),
reader.IsRetro("DNS_UPSTREAM_PLAIN_ADDRESSES")) reader.IsRetro("DNS_UPSTREAM_PLAIN_ADDRESSES"))
@@ -291,6 +281,5 @@ func (d *DNS) readUpstreamPlainAddresses(r *reader.Reader) (err error) {
addrPort := netip.AddrPortFrom(serverAddress, defaultPlainPort) addrPort := netip.AddrPortFrom(serverAddress, defaultPlainPort)
d.UpstreamPlainAddresses = append(d.UpstreamPlainAddresses, addrPort) d.UpstreamPlainAddresses = append(d.UpstreamPlainAddresses, addrPort)
d.UpstreamType = DNSUpstreamTypePlain d.UpstreamType = DNSUpstreamTypePlain
d.Providers = []string{}
return nil return nil
} }
+17 -17
View File
@@ -120,22 +120,23 @@ func buildServerSettings(userSettings settings.DNS,
func buildProviders(userSettings settings.DNS, localSubnets []netip.Prefix, func buildProviders(userSettings settings.DNS, localSubnets []netip.Prefix,
logger Logger, logger Logger,
) (providers []provider.Provider) { ) (providers []provider.Provider) {
providersCount := len(userSettings.Providers) userDefinedPlainAddresses := userSettings.UpstreamType == settings.DNSUpstreamTypePlain &&
if userSettings.UpstreamType == settings.DNSUpstreamTypePlain { len(userSettings.UpstreamPlainAddresses) > 0
providersCount += len(userSettings.UpstreamPlainAddresses) if !userDefinedPlainAddresses {
} providers = make([]provider.Provider, len(userSettings.Providers))
providers = make([]provider.Provider, 0, providersCount) providersData := provider.NewProviders()
for i, providerName := range userSettings.Providers {
providersData := provider.NewProviders() var err error
for _, providerName := range userSettings.Providers { providers[i], err = providersData.Get(providerName)
provider, err := providersData.Get(providerName) if err != nil {
if err != nil { panic(err) // this should already had been checked
panic(err) // this should already had been checked }
} }
providers = append(providers, provider) return providers
} }
for _, addrPort := range userSettings.UpstreamPlainAddresses { providers = make([]provider.Provider, len(userSettings.UpstreamPlainAddresses))
for i, addrPort := range userSettings.UpstreamPlainAddresses {
addr := addrPort.Addr() addr := addrPort.Addr()
if addr.IsPrivate() && !addr.IsLoopback() && if addr.IsPrivate() && !addr.IsLoopback() &&
!slices.ContainsFunc(localSubnets, func(prefix netip.Prefix) bool { !slices.ContainsFunc(localSubnets, func(prefix netip.Prefix) bool {
@@ -146,15 +147,14 @@ func buildProviders(userSettings settings.DNS, localSubnets []netip.Prefix,
addr, netip.PrefixFrom(addr, addr.BitLen())) addr, netip.PrefixFrom(addr, addr.BitLen()))
} }
provider := provider.Provider{ providers[i] = provider.Provider{
Name: addrPort.String(), Name: addrPort.String(),
} }
if addr.Is4() { if addr.Is4() {
provider.Plain.IPv4 = []netip.AddrPort{addrPort} providers[i].Plain.IPv4 = []netip.AddrPort{addrPort}
} else { } else {
provider.Plain.IPv6 = []netip.AddrPort{addrPort} providers[i].Plain.IPv6 = []netip.AddrPort{addrPort}
} }
providers = append(providers, provider)
} }
return providers return providers