summaryrefslogtreecommitdiff
path: root/spartan
diff options
context:
space:
mode:
authortjp <tjp@ctrl-c.club>2024-01-13 11:29:17 -0700
committertjp <tjp@ctrl-c.club>2024-01-13 11:29:17 -0700
commit4d861a2c395c0926b066014b92999d4dda454b2b (patch)
tree7bbf7d94cec2a5412a68b830c1e5223918d9bc82 /spartan
parentde1490808fa6e4d6749ff29d20cc1a589ec476d1 (diff)
dial timeouts for clients, and catch up on test fixes
Diffstat (limited to 'spartan')
-rw-r--r--spartan/client.go11
1 files changed, 6 insertions, 5 deletions
diff --git a/spartan/client.go b/spartan/client.go
index 81f2132..d77e791 100644
--- a/spartan/client.go
+++ b/spartan/client.go
@@ -2,6 +2,7 @@ package spartan
import (
"bytes"
+ "context"
"errors"
"io"
"net"
@@ -16,7 +17,7 @@ import (
// It carries no state and is reusable simultaneously by multiple goroutines.
//
// The zero value is immediately usabble, but will not follow redirects.
-type Client struct{
+type Client struct {
MaxRedirects int
}
@@ -32,7 +33,7 @@ const DefaultMaxRedirects int = 2
var ExceededMaxRedirects = errors.New("spartan.Client: exceeded MaxRedirects")
// RoundTrip sends a single spartan request and returns its response.
-func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
+func (c Client) RoundTrip(ctx context.Context, request *types.Request) (*types.Response, error) {
if request.Scheme != "spartan" && request.Scheme != "" {
return nil, errors.New("non-spartan protocols not supported")
}
@@ -44,7 +45,7 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
}
addr := net.JoinHostPort(host, port)
- conn, err := net.Dial("tcp", addr)
+ conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", addr)
if err != nil {
return nil, err
}
@@ -90,14 +91,14 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
// Fetch parses a URL string and fetches the spartan resource.
//
// It will resolve any redirects along the way, up to client.MaxRedirects.
-func (c Client) Fetch(url string) (*types.Response, error) {
+func (c Client) Fetch(ctx context.Context, url string) (*types.Response, error) {
u, err := neturl.Parse(url)
if err != nil {
return nil, err
}
for i := 0; i <= c.MaxRedirects; i += 1 {
- response, err := c.RoundTrip(&types.Request{URL: u})
+ response, err := c.RoundTrip(ctx, &types.Request{URL: u})
if err != nil {
return nil, err
}