diff --git a/commands/start.go b/commands/start.go index 2da723cc..4446ec34 100644 --- a/commands/start.go +++ b/commands/start.go @@ -88,7 +88,7 @@ func (g *start) Check() error { } } - err = data.Load(config.Values.DatabaseLocation, g.clusterJoinToken) + err = data.Load(config.Values.DatabaseLocation, g.clusterJoinToken, false) if err != nil { return fmt.Errorf("cannot load database: %v", err) } diff --git a/internal/config/test_disabled_max_lifetime.json b/internal/config/test_disabled_max_lifetime.json index 8418001a..34dfb1dd 100644 --- a/internal/config/test_disabled_max_lifetime.json +++ b/internal/config/test_disabled_max_lifetime.json @@ -13,6 +13,12 @@ "Port": "8080" } }, + "Clustering": { + "ClusterState": "new", + "ETCDLogLevel": "error", + "Witness": false, + "TLSManagerListenURL": "https://localhost:3434" + }, "Authenticators": { "Issuer": "192.168.121.61" }, diff --git a/internal/config/test_disabled_sliding_window.json b/internal/config/test_disabled_sliding_window.json index af110447..5d66ebd6 100644 --- a/internal/config/test_disabled_sliding_window.json +++ b/internal/config/test_disabled_sliding_window.json @@ -16,6 +16,12 @@ "Authenticators": { "Issuer": "192.168.121.61" }, + "Clustering": { + "ClusterState": "new", + "ETCDLogLevel": "error", + "Witness": false, + "TLSManagerListenURL": "https://localhost:3434" + }, "Wireguard": { "DevName": "wg0", "ListenPort": 53230, diff --git a/internal/config/test_fail_with_multiple.json b/internal/config/test_fail_with_multiple.json index 63bf8803..89556d90 100644 --- a/internal/config/test_fail_with_multiple.json +++ b/internal/config/test_fail_with_multiple.json @@ -16,6 +16,12 @@ "Authenticators": { "Issuer": "192.168.121.61" }, + "Clustering": { + "ClusterState": "new", + "ETCDLogLevel": "error", + "Witness": false, + "TLSManagerListenURL": "https://localhost:3434" + }, "Wireguard": { "DevName": "wg45", "ListenPort": 53230, diff --git a/internal/config/test_in_memory_db.json b/internal/config/test_in_memory_db.json index e1ca5f4a..6c5d5a21 100644 --- a/internal/config/test_in_memory_db.json +++ b/internal/config/test_in_memory_db.json @@ -16,6 +16,12 @@ "Authenticators": { "Issuer": "192.168.121.61" }, + "Clustering": { + "ClusterState": "new", + "ETCDLogLevel": "error", + "Witness": false, + "TLSManagerListenURL": "https://localhost:3434" + }, "Wireguard": { "DevName": "wg45", "ListenPort": 53230, diff --git a/internal/config/test_mutliple_rule_definitions_and_mfa_preference.json b/internal/config/test_mutliple_rule_definitions_and_mfa_preference.json index a2e73aa1..3cdf7bb2 100644 --- a/internal/config/test_mutliple_rule_definitions_and_mfa_preference.json +++ b/internal/config/test_mutliple_rule_definitions_and_mfa_preference.json @@ -16,6 +16,12 @@ "Authenticators": { "Issuer": "192.168.121.61" }, + "Clustering": { + "ClusterState": "new", + "ETCDLogLevel": "error", + "Witness": false, + "TLSManagerListenURL": "https://localhost:3434" + }, "Wireguard": { "DevName": "wg45", "ListenPort": 53230, diff --git a/internal/config/test_port_based_rules.json b/internal/config/test_port_based_rules.json index 391a0b1f..69ce0d53 100644 --- a/internal/config/test_port_based_rules.json +++ b/internal/config/test_port_based_rules.json @@ -16,6 +16,12 @@ "Authenticators": { "Issuer": "192.168.121.61" }, + "Clustering": { + "ClusterState": "new", + "ETCDLogLevel": "error", + "Witness": false, + "TLSManagerListenURL": "https://localhost:3434" + }, "Wireguard": { "DevName": "wg45", "ListenPort": 53230, diff --git a/internal/config/test_roaming_all_routes_mfa_priority.json b/internal/config/test_roaming_all_routes_mfa_priority.json index fec116ef..ef4967fd 100644 --- a/internal/config/test_roaming_all_routes_mfa_priority.json +++ b/internal/config/test_roaming_all_routes_mfa_priority.json @@ -16,6 +16,12 @@ "Authenticators": { "Issuer": "192.168.121.61" }, + "Clustering": { + "ClusterState": "new", + "ETCDLogLevel": "error", + "Witness": false, + "TLSManagerListenURL": "https://localhost:3434" + }, "Wireguard": { "DevName": "wg45", "ListenPort": 53230, diff --git a/internal/config/test_route_restriction_preference.json b/internal/config/test_route_restriction_preference.json index 8e7088a8..7a2049ec 100644 --- a/internal/config/test_route_restriction_preference.json +++ b/internal/config/test_route_restriction_preference.json @@ -16,6 +16,12 @@ "Authenticators": { "Issuer": "192.168.121.61" }, + "Clustering": { + "ClusterState": "new", + "ETCDLogLevel": "error", + "Witness": false, + "TLSManagerListenURL": "https://localhost:3434" + }, "Wireguard": { "DevName": "wg45", "ListenPort": 53230, diff --git a/internal/data/devices.go b/internal/data/devices.go index f795f39c..e8625bf9 100644 --- a/internal/data/devices.go +++ b/internal/data/devices.go @@ -25,6 +25,16 @@ type Device struct { Authorised time.Time } +func (d Device) String() string { + + authorised := "no" + if !d.Authorised.Equal(time.Time{}) { + authorised = d.Authorised.Format(time.DateTime) + } + + return fmt.Sprintf("device[%s:%s][active: %t, attempts: %d, authorised: %s]", d.Username, d.Address, d.Active, d.Attempts, authorised) +} + func UpdateDeviceEndpoint(address string, endpoint *net.UDPAddr) error { realKey, err := etcd.Get(context.Background(), "deviceref-"+address) @@ -36,10 +46,7 @@ func UpdateDeviceEndpoint(address string, endpoint *net.UDPAddr) error { return errors.New("device was not found") } - var realDeviceAddr string - json.Unmarshal(realKey.Kvs[0].Value, &realDeviceAddr) - - return doSafeUpdate(context.Background(), realDeviceAddr, func(gr *clientv3.GetResponse) (string, error) { + return doSafeUpdate(context.Background(), string(realKey.Kvs[0].Value), func(gr *clientv3.GetResponse) (string, error) { if len(gr.Kvs) != 1 { return "", errors.New("user device has multiple keys") } diff --git a/internal/data/events.go b/internal/data/events.go index c5544fc2..ece3df2e 100644 --- a/internal/data/events.go +++ b/internal/data/events.go @@ -111,7 +111,7 @@ func RegisterEventListener[T any](path string, isPrefix bool, f func(key string, go func(key []byte) { if err := f(string(key), currentValue, previousValue, state); err != nil { - log.Println("applying event failed: ", currentValue, "err:", err) + log.Println("applying event failed: ", state, currentValue, "err:", err) err = RaiseError(GetServerID(), err, value) if err != nil { log.Println("failed to raise error with cluster: ", err) @@ -180,6 +180,10 @@ func checkClusterHealth() { notifyHealthy() case <-time.After(1 * time.Second): + if etcdServer == nil { + return + } + leader := etcdServer.Server.Leader() if leader == 0 { notifyClusterHealthListeners("electing") diff --git a/internal/data/init.go b/internal/data/init.go index 3d5b7ab8..cccf9284 100644 --- a/internal/data/init.go +++ b/internal/data/init.go @@ -48,7 +48,7 @@ func parseUrls(values ...string) []url.URL { return urls } -func Load(path, joinToken string) error { +func Load(path, joinToken string, testing bool) error { doMigration := true if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) { @@ -84,33 +84,35 @@ func Load(path, joinToken string) error { var err error - if joinToken == "" { - TLSManager, err = manager.New(config.Values.Clustering.TLSManagerStorage, config.Values.Clustering.TLSManagerListenURL) - if err != nil { - return fmt.Errorf("tls manager: %s", err) - } - } else { - - if config.Values.Clustering.TLSManagerStorage == "" { - config.Values.Clustering.TLSManagerStorage = "certificates" - } + if TLSManager == nil { + if joinToken == "" { + TLSManager, err = manager.New(config.Values.Clustering.TLSManagerStorage, config.Values.Clustering.TLSManagerListenURL) + if err != nil { + return fmt.Errorf("tls manager: %s", err) + } + } else { - TLSManager, err = manager.Join(joinToken, config.Values.Clustering.TLSManagerStorage, map[string]func(name string, data string){ - "config.json": func(name, data string) { - err := os.WriteFile("config.json", []byte(data), 0600) - if err != nil { - log.Fatal("failed to create config.json from other cluster members config: ", err) - } + if config.Values.Clustering.TLSManagerStorage == "" { + config.Values.Clustering.TLSManagerStorage = "certificates" + } - log.Println("got additional, loading config file") - err = config.Load("config.json") - if err != nil { - log.Fatal("config supplied by other cluster member was invalid (potential version issues?): ", err) - } - }, - }) - if err != nil { - return err + TLSManager, err = manager.Join(joinToken, config.Values.Clustering.TLSManagerStorage, map[string]func(name string, data string){ + "config.json": func(name, data string) { + err := os.WriteFile("config.json", []byte(data), 0600) + if err != nil { + log.Fatal("failed to create config.json from other cluster members config: ", err) + } + + log.Println("got additional, loading config file") + err = config.Load("config.json") + if err != nil { + log.Fatal("config supplied by other cluster member was invalid (potential version issues?): ", err) + } + }, + }) + if err != nil { + return err + } } } part, err := generateRandomBytes(10) @@ -121,8 +123,11 @@ func Load(path, joinToken string) error { cfg := embed.NewConfig() cfg.Name = config.Values.Clustering.Name + if testing { + cfg.Name += part + } cfg.ClusterState = config.Values.Clustering.ClusterState - cfg.InitialClusterToken = "wag-test" + cfg.InitialClusterToken = "wag" cfg.LogLevel = config.Values.Clustering.ETCDLogLevel cfg.ListenPeerUrls = parseUrls(config.Values.Clustering.ListenAddresses...) cfg.ListenClientUrls = parseUrls(etcdUnixSocket) @@ -149,7 +154,7 @@ func Load(path, joinToken string) error { cfg.InitialCluster = cfg.InitialCluster[:len(cfg.InitialCluster)-1] - cfg.Dir = filepath.Join(config.Values.Clustering.DatabaseLocation, config.Values.Clustering.Name+".wag-node.etcd") + cfg.Dir = filepath.Join(config.Values.Clustering.DatabaseLocation, cfg.Name+".wag-node.etcd") etcdServer, err = embed.StartEtcd(cfg) if err != nil { return fmt.Errorf("error starting etcd: %s", err) @@ -454,8 +459,11 @@ func migrateFromSql(database *sql.DB) error { func TearDown() { if etcdServer != nil { - log.Println("Tearing down server") + etcd.Close() etcdServer.Close() + + etcd = nil + etcdServer = nil } } diff --git a/internal/data/user.go b/internal/data/user.go index 89729af9..0821f21a 100644 --- a/internal/data/user.go +++ b/internal/data/user.go @@ -249,12 +249,7 @@ func DeleteUser(username string) error { return err } - _, err = etcd.Delete(context.Background(), "devices-"+username+"-", clientv3.WithPrefix()) - if err != nil { - return err - } - - return err + return DeleteDevices(username) } func GetUserData(username string) (u UserModel, err error) { diff --git a/internal/router/bpf.go b/internal/router/bpf.go index b941b93b..679412f1 100644 --- a/internal/router/bpf.go +++ b/internal/router/bpf.go @@ -527,7 +527,7 @@ func RemoveUser(username string) error { for address, publicKey := range usersToAddresses[username] { err = _removePeer(publicKey, address) if err != nil { - log.Println("unable to remove peer: ", err) + log.Println("unable to remove peer: ", address, err) } } diff --git a/internal/router/fwentry.go b/internal/router/fwentry.go index 8eeb5313..d89a6827 100644 --- a/internal/router/fwentry.go +++ b/internal/router/fwentry.go @@ -37,7 +37,7 @@ func (d fwentry) Bytes() []byte { func (d *fwentry) Unpack(b []byte) error { if len(b) != 40 { - return errors.New("too short") + return errors.New("firewall entry is too short") } d.sessionExpiry = binary.LittleEndian.Uint64(b[:8]) diff --git a/internal/router/statemachine.go b/internal/router/statemachine.go index 7c88c4e3..4d8ed511 100644 --- a/internal/router/statemachine.go +++ b/internal/router/statemachine.go @@ -11,25 +11,25 @@ import ( func handleEvents(erroChan chan<- error) { - _, err := data.RegisterEventListener[data.Device](data.DevicesPrefix, true, deviceChanges) + _, err := data.RegisterEventListener(data.DevicesPrefix, true, deviceChanges) if err != nil { erroChan <- err return } - _, err = data.RegisterEventListener[data.UserModel](data.UsersPrefix, true, userChanges) + _, err = data.RegisterEventListener(data.UsersPrefix, true, userChanges) if err != nil { erroChan <- err return } - _, err = data.RegisterEventListener[acls.Acl](data.AclsPrefix, true, aclsChanges) + _, err = data.RegisterEventListener(data.AclsPrefix, true, aclsChanges) if err != nil { erroChan <- err return } - _, err = data.RegisterEventListener[[]string](data.GroupsPrefix, true, groupChanges) + _, err = data.RegisterEventListener(data.GroupsPrefix, true, groupChanges) if err != nil { erroChan <- err return @@ -45,37 +45,37 @@ func handleEvents(erroChan chan<- error) { func deviceChanges(key string, current data.Device, previous data.Device, et data.EventType) error { - log.Printf("state: %d, event: %+v", et, current) - switch et { case data.DELETED: err := RemovePeer(current.Publickey, current.Address) if err != nil { - log.Println("could not remove peer: ", err) + return fmt.Errorf("unable to remove peer: %s: err: %s", current.Address, err) } + log.Println("removed peer: ", current.Address) case data.CREATED: key, _ := wgtypes.ParseKey(current.Publickey) err := AddPeer(key, current.Username, current.Address, current.PresharedKey) if err != nil { - log.Println("error creating peer: ", err) + return fmt.Errorf("unable to create peer: %s: err: %s", current.Address, err) } + log.Println("added peer: ", current.Address) + case data.MODIFIED: if current.Publickey != previous.Publickey { key, _ := wgtypes.ParseKey(current.Publickey) err := ReplacePeer(previous, key) if err != nil { - log.Println(err) - return err + return fmt.Errorf("failed to replace peer pub key: %s", err) } + log.Println("replaced peer public key: ", current.Address) } lockout, err := data.GetLockout() if err != nil { - log.Println("cannot get lockout:", err) - return err + return fmt.Errorf("cannot get lockout: %s", err) } if (current.Attempts != previous.Attempts && current.Attempts > lockout) || // If the number of authentication attempts on a device has exceeded the max @@ -83,18 +83,19 @@ func deviceChanges(key string, current data.Device, previous data.Device, et dat current.Authorised.IsZero() { // If we've explicitly deauthorised a device err := Deauthenticate(current.Address) if err != nil { - log.Println(err) - return err + return fmt.Errorf("cannot deauthenticate device %s: %s", current.Address, err) } + log.Println("deauthed device: ", current.Address) + } if current.Authorised != previous.Authorised { if !current.Authorised.IsZero() && current.Attempts <= lockout { err := SetAuthorized(current.Address, current.Username) if err != nil { - log.Println(err) - return err + return fmt.Errorf("cannot authorize device %s: %s", current.Address, err) } + log.Println("authorized device: ", current.Address) } } @@ -111,14 +112,14 @@ func userChanges(key string, current data.UserModel, previous data.UserModel, et acls := data.GetEffectiveAcl(current.Username) err := AddUser(current.Username, acls) if err != nil { - log.Println(err) - return err + log.Printf("cannot create user %s: %s", current.Username, err) + return fmt.Errorf("cannot create user %s: %s", current.Username, err) } case data.DELETED: err := RemoveUser(current.Username) if err != nil { - log.Println(err) - return err + log.Printf("cannot remove user %s: %s", current.Username, err) + return fmt.Errorf("cannot remove user %s: %s", current.Username, err) } case data.MODIFIED: @@ -131,16 +132,16 @@ func userChanges(key string, current data.UserModel, previous data.UserModel, et err := SetLockAccount(current.Username, lock) if err != nil { - log.Println(err) - return err + log.Printf("cannot lock user %s: %s", current.Username, err) + return fmt.Errorf("cannot lock user %s: %s", current.Username, err) } } if current.Mfa != previous.Mfa || current.MfaType != previous.MfaType { err := DeauthenticateAllDevices(current.Username) if err != nil { - log.Println(err) - return err + log.Printf("cannot deauthenticate user %s: %s", current.Username, err) + return fmt.Errorf("cannot deauthenticate user %s: %s", current.Username, err) } } @@ -154,7 +155,6 @@ func aclsChanges(key string, current acls.Acl, previous acls.Acl, et data.EventT case data.CREATED, data.DELETED, data.MODIFIED: err := RefreshConfiguration() if err != nil { - log.Println(err) return fmt.Errorf("failed to refresh configuration: %s", err) } @@ -170,8 +170,7 @@ func groupChanges(key string, current []string, previous []string, et data.Event for _, username := range current { err := RefreshUserAcls(username) if err != nil { - log.Println(err) - return err + return fmt.Errorf("failed to refresh acls for user %s: %s", username, err) } } diff --git a/internal/routetypes/key.go b/internal/routetypes/key.go index a76d08b7..881ff1a7 100644 --- a/internal/routetypes/key.go +++ b/internal/routetypes/key.go @@ -29,7 +29,7 @@ func (l Key) Bytes() []byte { func (l *Key) Unpack(b []byte) error { if len(b) != 8 { - return errors.New("too short") + return errors.New("firewall key too short") } l.Prefixlen = binary.LittleEndian.Uint32(b[:4]) diff --git a/internal/routetypes/policy.go b/internal/routetypes/policy.go index e77c4067..3bbf5aa0 100644 --- a/internal/routetypes/policy.go +++ b/internal/routetypes/policy.go @@ -56,7 +56,7 @@ func (r Policy) Bytes() []byte { func (r *Policy) Unpack(b []byte) error { if len(b) < 8 { - return errors.New("too short") + return errors.New("firewall policy is too short") } r.PolicyType = binary.LittleEndian.Uint16(b[0:]) diff --git a/internal/users/user_test.go b/internal/users/user_test.go index bd2f8195..ed8d3c7e 100644 --- a/internal/users/user_test.go +++ b/internal/users/user_test.go @@ -2,6 +2,8 @@ package users import ( "fmt" + "log" + "os" "testing" "github.com/NHAS/wag/internal/config" @@ -20,22 +22,35 @@ func setupWgTest() error { return err } - err = data.Load(fmt.Sprintf("file:%s?mode=memory&cache=shared", m.String()), "") + err = data.Load(fmt.Sprintf("file:%s?mode=memory&cache=shared", m.String()), "", true) if err != nil { return fmt.Errorf("cannot load database: %v", err) } errChan := make(chan error) + err = router.Setup(errChan, false) + return err +} - return router.Setup(errChan, false) +func teatDown() { + router.TearDown(true) + data.TearDown() } -func TestCreateUser(t *testing.T) { +func TestMain(m *testing.M) { + err := setupWgTest() if err != nil { - t.Fatalf("failed to setup wg: %s", err) + log.Println(err) + os.Exit(1) } - defer router.TearDown(false) + code := m.Run() + teatDown() + + os.Exit(code) +} + +func TestCreateUser(t *testing.T) { user, err := CreateUser("fronk") if err != nil { @@ -66,11 +81,6 @@ func TestCreateUser(t *testing.T) { } func TestAddDevice(t *testing.T) { - err := setupWgTest() - if err != nil { - t.Fatalf("failed to setup wg: %s", err) - } - defer router.TearDown(false) user, err := CreateUser("fronk") if err != nil { @@ -103,11 +113,6 @@ func TestAddDevice(t *testing.T) { } func TestDeleteDevice(t *testing.T) { - err := setupWgTest() - if err != nil { - t.Fatalf("failed to setup wg: %s", err) - } - defer router.TearDown(false) user, err := CreateUser("fronk") if err != nil { @@ -134,16 +139,16 @@ func TestDeleteDevice(t *testing.T) { t.Fatal("unable to get all devices:", err) } - if len(devices) != 0 { - t.Fatal("removed only device, should be no devices left in db") + for _, device := range devices { + if device.Publickey == pubkey.String() { + t.Fatal("device with matching public key was found in db") + return + } } + } func TestDeleteUser(t *testing.T) { - err := setupWgTest() - if err != nil { - t.Fatalf("failed to setup wg: %s", err) - } user, err := CreateUser("fronk") if err != nil { @@ -181,6 +186,8 @@ func TestDeleteUser(t *testing.T) { } if len(devices) != 0 { + t.Log(len(devices)) + t.Log(devices) t.Fatal("removed only user, should be no devices left in db") } @@ -192,4 +199,5 @@ func TestDeleteUser(t *testing.T) { if len(users) != 0 { t.Fatal("removed only user, should be no users left in db") } + }