From 4d861a2c395c0926b066014b92999d4dda454b2b Mon Sep 17 00:00:00 2001 From: tjp Date: Sat, 13 Jan 2024 11:29:17 -0700 Subject: dial timeouts for clients, and catch up on test fixes --- spartan/client.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) (limited to 'spartan/client.go') diff --git a/spartan/client.go b/spartan/client.go index 81f2132..d77e791 100644 --- a/spartan/client.go +++ b/spartan/client.go @@ -2,6 +2,7 @@ package spartan import ( "bytes" + "context" "errors" "io" "net" @@ -16,7 +17,7 @@ import ( // It carries no state and is reusable simultaneously by multiple goroutines. // // The zero value is immediately usabble, but will not follow redirects. -type Client struct{ +type Client struct { MaxRedirects int } @@ -32,7 +33,7 @@ const DefaultMaxRedirects int = 2 var ExceededMaxRedirects = errors.New("spartan.Client: exceeded MaxRedirects") // RoundTrip sends a single spartan request and returns its response. -func (c Client) RoundTrip(request *types.Request) (*types.Response, error) { +func (c Client) RoundTrip(ctx context.Context, request *types.Request) (*types.Response, error) { if request.Scheme != "spartan" && request.Scheme != "" { return nil, errors.New("non-spartan protocols not supported") } @@ -44,7 +45,7 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) { } addr := net.JoinHostPort(host, port) - conn, err := net.Dial("tcp", addr) + conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", addr) if err != nil { return nil, err } @@ -90,14 +91,14 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) { // Fetch parses a URL string and fetches the spartan resource. // // It will resolve any redirects along the way, up to client.MaxRedirects. -func (c Client) Fetch(url string) (*types.Response, error) { +func (c Client) Fetch(ctx context.Context, url string) (*types.Response, error) { u, err := neturl.Parse(url) if err != nil { return nil, err } for i := 0; i <= c.MaxRedirects; i += 1 { - response, err := c.RoundTrip(&types.Request{URL: u}) + response, err := c.RoundTrip(ctx, &types.Request{URL: u}) if err != nil { return nil, err } -- cgit v1.2.3