summaryrefslogtreecommitdiff
path: root/spartan
diff options
context:
space:
mode:
authortjpcc <tjp@ctrl-c.club>2023-05-01 08:38:54 -0600
committertjpcc <tjp@ctrl-c.club>2023-05-01 08:38:54 -0600
commita5d8aeb0bb53420a6dc87dcdc72494fa3f44224f (patch)
treeba60dbb50c1b7f94d59863e511fb389b6eba6bac /spartan
parenta97480d593dfc9ec1121d65b1697b53750f0d979 (diff)
make spartan.Client.RoundTrip match the API of other clients.
- Request.Meta is already used as an *io.LimitedReader in spartan servers, so by following this convention the RoundTrip method doesn't need anything more than the *Request. - Make a new public method for setting the body on a spartan request.
Diffstat (limited to 'spartan')
-rw-r--r--spartan/client.go15
-rw-r--r--spartan/request.go21
2 files changed, 27 insertions, 9 deletions
diff --git a/spartan/client.go b/spartan/client.go
index e571c14..40c9dd6 100644
--- a/spartan/client.go
+++ b/spartan/client.go
@@ -18,7 +18,7 @@ import (
type Client struct{}
// RoundTrip sends a single spartan request and returns its response.
-func (c Client) RoundTrip(request *sr.Request, body io.Reader) (*sr.Response, error) {
+func (c Client) RoundTrip(request *sr.Request) (*sr.Response, error) {
if request.Scheme != "spartan" && request.Scheme != "" {
return nil, errors.New("non-spartan protocols not supported")
}
@@ -38,20 +38,17 @@ func (c Client) RoundTrip(request *sr.Request, body io.Reader) (*sr.Response, er
request.RemoteAddr = conn.RemoteAddr()
- var bodyBytes []byte = nil
- if body != nil {
- bodyBytes, err = io.ReadAll(body)
- if err != nil {
- return nil, err
- }
+ rdr, ok := request.Meta.(*io.LimitedReader)
+ if !ok {
+ return nil, errors.New("request body must be an *io.LimitedReader")
}
- requestLine := host + " " + request.EscapedPath() + " " + strconv.Itoa(len(bodyBytes)) + "\r\n"
+ requestLine := host + " " + request.EscapedPath() + " " + strconv.Itoa(int(rdr.N)) + "\r\n"
if _, err := conn.Write([]byte(requestLine)); err != nil {
return nil, err
}
- if _, err := conn.Write(bodyBytes); err != nil {
+ if _, err := io.Copy(conn, rdr); err != nil {
return nil, err
}
diff --git a/spartan/request.go b/spartan/request.go
index a9b2815..c056af0 100644
--- a/spartan/request.go
+++ b/spartan/request.go
@@ -2,6 +2,7 @@ package spartan
import (
"bufio"
+ "bytes"
"errors"
"io"
"net/url"
@@ -80,3 +81,23 @@ func GetRequestBody(request *sr.Request) io.Reader {
}
return nil
}
+
+// SetRequestBody adds an io.Reader as a request body.
+//
+// It is for use in clients, preparing the request to be sent.
+//
+// This function will read the entire contents into memory unless
+// the reader is already an *io.LimitedReader.
+func SetRequestBody(request *sr.Request, body io.Reader) error {
+ if rdr, ok := body.(*io.LimitedReader); ok {
+ request.Meta = rdr
+ return nil
+ }
+
+ buf, err := io.ReadAll(body)
+ if err != nil {
+ return err
+ }
+ request.Meta = io.LimitReader(bytes.NewBuffer(buf), int64(len(buf)))
+ return nil
+}