Skip to content

Commit

Permalink
Move to using variable for group memebership, fix crash when prev kv …
Browse files Browse the repository at this point in the history
…is empty
  • Loading branch information
NHAS committed May 8, 2024
1 parent 08b7446 commit 5144115
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 9 deletions.
18 changes: 17 additions & 1 deletion internal/data/acls.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"log"

"github.com/NHAS/wag/internal/acls"
"github.com/NHAS/wag/internal/config"
Expand Down Expand Up @@ -81,9 +82,10 @@ func GetEffectiveAcl(username string) acls.Acl {
resultingACLs.Allow = []string{config.Values.Wireguard.ServerAddress.String() + "/32"}

txn := etcd.Txn(context.Background())
txn.Then(clientv3.OpGet("wag-acls-*"), clientv3.OpGet("wag-acls-"+username), clientv3.OpGet("wag-membership"), clientv3.OpGet(dnsKey))
txn.Then(clientv3.OpGet("wag-acls-*"), clientv3.OpGet("wag-acls-"+username), clientv3.OpGet(MembershipKey), clientv3.OpGet(dnsKey))
resp, err := txn.Commit()
if err != nil {
log.Println("failed to get policy data for user", username, "err:", err)
return acls.Acl{}
}

Expand All @@ -95,6 +97,9 @@ func GetEffectiveAcl(username string) acls.Acl {
if err == nil {
resultingACLs.Allow = append(resultingACLs.Allow, acl.Allow...)
resultingACLs.Mfa = append(resultingACLs.Mfa, acl.Mfa...)
} else {
RaiseError(err, []byte("failed to unmarshal default acls policy"))
log.Println("failed to unmarshal default acls policy: ", err)
}
}

Expand All @@ -106,6 +111,8 @@ func GetEffectiveAcl(username string) acls.Acl {
if err == nil {
resultingACLs.Allow = append(resultingACLs.Allow, acl.Allow...)
resultingACLs.Mfa = append(resultingACLs.Mfa, acl.Mfa...)
} else {
log.Println("failed to unmarshal user specific acls: ", err)
}
}

Expand All @@ -115,6 +122,8 @@ func GetEffectiveAcl(username string) acls.Acl {

err = json.Unmarshal(resp.Responses[2].GetResponseRange().Kvs[0].Value, &rGroupLookup)
if err == nil {
log.Println("DEBUG got reverse groups map", rGroupLookup)
log.Println("DEBUG got groups map for user", username, rGroupLookup[username])

txn := etcd.Txn(context.Background())

Expand All @@ -126,6 +135,8 @@ func GetEffectiveAcl(username string) acls.Acl {

resp, err := txn.Then(ops...).Commit()
if err != nil {
log.Println("failed to get acls for groups: ", err)
RaiseError(err, []byte("failed to determine acls from groups"))
return acls.Acl{}
}

Expand All @@ -137,9 +148,12 @@ func GetEffectiveAcl(username string) acls.Acl {

err := json.Unmarshal(r.Kvs[0].Value, &acl)
if err != nil {
log.Println("failed to unmarshal acl from response: ", err, string(r.Kvs[0].Value))
continue
}

log.Println("DEBUG user", username, " acls construction, adding: ", acl)

resultingACLs.Allow = append(resultingACLs.Allow, acl.Allow...)
resultingACLs.Mfa = append(resultingACLs.Mfa, acl.Mfa...)
}
Expand All @@ -158,6 +172,8 @@ func GetEffectiveAcl(username string) acls.Acl {
for _, server := range dns {
resultingACLs.Allow = append(resultingACLs.Allow, fmt.Sprintf("%s 53/any", server))
}
} else {
log.Println("failed to unmarshal dns setting: ", err)
}
}

Expand Down
2 changes: 2 additions & 0 deletions internal/data/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ const (

externalAddressKey = "wag-config-network-external-address"
dnsKey = "wag-config-network-dns"

MembershipKey = "wag-membership"
)

func getString(key string) (ret string, err error) {
Expand Down
13 changes: 10 additions & 3 deletions internal/data/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

"github.com/NHAS/wag/pkg/queue"
"go.etcd.io/etcd/api/v3/mvccpb"
clientv3 "go.etcd.io/etcd/client/v3"
)

Expand Down Expand Up @@ -110,7 +111,7 @@ func RegisterEventListener[T any](path string, isPrefix bool, f func(key string,
}
}

go func(key, previous []byte) {
go func(key []byte, prevKv *mvccpb.KeyValue) {
if err := f(string(key), currentValue, previousValue, state); err != nil {
log.Println("applying event failed: ", state, currentValue, "err:", err)
err = RaiseError(err, value)
Expand All @@ -126,11 +127,17 @@ func RegisterEventListener[T any](path string, isPrefix bool, f func(key string,
EventsQueue.Write([]byte(fmt.Sprintf("%s[%s]", key, state)))

case MODIFIED:
EventsQueue.Write([]byte(fmt.Sprintf("%s[%s]: %s -> %s", key, state, string(previous), string(value))))

previous := "nil"
if prevKv != nil {
previous = string(prevKv.Value)
}

EventsQueue.Write([]byte(fmt.Sprintf("%s[%s]: %s -> %s", key, state, previous, string(value))))

}

}(event.Kv.Key, event.PrevKv.Value)
}(event.Kv.Key, event.PrevKv)

}
}
Expand Down
8 changes: 4 additions & 4 deletions internal/data/groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func SetGroup(group string, members []string, overwrite bool) error {
}
}

err = doSafeUpdate(context.Background(), "wag-membership", func(gr *clientv3.GetResponse) (value string, err error) {
err = doSafeUpdate(context.Background(), MembershipKey, func(gr *clientv3.GetResponse) (value string, err error) {

if len(gr.Kvs) != 1 {
return "", errors.New("bad number of membership keys")
Expand Down Expand Up @@ -112,7 +112,7 @@ func RemoveGroup(groupName string) error {
}
}

err = doSafeUpdate(context.Background(), "wag-membership", func(gr *clientv3.GetResponse) (value string, err error) {
err = doSafeUpdate(context.Background(), MembershipKey, func(gr *clientv3.GetResponse) (value string, err error) {

if len(gr.Kvs) != 1 {
return "", errors.New("bad number of membership keys")
Expand All @@ -138,7 +138,7 @@ func RemoveGroup(groupName string) error {

func GetUserGroupMembership(username string) ([]string, error) {

response, err := etcd.Get(context.Background(), "wag-membership")
response, err := etcd.Get(context.Background(), MembershipKey)
if err != nil {
return nil, err
}
Expand All @@ -159,7 +159,7 @@ func GetUserGroupMembership(username string) ([]string, error) {

func SetUserGroupMembership(username string, newGroups []string) error {

err := doSafeUpdate(context.Background(), "wag-membership", func(gr *clientv3.GetResponse) (value string, err error) {
err := doSafeUpdate(context.Background(), MembershipKey, func(gr *clientv3.GetResponse) (value string, err error) {

if len(gr.Kvs) != 1 {
return "", errors.New("bad number of membership keys")
Expand Down
2 changes: 1 addition & 1 deletion internal/data/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ func loadInitialSettings() error {
}

reverseMappingJson, _ := json.Marshal(rGroupLookup)
_, err = etcd.Put(context.Background(), "wag-membership", string(reverseMappingJson))
_, err = etcd.Put(context.Background(), MembershipKey, string(reverseMappingJson))
if err != nil {
return err
}
Expand Down
32 changes: 32 additions & 0 deletions internal/routetypes/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,35 @@ func TestParseRules(t *testing.T) {
}

}

func TestParseRulesDuplicates(t *testing.T) {
/*
"*": {
"Allow": [
"7.7.7.7",
"google.com"
]
},
"tester": {
"Mfa": [
"192.168.3.0/24",
"192.168.5.0/24"
],
"Allow": [
"4.3.3.3/32"
]
},
*/

mfaRules := []string{"192.168.33.1/32", "192.168.33.1/32"}

result, err := ParseRules(mfaRules, []string{}, []string{})
if err != nil {
t.Fatal(err)
}

if len(result) < 1 {
t.Fatal("resulting number of rules was wrong")
}

}

0 comments on commit 5144115

Please sign in to comment.