fix: correctly reject quota-violating builds (#9233)

Due to a logical error in CommitQuota, all workspace Stop->Start operations
were being accepted, regardless of the Quota limit. This issue only
appeared after #9201, so this was a minor regression in main for about
3 days. This PR adds a test to make sure this kind of bug doesn't recur.

To make the new test possible, we give the echo provisioner the ability
to simulate responses to specific transitions.
This commit is contained in:
Ammar Bandukwala 2023-08-21 21:55:39 -05:00 committed by GitHub
parent 69ec8d774b
commit 545a256b57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 293 additions and 55 deletions

View File

@ -498,7 +498,10 @@ func (api *API) updateEntitlements(ctx context.Context) error {
if initial, changed, enabled := featureChanged(codersdk.FeatureTemplateRBAC); shouldUpdate(initial, changed, enabled) {
if enabled {
committer := committer{Database: api.Database}
committer := committer{
Log: api.Logger.Named("quota_committer"),
Database: api.Database,
}
ptr := proto.QuotaCommitter(&committer)
api.AGPL.QuotaCommitter.Store(&ptr)
} else {

View File

@ -3,10 +3,12 @@ package coderd
import (
"context"
"database/sql"
"errors"
"net/http"
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/httpapi"
@ -17,6 +19,7 @@ import (
)
type committer struct {
Log slog.Logger
Database database.Store
}
@ -28,12 +31,12 @@ func (c *committer) CommitQuota(
return nil, err
}
build, err := c.Database.GetWorkspaceBuildByJobID(ctx, jobID)
nextBuild, err := c.Database.GetWorkspaceBuildByJobID(ctx, jobID)
if err != nil {
return nil, err
}
workspace, err := c.Database.GetWorkspaceByID(ctx, build.WorkspaceID)
workspace, err := c.Database.GetWorkspaceByID(ctx, nextBuild.WorkspaceID)
if err != nil {
return nil, err
}
@ -58,25 +61,35 @@ func (c *committer) CommitQuota(
// If the new build will reduce overall quota consumption, then we
// allow it even if the user is over quota.
netIncrease := true
previousBuild, err := s.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{
prevBuild, err := s.GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx, database.GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams{
WorkspaceID: workspace.ID,
BuildNumber: build.BuildNumber - 1,
BuildNumber: nextBuild.BuildNumber - 1,
})
if err == nil {
if build.DailyCost < previousBuild.DailyCost {
netIncrease = false
}
} else if !xerrors.Is(err, sql.ErrNoRows) {
netIncrease = request.DailyCost >= prevBuild.DailyCost
c.Log.Debug(
ctx, "previous build cost",
slog.F("prev_cost", prevBuild.DailyCost),
slog.F("next_cost", request.DailyCost),
slog.F("net_increase", netIncrease),
)
} else if !errors.Is(err, sql.ErrNoRows) {
return err
}
newConsumed := int64(request.DailyCost) + consumed
if newConsumed > budget && netIncrease {
c.Log.Debug(
ctx, "over quota, rejecting",
slog.F("prev_consumed", consumed),
slog.F("next_consumed", newConsumed),
slog.F("budget", budget),
)
return nil
}
err = s.UpdateWorkspaceBuildCostByID(ctx, database.UpdateWorkspaceBuildCostByIDParams{
ID: build.ID,
ID: nextBuild.ID,
DailyCost: request.DailyCost,
})
if err != nil {

View File

@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
@ -31,12 +32,13 @@ func verifyQuota(ctx context.Context, t *testing.T, client *codersdk.Client, con
}
func TestWorkspaceQuota(t *testing.T) {
// TODO: refactor for new impl
t.Parallel()
t.Run("BlocksBuild", func(t *testing.T) {
// This first test verifies the behavior of creating and deleting workspaces.
// It also tests multi-group quota stacking and the everyone group.
t.Run("CreateDelete", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
max := 1
@ -49,8 +51,6 @@ func TestWorkspaceQuota(t *testing.T) {
},
})
coderdtest.NewProvisionerDaemon(t, api.AGPL)
coderdtest.NewProvisionerDaemon(t, api.AGPL)
coderdtest.NewProvisionerDaemon(t, api.AGPL)
verifyQuota(ctx, t, client, 0, 0)
@ -157,4 +157,104 @@ func TestWorkspaceQuota(t *testing.T) {
verifyQuota(ctx, t, client, 4, 4)
require.Equal(t, codersdk.WorkspaceStatusRunning, build.Status)
})
t.Run("StartStop", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
max := 1
client, _, api, user := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
UserWorkspaceQuota: max,
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureTemplateRBAC: 1,
},
},
})
coderdtest.NewProvisionerDaemon(t, api.AGPL)
verifyQuota(ctx, t, client, 0, 0)
// Patch the 'Everyone' group to verify its quota allowance is being accounted for.
_, err := client.PatchGroup(ctx, user.OrganizationID, codersdk.PatchGroupRequest{
QuotaAllowance: ptr.Ref(4),
})
require.NoError(t, err)
verifyQuota(ctx, t, client, 0, 4)
stopResp := []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "example",
Type: "aws_instance",
DailyCost: 1,
}},
},
},
}}
startResp := []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "example",
Type: "aws_instance",
DailyCost: 2,
}},
},
},
}}
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionPlanMap: map[proto.WorkspaceTransition][]*proto.Provision_Response{
proto.WorkspaceTransition_START: startResp,
proto.WorkspaceTransition_STOP: stopResp,
},
ProvisionApplyMap: map[proto.WorkspaceTransition][]*proto.Provision_Response{
proto.WorkspaceTransition_START: startResp,
proto.WorkspaceTransition_STOP: stopResp,
},
})
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
// Spin up two workspaces.
var wg sync.WaitGroup
var workspaces []codersdk.Workspace
for i := 0; i < 2; i++ {
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
workspaces = append(workspaces, workspace)
build := coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
assert.Equal(t, codersdk.WorkspaceStatusRunning, build.Status)
}
wg.Wait()
verifyQuota(ctx, t, client, 4, 4)
// Next one must fail
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
build := coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
require.Contains(t, build.Job.Error, "quota")
// Consumed shouldn't bump
verifyQuota(ctx, t, client, 4, 4)
require.Equal(t, codersdk.WorkspaceStatusFailed, build.Status)
build = coderdtest.CreateWorkspaceBuild(t, client, workspaces[0], database.WorkspaceTransitionStop)
build = coderdtest.AwaitWorkspaceBuildJob(t, client, build.ID)
// Quota goes down one
verifyQuota(ctx, t, client, 3, 4)
require.Equal(t, codersdk.WorkspaceStatusStopped, build.Status)
build = coderdtest.CreateWorkspaceBuild(t, client, workspaces[0], database.WorkspaceTransitionStart)
build = coderdtest.AwaitWorkspaceBuildJob(t, client, build.ID)
// Quota goes back up
verifyQuota(ctx, t, client, 4, 4)
require.Equal(t, codersdk.WorkspaceStatusRunning, build.Status)
})
}

