diff --git a/tailnet/client.go b/tailnet/client.go new file mode 100644 index 0000000000..db00a9d954 --- /dev/null +++ b/tailnet/client.go @@ -0,0 +1,22 @@ +package tailnet + +import ( + "io" + "net" + + "github.com/hashicorp/yamux" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/codersdk/drpc" + "github.com/coder/coder/v2/tailnet/proto" +) + +func NewDRPCClient(conn net.Conn) (proto.DRPCClientClient, error) { + config := yamux.DefaultConfig() + config.LogOutput = io.Discard + session, err := yamux.Client(conn, config) + if err != nil { + return nil, xerrors.Errorf("multiplex client: %w", err) + } + return proto.NewDRPCClientClient(drpc.MultiplexedConn(session)), nil +} diff --git a/tailnet/service.go b/tailnet/service.go index 92140aa3f5..a6c94ef8bf 100644 --- a/tailnet/service.go +++ b/tailnet/service.go @@ -1,8 +1,20 @@ package tailnet import ( + "context" + "io" + "net" "strconv" "strings" + "sync/atomic" + + "github.com/google/uuid" + "github.com/hashicorp/yamux" + "storj.io/drpc/drpcmux" + "storj.io/drpc/drpcserver" + + "cdr.dev/slog" + "github.com/coder/coder/v2/tailnet/proto" "golang.org/x/xerrors" ) @@ -15,17 +27,9 @@ const ( var SupportedMajors = []int{2, 1} func ValidateVersion(version string) error { - parts := strings.Split(version, ".") - if len(parts) != 2 { - return xerrors.Errorf("invalid version string: %s", version) - } - major, err := strconv.Atoi(parts[0]) + major, minor, err := parseVersion(version) if err != nil { - return xerrors.Errorf("invalid major version: %s", version) - } - minor, err := strconv.Atoi(parts[1]) - if err != nil { - return xerrors.Errorf("invalid minor version: %s", version) + return err } if major > CurrentMajor { return xerrors.Errorf("server is at version %d.%d, behind requested version %s", @@ -45,3 +49,186 @@ func ValidateVersion(version string) error { } return xerrors.Errorf("version %s is no longer supported", version) } + +func parseVersion(version string) (major int, minor int, err error) { + parts := strings.Split(version, ".") + if len(parts) != 2 { + return 0, 0, xerrors.Errorf("invalid version string: %s", version) + } + major, err = strconv.Atoi(parts[0]) + if err != nil { + return 0, 0, xerrors.Errorf("invalid major version: %s", version) + } + minor, err = strconv.Atoi(parts[1]) + if err != nil { + return 0, 0, xerrors.Errorf("invalid minor version: %s", version) + } + return major, minor, nil +} + +type streamIDContextKey struct{} + +// StreamID identifies the caller of the CoordinateTailnet RPC. We store this +// on the context, since the information is extracted at the HTTP layer for +// remote clients of the API, or set outside tailnet for local clients (e.g. +// Coderd's single_tailnet) +type StreamID struct { + Name string + ID uuid.UUID + Auth TunnelAuth +} + +func WithStreamID(ctx context.Context, streamID StreamID) context.Context { + return context.WithValue(ctx, streamIDContextKey{}, streamID) +} + +// ClientService is a tailnet coordination service that accepts a connection and version from a +// tailnet client, and support versions 1.0 and 2.x of the Tailnet API protocol. +type ClientService struct { + logger slog.Logger + coordPtr *atomic.Pointer[Coordinator] + drpc *drpcserver.Server +} + +// NewClientService returns a ClientService based on the given Coordinator pointer. The pointer is +// loaded on each processed connection. +func NewClientService(logger slog.Logger, coordPtr *atomic.Pointer[Coordinator]) (*ClientService, error) { + s := &ClientService{logger: logger, coordPtr: coordPtr} + mux := drpcmux.New() + drpcService := NewDRPCService(logger, coordPtr) + err := proto.DRPCRegisterClient(mux, drpcService) + if err != nil { + return nil, xerrors.Errorf("register DRPC service: %w", err) + } + server := drpcserver.NewWithOptions(mux, drpcserver.Options{ + Log: func(err error) { + if xerrors.Is(err, io.EOF) { + return + } + logger.Debug(context.Background(), "drpc server error", slog.Error(err)) + }, + }) + s.drpc = server + return s, nil +} + +func (s *ClientService) ServeClient(ctx context.Context, version string, conn net.Conn, id uuid.UUID, agent uuid.UUID) error { + major, _, err := parseVersion(version) + if err != nil { + s.logger.Warn(ctx, "serve client called with unparsable version", slog.Error(err)) + return err + } + switch major { + case 1: + coord := *(s.coordPtr.Load()) + return coord.ServeClient(conn, id, agent) + case 2: + config := yamux.DefaultConfig() + config.LogOutput = io.Discard + session, err := yamux.Server(conn, config) + if err != nil { + return xerrors.Errorf("yamux init failed: %w", err) + } + auth := ClientTunnelAuth{AgentID: agent} + streamID := StreamID{ + Name: "client", + ID: id, + Auth: auth, + } + ctx = WithStreamID(ctx, streamID) + return s.drpc.Serve(ctx, session) + default: + s.logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version)) + return xerrors.New("unsupported version") + } +} + +// DRPCService is the dRPC-based, version 2.x of the tailnet API and implements proto.DRPCClientServer +type DRPCService struct { + coordPtr *atomic.Pointer[Coordinator] + logger slog.Logger +} + +func NewDRPCService(logger slog.Logger, coordPtr *atomic.Pointer[Coordinator]) *DRPCService { + return &DRPCService{ + coordPtr: coordPtr, + logger: logger, + } +} + +func (*DRPCService) StreamDERPMaps(*proto.StreamDERPMapsRequest, proto.DRPCClient_StreamDERPMapsStream) error { + // TODO integrate with Dean's PR implementation + return xerrors.New("unimplemented") +} + +func (s *DRPCService) CoordinateTailnet(stream proto.DRPCClient_CoordinateTailnetStream) error { + ctx := stream.Context() + streamID, ok := ctx.Value(streamIDContextKey{}).(StreamID) + if !ok { + _ = stream.Close() + return xerrors.New("no Stream ID") + } + logger := s.logger.With(slog.F("peer_id", streamID), slog.F("name", streamID.Name)) + logger.Debug(ctx, "starting tailnet Coordinate") + coord := *(s.coordPtr.Load()) + reqs, resps := coord.Coordinate(ctx, streamID.ID, streamID.Name, streamID.Auth) + c := communicator{ + logger: logger, + stream: stream, + reqs: reqs, + resps: resps, + } + c.communicate() + return nil +} + +type communicator struct { + logger slog.Logger + stream proto.DRPCClient_CoordinateTailnetStream + reqs chan<- *proto.CoordinateRequest + resps <-chan *proto.CoordinateResponse +} + +func (c communicator) communicate() { + go c.loopReq() + c.loopResp() +} + +func (c communicator) loopReq() { + ctx := c.stream.Context() + defer close(c.reqs) + for { + req, err := c.stream.Recv() + if err != nil { + c.logger.Debug(ctx, "error receiving requests from DRPC stream", slog.Error(err)) + return + } + err = SendCtx(ctx, c.reqs, req) + if err != nil { + c.logger.Debug(ctx, "context done while sending coordinate request", slog.Error(ctx.Err())) + return + } + } +} + +func (c communicator) loopResp() { + ctx := c.stream.Context() + defer func() { + err := c.stream.Close() + if err != nil { + c.logger.Debug(ctx, "loopResp hit error closing stream", slog.Error(err)) + } + }() + for { + resp, err := RecvCtx(ctx, c.resps) + if err != nil { + c.logger.Debug(ctx, "loopResp failed to get response", slog.Error(err)) + return + } + err = c.stream.Send(resp) + if err != nil { + c.logger.Debug(ctx, "loopResp failed to send response to DRPC stream", slog.Error(err)) + return + } + } +} diff --git a/tailnet/service_test.go b/tailnet/service_test.go index ffb0442aea..359d443822 100644 --- a/tailnet/service_test.go +++ b/tailnet/service_test.go @@ -1,9 +1,23 @@ package tailnet_test import ( + "context" "fmt" + "io" + "net" + "net/http" + "sync/atomic" "testing" + "golang.org/x/xerrors" + + "github.com/google/uuid" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/tailnet/proto" + "github.com/coder/coder/v2/testutil" + "github.com/stretchr/testify/require" "github.com/coder/coder/v2/tailnet" @@ -72,3 +86,181 @@ func TestValidateVersion(t *testing.T) { }) } } + +func TestClientService_ServeClient_V2(t *testing.T) { + t.Parallel() + fCoord := newFakeCoordinator() + var coord tailnet.Coordinator = fCoord + coordPtr := atomic.Pointer[tailnet.Coordinator]{} + coordPtr.Store(&coord) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + uut, err := tailnet.NewClientService(logger, &coordPtr) + require.NoError(t, err) + + ctx := testutil.Context(t, testutil.WaitShort) + c, s := net.Pipe() + defer c.Close() + defer s.Close() + clientID := uuid.MustParse("10000001-0000-0000-0000-000000000000") + agentID := uuid.MustParse("20000001-0000-0000-0000-000000000000") + errCh := make(chan error, 1) + go func() { + err := uut.ServeClient(ctx, "2.0", s, clientID, agentID) + t.Logf("ServeClient returned; err=%v", err) + errCh <- err + }() + + client, err := tailnet.NewDRPCClient(c) + require.NoError(t, err) + stream, err := client.CoordinateTailnet(ctx) + require.NoError(t, err) + defer stream.Close() + + err = stream.Send(&proto.CoordinateRequest{ + UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: &proto.Node{PreferredDerp: 11}}, + }) + require.NoError(t, err) + + call := testutil.RequireRecvCtx(ctx, t, fCoord.coordinateCalls) + require.NotNil(t, call) + require.Equal(t, call.id, clientID) + require.Equal(t, call.name, "client") + require.True(t, call.auth.Authorize(agentID)) + req := testutil.RequireRecvCtx(ctx, t, call.reqs) + require.Equal(t, int32(11), req.GetUpdateSelf().GetNode().GetPreferredDerp()) + + call.resps <- &proto.CoordinateResponse{PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{ + { + Kind: proto.CoordinateResponse_PeerUpdate_NODE, + Node: &proto.Node{PreferredDerp: 22}, + Uuid: agentID[:], + }, + }} + resp, err := stream.Recv() + require.NoError(t, err) + u := resp.GetPeerUpdates() + require.Len(t, u, 1) + require.Equal(t, int32(22), u[0].GetNode().GetPreferredDerp()) + + err = stream.Close() + require.NoError(t, err) + + // stream ^^ is just one RPC; we need to close the Conn to end the session. + err = c.Close() + require.NoError(t, err) + err = testutil.RequireRecvCtx(ctx, t, errCh) + require.ErrorIs(t, err, io.EOF) +} + +func TestClientService_ServeClient_V1(t *testing.T) { + t.Parallel() + fCoord := newFakeCoordinator() + var coord tailnet.Coordinator = fCoord + coordPtr := atomic.Pointer[tailnet.Coordinator]{} + coordPtr.Store(&coord) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + uut, err := tailnet.NewClientService(logger, &coordPtr) + require.NoError(t, err) + + ctx := testutil.Context(t, testutil.WaitShort) + c, s := net.Pipe() + defer c.Close() + defer s.Close() + clientID := uuid.MustParse("10000001-0000-0000-0000-000000000000") + agentID := uuid.MustParse("20000001-0000-0000-0000-000000000000") + errCh := make(chan error, 1) + go func() { + err := uut.ServeClient(ctx, "1.0", s, clientID, agentID) + t.Logf("ServeClient returned; err=%v", err) + errCh <- err + }() + + call := testutil.RequireRecvCtx(ctx, t, fCoord.serveClientCalls) + require.NotNil(t, call) + require.Equal(t, call.id, clientID) + require.Equal(t, call.agent, agentID) + require.Equal(t, s, call.conn) + expectedError := xerrors.New("test error") + select { + case call.errCh <- expectedError: + // ok! + case <-ctx.Done(): + t.Fatalf("timeout sending error") + } + + err = testutil.RequireRecvCtx(ctx, t, errCh) + require.ErrorIs(t, err, expectedError) +} + +type fakeCoordinator struct { + coordinateCalls chan *fakeCoordinate + serveClientCalls chan *fakeServeClient +} + +func (*fakeCoordinator) ServeHTTPDebug(http.ResponseWriter, *http.Request) { + panic("unimplemented") +} + +func (*fakeCoordinator) Node(uuid.UUID) *tailnet.Node { + panic("unimplemented") +} + +func (f *fakeCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { + errCh := make(chan error) + f.serveClientCalls <- &fakeServeClient{ + conn: conn, + id: id, + agent: agent, + errCh: errCh, + } + return <-errCh +} + +func (*fakeCoordinator) ServeAgent(net.Conn, uuid.UUID, string) error { + panic("unimplemented") +} + +func (*fakeCoordinator) Close() error { + panic("unimplemented") +} + +func (*fakeCoordinator) ServeMultiAgent(uuid.UUID) tailnet.MultiAgentConn { + panic("unimplemented") +} + +func (f *fakeCoordinator) Coordinate(ctx context.Context, id uuid.UUID, name string, a tailnet.TunnelAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) { + reqs := make(chan *proto.CoordinateRequest, 100) + resps := make(chan *proto.CoordinateResponse, 100) + f.coordinateCalls <- &fakeCoordinate{ + ctx: ctx, + id: id, + name: name, + auth: a, + reqs: reqs, + resps: resps, + } + return reqs, resps +} + +func newFakeCoordinator() *fakeCoordinator { + return &fakeCoordinator{ + coordinateCalls: make(chan *fakeCoordinate, 100), + serveClientCalls: make(chan *fakeServeClient, 100), + } +} + +type fakeCoordinate struct { + ctx context.Context + id uuid.UUID + name string + auth tailnet.TunnelAuth + reqs chan *proto.CoordinateRequest + resps chan *proto.CoordinateResponse +} + +type fakeServeClient struct { + conn net.Conn + id uuid.UUID + agent uuid.UUID + errCh chan error +}