diff --git a/auth/cra.go b/auth/cra.go index 9d0d2c1..913135e 100644 --- a/auth/cra.go +++ b/auth/cra.go @@ -11,6 +11,7 @@ import ( "golang.org/x/crypto/pbkdf2" "github.com/xconnio/wampproto-go/messages" + "github.com/xconnio/wampproto-go/util" ) const MethodCRA = "wampcra" @@ -54,8 +55,8 @@ func (a *craAuthenticator) Authenticate(challenge messages.Challenge) (*messages // If no salt given, use raw password as key. if saltStr != "" { // If salting info give, then compute a derived key using PBKDF2. - iters, _ := messages.AsInt64(challenge.Extra()["iterations"]) - keylen, _ := messages.AsInt64(challenge.Extra()["keylen"]) + iters, _ := util.AsInt64(challenge.Extra()["iterations"]) + keylen, _ := util.AsInt64(challenge.Extra()["keylen"]) rawSecret = DeriveCRAKey(saltStr, a.secret, int(iters), int(keylen)) } else { diff --git a/messages/hello.go b/messages/hello.go index 0025449..b4e10e6 100644 --- a/messages/hello.go +++ b/messages/hello.go @@ -2,6 +2,8 @@ package messages import ( "fmt" + + "github.com/xconnio/wampproto-go/util" ) const MessageTypeHello = 1 @@ -94,7 +96,7 @@ func (h *Hello) Parse(wampMsg []any) error { roles = map[string]any{} } - authMethodsStr, err := AnysToStrings(authMethods) + authMethodsStr, err := util.AnysToStrings(authMethods) if err != nil { return fmt.Errorf("hello: failed to parse auth methods: %w", err) } diff --git a/messages/validator.go b/messages/validator.go index 63877da..52c2dca 100644 --- a/messages/validator.go +++ b/messages/validator.go @@ -3,6 +3,8 @@ package messages import ( "errors" "fmt" + + "github.com/xconnio/wampproto-go/util" ) const errString = "item at index %d must be of type %s but was %T" @@ -61,7 +63,7 @@ func sanityCheck(wampMsg []any, minLength, maxLength int) error { } func validateID(wampMsg []any, index int) (int64, error) { - item, ok := AsInt64(wampMsg[index]) + item, ok := util.AsInt64(wampMsg[index]) if !ok { return 0, fmt.Errorf(errString, index, "int64", wampMsg[index]) } @@ -286,43 +288,3 @@ func ValidateMessage(wampMsg []any, spec ValidationSpec) (*Fields, error) { return f, nil } - -func AsInt64(i interface{}) (int64, bool) { - switch v := i.(type) { - case int64: - return v, true - case uint64: - return int64(v), true // #nosec - case uint8: - return int64(v), true - case int: - return int64(v), true - case int8: - return int64(v), true - case int32: - return int64(v), true - case uint: - return int64(v), true // #nosec - case uint16: - return int64(v), true - case uint32: - return int64(v), true - case float64: - return int64(v), true - case float32: - return int64(v), true - } - return 0, false -} - -func AnysToStrings(input []any) ([]string, error) { - result := make([]string, 0, len(input)) - for _, item := range input { - str, ok := item.(string) - if !ok { - return nil, fmt.Errorf("element %v is not a string", item) - } - result = append(result, str) - } - return result, nil -} diff --git a/messages/validator_test.go b/messages/validator_test.go index ea4879a..9b89ab2 100644 --- a/messages/validator_test.go +++ b/messages/validator_test.go @@ -413,54 +413,3 @@ item at index 3 must be of type []any but was string`, }, err.Error()) }) } - -func TestAsInt64(t *testing.T) { - t.Run("ValidConversion", func(t *testing.T) { - tests := []struct { - input interface{} - expected int64 - }{ - {input: int64(123), expected: 123}, - {input: uint64(456), expected: 456}, - {input: uint8(7), expected: 7}, - {input: 890, expected: 890}, - {input: int8(-12), expected: -12}, - {input: int32(345), expected: 345}, - {input: uint(678), expected: 678}, - {input: uint16(901), expected: 901}, - {input: uint32(234), expected: 234}, - {input: 56.78, expected: 56}, - {input: float32(9.01), expected: 9}, - } - - for _, test := range tests { - result, ok := messages.AsInt64(test.input) - require.True(t, ok) - require.Equal(t, test.expected, result) - } - }) - - t.Run("InvalidConversion", func(t *testing.T) { - result, ok := messages.AsInt64("invalid") - require.False(t, ok) - require.Equal(t, int64(0), result) - }) -} - -func TestAnysToStrings(t *testing.T) { - t.Run("ValidConversion", func(t *testing.T) { - input := []any{"foo", "bar", "helloo"} - - result, err := messages.AnysToStrings(input) - require.NoError(t, err) - require.Equal(t, []string{"foo", "bar", "helloo"}, result) - }) - - t.Run("InvalidConversion", func(t *testing.T) { - input := []any{"foo", 123, "bar"} - - _, err := messages.AnysToStrings(input) - require.Error(t, err) - require.Contains(t, err.Error(), "element 123 is not a string") - }) -} diff --git a/serializers/helpers.go b/serializers/helpers.go index 71ed802..4531691 100644 --- a/serializers/helpers.go +++ b/serializers/helpers.go @@ -4,10 +4,11 @@ import ( "fmt" "github.com/xconnio/wampproto-go/messages" + "github.com/xconnio/wampproto-go/util" ) func ToMessage(wampMsg []any) (messages.Message, error) { - messageType, _ := messages.AsInt64(wampMsg[0]) + messageType, _ := util.AsInt64(wampMsg[0]) var msg messages.Message switch messageType { case messages.MessageTypeAbort: diff --git a/util/util.go b/util/util.go new file mode 100644 index 0000000..4dde301 --- /dev/null +++ b/util/util.go @@ -0,0 +1,91 @@ +package util + +import "fmt" + +func AsInt64(i any) (int64, bool) { + switch v := i.(type) { + case int64: + return v, true + case uint64: + return int64(v), true // #nosec + case uint8: + return int64(v), true + case int: + return int64(v), true + case int8: + return int64(v), true + case int32: + return int64(v), true + case uint: + return int64(v), true // #nosec + case uint16: + return int64(v), true + case uint32: + return int64(v), true + case float64: + return int64(v), true + case float32: + return int64(v), true + } + return 0, false +} + +func AsFloat64(v interface{}) (float64, bool) { + switch v := v.(type) { + case float64: + return v, true + case float32: + return float64(v), true + case int64: + return float64(v), true + case uint64: + return float64(v), true + case int: + return float64(v), true + case int8: + return float64(v), true + case int32: + return float64(v), true + case uint: + return float64(v), true + case uint32: + return float64(v), true + case uint8: + return float64(v), true + case uint16: + return float64(v), true + } + return 0.0, false +} + +func AsBool(i any) (bool, bool) { + boolean, ok := i.(bool) + return boolean, ok +} + +func ToBool(i any) bool { + boolean, _ := i.(bool) + return boolean +} + +func AsString(i any) (string, bool) { + str, ok := i.(string) + return str, ok +} + +func ToString(i any) string { + str, _ := i.(string) + return str +} + +func AnysToStrings(input []any) ([]string, error) { + result := make([]string, 0, len(input)) + for _, item := range input { + str, ok := item.(string) + if !ok { + return nil, fmt.Errorf("element %v is not a string", item) + } + result = append(result, str) + } + return result, nil +} diff --git a/util/util_test.go b/util/util_test.go new file mode 100644 index 0000000..d2ab245 --- /dev/null +++ b/util/util_test.go @@ -0,0 +1,136 @@ +package util_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/xconnio/wampproto-go/util" +) + +func TestAsInt64(t *testing.T) { + t.Run("ValidConversion", func(t *testing.T) { + tests := []struct { + input interface{} + expected int64 + }{ + {input: int64(123), expected: 123}, + {input: uint64(456), expected: 456}, + {input: uint8(7), expected: 7}, + {input: 890, expected: 890}, + {input: int8(-12), expected: -12}, + {input: int32(345), expected: 345}, + {input: uint(678), expected: 678}, + {input: uint16(901), expected: 901}, + {input: uint32(234), expected: 234}, + {input: 56.78, expected: 56}, + {input: float32(9.01), expected: 9}, + } + + for _, test := range tests { + result, ok := util.AsInt64(test.input) + require.True(t, ok) + require.Equal(t, test.expected, result) + } + }) + + t.Run("InvalidConversion", func(t *testing.T) { + result, ok := util.AsInt64("invalid") + require.False(t, ok) + require.Equal(t, int64(0), result) + }) +} + +func TestAsFloat64(t *testing.T) { + t.Run("ValidConversion", func(t *testing.T) { + tests := []struct { + input any + expected float64 + }{ + {input: float64(123.45), expected: 123.45}, + {input: float32(67.89), expected: 67.88999938964844}, + {input: int64(123), expected: 123.0}, + {input: uint64(456), expected: 456.0}, + {input: 789, expected: 789.0}, + {input: int8(-12), expected: -12.0}, + {input: int32(345), expected: 345.0}, + {input: uint(678), expected: 678.0}, + {input: uint32(234), expected: 234.0}, + {input: uint8(7), expected: 7.0}, + {input: uint16(90), expected: 90.0}, + } + + for _, test := range tests { + result, ok := util.AsFloat64(test.input) + require.True(t, ok) + require.Equal(t, test.expected, result) + } + }) + + t.Run("InvalidConversion", func(t *testing.T) { + result, ok := util.AsFloat64("invalid") + require.False(t, ok) + require.Equal(t, float64(0), result) + }) +} + +func TestAsBool(t *testing.T) { + t.Run("ValidConversion", func(t *testing.T) { + result, ok := util.AsBool(true) + require.True(t, ok) + require.True(t, result) + + result, ok = util.AsBool(false) + require.True(t, ok) + require.False(t, result) + }) + + t.Run("InvalidConversion", func(t *testing.T) { + result, ok := util.AsBool(123) + require.False(t, ok) + require.False(t, result) + }) +} + +func TestToBool(t *testing.T) { + require.True(t, util.ToBool(true)) + require.False(t, util.ToBool(false)) + require.False(t, util.ToBool(123)) +} + +func TestAsString(t *testing.T) { + t.Run("ValidConversion", func(t *testing.T) { + result, ok := util.AsString("hello") + require.True(t, ok) + require.Equal(t, "hello", result) + }) + + t.Run("InvalidConversion", func(t *testing.T) { + result, ok := util.AsString(123) + require.False(t, ok) + require.Equal(t, "", result) + }) +} + +func TestToString(t *testing.T) { + require.Equal(t, "hello", util.ToString("hello")) + require.Equal(t, "", util.ToString(123)) +} + +func TestAnysToStrings(t *testing.T) { + t.Run("ValidConversion", func(t *testing.T) { + input := []any{"foo", "bar", "helloo"} + + result, err := util.AnysToStrings(input) + require.NoError(t, err) + require.Equal(t, []string{"foo", "bar", "helloo"}, result) + }) + + t.Run("InvalidConversion", func(t *testing.T) { + input := []any{"foo", 123, "bar"} + + _, err := util.AnysToStrings(input) + require.Error(t, err) + require.Contains(t, err.Error(), "element 123 is not a string") + }) +}