refactor: add postgres tailnet coordinator (#8044)

* postgres tailnet coordinator

Signed-off-by: Spike Curtis <spike@coder.com>

* Fix db migration; tests

Signed-off-by: Spike Curtis <spike@coder.com>

* Add fixture, regenerate

Signed-off-by: Spike Curtis <spike@coder.com>

* Fix fixtures

Signed-off-by: Spike Curtis <spike@coder.com>

* review comments, run clean gen

Signed-off-by: Spike Curtis <spike@coder.com>

* Rename waitForConn -> cleanupConn

Signed-off-by: Spike Curtis <spike@coder.com>

* code review updates

Signed-off-by: Spike Curtis <spike@coder.com>

* db migration order

Signed-off-by: Spike Curtis <spike@coder.com>

* fix log field name last_heartbeat

Signed-off-by: Spike Curtis <spike@coder.com>

* fix heartbeat_from log field

Signed-off-by: Spike Curtis <spike@coder.com>

* fix slog fields for linting

Signed-off-by: Spike Curtis <spike@coder.com>

---------

Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
Spike Curtis 2023-06-21 16:20:58 +04:00 committed by GitHub
parent 4fb4c9b270
commit cc17d2feea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 2739 additions and 18 deletions

View File

@ -707,6 +707,13 @@ func (q *querier) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, u
return q.db.DeleteApplicationConnectAPIKeysByUserID(ctx, userID)
}
func (q *querier) DeleteCoordinator(ctx context.Context, id uuid.UUID) error {
if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil {
return err
}
return q.db.DeleteCoordinator(ctx, id)
}
func (q *querier) DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error {
return deleteQ(q.log, q.auth, q.db.GetGitSSHKey, q.db.DeleteGitSSHKey)(ctx, userID)
}
@ -765,6 +772,20 @@ func (q *querier) DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt tim
return q.db.DeleteReplicasUpdatedBefore(ctx, updatedAt)
}
func (q *querier) DeleteTailnetAgent(ctx context.Context, arg database.DeleteTailnetAgentParams) (database.DeleteTailnetAgentRow, error) {
if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil {
return database.DeleteTailnetAgentRow{}, err
}
return q.db.DeleteTailnetAgent(ctx, arg)
}
func (q *querier) DeleteTailnetClient(ctx context.Context, arg database.DeleteTailnetClientParams) (database.DeleteTailnetClientRow, error) {
if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil {
return database.DeleteTailnetClientRow{}, err
}
return q.db.DeleteTailnetClient(ctx, arg)
}
func (q *querier) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) {
return fetch(q.log, q.auth, q.db.GetAPIKeyByID)(ctx, id)
}
@ -1137,6 +1158,20 @@ func (q *querier) GetServiceBanner(ctx context.Context) (string, error) {
return q.db.GetServiceBanner(ctx)
}
func (q *querier) GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]database.TailnetAgent, error) {
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
return nil, err
}
return q.db.GetTailnetAgents(ctx, id)
}
func (q *querier) GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]database.TailnetClient, error) {
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
return nil, err
}
return q.db.GetTailnetClientsForAgent(ctx, agentID)
}
// Only used by metrics cache.
func (q *querier) GetTemplateAverageBuildTime(ctx context.Context, arg database.GetTemplateAverageBuildTimeParams) (database.GetTemplateAverageBuildTimeRow, error) {
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil {
@ -2515,3 +2550,24 @@ func (q *querier) UpsertServiceBanner(ctx context.Context, value string) error {
}
return q.db.UpsertServiceBanner(ctx, value)
}
func (q *querier) UpsertTailnetAgent(ctx context.Context, arg database.UpsertTailnetAgentParams) (database.TailnetAgent, error) {
if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil {
return database.TailnetAgent{}, err
}
return q.db.UpsertTailnetAgent(ctx, arg)
}
func (q *querier) UpsertTailnetClient(ctx context.Context, arg database.UpsertTailnetClientParams) (database.TailnetClient, error) {
if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil {
return database.TailnetClient{}, err
}
return q.db.UpsertTailnetClient(ctx, arg)
}
func (q *querier) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (database.TailnetCoordinator, error) {
if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil {
return database.TailnetCoordinator{}, err
}
return q.db.UpsertTailnetCoordinator(ctx, id)
}

View File

@ -966,6 +966,14 @@ func isNotNull(v interface{}) bool {
return reflect.ValueOf(v).FieldByName("Valid").Bool()
}
// ErrUnimplemented is returned by methods only used by the enterprise/tailnet.pgCoord. This coordinator explicitly
// depends on postgres triggers that announce changes on the pubsub. Implementing support for this in the fake
// database would strongly couple the fakeQuerier to the pubsub, which is undesirable. Furthermore, it makes little
// sense to directly test the pgCoord against anything other than postgres. The fakeQuerier is designed to allow us to
// test the Coderd API, and for that kind of test, the in-memory, AGPL tailnet coordinator is sufficient. Therefore,
// these methods remain unimplemented in the fakeQuerier.
var ErrUnimplemented = xerrors.New("unimplemented")
func (*fakeQuerier) AcquireLock(_ context.Context, _ int64) error {
return xerrors.New("AcquireLock must only be called within a transaction")
}
@ -1066,6 +1074,10 @@ func (q *fakeQuerier) DeleteApplicationConnectAPIKeysByUserID(_ context.Context,
return nil
}
func (*fakeQuerier) DeleteCoordinator(context.Context, uuid.UUID) error {
return ErrUnimplemented
}
func (q *fakeQuerier) DeleteGitSSHKey(_ context.Context, userID uuid.UUID) error {
q.mutex.Lock()
defer q.mutex.Unlock()
@ -1174,6 +1186,14 @@ func (q *fakeQuerier) DeleteReplicasUpdatedBefore(_ context.Context, before time
return nil
}
func (*fakeQuerier) DeleteTailnetAgent(context.Context, database.DeleteTailnetAgentParams) (database.DeleteTailnetAgentRow, error) {
return database.DeleteTailnetAgentRow{}, ErrUnimplemented
}
func (*fakeQuerier) DeleteTailnetClient(context.Context, database.DeleteTailnetClientParams) (database.DeleteTailnetClientRow, error) {
return database.DeleteTailnetClientRow{}, ErrUnimplemented
}
func (q *fakeQuerier) GetAPIKeyByID(_ context.Context, id string) (database.APIKey, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
@ -2185,6 +2205,14 @@ func (q *fakeQuerier) GetServiceBanner(_ context.Context) (string, error) {
return string(q.serviceBanner), nil
}
func (*fakeQuerier) GetTailnetAgents(context.Context, uuid.UUID) ([]database.TailnetAgent, error) {
return nil, ErrUnimplemented
}
func (*fakeQuerier) GetTailnetClientsForAgent(context.Context, uuid.UUID) ([]database.TailnetClient, error) {
return nil, ErrUnimplemented
}
func (q *fakeQuerier) GetTemplateAverageBuildTime(ctx context.Context, arg database.GetTemplateAverageBuildTimeParams) (database.GetTemplateAverageBuildTimeRow, error) {
if err := validateDatabaseType(arg); err != nil {
return database.GetTemplateAverageBuildTimeRow{}, err
@ -5238,3 +5266,15 @@ func (q *fakeQuerier) UpsertServiceBanner(_ context.Context, data string) error
q.serviceBanner = []byte(data)
return nil
}
func (*fakeQuerier) UpsertTailnetAgent(context.Context, database.UpsertTailnetAgentParams) (database.TailnetAgent, error) {
return database.TailnetAgent{}, ErrUnimplemented
}
func (*fakeQuerier) UpsertTailnetClient(context.Context, database.UpsertTailnetClientParams) (database.TailnetClient, error) {
return database.TailnetClient{}, ErrUnimplemented
}
func (*fakeQuerier) UpsertTailnetCoordinator(context.Context, uuid.UUID) (database.TailnetCoordinator, error) {
return database.TailnetCoordinator{}, ErrUnimplemented
}

View File

@ -143,6 +143,12 @@ func (m metricsStore) DeleteApplicationConnectAPIKeysByUserID(ctx context.Contex
return err
}
func (m metricsStore) DeleteCoordinator(ctx context.Context, id uuid.UUID) error {
start := time.Now()
defer m.queryLatencies.WithLabelValues("DeleteCoordinator").Observe(time.Since(start).Seconds())
return m.s.DeleteCoordinator(ctx, id)
}
func (m metricsStore) DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error {
start := time.Now()
err := m.s.DeleteGitSSHKey(ctx, userID)
@ -199,6 +205,18 @@ func (m metricsStore) DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt
return err
}
func (m metricsStore) DeleteTailnetAgent(ctx context.Context, arg database.DeleteTailnetAgentParams) (database.DeleteTailnetAgentRow, error) {
start := time.Now()
defer m.queryLatencies.WithLabelValues("DeleteTailnetAgent").Observe(time.Since(start).Seconds())
return m.s.DeleteTailnetAgent(ctx, arg)
}
func (m metricsStore) DeleteTailnetClient(ctx context.Context, arg database.DeleteTailnetClientParams) (database.DeleteTailnetClientRow, error) {
start := time.Now()
defer m.queryLatencies.WithLabelValues("DeleteTailnetClient").Observe(time.Since(start).Seconds())
return m.s.DeleteTailnetClient(ctx, arg)
}
func (m metricsStore) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) {
start := time.Now()
apiKey, err := m.s.GetAPIKeyByID(ctx, id)
@ -556,6 +574,18 @@ func (m metricsStore) GetServiceBanner(ctx context.Context) (string, error) {
return banner, err
}
func (m metricsStore) GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]database.TailnetAgent, error) {
start := time.Now()
defer m.queryLatencies.WithLabelValues("GetTailnetAgents").Observe(time.Since(start).Seconds())
return m.s.GetTailnetAgents(ctx, id)
}
func (m metricsStore) GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]database.TailnetClient, error) {
start := time.Now()
defer m.queryLatencies.WithLabelValues("GetTailnetClientsForAgent").Observe(time.Since(start).Seconds())
return m.s.GetTailnetClientsForAgent(ctx, agentID)
}
func (m metricsStore) GetTemplateAverageBuildTime(ctx context.Context, arg database.GetTemplateAverageBuildTimeParams) (database.GetTemplateAverageBuildTimeRow, error) {
start := time.Now()
buildTime, err := m.s.GetTemplateAverageBuildTime(ctx, arg)
@ -1549,3 +1579,21 @@ func (m metricsStore) UpsertServiceBanner(ctx context.Context, value string) err
m.queryLatencies.WithLabelValues("UpsertServiceBanner").Observe(time.Since(start).Seconds())
return r0
}
func (m metricsStore) UpsertTailnetAgent(ctx context.Context, arg database.UpsertTailnetAgentParams) (database.TailnetAgent, error) {
start := time.Now()
defer m.queryLatencies.WithLabelValues("UpsertTailnetAgent").Observe(time.Since(start).Seconds())
return m.s.UpsertTailnetAgent(ctx, arg)
}
func (m metricsStore) UpsertTailnetClient(ctx context.Context, arg database.UpsertTailnetClientParams) (database.TailnetClient, error) {
start := time.Now()
defer m.queryLatencies.WithLabelValues("UpsertTailnetClient").Observe(time.Since(start).Seconds())
return m.s.UpsertTailnetClient(ctx, arg)
}
func (m metricsStore) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (database.TailnetCoordinator, error) {
start := time.Now()
defer m.queryLatencies.WithLabelValues("UpsertTailnetCoordinator").Observe(time.Since(start).Seconds())
return m.s.UpsertTailnetCoordinator(ctx, id)
}

