summaryrefslogtreecommitdiff
path: root/internal/pathtree.go
diff options
context:
space:
mode:
authortjpcc <tjp@ctrl-c.club>2023-02-02 16:15:53 -0700
committertjpcc <tjp@ctrl-c.club>2023-02-02 16:15:53 -0700
commitac024567e880f0da59557f0051018f4ac932c6ad (patch)
treea1d0d5661d5f980ed56716e430a4e4de4b5e7bd3 /internal/pathtree.go
parentb7cb13b4e68568b014c868186f536439e92d662f (diff)
Initial Router work.
- Router type, supports: adding handlers, serving, fetching the matching handler for a route. - Private PathTree type handles the modified radix trie.
Diffstat (limited to 'internal/pathtree.go')
-rw-r--r--internal/pathtree.go275
1 files changed, 275 insertions, 0 deletions
diff --git a/internal/pathtree.go b/internal/pathtree.go
new file mode 100644
index 0000000..563e85c
--- /dev/null
+++ b/internal/pathtree.go
@@ -0,0 +1,275 @@
+package internal
+
+import (
+ "sort"
+ "strings"
+)
+
+type PathTree[V any] struct {
+ tree subtree[V]
+ root *V
+}
+
+func (pt PathTree[V]) Match(path string) (*V, map[string]string) {
+ path = strings.TrimPrefix(path, "/")
+ if path == "" {
+ return pt.root, map[string]string{}
+ }
+ m := pt.tree.Match(strings.Split(path, "/"), nil)
+ if m == nil {
+ return nil, nil
+ }
+
+ return m.value, m.params()
+}
+
+func (pt *PathTree[V]) Add(pattern string, value V) {
+ pattern = strings.TrimPrefix(pattern, "/")
+ if pattern == "" {
+ pt.root = &value
+ } else {
+ pt.tree.Add(strings.Split(pattern, "/"), value)
+ }
+}
+
+// pattern segment which must be a specific string ("/users/").
+type segmentNode[V any] struct {
+ label string
+ value *V
+ subtree subtree[V]
+}
+
+func makeSegment[V any](pattern []string, value V) segmentNode[V] {
+ node := segmentNode[V]{label: pattern[0]}
+ if len(pattern) == 1 {
+ node.value = &value
+ } else {
+ node.subtree.Add(pattern[1:], value)
+ }
+ return node
+}
+
+func (sn segmentNode[V]) Match(segments []string, m *match[V]) *match[V] {
+ var l int
+ if m != nil {
+ l = m.length
+ }
+ m = &match[V]{
+ value: sn.value,
+ length: l + len(sn.label),
+ prev: m,
+ }
+ if len(segments) == 1 {
+ return m
+ }
+ return sn.subtree.Match(segments[1:], m)
+}
+
+func (sn *segmentNode[V]) Add(pattern []string, value V) {
+ if len(pattern) == 0 {
+ if sn.value != nil {
+ panic("pattern already exists")
+ }
+ sn.value = &value
+ return
+ }
+ sn.subtree.Add(pattern, value)
+}
+
+// single-segment param-capturing wildcard ("/:username/")
+type wildcardNode[V any] struct {
+ param string
+ value *V
+ subtree subtree[V]
+}
+
+func makeWildcard[V any](pattern []string, value V) wildcardNode[V] {
+ node := wildcardNode[V]{param: pattern[0][1:]}
+ if len(pattern) == 1 {
+ node.value = &value
+ } else {
+ node.subtree.Add(pattern[1:], value)
+ }
+ return node
+}
+
+func (wn wildcardNode[V]) Match(segments []string, m *match[V]) *match[V] {
+ var l int
+ if m != nil {
+ l = m.length
+ }
+
+ m = &match[V]{
+ value: wn.value,
+ length: l + len(segments[0]),
+ prev: m,
+ }
+ if wn.param != "" {
+ m.paramkey = wn.param
+ m.paramval = segments[0]
+ }
+
+ if len(segments) == 1 {
+ return m
+ }
+
+ return wn.subtree.Match(segments[1:], m)
+}
+
+func (wn *wildcardNode[V]) Add(pattern []string, value V) {
+ if len(pattern) == 0 {
+ if wn.value != nil {
+ panic("pattern already exists")
+ }
+ wn.value = &value
+ return
+ }
+ wn.subtree.Add(pattern, value)
+}
+
+// "rest of the path" capturing node ("/*path_info") - always a final node
+type remainderNode[V any] struct {
+ param string
+ value V
+}
+
+func (rn remainderNode[V]) Match(segments []string, m *match[V]) *match[V] {
+ m = &match[V]{
+ value: &rn.value,
+ length: m.length,
+ prev: m,
+ }
+ if rn.param != "" {
+ m.paramkey = rn.param
+ m.paramval = strings.Join(segments, "/")
+ }
+ return m
+}
+
+// all the children under a tree node
+type subtree[V any] struct {
+ segments childSegments[V]
+ wildcards childWildcards[V]
+ remainder *remainderNode[V]
+}
+
+func (st subtree[V]) Match(segments []string, m *match[V]) *match[V] {
+ var best *match[V]
+
+ if st.remainder != nil {
+ best = st.remainder.Match(segments, m)
+ }
+
+ for _, wc := range st.wildcards {
+ candidate := wc.Match(segments, m)
+ if best == nil || candidate.length > best.length {
+ best = candidate
+ }
+ }
+
+ childSeg := st.segments.Find(segments[0])
+ if childSeg == nil {
+ return best
+ }
+ candidate := childSeg.Match(segments, m)
+ if best == nil || (candidate != nil && candidate.length >= best.length) {
+ best = candidate
+ }
+
+ return best
+}
+
+func (st *subtree[V]) Add(pattern []string, value V) {
+ if len(pattern[0]) == 0 {
+ panic("invalid pattern")
+ }
+
+ switch pattern[0][0] {
+ case '*':
+ if len(pattern) > 1 {
+ panic("invalid pattern: segments after *remainder")
+ }
+ if st.remainder != nil {
+ panic("pattern already exists")
+ }
+
+ st.remainder = &remainderNode[V]{
+ param: pattern[0][1:],
+ value: value,
+ }
+ case ':':
+ child := st.wildcards.Find(pattern[0][1:])
+ if child != nil {
+ child.Add(pattern[1:], value)
+ } else {
+ st.wildcards = append(st.wildcards, makeWildcard(pattern, value))
+ sort.Sort(st.wildcards)
+ }
+ default:
+ child := st.segments.Find(pattern[0])
+ if child != nil {
+ child.Add(pattern[1:], value)
+ } else {
+ st.segments = append(st.segments, makeSegment(pattern, value))
+ sort.Sort(st.segments)
+ }
+ }
+}
+
+type childSegments[V any] []segmentNode[V]
+
+func (cs childSegments[V]) Len() int { return len(cs) }
+func (cs childSegments[V]) Less(i, j int) bool { return cs[i].label < cs[j].label }
+func (cs childSegments[V]) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] }
+
+func (cs childSegments[V]) Find(label string) *segmentNode[V] {
+ idx := sort.Search(len(cs), func(i int) bool {
+ return label <= cs[i].label
+ })
+ if idx < len(cs) && cs[idx].label == label {
+ return &cs[idx]
+ }
+ return nil
+}
+
+type childWildcards[V any] []wildcardNode[V]
+
+func (cw childWildcards[V]) Len() int { return len(cw) }
+func (cw childWildcards[V]) Less(i, j int) bool { return cw[i].param < cw[j].param }
+func (cw childWildcards[V]) Swap(i, j int) { cw[i], cw[j] = cw[j], cw[i] }
+
+func (cw childWildcards[V]) Find(param string) *wildcardNode[V] {
+ i := sort.Search(len(cw), func(i int) bool {
+ return param <= cw[i].param
+ })
+ if i < len(cw) && cw[i].param == param {
+ return &cw[i]
+ }
+ return nil
+}
+
+// linked list we build up as we match our way through the path segments
+// - if there is a parameter captured it goes in paramkey/paramval
+// - if there was a value it goes in value
+// - also store cumulative length to choose the longest match
+type match[V any] struct {
+ paramkey string
+ paramval string
+
+ value *V
+ length int
+ prev *match[V]
+}
+
+func (m match[_]) params() map[string]string {
+ mch := &m
+ out := make(map[string]string)
+ for mch != nil {
+ if mch.paramkey != "" {
+ out[mch.paramkey] = mch.paramval
+ }
+ mch = mch.prev
+ }
+
+ return out
+}