diff options
Diffstat (limited to 'internal/pathtree.go')
-rw-r--r-- | internal/pathtree.go | 46 |
1 files changed, 27 insertions, 19 deletions
diff --git a/internal/pathtree.go b/internal/pathtree.go index 0ab6b81..bbb1f2c 100644 --- a/internal/pathtree.go +++ b/internal/pathtree.go @@ -80,9 +80,6 @@ func (sn segmentNode[V]) Match(segments []string, m *match[V]) *match[V] { func (sn *segmentNode[V]) Add(pattern []string, value V) { if len(pattern) == 0 { - if sn.value != nil { - panic("pattern already exists") - } sn.value = &value return } @@ -91,13 +88,15 @@ func (sn *segmentNode[V]) Add(pattern []string, value V) { // single-segment param-capturing wildcard ("/:username/") type wildcardNode[V any] struct { + prefix string param string value *V subtree subtree[V] } func makeWildcard[V any](pattern []string, value V) wildcardNode[V] { - node := wildcardNode[V]{param: pattern[0][1:]} + prefix, param, _ := strings.Cut(pattern[0], ":") + node := wildcardNode[V]{prefix: prefix, param: param} if len(pattern) == 1 { node.value = &value } else { @@ -119,7 +118,7 @@ func (wn wildcardNode[V]) Match(segments []string, m *match[V]) *match[V] { } if wn.param != "" { m.paramkey = wn.param - m.paramval = segments[0] + m.paramval = strings.TrimPrefix(segments[0], wn.prefix) } if len(segments) == 1 { @@ -131,9 +130,6 @@ func (wn wildcardNode[V]) Match(segments []string, m *match[V]) *match[V] { func (wn *wildcardNode[V]) Add(pattern []string, value V) { if len(pattern) == 0 { - if wn.value != nil { - panic("pattern already exists") - } wn.value = &value return } @@ -178,6 +174,9 @@ func (st subtree[V]) Match(segments []string, m *match[V]) *match[V] { } for _, wc := range st.wildcards { + if wc.prefix != "" && !strings.HasPrefix(segments[0], wc.prefix) { + continue + } candidate := wc.Match(segments, m) if best == nil || candidate.length > best.length { best = candidate @@ -198,22 +197,20 @@ func (st subtree[V]) Match(segments []string, m *match[V]) *match[V] { func (st *subtree[V]) Add(pattern []string, value V) { if len(pattern[0]) > 0 { - switch pattern[0][0] { - case '*': + switch { + case pattern[0][0] == '*': if len(pattern) > 1 { panic("invalid pattern: segments after *remainder") } - if st.remainder != nil { - panic("pattern already exists") - } st.remainder = &remainderNode[V]{ param: pattern[0][1:], value: value, } return - case ':': - child := st.wildcards.Find(pattern[0][1:]) + case strings.Contains(pattern[0], ":"): + prefix, param, _ := strings.Cut(pattern[0], ":") + child := st.wildcards.Find(prefix, param) if child != nil { child.Add(pattern[1:], value) } else { @@ -291,19 +288,30 @@ func (cs childSegments[V]) Find(label string) *segmentNode[V] { type childWildcards[V any] []wildcardNode[V] func (cw childWildcards[V]) Len() int { return len(cw) } -func (cw childWildcards[V]) Less(i, j int) bool { return cw[i].param < cw[j].param } +func (cw childWildcards[V]) Less(i, j int) bool { return wildcardLess(cw[i], cw[j]) } func (cw childWildcards[V]) Swap(i, j int) { cw[i], cw[j] = cw[j], cw[i] } -func (cw childWildcards[V]) Find(param string) *wildcardNode[V] { +func (cw childWildcards[V]) Find(prefix, param string) *wildcardNode[V] { i := sort.Search(len(cw), func(i int) bool { - return param <= cw[i].param + return wildcardLess(wildcardNode[V]{prefix: prefix, param: param}, cw[i]) }) - if i < len(cw) && cw[i].param == param { + if i < len(cw) && cw[i].prefix == prefix && cw[i].param == param { return &cw[i] } return nil } +func wildcardLess[V any](a, b wildcardNode[V]) bool { + switch { + case a.prefix < b.prefix: + return true + case a.prefix == b.prefix: + return a.param <= b.param + default: + return false + } +} + // linked list we build up as we match our way through the path segments // - if there is a parameter captured it goes in paramkey/paramval // - if there was a value it goes in value |