Skip to content

Commit

Permalink
Merge pull request #1191 from apernet/wip-sni-guard
Browse files Browse the repository at this point in the history
feat: local cert loader & sni guard
  • Loading branch information
tobyxdd authored Aug 25, 2024
2 parents 903666f + d4b9c5a commit 21ea2a0
Show file tree
Hide file tree
Showing 9 changed files with 581 additions and 18 deletions.
48 changes: 32 additions & 16 deletions app/cmd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ type serverConfigObfs struct {
}

type serverConfigTLS struct {
Cert string `mapstructure:"cert"`
Key string `mapstructure:"key"`
Cert string `mapstructure:"cert"`
Key string `mapstructure:"key"`
SNIGuard string `mapstructure:"sniGuard"` // "disable", "dns-san", "strict"
}

type serverConfigACME struct {
Expand Down Expand Up @@ -291,30 +292,45 @@ func (c *serverConfig) fillTLSConfig(hyConfig *server.Config) error {
return configError{Field: "tls", Err: errors.New("cannot set both tls and acme")}
}
if c.TLS != nil {
// SNI guard
var sniGuard utils.SNIGuardFunc
switch strings.ToLower(c.TLS.SNIGuard) {
case "", "dns-san":
sniGuard = utils.SNIGuardDNSSAN
case "strict":
sniGuard = utils.SNIGuardStrict
case "disable":
sniGuard = nil
default:
return configError{Field: "tls.sniGuard", Err: errors.New("unsupported SNI guard")}
}
// Local TLS cert
if c.TLS.Cert == "" || c.TLS.Key == "" {
return configError{Field: "tls", Err: errors.New("empty cert or key path")}
}
certLoader := &utils.LocalCertificateLoader{
CertFile: c.TLS.Cert,
KeyFile: c.TLS.Key,
SNIGuard: sniGuard,
}
// Try loading the cert-key pair here to catch errors early
// (e.g. invalid files or insufficient permissions)
certPEMBlock, err := os.ReadFile(c.TLS.Cert)
err := certLoader.InitializeCache()
if err != nil {
return configError{Field: "tls.cert", Err: err}
}
keyPEMBlock, err := os.ReadFile(c.TLS.Key)
if err != nil {
return configError{Field: "tls.key", Err: err}
}
_, err = tls.X509KeyPair(certPEMBlock, keyPEMBlock)
if err != nil {
return configError{Field: "tls", Err: fmt.Errorf("invalid cert-key pair: %w", err)}
var pathErr *os.PathError
if errors.As(err, &pathErr) {
if pathErr.Path == c.TLS.Cert {
return configError{Field: "tls.cert", Err: pathErr}
}
if pathErr.Path == c.TLS.Key {
return configError{Field: "tls.key", Err: pathErr}
}
}
return configError{Field: "tls", Err: err}
}
// Use GetCertificate instead of Certificates so that
// users can update the cert without restarting the server.
hyConfig.TLSConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(c.TLS.Cert, c.TLS.Key)
return &cert, err
}
hyConfig.TLSConfig.GetCertificate = certLoader.GetCertificate
} else {
// ACME
dataDir := c.ACME.Dir
Expand Down
5 changes: 3 additions & 2 deletions app/cmd/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ func TestServerConfig(t *testing.T) {
},
},
TLS: &serverConfigTLS{
Cert: "some.crt",
Key: "some.key",
Cert: "some.crt",
Key: "some.key",
SNIGuard: "strict",
},
ACME: &serverConfigACME{
Domains: []string{
Expand Down
1 change: 1 addition & 0 deletions app/cmd/server_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ obfs:
tls:
cert: some.crt
key: some.key
sniGuard: strict

acme:
domains:
Expand Down
198 changes: 198 additions & 0 deletions app/internal/utils/certloader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
package utils

import (
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"os"
"strings"
"sync"
"sync/atomic"
"time"
)

type LocalCertificateLoader struct {
CertFile string
KeyFile string
SNIGuard SNIGuardFunc

lock sync.Mutex
cache atomic.Pointer[localCertificateCache]
}

type SNIGuardFunc func(info *tls.ClientHelloInfo, cert *tls.Certificate) error

// localCertificateCache holds the certificate and its mod times.
// this struct is designed to be read-only.
//
// to update the cache, use LocalCertificateLoader.makeCache and
// update the LocalCertificateLoader.cache field.
type localCertificateCache struct {
certificate *tls.Certificate
certModTime time.Time
keyModTime time.Time
}

func (l *LocalCertificateLoader) InitializeCache() error {
l.lock.Lock()
defer l.lock.Unlock()

cache, err := l.makeCache()
if err != nil {
return err
}

l.cache.Store(cache)
return nil
}

func (l *LocalCertificateLoader) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := l.getCertificateWithCache()
if err != nil {
return nil, err
}

if l.SNIGuard == nil {
return cert, nil
}
err = l.SNIGuard(info, cert)
if err != nil {
return nil, err
}

return cert, nil
}

func (l *LocalCertificateLoader) checkModTime() (certModTime, keyModTime time.Time, err error) {
fi, err := os.Stat(l.CertFile)
if err != nil {
err = fmt.Errorf("failed to stat certificate file: %w", err)
return
}
certModTime = fi.ModTime()

fi, err = os.Stat(l.KeyFile)
if err != nil {
err = fmt.Errorf("failed to stat key file: %w", err)
return
}
keyModTime = fi.ModTime()
return
}

func (l *LocalCertificateLoader) makeCache() (cache *localCertificateCache, err error) {
c := &localCertificateCache{}

c.certModTime, c.keyModTime, err = l.checkModTime()
if err != nil {
return
}

cert, err := tls.LoadX509KeyPair(l.CertFile, l.KeyFile)
if err != nil {
return
}
c.certificate = &cert
if c.certificate.Leaf == nil {
// certificate.Leaf was left nil by tls.LoadX509KeyPair before Go 1.23
c.certificate.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return
}
}

cache = c
return
}

