diff --git a/coderd/coderd.go b/coderd/coderd.go index 898dcb36d5..ae861d5687 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -479,7 +479,11 @@ func New(options *Options) *API { } } api.TailnetClientService, err = tailnet.NewClientService( - api.Logger.Named("tailnetclient"), &api.TailnetCoordinator) + api.Logger.Named("tailnetclient"), + &api.TailnetCoordinator, + api.Options.DERPMapUpdateFrequency, + api.DERPMap, + ) if err != nil { api.Logger.Fatal(api.ctx, "failed to initialize tailnet client service", slog.Error(err)) } diff --git a/tailnet/service.go b/tailnet/service.go index a6c94ef8bf..a9982798fc 100644 --- a/tailnet/service.go +++ b/tailnet/service.go @@ -7,11 +7,13 @@ import ( "strconv" "strings" "sync/atomic" + "time" "github.com/google/uuid" "github.com/hashicorp/yamux" "storj.io/drpc/drpcmux" "storj.io/drpc/drpcserver" + "tailscale.com/tailcfg" "cdr.dev/slog" "github.com/coder/coder/v2/tailnet/proto" @@ -92,10 +94,22 @@ type ClientService struct { // NewClientService returns a ClientService based on the given Coordinator pointer. The pointer is // loaded on each processed connection. -func NewClientService(logger slog.Logger, coordPtr *atomic.Pointer[Coordinator]) (*ClientService, error) { +func NewClientService( + logger slog.Logger, + coordPtr *atomic.Pointer[Coordinator], + derpMapUpdateFrequency time.Duration, + derpMapFn func() *tailcfg.DERPMap, +) ( + *ClientService, error, +) { s := &ClientService{logger: logger, coordPtr: coordPtr} mux := drpcmux.New() - drpcService := NewDRPCService(logger, coordPtr) + drpcService := &DRPCService{ + CoordPtr: coordPtr, + Logger: logger, + DerpMapUpdateFrequency: derpMapUpdateFrequency, + DerpMapFn: derpMapFn, + } err := proto.DRPCRegisterClient(mux, drpcService) if err != nil { return nil, xerrors.Errorf("register DRPC service: %w", err) @@ -145,22 +159,39 @@ func (s *ClientService) ServeClient(ctx context.Context, version string, conn ne // DRPCService is the dRPC-based, version 2.x of the tailnet API and implements proto.DRPCClientServer type DRPCService struct { - coordPtr *atomic.Pointer[Coordinator] - logger slog.Logger + CoordPtr *atomic.Pointer[Coordinator] + Logger slog.Logger + DerpMapUpdateFrequency time.Duration + DerpMapFn func() *tailcfg.DERPMap } -func NewDRPCService(logger slog.Logger, coordPtr *atomic.Pointer[Coordinator]) *DRPCService { - return &DRPCService{ - coordPtr: coordPtr, - logger: logger, +func (s *DRPCService) StreamDERPMaps(_ *proto.StreamDERPMapsRequest, stream proto.DRPCClient_StreamDERPMapsStream) error { + defer stream.Close() + + ticker := time.NewTicker(s.DerpMapUpdateFrequency) + defer ticker.Stop() + + var lastDERPMap *tailcfg.DERPMap + for { + derpMap := s.DerpMapFn() + if lastDERPMap == nil || !CompareDERPMaps(lastDERPMap, derpMap) { + protoDERPMap := DERPMapToProto(derpMap) + err := stream.Send(protoDERPMap) + if err != nil { + return xerrors.Errorf("send derp map: %w", err) + } + lastDERPMap = derpMap + } + + ticker.Reset(s.DerpMapUpdateFrequency) + select { + case <-stream.Context().Done(): + return nil + case <-ticker.C: + } } } -func (*DRPCService) StreamDERPMaps(*proto.StreamDERPMapsRequest, proto.DRPCClient_StreamDERPMapsStream) error { - // TODO integrate with Dean's PR implementation - return xerrors.New("unimplemented") -} - func (s *DRPCService) CoordinateTailnet(stream proto.DRPCClient_CoordinateTailnetStream) error { ctx := stream.Context() streamID, ok := ctx.Value(streamIDContextKey{}).(StreamID) @@ -168,9 +199,9 @@ func (s *DRPCService) CoordinateTailnet(stream proto.DRPCClient_CoordinateTailne _ = stream.Close() return xerrors.New("no Stream ID") } - logger := s.logger.With(slog.F("peer_id", streamID), slog.F("name", streamID.Name)) + logger := s.Logger.With(slog.F("peer_id", streamID), slog.F("name", streamID.Name)) logger.Debug(ctx, "starting tailnet Coordinate") - coord := *(s.coordPtr.Load()) + coord := *(s.CoordPtr.Load()) reqs, resps := coord.Coordinate(ctx, streamID.ID, streamID.Name, streamID.Auth) c := communicator{ logger: logger, diff --git a/tailnet/service_test.go b/tailnet/service_test.go index c69f5b1469..9a476e4b6d 100644 --- a/tailnet/service_test.go +++ b/tailnet/service_test.go @@ -8,8 +8,10 @@ import ( "net/http" "sync/atomic" "testing" + "time" "golang.org/x/xerrors" + "tailscale.com/tailcfg" "github.com/google/uuid" @@ -94,7 +96,11 @@ func TestClientService_ServeClient_V2(t *testing.T) { coordPtr := atomic.Pointer[tailnet.Coordinator]{} coordPtr.Store(&coord) logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - uut, err := tailnet.NewClientService(logger, &coordPtr) + derpMap := &tailcfg.DERPMap{Regions: map[int]*tailcfg.DERPRegion{999: {RegionCode: "test"}}} + uut, err := tailnet.NewClientService( + logger, &coordPtr, + time.Millisecond, func() *tailcfg.DERPMap { return derpMap }, + ) require.NoError(t, err) ctx := testutil.Context(t, testutil.WaitShort) @@ -112,6 +118,8 @@ func TestClientService_ServeClient_V2(t *testing.T) { client, err := tailnet.NewDRPCClient(c) require.NoError(t, err) + + // Coordinate stream, err := client.CoordinateTailnet(ctx) require.NoError(t, err) defer stream.Close() @@ -145,7 +153,17 @@ func TestClientService_ServeClient_V2(t *testing.T) { err = stream.Close() require.NoError(t, err) - // stream ^^ is just one RPC; we need to close the Conn to end the session. + // DERP Map + dms, err := client.StreamDERPMaps(ctx, &proto.StreamDERPMapsRequest{}) + require.NoError(t, err) + + gotDermMap, err := dms.Recv() + require.NoError(t, err) + require.Equal(t, "test", gotDermMap.GetRegions()[999].GetRegionCode()) + err = dms.Close() + require.NoError(t, err) + + // RPCs closed; we need to close the Conn to end the session. err = c.Close() require.NoError(t, err) err = testutil.RequireRecvCtx(ctx, t, errCh) @@ -159,7 +177,7 @@ func TestClientService_ServeClient_V1(t *testing.T) { coordPtr := atomic.Pointer[tailnet.Coordinator]{} coordPtr.Store(&coord) logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) - uut, err := tailnet.NewClientService(logger, &coordPtr) + uut, err := tailnet.NewClientService(logger, &coordPtr, 0, nil) require.NoError(t, err) ctx := testutil.Context(t, testutil.WaitShort)