hotfix(dns): fix behavior for DNS_UPSTREAM_PLAIN_ADDRESSES

This commit is contained in:
Quentin McGaw
2026-03-21 23:37:25 +00:00
parent 8a2e8bda0f
commit 72af17cc91
2 changed files with 43 additions and 54 deletions
+21 -32
View File
@@ -4,7 +4,6 @@ import (
"errors"
"fmt"
"net/netip"
"slices"
"time"
"github.com/qdm12/dns/v2/pkg/provider"
@@ -31,8 +30,8 @@ type DNS struct {
// the internal state.
UpdatePeriod *time.Duration
// Providers is a list of DNS providers.
// It defaults to either ["cloudflare"] or [] if the
// UpstreamPlainAddresses field is set.
// It defaults to ["cloudflare"] and is ignored if the UpstreamType is
// [DNSUpstreamTypePlain] and the UpstreamPlainAddresses field is set.
Providers []string `json:"providers"`
// Caching is true if the server should cache
// DNS responses.
@@ -45,9 +44,7 @@ type DNS struct {
// UpstreamPlainAddresses are the upstream plaintext DNS resolver
// addresses to use by the built-in DNS server forwarder.
// Note, if the upstream type is [dnsUpstreamTypePlain] and this field is set,
// the Providers field will default to the empty slice. If the Providers field
// is set by the user, then the content of this field will be merged together
// with the plain addresses of the providers set in the Providers field.
// the Providers field is ignored.
UpstreamPlainAddresses []netip.AddrPort
}
@@ -69,33 +66,27 @@ func (d DNS) validate() (err error) {
ErrDNSUpdatePeriodTooShort, *d.UpdatePeriod, minUpdatePeriod)
}
providers := provider.NewProviders()
if d.UpstreamType == DNSUpstreamTypePlain {
selectedHasPlainIPv4, selectedHasPlainIPv6 := false, false
for _, providerName := range d.Providers {
provider, err := providers.Get(providerName)
if err != nil {
return err
}
if !selectedHasPlainIPv4 && len(provider.Plain.IPv4) > 0 {
for _, addrPort := range d.UpstreamPlainAddresses {
if !selectedHasPlainIPv4 && addrPort.Addr().Is4() {
selectedHasPlainIPv4 = true
}
if !selectedHasPlainIPv6 && len(provider.Plain.IPv6) > 0 {
if !selectedHasPlainIPv6 && addrPort.Addr().Is6() {
selectedHasPlainIPv6 = true
}
if selectedHasPlainIPv4 && selectedHasPlainIPv6 {
break
}
if d.UpstreamType == DNSUpstreamTypePlain {
if *d.IPv6 && !selectedHasPlainIPv6 &&
!slices.ContainsFunc(d.UpstreamPlainAddresses, func(addrPort netip.AddrPort) bool {
return addrPort.Addr().Is6()
}) {
}
switch {
case *d.IPv6 && !selectedHasPlainIPv6:
return fmt.Errorf("%w: in %d addresses", ErrDNSUpstreamPlainNoIPv6, len(d.UpstreamPlainAddresses))
} else if !selectedHasPlainIPv4 && !slices.ContainsFunc(d.UpstreamPlainAddresses, func(addrPort netip.AddrPort) bool {
return addrPort.Addr().Is4()
}) {
case !*d.IPv6 && !selectedHasPlainIPv4:
return fmt.Errorf("%w: in %d addresses", ErrDNSUpstreamPlainNoIPv4, len(d.UpstreamPlainAddresses))
}
}
// Note: all DNS built in providers have both IPv4 and IPv6 addresses for all modes
err = d.Blacklist.validate()
if err != nil {
@@ -113,7 +104,7 @@ func (d *DNS) Copy() (copied DNS) {
Caching: gosettings.CopyPointer(d.Caching),
IPv6: gosettings.CopyPointer(d.IPv6),
Blacklist: d.Blacklist.copy(),
UpstreamPlainAddresses: d.UpstreamPlainAddresses,
UpstreamPlainAddresses: gosettings.CopySlice(d.UpstreamPlainAddresses),
}
}
@@ -135,11 +126,7 @@ func (d *DNS) setDefaults() {
const defaultUpdatePeriod = 24 * time.Hour
d.UpdatePeriod = gosettings.DefaultPointer(d.UpdatePeriod, defaultUpdatePeriod)
d.UpstreamPlainAddresses = gosettings.DefaultSlice(d.UpstreamPlainAddresses, []netip.AddrPort{})
defaultProviders := defaultDNSProviders()
if d.UpstreamType == DNSUpstreamTypePlain && len(d.UpstreamPlainAddresses) == 0 {
defaultProviders = []string{}
}
d.Providers = gosettings.DefaultSlice(d.Providers, defaultProviders)
d.Providers = gosettings.DefaultSlice(d.Providers, defaultDNSProviders())
d.Caching = gosettings.DefaultPointer(d.Caching, true)
d.IPv6 = gosettings.DefaultPointer(d.IPv6, false)
d.Blacklist.setDefaults()
@@ -208,6 +195,9 @@ func (d DNS) toLinesNode() (node *gotree.Node) {
}
} else {
node.Appendf("Upstream plain addresses: ignored because upstream type is not plain")
for _, provider := range d.Providers {
upstreamResolvers.Append(provider)
}
}
} else {
for _, provider := range d.Providers {
@@ -273,8 +263,8 @@ func (d *DNS) readUpstreamPlainAddresses(r *reader.Reader) (err error) {
// Retro-compatibility - remove in v4
// If DNS_ADDRESS is set to a non-localhost address, append it to the other
// upstream plain addresses, assuming port 53, and force the upstream type to plain AND
// clear any user picked providers, to maintain retro-compatibility behavior.
// upstream plain addresses, assuming port 53, and force the upstream type to plain
// to maintain retro-compatibility behavior.
serverAddress, err := r.NetipAddr("DNS_ADDRESS",
reader.RetroKeys("DNS_PLAINTEXT_ADDRESS"),
reader.IsRetro("DNS_UPSTREAM_PLAIN_ADDRESSES"))
@@ -291,6 +281,5 @@ func (d *DNS) readUpstreamPlainAddresses(r *reader.Reader) (err error) {
addrPort := netip.AddrPortFrom(serverAddress, defaultPlainPort)
d.UpstreamPlainAddresses = append(d.UpstreamPlainAddresses, addrPort)
d.UpstreamType = DNSUpstreamTypePlain
d.Providers = []string{}
return nil
}
+14 -14
View File
@@ -120,22 +120,23 @@ func buildServerSettings(userSettings settings.DNS,
func buildProviders(userSettings settings.DNS, localSubnets []netip.Prefix,
logger Logger,
) (providers []provider.Provider) {
providersCount := len(userSettings.Providers)
if userSettings.UpstreamType == settings.DNSUpstreamTypePlain {
providersCount += len(userSettings.UpstreamPlainAddresses)
}
providers = make([]provider.Provider, 0, providersCount)
userDefinedPlainAddresses := userSettings.UpstreamType == settings.DNSUpstreamTypePlain &&
len(userSettings.UpstreamPlainAddresses) > 0
if !userDefinedPlainAddresses {
providers = make([]provider.Provider, len(userSettings.Providers))
providersData := provider.NewProviders()
for _, providerName := range userSettings.Providers {
provider, err := providersData.Get(providerName)
for i, providerName := range userSettings.Providers {
var err error
providers[i], err = providersData.Get(providerName)
if err != nil {
panic(err) // this should already had been checked
}
providers = append(providers, provider)
}
return providers
}
for _, addrPort := range userSettings.UpstreamPlainAddresses {
providers = make([]provider.Provider, len(userSettings.UpstreamPlainAddresses))
for i, addrPort := range userSettings.UpstreamPlainAddresses {
addr := addrPort.Addr()
if addr.IsPrivate() && !addr.IsLoopback() &&
!slices.ContainsFunc(localSubnets, func(prefix netip.Prefix) bool {
@@ -146,15 +147,14 @@ func buildProviders(userSettings settings.DNS, localSubnets []netip.Prefix,
addr, netip.PrefixFrom(addr, addr.BitLen()))
}
provider := provider.Provider{
providers[i] = provider.Provider{
Name: addrPort.String(),
}
if addr.Is4() {
provider.Plain.IPv4 = []netip.AddrPort{addrPort}
providers[i].Plain.IPv4 = []netip.AddrPort{addrPort}
} else {
provider.Plain.IPv6 = []netip.AddrPort{addrPort}
providers[i].Plain.IPv6 = []netip.AddrPort{addrPort}
}
providers = append(providers, provider)
}
return providers