Skip to content

Commit

Permalink
Add version constraints
Browse files Browse the repository at this point in the history
Signed-off-by: Kimmo Lehto <[email protected]>
  • Loading branch information
kke committed Sep 5, 2023
1 parent a97a29d commit 770cea1
Show file tree
Hide file tree
Showing 4 changed files with 320 additions and 0 deletions.
148 changes: 148 additions & 0 deletions constraint.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package version

import (
"errors"
"fmt"
"regexp"
"strings"
)

var constraintRegex = regexp.MustCompile(`^(?:(>=|>|<=|<|!=|==?)\s*)?(.+)$`)

type constraintFunc func(a, b *Version) bool
type constraint struct {
f constraintFunc
b *Version
original string
}

// Constraints is a collection of version constraint rules that can be checked against a version.
type Constraints []constraint

// NewConstraint parses a string into a Constraints object that can be used to check
// if a given version satisfies the constraint.
func NewConstraint(cs string) (Constraints, error) {
parts := strings.Split(cs, ",")
newC := make(Constraints, len(parts))
for i, p := range parts {
parts[i] = strings.TrimSpace(p)
}
for i, p := range parts {
c, err := newConstraint(p)
if err != nil {
return Constraints{}, err
}
newC[i] = c
}

return newC, nil
}

// MustConstraint is like NewConstraint but panics if the constraint is invalid.
func MustConstraint(cs string) Constraints {
c, err := NewConstraint(cs)
if err != nil {
panic("github.com/k0sproject/version: NewConstraint: " + err.Error())
}
return c
}

// Check returns true if the given version satisfies all of the constraints.
func (cs Constraints) Check(v *Version) bool {
for _, c := range cs {
if c.b.Prerelease() == "" && v.Prerelease() != "" {
return false
}
if !c.f(c.b, v) {
return false
}
}

return true
}

// CheckString is like Check but takes a string version. If the version is invalid,
// it returns false.
func (cs Constraints) CheckString(v string) bool {
vv, err := NewVersion(v)
if err != nil {
return false
}
return cs.Check(vv)
}

// String returns the original constraint string.
func (c *constraint) String() string {
return c.original
}

func newConstraint(s string) (constraint, error) {
match := constraintRegex.FindStringSubmatch(s)
if len(match) != 3 {
return constraint{}, errors.New("invalid constraint: " + s)
}

op := match[1]
f, err := opfunc(op)
if err != nil {
return constraint{}, err
}

// convert one or two digit constraints to threes digit unless it's an equality operation
if op != "" && op != "=" && op != "==" {
segments := strings.Split(match[2], ".")
if len(segments) < 3 {
lastSegment := segments[len(segments)-1]
var pre string
if strings.Contains(lastSegment, "-") {
parts := strings.Split(lastSegment, "-")
segments[len(segments)-1] = parts[0]
pre = "-" + parts[1]
}
switch len(segments) {
case 1:
// >= 1 becomes >= 1.0.0
// >= 1-rc.1 becomes >= 1.0.0-rc.1
return newConstraint(fmt.Sprintf("%s %s.0.0%s", op, segments[0], pre))
case 2:
// >= 1.1 becomes >= 1.1.0
// >= 1.1-rc.1 becomes >= 1.1.0-rc.1
return newConstraint(fmt.Sprintf("%s %s.%s.0%s", op, segments[0], segments[1], pre))
}
}
}

target, err := NewVersion(match[2])
if err != nil {
return constraint{}, err
}

return constraint{f: f, b: target, original: s}, nil
}

func opfunc(s string) (constraintFunc, error) {
switch s {
case "", "=", "==":
return eq, nil
case ">":
return gt, nil
case ">=":
return gte, nil
case "<":
return lt, nil
case "<=":
return lte, nil
case "!=":
return neq, nil
default:
return nil, errors.New("invalid operator: " + s)
}
}

func gt(a, b *Version) bool { return b.GreaterThan(a) }
func lt(a, b *Version) bool { return b.LessThan(a) }
func gte(a, b *Version) bool { return b.GreaterThanOrEqual(a) }
func lte(a, b *Version) bool { return b.LessThanOrEqual(a) }
func eq(a, b *Version) bool { return b.Equal(a) }
func neq(a, b *Version) bool { return !b.Equal(a) }

153 changes: 153 additions & 0 deletions constraint_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package version

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

