package tlsauth_test

import (
	"crypto/tls"
	"crypto/x509"
	"errors"
	"testing"

	"github.com/stretchr/testify/assert"

	"tildegit.org/tjp/gus/contrib/tlsauth"
)

func TestRequireSpecificIdentity(t *testing.T) {
	cert1, err := leafCert("testdata/client1.crt", "testdata/client1.key")
	assert.Nil(t, err)

	cert2, err := leafCert("testdata/client2.crt", "testdata/client2.key")
	assert.Nil(t, err)

	assert.True(t, cert1.Equal(cert1))
	assert.False(t, cert1.Equal(cert2))
	assert.False(t, cert2.Equal(cert1))
	assert.True(t, cert2.Equal(cert2))

	assert.True(t, tlsauth.RequireSpecificIdentity(cert1)(cert1))
	assert.False(t, tlsauth.RequireSpecificIdentity(cert1)(cert2))
	assert.False(t, tlsauth.RequireSpecificIdentity(cert2)(cert1))
	assert.True(t, tlsauth.RequireSpecificIdentity(cert2)(cert2))
}

func leafCert(certfile, keyfile string) (*x509.Certificate, error) {
	cert, err := tls.LoadX509KeyPair(certfile, keyfile)
	if err != nil {
		return nil, err
	}

	if cert.Leaf != nil {
		return cert.Leaf, nil
	}

	if len(cert.Certificate) == 0 {
		return nil, errors.New("no certificate blocks found")
	}

	return x509.ParseCertificate(cert.Certificate[0])
}