diff options
Diffstat (limited to 'internal/pathtree.go')
-rw-r--r-- | internal/pathtree.go | 275 |
1 files changed, 275 insertions, 0 deletions
diff --git a/internal/pathtree.go b/internal/pathtree.go new file mode 100644 index 0000000..563e85c --- /dev/null +++ b/internal/pathtree.go @@ -0,0 +1,275 @@ +package internal + +import ( + "sort" + "strings" +) + +type PathTree[V any] struct { + tree subtree[V] + root *V +} + +func (pt PathTree[V]) Match(path string) (*V, map[string]string) { + path = strings.TrimPrefix(path, "/") + if path == "" { + return pt.root, map[string]string{} + } + m := pt.tree.Match(strings.Split(path, "/"), nil) + if m == nil { + return nil, nil + } + + return m.value, m.params() +} + +func (pt *PathTree[V]) Add(pattern string, value V) { + pattern = strings.TrimPrefix(pattern, "/") + if pattern == "" { + pt.root = &value + } else { + pt.tree.Add(strings.Split(pattern, "/"), value) + } +} + +// pattern segment which must be a specific string ("/users/"). +type segmentNode[V any] struct { + label string + value *V + subtree subtree[V] +} + +func makeSegment[V any](pattern []string, value V) segmentNode[V] { + node := segmentNode[V]{label: pattern[0]} + if len(pattern) == 1 { + node.value = &value + } else { + node.subtree.Add(pattern[1:], value) + } + return node +} + +func (sn segmentNode[V]) Match(segments []string, m *match[V]) *match[V] { + var l int + if m != nil { + l = m.length + } + m = &match[V]{ + value: sn.value, + length: l + len(sn.label), + prev: m, + } + if len(segments) == 1 { + return m + } + return sn.subtree.Match(segments[1:], m) +} + +func (sn *segmentNode[V]) Add(pattern []string, value V) { + if len(pattern) == 0 { + if sn.value != nil { + panic("pattern already exists") + } + sn.value = &value + return + } + sn.subtree.Add(pattern, value) +} + +// single-segment param-capturing wildcard ("/:username/") +type wildcardNode[V any] struct { + param string + value *V + subtree subtree[V] +} + +func makeWildcard[V any](pattern []string, value V) wildcardNode[V] { + node := wildcardNode[V]{param: pattern[0][1:]} + if len(pattern) == 1 { + node.value = &value + } else { + node.subtree.Add(pattern[1:], value) + } + return node +} + +func (wn wildcardNode[V]) Match(segments []string, m *match[V]) *match[V] { + var l int + if m != nil { + l = m.length + } + + m = &match[V]{ + value: wn.value, + length: l + len(segments[0]), + prev: m, + } + if wn.param != "" { + m.paramkey = wn.param + m.paramval = segments[0] + } + + if len(segments) == 1 { + return m + } + + return wn.subtree.Match(segments[1:], m) +} + +func (wn *wildcardNode[V]) Add(pattern []string, value V) { + if len(pattern) == 0 { + if wn.value != nil { + panic("pattern already exists") + } + wn.value = &value + return + } + wn.subtree.Add(pattern, value) +} + +// "rest of the path" capturing node ("/*path_info") - always a final node +type remainderNode[V any] struct { + param string + value V +} + +func (rn remainderNode[V]) Match(segments []string, m *match[V]) *match[V] { + m = &match[V]{ + value: &rn.value, + length: m.length, + prev: m, + } + if rn.param != "" { + m.paramkey = rn.param + m.paramval = strings.Join(segments, "/") + } + return m +} + +// all the children under a tree node +type subtree[V any] struct { + segments childSegments[V] + wildcards childWildcards[V] + remainder *remainderNode[V] +} + +func (st subtree[V]) Match(segments []string, m *match[V]) *match[V] { + var best *match[V] + + if st.remainder != nil { + best = st.remainder.Match(segments, m) + } + + for _, wc := range st.wildcards { + candidate := wc.Match(segments, m) + if best == nil || candidate.length > best.length { + best = candidate + } + } + + childSeg := st.segments.Find(segments[0]) + if childSeg == nil { + return best + } + candidate := childSeg.Match(segments, m) + if best == nil || (candidate != nil && candidate.length >= best.length) { + best = candidate + } + + return best +} + +func (st *subtree[V]) Add(pattern []string, value V) { + if len(pattern[0]) == 0 { + panic("invalid pattern") + } + + switch pattern[0][0] { + case '*': + if len(pattern) > 1 { + panic("invalid pattern: segments after *remainder") + } + if st.remainder != nil { + panic("pattern already exists") + } + + st.remainder = &remainderNode[V]{ + param: pattern[0][1:], + value: value, + } + case ':': + child := st.wildcards.Find(pattern[0][1:]) + if child != nil { + child.Add(pattern[1:], value) + } else { + st.wildcards = append(st.wildcards, makeWildcard(pattern, value)) + sort.Sort(st.wildcards) + } + default: + child := st.segments.Find(pattern[0]) + if child != nil { + child.Add(pattern[1:], value) + } else { + st.segments = append(st.segments, makeSegment(pattern, value)) + sort.Sort(st.segments) + } + } +} + +type childSegments[V any] []segmentNode[V] + +func (cs childSegments[V]) Len() int { return len(cs) } +func (cs childSegments[V]) Less(i, j int) bool { return cs[i].label < cs[j].label } +func (cs childSegments[V]) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] } + +func (cs childSegments[V]) Find(label string) *segmentNode[V] { + idx := sort.Search(len(cs), func(i int) bool { + return label <= cs[i].label + }) + if idx < len(cs) && cs[idx].label == label { + return &cs[idx] + } + return nil +} + +type childWildcards[V any] []wildcardNode[V] + +func (cw childWildcards[V]) Len() int { return len(cw) } +func (cw childWildcards[V]) Less(i, j int) bool { return cw[i].param < cw[j].param } +func (cw childWildcards[V]) Swap(i, j int) { cw[i], cw[j] = cw[j], cw[i] } + +func (cw childWildcards[V]) Find(param string) *wildcardNode[V] { + i := sort.Search(len(cw), func(i int) bool { + return param <= cw[i].param + }) + if i < len(cw) && cw[i].param == param { + return &cw[i] + } + return nil +} + +// linked list we build up as we match our way through the path segments +// - if there is a parameter captured it goes in paramkey/paramval +// - if there was a value it goes in value +// - also store cumulative length to choose the longest match +type match[V any] struct { + paramkey string + paramval string + + value *V + length int + prev *match[V] +} + +func (m match[_]) params() map[string]string { + mch := &m + out := make(map[string]string) + for mch != nil { + if mch.paramkey != "" { + out[mch.paramkey] = mch.paramval + } + mch = mch.prev + } + + return out +} |