View File

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"net"
"time"
@ -26,6 +27,9 @@ func ServeWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentC
var msg wsproxysdk.CoordinateMessage
err := decoder.Decode(&msg)
if err != nil {
if errors.Is(err, net.ErrClosed) {
return nil
}
return xerrors.Errorf("read json: %w", err)
}

View File

@ -127,19 +127,35 @@ func (e *echo) Provision(stream proto.DRPCProvisioner_ProvisionStream) error {
return nil
}
for index := 0; ; index++ {
outer:
for i := 0; ; i++ {
var extension string
if msg.GetPlan() != nil {
extension = ".plan.protobuf"
} else {
extension = ".apply.protobuf"
}
path := filepath.Join(config.Directory, fmt.Sprintf("%d.provision"+extension, index))
_, err := e.filesystem.Stat(path)
if err != nil {
if index == 0 {
// Error if nothing is around to enable failed states.
return xerrors.New("no state")
var (
path string
pathIndex int
)
// Try more specific path first, then fallback to generic.
paths := []string{
filepath.Join(config.Directory, fmt.Sprintf("%d.%s.provision"+extension, i, strings.ToLower(config.GetMetadata().GetWorkspaceTransition().String()))),
filepath.Join(config.Directory, fmt.Sprintf("%d.provision"+extension, i)),
}
for pathIndex, path = range paths {
_, err := e.filesystem.Stat(path)
if err != nil && pathIndex == len(paths)-1 {
// If there are zero messages, something is wrong.
if i == 0 {
// Error if nothing is around to enable failed states.
return xerrors.New("no state")
}
// Otherwise, we're done with the entire provision.
break outer
} else if err != nil {
continue
}
break
}
@ -170,16 +186,28 @@ func (*echo) Shutdown(_ context.Context, _ *proto.Empty) (*proto.Empty, error) {
return &proto.Empty{}, nil
}
// Responses is a collection of mocked responses to Provision operations.
type Responses struct {
Parse []*proto.Parse_Response
Parse []*proto.Parse_Response
// ProvisionApply and ProvisionPlan are used to mock ALL responses of
// Apply and Plan, regardless of transition.
ProvisionApply []*proto.Provision_Response
ProvisionPlan []*proto.Provision_Response
// ProvisionApplyMap and ProvisionPlanMap are used to mock specific
// transition responses. They are prioritized over the generic responses.
ProvisionApplyMap map[proto.WorkspaceTransition][]*proto.Provision_Response
ProvisionPlanMap map[proto.WorkspaceTransition][]*proto.Provision_Response
}
// Tar returns a tar archive of responses to provisioner operations.
func Tar(responses *Responses) ([]byte, error) {
if responses == nil {
responses = &Responses{ParseComplete, ProvisionComplete, ProvisionComplete}
responses = &Responses{
ParseComplete, ProvisionComplete, ProvisionComplete,
nil, nil,
}
}
if responses.ProvisionPlan == nil {
responses.ProvisionPlan = responses.ProvisionApply
@ -187,58 +215,61 @@ func Tar(responses *Responses) ([]byte, error) {
var buffer bytes.Buffer
writer := tar.NewWriter(&buffer)
for index, response := range responses.Parse {
data, err := protobuf.Marshal(response)
writeProto := func(name string, message protobuf.Message) error {
data, err := protobuf.Marshal(message)
if err != nil {
return nil, err
return err
}
err = writer.WriteHeader(&tar.Header{
Name: fmt.Sprintf("%d.parse.protobuf", index),
Name: name,
Size: int64(len(data)),
Mode: 0o644,
})
if err != nil {
return nil, err
return err
}
_, err = writer.Write(data)
if err != nil {
return err
}
return nil
}
for index, response := range responses.Parse {
err := writeProto(fmt.Sprintf("%d.parse.protobuf", index), response)
if err != nil {
return nil, err
}
}
for index, response := range responses.ProvisionApply {
data, err := protobuf.Marshal(response)
if err != nil {
return nil, err
}
err = writer.WriteHeader(&tar.Header{
Name: fmt.Sprintf("%d.provision.apply.protobuf", index),
Size: int64(len(data)),
Mode: 0o644,
})
if err != nil {
return nil, err
}
_, err = writer.Write(data)
err := writeProto(fmt.Sprintf("%d.provision.apply.protobuf", index), response)
if err != nil {
return nil, err
}
}
for index, response := range responses.ProvisionPlan {
data, err := protobuf.Marshal(response)
err := writeProto(fmt.Sprintf("%d.provision.plan.protobuf", index), response)
if err != nil {
return nil, err
}
err = writer.WriteHeader(&tar.Header{
Name: fmt.Sprintf("%d.provision.plan.protobuf", index),
Size: int64(len(data)),
Mode: 0o644,
})
if err != nil {
return nil, err
}
for trans, m := range responses.ProvisionApplyMap {
for i, rs := range m {
err := writeProto(fmt.Sprintf("%d.%s.provision.apply.protobuf", i, strings.ToLower(trans.String())), rs)
if err != nil {
return nil, err
}
}
_, err = writer.Write(data)
if err != nil {
return nil, err
}
for trans, m := range responses.ProvisionPlanMap {
for i, rs := range m {
err := writeProto(fmt.Sprintf("%d.%s.provision.plan.protobuf", i, strings.ToLower(trans.String())), rs)
if err != nil {
return nil, err
}
}
}
err := writer.Flush()

View File

@ -112,6 +112,92 @@ func TestEcho(t *testing.T) {
complete.GetComplete().Resources[0].Name)
})
t.Run("ProvisionStop", func(t *testing.T) {
t.Parallel()
// Stop responses should be returned when the workspace is being stopped.
defaultResponses := []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "DEFAULT",
}},
},
},
}}
stopResponses := []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "STOP",
}},
},
},
}}
data, err := echo.Tar(&echo.Responses{
ProvisionApply: defaultResponses,
ProvisionPlan: defaultResponses,
ProvisionPlanMap: map[proto.WorkspaceTransition][]*proto.Provision_Response{
proto.WorkspaceTransition_STOP: stopResponses,
},
ProvisionApplyMap: map[proto.WorkspaceTransition][]*proto.Provision_Response{
proto.WorkspaceTransition_STOP: stopResponses,
},
})
require.NoError(t, err)
client, err := api.Provision(ctx)
require.NoError(t, err)
// Do stop.
err = client.Send(&proto.Provision_Request{
Type: &proto.Provision_Request_Plan{
Plan: &proto.Provision_Plan{
Config: &proto.Provision_Config{
Directory: unpackTar(t, fs, data),
Metadata: &proto.Provision_Metadata{
WorkspaceTransition: proto.WorkspaceTransition_STOP,
},
},
},
},
})
require.NoError(t, err)
complete, err := client.Recv()
require.NoError(t, err)
require.Equal(t,
stopResponses[0].GetComplete().Resources[0].Name,
complete.GetComplete().Resources[0].Name,
)
// Do start.
client, err = api.Provision(ctx)
require.NoError(t, err)
err = client.Send(&proto.Provision_Request{
Type: &proto.Provision_Request_Plan{
Plan: &proto.Provision_Plan{
Config: &proto.Provision_Config{
Directory: unpackTar(t, fs, data),
Metadata: &proto.Provision_Metadata{
WorkspaceTransition: proto.WorkspaceTransition_START,
},
},
},
},
})
require.NoError(t, err)
complete, err = client.Recv()
require.NoError(t, err)
require.Equal(t,
defaultResponses[0].GetComplete().Resources[0].Name,
complete.GetComplete().Resources[0].Name,
)
})
t.Run("ProvisionWithLogLevel", func(t *testing.T) {
t.Parallel()

View File

@ -964,10 +964,11 @@ func (r *Runner) buildWorkspace(ctx context.Context, stage string, req *sdkproto
}
func (r *Runner) commitQuota(ctx context.Context, resources []*sdkproto.Resource) *proto.FailedJob {
cost := sumDailyCost(resources)
r.logger.Debug(ctx, "committing quota",
slog.F("resources", resources),
slog.F("cost", cost),
)
cost := sumDailyCost(resources)
if cost == 0 {
return nil
}