diff options
author | tjp <tjp@ctrl-c.club> | 2024-01-13 11:29:17 -0700 |
---|---|---|
committer | tjp <tjp@ctrl-c.club> | 2024-01-13 11:29:17 -0700 |
commit | 4d861a2c395c0926b066014b92999d4dda454b2b (patch) | |
tree | 7bbf7d94cec2a5412a68b830c1e5223918d9bc82 | |
parent | de1490808fa6e4d6749ff29d20cc1a589ec476d1 (diff) |
dial timeouts for clients, and catch up on test fixes
-rw-r--r-- | client.go | 24 | ||||
-rw-r--r-- | contrib/tlsauth/auth_test.go | 2 | ||||
-rw-r--r-- | examples/fetch/main.go | 3 | ||||
-rw-r--r-- | finger/client.go | 9 | ||||
-rw-r--r-- | gemini/client.go | 11 | ||||
-rw-r--r-- | gemini/gemtext/parse_line_test.go | 2 | ||||
-rw-r--r-- | gemini/gemtext/parse_test.go | 4 | ||||
-rw-r--r-- | gemini/roundtrip_test.go | 4 | ||||
-rw-r--r-- | gopher/client.go | 9 | ||||
-rw-r--r-- | gopher/gophermap/testdata/customlist_output.gophermap | 2 | ||||
-rw-r--r-- | nex/client.go | 9 | ||||
-rw-r--r-- | spartan/client.go | 11 | ||||
-rw-r--r-- | tools/sw-fetch/main.go | 53 |
13 files changed, 89 insertions, 54 deletions
@@ -1,12 +1,12 @@ package sliderule import ( + "context" "crypto/tls" "errors" "fmt" "io" "net/http" - "net/url" neturl "net/url" "tildegit.org/tjp/sliderule/finger" @@ -18,7 +18,7 @@ import ( ) type protocolClient interface { - RoundTrip(*Request) (*Response, error) + RoundTrip(context.Context, *Request) (*Response, error) IsRedirect(*Response) bool } @@ -61,23 +61,23 @@ func NewClient(tlsConf *tls.Config) Client { // RoundTrip sends a single request and returns the repsonse. // // If the response is a redirect it will be returned, rather than fetched. -func (c Client) RoundTrip(request *Request) (*Response, error) { +func (c Client) RoundTrip(ctx context.Context, request *Request) (*Response, error) { pc, ok := c.protos[request.Scheme] if !ok { return nil, fmt.Errorf("unrecognized protocol: %s", request.Scheme) } - return pc.RoundTrip(request) + return pc.RoundTrip(ctx, request) } // Fetch collects a resource from a URL including following any redirects. -func (c Client) Fetch(url string) (*Response, error) { +func (c Client) Fetch(ctx context.Context, url string) (*Response, error) { u, err := neturl.Parse(url) if err != nil { return nil, err } for i := 0; i <= c.MaxRedirects; i += 1 { - response, err := c.RoundTrip(&types.Request{URL: u}) + response, err := c.RoundTrip(ctx, &types.Request{URL: u}) if err != nil { return nil, err } @@ -100,23 +100,23 @@ func (c Client) Fetch(url string) (*Response, error) { } // Upload sends a request with a body and returns any redirect response. -func (c Client) Upload(url string, contents io.Reader) (*Response, error) { +func (c Client) Upload(ctx context.Context, url string, contents io.Reader) (*Response, error) { u, err := neturl.Parse(url) if err != nil { return nil, err } switch u.Scheme { case "titan", "spartan", "http", "https": - return c.RoundTrip(&types.Request{URL: u, Meta: contents}) + return c.RoundTrip(ctx, &types.Request{URL: u, Meta: contents}) default: return nil, fmt.Errorf("upload not supported on %s", u.Scheme) } } -func getRedirectLocation(prev *url.URL, proto string, meta any) string { +func getRedirectLocation(prev *neturl.URL, proto string, meta any) string { switch proto { case "gemini", "spartan": - u, _ := url.Parse(meta.(string)) + u, _ := neturl.Parse(meta.(string)) return prev.ResolveReference(u).String() case "http", "https": return meta.(*http.Response).Header.Get("Location") @@ -128,9 +128,9 @@ type httpClient struct { tp *http.Transport } -func (hc httpClient) RoundTrip(request *Request) (*Response, error) { +func (hc httpClient) RoundTrip(ctx context.Context, request *Request) (*Response, error) { body, _ := request.Meta.(io.Reader) - hreq, err := http.NewRequest("GET", request.URL.String(), body) + hreq, err := http.NewRequestWithContext(ctx, "GET", request.URL.String(), body) if err != nil { return nil, err } diff --git a/contrib/tlsauth/auth_test.go b/contrib/tlsauth/auth_test.go index df67159..862d9d8 100644 --- a/contrib/tlsauth/auth_test.go +++ b/contrib/tlsauth/auth_test.go @@ -98,7 +98,7 @@ func requestPath(t *testing.T, client gemini.Client, server sr.Server, path stri u, err := url.Parse("gemini://" + server.Address() + path) require.Nil(t, err) - response, err := client.RoundTrip(&sr.Request{URL: u}) + response, err := client.RoundTrip(context.Background(), &sr.Request{URL: u}) require.Nil(t, err) return response diff --git a/examples/fetch/main.go b/examples/fetch/main.go index 8f03114..d70c941 100644 --- a/examples/fetch/main.go +++ b/examples/fetch/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "io" "log" @@ -33,7 +34,7 @@ func main() { request := &sr.Request{URL: buildURL(os.Args[1])} // fetch the response - response, err := client.RoundTrip(request) + response, err := client.RoundTrip(context.Background(), request) if err != nil { log.Fatal(err) } diff --git a/finger/client.go b/finger/client.go index 0488b79..f89a788 100644 --- a/finger/client.go +++ b/finger/client.go @@ -2,6 +2,7 @@ package finger import ( "bytes" + "context" "errors" "io" "net" @@ -18,7 +19,7 @@ import ( type Client struct{} // RoundTrip sends a single finger request and returns its response. -func (c Client) RoundTrip(request *types.Request) (*types.Response, error) { +func (c Client) RoundTrip(ctx context.Context, request *types.Request) (*types.Response, error) { if request.Scheme != "finger" && request.Scheme != "" { return nil, errors.New("non-finger protocols not supported") } @@ -28,7 +29,7 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) { host = net.JoinHostPort(host, "79") } - conn, err := net.Dial("tcp", host) + conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", host) if err != nil { return nil, err } @@ -55,12 +56,12 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) { } // Fetch resolves a finger query. -func (c Client) Fetch(query string) (*types.Response, error) { +func (c Client) Fetch(ctx context.Context, query string) (*types.Response, error) { req, err := ParseRequest(bytes.NewBufferString(query + "\r\n")) if err != nil { return nil, err } - return c.RoundTrip(req) + return c.RoundTrip(ctx, req) } func (c Client) IsRedirect(_ *types.Response) bool { return false } diff --git a/gemini/client.go b/gemini/client.go index 00e28f6..1e65a39 100644 --- a/gemini/client.go +++ b/gemini/client.go @@ -2,6 +2,7 @@ package gemini import ( "bytes" + "context" "crypto/tls" "errors" "io" @@ -49,7 +50,7 @@ var ExceededMaxRedirects = errors.New("gemini.Client: exceeded MaxRedirects") // // This method will not automatically follow redirects or cache permanent failures or // redirects. -func (client Client) RoundTrip(request *types.Request) (*types.Response, error) { +func (client Client) RoundTrip(ctx context.Context, request *types.Request) (*types.Response, error) { if request.Scheme != "gemini" && request.Scheme != "titan" && request.Scheme != "" { return nil, errors.New("non-gemini protocols not supported") } @@ -64,14 +65,14 @@ func (client Client) RoundTrip(request *types.Request) (*types.Response, error) tlsConf = &tls.Config{InsecureSkipVerify: true} } - conn, err := tls.Dial("tcp", host, tlsConf) + conn, err := (&tls.Dialer{Config: tlsConf}).DialContext(ctx, "tcp", host) if err != nil { return nil, err } defer conn.Close() request.RemoteAddr = conn.RemoteAddr() - st := conn.ConnectionState() + st := conn.(*tls.Conn).ConnectionState() request.TLSState = &st destURL := *request.URL @@ -124,14 +125,14 @@ func (client Client) RoundTrip(request *types.Request) (*types.Response, error) // Fetch parses a URL string and fetches the gemini resource. // // It will resolve any redirects along the way, up to client.MaxRedirects. -func (c Client) Fetch(url string) (*types.Response, error) { +func (c Client) Fetch(ctx context.Context, url string) (*types.Response, error) { u, err := neturl.Parse(url) if err != nil { return nil, err } for i := 0; i <= c.MaxRedirects; i += 1 { - response, err := c.RoundTrip(&types.Request{URL: u}) + response, err := c.RoundTrip(ctx, &types.Request{URL: u}) if err != nil { return nil, err } diff --git a/gemini/gemtext/parse_line_test.go b/gemini/gemtext/parse_line_test.go index 9073df0..a7faf49 100644 --- a/gemini/gemtext/parse_line_test.go +++ b/gemini/gemtext/parse_line_test.go @@ -92,7 +92,7 @@ func TestParsePromptLine(t *testing.T) { if line.Type() != gemtext.LineTypePrompt{ t.Errorf("expected LineTypePrompt, got %d", line.Type()) } - link, ok := line.(gemtext.PromptLine) + link, ok := line.(gemtext.LinkLine) if !ok { t.Fatalf("expected a PromptLine, got %T", line) } diff --git a/gemini/gemtext/parse_test.go b/gemini/gemtext/parse_test.go index 90d2c75..4e7b6ff 100644 --- a/gemini/gemtext/parse_test.go +++ b/gemini/gemtext/parse_test.go @@ -78,8 +78,8 @@ This is some non-blank regular text. assert.Equal(t, gemtext.LineTypePrompt, doc[12].Type()) assert.Equal(t, "=: spartan://foo.bar/baz this should be a spartan prompt\n", string(doc[12].Raw())) - assert.Equal(t, "spartan://foo.bar/baz", doc[12].(gemtext.PromptLine).URL()) - assert.Equal(t, "this should be a spartan prompt", doc[12].(gemtext.PromptLine).Label()) + assert.Equal(t, "spartan://foo.bar/baz", doc[12].(gemtext.LinkLine).URL()) + assert.Equal(t, "this should be a spartan prompt", doc[12].(gemtext.LinkLine).Label()) assertEmptyLine(t, doc[13]) diff --git a/gemini/roundtrip_test.go b/gemini/roundtrip_test.go index 50c1962..4a8097f 100644 --- a/gemini/roundtrip_test.go +++ b/gemini/roundtrip_test.go @@ -12,8 +12,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "tildegit.org/tjp/sliderule/internal/types" "tildegit.org/tjp/sliderule/gemini" + "tildegit.org/tjp/sliderule/internal/types" ) func TestRoundTrip(t *testing.T) { @@ -36,7 +36,7 @@ func TestRoundTrip(t *testing.T) { require.Nil(t, err) cli := gemini.NewClient(testClientTLS()) - response, err := cli.RoundTrip(&types.Request{URL: u}) + response, err := cli.RoundTrip(context.Background(), &types.Request{URL: u}) require.Nil(t, err) assert.Equal(t, gemini.StatusSuccess, response.Status) diff --git a/gopher/client.go b/gopher/client.go index 0ae1730..6d46323 100644 --- a/gopher/client.go +++ b/gopher/client.go @@ -2,6 +2,7 @@ package gopher import ( "bytes" + "context" "errors" "io" "net" @@ -18,7 +19,7 @@ import ( type Client struct{} // RoundTrip sends a single gopher request and returns its response. -func (c Client) RoundTrip(request *types.Request) (*types.Response, error) { +func (c Client) RoundTrip(ctx context.Context, request *types.Request) (*types.Response, error) { if request.Scheme != "gopher" && request.Scheme != "" { return nil, errors.New("non-gopher protocols not supported") } @@ -28,7 +29,7 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) { host = net.JoinHostPort(host, "70") } - conn, err := net.Dial("tcp", host) + conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", host) if err != nil { return nil, err } @@ -56,12 +57,12 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) { } // Fetch parses a URL string and fetches the gopher resource. -func (c Client) Fetch(url string) (*types.Response, error) { +func (c Client) Fetch(ctx context.Context, url string) (*types.Response, error) { u, err := neturl.Parse(url) if err != nil { return nil, err } - return c.RoundTrip(&types.Request{URL: u}) + return c.RoundTrip(ctx, &types.Request{URL: u}) } func (c Client) IsRedirect(_ *types.Response) bool { return false } diff --git a/gopher/gophermap/testdata/customlist_output.gophermap b/gopher/gophermap/testdata/customlist_output.gophermap index 0b4e334..82330e0 100644 --- a/gopher/gophermap/testdata/customlist_output.gophermap +++ b/gopher/gophermap/testdata/customlist_output.gophermap @@ -11,6 +11,6 @@ i /customlist.gophermap localhost.localdomain 70 0file4.txt /file4.txt localhost.localdomain 70
1subdir title /subdir localhost.localdomain 70
1subdir2 title /subdir2 localhost.localdomain 70
-9uptime /uptime localhost.localdomain 70
+0uptime /uptime localhost.localdomain 70
1uptime_output.gophermap /uptime_output.gophermap localhost.localdomain 70
.
diff --git a/nex/client.go b/nex/client.go index 4a34903..4fa5265 100644 --- a/nex/client.go +++ b/nex/client.go @@ -2,6 +2,7 @@ package nex import ( "bytes" + "context" "errors" "io" "net" @@ -18,7 +19,7 @@ import ( type Client struct{} // RoundTrip sends a single nex request and returns its response. -func (c Client) RoundTrip(request *types.Request) (*types.Response, error) { +func (c Client) RoundTrip(ctx context.Context, request *types.Request) (*types.Response, error) { if request.Scheme != "nex" && request.Scheme != "" { return nil, errors.New("non-nex protocols not supported") } @@ -28,7 +29,7 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) { host = net.JoinHostPort(host, "1900") } - conn, err := net.Dial("tcp", host) + conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", host) if err != nil { return nil, err } @@ -50,12 +51,12 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) { } // Fetch builds and sends a nex request, and returns the response. -func (c Client) Fetch(url string) (*types.Response, error) { +func (c Client) Fetch(ctx context.Context, url string) (*types.Response, error) { u, err := neturl.Parse(url) if err != nil { return nil, err } - return c.RoundTrip(&types.Request{URL: u}) + return c.RoundTrip(ctx, &types.Request{URL: u}) } func (c Client) IsRedirect(response *types.Response) bool { return false } diff --git a/spartan/client.go b/spartan/client.go index 81f2132..d77e791 100644 --- a/spartan/client.go +++ b/spartan/client.go @@ -2,6 +2,7 @@ package spartan import ( "bytes" + "context" "errors" "io" "net" @@ -16,7 +17,7 @@ import ( // It carries no state and is reusable simultaneously by multiple goroutines. // // The zero value is immediately usabble, but will not follow redirects. -type Client struct{ +type Client struct { MaxRedirects int } @@ -32,7 +33,7 @@ const DefaultMaxRedirects int = 2 var ExceededMaxRedirects = errors.New("spartan.Client: exceeded MaxRedirects") // RoundTrip sends a single spartan request and returns its response. -func (c Client) RoundTrip(request *types.Request) (*types.Response, error) { +func (c Client) RoundTrip(ctx context.Context, request *types.Request) (*types.Response, error) { if request.Scheme != "spartan" && request.Scheme != "" { return nil, errors.New("non-spartan protocols not supported") } @@ -44,7 +45,7 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) { } addr := net.JoinHostPort(host, port) - conn, err := net.Dial("tcp", addr) + conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", addr) if err != nil { return nil, err } @@ -90,14 +91,14 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) { // Fetch parses a URL string and fetches the spartan resource. // // It will resolve any redirects along the way, up to client.MaxRedirects. -func (c Client) Fetch(url string) (*types.Response, error) { +func (c Client) Fetch(ctx context.Context, url string) (*types.Response, error) { u, err := neturl.Parse(url) if err != nil { return nil, err } for i := 0; i <= c.MaxRedirects; i += 1 { - response, err := c.RoundTrip(&types.Request{URL: u}) + response, err := c.RoundTrip(ctx, &types.Request{URL: u}) if err != nil { return nil, err } diff --git a/tools/sw-fetch/main.go b/tools/sw-fetch/main.go index c2cb3e8..76e414c 100644 --- a/tools/sw-fetch/main.go +++ b/tools/sw-fetch/main.go @@ -1,12 +1,14 @@ package main import ( + "context" "crypto/tls" "fmt" "io" "net/http" "net/url" "os" + "time" "tildegit.org/tjp/sliderule" "tildegit.org/tjp/sliderule/gemini" @@ -17,29 +19,45 @@ const usage = `Resource fetcher for the small web. Usage: sw-fetch (-h | --help) - sw-fetch [-v | --verbose] [-o PATH | --output PATH] [-k | --keyfile PATH] [ -c | --certfile PATH ] [ -s | --skip-verify ] [ -u | --upload ] URL + sw-fetch + [-v | --verbose] + [-o PATH | --output PATH] + [-k | --keyfile PATH] + [ -c | --certfile PATH ] + [ -s | --skip-verify ] + [ -t | --timeout TIMEOUT ] + [ -u | --upload ] + URL Options: - -h --help Show this screen. - -v --verbose Display more diagnostic information on standard error. - -o --output PATH Send the fetched resource to PATH instead of standard out. - -k --keyfile PATH Path to the TLS key file to use. - -c --certfile PATH Path to the TLS certificate file to use. - -s --skip-verify Don't verify server TLS certificates. - -u --upload Use stdin as the request body on supported protocols and don't follow redirects. + -h --help Show this screen. + -v --verbose Display more diagnostic information on standard error. + -o --output PATH Send the fetched resource to PATH instead of standard out. + -k --keyfile PATH Path to the TLS key file to use. + -c --certfile PATH Path to the TLS certificate file to use. + -s --skip-verify Don't verify server TLS certificates. + -t --timeout TIMEOUT Fail after the given timeout (like "15s"). + -u --upload Use stdin as the request body on supported protocols and don't follow redirects. ` func main() { conf := configure() cl := sliderule.NewClient(conf.clientTLS) + ctx := context.Background() + if conf.timeout != 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, conf.timeout) + defer cancel() + } + var response *sliderule.Response var err error if conf.upload { - response, err = cl.Upload(conf.url.String(), os.Stdin) + response, err = cl.Upload(ctx, conf.url.String(), os.Stdin) } else { - response, err = cl.Fetch(conf.url.String()) + response, err = cl.Fetch(ctx, conf.url.String()) } if err != nil { fail(err.Error() + "\n") @@ -61,6 +79,7 @@ type config struct { output io.WriteCloser url *url.URL clientTLS *tls.Config + timeout time.Duration } func configure() config { @@ -72,6 +91,7 @@ func configure() config { key := "" cert := "" verify := true + var err error for i := 1; i <= len(os.Args)-1; i += 1 { switch os.Args[i] { @@ -87,12 +107,11 @@ func configure() config { out := os.Args[i+1] if out != "-" { - output, err := os.OpenFile(out, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) + conf.output, err = os.OpenFile(out, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644) if err != nil { fmt.Println(err.Error()) failf("'%s' is not a valid path\n", out) } - conf.output = output } i += 1 @@ -112,6 +131,16 @@ func configure() config { cert = os.Args[i] case "-s", "--skip-verify": verify = false + case "-t", "--timeout": + if i+1 == len(os.Args)-1 { + fail(usage) + } + + i += 1 + conf.timeout, err = time.ParseDuration(os.Args[i]) + if err != nil { + fail(err.Error()) + } case "-u", "--upload": conf.upload = true } |