Skip to content

Commit

Permalink
feat: support multistage group
Browse files Browse the repository at this point in the history
  • Loading branch information
tonny-zhang committed Mar 1, 2022
1 parent bf1dc5b commit af38a26
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 140 deletions.
45 changes: 26 additions & 19 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,19 @@ func Default() *Router {

// Group get group router
func (router *Router) Group(path string, handler ...HandlerFunc) *Router {
if router.prefix != "" {
panic(fmt.Errorf("group [%s] can not group again", router.prefix))
}
if len(path) == 0 || path[0] != '/' {
panic(fmt.Errorf("group [%s] must start with /", path))
}
if strings.Index(path, "*") > -1 {
if strings.Index(path, "*") > -1 || strings.Index(path, ":") > -1 {
panic(fmt.Errorf("group path [%s] can not has parameter", path))
}
prefix := utils.CleanPath(path + "/")
for _, g := range router.groups {
if matchGroup(g, prefix) {
panic(fmt.Errorf("group [%s] conflicts with [%s]", prefix, g.prefix))
}
matchedGroup := router.matchGroup(prefix)
if matchedGroup != nil {
panic(fmt.Errorf("group [%s] conflicts with [%s]", prefix, matchedGroup.prefix))
}
if router.prefix != "" {
prefix = utils.CleanPath(router.prefix + "/" + prefix)
}
r := &Router{
prefix: prefix,
Expand All @@ -78,6 +77,20 @@ func (router *Router) Group(path string, handler ...HandlerFunc) *Router {
router.groups = append(router.groups, r)
return r
}
func (router *Router) matchGroup(path string) *Router {
for _, g := range router.groups {
if len(g.groups) > 0 {
gg := g.matchGroup(path)
if gg != nil {
return gg
}
}
if matchGroup(g, path) {
return g
}
}
return nil
}
func matchGroup(router *Router, path string) bool {
if len(router.prefix) > 0 {
if strings.HasPrefix(path, router.prefix) {
Expand All @@ -90,18 +103,15 @@ func matchGroup(router *Router, path string) bool {
}

for i, j := 0, len(arrRP); i < j; i++ {
if strings.Index(arrRP[i], ":") > -1 || strings.Index(arrPath[i], ":") > -1 {
continue
}
if i == j-1 && arrRP[i] == "" {
break
return true
}
if arrRP[i] != arrPath[i] {
return false
}
}
}
return true
return false
}

// NotFound custom NotFoundHandler
Expand Down Expand Up @@ -201,12 +211,9 @@ func (router *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}

routerUse := router
for _, g := range router.groups {
if matchGroup(g, reqURI) {
routerUse = g
break
}
routerUse := router.matchGroup(reqURI)
if routerUse == nil {
routerUse = router
}

notfoundHandlers := routerUse.notfoundHandlers
Expand Down
147 changes: 147 additions & 0 deletions router_group_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package cotton

import (
"net/http"
"testing"

"github.com/stretchr/testify/assert"
)

func TestGroup(t *testing.T) {
router := NewRouter()
g1 := router.Group("/v1")
g1.Get("/a", func(c *Context) {
c.String(http.StatusOK, "g1 a")
})
g1.Get("/b", func(c *Context) {
c.String(http.StatusBadGateway, "g1 b")
})

w := doRequest(router, http.MethodGet, "/v1/a")

assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "g1 a", w.Body.String())

w = doRequest(router, http.MethodGet, "/v1/b")

assert.Equal(t, http.StatusBadGateway, w.Code)
assert.Equal(t, "g1 b", w.Body.String())
// assert.True(t, false)
}
func TestGroupPrefix(t *testing.T) {
router := NewRouter()
g1 := router.Group("/v1")
g1.Get("/a", func(c *Context) {
c.String(http.StatusOK, "g1 a")
})

w := doRequest(router, http.MethodGet, "/v1/v1/a")

assert.Equal(t, http.StatusNotFound, w.Code)
}
func TestGroupPanic(t *testing.T) {
assert.PanicsWithError(t, "group [] must start with /", func() {
router := NewRouter()
router.Group("")
})
assert.PanicsWithError(t, "group [abc] must start with /", func() {
router := NewRouter()
router.Group("abc")
})

assert.PanicsWithError(t, "group [/a/] conflicts with [/a/]", func() {
router := NewRouter()
router.Group("/a")
router.Group("/a")
})

assert.PanicsWithError(t, "group path [/:method] can not has parameter", func() {
router := NewRouter()
router.Group("/:method")
})

assert.NotPanics(t, func() {
router := NewRouter()
router.Group("/s")
router.Group("/static")
})
}

func TestMatchGroup(t *testing.T) {
assert.True(t, matchGroup(&Router{
prefix: "/v1/",
}, "/v1/test"))

assert.False(t, matchGroup(&Router{
prefix: "/v1/",
}, "/v2/test"))
}

func TestCustomGroupNotFound(t *testing.T) {
router := NewRouter()

infoCustomNotFound := "not found from custom"
infoCustomGroupNotFound := "not found from custom group"
infoCustomGroupUserNotFound := "not found from custom group user"
router.NotFound(func(ctx *Context) {
ctx.String(http.StatusNotFound, infoCustomNotFound)
})

g := router.Group("/v1")
g.NotFound(func(ctx *Context) {
ctx.String(http.StatusNotFound, infoCustomGroupNotFound)
})

w := doRequest(router, http.MethodGet, "/path404")
assert.Equal(t, http.StatusNotFound, w.Code)
assert.Equal(t, infoCustomNotFound, w.Body.String())

w = doRequest(router, http.MethodGet, "/v1/path404")

assert.Equal(t, http.StatusNotFound, w.Code)
assert.Equal(t, infoCustomGroupNotFound, w.Body.String())

gUser := g.Group("/user")

w = doRequest(router, http.MethodGet, "/v1/user/path404")

assert.Equal(t, http.StatusNotFound, w.Code)
assert.Equal(t, infoCustomGroupNotFound, w.Body.String())

gUser.NotFound(func(ctx *Context) {
ctx.String(http.StatusNotFound, infoCustomGroupUserNotFound)
})

w = doRequest(router, http.MethodGet, "/v1/user/path404")

assert.Equal(t, http.StatusNotFound, w.Code)
assert.Equal(t, infoCustomGroupUserNotFound, w.Body.String())

w = doRequest(router, http.MethodGet, "/v1/user1/path404")

assert.Equal(t, http.StatusNotFound, w.Code)
assert.Equal(t, infoCustomGroupNotFound, w.Body.String())

w = doRequest(router, http.MethodGet, "/v2/user1/path404")

assert.Equal(t, http.StatusNotFound, w.Code)
assert.Equal(t, infoCustomNotFound, w.Body.String())
}

func TestGroupMulty(t *testing.T) {
router := NewRouter()
g1 := router.Group("/a")
g1.addHandleFunc("GET", "/test", func(ctx *Context) {
ctx.String(http.StatusOK, "/a/test")
})
g2 := g1.Group("/b")
g2.addHandleFunc("GET", "/test", func(ctx *Context) {
ctx.String(http.StatusOK, "/a/b/test")
})

w := doRequest(router, http.MethodGet, "/a/test")
assert.Equal(t, "/a/test", w.Body.String())

w = doRequest(router, http.MethodGet, "/a/b/test")
assert.Equal(t, "/a/b/test", w.Body.String())
}
121 changes: 0 additions & 121 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,127 +70,6 @@ func TestCustomNotFound(t *testing.T) {
assert.Equal(t, infoCustomNotFound, w.Body.String())
}

func TestCustomGroupNotFound(t *testing.T) {
router := NewRouter()

infoCustomNotFound := "not found from custom"
infoCustomGroupNotFound := "not found from custom group"
router.NotFound(func(ctx *Context) {
ctx.String(http.StatusNotFound, infoCustomNotFound)
})

g := router.Group("/v1")
g.NotFound(func(ctx *Context) {
ctx.String(http.StatusNotFound, infoCustomGroupNotFound)
})

w := doRequest(router, http.MethodGet, "/path404")
assert.Equal(t, http.StatusNotFound, w.Code)
assert.Equal(t, infoCustomNotFound, w.Body.String())

w = doRequest(router, http.MethodGet, "/v1/path404")

assert.Equal(t, http.StatusNotFound, w.Code)
assert.Equal(t, infoCustomGroupNotFound, w.Body.String())
}
func TestGroup(t *testing.T) {
router := NewRouter()
g1 := router.Group("/v1")
g1.Get("/a", func(c *Context) {
c.String(http.StatusOK, "g1 a")
})
g1.Get("/b", func(c *Context) {
c.String(http.StatusBadGateway, "g1 b")
})

w := doRequest(router, http.MethodGet, "/v1/a")

assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "g1 a", w.Body.String())

w = doRequest(router, http.MethodGet, "/v1/b")

assert.Equal(t, http.StatusBadGateway, w.Code)
assert.Equal(t, "g1 b", w.Body.String())
// assert.True(t, false)
}
func TestGroupPrefix(t *testing.T) {
router := NewRouter()
g1 := router.Group("/v1")
g1.Get("/a", func(c *Context) {
c.String(http.StatusOK, "g1 a")
})

w := doRequest(router, http.MethodGet, "/v1/v1/a")

assert.Equal(t, http.StatusNotFound, w.Code)
}
func TestGroupPanic(t *testing.T) {
assert.PanicsWithError(t, "group [] must start with /", func() {
router := NewRouter()
router.Group("")
})
assert.PanicsWithError(t, "group [abc] must start with /", func() {
router := NewRouter()
router.Group("abc")
})

assert.PanicsWithError(t, "group [/a/] can not group again", func() {
router := NewRouter()
router.Group("/a").Group("/a")
})

assert.PanicsWithError(t, "group [/a/] conflicts with [/a/]", func() {
router := NewRouter()
router.Group("/a")
router.Group("/a")
})
assert.PanicsWithError(t, "group [/b/] conflicts with [/:method/]", func() {
router := NewRouter()
router.Group("/:method")
router.Group("/b")
})
assert.PanicsWithError(t, "group [/:method/] conflicts with [/a/]", func() {
router := NewRouter()
router.Group("/a")
router.Group("/:method")
})
assert.PanicsWithError(t, "group [/:id/] conflicts with [/:method/]", func() {
router := NewRouter()
router.Group("/:method")
router.Group("/:id")
})

assert.NotPanics(t, func() {
router := NewRouter()
router.Group("/s")
router.Group("/static")
})
}

func TestMatchGroup(t *testing.T) {
assert.True(t, matchGroup(&Router{
prefix: "/v1/",
}, "/v1/test"))

assert.False(t, matchGroup(&Router{
prefix: "/v1/",
}, "/v2/test"))

assert.True(t, matchGroup(&Router{
prefix: "/v1/:method/",
}, "/v1/test/"))
assert.True(t, matchGroup(&Router{
prefix: "/v1/:method/",
}, "/v1/test/abc"))
assert.False(t, matchGroup(&Router{
prefix: "/v1/:method/",
}, "/v1/test"))
assert.False(t, matchGroup(&Router{
prefix: "/v1/:method/",
}, "/v2/test/"))
}

func TestMultipleEOP(t *testing.T) {
router := NewRouter()
content := "router a"
Expand Down

0 comments on commit af38a26

Please sign in to comment.