diff --git a/internal/configuration/settings/dns.go b/internal/configuration/settings/dns.go index deb808bc..263155bc 100644 --- a/internal/configuration/settings/dns.go +++ b/internal/configuration/settings/dns.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "net/netip" - "slices" "time" "github.com/qdm12/dns/v2/pkg/provider" @@ -31,8 +30,8 @@ type DNS struct { // the internal state. UpdatePeriod *time.Duration // Providers is a list of DNS providers. - // It defaults to either ["cloudflare"] or [] if the - // UpstreamPlainAddresses field is set. + // It defaults to ["cloudflare"] and is ignored if the UpstreamType is + // [DNSUpstreamTypePlain] and the UpstreamPlainAddresses field is set. Providers []string `json:"providers"` // Caching is true if the server should cache // DNS responses. @@ -45,9 +44,7 @@ type DNS struct { // UpstreamPlainAddresses are the upstream plaintext DNS resolver // addresses to use by the built-in DNS server forwarder. // Note, if the upstream type is [dnsUpstreamTypePlain] and this field is set, - // the Providers field will default to the empty slice. If the Providers field - // is set by the user, then the content of this field will be merged together - // with the plain addresses of the providers set in the Providers field. + // the Providers field is ignored. UpstreamPlainAddresses []netip.AddrPort } @@ -69,33 +66,27 @@ func (d DNS) validate() (err error) { ErrDNSUpdatePeriodTooShort, *d.UpdatePeriod, minUpdatePeriod) } - providers := provider.NewProviders() - selectedHasPlainIPv4, selectedHasPlainIPv6 := false, false - for _, providerName := range d.Providers { - provider, err := providers.Get(providerName) - if err != nil { - return err - } - if !selectedHasPlainIPv4 && len(provider.Plain.IPv4) > 0 { - selectedHasPlainIPv4 = true - } - if !selectedHasPlainIPv6 && len(provider.Plain.IPv6) > 0 { - selectedHasPlainIPv6 = true - } - } - if d.UpstreamType == DNSUpstreamTypePlain { - if *d.IPv6 && !selectedHasPlainIPv6 && - !slices.ContainsFunc(d.UpstreamPlainAddresses, func(addrPort netip.AddrPort) bool { - return addrPort.Addr().Is6() - }) { + selectedHasPlainIPv4, selectedHasPlainIPv6 := false, false + for _, addrPort := range d.UpstreamPlainAddresses { + if !selectedHasPlainIPv4 && addrPort.Addr().Is4() { + selectedHasPlainIPv4 = true + } + if !selectedHasPlainIPv6 && addrPort.Addr().Is6() { + selectedHasPlainIPv6 = true + } + if selectedHasPlainIPv4 && selectedHasPlainIPv6 { + break + } + } + switch { + case *d.IPv6 && !selectedHasPlainIPv6: return fmt.Errorf("%w: in %d addresses", ErrDNSUpstreamPlainNoIPv6, len(d.UpstreamPlainAddresses)) - } else if !selectedHasPlainIPv4 && !slices.ContainsFunc(d.UpstreamPlainAddresses, func(addrPort netip.AddrPort) bool { - return addrPort.Addr().Is4() - }) { + case !*d.IPv6 && !selectedHasPlainIPv4: return fmt.Errorf("%w: in %d addresses", ErrDNSUpstreamPlainNoIPv4, len(d.UpstreamPlainAddresses)) } } + // Note: all DNS built in providers have both IPv4 and IPv6 addresses for all modes err = d.Blacklist.validate() if err != nil { @@ -113,7 +104,7 @@ func (d *DNS) Copy() (copied DNS) { Caching: gosettings.CopyPointer(d.Caching), IPv6: gosettings.CopyPointer(d.IPv6), Blacklist: d.Blacklist.copy(), - UpstreamPlainAddresses: d.UpstreamPlainAddresses, + UpstreamPlainAddresses: gosettings.CopySlice(d.UpstreamPlainAddresses), } } @@ -135,11 +126,7 @@ func (d *DNS) setDefaults() { const defaultUpdatePeriod = 24 * time.Hour d.UpdatePeriod = gosettings.DefaultPointer(d.UpdatePeriod, defaultUpdatePeriod) d.UpstreamPlainAddresses = gosettings.DefaultSlice(d.UpstreamPlainAddresses, []netip.AddrPort{}) - defaultProviders := defaultDNSProviders() - if d.UpstreamType == DNSUpstreamTypePlain && len(d.UpstreamPlainAddresses) == 0 { - defaultProviders = []string{} - } - d.Providers = gosettings.DefaultSlice(d.Providers, defaultProviders) + d.Providers = gosettings.DefaultSlice(d.Providers, defaultDNSProviders()) d.Caching = gosettings.DefaultPointer(d.Caching, true) d.IPv6 = gosettings.DefaultPointer(d.IPv6, false) d.Blacklist.setDefaults() @@ -208,6 +195,9 @@ func (d DNS) toLinesNode() (node *gotree.Node) { } } else { node.Appendf("Upstream plain addresses: ignored because upstream type is not plain") + for _, provider := range d.Providers { + upstreamResolvers.Append(provider) + } } } else { for _, provider := range d.Providers { @@ -273,8 +263,8 @@ func (d *DNS) readUpstreamPlainAddresses(r *reader.Reader) (err error) { // Retro-compatibility - remove in v4 // If DNS_ADDRESS is set to a non-localhost address, append it to the other - // upstream plain addresses, assuming port 53, and force the upstream type to plain AND - // clear any user picked providers, to maintain retro-compatibility behavior. + // upstream plain addresses, assuming port 53, and force the upstream type to plain + // to maintain retro-compatibility behavior. serverAddress, err := r.NetipAddr("DNS_ADDRESS", reader.RetroKeys("DNS_PLAINTEXT_ADDRESS"), reader.IsRetro("DNS_UPSTREAM_PLAIN_ADDRESSES")) @@ -291,6 +281,5 @@ func (d *DNS) readUpstreamPlainAddresses(r *reader.Reader) (err error) { addrPort := netip.AddrPortFrom(serverAddress, defaultPlainPort) d.UpstreamPlainAddresses = append(d.UpstreamPlainAddresses, addrPort) d.UpstreamType = DNSUpstreamTypePlain - d.Providers = []string{} return nil } diff --git a/internal/dns/settings.go b/internal/dns/settings.go index 60840e01..e28471a2 100644 --- a/internal/dns/settings.go +++ b/internal/dns/settings.go @@ -120,22 +120,23 @@ func buildServerSettings(userSettings settings.DNS, func buildProviders(userSettings settings.DNS, localSubnets []netip.Prefix, logger Logger, ) (providers []provider.Provider) { - providersCount := len(userSettings.Providers) - if userSettings.UpstreamType == settings.DNSUpstreamTypePlain { - providersCount += len(userSettings.UpstreamPlainAddresses) - } - providers = make([]provider.Provider, 0, providersCount) - - providersData := provider.NewProviders() - for _, providerName := range userSettings.Providers { - provider, err := providersData.Get(providerName) - if err != nil { - panic(err) // this should already had been checked + userDefinedPlainAddresses := userSettings.UpstreamType == settings.DNSUpstreamTypePlain && + len(userSettings.UpstreamPlainAddresses) > 0 + if !userDefinedPlainAddresses { + providers = make([]provider.Provider, len(userSettings.Providers)) + providersData := provider.NewProviders() + for i, providerName := range userSettings.Providers { + var err error + providers[i], err = providersData.Get(providerName) + if err != nil { + panic(err) // this should already had been checked + } } - providers = append(providers, provider) + return providers } - for _, addrPort := range userSettings.UpstreamPlainAddresses { + providers = make([]provider.Provider, len(userSettings.UpstreamPlainAddresses)) + for i, addrPort := range userSettings.UpstreamPlainAddresses { addr := addrPort.Addr() if addr.IsPrivate() && !addr.IsLoopback() && !slices.ContainsFunc(localSubnets, func(prefix netip.Prefix) bool { @@ -146,15 +147,14 @@ func buildProviders(userSettings settings.DNS, localSubnets []netip.Prefix, addr, netip.PrefixFrom(addr, addr.BitLen())) } - provider := provider.Provider{ + providers[i] = provider.Provider{ Name: addrPort.String(), } if addr.Is4() { - provider.Plain.IPv4 = []netip.AddrPort{addrPort} + providers[i].Plain.IPv4 = []netip.AddrPort{addrPort} } else { - provider.Plain.IPv6 = []netip.AddrPort{addrPort} + providers[i].Plain.IPv6 = []netip.AddrPort{addrPort} } - providers = append(providers, provider) } return providers