Skip to content

Commit

Permalink
perf: check internet with DNS lookup (#1034)
Browse files Browse the repository at this point in the history
DNS lookups are much cheaper than HTTP requests since we only need to check if the Internet is available.

See: https://stackoverflow.com/a/50058255
  • Loading branch information
WaterLemons2k authored Mar 7, 2024
1 parent ea85383 commit ae0f47f
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 90 deletions.
4 changes: 2 additions & 2 deletions dns/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"time"

"github.com/jeessy2/ddns-go/v6/config"
"github.com/jeessy2/ddns-go/v6/dns/internal"
"github.com/jeessy2/ddns-go/v6/dns/internet"
"github.com/jeessy2/ddns-go/v6/util"
)

Expand Down Expand Up @@ -34,7 +34,7 @@ var (

// RunTimer 定时运行
func RunTimer(delay time.Duration) {
internal.WaitForNetworkConnected(addresses)
internet.Wait(addresses)

for {
RunOnce()
Expand Down
48 changes: 0 additions & 48 deletions dns/internal/wait_net.go

This file was deleted.

56 changes: 56 additions & 0 deletions dns/internet/wait.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Package internet implements utilities for checking the Internet connection.
package internet

import (
"strings"
"time"

"github.com/jeessy2/ddns-go/v6/util"
)

const (
// fallbackDNS used when a fallback occurs.
fallbackDNS = "1.1.1.1"

// delay is the delay time for each DNS lookup.
delay = time.Second * 5
)

// Wait blocks until the Internet is connected.
//
// See also:
//
// - https://stackoverflow.com/a/50058255
// - https://github.com/ddev/ddev/blob/v1.22.7/pkg/globalconfig/global_config.go#L776
func Wait(addresses []string) {
// fallbase in case loopback DNS is unavailable and only once.
fallback := false

for {
for _, addr := range addresses {
err := util.LookupHost(addr)
// Internet is connected.
if err == nil {
return
}

if !fallback && isLoopback(err) {
util.Log("本机DNS异常! 将默认使用 %s, 可参考文档通过 -dns 自定义 DNS 服务器", fallbackDNS)
util.SetDNS(fallbackDNS)

fallback = true
continue
}

util.Log("等待网络连接: %s", err)

util.Log("%s 后重试...", delay)
time.Sleep(delay)
}
}
}

// isLoopback checks if the error is a loopback error.
func isLoopback(e error) bool {
return strings.Contains(e.Error(), "[::1]:53")
}
7 changes: 1 addition & 6 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"os/exec"
"path/filepath"
"strconv"
"strings"
"time"

"github.com/jeessy2/ddns-go/v6/config"
Expand Down Expand Up @@ -85,11 +84,7 @@ func main() {
util.SetInsecureSkipVerify()
}
if *customDNSServer != "" {
if !strings.Contains(*customDNSServer, ":") {
util.NewDialerResolver(*customDNSServer + ":53")
} else {
util.NewDialerResolver(*customDNSServer)
}
util.SetDNS(*customDNSServer)
}
os.Setenv(util.IPCacheTimesENV, strconv.Itoa(*ipCacheTimes))
switch *serviceType {
Expand Down
31 changes: 16 additions & 15 deletions util/net_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,26 @@ import (
"net"
)

// NewDialerResolver 使用 s 将 dialer.Resolver 设置为新的 net.Resolver。
//
// s:用于创建新 net.Resolver 的字符串。
func NewDialerResolver(s string) {
dialer.Resolver = newNetResolver(s)
}

// newNetResolver 当 s 不为空时返回使用 s 的 Go 内置 DNS 解析器。
//
// s:net.Resolver 的 DNS 服务器地址。
func newNetResolver(s string) *net.Resolver {
if s == "" {
return net.DefaultResolver
// SetDNS sets the dialer.Resolver to use the given DNS server.
func SetDNS(dns string) {
// Error means that the given DNS doesn't have a port. Add it.
if _, _, err := net.SplitHostPort(dns); err != nil {
dns = net.JoinHostPort(dns, "53")
}

return &net.Resolver{
dialer.Resolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial("udp", s)
return net.Dial(network, dns)
},
}
}

// LookupHost looks up the host based on the given URL using the dialer.Resolver.
// A wrapper for [net.Resolver.LookupHost].
func LookupHost(url string) error {
name := toHostname(url)

_, err := dialer.Resolver.LookupHost(context.Background(), name)
return err
}
44 changes: 27 additions & 17 deletions util/net_resolver_test.go
Original file line number Diff line number Diff line change
@@ -1,28 +1,38 @@
package util

import (
"context"
"testing"
import "testing"

const (
testDNS = "1.1.1.1"
testURL = "https://cloudflare.com"
)

// TestNewDialerResolver 测试传递 DNS 服务器地址时能否设置 dialer.Resolver。
func TestNewDialerResolver(t *testing.T) {
// 测试前重置以确保正常设置
dialer.Resolver = nil
func TestSetDNS(t *testing.T) {
SetDNS(testDNS)

NewDialerResolver("1.1.1.1:53")
if dialer.Resolver == nil {
t.Error("Failed to set dialer.Resolver")
}

// 测试后重置以确保与测试前的值一致
dialer.Resolver = nil
}

// TestNewNetResolver 测试能否通过 newNetResolver 返回的 net.Resolver 解析域名的 IP。
func TestNewNetResolver(t *testing.T) {
_, err := newNetResolver("1.1.1.1:53").LookupIP(context.Background(), "ip", "cloudflare.com")
if err != nil {
t.Errorf("Failed to lookup IP, err: %v", err)
}
func TestLookupHost(t *testing.T) {
t.Run("Valid URL", func(t *testing.T) {
if err := LookupHost(testURL); err != nil {
t.Errorf("Expected nil error, got %v", err)
}
})

t.Run("Invalid URL", func(t *testing.T) {
if err := LookupHost("invalidurl"); err == nil {
t.Error("Expected error, got nil")
}
})

t.Run("After SetDNS", func(t *testing.T) {
SetDNS(testDNS)

if err := LookupHost(testURL); err != nil {
t.Errorf("Expected nil error, got %v", err)
}
})
}
15 changes: 13 additions & 2 deletions util/string.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ package util

import "strings"

// WriteString 使用 strings.Builder 生成字符串并返回 string
// https://pkg.go.dev/strings#Builder
// WriteString creates a new string using [strings.Builder].
func WriteString(strs ...string) string {
var b strings.Builder
for _, str := range strs {
Expand All @@ -12,3 +11,15 @@ func WriteString(strs ...string) string {

return b.String()
}

// toHostname normalizes a URL with a https scheme to just its hostname.
//
// See also:
//
// - https://github.com/moby/moby/blob/v25.0.3/registry/auth.go#L132
func toHostname(url string) string {
stripped := url
stripped = strings.TrimPrefix(stripped, "https://")

return strings.Split(stripped, "/")[0]
}
43 changes: 43 additions & 0 deletions util/string_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package util

import "testing"

func TestWriteString(t *testing.T) {
tests := []struct {
input []string
expected string
}{
{[]string{"hello", "world"}, "helloworld"},
{[]string{"", "test"}, "test"},
{[]string{"hello", " ", "world"}, "hello world"},
{[]string{""}, ""},
}

for _, tt := range tests {
result := WriteString(tt.input...)
if result != tt.expected {
t.Errorf("Expected %s, but got %s", tt.expected, result)
}
}
}

func TestToHostname(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"With https scheme", "https://www.example.com", "www.example.com"},
{"With path", "www.example.com/path", "www.example.com"},
{"With https scheme and path", "https://www.example.com/path", "www.example.com"},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := toHostname(tt.input)
if result != tt.expected {
t.Errorf("Expected %s, but got %s", tt.expected, result)
}
})
}
}

0 comments on commit ae0f47f

Please sign in to comment.