summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/pathtree.go46
-rw-r--r--internal/pathtree_test.go8
-rw-r--r--router.go17
3 files changed, 45 insertions, 26 deletions
diff --git a/internal/pathtree.go b/internal/pathtree.go
index 0ab6b81..bbb1f2c 100644
--- a/internal/pathtree.go
+++ b/internal/pathtree.go
@@ -80,9 +80,6 @@ func (sn segmentNode[V]) Match(segments []string, m *match[V]) *match[V] {
func (sn *segmentNode[V]) Add(pattern []string, value V) {
if len(pattern) == 0 {
- if sn.value != nil {
- panic("pattern already exists")
- }
sn.value = &value
return
}
@@ -91,13 +88,15 @@ func (sn *segmentNode[V]) Add(pattern []string, value V) {
// single-segment param-capturing wildcard ("/:username/")
type wildcardNode[V any] struct {
+ prefix string
param string
value *V
subtree subtree[V]
}
func makeWildcard[V any](pattern []string, value V) wildcardNode[V] {
- node := wildcardNode[V]{param: pattern[0][1:]}
+ prefix, param, _ := strings.Cut(pattern[0], ":")
+ node := wildcardNode[V]{prefix: prefix, param: param}
if len(pattern) == 1 {
node.value = &value
} else {
@@ -119,7 +118,7 @@ func (wn wildcardNode[V]) Match(segments []string, m *match[V]) *match[V] {
}
if wn.param != "" {
m.paramkey = wn.param
- m.paramval = segments[0]
+ m.paramval = strings.TrimPrefix(segments[0], wn.prefix)
}
if len(segments) == 1 {
@@ -131,9 +130,6 @@ func (wn wildcardNode[V]) Match(segments []string, m *match[V]) *match[V] {
func (wn *wildcardNode[V]) Add(pattern []string, value V) {
if len(pattern) == 0 {
- if wn.value != nil {
- panic("pattern already exists")
- }
wn.value = &value
return
}
@@ -178,6 +174,9 @@ func (st subtree[V]) Match(segments []string, m *match[V]) *match[V] {
}
for _, wc := range st.wildcards {
+ if wc.prefix != "" && !strings.HasPrefix(segments[0], wc.prefix) {
+ continue
+ }
candidate := wc.Match(segments, m)
if best == nil || candidate.length > best.length {
best = candidate
@@ -198,22 +197,20 @@ func (st subtree[V]) Match(segments []string, m *match[V]) *match[V] {
func (st *subtree[V]) Add(pattern []string, value V) {
if len(pattern[0]) > 0 {
- switch pattern[0][0] {
- case '*':
+ switch {
+ case pattern[0][0] == '*':
if len(pattern) > 1 {
panic("invalid pattern: segments after *remainder")
}
- if st.remainder != nil {
- panic("pattern already exists")
- }
st.remainder = &remainderNode[V]{
param: pattern[0][1:],
value: value,
}
return
- case ':':
- child := st.wildcards.Find(pattern[0][1:])
+ case strings.Contains(pattern[0], ":"):
+ prefix, param, _ := strings.Cut(pattern[0], ":")
+ child := st.wildcards.Find(prefix, param)
if child != nil {
child.Add(pattern[1:], value)
} else {
@@ -291,19 +288,30 @@ func (cs childSegments[V]) Find(label string) *segmentNode[V] {
type childWildcards[V any] []wildcardNode[V]
func (cw childWildcards[V]) Len() int { return len(cw) }
-func (cw childWildcards[V]) Less(i, j int) bool { return cw[i].param < cw[j].param }
+func (cw childWildcards[V]) Less(i, j int) bool { return wildcardLess(cw[i], cw[j]) }
func (cw childWildcards[V]) Swap(i, j int) { cw[i], cw[j] = cw[j], cw[i] }
-func (cw childWildcards[V]) Find(param string) *wildcardNode[V] {
+func (cw childWildcards[V]) Find(prefix, param string) *wildcardNode[V] {
i := sort.Search(len(cw), func(i int) bool {
- return param <= cw[i].param
+ return wildcardLess(wildcardNode[V]{prefix: prefix, param: param}, cw[i])
})
- if i < len(cw) && cw[i].param == param {
+ if i < len(cw) && cw[i].prefix == prefix && cw[i].param == param {
return &cw[i]
}
return nil
}
+func wildcardLess[V any](a, b wildcardNode[V]) bool {
+ switch {
+ case a.prefix < b.prefix:
+ return true
+ case a.prefix == b.prefix:
+ return a.param <= b.param
+ default:
+ return false
+ }
+}
+
// linked list we build up as we match our way through the path segments
// - if there is a parameter captured it goes in paramkey/paramval
// - if there was a value it goes in value
diff --git a/internal/pathtree_test.go b/internal/pathtree_test.go
index e152e85..d9df757 100644
--- a/internal/pathtree_test.go
+++ b/internal/pathtree_test.go
@@ -34,6 +34,7 @@ func TestPathTree(t *testing.T) {
{"/a/b", 3},
{"/c", 4},
{"/x/:y/z/*rest", 5},
+ {"/f/prefix:y/z/*rest", 6},
},
paths: map[string]matchresult{
"/a": {
@@ -66,6 +67,13 @@ func TestPathTree(t *testing.T) {
"/": {
failed: true,
},
+ "/f/mismatch/z/more": {
+ failed: true,
+ },
+ "/f/prefixblargh/z/more": {
+ value: 6,
+ params: map[string]string{"y": "blargh", "rest": "more"},
+ },
},
},
{
diff --git a/router.go b/router.go
index d45a7de..7cf3142 100644
--- a/router.go
+++ b/router.go
@@ -2,6 +2,7 @@ package sliderule
import (
"context"
+ "path"
"strings"
"tildegit.org/tjp/sliderule/internal"
@@ -10,9 +11,11 @@ import (
// Router stores a mapping of request path patterns to handlers.
//
// Pattern may begin with "/" and then contain slash-delimited segments.
-// - Segments beginning with colon (:) are wildcards and will match any path
+// - Segments containing a colon (:) are wildcards and will match any path
// segment at that location. It may optionally have a word after the colon,
-// which will be the parameter name the path segment is captured into.
+// which will be the parameter name the path segment is captured into. It
+// may also optionally have text before the colon, in which case the pattern
+// will not match unless the request path segment contains that prefix.
// - Segments beginning with asterisk (*) are remainder wildcards. This must
// come last and will capture any remainder of the path. It may have a name
// after the asterisk which will be the parameter name.
@@ -50,7 +53,7 @@ func (r Router) Handle(ctx context.Context, request *Request) *Response {
return nil
}
- return handler.Handle(context.WithValue(ctx, routeParamsKey, params), request)
+ return handler.Handle(context.WithValue(ctx, RouteParamsKey, params), request)
}
// Handler builds a Handler
@@ -83,8 +86,8 @@ func (r *Router) Mount(prefix string, subrouter *Router) {
prefix = strings.TrimSuffix(prefix, "/")
for _, subroute := range subrouter.tree.Routes() {
- r.Route(prefix+"/"+subroute.Pattern, subroute.Value)
- if subroute.Pattern == "" {
+ r.Route(path.Join(prefix, subroute.Pattern), subroute.Value)
+ if subroute.Pattern == "/" {
r.Route(prefix, subroute.Value)
}
}
@@ -111,7 +114,7 @@ func (r *Router) Use(mw Middleware) {
// If Router was used but no parameters were captured in the pattern, it
// returns a non-nil empty map.
func RouteParams(ctx context.Context) map[string]string {
- if m, ok := ctx.Value(routeParamsKey).(map[string]string); ok {
+ if m, ok := ctx.Value(RouteParamsKey).(map[string]string); ok {
return m
}
return nil
@@ -119,4 +122,4 @@ func RouteParams(ctx context.Context) map[string]string {
type routeParamsKeyType struct{}
-var routeParamsKey = routeParamsKeyType{}
+var RouteParamsKey = routeParamsKeyType{}