test(agent): improve TestAgent_Dial tests (#11013)

Refs #11008
This commit is contained in:
Mathias Fredriksson 2023-12-04 13:11:30 +02:00 committed by GitHub
parent b212bd4ac5
commit 70cede8f7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 43 additions and 22 deletions

View File

@ -1547,32 +1547,33 @@ func TestAgent_Dial(t *testing.T) {
t.Run(c.name, func(t *testing.T) {
t.Parallel()
// Setup listener
// The purpose of this test is to ensure that a client can dial a
// listener in the workspace over tailnet.
l := c.setup(t)
defer l.Close()
go func() {
for {
c, err := l.Accept()
if err != nil {
return
}
done := make(chan struct{})
defer func() {
l.Close()
<-done
}()
go testAccept(t, c)
}
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
go func() {
defer close(done)
c, err := l.Accept()
assert.NoError(t, err, "accept connection")
defer c.Close()
testAccept(ctx, t, c)
}()
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
require.True(t, conn.AwaitReachable(context.Background()))
conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
require.True(t, agentConn.AwaitReachable(ctx))
conn, err := agentConn.DialContext(ctx, l.Addr().Network(), l.Addr().String())
require.NoError(t, err)
defer conn1.Close()
conn2, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
require.NoError(t, err)
defer conn2.Close()
testDial(t, conn2)
testDial(t, conn1)
time.Sleep(150 * time.Millisecond)
defer conn.Close()
testDial(ctx, t, conn)
})
}
}
@ -2002,22 +2003,41 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
var dialTestPayload = []byte("dean-was-here123")
func testDial(t *testing.T, c net.Conn) {
func testDial(ctx context.Context, t *testing.T, c net.Conn) {
t.Helper()
if deadline, ok := ctx.Deadline(); ok {
err := c.SetDeadline(deadline)
assert.NoError(t, err)
defer func() {
err := c.SetDeadline(time.Time{})
assert.NoError(t, err)
}()
}
assertWritePayload(t, c, dialTestPayload)
assertReadPayload(t, c, dialTestPayload)
}
func testAccept(t *testing.T, c net.Conn) {
func testAccept(ctx context.Context, t *testing.T, c net.Conn) {
t.Helper()
defer c.Close()
if deadline, ok := ctx.Deadline(); ok {
err := c.SetDeadline(deadline)
assert.NoError(t, err)
defer func() {
err := c.SetDeadline(time.Time{})
assert.NoError(t, err)
}()
}
assertReadPayload(t, c, dialTestPayload)
assertWritePayload(t, c, dialTestPayload)
}
func assertReadPayload(t *testing.T, r io.Reader, payload []byte) {
t.Helper()
b := make([]byte, len(payload)+16)
n, err := r.Read(b)
assert.NoError(t, err, "read payload")
@ -2026,6 +2046,7 @@ func assertReadPayload(t *testing.T, r io.Reader, payload []byte) {
}
func assertWritePayload(t *testing.T, w io.Writer, payload []byte) {
t.Helper()
n, err := w.Write(payload)
assert.NoError(t, err, "write payload")
assert.Equal(t, len(payload), n, "payload length does not match")