feat: add single tailnet support to pgcoord (#9351)

This commit is contained in:
Colin Adler 2023-09-21 15:30:48 -04:00 committed by GitHub
parent fbad06f406
commit c900b5f8df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 1647 additions and 293 deletions

View File

@ -694,6 +694,13 @@ func (q *querier) DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) e
return q.db.DeleteAPIKeysByUserID(ctx, userID)
}
func (q *querier) DeleteAllTailnetClientSubscriptions(ctx context.Context, arg database.DeleteAllTailnetClientSubscriptionsParams) error {
if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil {
return err
}
return q.db.DeleteAllTailnetClientSubscriptions(ctx, arg)
}
func (q *querier) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error {
// TODO: This is not 100% correct because it omits apikey IDs.
err := q.authorizeContext(ctx, rbac.ActionDelete,
@ -783,6 +790,13 @@ func (q *querier) DeleteTailnetClient(ctx context.Context, arg database.DeleteTa
return q.db.DeleteTailnetClient(ctx, arg)
}
func (q *querier) DeleteTailnetClientSubscription(ctx context.Context, arg database.DeleteTailnetClientSubscriptionParams) error {
if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil {
return err
}
return q.db.DeleteTailnetClientSubscription(ctx, arg)
}
func (q *querier) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) {
return fetch(q.log, q.auth, q.db.GetAPIKeyByID)(ctx, id)
}
@ -825,9 +839,9 @@ func (q *querier) GetAllTailnetAgents(ctx context.Context) ([]database.TailnetAg
return q.db.GetAllTailnetAgents(ctx)
}
func (q *querier) GetAllTailnetClients(ctx context.Context) ([]database.TailnetClient, error) {
func (q *querier) GetAllTailnetClients(ctx context.Context) ([]database.GetAllTailnetClientsRow, error) {
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
return []database.TailnetClient{}, err
return []database.GetAllTailnetClientsRow{}, err
}
return q.db.GetAllTailnetClients(ctx)
}
@ -2794,6 +2808,13 @@ func (q *querier) UpsertTailnetClient(ctx context.Context, arg database.UpsertTa
return q.db.UpsertTailnetClient(ctx, arg)
}
func (q *querier) UpsertTailnetClientSubscription(ctx context.Context, arg database.UpsertTailnetClientSubscriptionParams) error {
if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil {
return err
}
return q.db.UpsertTailnetClientSubscription(ctx, arg)
}
func (q *querier) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (database.TailnetCoordinator, error) {
if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil {
return database.TailnetCoordinator{}, err

View File

@ -854,6 +854,15 @@ func (q *FakeQuerier) DeleteAPIKeysByUserID(_ context.Context, userID uuid.UUID)
return nil
}
func (*FakeQuerier) DeleteAllTailnetClientSubscriptions(_ context.Context, arg database.DeleteAllTailnetClientSubscriptionsParams) error {
err := validateDatabaseType(arg)
if err != nil {
return err
}
return ErrUnimplemented
}
func (q *FakeQuerier) DeleteApplicationConnectAPIKeysByUserID(_ context.Context, userID uuid.UUID) error {
q.mutex.Lock()
defer q.mutex.Unlock()
@ -987,6 +996,10 @@ func (*FakeQuerier) DeleteTailnetClient(context.Context, database.DeleteTailnetC
return database.DeleteTailnetClientRow{}, ErrUnimplemented
}
func (*FakeQuerier) DeleteTailnetClientSubscription(context.Context, database.DeleteTailnetClientSubscriptionParams) error {
return ErrUnimplemented
}
func (q *FakeQuerier) GetAPIKeyByID(_ context.Context, id string) (database.APIKey, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
@ -1102,7 +1115,7 @@ func (*FakeQuerier) GetAllTailnetAgents(_ context.Context) ([]database.TailnetAg
return nil, ErrUnimplemented
}
func (*FakeQuerier) GetAllTailnetClients(_ context.Context) ([]database.TailnetClient, error) {
func (*FakeQuerier) GetAllTailnetClients(_ context.Context) ([]database.GetAllTailnetClientsRow, error) {
return nil, ErrUnimplemented
}
@ -6112,6 +6125,10 @@ func (*FakeQuerier) UpsertTailnetClient(context.Context, database.UpsertTailnetC
return database.TailnetClient{}, ErrUnimplemented
}
func (*FakeQuerier) UpsertTailnetClientSubscription(context.Context, database.UpsertTailnetClientSubscriptionParams) error {
return ErrUnimplemented
}
func (*FakeQuerier) UpsertTailnetCoordinator(context.Context, uuid.UUID) (database.TailnetCoordinator, error) {
return database.TailnetCoordinator{}, ErrUnimplemented
}

View File

@ -128,6 +128,13 @@ func (m metricsStore) DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUI
return err
}
func (m metricsStore) DeleteAllTailnetClientSubscriptions(ctx context.Context, arg database.DeleteAllTailnetClientSubscriptionsParams) error {
start := time.Now()
r0 := m.s.DeleteAllTailnetClientSubscriptions(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteAllTailnetClientSubscriptions").Observe(time.Since(start).Seconds())
return r0
}
func (m metricsStore) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error {
start := time.Now()
err := m.s.DeleteApplicationConnectAPIKeysByUserID(ctx, userID)
@ -209,6 +216,13 @@ func (m metricsStore) DeleteTailnetClient(ctx context.Context, arg database.Dele
return m.s.DeleteTailnetClient(ctx, arg)
}
func (m metricsStore) DeleteTailnetClientSubscription(ctx context.Context, arg database.DeleteTailnetClientSubscriptionParams) error {
start := time.Now()
r0 := m.s.DeleteTailnetClientSubscription(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteTailnetClientSubscription").Observe(time.Since(start).Seconds())
return r0
}
func (m metricsStore) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) {
start := time.Now()
apiKey, err := m.s.GetAPIKeyByID(ctx, id)
@ -265,7 +279,7 @@ func (m metricsStore) GetAllTailnetAgents(ctx context.Context) ([]database.Tailn
return r0, r1
}
func (m metricsStore) GetAllTailnetClients(ctx context.Context) ([]database.TailnetClient, error) {
func (m metricsStore) GetAllTailnetClients(ctx context.Context) ([]database.GetAllTailnetClientsRow, error) {
start := time.Now()
r0, r1 := m.s.GetAllTailnetClients(ctx)
m.queryLatencies.WithLabelValues("GetAllTailnetClients").Observe(time.Since(start).Seconds())
@ -1752,6 +1766,13 @@ func (m metricsStore) UpsertTailnetClient(ctx context.Context, arg database.Upse
return m.s.UpsertTailnetClient(ctx, arg)
}
func (m metricsStore) UpsertTailnetClientSubscription(ctx context.Context, arg database.UpsertTailnetClientSubscriptionParams) error {
start := time.Now()
r0 := m.s.UpsertTailnetClientSubscription(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertTailnetClientSubscription").Observe(time.Since(start).Seconds())
return r0
}
func (m metricsStore) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (database.TailnetCoordinator, error) {
start := time.Now()
defer m.queryLatencies.WithLabelValues("UpsertTailnetCoordinator").Observe(time.Since(start).Seconds())

View File

@ -139,6 +139,20 @@ func (mr *MockStoreMockRecorder) DeleteAPIKeysByUserID(arg0, arg1 interface{}) *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAPIKeysByUserID", reflect.TypeOf((*MockStore)(nil).DeleteAPIKeysByUserID), arg0, arg1)
}
// DeleteAllTailnetClientSubscriptions mocks base method.
func (m *MockStore) DeleteAllTailnetClientSubscriptions(arg0 context.Context, arg1 database.DeleteAllTailnetClientSubscriptionsParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAllTailnetClientSubscriptions", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAllTailnetClientSubscriptions indicates an expected call of DeleteAllTailnetClientSubscriptions.
func (mr *MockStoreMockRecorder) DeleteAllTailnetClientSubscriptions(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllTailnetClientSubscriptions", reflect.TypeOf((*MockStore)(nil).DeleteAllTailnetClientSubscriptions), arg0, arg1)
}
// DeleteApplicationConnectAPIKeysByUserID mocks base method.
func (m *MockStore) DeleteApplicationConnectAPIKeysByUserID(arg0 context.Context, arg1 uuid.UUID) error {
m.ctrl.T.Helper()
@ -310,6 +324,20 @@ func (mr *MockStoreMockRecorder) DeleteTailnetClient(arg0, arg1 interface{}) *go
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTailnetClient", reflect.TypeOf((*MockStore)(nil).DeleteTailnetClient), arg0, arg1)
}
// DeleteTailnetClientSubscription mocks base method.
func (m *MockStore) DeleteTailnetClientSubscription(arg0 context.Context, arg1 database.DeleteTailnetClientSubscriptionParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteTailnetClientSubscription", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteTailnetClientSubscription indicates an expected call of DeleteTailnetClientSubscription.
func (mr *MockStoreMockRecorder) DeleteTailnetClientSubscription(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTailnetClientSubscription", reflect.TypeOf((*MockStore)(nil).DeleteTailnetClientSubscription), arg0, arg1)
}
// GetAPIKeyByID mocks base method.
func (m *MockStore) GetAPIKeyByID(arg0 context.Context, arg1 string) (database.APIKey, error) {
m.ctrl.T.Helper()
@ -431,10 +459,10 @@ func (mr *MockStoreMockRecorder) GetAllTailnetAgents(arg0 interface{}) *gomock.C
}
// GetAllTailnetClients mocks base method.
func (m *MockStore) GetAllTailnetClients(arg0 context.Context) ([]database.TailnetClient, error) {
func (m *MockStore) GetAllTailnetClients(arg0 context.Context) ([]database.GetAllTailnetClientsRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAllTailnetClients", arg0)
ret0, _ := ret[0].([]database.TailnetClient)
ret0, _ := ret[0].([]database.GetAllTailnetClientsRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@ -3681,6 +3709,20 @@ func (mr *MockStoreMockRecorder) UpsertTailnetClient(arg0, arg1 interface{}) *go
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTailnetClient", reflect.TypeOf((*MockStore)(nil).UpsertTailnetClient), arg0, arg1)
}
// UpsertTailnetClientSubscription mocks base method.
func (m *MockStore) UpsertTailnetClientSubscription(arg0 context.Context, arg1 database.UpsertTailnetClientSubscriptionParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertTailnetClientSubscription", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// UpsertTailnetClientSubscription indicates an expected call of UpsertTailnetClientSubscription.
func (mr *MockStoreMockRecorder) UpsertTailnetClientSubscription(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTailnetClientSubscription", reflect.TypeOf((*MockStore)(nil).UpsertTailnetClientSubscription), arg0, arg1)
}
// UpsertTailnetCoordinator mocks base method.
func (m *MockStore) UpsertTailnetCoordinator(arg0 context.Context, arg1 uuid.UUID) (database.TailnetCoordinator, error) {
m.ctrl.T.Helper()

View File

@ -219,13 +219,57 @@ $$;
CREATE FUNCTION tailnet_notify_client_change() RETURNS trigger
LANGUAGE plpgsql
AS $$
DECLARE
var_client_id uuid;
var_coordinator_id uuid;
var_agent_ids uuid[];
var_agent_id uuid;
BEGIN
IF (OLD IS NOT NULL) THEN
PERFORM pg_notify('tailnet_client_update', OLD.id || ',' || OLD.agent_id);
RETURN NULL;
IF (NEW.id IS NOT NULL) THEN
var_client_id = NEW.id;
var_coordinator_id = NEW.coordinator_id;
ELSIF (OLD.id IS NOT NULL) THEN
var_client_id = OLD.id;
var_coordinator_id = OLD.coordinator_id;
END IF;
-- Read all agents the client is subscribed to, so we can notify them.
SELECT
array_agg(agent_id)
INTO
var_agent_ids
FROM
tailnet_client_subscriptions subs
WHERE
subs.client_id = NEW.id AND
subs.coordinator_id = NEW.coordinator_id;
-- No agents to notify
if (var_agent_ids IS NULL) THEN
return NULL;
END IF;
-- pg_notify is limited to 8k bytes, which is approximately 221 UUIDs.
-- Instead of sending all agent ids in a single update, send one for each
-- agent id to prevent overflow.
FOREACH var_agent_id IN ARRAY var_agent_ids
LOOP
PERFORM pg_notify('tailnet_client_update', var_client_id || ',' || var_agent_id);
END LOOP;
return NULL;
END;
$$;
CREATE FUNCTION tailnet_notify_client_subscription_change() RETURNS trigger
LANGUAGE plpgsql
AS $$
BEGIN
IF (NEW IS NOT NULL) THEN
PERFORM pg_notify('tailnet_client_update', NEW.id || ',' || NEW.agent_id);
PERFORM pg_notify('tailnet_client_update', NEW.client_id || ',' || NEW.agent_id);
RETURN NULL;
ELSIF (OLD IS NOT NULL) THEN
PERFORM pg_notify('tailnet_client_update', OLD.client_id || ',' || OLD.agent_id);
RETURN NULL;
END IF;
END;
@ -495,10 +539,16 @@ CREATE TABLE tailnet_agents (
node jsonb NOT NULL
);
CREATE TABLE tailnet_client_subscriptions (
client_id uuid NOT NULL,
coordinator_id uuid NOT NULL,
agent_id uuid NOT NULL,
updated_at timestamp with time zone NOT NULL
);
CREATE TABLE tailnet_clients (
id uuid NOT NULL,
coordinator_id uuid NOT NULL,
agent_id uuid NOT NULL,
updated_at timestamp with time zone NOT NULL,
node jsonb NOT NULL
);
@ -1144,6 +1194,9 @@ ALTER TABLE ONLY site_configs
ALTER TABLE ONLY tailnet_agents
ADD CONSTRAINT tailnet_agents_pkey PRIMARY KEY (id, coordinator_id);
ALTER TABLE ONLY tailnet_client_subscriptions
ADD CONSTRAINT tailnet_client_subscriptions_pkey PRIMARY KEY (client_id, coordinator_id, agent_id);
ALTER TABLE ONLY tailnet_clients
ADD CONSTRAINT tailnet_clients_pkey PRIMARY KEY (id, coordinator_id);
@ -1248,8 +1301,6 @@ CREATE UNIQUE INDEX idx_organization_name_lower ON organizations USING btree (lo
CREATE INDEX idx_tailnet_agents_coordinator ON tailnet_agents USING btree (coordinator_id);
CREATE INDEX idx_tailnet_clients_agent ON tailnet_clients USING btree (agent_id);
CREATE INDEX idx_tailnet_clients_coordinator ON tailnet_clients USING btree (coordinator_id);
CREATE UNIQUE INDEX idx_users_email ON users USING btree (email) WHERE (deleted = false);
@ -1284,6 +1335,8 @@ CREATE TRIGGER tailnet_notify_agent_change AFTER INSERT OR DELETE OR UPDATE ON t
CREATE TRIGGER tailnet_notify_client_change AFTER INSERT OR DELETE OR UPDATE ON tailnet_clients FOR EACH ROW EXECUTE FUNCTION tailnet_notify_client_change();
CREATE TRIGGER tailnet_notify_client_subscription_change AFTER INSERT OR DELETE OR UPDATE ON tailnet_client_subscriptions FOR EACH ROW EXECUTE FUNCTION tailnet_notify_client_subscription_change();
CREATE TRIGGER tailnet_notify_coordinator_heartbeat AFTER INSERT OR UPDATE ON tailnet_coordinators FOR EACH ROW EXECUTE FUNCTION tailnet_notify_coordinator_heartbeat();
CREATE TRIGGER trigger_insert_apikeys BEFORE INSERT ON api_keys FOR EACH ROW EXECUTE FUNCTION insert_apikey_fail_if_user_deleted();
@ -1329,6 +1382,9 @@ ALTER TABLE ONLY provisioner_jobs
ALTER TABLE ONLY tailnet_agents
ADD CONSTRAINT tailnet_agents_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE;
ALTER TABLE ONLY tailnet_client_subscriptions
ADD CONSTRAINT tailnet_client_subscriptions_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE;
ALTER TABLE ONLY tailnet_clients
ADD CONSTRAINT tailnet_clients_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE;

View File

@ -19,6 +19,7 @@ const (
ForeignKeyProvisionerJobLogsJobID ForeignKeyConstraint = "provisioner_job_logs_job_id_fkey" // ALTER TABLE ONLY provisioner_job_logs ADD CONSTRAINT provisioner_job_logs_job_id_fkey FOREIGN KEY (job_id) REFERENCES provisioner_jobs(id) ON DELETE CASCADE;
ForeignKeyProvisionerJobsOrganizationID ForeignKeyConstraint = "provisioner_jobs_organization_id_fkey" // ALTER TABLE ONLY provisioner_jobs ADD CONSTRAINT provisioner_jobs_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
ForeignKeyTailnetAgentsCoordinatorID ForeignKeyConstraint = "tailnet_agents_coordinator_id_fkey" // ALTER TABLE ONLY tailnet_agents ADD CONSTRAINT tailnet_agents_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE;
ForeignKeyTailnetClientSubscriptionsCoordinatorID ForeignKeyConstraint = "tailnet_client_subscriptions_coordinator_id_fkey" // ALTER TABLE ONLY tailnet_client_subscriptions ADD CONSTRAINT tailnet_client_subscriptions_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE;
ForeignKeyTailnetClientsCoordinatorID ForeignKeyConstraint = "tailnet_clients_coordinator_id_fkey" // ALTER TABLE ONLY tailnet_clients ADD CONSTRAINT tailnet_clients_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE;
ForeignKeyTemplateVersionParametersTemplateVersionID ForeignKeyConstraint = "template_version_parameters_template_version_id_fkey" // ALTER TABLE ONLY template_version_parameters ADD CONSTRAINT template_version_parameters_template_version_id_fkey FOREIGN KEY (template_version_id) REFERENCES template_versions(id) ON DELETE CASCADE;
ForeignKeyTemplateVersionVariablesTemplateVersionID ForeignKeyConstraint = "template_version_variables_template_version_id_fkey" // ALTER TABLE ONLY template_version_variables ADD CONSTRAINT template_version_variables_template_version_id_fkey FOREIGN KEY (template_version_id) REFERENCES template_versions(id) ON DELETE CASCADE;

View File

@ -0,0 +1,39 @@
BEGIN;
ALTER TABLE
tailnet_clients
ADD COLUMN
agent_id uuid;
UPDATE
tailnet_clients
SET
-- there's no reason for us to try and preserve data since coordinators will
-- have to restart anyways, which will create all of the client mappings.
agent_id = '00000000-0000-0000-0000-000000000000'::uuid;
ALTER TABLE
tailnet_clients
ALTER COLUMN
agent_id SET NOT NULL;
DROP TABLE tailnet_client_subscriptions;
DROP FUNCTION tailnet_notify_client_subscription_change;
-- update the tailnet_clients trigger to the old version.
CREATE OR REPLACE FUNCTION tailnet_notify_client_change() RETURNS trigger
LANGUAGE plpgsql
AS $$
BEGIN
IF (OLD IS NOT NULL) THEN
PERFORM pg_notify('tailnet_client_update', OLD.id || ',' || OLD.agent_id);
RETURN NULL;
END IF;
IF (NEW IS NOT NULL) THEN
PERFORM pg_notify('tailnet_client_update', NEW.id || ',' || NEW.agent_id);
RETURN NULL;
END IF;
END;
$$;
COMMIT;

View File

@ -0,0 +1,88 @@
BEGIN;
CREATE TABLE tailnet_client_subscriptions (
client_id uuid NOT NULL,
coordinator_id uuid NOT NULL,
-- this isn't a foreign key since it's more of a list of agents the client
-- *wants* to connect to, and they don't necessarily have to currently
-- exist in the db.
agent_id uuid NOT NULL,
updated_at timestamp with time zone NOT NULL,
PRIMARY KEY (client_id, coordinator_id, agent_id),
FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators (id) ON DELETE CASCADE
-- we don't keep a foreign key to the tailnet_clients table since there's
-- not a great way to guarantee that a subscription is always added after
-- the client is inserted. clients are only created after the client sends
-- its first node update, which can take an undetermined amount of time.
);
CREATE FUNCTION tailnet_notify_client_subscription_change() RETURNS trigger
LANGUAGE plpgsql
AS $$
BEGIN
IF (NEW IS NOT NULL) THEN
PERFORM pg_notify('tailnet_client_update', NEW.client_id || ',' || NEW.agent_id);
RETURN NULL;
ELSIF (OLD IS NOT NULL) THEN
PERFORM pg_notify('tailnet_client_update', OLD.client_id || ',' || OLD.agent_id);
RETURN NULL;
END IF;
END;
$$;
CREATE TRIGGER tailnet_notify_client_subscription_change
AFTER INSERT OR UPDATE OR DELETE ON tailnet_client_subscriptions
FOR EACH ROW
EXECUTE PROCEDURE tailnet_notify_client_subscription_change();
CREATE OR REPLACE FUNCTION tailnet_notify_client_change() RETURNS trigger
LANGUAGE plpgsql
AS $$
DECLARE
var_client_id uuid;
var_coordinator_id uuid;
var_agent_ids uuid[];
var_agent_id uuid;
BEGIN
IF (NEW.id IS NOT NULL) THEN
var_client_id = NEW.id;
var_coordinator_id = NEW.coordinator_id;
ELSIF (OLD.id IS NOT NULL) THEN
var_client_id = OLD.id;
var_coordinator_id = OLD.coordinator_id;
END IF;
-- Read all agents the client is subscribed to, so we can notify them.
SELECT
array_agg(agent_id)
INTO
var_agent_ids
FROM
tailnet_client_subscriptions subs
WHERE
subs.client_id = NEW.id AND
subs.coordinator_id = NEW.coordinator_id;
-- No agents to notify
if (var_agent_ids IS NULL) THEN
return NULL;
END IF;
-- pg_notify is limited to 8k bytes, which is approximately 221 UUIDs.
-- Instead of sending all agent ids in a single update, send one for each
-- agent id to prevent overflow.
FOREACH var_agent_id IN ARRAY var_agent_ids
LOOP
PERFORM pg_notify('tailnet_client_update', var_client_id || ',' || var_agent_id);
END LOOP;
return NULL;
END;
$$;
ALTER TABLE
tailnet_clients
DROP COLUMN
agent_id;
COMMIT;

View File

@ -18,7 +18,7 @@ VALUES
);
INSERT INTO tailnet_agents
(id, coordinator_id, updated_at, node)
(id, coordinator_id, updated_at, node)
VALUES
(
'c0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11',

View File

@ -0,0 +1,9 @@
INSERT INTO tailnet_client_subscriptions
(client_id, agent_id, coordinator_id, updated_at)
VALUES
(
'b0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11',
'c0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11',
'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11',
'2023-06-15 10:23:54+00'
);

View File

@ -1783,11 +1783,17 @@ type TailnetAgent struct {
type TailnetClient struct {
ID uuid.UUID `db:"id" json:"id"`
CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"`
AgentID uuid.UUID `db:"agent_id" json:"agent_id"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Node json.RawMessage `db:"node" json:"node"`
}
type TailnetClientSubscription struct {
ClientID uuid.UUID `db:"client_id" json:"client_id"`
CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"`
AgentID uuid.UUID `db:"agent_id" json:"agent_id"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
// We keep this separate from replicas in case we need to break the coordinator out into its own service
type TailnetCoordinator struct {
ID uuid.UUID `db:"id" json:"id"`

View File

@ -36,6 +36,7 @@ type sqlcQuerier interface {
CleanTailnetCoordinators(ctx context.Context) error
DeleteAPIKeyByID(ctx context.Context, id string) error
DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error
DeleteAllTailnetClientSubscriptions(ctx context.Context, arg DeleteAllTailnetClientSubscriptionsParams) error
DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error
DeleteCoordinator(ctx context.Context, id uuid.UUID) error
DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error
@ -50,6 +51,7 @@ type sqlcQuerier interface {
DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt time.Time) error
DeleteTailnetAgent(ctx context.Context, arg DeleteTailnetAgentParams) (DeleteTailnetAgentRow, error)
DeleteTailnetClient(ctx context.Context, arg DeleteTailnetClientParams) (DeleteTailnetClientRow, error)
DeleteTailnetClientSubscription(ctx context.Context, arg DeleteTailnetClientSubscriptionParams) error
GetAPIKeyByID(ctx context.Context, id string) (APIKey, error)
// there is no unique constraint on empty token names
GetAPIKeyByName(ctx context.Context, arg GetAPIKeyByNameParams) (APIKey, error)
@ -59,7 +61,7 @@ type sqlcQuerier interface {
GetActiveUserCount(ctx context.Context) (int64, error)
GetActiveWorkspaceBuildsByTemplateID(ctx context.Context, templateID uuid.UUID) ([]WorkspaceBuild, error)
GetAllTailnetAgents(ctx context.Context) ([]TailnetAgent, error)
GetAllTailnetClients(ctx context.Context) ([]TailnetClient, error)
GetAllTailnetClients(ctx context.Context) ([]GetAllTailnetClientsRow, error)
GetAppSecurityKey(ctx context.Context) (string, error)
// GetAuditLogsBefore retrieves `row_limit` number of audit logs before the provided
// ID.
@ -324,6 +326,7 @@ type sqlcQuerier interface {
UpsertServiceBanner(ctx context.Context, value string) error
UpsertTailnetAgent(ctx context.Context, arg UpsertTailnetAgentParams) (TailnetAgent, error)
UpsertTailnetClient(ctx context.Context, arg UpsertTailnetClientParams) (TailnetClient, error)
UpsertTailnetClientSubscription(ctx context.Context, arg UpsertTailnetClientSubscriptionParams) error
UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (TailnetCoordinator, error)
}

View File

@ -4131,6 +4131,22 @@ func (q *sqlQuerier) CleanTailnetCoordinators(ctx context.Context) error {
return err
}
const deleteAllTailnetClientSubscriptions = `-- name: DeleteAllTailnetClientSubscriptions :exec
DELETE
FROM tailnet_client_subscriptions
WHERE client_id = $1 and coordinator_id = $2
`
type DeleteAllTailnetClientSubscriptionsParams struct {
ClientID uuid.UUID `db:"client_id" json:"client_id"`
CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"`
}
func (q *sqlQuerier) DeleteAllTailnetClientSubscriptions(ctx context.Context, arg DeleteAllTailnetClientSubscriptionsParams) error {
_, err := q.db.ExecContext(ctx, deleteAllTailnetClientSubscriptions, arg.ClientID, arg.CoordinatorID)
return err
}
const deleteCoordinator = `-- name: DeleteCoordinator :exec
DELETE
FROM tailnet_coordinators
@ -4190,6 +4206,23 @@ func (q *sqlQuerier) DeleteTailnetClient(ctx context.Context, arg DeleteTailnetC
return i, err
}
const deleteTailnetClientSubscription = `-- name: DeleteTailnetClientSubscription :exec
DELETE
FROM tailnet_client_subscriptions
WHERE client_id = $1 and agent_id = $2 and coordinator_id = $3
`
type DeleteTailnetClientSubscriptionParams struct {
ClientID uuid.UUID `db:"client_id" json:"client_id"`
AgentID uuid.UUID `db:"agent_id" json:"agent_id"`
CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"`
}
func (q *sqlQuerier) DeleteTailnetClientSubscription(ctx context.Context, arg DeleteTailnetClientSubscriptionParams) error {
_, err := q.db.ExecContext(ctx, deleteTailnetClientSubscription, arg.ClientID, arg.AgentID, arg.CoordinatorID)
return err
}
const getAllTailnetAgents = `-- name: GetAllTailnetAgents :many
SELECT id, coordinator_id, updated_at, node
FROM tailnet_agents
@ -4224,26 +4257,32 @@ func (q *sqlQuerier) GetAllTailnetAgents(ctx context.Context) ([]TailnetAgent, e
}
const getAllTailnetClients = `-- name: GetAllTailnetClients :many
SELECT id, coordinator_id, agent_id, updated_at, node
SELECT tailnet_clients.id, tailnet_clients.coordinator_id, tailnet_clients.updated_at, tailnet_clients.node, array_agg(tailnet_client_subscriptions.agent_id)::uuid[] as agent_ids
FROM tailnet_clients
ORDER BY agent_id
LEFT JOIN tailnet_client_subscriptions
ON tailnet_clients.id = tailnet_client_subscriptions.client_id
`
func (q *sqlQuerier) GetAllTailnetClients(ctx context.Context) ([]TailnetClient, error) {
type GetAllTailnetClientsRow struct {
TailnetClient TailnetClient `db:"tailnet_client" json:"tailnet_client"`
AgentIds []uuid.UUID `db:"agent_ids" json:"agent_ids"`
}
func (q *sqlQuerier) GetAllTailnetClients(ctx context.Context) ([]GetAllTailnetClientsRow, error) {
rows, err := q.db.QueryContext(ctx, getAllTailnetClients)
if err != nil {
return nil, err
}
defer rows.Close()
var items []TailnetClient
var items []GetAllTailnetClientsRow
for rows.Next() {
var i TailnetClient
var i GetAllTailnetClientsRow
if err := rows.Scan(
&i.ID,
&i.CoordinatorID,
&i.AgentID,
&i.UpdatedAt,
&i.Node,
&i.TailnetClient.ID,
&i.TailnetClient.CoordinatorID,
&i.TailnetClient.UpdatedAt,
&i.TailnetClient.Node,
pq.Array(&i.AgentIds),
); err != nil {
return nil, err
}
@ -4293,9 +4332,13 @@ func (q *sqlQuerier) GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]Tail
}
const getTailnetClientsForAgent = `-- name: GetTailnetClientsForAgent :many
SELECT id, coordinator_id, agent_id, updated_at, node
SELECT id, coordinator_id, updated_at, node
FROM tailnet_clients
WHERE agent_id = $1
WHERE id IN (
SELECT tailnet_client_subscriptions.client_id
FROM tailnet_client_subscriptions
WHERE tailnet_client_subscriptions.agent_id = $1
)
`
func (q *sqlQuerier) GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]TailnetClient, error) {
@ -4310,7 +4353,6 @@ func (q *sqlQuerier) GetTailnetClientsForAgent(ctx context.Context, agentID uuid
if err := rows.Scan(
&i.ID,
&i.CoordinatorID,
&i.AgentID,
&i.UpdatedAt,
&i.Node,
); err != nil {
@ -4369,47 +4411,67 @@ INSERT INTO
tailnet_clients (
id,
coordinator_id,
agent_id,
node,
updated_at
)
VALUES
($1, $2, $3, $4, now() at time zone 'utc')
($1, $2, $3, now() at time zone 'utc')
ON CONFLICT (id, coordinator_id)
DO UPDATE SET
id = $1,
coordinator_id = $2,
agent_id = $3,
node = $4,
node = $3,
updated_at = now() at time zone 'utc'
RETURNING id, coordinator_id, agent_id, updated_at, node
RETURNING id, coordinator_id, updated_at, node
`
type UpsertTailnetClientParams struct {
ID uuid.UUID `db:"id" json:"id"`
CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"`
AgentID uuid.UUID `db:"agent_id" json:"agent_id"`
Node json.RawMessage `db:"node" json:"node"`
}
func (q *sqlQuerier) UpsertTailnetClient(ctx context.Context, arg UpsertTailnetClientParams) (TailnetClient, error) {
row := q.db.QueryRowContext(ctx, upsertTailnetClient,
arg.ID,
arg.CoordinatorID,
arg.AgentID,
arg.Node,
)
row := q.db.QueryRowContext(ctx, upsertTailnetClient, arg.ID, arg.CoordinatorID, arg.Node)
var i TailnetClient
err := row.Scan(
&i.ID,
&i.CoordinatorID,
&i.AgentID,
&i.UpdatedAt,
&i.Node,
)
return i, err
}
const upsertTailnetClientSubscription = `-- name: UpsertTailnetClientSubscription :exec
INSERT INTO
tailnet_client_subscriptions (
client_id,
coordinator_id,
agent_id,
updated_at
)
VALUES
($1, $2, $3, now() at time zone 'utc')
ON CONFLICT (client_id, coordinator_id, agent_id)
DO UPDATE SET
client_id = $1,
coordinator_id = $2,
agent_id = $3,
updated_at = now() at time zone 'utc'
`
type UpsertTailnetClientSubscriptionParams struct {
ClientID uuid.UUID `db:"client_id" json:"client_id"`
CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"`
AgentID uuid.UUID `db:"agent_id" json:"agent_id"`
}
func (q *sqlQuerier) UpsertTailnetClientSubscription(ctx context.Context, arg UpsertTailnetClientSubscriptionParams) error {
_, err := q.db.ExecContext(ctx, upsertTailnetClientSubscription, arg.ClientID, arg.CoordinatorID, arg.AgentID)
return err
}
const upsertTailnetCoordinator = `-- name: UpsertTailnetCoordinator :one
INSERT INTO
tailnet_coordinators (

View File

@ -3,21 +3,36 @@ INSERT INTO
tailnet_clients (
id,
coordinator_id,
agent_id,
node,
updated_at
)
VALUES
($1, $2, $3, $4, now() at time zone 'utc')
($1, $2, $3, now() at time zone 'utc')
ON CONFLICT (id, coordinator_id)
DO UPDATE SET
id = $1,
coordinator_id = $2,
agent_id = $3,
node = $4,
node = $3,
updated_at = now() at time zone 'utc'
RETURNING *;
-- name: UpsertTailnetClientSubscription :exec
INSERT INTO
tailnet_client_subscriptions (
client_id,
coordinator_id,
agent_id,
updated_at
)
VALUES
($1, $2, $3, now() at time zone 'utc')
ON CONFLICT (client_id, coordinator_id, agent_id)
DO UPDATE SET
client_id = $1,
coordinator_id = $2,
agent_id = $3,
updated_at = now() at time zone 'utc';
-- name: UpsertTailnetAgent :one
INSERT INTO
tailnet_agents (
@ -43,6 +58,16 @@ FROM tailnet_clients
WHERE id = $1 and coordinator_id = $2
RETURNING id, coordinator_id;
-- name: DeleteTailnetClientSubscription :exec
DELETE
FROM tailnet_client_subscriptions
WHERE client_id = $1 and agent_id = $2 and coordinator_id = $3;
-- name: DeleteAllTailnetClientSubscriptions :exec
DELETE
FROM tailnet_client_subscriptions
WHERE client_id = $1 and coordinator_id = $2;
-- name: DeleteTailnetAgent :one
DELETE
FROM tailnet_agents
@ -66,12 +91,17 @@ FROM tailnet_agents;
-- name: GetTailnetClientsForAgent :many
SELECT *
FROM tailnet_clients
WHERE agent_id = $1;
WHERE id IN (
SELECT tailnet_client_subscriptions.client_id
FROM tailnet_client_subscriptions
WHERE tailnet_client_subscriptions.agent_id = $1
);
-- name: GetAllTailnetClients :many
SELECT *
SELECT sqlc.embed(tailnet_clients), array_agg(tailnet_client_subscriptions.agent_id)::uuid[] as agent_ids
FROM tailnet_clients
ORDER BY agent_id;
LEFT JOIN tailnet_client_subscriptions
ON tailnet_clients.id = tailnet_client_subscriptions.client_id;
-- name: UpsertTailnetCoordinator :one
INSERT INTO

View File

@ -68,6 +68,7 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request
id := uuid.New()
sub := (*api.AGPL.TailnetCoordinator.Load()).ServeMultiAgent(id)
ctx, nc := websocketNetConn(ctx, conn, websocket.MessageText)
defer nc.Close()

View File

@ -0,0 +1,137 @@
package tailnet
import (
"context"
"encoding/json"
"io"
"net"
"github.com/google/uuid"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"cdr.dev/slog"
agpl "github.com/coder/coder/v2/tailnet"
)
// connIO manages the reading and writing to a connected client or agent. Agent connIOs have their client field set to
// uuid.Nil. It reads node updates via its decoder, then pushes them onto the bindings channel. It receives mappings
// via its updates TrackedConn, which then writes them.
type connIO struct {
pCtx context.Context
ctx context.Context
cancel context.CancelFunc
logger slog.Logger
decoder *json.Decoder
updates *agpl.TrackedConn
bindings chan<- binding
}
func newConnIO(pCtx context.Context,
logger slog.Logger,
bindings chan<- binding,
conn net.Conn,
id uuid.UUID,
name string,
kind agpl.QueueKind,
) *connIO {
ctx, cancel := context.WithCancel(pCtx)
c := &connIO{
pCtx: pCtx,
ctx: ctx,
cancel: cancel,
logger: logger,
decoder: json.NewDecoder(conn),
updates: agpl.NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, kind),
bindings: bindings,
}
go c.recvLoop()
go c.updates.SendUpdates()
logger.Info(ctx, "serving connection")
return c
}
func (c *connIO) recvLoop() {
defer func() {
// withdraw bindings when we exit. We need to use the parent context here, since our own context might be
// canceled, but we still need to withdraw bindings.
b := binding{
bKey: bKey{
id: c.UniqueID(),
kind: c.Kind(),
},
}
if err := sendCtx(c.pCtx, c.bindings, b); err != nil {
c.logger.Debug(c.ctx, "parent context expired while withdrawing bindings", slog.Error(err))
}
}()
defer c.cancel()
for {
var node agpl.Node
err := c.decoder.Decode(&node)
if err != nil {
if xerrors.Is(err, io.EOF) ||
xerrors.Is(err, io.ErrClosedPipe) ||
xerrors.Is(err, context.Canceled) ||
xerrors.Is(err, context.DeadlineExceeded) ||
websocket.CloseStatus(err) > 0 {
c.logger.Debug(c.ctx, "exiting recvLoop", slog.Error(err))
} else {
c.logger.Error(c.ctx, "failed to decode Node update", slog.Error(err))
}
return
}
c.logger.Debug(c.ctx, "got node update", slog.F("node", node))
b := binding{
bKey: bKey{
id: c.UniqueID(),
kind: c.Kind(),
},
node: &node,
}
if err := sendCtx(c.ctx, c.bindings, b); err != nil {
c.logger.Debug(c.ctx, "recvLoop ctx expired", slog.Error(err))
return
}
}
}
func (c *connIO) UniqueID() uuid.UUID {
return c.updates.UniqueID()
}
func (c *connIO) Kind() agpl.QueueKind {
return c.updates.Kind()
}
func (c *connIO) Enqueue(n []*agpl.Node) error {
return c.updates.Enqueue(n)
}
func (c *connIO) Name() string {
return c.updates.Name()
}
func (c *connIO) Stats() (start int64, lastWrite int64) {
return c.updates.Stats()
}
func (c *connIO) Overwrites() int64 {
return c.updates.Overwrites()
}
// CoordinatorClose is used by the coordinator when closing a Queue. It
// should skip removing itself from the coordinator.
func (c *connIO) CoordinatorClose() error {
c.cancel()
return c.updates.CoordinatorClose()
}
func (c *connIO) Done() <-chan struct{} {
return c.ctx.Done()
}
func (c *connIO) Close() error {
c.cancel()
return c.updates.Close()
}

View File

@ -58,7 +58,7 @@ func (c *haCoordinator) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn {
AgentIsLegacyFunc: c.agentIsLegacy,
OnSubscribe: c.clientSubscribeToAgent,
OnNodeUpdate: c.clientNodeUpdate,
OnRemove: c.clientDisconnected,
OnRemove: func(enq agpl.Queue) { c.clientDisconnected(enq.UniqueID()) },
}).Init()
c.addClient(id, m)
return m
@ -157,7 +157,7 @@ func (c *haCoordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error
defer cancel()
logger := c.clientLogger(id, agentID)
tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0)
tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0, agpl.QueueKindClient)
defer tc.Close()
c.addClient(id, tc)
@ -300,7 +300,7 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err
}
// This uniquely identifies a connection that belongs to this goroutine.
unique := uuid.New()
tc := agpl.NewTrackedConn(ctx, cancel, conn, unique, logger, name, overwrites)
tc := agpl.NewTrackedConn(ctx, cancel, conn, unique, logger, name, overwrites, agpl.QueueKindAgent)
// Publish all nodes on this instance that want to connect to this agent.
nodes := c.nodesSubscribedToAgent(id)

View File

@ -0,0 +1,354 @@
package tailnet_test
import (
"context"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/enterprise/tailnet"
agpl "github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/testutil"
)
// TestPGCoordinator_MultiAgent tests a single coordinator with a MultiAgent
// connecting to one agent.
//
// +--------+
// agent1 ---> | coord1 | <--- client
// +--------+
func TestPGCoordinator_MultiAgent(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
store, ps := dbtestutil.NewDB(t)
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()
agent1 := newTestAgent(t, coord1, "agent1")
defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
id := uuid.New()
ma1 := coord1.ServeMultiAgent(id)
defer ma1.Close()
err = ma1.SubscribeAgent(agent1.id)
require.NoError(t, err)
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1)
err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3})
require.NoError(t, err)
assertEventuallyHasDERPs(ctx, t, agent1, 3)
require.NoError(t, ma1.Close())
require.NoError(t, agent1.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
assertEventuallyNoAgents(ctx, t, store, agent1.id)
}
// TestPGCoordinator_MultiAgent_UnsubscribeRace tests a single coordinator with
// a MultiAgent connecting to one agent. It tries to race a call to Unsubscribe
// with the MultiAgent closing.
//
// +--------+
// agent1 ---> | coord1 | <--- client
// +--------+
func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
store, ps := dbtestutil.NewDB(t)
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()
agent1 := newTestAgent(t, coord1, "agent1")
defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
id := uuid.New()
ma1 := coord1.ServeMultiAgent(id)
defer ma1.Close()
err = ma1.SubscribeAgent(agent1.id)
require.NoError(t, err)
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1)
err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3})
require.NoError(t, err)
assertEventuallyHasDERPs(ctx, t, agent1, 3)
require.NoError(t, ma1.UnsubscribeAgent(agent1.id))
require.NoError(t, ma1.Close())
require.NoError(t, agent1.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
assertEventuallyNoAgents(ctx, t, store, agent1.id)
}
// TestPGCoordinator_MultiAgent_Unsubscribe tests a single coordinator with a
// MultiAgent connecting to one agent. It unsubscribes before closing, and
// ensures node updates are no longer propagated.
//
// +--------+
// agent1 ---> | coord1 | <--- client
// +--------+
func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
store, ps := dbtestutil.NewDB(t)
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()
agent1 := newTestAgent(t, coord1, "agent1")
defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
id := uuid.New()
ma1 := coord1.ServeMultiAgent(id)
defer ma1.Close()
err = ma1.SubscribeAgent(agent1.id)
require.NoError(t, err)
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1)
require.NoError(t, ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}))
assertEventuallyHasDERPs(ctx, t, agent1, 3)
require.NoError(t, ma1.UnsubscribeAgent(agent1.id))
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
func() {
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
defer cancel()
require.NoError(t, ma1.UpdateSelf(&agpl.Node{PreferredDERP: 9}))
assertNeverHasDERPs(ctx, t, agent1, 9)
}()
func() {
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
defer cancel()
agent1.sendNode(&agpl.Node{PreferredDERP: 8})
assertMultiAgentNeverHasDERPs(ctx, t, ma1, 8)
}()
require.NoError(t, ma1.Close())
require.NoError(t, agent1.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
assertEventuallyNoAgents(ctx, t, store, agent1.id)
}
// TestPGCoordinator_MultiAgent_MultiCoordinator tests two coordinators with a
// MultiAgent connecting to an agent on a separate coordinator.
//
// +--------+
// agent1 ---> | coord1 |
// +--------+
// +--------+
// | coord2 | <--- client
// +--------+
func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
store, ps := dbtestutil.NewDB(t)
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()
coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store)
require.NoError(t, err)
defer coord2.Close()
agent1 := newTestAgent(t, coord1, "agent1")
defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
id := uuid.New()
ma1 := coord2.ServeMultiAgent(id)
defer ma1.Close()
err = ma1.SubscribeAgent(agent1.id)
require.NoError(t, err)
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1)
err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3})
require.NoError(t, err)
assertEventuallyHasDERPs(ctx, t, agent1, 3)
require.NoError(t, ma1.Close())
require.NoError(t, agent1.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
assertEventuallyNoAgents(ctx, t, store, agent1.id)
}
// TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe tests two
// coordinators with a MultiAgent connecting to an agent on a separate
// coordinator. The MultiAgent updates its own node before subscribing.
//
// +--------+
// agent1 ---> | coord1 |
// +--------+
// +--------+
// | coord2 | <--- client
// +--------+
func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
store, ps := dbtestutil.NewDB(t)
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()
coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store)
require.NoError(t, err)
defer coord2.Close()
agent1 := newTestAgent(t, coord1, "agent1")
defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
id := uuid.New()
ma1 := coord2.ServeMultiAgent(id)
defer ma1.Close()
err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3})
require.NoError(t, err)
err = ma1.SubscribeAgent(agent1.id)
require.NoError(t, err)
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
assertEventuallyHasDERPs(ctx, t, agent1, 3)
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1)
require.NoError(t, ma1.Close())
require.NoError(t, agent1.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
assertEventuallyNoAgents(ctx, t, store, agent1.id)
}
// TestPGCoordinator_MultiAgent_TwoAgents tests three coordinators with a
// MultiAgent connecting to two agents on separate coordinators.
//
// +--------+
// agent1 ---> | coord1 |
// +--------+
// +--------+
// agent2 ---> | coord2 |
// +--------+
// +--------+
// | coord3 | <--- client
// +--------+
func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
}
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
store, ps := dbtestutil.NewDB(t)
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
require.NoError(t, err)
defer coord1.Close()
coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store)
require.NoError(t, err)
defer coord2.Close()
coord3, err := tailnet.NewPGCoord(ctx, logger.Named("coord3"), ps, store)
require.NoError(t, err)
defer coord3.Close()
agent1 := newTestAgent(t, coord1, "agent1")
defer agent1.close()
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
agent2 := newTestAgent(t, coord2, "agent2")
defer agent1.close()
agent2.sendNode(&agpl.Node{PreferredDERP: 6})
id := uuid.New()
ma1 := coord3.ServeMultiAgent(id)
defer ma1.Close()
err = ma1.SubscribeAgent(agent1.id)
require.NoError(t, err)
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1)
err = ma1.SubscribeAgent(agent2.id)
require.NoError(t, err)
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 6)
agent2.sendNode(&agpl.Node{PreferredDERP: 2})
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 2)
err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3})
require.NoError(t, err)
assertEventuallyHasDERPs(ctx, t, agent1, 3)
assertEventuallyHasDERPs(ctx, t, agent2, 3)
require.NoError(t, ma1.Close())
require.NoError(t, agent1.close())
require.NoError(t, agent2.close())
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
assertEventuallyNoAgents(ctx, t, store, agent1.id)
}

File diff suppressed because it is too large Load Diff

View File

@ -438,7 +438,7 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) {
assertEventuallyNoClientsForAgent(ctx, t, store, agent2.id)
}
// TestPGCoordinator_MultiAgent tests when a single agent connects to multiple coordinators.
// TestPGCoordinator_MultiCoordinatorAgent tests when a single agent connects to multiple coordinators.
// We use two agent connections, but they share the same AgentID. This could happen due to a reconnection,
// or an infrastructure problem where an old workspace is not fully cleaned up before a new one started.
//
@ -451,7 +451,7 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) {
// +---------+
// | coord3 | <--- client
// +---------+
func TestPGCoordinator_MultiAgent(t *testing.T) {
func TestPGCoordinator_MultiCoordinatorAgent(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test only with postgres")
@ -693,8 +693,79 @@ func assertEventuallyHasDERPs(ctx context.Context, t *testing.T, c *testConn, ex
t.Logf("expected DERP %d to be in %v", e, derps)
continue
}
return
}
}
}
func assertNeverHasDERPs(ctx context.Context, t *testing.T, c *testConn, expected ...int) {
t.Helper()
for {
select {
case <-ctx.Done():
return
case nodes := <-c.nodeChan:
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if slices.Contains(derps, e) {
t.Fatalf("expected not to get DERP %d, but received it", e)
return
}
}
}
}
}
func assertMultiAgentEventuallyHasDERPs(ctx context.Context, t *testing.T, ma agpl.MultiAgentConn, expected ...int) {
t.Helper()
for {
nodes, ok := ma.NextUpdate(ctx)
require.True(t, ok)
if len(nodes) != len(expected) {
t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
continue
}
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if !slices.Contains(derps, e) {
t.Logf("expected DERP %d to be in %v", e, derps)
continue
}
return
}
}
}
func assertMultiAgentNeverHasDERPs(ctx context.Context, t *testing.T, ma agpl.MultiAgentConn, expected ...int) {
t.Helper()
for {
nodes, ok := ma.NextUpdate(ctx)
if !ok {
return
}
if len(nodes) != len(expected) {
t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
continue
}
derps := make([]int, 0, len(nodes))
for _, n := range nodes {
derps = append(derps, n.PreferredDERP)
}
for _, e := range expected {
if !slices.Contains(derps, e) {
t.Logf("expected DERP %d to be in %v", e, derps)
continue
}
return
}
return
}
}
@ -712,6 +783,7 @@ func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.
}
func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) {
t.Helper()
assert.Eventually(t, func() bool {
clients, err := store.GetTailnetClientsForAgent(ctx, agentID)
if xerrors.Is(err, sql.ErrNoRows) {

View File

@ -22,7 +22,7 @@ import (
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/workspaceapps"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/tailnet"
agpl "github.com/coder/coder/v2/tailnet"
)
// Client is a HTTP client for a subset of Coder API routes that external
@ -422,14 +422,14 @@ const (
type CoordinateMessage struct {
Type CoordinateMessageType `json:"type"`
AgentID uuid.UUID `json:"agent_id"`
Node *tailnet.Node `json:"node"`
Node *agpl.Node `json:"node"`
}
type CoordinateNodes struct {
Nodes []*tailnet.Node
Nodes []*agpl.Node
}
func (c *Client) DialCoordinator(ctx context.Context) (tailnet.MultiAgentConn, error) {
func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, error) {
ctx, cancel := context.WithCancel(ctx)
coordinateURL, err := c.SDKClient.URL.Parse("/api/v2/workspaceproxies/me/coordinate")
@ -463,13 +463,13 @@ func (c *Client) DialCoordinator(ctx context.Context) (tailnet.MultiAgentConn, e
legacyAgentCache: map[uuid.UUID]bool{},
}
ma := (&tailnet.MultiAgent{
ma := (&agpl.MultiAgent{
ID: uuid.New(),
AgentIsLegacyFunc: rma.AgentIsLegacy,
OnSubscribe: rma.OnSubscribe,
OnUnsubscribe: rma.OnUnsubscribe,
OnNodeUpdate: rma.OnNodeUpdate,
OnRemove: func(uuid.UUID) { conn.Close(websocket.StatusGoingAway, "closed") },
OnRemove: func(agpl.Queue) { conn.Close(websocket.StatusGoingAway, "closed") },
}).Init()
go func() {
@ -515,7 +515,7 @@ func (a *remoteMultiAgentHandler) writeJSON(v interface{}) error {
// Set a deadline so that hung connections don't put back pressure on the system.
// Node updates are tiny, so even the dinkiest connection can handle them if it's not hung.
err = a.nc.SetWriteDeadline(time.Now().Add(tailnet.WriteTimeout))
err = a.nc.SetWriteDeadline(time.Now().Add(agpl.WriteTimeout))
if err != nil {
return xerrors.Errorf("set write deadline: %w", err)
}
@ -537,21 +537,21 @@ func (a *remoteMultiAgentHandler) writeJSON(v interface{}) error {
return nil
}
func (a *remoteMultiAgentHandler) OnNodeUpdate(_ uuid.UUID, node *tailnet.Node) error {
func (a *remoteMultiAgentHandler) OnNodeUpdate(_ uuid.UUID, node *agpl.Node) error {
return a.writeJSON(CoordinateMessage{
Type: CoordinateMessageTypeNodeUpdate,
Node: node,
})
}
func (a *remoteMultiAgentHandler) OnSubscribe(_ tailnet.Queue, agentID uuid.UUID) (*tailnet.Node, error) {
func (a *remoteMultiAgentHandler) OnSubscribe(_ agpl.Queue, agentID uuid.UUID) (*agpl.Node, error) {
return nil, a.writeJSON(CoordinateMessage{
Type: CoordinateMessageTypeSubscribe,
AgentID: agentID,
})
}
func (a *remoteMultiAgentHandler) OnUnsubscribe(_ tailnet.Queue, agentID uuid.UUID) error {
func (a *remoteMultiAgentHandler) OnUnsubscribe(_ agpl.Queue, agentID uuid.UUID) error {
return a.writeJSON(CoordinateMessage{
Type: CoordinateMessageTypeUnsubscribe,
AgentID: agentID,

View File

@ -146,7 +146,7 @@ func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn {
OnSubscribe: c.core.clientSubscribeToAgent,
OnUnsubscribe: c.core.clientUnsubscribeFromAgent,
OnNodeUpdate: c.core.clientNodeUpdate,
OnRemove: c.core.clientDisconnected,
OnRemove: func(enq Queue) { c.core.clientDisconnected(enq.UniqueID()) },
}).Init()
c.core.addClient(id, m)
return m
@ -191,8 +191,16 @@ type core struct {
legacyAgents map[uuid.UUID]struct{}
}
type QueueKind int
const (
QueueKindClient QueueKind = 1 + iota
QueueKindAgent
)
type Queue interface {
UniqueID() uuid.UUID
Kind() QueueKind
Enqueue(n []*Node) error
Name() string
Stats() (start, lastWrite int64)
@ -200,6 +208,7 @@ type Queue interface {
// CoordinatorClose is used by the coordinator when closing a Queue. It
// should skip removing itself from the coordinator.
CoordinatorClose() error
Done() <-chan struct{}
Close() error
}
@ -264,7 +273,7 @@ func (c *coordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error {
logger := c.core.clientLogger(id, agentID)
logger.Debug(ctx, "coordinating client")
tc := NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0)
tc := NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0, QueueKindClient)
defer tc.Close()
c.core.addClient(id, tc)
@ -509,7 +518,7 @@ func (c *core) initAndTrackAgent(ctx context.Context, cancel func(), conn net.Co
overwrites = oldAgentSocket.Overwrites() + 1
_ = oldAgentSocket.Close()
}
tc := NewTrackedConn(ctx, cancel, conn, unique, logger, name, overwrites)
tc := NewTrackedConn(ctx, cancel, conn, unique, logger, name, overwrites, QueueKindAgent)
c.agentNameCache.Add(id, name)
sockets, ok := c.agentToConnectionSockets[id]

View File

@ -29,9 +29,12 @@ type MultiAgent struct {
OnSubscribe func(enq Queue, agent uuid.UUID) (*Node, error)
OnUnsubscribe func(enq Queue, agent uuid.UUID) error
OnNodeUpdate func(id uuid.UUID, node *Node) error
OnRemove func(id uuid.UUID)
OnRemove func(enq Queue)
ctx context.Context
ctxCancel func()
closed bool
updates chan []*Node
closeOnce sync.Once
start int64
@ -44,9 +47,14 @@ type MultiAgent struct {
func (m *MultiAgent) Init() *MultiAgent {
m.updates = make(chan []*Node, 128)
m.start = time.Now().Unix()
m.ctx, m.ctxCancel = context.WithCancel(context.Background())
return m
}
func (*MultiAgent) Kind() QueueKind {
return QueueKindClient
}
func (m *MultiAgent) UniqueID() uuid.UUID {
return m.ID
}
@ -156,8 +164,13 @@ func (m *MultiAgent) CoordinatorClose() error {
return nil
}
func (m *MultiAgent) Done() <-chan struct{} {
return m.ctx.Done()
}
func (m *MultiAgent) Close() error {
_ = m.CoordinatorClose()
m.closeOnce.Do(func() { m.OnRemove(m.ID) })
m.ctxCancel()
m.closeOnce.Do(func() { m.OnRemove(m) })
return nil
}

View File

@ -20,6 +20,7 @@ const WriteTimeout = time.Second * 5
type TrackedConn struct {
ctx context.Context
cancel func()
kind QueueKind
conn net.Conn
updates chan []*Node
logger slog.Logger
@ -35,7 +36,14 @@ type TrackedConn struct {
overwrites int64
}
func NewTrackedConn(ctx context.Context, cancel func(), conn net.Conn, id uuid.UUID, logger slog.Logger, name string, overwrites int64) *TrackedConn {
func NewTrackedConn(ctx context.Context, cancel func(),
conn net.Conn,
id uuid.UUID,
logger slog.Logger,
name string,
overwrites int64,
kind QueueKind,
) *TrackedConn {
// buffer updates so they don't block, since we hold the
// coordinator mutex while queuing. Node updates don't
// come quickly, so 512 should be plenty for all but
@ -53,6 +61,7 @@ func NewTrackedConn(ctx context.Context, cancel func(), conn net.Conn, id uuid.U
lastWrite: now,
name: name,
overwrites: overwrites,
kind: kind,
}
}
@ -70,6 +79,10 @@ func (t *TrackedConn) UniqueID() uuid.UUID {
return t.id
}
func (t *TrackedConn) Kind() QueueKind {
return t.kind
}
func (t *TrackedConn) Name() string {
return t.name
}
@ -86,6 +99,10 @@ func (t *TrackedConn) CoordinatorClose() error {
return t.Close()
}
func (t *TrackedConn) Done() <-chan struct{} {
return t.ctx.Done()
}
// Close the connection and cancel the context for reading node updates from the queue
func (t *TrackedConn) Close() error {
t.cancel()