diff --git a/coderd/tailnet.go b/coderd/tailnet.go index 74b821deb8..0247096d71 100644 --- a/coderd/tailnet.go +++ b/coderd/tailnet.go @@ -136,28 +136,8 @@ func NewServerTailnet( return nil, xerrors.Errorf("get initial multi agent: %w", err) } tn.agentConn.Store(&agentConn) - - pn, err := tailnet.NodeToProto(conn.Node()) - if err != nil { - tn.logger.Critical(context.Background(), "failed to convert node", slog.Error(err)) - } else { - err = tn.getAgentConn().UpdateSelf(pn) - if err != nil { - tn.logger.Warn(context.Background(), "server tailnet update self", slog.Error(err)) - } - } - - conn.SetNodeCallback(func(node *tailnet.Node) { - pn, err := tailnet.NodeToProto(node) - if err != nil { - tn.logger.Critical(context.Background(), "failed to convert node", slog.Error(err)) - return - } - err = tn.getAgentConn().UpdateSelf(pn) - if err != nil { - tn.logger.Warn(context.Background(), "broadcast server node to agents", slog.Error(err)) - } - }) + // registering the callback also triggers send of the initial node + tn.coordinatee.SetNodeCallback(tn.nodeCallback) // This is set to allow local DERP traffic to be proxied through memory // instead of needing to hit the external access URL. Don't use the ctx @@ -183,6 +163,18 @@ func NewServerTailnet( return tn, nil } +func (s *ServerTailnet) nodeCallback(node *tailnet.Node) { + pn, err := tailnet.NodeToProto(node) + if err != nil { + s.logger.Critical(context.Background(), "failed to convert node", slog.Error(err)) + return + } + err = s.getAgentConn().UpdateSelf(pn) + if err != nil { + s.logger.Warn(context.Background(), "broadcast server node to agents", slog.Error(err)) + } +} + func (s *ServerTailnet) Describe(descs chan<- *prometheus.Desc) { s.connsPerAgent.Describe(descs) s.totalConns.Describe(descs) @@ -285,6 +277,9 @@ func (s *ServerTailnet) reinitCoordinator() { continue } s.agentConn.Store(&agentConn) + // reset the Node callback, which triggers the conn to send the node immediately, and also + // register for updates + s.coordinatee.SetNodeCallback(s.nodeCallback) // Resubscribe to all of the agents we're tracking. for agentID := range s.agentConnectionTimes { diff --git a/coderd/tailnet_internal_test.go b/coderd/tailnet_internal_test.go index f09ac1d28b..f8750dcbe9 100644 --- a/coderd/tailnet_internal_test.go +++ b/coderd/tailnet_internal_test.go @@ -48,6 +48,7 @@ func TestServerTailnet_Reconnect(t *testing.T) { agentConnectionTimes: make(map[uuid.UUID]time.Time), } // reinit the Coordinator once, to load mMultiAgent0 + mCoord.EXPECT().SetNodeCallback(gomock.Any()).Times(1) uut.reinitCoordinator() mMultiAgent0.EXPECT().NextUpdate(gomock.Any()). @@ -57,6 +58,7 @@ func TestServerTailnet_Reconnect(t *testing.T) { Times(1). Return(true) // this triggers reconnect setLost := mCoord.EXPECT().SetAllPeersLost().Times(1).After(closed0) + mCoord.EXPECT().SetNodeCallback(gomock.Any()).Times(1).After(closed0) mMultiAgent1.EXPECT().NextUpdate(gomock.Any()). Times(1). After(setLost).