mirror of https://github.com/coder/coder.git
feat: add single tailnet support to pgcoord (#9351)
This commit is contained in:
parent
fbad06f406
commit
c900b5f8df
|
@ -694,6 +694,13 @@ func (q *querier) DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) e
|
|||
return q.db.DeleteAPIKeysByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteAllTailnetClientSubscriptions(ctx context.Context, arg database.DeleteAllTailnetClientSubscriptionsParams) error {
|
||||
if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteAllTailnetClientSubscriptions(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error {
|
||||
// TODO: This is not 100% correct because it omits apikey IDs.
|
||||
err := q.authorizeContext(ctx, rbac.ActionDelete,
|
||||
|
@ -783,6 +790,13 @@ func (q *querier) DeleteTailnetClient(ctx context.Context, arg database.DeleteTa
|
|||
return q.db.DeleteTailnetClient(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteTailnetClientSubscription(ctx context.Context, arg database.DeleteTailnetClientSubscriptionParams) error {
|
||||
if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteTailnetClientSubscription(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) {
|
||||
return fetch(q.log, q.auth, q.db.GetAPIKeyByID)(ctx, id)
|
||||
}
|
||||
|
@ -825,9 +839,9 @@ func (q *querier) GetAllTailnetAgents(ctx context.Context) ([]database.TailnetAg
|
|||
return q.db.GetAllTailnetAgents(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetAllTailnetClients(ctx context.Context) ([]database.TailnetClient, error) {
|
||||
func (q *querier) GetAllTailnetClients(ctx context.Context) ([]database.GetAllTailnetClientsRow, error) {
|
||||
if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTailnetCoordinator); err != nil {
|
||||
return []database.TailnetClient{}, err
|
||||
return []database.GetAllTailnetClientsRow{}, err
|
||||
}
|
||||
return q.db.GetAllTailnetClients(ctx)
|
||||
}
|
||||
|
@ -2794,6 +2808,13 @@ func (q *querier) UpsertTailnetClient(ctx context.Context, arg database.UpsertTa
|
|||
return q.db.UpsertTailnetClient(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertTailnetClientSubscription(ctx context.Context, arg database.UpsertTailnetClientSubscriptionParams) error {
|
||||
if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.UpsertTailnetClientSubscription(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (database.TailnetCoordinator, error) {
|
||||
if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil {
|
||||
return database.TailnetCoordinator{}, err
|
||||
|
|
|
@ -854,6 +854,15 @@ func (q *FakeQuerier) DeleteAPIKeysByUserID(_ context.Context, userID uuid.UUID)
|
|||
return nil
|
||||
}
|
||||
|
||||
func (*FakeQuerier) DeleteAllTailnetClientSubscriptions(_ context.Context, arg database.DeleteAllTailnetClientSubscriptionsParams) error {
|
||||
err := validateDatabaseType(arg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ErrUnimplemented
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) DeleteApplicationConnectAPIKeysByUserID(_ context.Context, userID uuid.UUID) error {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
@ -987,6 +996,10 @@ func (*FakeQuerier) DeleteTailnetClient(context.Context, database.DeleteTailnetC
|
|||
return database.DeleteTailnetClientRow{}, ErrUnimplemented
|
||||
}
|
||||
|
||||
func (*FakeQuerier) DeleteTailnetClientSubscription(context.Context, database.DeleteTailnetClientSubscriptionParams) error {
|
||||
return ErrUnimplemented
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetAPIKeyByID(_ context.Context, id string) (database.APIKey, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
@ -1102,7 +1115,7 @@ func (*FakeQuerier) GetAllTailnetAgents(_ context.Context) ([]database.TailnetAg
|
|||
return nil, ErrUnimplemented
|
||||
}
|
||||
|
||||
func (*FakeQuerier) GetAllTailnetClients(_ context.Context) ([]database.TailnetClient, error) {
|
||||
func (*FakeQuerier) GetAllTailnetClients(_ context.Context) ([]database.GetAllTailnetClientsRow, error) {
|
||||
return nil, ErrUnimplemented
|
||||
}
|
||||
|
||||
|
@ -6112,6 +6125,10 @@ func (*FakeQuerier) UpsertTailnetClient(context.Context, database.UpsertTailnetC
|
|||
return database.TailnetClient{}, ErrUnimplemented
|
||||
}
|
||||
|
||||
func (*FakeQuerier) UpsertTailnetClientSubscription(context.Context, database.UpsertTailnetClientSubscriptionParams) error {
|
||||
return ErrUnimplemented
|
||||
}
|
||||
|
||||
func (*FakeQuerier) UpsertTailnetCoordinator(context.Context, uuid.UUID) (database.TailnetCoordinator, error) {
|
||||
return database.TailnetCoordinator{}, ErrUnimplemented
|
||||
}
|
||||
|
|
|
@ -128,6 +128,13 @@ func (m metricsStore) DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUI
|
|||
return err
|
||||
}
|
||||
|
||||
func (m metricsStore) DeleteAllTailnetClientSubscriptions(ctx context.Context, arg database.DeleteAllTailnetClientSubscriptionsParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteAllTailnetClientSubscriptions(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteAllTailnetClientSubscriptions").Observe(time.Since(start).Seconds())
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m metricsStore) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
err := m.s.DeleteApplicationConnectAPIKeysByUserID(ctx, userID)
|
||||
|
@ -209,6 +216,13 @@ func (m metricsStore) DeleteTailnetClient(ctx context.Context, arg database.Dele
|
|||
return m.s.DeleteTailnetClient(ctx, arg)
|
||||
}
|
||||
|
||||
func (m metricsStore) DeleteTailnetClientSubscription(ctx context.Context, arg database.DeleteTailnetClientSubscriptionParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteTailnetClientSubscription(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteTailnetClientSubscription").Observe(time.Since(start).Seconds())
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m metricsStore) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) {
|
||||
start := time.Now()
|
||||
apiKey, err := m.s.GetAPIKeyByID(ctx, id)
|
||||
|
@ -265,7 +279,7 @@ func (m metricsStore) GetAllTailnetAgents(ctx context.Context) ([]database.Tailn
|
|||
return r0, r1
|
||||
}
|
||||
|
||||
func (m metricsStore) GetAllTailnetClients(ctx context.Context) ([]database.TailnetClient, error) {
|
||||
func (m metricsStore) GetAllTailnetClients(ctx context.Context) ([]database.GetAllTailnetClientsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetAllTailnetClients(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetAllTailnetClients").Observe(time.Since(start).Seconds())
|
||||
|
@ -1752,6 +1766,13 @@ func (m metricsStore) UpsertTailnetClient(ctx context.Context, arg database.Upse
|
|||
return m.s.UpsertTailnetClient(ctx, arg)
|
||||
}
|
||||
|
||||
func (m metricsStore) UpsertTailnetClientSubscription(ctx context.Context, arg database.UpsertTailnetClientSubscriptionParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertTailnetClientSubscription(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertTailnetClientSubscription").Observe(time.Since(start).Seconds())
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m metricsStore) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (database.TailnetCoordinator, error) {
|
||||
start := time.Now()
|
||||
defer m.queryLatencies.WithLabelValues("UpsertTailnetCoordinator").Observe(time.Since(start).Seconds())
|
||||
|
|
|
@ -139,6 +139,20 @@ func (mr *MockStoreMockRecorder) DeleteAPIKeysByUserID(arg0, arg1 interface{}) *
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAPIKeysByUserID", reflect.TypeOf((*MockStore)(nil).DeleteAPIKeysByUserID), arg0, arg1)
|
||||
}
|
||||
|
||||
// DeleteAllTailnetClientSubscriptions mocks base method.
|
||||
func (m *MockStore) DeleteAllTailnetClientSubscriptions(arg0 context.Context, arg1 database.DeleteAllTailnetClientSubscriptionsParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAllTailnetClientSubscriptions", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAllTailnetClientSubscriptions indicates an expected call of DeleteAllTailnetClientSubscriptions.
|
||||
func (mr *MockStoreMockRecorder) DeleteAllTailnetClientSubscriptions(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllTailnetClientSubscriptions", reflect.TypeOf((*MockStore)(nil).DeleteAllTailnetClientSubscriptions), arg0, arg1)
|
||||
}
|
||||
|
||||
// DeleteApplicationConnectAPIKeysByUserID mocks base method.
|
||||
func (m *MockStore) DeleteApplicationConnectAPIKeysByUserID(arg0 context.Context, arg1 uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -310,6 +324,20 @@ func (mr *MockStoreMockRecorder) DeleteTailnetClient(arg0, arg1 interface{}) *go
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTailnetClient", reflect.TypeOf((*MockStore)(nil).DeleteTailnetClient), arg0, arg1)
|
||||
}
|
||||
|
||||
// DeleteTailnetClientSubscription mocks base method.
|
||||
func (m *MockStore) DeleteTailnetClientSubscription(arg0 context.Context, arg1 database.DeleteTailnetClientSubscriptionParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteTailnetClientSubscription", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteTailnetClientSubscription indicates an expected call of DeleteTailnetClientSubscription.
|
||||
func (mr *MockStoreMockRecorder) DeleteTailnetClientSubscription(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTailnetClientSubscription", reflect.TypeOf((*MockStore)(nil).DeleteTailnetClientSubscription), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetAPIKeyByID mocks base method.
|
||||
func (m *MockStore) GetAPIKeyByID(arg0 context.Context, arg1 string) (database.APIKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -431,10 +459,10 @@ func (mr *MockStoreMockRecorder) GetAllTailnetAgents(arg0 interface{}) *gomock.C
|
|||
}
|
||||
|
||||
// GetAllTailnetClients mocks base method.
|
||||
func (m *MockStore) GetAllTailnetClients(arg0 context.Context) ([]database.TailnetClient, error) {
|
||||
func (m *MockStore) GetAllTailnetClients(arg0 context.Context) ([]database.GetAllTailnetClientsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAllTailnetClients", arg0)
|
||||
ret0, _ := ret[0].([]database.TailnetClient)
|
||||
ret0, _ := ret[0].([]database.GetAllTailnetClientsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
@ -3681,6 +3709,20 @@ func (mr *MockStoreMockRecorder) UpsertTailnetClient(arg0, arg1 interface{}) *go
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTailnetClient", reflect.TypeOf((*MockStore)(nil).UpsertTailnetClient), arg0, arg1)
|
||||
}
|
||||
|
||||
// UpsertTailnetClientSubscription mocks base method.
|
||||
func (m *MockStore) UpsertTailnetClientSubscription(arg0 context.Context, arg1 database.UpsertTailnetClientSubscriptionParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertTailnetClientSubscription", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpsertTailnetClientSubscription indicates an expected call of UpsertTailnetClientSubscription.
|
||||
func (mr *MockStoreMockRecorder) UpsertTailnetClientSubscription(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTailnetClientSubscription", reflect.TypeOf((*MockStore)(nil).UpsertTailnetClientSubscription), arg0, arg1)
|
||||
}
|
||||
|
||||
// UpsertTailnetCoordinator mocks base method.
|
||||
func (m *MockStore) UpsertTailnetCoordinator(arg0 context.Context, arg1 uuid.UUID) (database.TailnetCoordinator, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -219,13 +219,57 @@ $$;
|
|||
CREATE FUNCTION tailnet_notify_client_change() RETURNS trigger
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
DECLARE
|
||||
var_client_id uuid;
|
||||
var_coordinator_id uuid;
|
||||
var_agent_ids uuid[];
|
||||
var_agent_id uuid;
|
||||
BEGIN
|
||||
IF (OLD IS NOT NULL) THEN
|
||||
PERFORM pg_notify('tailnet_client_update', OLD.id || ',' || OLD.agent_id);
|
||||
RETURN NULL;
|
||||
IF (NEW.id IS NOT NULL) THEN
|
||||
var_client_id = NEW.id;
|
||||
var_coordinator_id = NEW.coordinator_id;
|
||||
ELSIF (OLD.id IS NOT NULL) THEN
|
||||
var_client_id = OLD.id;
|
||||
var_coordinator_id = OLD.coordinator_id;
|
||||
END IF;
|
||||
|
||||
-- Read all agents the client is subscribed to, so we can notify them.
|
||||
SELECT
|
||||
array_agg(agent_id)
|
||||
INTO
|
||||
var_agent_ids
|
||||
FROM
|
||||
tailnet_client_subscriptions subs
|
||||
WHERE
|
||||
subs.client_id = NEW.id AND
|
||||
subs.coordinator_id = NEW.coordinator_id;
|
||||
|
||||
-- No agents to notify
|
||||
if (var_agent_ids IS NULL) THEN
|
||||
return NULL;
|
||||
END IF;
|
||||
|
||||
-- pg_notify is limited to 8k bytes, which is approximately 221 UUIDs.
|
||||
-- Instead of sending all agent ids in a single update, send one for each
|
||||
-- agent id to prevent overflow.
|
||||
FOREACH var_agent_id IN ARRAY var_agent_ids
|
||||
LOOP
|
||||
PERFORM pg_notify('tailnet_client_update', var_client_id || ',' || var_agent_id);
|
||||
END LOOP;
|
||||
|
||||
return NULL;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE FUNCTION tailnet_notify_client_subscription_change() RETURNS trigger
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
IF (NEW IS NOT NULL) THEN
|
||||
PERFORM pg_notify('tailnet_client_update', NEW.id || ',' || NEW.agent_id);
|
||||
PERFORM pg_notify('tailnet_client_update', NEW.client_id || ',' || NEW.agent_id);
|
||||
RETURN NULL;
|
||||
ELSIF (OLD IS NOT NULL) THEN
|
||||
PERFORM pg_notify('tailnet_client_update', OLD.client_id || ',' || OLD.agent_id);
|
||||
RETURN NULL;
|
||||
END IF;
|
||||
END;
|
||||
|
@ -495,10 +539,16 @@ CREATE TABLE tailnet_agents (
|
|||
node jsonb NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE tailnet_client_subscriptions (
|
||||
client_id uuid NOT NULL,
|
||||
coordinator_id uuid NOT NULL,
|
||||
agent_id uuid NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE tailnet_clients (
|
||||
id uuid NOT NULL,
|
||||
coordinator_id uuid NOT NULL,
|
||||
agent_id uuid NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
node jsonb NOT NULL
|
||||
);
|
||||
|
@ -1144,6 +1194,9 @@ ALTER TABLE ONLY site_configs
|
|||
ALTER TABLE ONLY tailnet_agents
|
||||
ADD CONSTRAINT tailnet_agents_pkey PRIMARY KEY (id, coordinator_id);
|
||||
|
||||
ALTER TABLE ONLY tailnet_client_subscriptions
|
||||
ADD CONSTRAINT tailnet_client_subscriptions_pkey PRIMARY KEY (client_id, coordinator_id, agent_id);
|
||||
|
||||
ALTER TABLE ONLY tailnet_clients
|
||||
ADD CONSTRAINT tailnet_clients_pkey PRIMARY KEY (id, coordinator_id);
|
||||
|
||||
|
@ -1248,8 +1301,6 @@ CREATE UNIQUE INDEX idx_organization_name_lower ON organizations USING btree (lo
|
|||
|
||||
CREATE INDEX idx_tailnet_agents_coordinator ON tailnet_agents USING btree (coordinator_id);
|
||||
|
||||
CREATE INDEX idx_tailnet_clients_agent ON tailnet_clients USING btree (agent_id);
|
||||
|
||||
CREATE INDEX idx_tailnet_clients_coordinator ON tailnet_clients USING btree (coordinator_id);
|
||||
|
||||
CREATE UNIQUE INDEX idx_users_email ON users USING btree (email) WHERE (deleted = false);
|
||||
|
@ -1284,6 +1335,8 @@ CREATE TRIGGER tailnet_notify_agent_change AFTER INSERT OR DELETE OR UPDATE ON t
|
|||
|
||||
CREATE TRIGGER tailnet_notify_client_change AFTER INSERT OR DELETE OR UPDATE ON tailnet_clients FOR EACH ROW EXECUTE FUNCTION tailnet_notify_client_change();
|
||||
|
||||
CREATE TRIGGER tailnet_notify_client_subscription_change AFTER INSERT OR DELETE OR UPDATE ON tailnet_client_subscriptions FOR EACH ROW EXECUTE FUNCTION tailnet_notify_client_subscription_change();
|
||||
|
||||
CREATE TRIGGER tailnet_notify_coordinator_heartbeat AFTER INSERT OR UPDATE ON tailnet_coordinators FOR EACH ROW EXECUTE FUNCTION tailnet_notify_coordinator_heartbeat();
|
||||
|
||||
CREATE TRIGGER trigger_insert_apikeys BEFORE INSERT ON api_keys FOR EACH ROW EXECUTE FUNCTION insert_apikey_fail_if_user_deleted();
|
||||
|
@ -1329,6 +1382,9 @@ ALTER TABLE ONLY provisioner_jobs
|
|||
ALTER TABLE ONLY tailnet_agents
|
||||
ADD CONSTRAINT tailnet_agents_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY tailnet_client_subscriptions
|
||||
ADD CONSTRAINT tailnet_client_subscriptions_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY tailnet_clients
|
||||
ADD CONSTRAINT tailnet_clients_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE;
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ const (
|
|||
ForeignKeyProvisionerJobLogsJobID ForeignKeyConstraint = "provisioner_job_logs_job_id_fkey" // ALTER TABLE ONLY provisioner_job_logs ADD CONSTRAINT provisioner_job_logs_job_id_fkey FOREIGN KEY (job_id) REFERENCES provisioner_jobs(id) ON DELETE CASCADE;
|
||||
ForeignKeyProvisionerJobsOrganizationID ForeignKeyConstraint = "provisioner_jobs_organization_id_fkey" // ALTER TABLE ONLY provisioner_jobs ADD CONSTRAINT provisioner_jobs_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
|
||||
ForeignKeyTailnetAgentsCoordinatorID ForeignKeyConstraint = "tailnet_agents_coordinator_id_fkey" // ALTER TABLE ONLY tailnet_agents ADD CONSTRAINT tailnet_agents_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE;
|
||||
ForeignKeyTailnetClientSubscriptionsCoordinatorID ForeignKeyConstraint = "tailnet_client_subscriptions_coordinator_id_fkey" // ALTER TABLE ONLY tailnet_client_subscriptions ADD CONSTRAINT tailnet_client_subscriptions_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE;
|
||||
ForeignKeyTailnetClientsCoordinatorID ForeignKeyConstraint = "tailnet_clients_coordinator_id_fkey" // ALTER TABLE ONLY tailnet_clients ADD CONSTRAINT tailnet_clients_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE;
|
||||
ForeignKeyTemplateVersionParametersTemplateVersionID ForeignKeyConstraint = "template_version_parameters_template_version_id_fkey" // ALTER TABLE ONLY template_version_parameters ADD CONSTRAINT template_version_parameters_template_version_id_fkey FOREIGN KEY (template_version_id) REFERENCES template_versions(id) ON DELETE CASCADE;
|
||||
ForeignKeyTemplateVersionVariablesTemplateVersionID ForeignKeyConstraint = "template_version_variables_template_version_id_fkey" // ALTER TABLE ONLY template_version_variables ADD CONSTRAINT template_version_variables_template_version_id_fkey FOREIGN KEY (template_version_id) REFERENCES template_versions(id) ON DELETE CASCADE;
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
BEGIN;
|
||||
|
||||
ALTER TABLE
|
||||
tailnet_clients
|
||||
ADD COLUMN
|
||||
agent_id uuid;
|
||||
|
||||
UPDATE
|
||||
tailnet_clients
|
||||
SET
|
||||
-- there's no reason for us to try and preserve data since coordinators will
|
||||
-- have to restart anyways, which will create all of the client mappings.
|
||||
agent_id = '00000000-0000-0000-0000-000000000000'::uuid;
|
||||
|
||||
ALTER TABLE
|
||||
tailnet_clients
|
||||
ALTER COLUMN
|
||||
agent_id SET NOT NULL;
|
||||
|
||||
DROP TABLE tailnet_client_subscriptions;
|
||||
DROP FUNCTION tailnet_notify_client_subscription_change;
|
||||
|
||||
-- update the tailnet_clients trigger to the old version.
|
||||
CREATE OR REPLACE FUNCTION tailnet_notify_client_change() RETURNS trigger
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
IF (OLD IS NOT NULL) THEN
|
||||
PERFORM pg_notify('tailnet_client_update', OLD.id || ',' || OLD.agent_id);
|
||||
RETURN NULL;
|
||||
END IF;
|
||||
IF (NEW IS NOT NULL) THEN
|
||||
PERFORM pg_notify('tailnet_client_update', NEW.id || ',' || NEW.agent_id);
|
||||
RETURN NULL;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
COMMIT;
|
|
@ -0,0 +1,88 @@
|
|||
BEGIN;
|
||||
|
||||
CREATE TABLE tailnet_client_subscriptions (
|
||||
client_id uuid NOT NULL,
|
||||
coordinator_id uuid NOT NULL,
|
||||
-- this isn't a foreign key since it's more of a list of agents the client
|
||||
-- *wants* to connect to, and they don't necessarily have to currently
|
||||
-- exist in the db.
|
||||
agent_id uuid NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL,
|
||||
PRIMARY KEY (client_id, coordinator_id, agent_id),
|
||||
FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators (id) ON DELETE CASCADE
|
||||
-- we don't keep a foreign key to the tailnet_clients table since there's
|
||||
-- not a great way to guarantee that a subscription is always added after
|
||||
-- the client is inserted. clients are only created after the client sends
|
||||
-- its first node update, which can take an undetermined amount of time.
|
||||
);
|
||||
|
||||
CREATE FUNCTION tailnet_notify_client_subscription_change() RETURNS trigger
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
IF (NEW IS NOT NULL) THEN
|
||||
PERFORM pg_notify('tailnet_client_update', NEW.client_id || ',' || NEW.agent_id);
|
||||
RETURN NULL;
|
||||
ELSIF (OLD IS NOT NULL) THEN
|
||||
PERFORM pg_notify('tailnet_client_update', OLD.client_id || ',' || OLD.agent_id);
|
||||
RETURN NULL;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE TRIGGER tailnet_notify_client_subscription_change
|
||||
AFTER INSERT OR UPDATE OR DELETE ON tailnet_client_subscriptions
|
||||
FOR EACH ROW
|
||||
EXECUTE PROCEDURE tailnet_notify_client_subscription_change();
|
||||
|
||||
CREATE OR REPLACE FUNCTION tailnet_notify_client_change() RETURNS trigger
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
DECLARE
|
||||
var_client_id uuid;
|
||||
var_coordinator_id uuid;
|
||||
var_agent_ids uuid[];
|
||||
var_agent_id uuid;
|
||||
BEGIN
|
||||
IF (NEW.id IS NOT NULL) THEN
|
||||
var_client_id = NEW.id;
|
||||
var_coordinator_id = NEW.coordinator_id;
|
||||
ELSIF (OLD.id IS NOT NULL) THEN
|
||||
var_client_id = OLD.id;
|
||||
var_coordinator_id = OLD.coordinator_id;
|
||||
END IF;
|
||||
|
||||
-- Read all agents the client is subscribed to, so we can notify them.
|
||||
SELECT
|
||||
array_agg(agent_id)
|
||||
INTO
|
||||
var_agent_ids
|
||||
FROM
|
||||
tailnet_client_subscriptions subs
|
||||
WHERE
|
||||
subs.client_id = NEW.id AND
|
||||
subs.coordinator_id = NEW.coordinator_id;
|
||||
|
||||
-- No agents to notify
|
||||
if (var_agent_ids IS NULL) THEN
|
||||
return NULL;
|
||||
END IF;
|
||||
|
||||
-- pg_notify is limited to 8k bytes, which is approximately 221 UUIDs.
|
||||
-- Instead of sending all agent ids in a single update, send one for each
|
||||
-- agent id to prevent overflow.
|
||||
FOREACH var_agent_id IN ARRAY var_agent_ids
|
||||
LOOP
|
||||
PERFORM pg_notify('tailnet_client_update', var_client_id || ',' || var_agent_id);
|
||||
END LOOP;
|
||||
|
||||
return NULL;
|
||||
END;
|
||||
$$;
|
||||
|
||||
ALTER TABLE
|
||||
tailnet_clients
|
||||
DROP COLUMN
|
||||
agent_id;
|
||||
|
||||
COMMIT;
|
|
@ -18,7 +18,7 @@ VALUES
|
|||
);
|
||||
|
||||
INSERT INTO tailnet_agents
|
||||
(id, coordinator_id, updated_at, node)
|
||||
(id, coordinator_id, updated_at, node)
|
||||
VALUES
|
||||
(
|
||||
'c0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11',
|
||||
|
|
9
coderd/database/migrations/testdata/fixtures/000156_pg_coordinator_single_tailnet.up.sql
vendored
Normal file
9
coderd/database/migrations/testdata/fixtures/000156_pg_coordinator_single_tailnet.up.sql
vendored
Normal file
|
@ -0,0 +1,9 @@
|
|||
INSERT INTO tailnet_client_subscriptions
|
||||
(client_id, agent_id, coordinator_id, updated_at)
|
||||
VALUES
|
||||
(
|
||||
'b0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11',
|
||||
'c0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11',
|
||||
'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11',
|
||||
'2023-06-15 10:23:54+00'
|
||||
);
|
|
@ -1783,11 +1783,17 @@ type TailnetAgent struct {
|
|||
type TailnetClient struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"`
|
||||
AgentID uuid.UUID `db:"agent_id" json:"agent_id"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
Node json.RawMessage `db:"node" json:"node"`
|
||||
}
|
||||
|
||||
type TailnetClientSubscription struct {
|
||||
ClientID uuid.UUID `db:"client_id" json:"client_id"`
|
||||
CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"`
|
||||
AgentID uuid.UUID `db:"agent_id" json:"agent_id"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
// We keep this separate from replicas in case we need to break the coordinator out into its own service
|
||||
type TailnetCoordinator struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
|
|
|
@ -36,6 +36,7 @@ type sqlcQuerier interface {
|
|||
CleanTailnetCoordinators(ctx context.Context) error
|
||||
DeleteAPIKeyByID(ctx context.Context, id string) error
|
||||
DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error
|
||||
DeleteAllTailnetClientSubscriptions(ctx context.Context, arg DeleteAllTailnetClientSubscriptionsParams) error
|
||||
DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error
|
||||
DeleteCoordinator(ctx context.Context, id uuid.UUID) error
|
||||
DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error
|
||||
|
@ -50,6 +51,7 @@ type sqlcQuerier interface {
|
|||
DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt time.Time) error
|
||||
DeleteTailnetAgent(ctx context.Context, arg DeleteTailnetAgentParams) (DeleteTailnetAgentRow, error)
|
||||
DeleteTailnetClient(ctx context.Context, arg DeleteTailnetClientParams) (DeleteTailnetClientRow, error)
|
||||
DeleteTailnetClientSubscription(ctx context.Context, arg DeleteTailnetClientSubscriptionParams) error
|
||||
GetAPIKeyByID(ctx context.Context, id string) (APIKey, error)
|
||||
// there is no unique constraint on empty token names
|
||||
GetAPIKeyByName(ctx context.Context, arg GetAPIKeyByNameParams) (APIKey, error)
|
||||
|
@ -59,7 +61,7 @@ type sqlcQuerier interface {
|
|||
GetActiveUserCount(ctx context.Context) (int64, error)
|
||||
GetActiveWorkspaceBuildsByTemplateID(ctx context.Context, templateID uuid.UUID) ([]WorkspaceBuild, error)
|
||||
GetAllTailnetAgents(ctx context.Context) ([]TailnetAgent, error)
|
||||
GetAllTailnetClients(ctx context.Context) ([]TailnetClient, error)
|
||||
GetAllTailnetClients(ctx context.Context) ([]GetAllTailnetClientsRow, error)
|
||||
GetAppSecurityKey(ctx context.Context) (string, error)
|
||||
// GetAuditLogsBefore retrieves `row_limit` number of audit logs before the provided
|
||||
// ID.
|
||||
|
@ -324,6 +326,7 @@ type sqlcQuerier interface {
|
|||
UpsertServiceBanner(ctx context.Context, value string) error
|
||||
UpsertTailnetAgent(ctx context.Context, arg UpsertTailnetAgentParams) (TailnetAgent, error)
|
||||
UpsertTailnetClient(ctx context.Context, arg UpsertTailnetClientParams) (TailnetClient, error)
|
||||
UpsertTailnetClientSubscription(ctx context.Context, arg UpsertTailnetClientSubscriptionParams) error
|
||||
UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (TailnetCoordinator, error)
|
||||
}
|
||||
|
||||
|
|
|
@ -4131,6 +4131,22 @@ func (q *sqlQuerier) CleanTailnetCoordinators(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
const deleteAllTailnetClientSubscriptions = `-- name: DeleteAllTailnetClientSubscriptions :exec
|
||||
DELETE
|
||||
FROM tailnet_client_subscriptions
|
||||
WHERE client_id = $1 and coordinator_id = $2
|
||||
`
|
||||
|
||||
type DeleteAllTailnetClientSubscriptionsParams struct {
|
||||
ClientID uuid.UUID `db:"client_id" json:"client_id"`
|
||||
CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) DeleteAllTailnetClientSubscriptions(ctx context.Context, arg DeleteAllTailnetClientSubscriptionsParams) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteAllTailnetClientSubscriptions, arg.ClientID, arg.CoordinatorID)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteCoordinator = `-- name: DeleteCoordinator :exec
|
||||
DELETE
|
||||
FROM tailnet_coordinators
|
||||
|
@ -4190,6 +4206,23 @@ func (q *sqlQuerier) DeleteTailnetClient(ctx context.Context, arg DeleteTailnetC
|
|||
return i, err
|
||||
}
|
||||
|
||||
const deleteTailnetClientSubscription = `-- name: DeleteTailnetClientSubscription :exec
|
||||
DELETE
|
||||
FROM tailnet_client_subscriptions
|
||||
WHERE client_id = $1 and agent_id = $2 and coordinator_id = $3
|
||||
`
|
||||
|
||||
type DeleteTailnetClientSubscriptionParams struct {
|
||||
ClientID uuid.UUID `db:"client_id" json:"client_id"`
|
||||
AgentID uuid.UUID `db:"agent_id" json:"agent_id"`
|
||||
CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) DeleteTailnetClientSubscription(ctx context.Context, arg DeleteTailnetClientSubscriptionParams) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteTailnetClientSubscription, arg.ClientID, arg.AgentID, arg.CoordinatorID)
|
||||
return err
|
||||
}
|
||||
|
||||
const getAllTailnetAgents = `-- name: GetAllTailnetAgents :many
|
||||
SELECT id, coordinator_id, updated_at, node
|
||||
FROM tailnet_agents
|
||||
|
@ -4224,26 +4257,32 @@ func (q *sqlQuerier) GetAllTailnetAgents(ctx context.Context) ([]TailnetAgent, e
|
|||
}
|
||||
|
||||
const getAllTailnetClients = `-- name: GetAllTailnetClients :many
|
||||
SELECT id, coordinator_id, agent_id, updated_at, node
|
||||
SELECT tailnet_clients.id, tailnet_clients.coordinator_id, tailnet_clients.updated_at, tailnet_clients.node, array_agg(tailnet_client_subscriptions.agent_id)::uuid[] as agent_ids
|
||||
FROM tailnet_clients
|
||||
ORDER BY agent_id
|
||||
LEFT JOIN tailnet_client_subscriptions
|
||||
ON tailnet_clients.id = tailnet_client_subscriptions.client_id
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetAllTailnetClients(ctx context.Context) ([]TailnetClient, error) {
|
||||
type GetAllTailnetClientsRow struct {
|
||||
TailnetClient TailnetClient `db:"tailnet_client" json:"tailnet_client"`
|
||||
AgentIds []uuid.UUID `db:"agent_ids" json:"agent_ids"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetAllTailnetClients(ctx context.Context) ([]GetAllTailnetClientsRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getAllTailnetClients)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []TailnetClient
|
||||
var items []GetAllTailnetClientsRow
|
||||
for rows.Next() {
|
||||
var i TailnetClient
|
||||
var i GetAllTailnetClientsRow
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.CoordinatorID,
|
||||
&i.AgentID,
|
||||
&i.UpdatedAt,
|
||||
&i.Node,
|
||||
&i.TailnetClient.ID,
|
||||
&i.TailnetClient.CoordinatorID,
|
||||
&i.TailnetClient.UpdatedAt,
|
||||
&i.TailnetClient.Node,
|
||||
pq.Array(&i.AgentIds),
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -4293,9 +4332,13 @@ func (q *sqlQuerier) GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]Tail
|
|||
}
|
||||
|
||||
const getTailnetClientsForAgent = `-- name: GetTailnetClientsForAgent :many
|
||||
SELECT id, coordinator_id, agent_id, updated_at, node
|
||||
SELECT id, coordinator_id, updated_at, node
|
||||
FROM tailnet_clients
|
||||
WHERE agent_id = $1
|
||||
WHERE id IN (
|
||||
SELECT tailnet_client_subscriptions.client_id
|
||||
FROM tailnet_client_subscriptions
|
||||
WHERE tailnet_client_subscriptions.agent_id = $1
|
||||
)
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]TailnetClient, error) {
|
||||
|
@ -4310,7 +4353,6 @@ func (q *sqlQuerier) GetTailnetClientsForAgent(ctx context.Context, agentID uuid
|
|||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.CoordinatorID,
|
||||
&i.AgentID,
|
||||
&i.UpdatedAt,
|
||||
&i.Node,
|
||||
); err != nil {
|
||||
|
@ -4369,47 +4411,67 @@ INSERT INTO
|
|||
tailnet_clients (
|
||||
id,
|
||||
coordinator_id,
|
||||
agent_id,
|
||||
node,
|
||||
updated_at
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, now() at time zone 'utc')
|
||||
($1, $2, $3, now() at time zone 'utc')
|
||||
ON CONFLICT (id, coordinator_id)
|
||||
DO UPDATE SET
|
||||
id = $1,
|
||||
coordinator_id = $2,
|
||||
agent_id = $3,
|
||||
node = $4,
|
||||
node = $3,
|
||||
updated_at = now() at time zone 'utc'
|
||||
RETURNING id, coordinator_id, agent_id, updated_at, node
|
||||
RETURNING id, coordinator_id, updated_at, node
|
||||
`
|
||||
|
||||
type UpsertTailnetClientParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"`
|
||||
AgentID uuid.UUID `db:"agent_id" json:"agent_id"`
|
||||
Node json.RawMessage `db:"node" json:"node"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpsertTailnetClient(ctx context.Context, arg UpsertTailnetClientParams) (TailnetClient, error) {
|
||||
row := q.db.QueryRowContext(ctx, upsertTailnetClient,
|
||||
arg.ID,
|
||||
arg.CoordinatorID,
|
||||
arg.AgentID,
|
||||
arg.Node,
|
||||
)
|
||||
row := q.db.QueryRowContext(ctx, upsertTailnetClient, arg.ID, arg.CoordinatorID, arg.Node)
|
||||
var i TailnetClient
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.CoordinatorID,
|
||||
&i.AgentID,
|
||||
&i.UpdatedAt,
|
||||
&i.Node,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const upsertTailnetClientSubscription = `-- name: UpsertTailnetClientSubscription :exec
|
||||
INSERT INTO
|
||||
tailnet_client_subscriptions (
|
||||
client_id,
|
||||
coordinator_id,
|
||||
agent_id,
|
||||
updated_at
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, now() at time zone 'utc')
|
||||
ON CONFLICT (client_id, coordinator_id, agent_id)
|
||||
DO UPDATE SET
|
||||
client_id = $1,
|
||||
coordinator_id = $2,
|
||||
agent_id = $3,
|
||||
updated_at = now() at time zone 'utc'
|
||||
`
|
||||
|
||||
type UpsertTailnetClientSubscriptionParams struct {
|
||||
ClientID uuid.UUID `db:"client_id" json:"client_id"`
|
||||
CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"`
|
||||
AgentID uuid.UUID `db:"agent_id" json:"agent_id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpsertTailnetClientSubscription(ctx context.Context, arg UpsertTailnetClientSubscriptionParams) error {
|
||||
_, err := q.db.ExecContext(ctx, upsertTailnetClientSubscription, arg.ClientID, arg.CoordinatorID, arg.AgentID)
|
||||
return err
|
||||
}
|
||||
|
||||
const upsertTailnetCoordinator = `-- name: UpsertTailnetCoordinator :one
|
||||
INSERT INTO
|
||||
tailnet_coordinators (
|
||||
|
|
|
@ -3,21 +3,36 @@ INSERT INTO
|
|||
tailnet_clients (
|
||||
id,
|
||||
coordinator_id,
|
||||
agent_id,
|
||||
node,
|
||||
updated_at
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, now() at time zone 'utc')
|
||||
($1, $2, $3, now() at time zone 'utc')
|
||||
ON CONFLICT (id, coordinator_id)
|
||||
DO UPDATE SET
|
||||
id = $1,
|
||||
coordinator_id = $2,
|
||||
agent_id = $3,
|
||||
node = $4,
|
||||
node = $3,
|
||||
updated_at = now() at time zone 'utc'
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpsertTailnetClientSubscription :exec
|
||||
INSERT INTO
|
||||
tailnet_client_subscriptions (
|
||||
client_id,
|
||||
coordinator_id,
|
||||
agent_id,
|
||||
updated_at
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, now() at time zone 'utc')
|
||||
ON CONFLICT (client_id, coordinator_id, agent_id)
|
||||
DO UPDATE SET
|
||||
client_id = $1,
|
||||
coordinator_id = $2,
|
||||
agent_id = $3,
|
||||
updated_at = now() at time zone 'utc';
|
||||
|
||||
-- name: UpsertTailnetAgent :one
|
||||
INSERT INTO
|
||||
tailnet_agents (
|
||||
|
@ -43,6 +58,16 @@ FROM tailnet_clients
|
|||
WHERE id = $1 and coordinator_id = $2
|
||||
RETURNING id, coordinator_id;
|
||||
|
||||
-- name: DeleteTailnetClientSubscription :exec
|
||||
DELETE
|
||||
FROM tailnet_client_subscriptions
|
||||
WHERE client_id = $1 and agent_id = $2 and coordinator_id = $3;
|
||||
|
||||
-- name: DeleteAllTailnetClientSubscriptions :exec
|
||||
DELETE
|
||||
FROM tailnet_client_subscriptions
|
||||
WHERE client_id = $1 and coordinator_id = $2;
|
||||
|
||||
-- name: DeleteTailnetAgent :one
|
||||
DELETE
|
||||
FROM tailnet_agents
|
||||
|
@ -66,12 +91,17 @@ FROM tailnet_agents;
|
|||
-- name: GetTailnetClientsForAgent :many
|
||||
SELECT *
|
||||
FROM tailnet_clients
|
||||
WHERE agent_id = $1;
|
||||
WHERE id IN (
|
||||
SELECT tailnet_client_subscriptions.client_id
|
||||
FROM tailnet_client_subscriptions
|
||||
WHERE tailnet_client_subscriptions.agent_id = $1
|
||||
);
|
||||
|
||||
-- name: GetAllTailnetClients :many
|
||||
SELECT *
|
||||
SELECT sqlc.embed(tailnet_clients), array_agg(tailnet_client_subscriptions.agent_id)::uuid[] as agent_ids
|
||||
FROM tailnet_clients
|
||||
ORDER BY agent_id;
|
||||
LEFT JOIN tailnet_client_subscriptions
|
||||
ON tailnet_clients.id = tailnet_client_subscriptions.client_id;
|
||||
|
||||
-- name: UpsertTailnetCoordinator :one
|
||||
INSERT INTO
|
||||
|
|
|
@ -68,6 +68,7 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request
|
|||
|
||||
id := uuid.New()
|
||||
sub := (*api.AGPL.TailnetCoordinator.Load()).ServeMultiAgent(id)
|
||||
|
||||
ctx, nc := websocketNetConn(ctx, conn, websocket.MessageText)
|
||||
defer nc.Close()
|
||||
|
||||
|
|
|
@ -0,0 +1,137 @@
|
|||
package tailnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
"nhooyr.io/websocket"
|
||||
|
||||
"cdr.dev/slog"
|
||||
agpl "github.com/coder/coder/v2/tailnet"
|
||||
)
|
||||
|
||||
// connIO manages the reading and writing to a connected client or agent. Agent connIOs have their client field set to
|
||||
// uuid.Nil. It reads node updates via its decoder, then pushes them onto the bindings channel. It receives mappings
|
||||
// via its updates TrackedConn, which then writes them.
|
||||
type connIO struct {
|
||||
pCtx context.Context
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
logger slog.Logger
|
||||
decoder *json.Decoder
|
||||
updates *agpl.TrackedConn
|
||||
bindings chan<- binding
|
||||
}
|
||||
|
||||
func newConnIO(pCtx context.Context,
|
||||
logger slog.Logger,
|
||||
bindings chan<- binding,
|
||||
conn net.Conn,
|
||||
id uuid.UUID,
|
||||
name string,
|
||||
kind agpl.QueueKind,
|
||||
) *connIO {
|
||||
ctx, cancel := context.WithCancel(pCtx)
|
||||
c := &connIO{
|
||||
pCtx: pCtx,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
decoder: json.NewDecoder(conn),
|
||||
updates: agpl.NewTrackedConn(ctx, cancel, conn, id, logger, name, 0, kind),
|
||||
bindings: bindings,
|
||||
}
|
||||
go c.recvLoop()
|
||||
go c.updates.SendUpdates()
|
||||
logger.Info(ctx, "serving connection")
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *connIO) recvLoop() {
|
||||
defer func() {
|
||||
// withdraw bindings when we exit. We need to use the parent context here, since our own context might be
|
||||
// canceled, but we still need to withdraw bindings.
|
||||
b := binding{
|
||||
bKey: bKey{
|
||||
id: c.UniqueID(),
|
||||
kind: c.Kind(),
|
||||
},
|
||||
}
|
||||
if err := sendCtx(c.pCtx, c.bindings, b); err != nil {
|
||||
c.logger.Debug(c.ctx, "parent context expired while withdrawing bindings", slog.Error(err))
|
||||
}
|
||||
}()
|
||||
defer c.cancel()
|
||||
for {
|
||||
var node agpl.Node
|
||||
err := c.decoder.Decode(&node)
|
||||
if err != nil {
|
||||
if xerrors.Is(err, io.EOF) ||
|
||||
xerrors.Is(err, io.ErrClosedPipe) ||
|
||||
xerrors.Is(err, context.Canceled) ||
|
||||
xerrors.Is(err, context.DeadlineExceeded) ||
|
||||
websocket.CloseStatus(err) > 0 {
|
||||
c.logger.Debug(c.ctx, "exiting recvLoop", slog.Error(err))
|
||||
} else {
|
||||
c.logger.Error(c.ctx, "failed to decode Node update", slog.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
c.logger.Debug(c.ctx, "got node update", slog.F("node", node))
|
||||
b := binding{
|
||||
bKey: bKey{
|
||||
id: c.UniqueID(),
|
||||
kind: c.Kind(),
|
||||
},
|
||||
node: &node,
|
||||
}
|
||||
if err := sendCtx(c.ctx, c.bindings, b); err != nil {
|
||||
c.logger.Debug(c.ctx, "recvLoop ctx expired", slog.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connIO) UniqueID() uuid.UUID {
|
||||
return c.updates.UniqueID()
|
||||
}
|
||||
|
||||
func (c *connIO) Kind() agpl.QueueKind {
|
||||
return c.updates.Kind()
|
||||
}
|
||||
|
||||
func (c *connIO) Enqueue(n []*agpl.Node) error {
|
||||
return c.updates.Enqueue(n)
|
||||
}
|
||||
|
||||
func (c *connIO) Name() string {
|
||||
return c.updates.Name()
|
||||
}
|
||||
|
||||
func (c *connIO) Stats() (start int64, lastWrite int64) {
|
||||
return c.updates.Stats()
|
||||
}
|
||||
|
||||
func (c *connIO) Overwrites() int64 {
|
||||
return c.updates.Overwrites()
|
||||
}
|
||||
|
||||
// CoordinatorClose is used by the coordinator when closing a Queue. It
|
||||
// should skip removing itself from the coordinator.
|
||||
func (c *connIO) CoordinatorClose() error {
|
||||
c.cancel()
|
||||
return c.updates.CoordinatorClose()
|
||||
}
|
||||
|
||||
func (c *connIO) Done() <-chan struct{} {
|
||||
return c.ctx.Done()
|
||||
}
|
||||
|
||||
func (c *connIO) Close() error {
|
||||
c.cancel()
|
||||
return c.updates.Close()
|
||||
}
|
|
@ -58,7 +58,7 @@ func (c *haCoordinator) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn {
|
|||
AgentIsLegacyFunc: c.agentIsLegacy,
|
||||
OnSubscribe: c.clientSubscribeToAgent,
|
||||
OnNodeUpdate: c.clientNodeUpdate,
|
||||
OnRemove: c.clientDisconnected,
|
||||
OnRemove: func(enq agpl.Queue) { c.clientDisconnected(enq.UniqueID()) },
|
||||
}).Init()
|
||||
c.addClient(id, m)
|
||||
return m
|
||||
|
@ -157,7 +157,7 @@ func (c *haCoordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error
|
|||
defer cancel()
|
||||
logger := c.clientLogger(id, agentID)
|
||||
|
||||
tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0)
|
||||
tc := agpl.NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0, agpl.QueueKindClient)
|
||||
defer tc.Close()
|
||||
|
||||
c.addClient(id, tc)
|
||||
|
@ -300,7 +300,7 @@ func (c *haCoordinator) ServeAgent(conn net.Conn, id uuid.UUID, name string) err
|
|||
}
|
||||
// This uniquely identifies a connection that belongs to this goroutine.
|
||||
unique := uuid.New()
|
||||
tc := agpl.NewTrackedConn(ctx, cancel, conn, unique, logger, name, overwrites)
|
||||
tc := agpl.NewTrackedConn(ctx, cancel, conn, unique, logger, name, overwrites, agpl.QueueKindAgent)
|
||||
|
||||
// Publish all nodes on this instance that want to connect to this agent.
|
||||
nodes := c.nodesSubscribedToAgent(id)
|
||||
|
|
|
@ -0,0 +1,354 @@
|
|||
package tailnet_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/enterprise/tailnet"
|
||||
agpl "github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// TestPGCoordinator_MultiAgent tests a single coordinator with a MultiAgent
|
||||
// connecting to one agent.
|
||||
//
|
||||
// +--------+
|
||||
// agent1 ---> | coord1 | <--- client
|
||||
// +--------+
|
||||
func TestPGCoordinator_MultiAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("test only with postgres")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord1.Close()
|
||||
|
||||
agent1 := newTestAgent(t, coord1, "agent1")
|
||||
defer agent1.close()
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
|
||||
|
||||
id := uuid.New()
|
||||
ma1 := coord1.ServeMultiAgent(id)
|
||||
defer ma1.Close()
|
||||
|
||||
err = ma1.SubscribeAgent(agent1.id)
|
||||
require.NoError(t, err)
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
|
||||
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1)
|
||||
|
||||
err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3})
|
||||
require.NoError(t, err)
|
||||
assertEventuallyHasDERPs(ctx, t, agent1, 3)
|
||||
|
||||
require.NoError(t, ma1.Close())
|
||||
require.NoError(t, agent1.close())
|
||||
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
||||
assertEventuallyNoAgents(ctx, t, store, agent1.id)
|
||||
}
|
||||
|
||||
// TestPGCoordinator_MultiAgent_UnsubscribeRace tests a single coordinator with
|
||||
// a MultiAgent connecting to one agent. It tries to race a call to Unsubscribe
|
||||
// with the MultiAgent closing.
|
||||
//
|
||||
// +--------+
|
||||
// agent1 ---> | coord1 | <--- client
|
||||
// +--------+
|
||||
func TestPGCoordinator_MultiAgent_UnsubscribeRace(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("test only with postgres")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord1.Close()
|
||||
|
||||
agent1 := newTestAgent(t, coord1, "agent1")
|
||||
defer agent1.close()
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
|
||||
|
||||
id := uuid.New()
|
||||
ma1 := coord1.ServeMultiAgent(id)
|
||||
defer ma1.Close()
|
||||
|
||||
err = ma1.SubscribeAgent(agent1.id)
|
||||
require.NoError(t, err)
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
|
||||
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1)
|
||||
|
||||
err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3})
|
||||
require.NoError(t, err)
|
||||
assertEventuallyHasDERPs(ctx, t, agent1, 3)
|
||||
|
||||
require.NoError(t, ma1.UnsubscribeAgent(agent1.id))
|
||||
require.NoError(t, ma1.Close())
|
||||
require.NoError(t, agent1.close())
|
||||
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
||||
assertEventuallyNoAgents(ctx, t, store, agent1.id)
|
||||
}
|
||||
|
||||
// TestPGCoordinator_MultiAgent_Unsubscribe tests a single coordinator with a
|
||||
// MultiAgent connecting to one agent. It unsubscribes before closing, and
|
||||
// ensures node updates are no longer propagated.
|
||||
//
|
||||
// +--------+
|
||||
// agent1 ---> | coord1 | <--- client
|
||||
// +--------+
|
||||
func TestPGCoordinator_MultiAgent_Unsubscribe(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("test only with postgres")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord1.Close()
|
||||
|
||||
agent1 := newTestAgent(t, coord1, "agent1")
|
||||
defer agent1.close()
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
|
||||
|
||||
id := uuid.New()
|
||||
ma1 := coord1.ServeMultiAgent(id)
|
||||
defer ma1.Close()
|
||||
|
||||
err = ma1.SubscribeAgent(agent1.id)
|
||||
require.NoError(t, err)
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
|
||||
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1)
|
||||
|
||||
require.NoError(t, ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3}))
|
||||
assertEventuallyHasDERPs(ctx, t, agent1, 3)
|
||||
|
||||
require.NoError(t, ma1.UnsubscribeAgent(agent1.id))
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
||||
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
|
||||
defer cancel()
|
||||
require.NoError(t, ma1.UpdateSelf(&agpl.Node{PreferredDERP: 9}))
|
||||
assertNeverHasDERPs(ctx, t, agent1, 9)
|
||||
}()
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(ctx, testutil.IntervalSlow*3)
|
||||
defer cancel()
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 8})
|
||||
assertMultiAgentNeverHasDERPs(ctx, t, ma1, 8)
|
||||
}()
|
||||
|
||||
require.NoError(t, ma1.Close())
|
||||
require.NoError(t, agent1.close())
|
||||
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
||||
assertEventuallyNoAgents(ctx, t, store, agent1.id)
|
||||
}
|
||||
|
||||
// TestPGCoordinator_MultiAgent_MultiCoordinator tests two coordinators with a
|
||||
// MultiAgent connecting to an agent on a separate coordinator.
|
||||
//
|
||||
// +--------+
|
||||
// agent1 ---> | coord1 |
|
||||
// +--------+
|
||||
// +--------+
|
||||
// | coord2 | <--- client
|
||||
// +--------+
|
||||
func TestPGCoordinator_MultiAgent_MultiCoordinator(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("test only with postgres")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord1.Close()
|
||||
coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord2.Close()
|
||||
|
||||
agent1 := newTestAgent(t, coord1, "agent1")
|
||||
defer agent1.close()
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
|
||||
|
||||
id := uuid.New()
|
||||
ma1 := coord2.ServeMultiAgent(id)
|
||||
defer ma1.Close()
|
||||
|
||||
err = ma1.SubscribeAgent(agent1.id)
|
||||
require.NoError(t, err)
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
|
||||
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1)
|
||||
|
||||
err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3})
|
||||
require.NoError(t, err)
|
||||
assertEventuallyHasDERPs(ctx, t, agent1, 3)
|
||||
|
||||
require.NoError(t, ma1.Close())
|
||||
require.NoError(t, agent1.close())
|
||||
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
||||
assertEventuallyNoAgents(ctx, t, store, agent1.id)
|
||||
}
|
||||
|
||||
// TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe tests two
|
||||
// coordinators with a MultiAgent connecting to an agent on a separate
|
||||
// coordinator. The MultiAgent updates its own node before subscribing.
|
||||
//
|
||||
// +--------+
|
||||
// agent1 ---> | coord1 |
|
||||
// +--------+
|
||||
// +--------+
|
||||
// | coord2 | <--- client
|
||||
// +--------+
|
||||
func TestPGCoordinator_MultiAgent_MultiCoordinator_UpdateBeforeSubscribe(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("test only with postgres")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord1.Close()
|
||||
coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord2.Close()
|
||||
|
||||
agent1 := newTestAgent(t, coord1, "agent1")
|
||||
defer agent1.close()
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
|
||||
|
||||
id := uuid.New()
|
||||
ma1 := coord2.ServeMultiAgent(id)
|
||||
defer ma1.Close()
|
||||
|
||||
err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ma1.SubscribeAgent(agent1.id)
|
||||
require.NoError(t, err)
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
|
||||
assertEventuallyHasDERPs(ctx, t, agent1, 3)
|
||||
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1)
|
||||
|
||||
require.NoError(t, ma1.Close())
|
||||
require.NoError(t, agent1.close())
|
||||
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
||||
assertEventuallyNoAgents(ctx, t, store, agent1.id)
|
||||
}
|
||||
|
||||
// TestPGCoordinator_MultiAgent_TwoAgents tests three coordinators with a
|
||||
// MultiAgent connecting to two agents on separate coordinators.
|
||||
//
|
||||
// +--------+
|
||||
// agent1 ---> | coord1 |
|
||||
// +--------+
|
||||
// +--------+
|
||||
// agent2 ---> | coord2 |
|
||||
// +--------+
|
||||
// +--------+
|
||||
// | coord3 | <--- client
|
||||
// +--------+
|
||||
func TestPGCoordinator_MultiAgent_TwoAgents(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("test only with postgres")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
|
||||
defer cancel()
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
coord1, err := tailnet.NewPGCoord(ctx, logger.Named("coord1"), ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord1.Close()
|
||||
coord2, err := tailnet.NewPGCoord(ctx, logger.Named("coord2"), ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord2.Close()
|
||||
coord3, err := tailnet.NewPGCoord(ctx, logger.Named("coord3"), ps, store)
|
||||
require.NoError(t, err)
|
||||
defer coord3.Close()
|
||||
|
||||
agent1 := newTestAgent(t, coord1, "agent1")
|
||||
defer agent1.close()
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 5})
|
||||
|
||||
agent2 := newTestAgent(t, coord2, "agent2")
|
||||
defer agent1.close()
|
||||
agent2.sendNode(&agpl.Node{PreferredDERP: 6})
|
||||
|
||||
id := uuid.New()
|
||||
ma1 := coord3.ServeMultiAgent(id)
|
||||
defer ma1.Close()
|
||||
|
||||
err = ma1.SubscribeAgent(agent1.id)
|
||||
require.NoError(t, err)
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 5)
|
||||
|
||||
agent1.sendNode(&agpl.Node{PreferredDERP: 1})
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 1)
|
||||
|
||||
err = ma1.SubscribeAgent(agent2.id)
|
||||
require.NoError(t, err)
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 6)
|
||||
|
||||
agent2.sendNode(&agpl.Node{PreferredDERP: 2})
|
||||
assertMultiAgentEventuallyHasDERPs(ctx, t, ma1, 2)
|
||||
|
||||
err = ma1.UpdateSelf(&agpl.Node{PreferredDERP: 3})
|
||||
require.NoError(t, err)
|
||||
assertEventuallyHasDERPs(ctx, t, agent1, 3)
|
||||
assertEventuallyHasDERPs(ctx, t, agent2, 3)
|
||||
|
||||
require.NoError(t, ma1.Close())
|
||||
require.NoError(t, agent1.close())
|
||||
require.NoError(t, agent2.close())
|
||||
|
||||
assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id)
|
||||
assertEventuallyNoAgents(ctx, t, store, agent1.id)
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -438,7 +438,7 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) {
|
|||
assertEventuallyNoClientsForAgent(ctx, t, store, agent2.id)
|
||||
}
|
||||
|
||||
// TestPGCoordinator_MultiAgent tests when a single agent connects to multiple coordinators.
|
||||
// TestPGCoordinator_MultiCoordinatorAgent tests when a single agent connects to multiple coordinators.
|
||||
// We use two agent connections, but they share the same AgentID. This could happen due to a reconnection,
|
||||
// or an infrastructure problem where an old workspace is not fully cleaned up before a new one started.
|
||||
//
|
||||
|
@ -451,7 +451,7 @@ func TestPGCoordinatorDual_Mainline(t *testing.T) {
|
|||
// +---------+
|
||||
// | coord3 | <--- client
|
||||
// +---------+
|
||||
func TestPGCoordinator_MultiAgent(t *testing.T) {
|
||||
func TestPGCoordinator_MultiCoordinatorAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("test only with postgres")
|
||||
|
@ -693,8 +693,79 @@ func assertEventuallyHasDERPs(ctx context.Context, t *testing.T, c *testConn, ex
|
|||
t.Logf("expected DERP %d to be in %v", e, derps)
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assertNeverHasDERPs(ctx context.Context, t *testing.T, c *testConn, expected ...int) {
|
||||
t.Helper()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case nodes := <-c.nodeChan:
|
||||
derps := make([]int, 0, len(nodes))
|
||||
for _, n := range nodes {
|
||||
derps = append(derps, n.PreferredDERP)
|
||||
}
|
||||
for _, e := range expected {
|
||||
if slices.Contains(derps, e) {
|
||||
t.Fatalf("expected not to get DERP %d, but received it", e)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assertMultiAgentEventuallyHasDERPs(ctx context.Context, t *testing.T, ma agpl.MultiAgentConn, expected ...int) {
|
||||
t.Helper()
|
||||
for {
|
||||
nodes, ok := ma.NextUpdate(ctx)
|
||||
require.True(t, ok)
|
||||
if len(nodes) != len(expected) {
|
||||
t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
|
||||
continue
|
||||
}
|
||||
|
||||
derps := make([]int, 0, len(nodes))
|
||||
for _, n := range nodes {
|
||||
derps = append(derps, n.PreferredDERP)
|
||||
}
|
||||
for _, e := range expected {
|
||||
if !slices.Contains(derps, e) {
|
||||
t.Logf("expected DERP %d to be in %v", e, derps)
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assertMultiAgentNeverHasDERPs(ctx context.Context, t *testing.T, ma agpl.MultiAgentConn, expected ...int) {
|
||||
t.Helper()
|
||||
for {
|
||||
nodes, ok := ma.NextUpdate(ctx)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if len(nodes) != len(expected) {
|
||||
t.Logf("expected %d, got %d nodes", len(expected), len(nodes))
|
||||
continue
|
||||
}
|
||||
|
||||
derps := make([]int, 0, len(nodes))
|
||||
for _, n := range nodes {
|
||||
derps = append(derps, n.PreferredDERP)
|
||||
}
|
||||
for _, e := range expected {
|
||||
if !slices.Contains(derps, e) {
|
||||
t.Logf("expected DERP %d to be in %v", e, derps)
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -712,6 +783,7 @@ func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.
|
|||
}
|
||||
|
||||
func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) {
|
||||
t.Helper()
|
||||
assert.Eventually(t, func() bool {
|
||||
clients, err := store.GetTailnetClientsForAgent(ctx, agentID)
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
|
|
|
@ -22,7 +22,7 @@ import (
|
|||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
agpl "github.com/coder/coder/v2/tailnet"
|
||||
)
|
||||
|
||||
// Client is a HTTP client for a subset of Coder API routes that external
|
||||
|
@ -422,14 +422,14 @@ const (
|
|||
type CoordinateMessage struct {
|
||||
Type CoordinateMessageType `json:"type"`
|
||||
AgentID uuid.UUID `json:"agent_id"`
|
||||
Node *tailnet.Node `json:"node"`
|
||||
Node *agpl.Node `json:"node"`
|
||||
}
|
||||
|
||||
type CoordinateNodes struct {
|
||||
Nodes []*tailnet.Node
|
||||
Nodes []*agpl.Node
|
||||
}
|
||||
|
||||
func (c *Client) DialCoordinator(ctx context.Context) (tailnet.MultiAgentConn, error) {
|
||||
func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, error) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
coordinateURL, err := c.SDKClient.URL.Parse("/api/v2/workspaceproxies/me/coordinate")
|
||||
|
@ -463,13 +463,13 @@ func (c *Client) DialCoordinator(ctx context.Context) (tailnet.MultiAgentConn, e
|
|||
legacyAgentCache: map[uuid.UUID]bool{},
|
||||
}
|
||||
|
||||
ma := (&tailnet.MultiAgent{
|
||||
ma := (&agpl.MultiAgent{
|
||||
ID: uuid.New(),
|
||||
AgentIsLegacyFunc: rma.AgentIsLegacy,
|
||||
OnSubscribe: rma.OnSubscribe,
|
||||
OnUnsubscribe: rma.OnUnsubscribe,
|
||||
OnNodeUpdate: rma.OnNodeUpdate,
|
||||
OnRemove: func(uuid.UUID) { conn.Close(websocket.StatusGoingAway, "closed") },
|
||||
OnRemove: func(agpl.Queue) { conn.Close(websocket.StatusGoingAway, "closed") },
|
||||
}).Init()
|
||||
|
||||
go func() {
|
||||
|
@ -515,7 +515,7 @@ func (a *remoteMultiAgentHandler) writeJSON(v interface{}) error {
|
|||
|
||||
// Set a deadline so that hung connections don't put back pressure on the system.
|
||||
// Node updates are tiny, so even the dinkiest connection can handle them if it's not hung.
|
||||
err = a.nc.SetWriteDeadline(time.Now().Add(tailnet.WriteTimeout))
|
||||
err = a.nc.SetWriteDeadline(time.Now().Add(agpl.WriteTimeout))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("set write deadline: %w", err)
|
||||
}
|
||||
|
@ -537,21 +537,21 @@ func (a *remoteMultiAgentHandler) writeJSON(v interface{}) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a *remoteMultiAgentHandler) OnNodeUpdate(_ uuid.UUID, node *tailnet.Node) error {
|
||||
func (a *remoteMultiAgentHandler) OnNodeUpdate(_ uuid.UUID, node *agpl.Node) error {
|
||||
return a.writeJSON(CoordinateMessage{
|
||||
Type: CoordinateMessageTypeNodeUpdate,
|
||||
Node: node,
|
||||
})
|
||||
}
|
||||
|
||||
func (a *remoteMultiAgentHandler) OnSubscribe(_ tailnet.Queue, agentID uuid.UUID) (*tailnet.Node, error) {
|
||||
func (a *remoteMultiAgentHandler) OnSubscribe(_ agpl.Queue, agentID uuid.UUID) (*agpl.Node, error) {
|
||||
return nil, a.writeJSON(CoordinateMessage{
|
||||
Type: CoordinateMessageTypeSubscribe,
|
||||
AgentID: agentID,
|
||||
})
|
||||
}
|
||||
|
||||
func (a *remoteMultiAgentHandler) OnUnsubscribe(_ tailnet.Queue, agentID uuid.UUID) error {
|
||||
func (a *remoteMultiAgentHandler) OnUnsubscribe(_ agpl.Queue, agentID uuid.UUID) error {
|
||||
return a.writeJSON(CoordinateMessage{
|
||||
Type: CoordinateMessageTypeUnsubscribe,
|
||||
AgentID: agentID,
|
||||
|
|
|
@ -146,7 +146,7 @@ func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn {
|
|||
OnSubscribe: c.core.clientSubscribeToAgent,
|
||||
OnUnsubscribe: c.core.clientUnsubscribeFromAgent,
|
||||
OnNodeUpdate: c.core.clientNodeUpdate,
|
||||
OnRemove: c.core.clientDisconnected,
|
||||
OnRemove: func(enq Queue) { c.core.clientDisconnected(enq.UniqueID()) },
|
||||
}).Init()
|
||||
c.core.addClient(id, m)
|
||||
return m
|
||||
|
@ -191,8 +191,16 @@ type core struct {
|
|||
legacyAgents map[uuid.UUID]struct{}
|
||||
}
|
||||
|
||||
type QueueKind int
|
||||
|
||||
const (
|
||||
QueueKindClient QueueKind = 1 + iota
|
||||
QueueKindAgent
|
||||
)
|
||||
|
||||
type Queue interface {
|
||||
UniqueID() uuid.UUID
|
||||
Kind() QueueKind
|
||||
Enqueue(n []*Node) error
|
||||
Name() string
|
||||
Stats() (start, lastWrite int64)
|
||||
|
@ -200,6 +208,7 @@ type Queue interface {
|
|||
// CoordinatorClose is used by the coordinator when closing a Queue. It
|
||||
// should skip removing itself from the coordinator.
|
||||
CoordinatorClose() error
|
||||
Done() <-chan struct{}
|
||||
Close() error
|
||||
}
|
||||
|
||||
|
@ -264,7 +273,7 @@ func (c *coordinator) ServeClient(conn net.Conn, id, agentID uuid.UUID) error {
|
|||
logger := c.core.clientLogger(id, agentID)
|
||||
logger.Debug(ctx, "coordinating client")
|
||||
|
||||
tc := NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0)
|
||||
tc := NewTrackedConn(ctx, cancel, conn, id, logger, id.String(), 0, QueueKindClient)
|
||||
defer tc.Close()
|
||||
|
||||
c.core.addClient(id, tc)
|
||||
|
@ -509,7 +518,7 @@ func (c *core) initAndTrackAgent(ctx context.Context, cancel func(), conn net.Co
|
|||
overwrites = oldAgentSocket.Overwrites() + 1
|
||||
_ = oldAgentSocket.Close()
|
||||
}
|
||||
tc := NewTrackedConn(ctx, cancel, conn, unique, logger, name, overwrites)
|
||||
tc := NewTrackedConn(ctx, cancel, conn, unique, logger, name, overwrites, QueueKindAgent)
|
||||
c.agentNameCache.Add(id, name)
|
||||
|
||||
sockets, ok := c.agentToConnectionSockets[id]
|
||||
|
|
|
@ -29,9 +29,12 @@ type MultiAgent struct {
|
|||
OnSubscribe func(enq Queue, agent uuid.UUID) (*Node, error)
|
||||
OnUnsubscribe func(enq Queue, agent uuid.UUID) error
|
||||
OnNodeUpdate func(id uuid.UUID, node *Node) error
|
||||
OnRemove func(id uuid.UUID)
|
||||
OnRemove func(enq Queue)
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel func()
|
||||
closed bool
|
||||
|
||||
updates chan []*Node
|
||||
closeOnce sync.Once
|
||||
start int64
|
||||
|
@ -44,9 +47,14 @@ type MultiAgent struct {
|
|||
func (m *MultiAgent) Init() *MultiAgent {
|
||||
m.updates = make(chan []*Node, 128)
|
||||
m.start = time.Now().Unix()
|
||||
m.ctx, m.ctxCancel = context.WithCancel(context.Background())
|
||||
return m
|
||||
}
|
||||
|
||||
func (*MultiAgent) Kind() QueueKind {
|
||||
return QueueKindClient
|
||||
}
|
||||
|
||||
func (m *MultiAgent) UniqueID() uuid.UUID {
|
||||
return m.ID
|
||||
}
|
||||
|
@ -156,8 +164,13 @@ func (m *MultiAgent) CoordinatorClose() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (m *MultiAgent) Done() <-chan struct{} {
|
||||
return m.ctx.Done()
|
||||
}
|
||||
|
||||
func (m *MultiAgent) Close() error {
|
||||
_ = m.CoordinatorClose()
|
||||
m.closeOnce.Do(func() { m.OnRemove(m.ID) })
|
||||
m.ctxCancel()
|
||||
m.closeOnce.Do(func() { m.OnRemove(m) })
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ const WriteTimeout = time.Second * 5
|
|||
type TrackedConn struct {
|
||||
ctx context.Context
|
||||
cancel func()
|
||||
kind QueueKind
|
||||
conn net.Conn
|
||||
updates chan []*Node
|
||||
logger slog.Logger
|
||||
|
@ -35,7 +36,14 @@ type TrackedConn struct {
|
|||
overwrites int64
|
||||
}
|
||||
|
||||
func NewTrackedConn(ctx context.Context, cancel func(), conn net.Conn, id uuid.UUID, logger slog.Logger, name string, overwrites int64) *TrackedConn {
|
||||
func NewTrackedConn(ctx context.Context, cancel func(),
|
||||
conn net.Conn,
|
||||
id uuid.UUID,
|
||||
logger slog.Logger,
|
||||
name string,
|
||||
overwrites int64,
|
||||
kind QueueKind,
|
||||
) *TrackedConn {
|
||||
// buffer updates so they don't block, since we hold the
|
||||
// coordinator mutex while queuing. Node updates don't
|
||||
// come quickly, so 512 should be plenty for all but
|
||||
|
@ -53,6 +61,7 @@ func NewTrackedConn(ctx context.Context, cancel func(), conn net.Conn, id uuid.U
|
|||
lastWrite: now,
|
||||
name: name,
|
||||
overwrites: overwrites,
|
||||
kind: kind,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -70,6 +79,10 @@ func (t *TrackedConn) UniqueID() uuid.UUID {
|
|||
return t.id
|
||||
}
|
||||
|
||||
func (t *TrackedConn) Kind() QueueKind {
|
||||
return t.kind
|
||||
}
|
||||
|
||||
func (t *TrackedConn) Name() string {
|
||||
return t.name
|
||||
}
|
||||
|
@ -86,6 +99,10 @@ func (t *TrackedConn) CoordinatorClose() error {
|
|||
return t.Close()
|
||||
}
|
||||
|
||||
func (t *TrackedConn) Done() <-chan struct{} {
|
||||
return t.ctx.Done()
|
||||
}
|
||||
|
||||
// Close the connection and cancel the context for reading node updates from the queue
|
||||
func (t *TrackedConn) Close() error {
|
||||
t.cancel()
|
||||
|
|
Loading…
Reference in New Issue