coder/tailnet/service_test.go

145 lines
4.2 KiB
Go

package tailnet_test
import (
"io"
"net"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"tailscale.com/tailcfg"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
)
func TestClientService_ServeClient_V2(t *testing.T) {
t.Parallel()
fCoord := tailnettest.NewFakeCoordinator()
var coord tailnet.Coordinator = fCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
derpMap := &tailcfg.DERPMap{Regions: map[int]*tailcfg.DERPRegion{999: {RegionCode: "test"}}}
uut, err := tailnet.NewClientService(
logger, &coordPtr,
time.Millisecond, func() *tailcfg.DERPMap { return derpMap },
)
require.NoError(t, err)
ctx := testutil.Context(t, testutil.WaitShort)
c, s := net.Pipe()
defer c.Close()
defer s.Close()
clientID := uuid.MustParse("10000001-0000-0000-0000-000000000000")
agentID := uuid.MustParse("20000001-0000-0000-0000-000000000000")
errCh := make(chan error, 1)
go func() {
err := uut.ServeClient(ctx, "2.0", s, clientID, agentID)
t.Logf("ServeClient returned; err=%v", err)
errCh <- err
}()
client, err := tailnet.NewDRPCClient(c, logger)
require.NoError(t, err)
// Coordinate
stream, err := client.Coordinate(ctx)
require.NoError(t, err)
defer stream.Close()
err = stream.Send(&proto.CoordinateRequest{
UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: &proto.Node{PreferredDerp: 11}},
})
require.NoError(t, err)
call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
require.NotNil(t, call)
require.Equal(t, call.ID, clientID)
require.Equal(t, call.Name, "client")
require.NoError(t, call.Auth.Authorize(&proto.CoordinateRequest{
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
}))
req := testutil.RequireRecvCtx(ctx, t, call.Reqs)
require.Equal(t, int32(11), req.GetUpdateSelf().GetNode().GetPreferredDerp())
call.Resps <- &proto.CoordinateResponse{PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{
{
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
Node: &proto.Node{PreferredDerp: 22},
Id: agentID[:],
},
}}
resp, err := stream.Recv()
require.NoError(t, err)
u := resp.GetPeerUpdates()
require.Len(t, u, 1)
require.Equal(t, int32(22), u[0].GetNode().GetPreferredDerp())
err = stream.Close()
require.NoError(t, err)
// DERP Map
dms, err := client.StreamDERPMaps(ctx, &proto.StreamDERPMapsRequest{})
require.NoError(t, err)
gotDermMap, err := dms.Recv()
require.NoError(t, err)
require.Equal(t, "test", gotDermMap.GetRegions()[999].GetRegionCode())
err = dms.Close()
require.NoError(t, err)
// RPCs closed; we need to close the Conn to end the session.
err = c.Close()
require.NoError(t, err)
err = testutil.RequireRecvCtx(ctx, t, errCh)
require.True(t, xerrors.Is(err, io.EOF) || xerrors.Is(err, io.ErrClosedPipe))
}
func TestClientService_ServeClient_V1(t *testing.T) {
t.Parallel()
fCoord := tailnettest.NewFakeCoordinator()
var coord tailnet.Coordinator = fCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
uut, err := tailnet.NewClientService(logger, &coordPtr, 0, nil)
require.NoError(t, err)
ctx := testutil.Context(t, testutil.WaitShort)
c, s := net.Pipe()
defer c.Close()
defer s.Close()
clientID := uuid.MustParse("10000001-0000-0000-0000-000000000000")
agentID := uuid.MustParse("20000001-0000-0000-0000-000000000000")
errCh := make(chan error, 1)
go func() {
err := uut.ServeClient(ctx, "1.0", s, clientID, agentID)
t.Logf("ServeClient returned; err=%v", err)
errCh <- err
}()
call := testutil.RequireRecvCtx(ctx, t, fCoord.ServeClientCalls)
require.NotNil(t, call)
require.Equal(t, call.ID, clientID)
require.Equal(t, call.Agent, agentID)
require.Equal(t, s, call.Conn)
expectedError := xerrors.New("test error")
select {
case call.ErrCh <- expectedError:
// ok!
case <-ctx.Done():
t.Fatalf("timeout sending error")
}
err = testutil.RequireRecvCtx(ctx, t, errCh)
require.ErrorIs(t, err, expectedError)
}