diff options
Diffstat (limited to 'spartan')
-rw-r--r-- | spartan/client.go | 62 |
1 files changed, 58 insertions, 4 deletions
diff --git a/spartan/client.go b/spartan/client.go index 40c9dd6..a98949d 100644 --- a/spartan/client.go +++ b/spartan/client.go @@ -5,6 +5,7 @@ import ( "errors" "io" "net" + neturl "net/url" "strconv" sr "tildegit.org/tjp/sliderule" @@ -15,7 +16,19 @@ import ( // It carries no state and is reusable simultaneously by multiple goroutines. // // The zero value is immediately usabble. -type Client struct{} +type Client struct{ + MaxRedirects int +} + +func NewClient() Client { + return Client{MaxRedirects: DefaultMaxRedirects} +} + +// DefaultMaxRedirects is the number of chained redirects a Client will perform for a +// single request by default. This can be changed by altering the MaxRedirects field. +const DefaultMaxRedirects int = 2 + +var ExceededMaxRedirects = errors.New("spartan.Client: exceeded MaxRedirects") // RoundTrip sends a single spartan request and returns its response. func (c Client) RoundTrip(request *sr.Request) (*sr.Response, error) { @@ -38,9 +51,15 @@ func (c Client) RoundTrip(request *sr.Request) (*sr.Response, error) { request.RemoteAddr = conn.RemoteAddr() - rdr, ok := request.Meta.(*io.LimitedReader) - if !ok { - return nil, errors.New("request body must be an *io.LimitedReader") + var rdr *io.LimitedReader + if request.Meta == nil { + rdr = &io.LimitedReader{R: devnull{}, N: 0} + } else { + var ok bool + rdr, ok = request.Meta.(*io.LimitedReader) + if !ok { + return nil, errors.New("request body must be nil or an *io.LimitedReader") + } } requestLine := host + " " + request.EscapedPath() + " " + strconv.Itoa(int(rdr.N)) + "\r\n" @@ -65,3 +84,38 @@ func (c Client) RoundTrip(request *sr.Request) (*sr.Response, error) { return response, nil } + +// Fetch parses a URL string and fetches the spartan resource. +// +// It will resolve any redirects along the way, up to client.MaxRedirects. +func (c Client) Fetch(url string) (*sr.Response, error) { + u, err := neturl.Parse(url) + if err != nil { + return nil, err + } + + for i := 0; i <= c.MaxRedirects; i += 1 { + response, err := c.RoundTrip(&sr.Request{URL: u}) + if err != nil { + return nil, err + } + if response.Status != StatusRedirect { + return response, nil + } + + prev := u + u, err = neturl.Parse(url) + if err != nil { + return nil, err + } + u = prev.ResolveReference(u) + } + + return nil, ExceededMaxRedirects +} + +type devnull struct{} + +func (_ devnull) Read(p []byte) (int, error) { + return 0, nil +} |