-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1191 from apernet/wip-sni-guard
feat: local cert loader & sni guard
- Loading branch information
Showing
9 changed files
with
581 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ obfs: | |
tls: | ||
cert: some.crt | ||
key: some.key | ||
sniGuard: strict | ||
|
||
acme: | ||
domains: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
package utils | ||
|
||
import ( | ||
"crypto/tls" | ||
"crypto/x509" | ||
"fmt" | ||
"net" | ||
"os" | ||
"strings" | ||
"sync" | ||
"sync/atomic" | ||
"time" | ||
) | ||
|
||
type LocalCertificateLoader struct { | ||
CertFile string | ||
KeyFile string | ||
SNIGuard SNIGuardFunc | ||
|
||
lock sync.Mutex | ||
cache atomic.Pointer[localCertificateCache] | ||
} | ||
|
||
type SNIGuardFunc func(info *tls.ClientHelloInfo, cert *tls.Certificate) error | ||
|
||
// localCertificateCache holds the certificate and its mod times. | ||
// this struct is designed to be read-only. | ||
// | ||
// to update the cache, use LocalCertificateLoader.makeCache and | ||
// update the LocalCertificateLoader.cache field. | ||
type localCertificateCache struct { | ||
certificate *tls.Certificate | ||
certModTime time.Time | ||
keyModTime time.Time | ||
} | ||
|
||
func (l *LocalCertificateLoader) InitializeCache() error { | ||
l.lock.Lock() | ||
defer l.lock.Unlock() | ||
|
||
cache, err := l.makeCache() | ||
if err != nil { | ||
return err | ||
} | ||
|
||
l.cache.Store(cache) | ||
return nil | ||
} | ||
|
||
func (l *LocalCertificateLoader) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) { | ||
cert, err := l.getCertificateWithCache() | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
if l.SNIGuard == nil { | ||
return cert, nil | ||
} | ||
err = l.SNIGuard(info, cert) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return cert, nil | ||
} | ||
|
||
func (l *LocalCertificateLoader) checkModTime() (certModTime, keyModTime time.Time, err error) { | ||
fi, err := os.Stat(l.CertFile) | ||
if err != nil { | ||
err = fmt.Errorf("failed to stat certificate file: %w", err) | ||
return | ||
} | ||
certModTime = fi.ModTime() | ||
|
||
fi, err = os.Stat(l.KeyFile) | ||
if err != nil { | ||
err = fmt.Errorf("failed to stat key file: %w", err) | ||
return | ||
} | ||
keyModTime = fi.ModTime() | ||
return | ||
} | ||
|
||
func (l *LocalCertificateLoader) makeCache() (cache *localCertificateCache, err error) { | ||
c := &localCertificateCache{} | ||
|
||
c.certModTime, c.keyModTime, err = l.checkModTime() | ||
if err != nil { | ||
return | ||
} | ||
|
||
cert, err := tls.LoadX509KeyPair(l.CertFile, l.KeyFile) | ||
if err != nil { | ||
return | ||
} | ||
c.certificate = &cert | ||
if c.certificate.Leaf == nil { | ||
// certificate.Leaf was left nil by tls.LoadX509KeyPair before Go 1.23 | ||
c.certificate.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) | ||
if err != nil { | ||
return | ||
} | ||
} | ||
|
||
cache = c | ||
return | ||
} | ||
|
||
func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, error) { | ||
cache := l.cache.Load() | ||
|
||
certModTime, keyModTime, terr := l.checkModTime() | ||
if terr != nil { | ||
if cache != nil { | ||
// use cache when file is temporarily unavailable | ||
return cache.certificate, nil | ||
} | ||
return nil, terr | ||
} | ||
|
||
if cache != nil && cache.certModTime.Equal(certModTime) && cache.keyModTime.Equal(keyModTime) { | ||
// cache is up-to-date | ||
return cache.certificate, nil | ||
} | ||
|
||
if cache != nil { | ||
if !l.lock.TryLock() { | ||
// another goroutine is updating the cache | ||
return cache.certificate, nil | ||
} | ||
} else { | ||
l.lock.Lock() | ||
} | ||
defer l.lock.Unlock() | ||
|
||
if l.cache.Load() != cache { | ||
// another goroutine updated the cache | ||
return l.cache.Load().certificate, nil | ||
} | ||
|
||
newCache, err := l.makeCache() | ||
if err != nil { | ||
if cache != nil { | ||
// use cache when loading failed | ||
return cache.certificate, nil | ||
} | ||
return nil, err | ||
} | ||
|
||
l.cache.Store(newCache) | ||
return newCache.certificate, nil | ||
} | ||
|
||
// getNameFromClientHello returns a normalized form of hello.ServerName. | ||
// If hello.ServerName is empty (i.e. client did not use SNI), then the | ||
// associated connection's local address is used to extract an IP address. | ||
// | ||
// ref: https://github.com/caddyserver/certmagic/blob/3bad5b6bb595b09c14bd86ff0b365d302faaf5e2/handshake.go#L838 | ||
func getNameFromClientHello(hello *tls.ClientHelloInfo) string { | ||
normalizedName := func(serverName string) string { | ||
return strings.ToLower(strings.TrimSpace(serverName)) | ||
} | ||
localIPFromConn := func(c net.Conn) string { | ||
if c == nil { | ||
return "" | ||
} | ||
localAddr := c.LocalAddr().String() | ||
ip, _, err := net.SplitHostPort(localAddr) | ||
if err != nil { | ||
ip = localAddr | ||
} | ||
if scopeIDStart := strings.Index(ip, "%"); scopeIDStart > -1 { | ||
ip = ip[:scopeIDStart] | ||
} | ||
return ip | ||
} | ||
|
||
if name := normalizedName(hello.ServerName); name != "" { | ||
return name | ||
} | ||
return localIPFromConn(hello.Conn) | ||
} | ||
|
||
func SNIGuardDNSSAN(info *tls.ClientHelloInfo, cert *tls.Certificate) error { | ||
if len(cert.Leaf.DNSNames) == 0 { | ||
return nil | ||
} | ||
return SNIGuardStrict(info, cert) | ||
} | ||
|
||
func SNIGuardStrict(info *tls.ClientHelloInfo, cert *tls.Certificate) error { | ||
hostname := getNameFromClientHello(info) | ||
err := cert.Leaf.VerifyHostname(hostname) | ||
if err != nil { | ||
return fmt.Errorf("sni guard: %w", err) | ||
} | ||
return nil | ||
} |
Oops, something went wrong.