func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, error) {
cache := l.cache.Load()

certModTime, keyModTime, terr := l.checkModTime()
if terr != nil {
if cache != nil {
// use cache when file is temporarily unavailable
return cache.certificate, nil
}
return nil, terr
}

if cache != nil && cache.certModTime.Equal(certModTime) && cache.keyModTime.Equal(keyModTime) {
// cache is up-to-date
return cache.certificate, nil
}

if cache != nil {
if !l.lock.TryLock() {
// another goroutine is updating the cache
return cache.certificate, nil
}
} else {
l.lock.Lock()
}
defer l.lock.Unlock()

if l.cache.Load() != cache {
// another goroutine updated the cache
return l.cache.Load().certificate, nil
}

newCache, err := l.makeCache()
if err != nil {
if cache != nil {
// use cache when loading failed
return cache.certificate, nil
}
return nil, err
}

l.cache.Store(newCache)
return newCache.certificate, nil
}

// getNameFromClientHello returns a normalized form of hello.ServerName.
// If hello.ServerName is empty (i.e. client did not use SNI), then the
// associated connection's local address is used to extract an IP address.
//
// ref: https://github.com/caddyserver/certmagic/blob/3bad5b6bb595b09c14bd86ff0b365d302faaf5e2/handshake.go#L838
func getNameFromClientHello(hello *tls.ClientHelloInfo) string {
normalizedName := func(serverName string) string {
return strings.ToLower(strings.TrimSpace(serverName))
}
localIPFromConn := func(c net.Conn) string {
if c == nil {
return ""
}
localAddr := c.LocalAddr().String()
ip, _, err := net.SplitHostPort(localAddr)
if err != nil {
ip = localAddr
}
if scopeIDStart := strings.Index(ip, "%"); scopeIDStart > -1 {
ip = ip[:scopeIDStart]
}
return ip
}

if name := normalizedName(hello.ServerName); name != "" {
return name
}
return localIPFromConn(hello.Conn)
}

func SNIGuardDNSSAN(info *tls.ClientHelloInfo, cert *tls.Certificate) error {
if len(cert.Leaf.DNSNames) == 0 {
return nil
}
return SNIGuardStrict(info, cert)
}

func SNIGuardStrict(info *tls.ClientHelloInfo, cert *tls.Certificate) error {
hostname := getNameFromClientHello(info)
err := cert.Leaf.VerifyHostname(hostname)
if err != nil {
return fmt.Errorf("sni guard: %w", err)
}
return nil
}
Loading

0 comments on commit 21ea2a0

Please sign in to comment.