coderd: tighten /login rate limiting (#4432)

* coderd: tighten /login rate limit

* coderd: add Bypass rate limit header
This commit is contained in:
Ammar Bandukwala 2022-10-20 12:01:23 -05:00 committed by GitHub
parent 43f199a987
commit 423ac04156
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 211 additions and 20 deletions

View File

@ -235,10 +235,15 @@ linters:
- noctx
- paralleltest
- revive
- rowserrcheck
- sqlclosecheck
# These don't work until the following issue is solved.
# https://github.com/golangci/golangci-lint/issues/2649
# - rowserrcheck
# - sqlclosecheck
# - structcheck
# - wastedassign
- staticcheck
- structcheck
- tenv
# In Go, it's possible for a package to test it's internal functionality
# without testing any exported functions. This is enabled to promote
@ -253,4 +258,3 @@ linters:
- unconvert
- unused
- varcheck
- wastedassign

View File

@ -204,7 +204,7 @@ func New(options *Options) *API {
// app URL. If it is, it will serve that application.
api.handleSubdomainApplications(
// Middleware to impose on the served application.
httpmw.RateLimitPerMinute(options.APIRateLimit),
httpmw.RateLimit(options.APIRateLimit, time.Minute),
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
DB: options.Database,
OAuth2Configs: oauthConfigs,
@ -229,7 +229,7 @@ func New(options *Options) *API {
apps := func(r chi.Router) {
r.Use(
tracing.Middleware(api.TracerProvider),
httpmw.RateLimitPerMinute(options.APIRateLimit),
httpmw.RateLimit(options.APIRateLimit, time.Minute),
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
DB: options.Database,
OAuth2Configs: oauthConfigs,
@ -267,7 +267,7 @@ func New(options *Options) *API {
r.Use(
tracing.Middleware(api.TracerProvider),
// Specific routes can specify smaller limits.
httpmw.RateLimitPerMinute(options.APIRateLimit),
httpmw.RateLimit(options.APIRateLimit, time.Minute),
)
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
httpapi.Write(r.Context(), w, http.StatusOK, codersdk.Response{
@ -304,7 +304,7 @@ func New(options *Options) *API {
apiKeyMiddleware,
// This number is arbitrary, but reading/writing
// file content is expensive so it should be small.
httpmw.RateLimitPerMinute(12),
httpmw.RateLimit(12, time.Minute),
)
r.Get("/{fileID}", api.fileByID)
r.Post("/", api.postFile)
@ -391,7 +391,15 @@ func New(options *Options) *API {
r.Route("/users", func(r chi.Router) {
r.Get("/first", api.firstUser)
r.Post("/first", api.postFirstUser)
r.Post("/login", api.postLogin)
r.Group(func(r chi.Router) {
// We use a tight limit for password login to protect
// against audit-log write DoS, pbkdf2 DoS, and simple
// brute-force attacks.
//
// Making this too small can break tests.
r.Use(httpmw.RateLimit(60, time.Minute))
r.Post("/login", api.postLogin)
})
r.Get("/authmethods", api.userAuthMethods)
r.Route("/oauth2", func(r chi.Router) {
r.Route("/github", func(r chi.Router) {

View File

@ -631,8 +631,8 @@ func TestAPIKey(t *testing.T) {
})
}
func createUser(ctx context.Context, t *testing.T, db database.Store) database.User {
user, err := db.InsertUser(ctx, database.InsertUserParams{
func createUser(ctx context.Context, t *testing.T, db database.Store, opts ...func(u *database.InsertUserParams)) database.User {
insert := database.InsertUserParams{
ID: uuid.New(),
Email: "email@coder.com",
Username: "username",
@ -640,7 +640,11 @@ func createUser(ctx context.Context, t *testing.T, db database.Store) database.U
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
RBACRoles: []string{},
})
}
for _, opt := range opts {
opt(&insert)
}
user, err := db.InsertUser(ctx, insert)
require.NoError(t, err, "create user")
return user
}

View File

@ -1,39 +1,71 @@
package httpmw
import (
"fmt"
"net/http"
"strconv"
"time"
"github.com/go-chi/httprate"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/cryptorand"
)
// RateLimitPerMinute returns a handler that limits requests per-minute based
// RateLimit returns a handler that limits requests per-minute based
// on IP, endpoint, and user ID (if available).
func RateLimitPerMinute(count int) func(http.Handler) http.Handler {
func RateLimit(count int, window time.Duration) func(http.Handler) http.Handler {
// -1 is no rate limit
if count <= 0 {
return func(handler http.Handler) http.Handler {
return handler
}
}
return httprate.Limit(
count,
1*time.Minute,
window,
httprate.WithKeyFuncs(func(r *http.Request) (string, error) {
// Prioritize by user, but fallback to IP.
apiKey, ok := r.Context().Value(apiKeyContextKey{}).(database.APIKey)
if ok {
if !ok {
return httprate.KeyByIP(r)
}
if ok, _ := strconv.ParseBool(r.Header.Get(codersdk.BypassRatelimitHeader)); !ok {
// No bypass attempt, just ratelimit.
return apiKey.UserID.String(), nil
}
return httprate.KeyByIP(r)
// Allow Owner to bypass rate limiting for load tests
// and automation.
auth := UserAuthorization(r)
// We avoid using rbac.Authorizer since rego is CPU-intensive
// and undermines the DoS-prevention goal of the rate limiter.
for _, role := range auth.Roles {
if role == rbac.RoleOwner() {
// HACK: use a random key each time to
// de facto disable rate limiting. The
// `httprate` package has no
// support for selectively changing the limit
// for particular keys.
return cryptorand.String(16)
}
}
return apiKey.UserID.String(), xerrors.Errorf(
"%q provided but user is not %v",
codersdk.BypassRatelimitHeader, rbac.RoleOwner(),
)
}, httprate.KeyByEndpoint),
httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) {
httpapi.Write(r.Context(), w, http.StatusTooManyRequests, codersdk.Response{
Message: "You've been rate limited for sending too many requests!",
Message: fmt.Sprintf("You've been rate limited for sending more than %v requests in %v.", count, window),
})
}),
)

View File

@ -1,23 +1,60 @@
package httpmw_test
import (
"context"
"crypto/sha256"
"fmt"
"math/rand"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/testutil"
)
func insertAPIKey(ctx context.Context, t *testing.T, db database.Store, userID uuid.UUID) string {
id, secret := randomAPIKeyParts()
hashed := sha256.Sum256([]byte(secret))
_, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{
ID: id,
HashedSecret: hashed[:],
LastUsed: database.Now().AddDate(0, 0, -1),
ExpiresAt: database.Now().AddDate(0, 0, 1),
UserID: userID,
LoginType: database.LoginTypePassword,
Scope: database.APIKeyScopeAll,
})
require.NoError(t, err)
return fmt.Sprintf("%s-%s", id, secret)
}
func randRemoteAddr() string {
var b [4]byte
// nolint:gosec
rand.Read(b[:])
// nolint:gosec
return fmt.Sprintf("%s:%v", net.IP(b[:]).String(), rand.Int31()%(1<<16))
}
func TestRateLimit(t *testing.T) {
t.Parallel()
t.Run("NoUser", func(t *testing.T) {
t.Run("NoUserSucceeds", func(t *testing.T) {
t.Parallel()
rtr := chi.NewRouter()
rtr.Use(httpmw.RateLimitPerMinute(5))
rtr.Use(httpmw.RateLimit(5, time.Second))
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusOK)
})
@ -31,4 +68,107 @@ func TestRateLimit(t *testing.T) {
return resp.StatusCode == http.StatusTooManyRequests
}, testutil.WaitShort, testutil.IntervalFast)
})
t.Run("RandomIPs", func(t *testing.T) {
t.Parallel()
rtr := chi.NewRouter()
rtr.Use(httpmw.RateLimit(5, time.Second))
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusOK)
})
require.Never(t, func() bool {
req := httptest.NewRequest("GET", "/", nil)
rec := httptest.NewRecorder()
req.RemoteAddr = randRemoteAddr()
rtr.ServeHTTP(rec, req)
resp := rec.Result()
defer resp.Body.Close()
return resp.StatusCode == http.StatusTooManyRequests
}, testutil.WaitShort, testutil.IntervalFast)
})
t.Run("RegularUser", func(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := databasefake.New()
u := createUser(ctx, t, db)
key := insertAPIKey(ctx, t, db, u.ID)
rtr := chi.NewRouter()
rtr.Use(httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
DB: db,
Optional: false,
}))
rtr.Use(httpmw.RateLimit(5, time.Second))
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusOK)
})
// Bypass must fail
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set(codersdk.SessionCustomHeader, key)
req.Header.Set(codersdk.BypassRatelimitHeader, "true")
rec := httptest.NewRecorder()
// Assert we're not using IP address.
req.RemoteAddr = randRemoteAddr()
rtr.ServeHTTP(rec, req)
resp := rec.Result()
defer resp.Body.Close()
require.Equal(t, http.StatusPreconditionRequired, resp.StatusCode)
require.Eventually(t, func() bool {
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set(codersdk.SessionCustomHeader, key)
rec := httptest.NewRecorder()
// Assert we're not using IP address.
req.RemoteAddr = randRemoteAddr()
rtr.ServeHTTP(rec, req)
resp := rec.Result()
defer resp.Body.Close()
return resp.StatusCode == http.StatusTooManyRequests
}, testutil.WaitShort, testutil.IntervalFast)
})
t.Run("OwnerBypass", func(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := databasefake.New()
u := createUser(ctx, t, db, func(u *database.InsertUserParams) {
u.RBACRoles = []string{rbac.RoleOwner()}
})
key := insertAPIKey(ctx, t, db, u.ID)
rtr := chi.NewRouter()
rtr.Use(httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
DB: db,
Optional: false,
}))
rtr.Use(httpmw.RateLimit(5, time.Second))
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusOK)
})
require.Never(t, func() bool {
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set(codersdk.SessionCustomHeader, key)
req.Header.Set(codersdk.BypassRatelimitHeader, "true")
rec := httptest.NewRecorder()
// Assert we're not using IP address.
req.RemoteAddr = randRemoteAddr()
rtr.ServeHTTP(rec, req)
resp := rec.Result()
defer resp.Body.Close()
return resp.StatusCode == http.StatusTooManyRequests
}, testutil.WaitShort, testutil.IntervalFast)
})
}

View File

@ -24,6 +24,9 @@ const (
SessionCustomHeader = "Coder-Session-Token"
OAuth2StateKey = "oauth_state"
OAuth2RedirectKey = "oauth_redirect"
// nolint: gosec
BypassRatelimitHeader = "X-Coder-Bypass-Ratelimit"
)
// New creates a Coder client for the provided URL.