View File

@ -110,6 +110,20 @@ func (mr *MockStoreMockRecorder) DeleteApplicationConnectAPIKeysByUserID(arg0, a
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteApplicationConnectAPIKeysByUserID", reflect.TypeOf((*MockStore)(nil).DeleteApplicationConnectAPIKeysByUserID), arg0, arg1)
}
// DeleteCoordinator mocks base method.
func (m *MockStore) DeleteCoordinator(arg0 context.Context, arg1 uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteCoordinator", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteCoordinator indicates an expected call of DeleteCoordinator.
func (mr *MockStoreMockRecorder) DeleteCoordinator(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCoordinator", reflect.TypeOf((*MockStore)(nil).DeleteCoordinator), arg0, arg1)
}
// DeleteGitSSHKey mocks base method.
func (m *MockStore) DeleteGitSSHKey(arg0 context.Context, arg1 uuid.UUID) error {
m.ctrl.T.Helper()
@ -223,6 +237,36 @@ func (mr *MockStoreMockRecorder) DeleteReplicasUpdatedBefore(arg0, arg1 interfac
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteReplicasUpdatedBefore", reflect.TypeOf((*MockStore)(nil).DeleteReplicasUpdatedBefore), arg0, arg1)
}
// DeleteTailnetAgent mocks base method.
func (m *MockStore) DeleteTailnetAgent(arg0 context.Context, arg1 database.DeleteTailnetAgentParams) (database.DeleteTailnetAgentRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteTailnetAgent", arg0, arg1)
ret0, _ := ret[0].(database.DeleteTailnetAgentRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DeleteTailnetAgent indicates an expected call of DeleteTailnetAgent.
func (mr *MockStoreMockRecorder) DeleteTailnetAgent(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTailnetAgent", reflect.TypeOf((*MockStore)(nil).DeleteTailnetAgent), arg0, arg1)
}
// DeleteTailnetClient mocks base method.
func (m *MockStore) DeleteTailnetClient(arg0 context.Context, arg1 database.DeleteTailnetClientParams) (database.DeleteTailnetClientRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteTailnetClient", arg0, arg1)
ret0, _ := ret[0].(database.DeleteTailnetClientRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DeleteTailnetClient indicates an expected call of DeleteTailnetClient.
func (mr *MockStoreMockRecorder) DeleteTailnetClient(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTailnetClient", reflect.TypeOf((*MockStore)(nil).DeleteTailnetClient), arg0, arg1)
}
// GetAPIKeyByID mocks base method.
func (m *MockStore) GetAPIKeyByID(arg0 context.Context, arg1 string) (database.APIKey, error) {
m.ctrl.T.Helper()
@ -1033,6 +1077,36 @@ func (mr *MockStoreMockRecorder) GetServiceBanner(arg0 interface{}) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceBanner", reflect.TypeOf((*MockStore)(nil).GetServiceBanner), arg0)
}
// GetTailnetAgents mocks base method.
func (m *MockStore) GetTailnetAgents(arg0 context.Context, arg1 uuid.UUID) ([]database.TailnetAgent, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTailnetAgents", arg0, arg1)
ret0, _ := ret[0].([]database.TailnetAgent)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTailnetAgents indicates an expected call of GetTailnetAgents.
func (mr *MockStoreMockRecorder) GetTailnetAgents(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetAgents", reflect.TypeOf((*MockStore)(nil).GetTailnetAgents), arg0, arg1)
}
// GetTailnetClientsForAgent mocks base method.
func (m *MockStore) GetTailnetClientsForAgent(arg0 context.Context, arg1 uuid.UUID) ([]database.TailnetClient, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTailnetClientsForAgent", arg0, arg1)
ret0, _ := ret[0].([]database.TailnetClient)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTailnetClientsForAgent indicates an expected call of GetTailnetClientsForAgent.
func (mr *MockStoreMockRecorder) GetTailnetClientsForAgent(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetClientsForAgent", reflect.TypeOf((*MockStore)(nil).GetTailnetClientsForAgent), arg0, arg1)
}
// GetTemplateAverageBuildTime mocks base method.
func (m *MockStore) GetTemplateAverageBuildTime(arg0 context.Context, arg1 database.GetTemplateAverageBuildTimeParams) (database.GetTemplateAverageBuildTimeRow, error) {
m.ctrl.T.Helper()
@ -3189,6 +3263,51 @@ func (mr *MockStoreMockRecorder) UpsertServiceBanner(arg0, arg1 interface{}) *go
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertServiceBanner", reflect.TypeOf((*MockStore)(nil).UpsertServiceBanner), arg0, arg1)
}
// UpsertTailnetAgent mocks base method.
func (m *MockStore) UpsertTailnetAgent(arg0 context.Context, arg1 database.UpsertTailnetAgentParams) (database.TailnetAgent, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertTailnetAgent", arg0, arg1)
ret0, _ := ret[0].(database.TailnetAgent)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertTailnetAgent indicates an expected call of UpsertTailnetAgent.
func (mr *MockStoreMockRecorder) UpsertTailnetAgent(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTailnetAgent", reflect.TypeOf((*MockStore)(nil).UpsertTailnetAgent), arg0, arg1)
}
// UpsertTailnetClient mocks base method.
func (m *MockStore) UpsertTailnetClient(arg0 context.Context, arg1 database.UpsertTailnetClientParams) (database.TailnetClient, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertTailnetClient", arg0, arg1)
ret0, _ := ret[0].(database.TailnetClient)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertTailnetClient indicates an expected call of UpsertTailnetClient.
func (mr *MockStoreMockRecorder) UpsertTailnetClient(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTailnetClient", reflect.TypeOf((*MockStore)(nil).UpsertTailnetClient), arg0, arg1)
}
// UpsertTailnetCoordinator mocks base method.
func (m *MockStore) UpsertTailnetCoordinator(arg0 context.Context, arg1 uuid.UUID) (database.TailnetCoordinator, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertTailnetCoordinator", arg0, arg1)
ret0, _ := ret[0].(database.TailnetCoordinator)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertTailnetCoordinator indicates an expected call of UpsertTailnetCoordinator.
func (mr *MockStoreMockRecorder) UpsertTailnetCoordinator(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTailnetCoordinator", reflect.TypeOf((*MockStore)(nil).UpsertTailnetCoordinator), arg0, arg1)
}
// Wrappers mocks base method.
func (m *MockStore) Wrappers() []string {
m.ctrl.T.Helper()

View File

@ -14,12 +14,17 @@ import (
"github.com/coder/coder/coderd/database/pubsub"
)
// WillUsePostgres returns true if a call to NewDB() will return a real, postgres-backed Store and Pubsub.
func WillUsePostgres() bool {
return os.Getenv("DB") != ""
}
func NewDB(t testing.TB) (database.Store, pubsub.Pubsub) {
t.Helper()
db := dbfake.New()
ps := pubsub.NewInMemory()
if os.Getenv("DB") != "" {
if WillUsePostgres() {
connectionURL := os.Getenv("CODER_PG_CONNECTION_URL")
if connectionURL == "" {
var (

View File

@ -171,6 +171,45 @@ BEGIN
END;
$$;
CREATE FUNCTION tailnet_notify_agent_change() RETURNS trigger
LANGUAGE plpgsql
AS $$
BEGIN
IF (OLD IS NOT NULL) THEN
PERFORM pg_notify('tailnet_agent_update', OLD.id::text);
RETURN NULL;
END IF;
IF (NEW IS NOT NULL) THEN
PERFORM pg_notify('tailnet_agent_update', NEW.id::text);
RETURN NULL;
END IF;
END;
$$;
CREATE FUNCTION tailnet_notify_client_change() RETURNS trigger
LANGUAGE plpgsql
AS $$
BEGIN
IF (OLD IS NOT NULL) THEN
PERFORM pg_notify('tailnet_client_update', OLD.id || ',' || OLD.agent_id);
RETURN NULL;
END IF;
IF (NEW IS NOT NULL) THEN
PERFORM pg_notify('tailnet_client_update', NEW.id || ',' || NEW.agent_id);
RETURN NULL;
END IF;
END;
$$;
CREATE FUNCTION tailnet_notify_coordinator_heartbeat() RETURNS trigger
LANGUAGE plpgsql
AS $$
BEGIN
PERFORM pg_notify('tailnet_coordinator_heartbeat', NEW.id::text);
RETURN NULL;
END;
$$;
CREATE TABLE api_keys (
id text NOT NULL,
hashed_secret bytea NOT NULL,
@ -383,6 +422,28 @@ CREATE TABLE site_configs (
value character varying(8192) NOT NULL
);
CREATE TABLE tailnet_agents (
id uuid NOT NULL,
coordinator_id uuid NOT NULL,
updated_at timestamp with time zone NOT NULL,
node jsonb NOT NULL
);
CREATE TABLE tailnet_clients (
id uuid NOT NULL,
coordinator_id uuid NOT NULL,
agent_id uuid NOT NULL,
updated_at timestamp with time zone NOT NULL,
node jsonb NOT NULL
);
CREATE TABLE tailnet_coordinators (
id uuid NOT NULL,
heartbeat_at timestamp with time zone NOT NULL
);
COMMENT ON TABLE tailnet_coordinators IS 'We keep this separate from replicas in case we need to break the coordinator out into its own service';
CREATE TABLE template_version_parameters (
template_version_id uuid NOT NULL,
name text NOT NULL,
@ -835,6 +896,15 @@ ALTER TABLE ONLY provisioner_jobs
ALTER TABLE ONLY site_configs
ADD CONSTRAINT site_configs_key_key UNIQUE (key);
ALTER TABLE ONLY tailnet_agents
ADD CONSTRAINT tailnet_agents_pkey PRIMARY KEY (id, coordinator_id);
ALTER TABLE ONLY tailnet_clients
ADD CONSTRAINT tailnet_clients_pkey PRIMARY KEY (id, coordinator_id);
ALTER TABLE ONLY tailnet_coordinators
ADD CONSTRAINT tailnet_coordinators_pkey PRIMARY KEY (id);
ALTER TABLE ONLY template_version_parameters
ADD CONSTRAINT template_version_parameters_template_version_id_name_key UNIQUE (template_version_id, name);
@ -922,6 +992,12 @@ CREATE UNIQUE INDEX idx_organization_name ON organizations USING btree (name);
CREATE UNIQUE INDEX idx_organization_name_lower ON organizations USING btree (lower(name));
CREATE INDEX idx_tailnet_agents_coordinator ON tailnet_agents USING btree (coordinator_id);
CREATE INDEX idx_tailnet_clients_agent ON tailnet_clients USING btree (agent_id);
CREATE INDEX idx_tailnet_clients_coordinator ON tailnet_clients USING btree (coordinator_id);
CREATE UNIQUE INDEX idx_users_email ON users USING btree (email) WHERE (deleted = false);
CREATE UNIQUE INDEX idx_users_username ON users USING btree (username) WHERE (deleted = false);
@ -948,6 +1024,12 @@ CREATE INDEX workspace_resources_job_id_idx ON workspace_resources USING btree (
CREATE UNIQUE INDEX workspaces_owner_id_lower_idx ON workspaces USING btree (owner_id, lower((name)::text)) WHERE (deleted = false);
CREATE TRIGGER tailnet_notify_agent_change AFTER INSERT OR DELETE OR UPDATE ON tailnet_agents FOR EACH ROW EXECUTE FUNCTION tailnet_notify_agent_change();
CREATE TRIGGER tailnet_notify_client_change AFTER INSERT OR DELETE OR UPDATE ON tailnet_clients FOR EACH ROW EXECUTE FUNCTION tailnet_notify_client_change();
CREATE TRIGGER tailnet_notify_coordinator_heartbeat AFTER INSERT OR UPDATE ON tailnet_coordinators FOR EACH ROW EXECUTE FUNCTION tailnet_notify_coordinator_heartbeat();
CREATE TRIGGER trigger_insert_apikeys BEFORE INSERT ON api_keys FOR EACH ROW EXECUTE FUNCTION insert_apikey_fail_if_user_deleted();
CREATE TRIGGER trigger_update_users AFTER INSERT OR UPDATE ON users FOR EACH ROW WHEN ((new.deleted = true)) EXECUTE FUNCTION delete_deleted_user_api_keys();
@ -982,6 +1064,12 @@ ALTER TABLE ONLY provisioner_job_logs
ALTER TABLE ONLY provisioner_jobs
ADD CONSTRAINT provisioner_jobs_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
ALTER TABLE ONLY tailnet_agents
ADD CONSTRAINT tailnet_agents_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE;
ALTER TABLE ONLY tailnet_clients
ADD CONSTRAINT tailnet_clients_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE;
ALTER TABLE ONLY template_version_parameters
ADD CONSTRAINT template_version_parameters_template_version_id_fkey FOREIGN KEY (template_version_id) REFERENCES template_versions(id) ON DELETE CASCADE;

View File

@ -0,0 +1,18 @@
BEGIN;
DROP TRIGGER IF EXISTS tailnet_notify_client_change ON tailnet_clients;
DROP FUNCTION IF EXISTS tailnet_notify_client_change;
DROP INDEX IF EXISTS idx_tailnet_clients_agent;
DROP INDEX IF EXISTS idx_tailnet_clients_coordinator;
DROP TABLE tailnet_clients;
DROP TRIGGER IF EXISTS tailnet_notify_agent_change ON tailnet_agents;
DROP FUNCTION IF EXISTS tailnet_notify_agent_change;
DROP INDEX IF EXISTS idx_tailnet_agents_coordinator;
DROP TABLE IF EXISTS tailnet_agents;
DROP TRIGGER IF EXISTS tailnet_notify_coordinator_heartbeat ON tailnet_coordinators;
DROP FUNCTION IF EXISTS tailnet_notify_coordinator_heartbeat;
DROP TABLE IF EXISTS tailnet_coordinators;
COMMIT;

View File

@ -0,0 +1,97 @@
BEGIN;
CREATE TABLE tailnet_coordinators (
id uuid NOT NULL PRIMARY KEY,
heartbeat_at timestamp with time zone NOT NULL
);
COMMENT ON TABLE tailnet_coordinators IS 'We keep this separate from replicas in case we need to break the coordinator out into its own service';
CREATE TABLE tailnet_clients (
id uuid NOT NULL,
coordinator_id uuid NOT NULL,
agent_id uuid NOT NULL,
updated_at timestamp with time zone NOT NULL,
node jsonb NOT NULL,
PRIMARY KEY (id, coordinator_id),
FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE
);
-- For querying/deleting mappings
CREATE INDEX idx_tailnet_clients_agent ON tailnet_clients (agent_id);
-- For shutting down / GC a coordinator
CREATE INDEX idx_tailnet_clients_coordinator ON tailnet_clients (coordinator_id);
CREATE TABLE tailnet_agents (
id uuid NOT NULL,
coordinator_id uuid NOT NULL,
updated_at timestamp with time zone NOT NULL,
node jsonb NOT NULL,
PRIMARY KEY (id, coordinator_id),
FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE
);
-- For shutting down / GC a coordinator
CREATE INDEX idx_tailnet_agents_coordinator ON tailnet_agents (coordinator_id);
-- Any time the tailnet_clients table changes, send an update with the affected client and agent IDs
CREATE FUNCTION tailnet_notify_client_change() RETURNS trigger
LANGUAGE plpgsql
AS $$
BEGIN
IF (OLD IS NOT NULL) THEN
PERFORM pg_notify('tailnet_client_update', OLD.id || ',' || OLD.agent_id);
RETURN NULL;
END IF;
IF (NEW IS NOT NULL) THEN
PERFORM pg_notify('tailnet_client_update', NEW.id || ',' || NEW.agent_id);
RETURN NULL;
END IF;
END;
$$;
CREATE TRIGGER tailnet_notify_client_change
AFTER INSERT OR UPDATE OR DELETE ON tailnet_clients
FOR EACH ROW
EXECUTE PROCEDURE tailnet_notify_client_change();
-- Any time tailnet_agents table changes, send an update with the affected agent ID.
CREATE FUNCTION tailnet_notify_agent_change() RETURNS trigger
LANGUAGE plpgsql
AS $$
BEGIN
IF (OLD IS NOT NULL) THEN
PERFORM pg_notify('tailnet_agent_update', OLD.id::text);
RETURN NULL;
END IF;
IF (NEW IS NOT NULL) THEN
PERFORM pg_notify('tailnet_agent_update', NEW.id::text);
RETURN NULL;
END IF;
END;
$$;
CREATE TRIGGER tailnet_notify_agent_change
AFTER INSERT OR UPDATE OR DELETE ON tailnet_agents
FOR EACH ROW
EXECUTE PROCEDURE tailnet_notify_agent_change();
-- Send coordinator heartbeats
CREATE FUNCTION tailnet_notify_coordinator_heartbeat() RETURNS trigger
LANGUAGE plpgsql
AS $$
BEGIN
PERFORM pg_notify('tailnet_coordinator_heartbeat', NEW.id::text);
RETURN NULL;
END;
$$;
CREATE TRIGGER tailnet_notify_coordinator_heartbeat
AFTER INSERT OR UPDATE ON tailnet_coordinators
FOR EACH ROW
EXECUTE PROCEDURE tailnet_notify_coordinator_heartbeat();
COMMIT;

View File

@ -0,0 +1,28 @@
INSERT INTO tailnet_coordinators
(id, heartbeat_at)
VALUES
(
'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11',
'2023-06-15 10:23:54+00'
);
INSERT INTO tailnet_clients
(id, agent_id, coordinator_id, updated_at, node)
VALUES
(
'b0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11',
'c0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11',
'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11',
'2023-06-15 10:23:54+00',
'{"preferred_derp": 12}'::json
);
INSERT INTO tailnet_agents
(id, coordinator_id, updated_at, node)
VALUES
(
'c0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11',
'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11',
'2023-06-15 10:23:54+00',
'{"preferred_derp": 13}'::json
);

View File

@ -1534,6 +1534,27 @@ type SiteConfig struct {
Value string `db:"value" json:"value"`
}
type TailnetAgent struct {
ID uuid.UUID `db:"id" json:"id"`
CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Node json.RawMessage `db:"node" json:"node"`
}
type TailnetClient struct {
ID uuid.UUID `db:"id" json:"id"`
CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"`
AgentID uuid.UUID `db:"agent_id" json:"agent_id"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Node json.RawMessage `db:"node" json:"node"`
}
// We keep this separate from replicas in case we need to break the coordinator out into its own service
type TailnetCoordinator struct {
ID uuid.UUID `db:"id" json:"id"`
HeartbeatAt time.Time `db:"heartbeat_at" json:"heartbeat_at"`
}
type Template struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`

View File

@ -29,6 +29,7 @@ type sqlcQuerier interface {
DeleteAPIKeyByID(ctx context.Context, id string) error
DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error
DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error
DeleteCoordinator(ctx context.Context, id uuid.UUID) error
DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error
DeleteGroupByID(ctx context.Context, id uuid.UUID) error
DeleteGroupMemberFromGroup(ctx context.Context, arg DeleteGroupMemberFromGroupParams) error
@ -39,6 +40,8 @@ type sqlcQuerier interface {
DeleteOldWorkspaceAgentStartupLogs(ctx context.Context) error
DeleteOldWorkspaceAgentStats(ctx context.Context) error
DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt time.Time) error
DeleteTailnetAgent(ctx context.Context, arg DeleteTailnetAgentParams) (DeleteTailnetAgentRow, error)
DeleteTailnetClient(ctx context.Context, arg DeleteTailnetClientParams) (DeleteTailnetClientRow, error)
GetAPIKeyByID(ctx context.Context, id string) (APIKey, error)
// there is no unique constraint on empty token names
GetAPIKeyByName(ctx context.Context, arg GetAPIKeyByNameParams) (APIKey, error)
@ -97,6 +100,8 @@ type sqlcQuerier interface {
GetQuotaConsumedForUser(ctx context.Context, ownerID uuid.UUID) (int64, error)
GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error)
GetServiceBanner(ctx context.Context) (string, error)
GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]TailnetAgent, error)
GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]TailnetClient, error)
GetTemplateAverageBuildTime(ctx context.Context, arg GetTemplateAverageBuildTimeParams) (GetTemplateAverageBuildTimeRow, error)
GetTemplateByID(ctx context.Context, id uuid.UUID) (Template, error)
GetTemplateByOrganizationAndName(ctx context.Context, arg GetTemplateByOrganizationAndNameParams) (Template, error)
@ -264,6 +269,9 @@ type sqlcQuerier interface {
UpsertLastUpdateCheck(ctx context.Context, value string) error
UpsertLogoURL(ctx context.Context, value string) error
UpsertServiceBanner(ctx context.Context, value string) error
UpsertTailnetAgent(ctx context.Context, arg UpsertTailnetAgentParams) (TailnetAgent, error)
UpsertTailnetClient(ctx context.Context, arg UpsertTailnetClientParams) (TailnetClient, error)
UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (TailnetCoordinator, error)
}
var _ sqlcQuerier = (*sqlQuerier)(nil)

View File

@ -3261,6 +3261,239 @@ func (q *sqlQuerier) UpsertServiceBanner(ctx context.Context, value string) erro
return err
}
const deleteCoordinator = `-- name: DeleteCoordinator :exec
DELETE
FROM tailnet_coordinators
WHERE id = $1
`
func (q *sqlQuerier) DeleteCoordinator(ctx context.Context, id uuid.UUID) error {
_, err := q.db.ExecContext(ctx, deleteCoordinator, id)
return err
}
const deleteTailnetAgent = `-- name: DeleteTailnetAgent :one
DELETE
FROM tailnet_agents
WHERE id = $1 and coordinator_id = $2
RETURNING id, coordinator_id
`
type DeleteTailnetAgentParams struct {
ID uuid.UUID `db:"id" json:"id"`
CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"`
}
type DeleteTailnetAgentRow struct {
ID uuid.UUID `db:"id" json:"id"`
CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"`
}
func (q *sqlQuerier) DeleteTailnetAgent(ctx context.Context, arg DeleteTailnetAgentParams) (DeleteTailnetAgentRow, error) {
row := q.db.QueryRowContext(ctx, deleteTailnetAgent, arg.ID, arg.CoordinatorID)
var i DeleteTailnetAgentRow
err := row.Scan(&i.ID, &i.CoordinatorID)
return i, err
}
const deleteTailnetClient = `-- name: DeleteTailnetClient :one
DELETE
FROM tailnet_clients
WHERE id = $1 and coordinator_id = $2
RETURNING id, coordinator_id
`
type DeleteTailnetClientParams struct {
ID uuid.UUID `db:"id" json:"id"`
CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"`
}
type DeleteTailnetClientRow struct {
ID uuid.UUID `db:"id" json:"id"`
CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"`
}
func (q *sqlQuerier) DeleteTailnetClient(ctx context.Context, arg DeleteTailnetClientParams) (DeleteTailnetClientRow, error) {
row := q.db.QueryRowContext(ctx, deleteTailnetClient, arg.ID, arg.CoordinatorID)
var i DeleteTailnetClientRow
err := row.Scan(&i.ID, &i.CoordinatorID)
return i, err
}
const getTailnetAgents = `-- name: GetTailnetAgents :many
SELECT id, coordinator_id, updated_at, node
FROM tailnet_agents
WHERE id = $1
`
func (q *sqlQuerier) GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]TailnetAgent, error) {
rows, err := q.db.QueryContext(ctx, getTailnetAgents, id)
if err != nil {
return nil, err
}
defer rows.Close()
var items []TailnetAgent
for rows.Next() {
var i TailnetAgent
if err := rows.Scan(
&i.ID,
&i.CoordinatorID,
&i.UpdatedAt,
&i.Node,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getTailnetClientsForAgent = `-- name: GetTailnetClientsForAgent :many
SELECT id, coordinator_id, agent_id, updated_at, node
FROM tailnet_clients
WHERE agent_id = $1
`
func (q *sqlQuerier) GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]TailnetClient, error) {
rows, err := q.db.QueryContext(ctx, getTailnetClientsForAgent, agentID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []TailnetClient
for rows.Next() {
var i TailnetClient
if err := rows.Scan(
&i.ID,
&i.CoordinatorID,
&i.AgentID,
&i.UpdatedAt,
&i.Node,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const upsertTailnetAgent = `-- name: UpsertTailnetAgent :one
INSERT INTO
tailnet_agents (
id,
coordinator_id,
node,
updated_at
)
VALUES
($1, $2, $3, now() at time zone 'utc')
ON CONFLICT (id, coordinator_id)
DO UPDATE SET
id = $1,
coordinator_id = $2,
node = $3,
updated_at = now() at time zone 'utc'
RETURNING id, coordinator_id, updated_at, node
`
type UpsertTailnetAgentParams struct {
ID uuid.UUID `db:"id" json:"id"`
CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"`
Node json.RawMessage `db:"node" json:"node"`
}
func (q *sqlQuerier) UpsertTailnetAgent(ctx context.Context, arg UpsertTailnetAgentParams) (TailnetAgent, error) {
row := q.db.QueryRowContext(ctx, upsertTailnetAgent, arg.ID, arg.CoordinatorID, arg.Node)
var i TailnetAgent
err := row.Scan(
&i.ID,
&i.CoordinatorID,
&i.UpdatedAt,
&i.Node,
)
return i, err
}
const upsertTailnetClient = `-- name: UpsertTailnetClient :one
INSERT INTO
tailnet_clients (
id,
coordinator_id,
agent_id,
node,
updated_at
)
VALUES
($1, $2, $3, $4, now() at time zone 'utc')
ON CONFLICT (id, coordinator_id)
DO UPDATE SET
id = $1,
coordinator_id = $2,
agent_id = $3,
node = $4,
updated_at = now() at time zone 'utc'
RETURNING id, coordinator_id, agent_id, updated_at, node
`
type UpsertTailnetClientParams struct {
ID uuid.UUID `db:"id" json:"id"`
CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"`
AgentID uuid.UUID `db:"agent_id" json:"agent_id"`
Node json.RawMessage `db:"node" json:"node"`
}
func (q *sqlQuerier) UpsertTailnetClient(ctx context.Context, arg UpsertTailnetClientParams) (TailnetClient, error) {
row := q.db.QueryRowContext(ctx, upsertTailnetClient,
arg.ID,
arg.CoordinatorID,
arg.AgentID,
arg.Node,
)
var i TailnetClient
err := row.Scan(
&i.ID,
&i.CoordinatorID,
&i.AgentID,
&i.UpdatedAt,
&i.Node,
)
return i, err
}
const upsertTailnetCoordinator = `-- name: UpsertTailnetCoordinator :one
INSERT INTO
tailnet_coordinators (
id,
heartbeat_at
)
VALUES
($1, now() at time zone 'utc')
ON CONFLICT (id)
DO UPDATE SET
id = $1,
heartbeat_at = now() at time zone 'utc'
RETURNING id, heartbeat_at
`
func (q *sqlQuerier) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (TailnetCoordinator, error) {
row := q.db.QueryRowContext(ctx, upsertTailnetCoordinator, id)
var i TailnetCoordinator
err := row.Scan(&i.ID, &i.HeartbeatAt)
return i, err
}
const getTemplateAverageBuildTime = `-- name: GetTemplateAverageBuildTime :one
WITH build_times AS (
SELECT

View File

@ -0,0 +1,79 @@
-- name: UpsertTailnetClient :one
INSERT INTO
tailnet_clients (
id,
coordinator_id,
agent_id,
node,
updated_at
)
VALUES
($1, $2, $3, $4, now() at time zone 'utc')
ON CONFLICT (id, coordinator_id)
DO UPDATE SET
id = $1,
coordinator_id = $2,
agent_id = $3,
node = $4,
updated_at = now() at time zone 'utc'
RETURNING *;
-- name: UpsertTailnetAgent :one
INSERT INTO
tailnet_agents (
id,
coordinator_id,
node,
updated_at
)
VALUES
($1, $2, $3, now() at time zone 'utc')
ON CONFLICT (id, coordinator_id)
DO UPDATE SET
id = $1,
coordinator_id = $2,
node = $3,
updated_at = now() at time zone 'utc'
RETURNING *;
-- name: DeleteTailnetClient :one
DELETE
FROM tailnet_clients
WHERE id = $1 and coordinator_id = $2
RETURNING id, coordinator_id;
-- name: DeleteTailnetAgent :one
DELETE
FROM tailnet_agents
WHERE id = $1 and coordinator_id = $2
RETURNING id, coordinator_id;
-- name: DeleteCoordinator :exec
DELETE
FROM tailnet_coordinators
WHERE id = $1;
-- name: GetTailnetAgents :many
SELECT *
FROM tailnet_agents
WHERE id = $1;
-- name: GetTailnetClientsForAgent :many
SELECT *
FROM tailnet_clients
WHERE agent_id = $1;
-- name: UpsertTailnetCoordinator :one
INSERT INTO
tailnet_coordinators (
id,
heartbeat_at
)
VALUES
($1, now() at time zone 'utc')
ON CONFLICT (id)
DO UPDATE SET
id = $1,
heartbeat_at = now() at time zone 'utc'
RETURNING *;

View File

@ -173,6 +173,11 @@ var (
ResourceSystem = Object{
Type: "system",
}
// ResourceTailnetCoordinator is a pseudo-resource for use by the tailnet coordinator
ResourceTailnetCoordinator = Object{
Type: "tailnet_coordinator",
}
)
// Object is used to create objects for authz checks when you have none in

View File

@ -18,6 +18,7 @@ func AllResources() []Object {
ResourceReplicas,
ResourceRoleAssignment,
ResourceSystem,
ResourceTailnetCoordinator,
ResourceTemplate,
ResourceUser,
ResourceUserData,

View File

@ -95,7 +95,7 @@ func TestCoordinatorSingle(t *testing.T) {
assert.NoError(t, err)
close(closeAgentChan)
}()
sendAgentNode(&agpl.Node{})
sendAgentNode(&agpl.Node{PreferredDERP: 1})
require.Eventually(t, func() bool {
return coordinator.Node(agentID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
@ -117,12 +117,12 @@ func TestCoordinatorSingle(t *testing.T) {
}()
agentNodes := <-clientNodeChan
require.Len(t, agentNodes, 1)
sendClientNode(&agpl.Node{})
sendClientNode(&agpl.Node{PreferredDERP: 2})
clientNodes := <-agentNodeChan
require.Len(t, clientNodes, 1)
// Ensure an update to the agent node reaches the client!
sendAgentNode(&agpl.Node{})
sendAgentNode(&agpl.Node{PreferredDERP: 3})
agentNodes = <-clientNodeChan
require.Len(t, agentNodes, 1)
@ -188,7 +188,7 @@ func TestCoordinatorHA(t *testing.T) {
assert.NoError(t, err)
close(closeAgentChan)
}()
sendAgentNode(&agpl.Node{})
sendAgentNode(&agpl.Node{PreferredDERP: 1})
require.Eventually(t, func() bool {
return coordinator1.Node(agentID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
@ -214,13 +214,13 @@ func TestCoordinatorHA(t *testing.T) {
}()
agentNodes := <-clientNodeChan
require.Len(t, agentNodes, 1)
sendClientNode(&agpl.Node{})
sendClientNode(&agpl.Node{PreferredDERP: 2})
_ = sendClientNode
clientNodes := <-agentNodeChan
require.Len(t, clientNodes, 1)
// Ensure an update to the agent node reaches the client!
sendAgentNode(&agpl.Node{})
sendAgentNode(&agpl.Node{PreferredDERP: 3})
agentNodes = <-clientNodeChan
require.Len(t, agentNodes, 1)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,655 @@
package tailnet_test
import (
"context"
"database/sql"
"encoding/json"
"io"
"net"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"golang.org/x/xerrors"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbtestutil"
"github.com/coder/coder/enterprise/tailnet"
agpl "github.com/coder/coder/tailnet"
"github.com/coder/coder/testutil"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestPGCoordinatorSingle_ClientWithoutAgent(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, pubsub := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
require.NoError(t, err)
defer coordinator.Close()
agentID := uuid.New()
client := newTestClient(t, coordinator, agentID)
defer client.close()
client.sendNode(&agpl.Node{PreferredDERP: 10})
require.Eventually(t, func() bool {
clients, err := store.GetTailnetClientsForAgent(ctx, agentID)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
t.Fatalf("database error: %v", err)
}
if len(clients) == 0 {
return false
}
var node agpl.Node
err = json.Unmarshal(clients[0].Node, &node)
assert.NoError(t, err)
assert.Equal(t, 10, node.PreferredDERP)
return true
}, testutil.WaitShort, testutil.IntervalFast)
err = client.close()
require.NoError(t, err)
<-client.errChan
<-client.closeChan
assertEventuallyNoClientsForAgent(ctx, t, store, agentID)
}
func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, pubsub := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
require.NoError(t, err)
defer coordinator.Close()
agent := newTestAgent(t, coordinator)
defer agent.close()
agent.sendNode(&agpl.Node{PreferredDERP: 10})
require.Eventually(t, func() bool {
agents, err := store.GetTailnetAgents(ctx, agent.id)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
t.Fatalf("database error: %v", err)
}
if len(agents) == 0 {
return false
}
var node agpl.Node
err = json.Unmarshal(agents[0].Node, &node)
assert.NoError(t, err)
assert.Equal(t, 10, node.PreferredDERP)
return true
}, testutil.WaitShort, testutil.IntervalFast)
err = agent.close()
require.NoError(t, err)
<-agent.errChan
<-agent.closeChan
assertEventuallyNoAgents(ctx, t, store, agent.id)
}
func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, pubsub := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
require.NoError(t, err)
defer coordinator.Close()
agent := newTestAgent(t, coordinator)
defer agent.close()
agent.sendNode(&agpl.Node{PreferredDERP: 10})
client := newTestClient(t, coordinator, agent.id)
defer client.close()
agentNodes := client.recvNodes(ctx, t)
require.Len(t, agentNodes, 1)
assert.Equal(t, 10, agentNodes[0].PreferredDERP)
client.sendNode(&agpl.Node{PreferredDERP: 11})
clientNodes := agent.recvNodes(ctx, t)
require.Len(t, clientNodes, 1)
assert.Equal(t, 11, clientNodes[0].PreferredDERP)
// Ensure an update to the agent node reaches the connIO!
agent.sendNode(&agpl.Node{PreferredDERP: 12})
agentNodes = client.recvNodes(ctx, t)
require.Len(t, agentNodes, 1)
assert.Equal(t, 12, agentNodes[0].PreferredDERP)
// Close the agent WebSocket so a new one can connect.
err = agent.close()
require.NoError(t, err)
_ = agent.recvErr(ctx, t)
agent.waitForClose(ctx, t)
// Create a new agent connection. This is to simulate a reconnect!
agent = newTestAgent(t, coordinator, agent.id)
// Ensure the existing listening connIO sends its node immediately!
clientNodes = agent.recvNodes(ctx, t)
require.Len(t, clientNodes, 1)
assert.Equal(t, 11, clientNodes[0].PreferredDERP)
// Send a bunch of updates in rapid succession, and test that we eventually get the latest. We don't want the
// coordinator accidentally reordering things.
for d := 13; d < 36; d++ {
agent.sendNode(&agpl.Node{PreferredDERP: d})
}
for {
nodes := client.recvNodes(ctx, t)
if !assert.Len(t, nodes, 1) {
break
}
if nodes[0].PreferredDERP == 35 {
// got latest!
break
}
}
err = agent.close()
require.NoError(t, err)
_ = agent.recvErr(ctx, t)
agent.waitForClose(ctx, t)
err = client.close()
require.NoError(t, err)
_ = client.recvErr(ctx, t)
client.waitForClose(ctx, t)
assertEventuallyNoAgents(ctx, t, store, agent.id)
assertEventuallyNoClientsForAgent(ctx, t, store, agent.id)
}
func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, pubsub := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
require.NoError(t, err)
defer coordinator.Close()
agent := newTestAgent(t, coordinator)
defer agent.close()
agent.sendNode(&agpl.Node{PreferredDERP: 10})
client := newTestClient(t, coordinator, agent.id)
defer client.close()
nodes := client.recvNodes(ctx, t)
assertHasDERPs(t, nodes, 10)
client.sendNode(&agpl.Node{PreferredDERP: 11})
nodes = agent.recvNodes(ctx, t)
assertHasDERPs(t, nodes, 11)
// simulate a second coordinator via DB calls only --- our goal is to test broken heart-beating, so we can't use a
// real coordinator
fCoord := &fakeCoordinator{
ctx: ctx,
t: t,
store: store,
id: uuid.New(),
}
start := time.Now()
fCoord.heartbeat()
fCoord.agentNode(agent.id, &agpl.Node{PreferredDERP: 12})
nodes = client.recvNodes(ctx, t)
assertHasDERPs(t, nodes, 12)
// when the fake coordinator misses enough heartbeats, the real coordinator should send an update with the old
// node for the agent.
nodes = client.recvNodes(ctx, t)
assert.Greater(t, time.Since(start), tailnet.HeartbeatPeriod*tailnet.MissedHeartbeats)
assertHasDERPs(t, nodes, 10)
err = agent.close()
require.NoError(t, err)
_ = agent.recvErr(ctx, t)
agent.waitForClose(ctx, t)
err = client.close()
require.NoError(t, err)
_ = client.recvErr(ctx, t)
client.waitForClose(ctx, t)
assertEventuallyNoClientsForAgent(ctx, t, store, agent.id)
}
func TestPGCoordinatorSingle_SendsHeartbeats(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, pubsub := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
mu := sync.Mutex{}
heartbeats := []time.Time{}
unsub, err := pubsub.SubscribeWithErr(tailnet.EventHeartbeats, func(_ context.Context, msg []byte, err error) {
assert.NoError(t, err)
mu.Lock()
defer mu.Unlock()
heartbeats = append(heartbeats, time.Now())
})
require.NoError(t, err)
defer unsub()
start := time.Now()
coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
require.NoError(t, err)
defer coordinator.Close()
require.Eventually(t, func() bool {
mu.Lock()
defer mu.Unlock()
if len(heartbeats) < 2 {
return false
}
require.Greater(t, heartbeats[0].Sub(start), time.Duration(0))
require.Greater(t, heartbeats[1].Sub(start), time.Duration(0))
return assert.Greater(t, heartbeats[1].Sub(heartbeats[0]), tailnet.HeartbeatPeriod*9/10)
}, testutil.WaitMedium, testutil.IntervalMedium)
}
// TestPGCoordinatorDual_Mainline tests with 2 coordinators, one agent connected to each, and 2 clients per agent.
//
// +---------+
// agent1 ---> | coord1 | <--- client11 (coord 1, agent 1)
// | |
// | | <--- client12 (coord 1, agent 2)
// +---------+
// +---------+
// agent2 ---> | coord2 | <--- client21 (coord 2, agent 1)
// | |
// | | <--- client22 (coord2, agent 2)
// +---------+
func TestPGCoordinatorDual_Mainline(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, pubsub := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coord1, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
require.NoError(t, err)
defer coord1.Close()
coord2, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
require.NoError(t, err)
defer coord2.Close()
agent1 := newTestAgent(t, coord1)
defer agent1.close()
agent2 := newTestAgent(t, coord2)
defer agent2.close()
client11 := newTestClient(t, coord1, agent1.id)
defer client11.close()
client12 := newTestClient(t, coord1, agent2.id)
defer client12.close()
client21 := newTestClient(t, coord2, agent1.id)
defer client21.close()
client22 := newTestClient(t, coord2, agent2.id)
defer client22.close()
client11.sendNode(&agpl.Node{PreferredDERP: 11})
nodes := agent1.recvNodes(ctx, t)
assert.Len(t, nodes, 1)
assertHasDERPs(t, nodes, 11)
client21.sendNode(&agpl.Node{PreferredDERP: 21})
nodes = agent1.recvNodes(ctx, t)
assertHasDERPs(t, nodes, 21, 11)
client22.sendNode(&agpl.Node{PreferredDERP: 22})
nodes = agent2.recvNodes(ctx, t)
assertHasDERPs(t, nodes, 22)
agent2.sendNode(&agpl.Node{PreferredDERP: 2})
nodes = client22.recvNodes(ctx, t)
assertHasDERPs(t, nodes, 2)
nodes = client12.recvNodes(ctx, t)
assertHasDERPs(t, nodes, 2)
client12.sendNode(&agpl.Node{PreferredDERP: 12})
nodes = agent2.recvNodes(ctx, t)
assertHasDERPs(t, nodes, 12, 22)
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
nodes = client21.recvNodes(ctx, t)
assertHasDERPs(t, nodes, 1)
nodes = client11.recvNodes(ctx, t)
assertHasDERPs(t, nodes, 1)
// let's close coord2
err = coord2.Close()
require.NoError(t, err)
// this closes agent2, client22, client21
err = agent2.recvErr(ctx, t)
require.ErrorIs(t, err, io.EOF)
err = client22.recvErr(ctx, t)
require.ErrorIs(t, err, io.EOF)
err = client21.recvErr(ctx, t)
require.ErrorIs(t, err, io.EOF)
// agent1 will see an update that drops client21.
// In this case the update is superfluous because client11's node hasn't changed, and agents don't deprogram clients
// from the dataplane even if they are missing. Suppressing this kind of update would require the coordinator to
// store all the data its sent to each connection, so we don't bother.
nodes = agent1.recvNodes(ctx, t)
assertHasDERPs(t, nodes, 11)
// note that although agent2 is disconnected, client12 does NOT get an update because we suppress empty updates.
// (Its easy to tell these are superfluous.)
assertEventuallyNoAgents(ctx, t, store, agent2.id)
// Close coord1
err = coord1.Close()
require.NoError(t, err)
// this closes agent1, client12, client11
err = agent1.recvErr(ctx, t)
require.ErrorIs(t, err, io.EOF)
err = client12.recvErr(ctx, t)
require.ErrorIs(t, err, io.EOF)
err = client11.recvErr(ctx, t)
require.ErrorIs(t, err, io.EOF)
// wait for all connections to close
err = agent1.close()
require.NoError(t, err)
agent1.waitForClose(ctx, t)
err = agent2.close()
require.NoError(t, err)
agent2.waitForClose(ctx, t)
err = client11.close()
require.NoError(t, err)
client11.waitForClose(ctx, t)
err = client12.close()
require.NoError(t, err)
client12.waitForClose(ctx, t)
err = client21.close()
require.NoError(t, err)
client21.waitForClose(ctx, t)
err = client22.close()
require.NoError(t, err)
client22.waitForClose(ctx, t)
assertEventuallyNoAgents(ctx, t, store, agent1.id)
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
assertEventuallyNoClientsForAgent(ctx, t, store, agent2.id)
}
// TestPGCoordinator_MultiAgent tests when a single agent connects to multiple coordinators.
// We use two agent connections, but they share the same AgentID. This could happen due to a reconnection,
// or an infrastructure problem where an old workspace is not fully cleaned up before a new one started.
//
// +---------+
// agent1 ---> | coord1 |
// +---------+
// +---------+
// agent2 ---> | coord2 |
// +---------+
// +---------+
// | coord3 | <--- client
// +---------+
func TestPGCoordinator_MultiAgent(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
store, pubsub := dbtestutil.NewDB(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancel()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
coord1, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
require.NoError(t, err)
defer coord1.Close()
coord2, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
require.NoError(t, err)
defer coord2.Close()
coord3, err := tailnet.NewPGCoord(ctx, logger, pubsub, store)
require.NoError(t, err)
defer coord3.Close()
agent1 := newTestAgent(t, coord1)
defer agent1.close()
agent2 := newTestAgent(t, coord2, agent1.id)
defer agent2.close()
client := newTestClient(t, coord3, agent1.id)
defer client.close()
client.sendNode(&agpl.Node{PreferredDERP: 3})
nodes := agent1.recvNodes(ctx, t)
assertHasDERPs(t, nodes, 3)
nodes = agent2.recvNodes(ctx, t)
assertHasDERPs(t, nodes, 3)
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
nodes = client.recvNodes(ctx, t)
assertHasDERPs(t, nodes, 1)
// agent2's update overrides agent1 because it is newer
agent2.sendNode(&agpl.Node{PreferredDERP: 2})
nodes = client.recvNodes(ctx, t)
assertHasDERPs(t, nodes, 2)
// agent2 disconnects, and we should revert back to agent1
err = agent2.close()
require.NoError(t, err)
err = agent2.recvErr(ctx, t)
require.ErrorIs(t, err, io.ErrClosedPipe)
agent2.waitForClose(ctx, t)
nodes = client.recvNodes(ctx, t)
assertHasDERPs(t, nodes, 1)
agent1.sendNode(&agpl.Node{PreferredDERP: 11})
nodes = client.recvNodes(ctx, t)
assertHasDERPs(t, nodes, 11)
client.sendNode(&agpl.Node{PreferredDERP: 31})
nodes = agent1.recvNodes(ctx, t)
assertHasDERPs(t, nodes, 31)
err = agent1.close()
require.NoError(t, err)
err = agent1.recvErr(ctx, t)
require.ErrorIs(t, err, io.ErrClosedPipe)
agent1.waitForClose(ctx, t)
err = client.close()
require.NoError(t, err)
err = client.recvErr(ctx, t)
require.ErrorIs(t, err, io.ErrClosedPipe)
client.waitForClose(ctx, t)
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
assertEventuallyNoAgents(ctx, t, store, agent1.id)
}
type testConn struct {
ws, serverWS net.Conn
nodeChan chan []*agpl.Node
sendNode func(node *agpl.Node)
errChan <-chan error
id uuid.UUID
closeChan chan struct{}
}
func newTestConn(ids []uuid.UUID) *testConn {
a := &testConn{}
a.ws, a.serverWS = net.Pipe()
a.nodeChan = make(chan []*agpl.Node)
a.sendNode, a.errChan = agpl.ServeCoordinator(a.ws, func(nodes []*agpl.Node) error {
a.nodeChan <- nodes
return nil
})
if len(ids) > 1 {
panic("too many")
}
if len(ids) == 1 {
a.id = ids[0]
} else {
a.id = uuid.New()
}
a.closeChan = make(chan struct{})
return a
}
func newTestAgent(t *testing.T, coord agpl.Coordinator, id ...uuid.UUID) *testConn {
a := newTestConn(id)
go func() {
err := coord.ServeAgent(a.serverWS, a.id, "")
assert.NoError(t, err)
close(a.closeChan)
}()
return a
}
func (c *testConn) close() error {
return c.ws.Close()
}
func (c *testConn) recvNodes(ctx context.Context, t *testing.T) []*agpl.Node {
t.Helper()
select {
case <-ctx.Done():
t.Fatal("timeout receiving nodes")
return nil
case nodes := <-c.nodeChan:
return nodes
}
}
func (c *testConn) recvErr(ctx context.Context, t *testing.T) error {
t.Helper()
select {
case <-ctx.Done():
t.Fatal("timeout receiving error")
return ctx.Err()
case err := <-c.errChan:
return err
}
}
func (c *testConn) waitForClose(ctx context.Context, t *testing.T) {
t.Helper()
select {
case <-ctx.Done():
t.Fatal("timeout waiting for connection to close")
return
case <-c.closeChan:
return
}
}
func newTestClient(t *testing.T, coord agpl.Coordinator, agentID uuid.UUID, id ...uuid.UUID) *testConn {
c := newTestConn(id)
go func() {
err := coord.ServeClient(c.serverWS, c.id, agentID)
assert.NoError(t, err)
close(c.closeChan)
}()
return c
}
func assertHasDERPs(t *testing.T, nodes []*agpl.Node, expected ...int) {
if !assert.Len(t, nodes, len(expected), "expected %d node(s), got %d", len(expected), len(nodes)) {
return
}
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
assert.Contains(t, derps, e, "expected DERP %v, got %v", e, derps)
}
}
func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) {
assert.Eventually(t, func() bool {
agents, err := store.GetTailnetAgents(ctx, agentID)
if xerrors.Is(err, sql.ErrNoRows) {
return true
}
if err != nil {
t.Fatal(err)
}
return len(agents) == 0
}, testutil.WaitShort, testutil.IntervalFast)
}
func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) {
assert.Eventually(t, func() bool {
clients, err := store.GetTailnetClientsForAgent(ctx, agentID)
if xerrors.Is(err, sql.ErrNoRows) {
return true
}
if err != nil {
t.Fatal(err)
}
return len(clients) == 0
}, testutil.WaitShort, testutil.IntervalFast)
}
type fakeCoordinator struct {
ctx context.Context
t *testing.T
store database.Store
id uuid.UUID
}
func (c *fakeCoordinator) heartbeat() {
c.t.Helper()
_, err := c.store.UpsertTailnetCoordinator(c.ctx, c.id)
require.NoError(c.t, err)
}
func (c *fakeCoordinator) agentNode(agentID uuid.UUID, node *agpl.Node) {
c.t.Helper()
nodeRaw, err := json.Marshal(node)
require.NoError(c.t, err)
_, err = c.store.UpsertTailnetAgent(c.ctx, database.UpsertTailnetAgentParams{
ID: agentID,
CoordinatorID: c.id,
Node: nodeRaw,
})
require.NoError(c.t, err)
}

View File

@ -1,6 +1,7 @@
package tailnet
import (
"bytes"
"context"
"encoding/json"
"errors"
@ -174,11 +175,12 @@ func newCore(logger slog.Logger) *core {
var ErrWouldBlock = xerrors.New("would block")
type TrackedConn struct {
ctx context.Context
cancel func()
conn net.Conn
updates chan []*Node
logger slog.Logger
ctx context.Context
cancel func()
conn net.Conn
updates chan []*Node
logger slog.Logger
lastData []byte
// ID is an ephemeral UUID used to uniquely identify the owner of the
// connection.
@ -224,6 +226,10 @@ func (t *TrackedConn) SendUpdates() {
t.logger.Error(t.ctx, "unable to marshal nodes update", slog.Error(err), slog.F("nodes", nodes))
return
}
if bytes.Equal(t.lastData, data) {
t.logger.Debug(t.ctx, "skipping duplicate update", slog.F("nodes", nodes))
continue
}
// Set a deadline so that hung connections don't put back pressure on the system.
// Node updates are tiny, so even the dinkiest connection can handle them if it's not hung.
@ -255,6 +261,7 @@ func (t *TrackedConn) SendUpdates() {
_ = t.Close()
return
}
t.lastData = data
}
}
}

View File

@ -96,7 +96,7 @@ func TestCoordinator(t *testing.T) {
assert.NoError(t, err)
close(closeAgentChan)
}()
sendAgentNode(&tailnet.Node{})
sendAgentNode(&tailnet.Node{PreferredDERP: 1})
require.Eventually(t, func() bool {
return coordinator.Node(agentID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
@ -122,7 +122,7 @@ func TestCoordinator(t *testing.T) {
case <-ctx.Done():
t.Fatal("timed out")
}
sendClientNode(&tailnet.Node{})
sendClientNode(&tailnet.Node{PreferredDERP: 2})
clientNodes := <-agentNodeChan
require.Len(t, clientNodes, 1)
@ -131,7 +131,7 @@ func TestCoordinator(t *testing.T) {
time.Sleep(tailnet.WriteTimeout * 3 / 2)
// Ensure an update to the agent node reaches the client!
sendAgentNode(&tailnet.Node{})
sendAgentNode(&tailnet.Node{PreferredDERP: 3})
select {
case agentNodes := <-clientNodeChan:
require.Len(t, agentNodes, 1)
@ -193,7 +193,7 @@ func TestCoordinator(t *testing.T) {
assert.NoError(t, err)
close(closeAgentChan1)
}()
sendAgentNode1(&tailnet.Node{})
sendAgentNode1(&tailnet.Node{PreferredDERP: 1})
require.Eventually(t, func() bool {
return coordinator.Node(agentID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
@ -215,12 +215,12 @@ func TestCoordinator(t *testing.T) {
}()
agentNodes := <-clientNodeChan
require.Len(t, agentNodes, 1)
sendClientNode(&tailnet.Node{})
sendClientNode(&tailnet.Node{PreferredDERP: 2})
clientNodes := <-agentNodeChan1
require.Len(t, clientNodes, 1)
// Ensure an update to the agent node reaches the client!
sendAgentNode1(&tailnet.Node{})
sendAgentNode1(&tailnet.Node{PreferredDERP: 3})
agentNodes = <-clientNodeChan
require.Len(t, agentNodes, 1)