diff --git a/coderd/coderd.go b/coderd/coderd.go index 119953dfaa..04fa372879 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -125,6 +125,8 @@ type Options struct { ExternalAuthConfigs []*externalauth.Config RealIPConfig *httpmw.RealIPConfig TrialGenerator func(ctx context.Context, body codersdk.LicensorTrialRequest) error + // RefreshEntitlements is used to set correct entitlements after creating first user and generating trial license. + RefreshEntitlements func(ctx context.Context) error // TLSCertificates is used to mesh DERP servers securely. TLSCertificates []tls.Certificate TailnetCoordinator tailnet.Coordinator diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index f1a53815bb..85d92a5ef6 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -107,6 +107,7 @@ type Options struct { TLSCertificates []tls.Certificate ExternalAuthConfigs []*externalauth.Config TrialGenerator func(ctx context.Context, body codersdk.LicensorTrialRequest) error + RefreshEntitlements func(ctx context.Context) error TemplateScheduleStore schedule.TemplateScheduleStore Coordinator tailnet.Coordinator @@ -434,6 +435,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can AccessControlStore: accessControlStore, TLSCertificates: options.TLSCertificates, TrialGenerator: options.TrialGenerator, + RefreshEntitlements: options.RefreshEntitlements, TailnetCoordinator: options.Coordinator, BaseDERPMap: derpMap, DERPMapUpdateFrequency: 150 * time.Millisecond, diff --git a/coderd/users.go b/coderd/users.go index be4e46ea7f..cbc9a75059 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -191,6 +191,16 @@ func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) { return } + if api.RefreshEntitlements != nil { + err = api.RefreshEntitlements(ctx) + if err != nil { + api.Logger.Error(ctx, "failed to refresh entitlements after generating trial license") + return + } + } else { + api.Logger.Debug(ctx, "entitlements will not be refreshed") + } + telemetryUser := telemetry.ConvertUser(user) // Send the initial users email address! telemetryUser.Email = &user.Email diff --git a/coderd/users_test.go b/coderd/users_test.go index 0d7f5a7bb2..1c962a50b9 100644 --- a/coderd/users_test.go +++ b/coderd/users_test.go @@ -75,10 +75,16 @@ func TestFirstUser(t *testing.T) { t.Run("Trial", func(t *testing.T) { t.Parallel() - called := make(chan struct{}) + trialGenerated := make(chan struct{}) + entitlementsRefreshed := make(chan struct{}) + client := coderdtest.New(t, &coderdtest.Options{ TrialGenerator: func(context.Context, codersdk.LicensorTrialRequest) error { - close(called) + close(trialGenerated) + return nil + }, + RefreshEntitlements: func(context.Context) error { + close(entitlementsRefreshed) return nil }, }) @@ -94,7 +100,9 @@ func TestFirstUser(t *testing.T) { } _, err := client.CreateFirstUser(ctx, req) require.NoError(t, err) - <-called + + _ = testutil.RequireRecvCtx(ctx, t, trialGenerated) + _ = testutil.RequireRecvCtx(ctx, t, entitlementsRefreshed) }) } diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 85903d023b..8c22ea0f0b 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -203,6 +203,9 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { }) }) }) + api.AGPL.RefreshEntitlements = func(ctx context.Context) error { + return api.refreshEntitlements(ctx) + } api.AGPL.APIHandler.Group(func(r chi.Router) { r.Get("/entitlements", api.serveEntitlements) diff --git a/enterprise/coderd/licenses.go b/enterprise/coderd/licenses.go index b7c7b5af6e..9c32268b39 100644 --- a/enterprise/coderd/licenses.go +++ b/enterprise/coderd/licenses.go @@ -197,20 +197,10 @@ func (api *API) postRefreshEntitlements(rw http.ResponseWriter, r *http.Request) return } - err = api.updateEntitlements(ctx) + err = api.refreshEntitlements(ctx) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to update entitlements", - Detail: err.Error(), - }) - return - } - - err = api.Pubsub.Publish(PubsubEventLicenses, []byte("refresh")) - if err != nil { - api.Logger.Error(context.Background(), "failed to publish forced entitlement update", slog.Error(err)) - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to publish forced entitlement update. Other replicas might not be updated.", + Message: "Failed to refresh entitlements", Detail: err.Error(), }) return @@ -221,6 +211,21 @@ func (api *API) postRefreshEntitlements(rw http.ResponseWriter, r *http.Request) }) } +func (api *API) refreshEntitlements(ctx context.Context) error { + api.Logger.Info(ctx, "refresh entitlements now") + + err := api.updateEntitlements(ctx) + if err != nil { + return xerrors.Errorf("failed to update entitlements: %w", err) + } + err = api.Pubsub.Publish(PubsubEventLicenses, []byte("refresh")) + if err != nil { + api.Logger.Error(ctx, "failed to publish forced entitlement update", slog.Error(err)) + return xerrors.Errorf("failed to publish forced entitlement update, other replicas might not be updated: %w", err) + } + return nil +} + // @Summary Get licenses // @ID get-licenses // @Security CoderSessionToken diff --git a/enterprise/coderd/users_test.go b/enterprise/coderd/users_test.go index 05bfa80e87..ede99551ef 100644 --- a/enterprise/coderd/users_test.go +++ b/enterprise/coderd/users_test.go @@ -1,6 +1,7 @@ package coderd_test import ( + "context" "net/http" "testing" "time" @@ -218,3 +219,21 @@ func TestUserQuietHours(t *testing.T) { require.Contains(t, sdkErr.Message, "cannot set custom quiet hours schedule") }) } + +func TestCreateFirstUser_Entitlements_Trial(t *testing.T) { + t.Parallel() + + adminClient, _ := coderdenttest.New(t, &coderdenttest.Options{ + LicenseOptions: &coderdenttest.LicenseOptions{ + Trial: true, + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + + //nolint:gocritic // we need the first user so admin + entitlements, err := adminClient.Entitlements(ctx) + require.NoError(t, err) + require.True(t, entitlements.Trial, "Trial license should be immediately active.") +}