From a58e4febb9feef489b026eaa8202e5389e6d1084 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Fri, 15 Dec 2023 12:48:39 +0400 Subject: [PATCH] feat: add tailnet v2 Service and Client (#11225) Part of #10532 Adds a tailnet ClientService that accepts a net.Conn and serves v1 or v2 of the tailnet API. Also adds a DRPCService that implements the DRPC interface for the v2 API. This component is within the ClientService, but needs to be reusable and exported so that we can also embed it in the Agent API. Finally, includes a NewDRPCClient function that takes a net.Conn and runs dRPC in yamux over it on the client side. --- tailnet/client.go | 22 +++++ tailnet/service.go | 207 ++++++++++++++++++++++++++++++++++++++-- tailnet/service_test.go | 192 +++++++++++++++++++++++++++++++++++++ 3 files changed, 411 insertions(+), 10 deletions(-) create mode 100644 tailnet/client.go diff --git a/tailnet/client.go b/tailnet/client.go new file mode 100644 index 0000000000..db00a9d954 --- /dev/null +++ b/tailnet/client.go @@ -0,0 +1,22 @@ +package tailnet + +import ( + "io" + "net" + + "github.com/hashicorp/yamux" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/codersdk/drpc" + "github.com/coder/coder/v2/tailnet/proto" +) + +func NewDRPCClient(conn net.Conn) (proto.DRPCClientClient, error) { + config := yamux.DefaultConfig() + config.LogOutput = io.Discard + session, err := yamux.Client(conn, config) + if err != nil { + return nil, xerrors.Errorf("multiplex client: %w", err) + } + return proto.NewDRPCClientClient(drpc.MultiplexedConn(session)), nil +} diff --git a/tailnet/service.go b/tailnet/service.go index 92140aa3f5..a6c94ef8bf 100644 --- a/tailnet/service.go +++ b/tailnet/service.go @@ -1,8 +1,20 @@ package tailnet import ( + "context" + "io" + "net" "strconv" "strings" + "sync/atomic" + + "github.com/google/uuid" + "github.com/hashicorp/yamux" + "storj.io/drpc/drpcmux" + "storj.io/drpc/drpcserver" + + "cdr.dev/slog" + "github.com/coder/coder/v2/tailnet/proto" "golang.org/x/xerrors" ) @@ -15,17 +27,9 @@ const ( var SupportedMajors = []int{2, 1} func ValidateVersion(version string) error { - parts := strings.Split(version, ".") - if len(parts) != 2 { - return xerrors.Errorf("invalid version string: %s", version) - } - major, err := strconv.Atoi(parts[0]) + major, minor, err := parseVersion(version) if err != nil { - return xerrors.Errorf("invalid major version: %s", version) - } - minor, err := strconv.Atoi(parts[1]) - if err != nil { - return xerrors.Errorf("invalid minor version: %s", version) + return err } if major > CurrentMajor { return xerrors.Errorf("server is at version %d.%d, behind requested version %s", @@ -45,3 +49,186 @@ func ValidateVersion(version string) error { } return xerrors.Errorf("version %s is no longer supported", version) } + +func parseVersion(version string) (major int, minor int, err error) { + parts := strings.Split(version, ".") + if len(parts) != 2 { + return 0, 0, xerrors.Errorf("invalid version string: %s", version) + } + major, err = strconv.Atoi(parts[0]) + if err != nil { + return 0, 0, xerrors.Errorf("invalid major version: %s", version) + } + minor, err = strconv.Atoi(parts[1]) + if err != nil { + return 0, 0, xerrors.Errorf("invalid minor version: %s", version) + } + return major, minor, nil +} + +type streamIDContextKey struct{} + +// StreamID identifies the caller of the CoordinateTailnet RPC. We store this +// on the context, since the information is extracted at the HTTP layer for +// remote clients of the API, or set outside tailnet for local clients (e.g. +// Coderd's single_tailnet) +type StreamID struct { + Name string + ID uuid.UUID + Auth TunnelAuth +} + +func WithStreamID(ctx context.Context, streamID StreamID) context.Context { + return context.WithValue(ctx, streamIDContextKey{}, streamID) +} + +// ClientService is a tailnet coordination service that accepts a connection and version from a +// tailnet client, and support versions 1.0 and 2.x of the Tailnet API protocol. +type ClientService struct { + logger slog.Logger + coordPtr *atomic.Pointer[Coordinator] + drpc *drpcserver.Server +} + +// NewClientService returns a ClientService based on the given Coordinator pointer. The pointer is +// loaded on each processed connection. +func NewClientService(logger slog.Logger, coordPtr *atomic.Pointer[Coordinator]) (*ClientService, error) { + s := &ClientService{logger: logger, coordPtr: coordPtr} + mux := drpcmux.New() + drpcService := NewDRPCService(logger, coordPtr) + err := proto.DRPCRegisterClient(mux, drpcService) + if err != nil { + return nil, xerrors.Errorf("register DRPC service: %w", err) + } + server := drpcserver.NewWithOptions(mux, drpcserver.Options{ + Log: func(err error) { + if xerrors.Is(err, io.EOF) { + return + } + logger.Debug(context.Background(), "drpc server error", slog.Error(err)) + }, + }) + s.drpc = server + return s, nil +} + +func (s *ClientService) ServeClient(ctx context.Context, version string, conn net.Conn, id uuid.UUID, agent uuid.UUID) error { + major, _, err := parseVersion(version) + if err != nil { + s.logger.Warn(ctx, "serve client called with unparsable version", slog.Error(err)) + return err + } + switch major { + case 1: + coord := *(s.coordPtr.Load()) + return coord.ServeClient(conn, id, agent) + case 2: + config := yamux.DefaultConfig() + config.LogOutput = io.Discard + session, err := yamux.Server(conn, config) + if err != nil { + return xerrors.Errorf("yamux init failed: %w", err) + } + auth := ClientTunnelAuth{AgentID: agent} + streamID := StreamID{ + Name: "client", + ID: id, + Auth: auth, + } + ctx = WithStreamID(ctx, streamID) + return s.drpc.Serve(ctx, session) + default: + s.logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version)) + return xerrors.New("unsupported version") + } +} + +// DRPCService is the dRPC-based, version 2.x of the tailnet API and implements proto.DRPCClientServer +type DRPCService struct { + coordPtr *atomic.Pointer[Coordinator] + logger slog.Logger +} + +func NewDRPCService(logger slog.Logger, coordPtr *atomic.Pointer[Coordinator]) *DRPCService { + return &DRPCService{ + coordPtr: coordPtr, + logger: logger, + } +} + +func (*DRPCService) StreamDERPMaps(*proto.StreamDERPMapsRequest, proto.DRPCClient_StreamDERPMapsStream) error { + // TODO integrate with Dean's PR implementation + return xerrors.New("unimplemented") +} + +func (s *DRPCService) CoordinateTailnet(stream proto.DRPCClient_CoordinateTailnetStream) error { + ctx := stream.Context() + streamID, ok := ctx.Value(streamIDContextKey{}).(StreamID) + if !ok { + _ = stream.Close() + return xerrors.New("no Stream ID") + } + logger := s.logger.With(slog.F("peer_id", streamID), slog.F("name", streamID.Name)) + logger.Debug(ctx, "starting tailnet Coordinate") + coord := *(s.coordPtr.Load()) + reqs, resps := coord.Coordinate(ctx, streamID.ID, streamID.Name, streamID.Auth) + c := communicator{ + logger: logger, + stream: stream, + reqs: reqs, + resps: resps, + } + c.communicate() + return nil +} + +type communicator struct { + logger slog.Logger + stream proto.DRPCClient_CoordinateTailnetStream + reqs chan<- *proto.CoordinateRequest + resps <-chan *proto.CoordinateResponse +} + +func (c communicator) communicate() { + go c.loopReq() + c.loopResp() +} + +func (c communicator) loopReq() { + ctx := c.stream.Context() + defer close(c.reqs) + for { + req, err := c.stream.Recv() + if err != nil { + c.logger.Debug(ctx, "error receiving requests from DRPC stream", slog.Error(err)) + return + } + err = SendCtx(ctx, c.reqs, req) + if err != nil { + c.logger.Debug(ctx, "context done while sending coordinate request", slog.Error(ctx.Err())) + return + } + } +} + +func (c communicator) loopResp() { + ctx := c.stream.Context() + defer func() { + err := c.stream.Close() + if err != nil { + c.logger.Debug(ctx, "loopResp hit error closing stream", slog.Error(err)) + } + }() + for { + resp, err := RecvCtx(ctx, c.resps) + if err != nil { + c.logger.Debug(ctx, "loopResp failed to get response", slog.Error(err)) + return + } + err = c.stream.Send(resp) + if err != nil { + c.logger.Debug(ctx, "loopResp failed to send response to DRPC stream", slog.Error(err)) + return + } + } +} diff --git a/tailnet/service_test.go b/tailnet/service_test.go index ffb0442aea..359d443822 100644 --- a/tailnet/service_test.go +++ b/tailnet/service_test.go @@ -1,9 +1,23 @@ package tailnet_test import ( + "context" "fmt" + "io" + "net" + "net/http" + "sync/atomic" "testing" + "golang.org/x/xerrors" + + "github.com/google/uuid" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/tailnet/proto" + "github.com/coder/coder/v2/testutil" + "github.com/stretchr/testify/require" "github.com/coder/coder/v2/tailnet" @@ -72,3 +86,181 @@ func TestValidateVersion(t *testing.T) { }) } } + +func TestClientService_ServeClient_V2(t *testing.T) { + t.Parallel() + fCoord := newFakeCoordinator() + var coord tailnet.Coordinator = fCoord + coordPtr := atomic.Pointer[tailnet.Coordinator]{} + coordPtr.Store(&coord) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + uut, err := tailnet.NewClientService(logger, &coordPtr) + require.NoError(t, err) + + ctx := testutil.Context(t, testutil.WaitShort) + c, s := net.Pipe() + defer c.Close() + defer s.Close() + clientID := uuid.MustParse("10000001-0000-0000-0000-000000000000") + agentID := uuid.MustParse("20000001-0000-0000-0000-000000000000") + errCh := make(chan error, 1) + go func() { + err := uut.ServeClient(ctx, "2.0", s, clientID, agentID) + t.Logf("ServeClient returned; err=%v", err) + errCh <- err + }() + + client, err := tailnet.NewDRPCClient(c) + require.NoError(t, err) + stream, err := client.CoordinateTailnet(ctx) + require.NoError(t, err) + defer stream.Close() + + err = stream.Send(&proto.CoordinateRequest{ + UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: &proto.Node{PreferredDerp: 11}}, + }) + require.NoError(t, err) + + call := testutil.RequireRecvCtx(ctx, t, fCoord.coordinateCalls) + require.NotNil(t, call) + require.Equal(t, call.id, clientID) + require.Equal(t, call.name, "client") + require.True(t, call.auth.Authorize(agentID)) + req := testutil.RequireRecvCtx(ctx, t, call.reqs) + require.Equal(t, int32(11), req.GetUpdateSelf().GetNode().GetPreferredDerp()) + + call.resps <- &proto.CoordinateResponse{PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{ + { + Kind: proto.CoordinateResponse_PeerUpdate_NODE, + Node: &proto.Node{PreferredDerp: 22}, + Uuid: agentID[:], + }, + }} + resp, err := stream.Recv() + require.NoError(t, err) + u := resp.GetPeerUpdates() + require.Len(t, u, 1) + require.Equal(t, int32(22), u[0].GetNode().GetPreferredDerp()) + + err = stream.Close() + require.NoError(t, err) + + // stream ^^ is just one RPC; we need to close the Conn to end the session. + err = c.Close() + require.NoError(t, err) + err = testutil.RequireRecvCtx(ctx, t, errCh) + require.ErrorIs(t, err, io.EOF) +} + +func TestClientService_ServeClient_V1(t *testing.T) { + t.Parallel() + fCoord := newFakeCoordinator() + var coord tailnet.Coordinator = fCoord + coordPtr := atomic.Pointer[tailnet.Coordinator]{} + coordPtr.Store(&coord) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + uut, err := tailnet.NewClientService(logger, &coordPtr) + require.NoError(t, err) + + ctx := testutil.Context(t, testutil.WaitShort) + c, s := net.Pipe() + defer c.Close() + defer s.Close() + clientID := uuid.MustParse("10000001-0000-0000-0000-000000000000") + agentID := uuid.MustParse("20000001-0000-0000-0000-000000000000") + errCh := make(chan error, 1) + go func() { + err := uut.ServeClient(ctx, "1.0", s, clientID, agentID) + t.Logf("ServeClient returned; err=%v", err) + errCh <- err + }() + + call := testutil.RequireRecvCtx(ctx, t, fCoord.serveClientCalls) + require.NotNil(t, call) + require.Equal(t, call.id, clientID) + require.Equal(t, call.agent, agentID) + require.Equal(t, s, call.conn) + expectedError := xerrors.New("test error") + select { + case call.errCh <- expectedError: + // ok! + case <-ctx.Done(): + t.Fatalf("timeout sending error") + } + + err = testutil.RequireRecvCtx(ctx, t, errCh) + require.ErrorIs(t, err, expectedError) +} + +type fakeCoordinator struct { + coordinateCalls chan *fakeCoordinate + serveClientCalls chan *fakeServeClient +} + +func (*fakeCoordinator) ServeHTTPDebug(http.ResponseWriter, *http.Request) { + panic("unimplemented") +} + +func (*fakeCoordinator) Node(uuid.UUID) *tailnet.Node { + panic("unimplemented") +} + +func (f *fakeCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { + errCh := make(chan error) + f.serveClientCalls <- &fakeServeClient{ + conn: conn, + id: id, + agent: agent, + errCh: errCh, + } + return <-errCh +} + +func (*fakeCoordinator) ServeAgent(net.Conn, uuid.UUID, string) error { + panic("unimplemented") +} + +func (*fakeCoordinator) Close() error { + panic("unimplemented") +} + +func (*fakeCoordinator) ServeMultiAgent(uuid.UUID) tailnet.MultiAgentConn { + panic("unimplemented") +} + +func (f *fakeCoordinator) Coordinate(ctx context.Context, id uuid.UUID, name string, a tailnet.TunnelAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) { + reqs := make(chan *proto.CoordinateRequest, 100) + resps := make(chan *proto.CoordinateResponse, 100) + f.coordinateCalls <- &fakeCoordinate{ + ctx: ctx, + id: id, + name: name, + auth: a, + reqs: reqs, + resps: resps, + } + return reqs, resps +} + +func newFakeCoordinator() *fakeCoordinator { + return &fakeCoordinator{ + coordinateCalls: make(chan *fakeCoordinate, 100), + serveClientCalls: make(chan *fakeServeClient, 100), + } +} + +type fakeCoordinate struct { + ctx context.Context + id uuid.UUID + name string + auth tailnet.TunnelAuth + reqs chan *proto.CoordinateRequest + resps chan *proto.CoordinateResponse +} + +type fakeServeClient struct { + conn net.Conn + id uuid.UUID + agent uuid.UUID + errCh chan error +}