summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortjp <tjp@ctrl-c.club>2024-01-13 11:29:17 -0700
committertjp <tjp@ctrl-c.club>2024-01-13 11:29:17 -0700
commit4d861a2c395c0926b066014b92999d4dda454b2b (patch)
tree7bbf7d94cec2a5412a68b830c1e5223918d9bc82
parentde1490808fa6e4d6749ff29d20cc1a589ec476d1 (diff)
dial timeouts for clients, and catch up on test fixes
-rw-r--r--client.go24
-rw-r--r--contrib/tlsauth/auth_test.go2
-rw-r--r--examples/fetch/main.go3
-rw-r--r--finger/client.go9
-rw-r--r--gemini/client.go11
-rw-r--r--gemini/gemtext/parse_line_test.go2
-rw-r--r--gemini/gemtext/parse_test.go4
-rw-r--r--gemini/roundtrip_test.go4
-rw-r--r--gopher/client.go9
-rw-r--r--gopher/gophermap/testdata/customlist_output.gophermap2
-rw-r--r--nex/client.go9
-rw-r--r--spartan/client.go11
-rw-r--r--tools/sw-fetch/main.go53
13 files changed, 89 insertions, 54 deletions
diff --git a/client.go b/client.go
index 217a777..8119200 100644
--- a/client.go
+++ b/client.go
@@ -1,12 +1,12 @@
package sliderule
import (
+ "context"
"crypto/tls"
"errors"
"fmt"
"io"
"net/http"
- "net/url"
neturl "net/url"
"tildegit.org/tjp/sliderule/finger"
@@ -18,7 +18,7 @@ import (
)
type protocolClient interface {
- RoundTrip(*Request) (*Response, error)
+ RoundTrip(context.Context, *Request) (*Response, error)
IsRedirect(*Response) bool
}
@@ -61,23 +61,23 @@ func NewClient(tlsConf *tls.Config) Client {
// RoundTrip sends a single request and returns the repsonse.
//
// If the response is a redirect it will be returned, rather than fetched.
-func (c Client) RoundTrip(request *Request) (*Response, error) {
+func (c Client) RoundTrip(ctx context.Context, request *Request) (*Response, error) {
pc, ok := c.protos[request.Scheme]
if !ok {
return nil, fmt.Errorf("unrecognized protocol: %s", request.Scheme)
}
- return pc.RoundTrip(request)
+ return pc.RoundTrip(ctx, request)
}
// Fetch collects a resource from a URL including following any redirects.
-func (c Client) Fetch(url string) (*Response, error) {
+func (c Client) Fetch(ctx context.Context, url string) (*Response, error) {
u, err := neturl.Parse(url)
if err != nil {
return nil, err
}
for i := 0; i <= c.MaxRedirects; i += 1 {
- response, err := c.RoundTrip(&types.Request{URL: u})
+ response, err := c.RoundTrip(ctx, &types.Request{URL: u})
if err != nil {
return nil, err
}
@@ -100,23 +100,23 @@ func (c Client) Fetch(url string) (*Response, error) {
}
// Upload sends a request with a body and returns any redirect response.
-func (c Client) Upload(url string, contents io.Reader) (*Response, error) {
+func (c Client) Upload(ctx context.Context, url string, contents io.Reader) (*Response, error) {
u, err := neturl.Parse(url)
if err != nil {
return nil, err
}
switch u.Scheme {
case "titan", "spartan", "http", "https":
- return c.RoundTrip(&types.Request{URL: u, Meta: contents})
+ return c.RoundTrip(ctx, &types.Request{URL: u, Meta: contents})
default:
return nil, fmt.Errorf("upload not supported on %s", u.Scheme)
}
}
-func getRedirectLocation(prev *url.URL, proto string, meta any) string {
+func getRedirectLocation(prev *neturl.URL, proto string, meta any) string {
switch proto {
case "gemini", "spartan":
- u, _ := url.Parse(meta.(string))
+ u, _ := neturl.Parse(meta.(string))
return prev.ResolveReference(u).String()
case "http", "https":
return meta.(*http.Response).Header.Get("Location")
@@ -128,9 +128,9 @@ type httpClient struct {
tp *http.Transport
}
-func (hc httpClient) RoundTrip(request *Request) (*Response, error) {
+func (hc httpClient) RoundTrip(ctx context.Context, request *Request) (*Response, error) {
body, _ := request.Meta.(io.Reader)
- hreq, err := http.NewRequest("GET", request.URL.String(), body)
+ hreq, err := http.NewRequestWithContext(ctx, "GET", request.URL.String(), body)
if err != nil {
return nil, err
}
diff --git a/contrib/tlsauth/auth_test.go b/contrib/tlsauth/auth_test.go
index df67159..862d9d8 100644
--- a/contrib/tlsauth/auth_test.go
+++ b/contrib/tlsauth/auth_test.go
@@ -98,7 +98,7 @@ func requestPath(t *testing.T, client gemini.Client, server sr.Server, path stri
u, err := url.Parse("gemini://" + server.Address() + path)
require.Nil(t, err)
- response, err := client.RoundTrip(&sr.Request{URL: u})
+ response, err := client.RoundTrip(context.Background(), &sr.Request{URL: u})
require.Nil(t, err)
return response
diff --git a/examples/fetch/main.go b/examples/fetch/main.go
index 8f03114..d70c941 100644
--- a/examples/fetch/main.go
+++ b/examples/fetch/main.go
@@ -1,6 +1,7 @@
package main
import (
+ "context"
"fmt"
"io"
"log"
@@ -33,7 +34,7 @@ func main() {
request := &sr.Request{URL: buildURL(os.Args[1])}
// fetch the response
- response, err := client.RoundTrip(request)
+ response, err := client.RoundTrip(context.Background(), request)
if err != nil {
log.Fatal(err)
}
diff --git a/finger/client.go b/finger/client.go
index 0488b79..f89a788 100644
--- a/finger/client.go
+++ b/finger/client.go
@@ -2,6 +2,7 @@ package finger
import (
"bytes"
+ "context"
"errors"
"io"
"net"
@@ -18,7 +19,7 @@ import (
type Client struct{}
// RoundTrip sends a single finger request and returns its response.
-func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
+func (c Client) RoundTrip(ctx context.Context, request *types.Request) (*types.Response, error) {
if request.Scheme != "finger" && request.Scheme != "" {
return nil, errors.New("non-finger protocols not supported")
}
@@ -28,7 +29,7 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
host = net.JoinHostPort(host, "79")
}
- conn, err := net.Dial("tcp", host)
+ conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", host)
if err != nil {
return nil, err
}
@@ -55,12 +56,12 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
}
// Fetch resolves a finger query.
-func (c Client) Fetch(query string) (*types.Response, error) {
+func (c Client) Fetch(ctx context.Context, query string) (*types.Response, error) {
req, err := ParseRequest(bytes.NewBufferString(query + "\r\n"))
if err != nil {
return nil, err
}
- return c.RoundTrip(req)
+ return c.RoundTrip(ctx, req)
}
func (c Client) IsRedirect(_ *types.Response) bool { return false }
diff --git a/gemini/client.go b/gemini/client.go
index 00e28f6..1e65a39 100644
--- a/gemini/client.go
+++ b/gemini/client.go
@@ -2,6 +2,7 @@ package gemini
import (
"bytes"
+ "context"
"crypto/tls"
"errors"
"io"
@@ -49,7 +50,7 @@ var ExceededMaxRedirects = errors.New("gemini.Client: exceeded MaxRedirects")
//
// This method will not automatically follow redirects or cache permanent failures or
// redirects.
-func (client Client) RoundTrip(request *types.Request) (*types.Response, error) {
+func (client Client) RoundTrip(ctx context.Context, request *types.Request) (*types.Response, error) {
if request.Scheme != "gemini" && request.Scheme != "titan" && request.Scheme != "" {
return nil, errors.New("non-gemini protocols not supported")
}
@@ -64,14 +65,14 @@ func (client Client) RoundTrip(request *types.Request) (*types.Response, error)
tlsConf = &tls.Config{InsecureSkipVerify: true}
}
- conn, err := tls.Dial("tcp", host, tlsConf)
+ conn, err := (&tls.Dialer{Config: tlsConf}).DialContext(ctx, "tcp", host)
if err != nil {
return nil, err
}
defer conn.Close()
request.RemoteAddr = conn.RemoteAddr()
- st := conn.ConnectionState()
+ st := conn.(*tls.Conn).ConnectionState()
request.TLSState = &st
destURL := *request.URL
@@ -124,14 +125,14 @@ func (client Client) RoundTrip(request *types.Request) (*types.Response, error)
// Fetch parses a URL string and fetches the gemini resource.
//
// It will resolve any redirects along the way, up to client.MaxRedirects.
-func (c Client) Fetch(url string) (*types.Response, error) {
+func (c Client) Fetch(ctx context.Context, url string) (*types.Response, error) {
u, err := neturl.Parse(url)
if err != nil {
return nil, err
}
for i := 0; i <= c.MaxRedirects; i += 1 {
- response, err := c.RoundTrip(&types.Request{URL: u})
+ response, err := c.RoundTrip(ctx, &types.Request{URL: u})
if err != nil {
return nil, err
}
diff --git a/gemini/gemtext/parse_line_test.go b/gemini/gemtext/parse_line_test.go
index 9073df0..a7faf49 100644
--- a/gemini/gemtext/parse_line_test.go
+++ b/gemini/gemtext/parse_line_test.go
@@ -92,7 +92,7 @@ func TestParsePromptLine(t *testing.T) {
if line.Type() != gemtext.LineTypePrompt{
t.Errorf("expected LineTypePrompt, got %d", line.Type())
}
- link, ok := line.(gemtext.PromptLine)
+ link, ok := line.(gemtext.LinkLine)
if !ok {
t.Fatalf("expected a PromptLine, got %T", line)
}
diff --git a/gemini/gemtext/parse_test.go b/gemini/gemtext/parse_test.go
index 90d2c75..4e7b6ff 100644
--- a/gemini/gemtext/parse_test.go
+++ b/gemini/gemtext/parse_test.go
@@ -78,8 +78,8 @@ This is some non-blank regular text.
assert.Equal(t, gemtext.LineTypePrompt, doc[12].Type())
assert.Equal(t, "=: spartan://foo.bar/baz this should be a spartan prompt\n", string(doc[12].Raw()))
- assert.Equal(t, "spartan://foo.bar/baz", doc[12].(gemtext.PromptLine).URL())
- assert.Equal(t, "this should be a spartan prompt", doc[12].(gemtext.PromptLine).Label())
+ assert.Equal(t, "spartan://foo.bar/baz", doc[12].(gemtext.LinkLine).URL())
+ assert.Equal(t, "this should be a spartan prompt", doc[12].(gemtext.LinkLine).Label())
assertEmptyLine(t, doc[13])
diff --git a/gemini/roundtrip_test.go b/gemini/roundtrip_test.go
index 50c1962..4a8097f 100644
--- a/gemini/roundtrip_test.go
+++ b/gemini/roundtrip_test.go
@@ -12,8 +12,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "tildegit.org/tjp/sliderule/internal/types"
"tildegit.org/tjp/sliderule/gemini"
+ "tildegit.org/tjp/sliderule/internal/types"
)
func TestRoundTrip(t *testing.T) {
@@ -36,7 +36,7 @@ func TestRoundTrip(t *testing.T) {
require.Nil(t, err)
cli := gemini.NewClient(testClientTLS())
- response, err := cli.RoundTrip(&types.Request{URL: u})
+ response, err := cli.RoundTrip(context.Background(), &types.Request{URL: u})
require.Nil(t, err)
assert.Equal(t, gemini.StatusSuccess, response.Status)
diff --git a/gopher/client.go b/gopher/client.go
index 0ae1730..6d46323 100644
--- a/gopher/client.go
+++ b/gopher/client.go
@@ -2,6 +2,7 @@ package gopher
import (
"bytes"
+ "context"
"errors"
"io"
"net"
@@ -18,7 +19,7 @@ import (
type Client struct{}
// RoundTrip sends a single gopher request and returns its response.
-func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
+func (c Client) RoundTrip(ctx context.Context, request *types.Request) (*types.Response, error) {
if request.Scheme != "gopher" && request.Scheme != "" {
return nil, errors.New("non-gopher protocols not supported")
}
@@ -28,7 +29,7 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
host = net.JoinHostPort(host, "70")
}
- conn, err := net.Dial("tcp", host)
+ conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", host)
if err != nil {
return nil, err
}
@@ -56,12 +57,12 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
}
// Fetch parses a URL string and fetches the gopher resource.
-func (c Client) Fetch(url string) (*types.Response, error) {
+func (c Client) Fetch(ctx context.Context, url string) (*types.Response, error) {
u, err := neturl.Parse(url)
if err != nil {
return nil, err
}
- return c.RoundTrip(&types.Request{URL: u})
+ return c.RoundTrip(ctx, &types.Request{URL: u})
}
func (c Client) IsRedirect(_ *types.Response) bool { return false }
diff --git a/gopher/gophermap/testdata/customlist_output.gophermap b/gopher/gophermap/testdata/customlist_output.gophermap
index 0b4e334..82330e0 100644
--- a/gopher/gophermap/testdata/customlist_output.gophermap
+++ b/gopher/gophermap/testdata/customlist_output.gophermap
@@ -11,6 +11,6 @@ i /customlist.gophermap localhost.localdomain 70
0file4.txt /file4.txt localhost.localdomain 70
1subdir title /subdir localhost.localdomain 70
1subdir2 title /subdir2 localhost.localdomain 70
-9uptime /uptime localhost.localdomain 70
+0uptime /uptime localhost.localdomain 70
1uptime_output.gophermap /uptime_output.gophermap localhost.localdomain 70
.
diff --git a/nex/client.go b/nex/client.go
index 4a34903..4fa5265 100644
--- a/nex/client.go
+++ b/nex/client.go
@@ -2,6 +2,7 @@ package nex
import (
"bytes"
+ "context"
"errors"
"io"
"net"
@@ -18,7 +19,7 @@ import (
type Client struct{}
// RoundTrip sends a single nex request and returns its response.
-func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
+func (c Client) RoundTrip(ctx context.Context, request *types.Request) (*types.Response, error) {
if request.Scheme != "nex" && request.Scheme != "" {
return nil, errors.New("non-nex protocols not supported")
}
@@ -28,7 +29,7 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
host = net.JoinHostPort(host, "1900")
}
- conn, err := net.Dial("tcp", host)
+ conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", host)
if err != nil {
return nil, err
}
@@ -50,12 +51,12 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
}
// Fetch builds and sends a nex request, and returns the response.
-func (c Client) Fetch(url string) (*types.Response, error) {
+func (c Client) Fetch(ctx context.Context, url string) (*types.Response, error) {
u, err := neturl.Parse(url)
if err != nil {
return nil, err
}
- return c.RoundTrip(&types.Request{URL: u})
+ return c.RoundTrip(ctx, &types.Request{URL: u})
}
func (c Client) IsRedirect(response *types.Response) bool { return false }
diff --git a/spartan/client.go b/spartan/client.go
index 81f2132..d77e791 100644
--- a/spartan/client.go
+++ b/spartan/client.go
@@ -2,6 +2,7 @@ package spartan
import (
"bytes"
+ "context"
"errors"
"io"
"net"
@@ -16,7 +17,7 @@ import (
// It carries no state and is reusable simultaneously by multiple goroutines.
//
// The zero value is immediately usabble, but will not follow redirects.
-type Client struct{
+type Client struct {
MaxRedirects int
}
@@ -32,7 +33,7 @@ const DefaultMaxRedirects int = 2
var ExceededMaxRedirects = errors.New("spartan.Client: exceeded MaxRedirects")
// RoundTrip sends a single spartan request and returns its response.
-func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
+func (c Client) RoundTrip(ctx context.Context, request *types.Request) (*types.Response, error) {
if request.Scheme != "spartan" && request.Scheme != "" {
return nil, errors.New("non-spartan protocols not supported")
}
@@ -44,7 +45,7 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
}
addr := net.JoinHostPort(host, port)
- conn, err := net.Dial("tcp", addr)
+ conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", addr)
if err != nil {
return nil, err
}
@@ -90,14 +91,14 @@ func (c Client) RoundTrip(request *types.Request) (*types.Response, error) {
// Fetch parses a URL string and fetches the spartan resource.
//
// It will resolve any redirects along the way, up to client.MaxRedirects.
-func (c Client) Fetch(url string) (*types.Response, error) {
+func (c Client) Fetch(ctx context.Context, url string) (*types.Response, error) {
u, err := neturl.Parse(url)
if err != nil {
return nil, err
}
for i := 0; i <= c.MaxRedirects; i += 1 {
- response, err := c.RoundTrip(&types.Request{URL: u})
+ response, err := c.RoundTrip(ctx, &types.Request{URL: u})
if err != nil {
return nil, err
}
diff --git a/tools/sw-fetch/main.go b/tools/sw-fetch/main.go
index c2cb3e8..76e414c 100644
--- a/tools/sw-fetch/main.go
+++ b/tools/sw-fetch/main.go
@@ -1,12 +1,14 @@
package main
import (
+ "context"
"crypto/tls"
"fmt"
"io"
"net/http"
"net/url"
"os"
+ "time"
"tildegit.org/tjp/sliderule"
"tildegit.org/tjp/sliderule/gemini"
@@ -17,29 +19,45 @@ const usage = `Resource fetcher for the small web.
Usage:
sw-fetch (-h | --help)
- sw-fetch [-v | --verbose] [-o PATH | --output PATH] [-k | --keyfile PATH] [ -c | --certfile PATH ] [ -s | --skip-verify ] [ -u | --upload ] URL
+ sw-fetch
+ [-v | --verbose]
+ [-o PATH | --output PATH]
+ [-k | --keyfile PATH]
+ [ -c | --certfile PATH ]
+ [ -s | --skip-verify ]
+ [ -t | --timeout TIMEOUT ]
+ [ -u | --upload ]
+ URL
Options:
- -h --help Show this screen.
- -v --verbose Display more diagnostic information on standard error.
- -o --output PATH Send the fetched resource to PATH instead of standard out.
- -k --keyfile PATH Path to the TLS key file to use.
- -c --certfile PATH Path to the TLS certificate file to use.
- -s --skip-verify Don't verify server TLS certificates.
- -u --upload Use stdin as the request body on supported protocols and don't follow redirects.
+ -h --help Show this screen.
+ -v --verbose Display more diagnostic information on standard error.
+ -o --output PATH Send the fetched resource to PATH instead of standard out.
+ -k --keyfile PATH Path to the TLS key file to use.
+ -c --certfile PATH Path to the TLS certificate file to use.
+ -s --skip-verify Don't verify server TLS certificates.
+ -t --timeout TIMEOUT Fail after the given timeout (like "15s").
+ -u --upload Use stdin as the request body on supported protocols and don't follow redirects.
`
func main() {
conf := configure()
cl := sliderule.NewClient(conf.clientTLS)
+ ctx := context.Background()
+ if conf.timeout != 0 {
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithTimeout(ctx, conf.timeout)
+ defer cancel()
+ }
+
var response *sliderule.Response
var err error
if conf.upload {
- response, err = cl.Upload(conf.url.String(), os.Stdin)
+ response, err = cl.Upload(ctx, conf.url.String(), os.Stdin)
} else {
- response, err = cl.Fetch(conf.url.String())
+ response, err = cl.Fetch(ctx, conf.url.String())
}
if err != nil {
fail(err.Error() + "\n")
@@ -61,6 +79,7 @@ type config struct {
output io.WriteCloser
url *url.URL
clientTLS *tls.Config
+ timeout time.Duration
}
func configure() config {
@@ -72,6 +91,7 @@ func configure() config {
key := ""
cert := ""
verify := true
+ var err error
for i := 1; i <= len(os.Args)-1; i += 1 {
switch os.Args[i] {
@@ -87,12 +107,11 @@ func configure() config {
out := os.Args[i+1]
if out != "-" {
- output, err := os.OpenFile(out, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
+ conf.output, err = os.OpenFile(out, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644)
if err != nil {
fmt.Println(err.Error())
failf("'%s' is not a valid path\n", out)
}
- conf.output = output
}
i += 1
@@ -112,6 +131,16 @@ func configure() config {
cert = os.Args[i]
case "-s", "--skip-verify":
verify = false
+ case "-t", "--timeout":
+ if i+1 == len(os.Args)-1 {
+ fail(usage)
+ }
+
+ i += 1
+ conf.timeout, err = time.ParseDuration(os.Args[i])
+ if err != nil {
+ fail(err.Error())
+ }
case "-u", "--upload":
conf.upload = true
}