func TestConstraint(t *testing.T) {
type testCase struct {
constraint string
truthTable map[bool][]string
}

testCases := []testCase{
{
constraint: ">= 1.1.0-beta.1+k0s.1",
truthTable: map[bool][]string{
true: {
"1.1.0+k0s.0",
"1.1.0-rc.1+k0s.0",
"1.1.1+k0s.0",
"1.1.1-rc.1+k0s.0",
},
false: {
"1.1.0-alpha.1+k0s.2",
"1.0.1+k0s.10",
},
},
},
{
constraint: ">= 1.1.0+k0s.1",
truthTable: map[bool][]string{
true: {
"1.1.0+k0s.1",
"1.1.0+k0s.2",
"1.1.1+k0s.0",
},
false: {
"1.0.9+k0s.255",
"1.1.0+k0s.0",
},
},
},
// simple operator checks
{
constraint: "= 1.0.0",
truthTable: map[bool][]string{
true: {"1.0.0"},
false: {"1.0.1", "0.9.9"},
},
},
{
constraint: "1.0.0",
truthTable: map[bool][]string{
true: {"1.0.0"},
false: {"1.0.1", "0.9.9"},
},
},
{
constraint: "!= 1.0.0",
truthTable: map[bool][]string{
true: {"1.0.1", "0.9.9"},
false: {"1.0.0"},
},
},
{
constraint: "> 1.0.0",
truthTable: map[bool][]string{
true: {"1.0.1", "1.1.0"},
false: {"1.0.0", "0.9.9"},
},
},
{
constraint: "< 1.0.0",
truthTable: map[bool][]string{
true: {"0.9.9", "0.9.8"},
false: {"1.0.0", "1.0.1"},
},
},
{
constraint: ">= 1.0.0",
truthTable: map[bool][]string{
true: {"1.0.0", "1.0.1"},
false: {"0.9.9"},
},
},
{
constraint: "<= 1.0.0",
truthTable: map[bool][]string{
true: {"1.0.0", "0.9.9"},
false: {"1.0.1"},
},
},
// two digit constraints
{
constraint: ">= 1.0",
truthTable: map[bool][]string{
true: {"1.0.0", "1.0.1", "1.1.0"},
false: {"0.9.9", "1.0.1-alpha.1"},
},
},
{
constraint: ">= 1.0-a",
truthTable: map[bool][]string{
true: {"1.0.0", "1.0.1", "1.0.0-alpha.1"},
false: {"0.9.9"},
},
},
}

for _, tc := range testCases {
t.Run(tc.constraint, func(t *testing.T) {
c, err := NewConstraint(tc.constraint)
assert.NoError(t, err)

for expected, versions := range tc.truthTable {
t.Run(fmt.Sprintf("%t", expected), func(t *testing.T) {
for _, version := range versions {
t.Run(version, func(t *testing.T) {
assert.Equal(t, expected, c.Check(MustParse(version)))
})
}
})
}
})
}
}

func TestInvalidConstraint(t *testing.T) {
invalidConstraints := []string{
"",
"==",
">= ",
"invalid",
">= abc",
}

for _, invalidConstraint := range invalidConstraints {
_, err := newConstraint(invalidConstraint)
assert.Error(t, err, "Expected error for invalid constraint: "+invalidConstraint)
}
}

func TestCheckString(t *testing.T) {
c, err := NewConstraint(">= 1.0.0")
assert.NoError(t, err)

assert.True(t, c.CheckString("1.0.0"))
assert.False(t, c.CheckString("0.9.9"))
assert.False(t, c.CheckString("x"))
}
5 changes: 5 additions & 0 deletions version.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ func (v *Version) UnmarshalJSON(b []byte) error {
})
}

// Satisfies returns true if the version satisfies the supplied constraint
func (v *Version) Satisfies(constraint Constraints) bool {
return constraint.Check(v)
}

// NewVersion returns a new Version created from the supplied string or an error if the string is not a valid version number
func NewVersion(v string) (*Version, error) {
n, err := goversion.NewVersion(strings.TrimPrefix(v, "v"))
Expand Down
14 changes: 14 additions & 0 deletions version_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@ func TestK0sComparison(t *testing.T) {
assert.False(t, b.Equal(a), "version %s should not be equal to %s", b, a)
}

func TestSatisfies(t *testing.T) {
v, err := NewVersion("1.23.1+k0s.1")
assert.NoError(t, err)
assert.True(t, v.Satisfies(MustConstraint(">=1.23.1")))
assert.True(t, v.Satisfies(MustConstraint(">=1.23.1+k0s.0")))
assert.True(t, v.Satisfies(MustConstraint(">=1.23.1+k0s.1")))
assert.True(t, v.Satisfies(MustConstraint("=1.23.1+k0s.1")))
assert.True(t, v.Satisfies(MustConstraint("<1.23.1+k0s.2")))
assert.False(t, v.Satisfies(MustConstraint(">=1.23.1+k0s.2")))
assert.False(t, v.Satisfies(MustConstraint(">=1.23.2")))
assert.False(t, v.Satisfies(MustConstraint(">1.23.1+k0s.1")))
assert.False(t, v.Satisfies(MustConstraint("<1.23.1+k0s.1")))
}

func TestURLs(t *testing.T) {
a, err := NewVersion("1.23.3+k0s.1")
assert.NoError(t, err)
Expand Down

0 comments on commit 770cea1

Please sign in to comment.