diff options
author | tjpcc <tjp@ctrl-c.club> | 2023-11-01 16:17:41 -0600 |
---|---|---|
committer | tjpcc <tjp@ctrl-c.club> | 2023-11-01 16:17:41 -0600 |
commit | a808b4692656c10bb43e2d54a2f5ef2746d231d5 (patch) | |
tree | 6e432d4a0271bb080451e3c87cffa6831c5736cd | |
parent | 5be7e44150d31582e6d8577d48c842a421fa8ddc (diff) |
allow titan uploads in the meta-client
fixes #18
-rw-r--r-- | client.go | 6 | ||||
-rw-r--r-- | gemini/client.go | 68 | ||||
-rw-r--r-- | tools/sw-fetch/main.go | 19 |
3 files changed, 71 insertions, 22 deletions
@@ -40,11 +40,13 @@ func NewClient(tlsConf *tls.Config) Client { if tlsConf != nil { hc.tp.TLSClientConfig = tlsConf } + gemcl := gemini.NewClient(tlsConf) return Client{ protos: map[string]protocolClient{ "finger": finger.Client{}, "gopher": gopher.Client{}, - "gemini": gemini.NewClient(tlsConf), + "gemini": gemcl, + "titan": gemcl, "spartan": spartan.NewClient(), "http": hc, "https": hc, @@ -95,7 +97,7 @@ func (c Client) Fetch(url string) (*Response, error) { } // Upload sends a request with a body and returns any redirect response. -func (c Client) Upload(url string, contents *io.LimitedReader) (*Response, error) { +func (c Client) Upload(url string, contents io.Reader) (*Response, error) { u, err := neturl.Parse(url) if err != nil { return nil, err diff --git a/gemini/client.go b/gemini/client.go index 0a5dd76..c60e92e 100644 --- a/gemini/client.go +++ b/gemini/client.go @@ -7,6 +7,8 @@ import ( "io" "net" neturl "net/url" + "strconv" + "strings" "tildegit.org/tjp/sliderule/internal/types" ) @@ -48,7 +50,7 @@ var ExceededMaxRedirects = errors.New("gemini.Client: exceeded MaxRedirects") // This method will not automatically follow redirects or cache permanent failures or // redirects. func (client Client) RoundTrip(request *types.Request) (*types.Response, error) { - if request.Scheme != "gemini" && request.Scheme != "" { + if request.Scheme != "gemini" && request.Scheme != "titan" && request.Scheme != "" { return nil, errors.New("non-gemini protocols not supported") } @@ -72,7 +74,33 @@ func (client Client) RoundTrip(request *types.Request) (*types.Response, error) st := conn.ConnectionState() request.TLSState = &st - if _, err := conn.Write([]byte(request.URL.String() + "\r\n")); err != nil { + destURL := *request.URL + + var body []byte + if request.Scheme == "titan" { + var err error + if bodyrdr, ok := request.Meta.(io.Reader); ok { + body, err = io.ReadAll(bodyrdr) + if err != nil { + return nil, err + } + if err := close(request.Meta); err != nil { + return nil, err + } + + path, params := pathparams(destURL.Path) + params["size"] = strconv.Itoa(len(body)) + destURL.Path = assemblepath(path, params) + } else { + body = []byte{} + } + } + + if _, err := conn.Write([]byte(destURL.String() + "\r\n")); err != nil { + return nil, err + } + + if _, err := conn.Write(body); err != nil { return nil, err } @@ -124,3 +152,39 @@ func (c Client) Fetch(url string) (*types.Response, error) { func (c Client) IsRedirect(response *types.Response) bool { return ResponseCategoryForStatus(response.Status) == ResponseCategoryRedirect } + +func pathparams(basepath string) (string, map[string]string) { + params := map[string]string{} + path, paramstr, found := strings.Cut(basepath, ";") + if !found { + return path, params + } + + for _, pairstr := range strings.Split(paramstr, ";") { + key, val, found := strings.Cut(pairstr, "=") + if found { + params[key] = val + } + } + + return path, params +} + +func assemblepath(basepath string, params map[string]string) string { + path := strings.Builder{} + _, _ = path.WriteString(basepath) + for key, val := range params { + _ = path.WriteByte(';') + _, _ = path.WriteString(key) + _ = path.WriteByte('=') + _, _ = path.WriteString(val) + } + return path.String() +} + +func close(closer any) error { + if cl, ok := closer.(io.Closer); ok { + return cl.Close() + } + return nil +} diff --git a/tools/sw-fetch/main.go b/tools/sw-fetch/main.go index 8591fe2..c2cb3e8 100644 --- a/tools/sw-fetch/main.go +++ b/tools/sw-fetch/main.go @@ -1,7 +1,6 @@ package main import ( - "bytes" "crypto/tls" "fmt" "io" @@ -38,12 +37,7 @@ func main() { var err error if conf.upload { - body, er := stdinContents() - if er != nil { - err = er - } else { - response, err = cl.Upload(conf.url.String(), body) - } + response, err = cl.Upload(conf.url.String(), os.Stdin) } else { response, err = cl.Fetch(conf.url.String()) } @@ -155,17 +149,6 @@ func failf(msg string, args ...any) { os.Exit(1) } -func stdinContents() (*io.LimitedReader, error) { - contents, err := io.ReadAll(os.Stdin) - if err != nil { - return nil, err - } - return &io.LimitedReader{ - R: bytes.NewBuffer(contents), - N: int64(len(contents)), - }, nil -} - func printResponse(response *sliderule.Response, conf config) bool { success := true |