mirror of https://github.com/coder/coder.git
coderd: tighten /login rate limiting (#4432)
* coderd: tighten /login rate limit * coderd: add Bypass rate limit header
This commit is contained in:
parent
43f199a987
commit
423ac04156
|
@ -235,10 +235,15 @@ linters:
|
|||
- noctx
|
||||
- paralleltest
|
||||
- revive
|
||||
- rowserrcheck
|
||||
- sqlclosecheck
|
||||
|
||||
# These don't work until the following issue is solved.
|
||||
# https://github.com/golangci/golangci-lint/issues/2649
|
||||
# - rowserrcheck
|
||||
# - sqlclosecheck
|
||||
# - structcheck
|
||||
# - wastedassign
|
||||
|
||||
- staticcheck
|
||||
- structcheck
|
||||
- tenv
|
||||
# In Go, it's possible for a package to test it's internal functionality
|
||||
# without testing any exported functions. This is enabled to promote
|
||||
|
@ -253,4 +258,3 @@ linters:
|
|||
- unconvert
|
||||
- unused
|
||||
- varcheck
|
||||
- wastedassign
|
||||
|
|
|
@ -204,7 +204,7 @@ func New(options *Options) *API {
|
|||
// app URL. If it is, it will serve that application.
|
||||
api.handleSubdomainApplications(
|
||||
// Middleware to impose on the served application.
|
||||
httpmw.RateLimitPerMinute(options.APIRateLimit),
|
||||
httpmw.RateLimit(options.APIRateLimit, time.Minute),
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: options.Database,
|
||||
OAuth2Configs: oauthConfigs,
|
||||
|
@ -229,7 +229,7 @@ func New(options *Options) *API {
|
|||
apps := func(r chi.Router) {
|
||||
r.Use(
|
||||
tracing.Middleware(api.TracerProvider),
|
||||
httpmw.RateLimitPerMinute(options.APIRateLimit),
|
||||
httpmw.RateLimit(options.APIRateLimit, time.Minute),
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: options.Database,
|
||||
OAuth2Configs: oauthConfigs,
|
||||
|
@ -267,7 +267,7 @@ func New(options *Options) *API {
|
|||
r.Use(
|
||||
tracing.Middleware(api.TracerProvider),
|
||||
// Specific routes can specify smaller limits.
|
||||
httpmw.RateLimitPerMinute(options.APIRateLimit),
|
||||
httpmw.RateLimit(options.APIRateLimit, time.Minute),
|
||||
)
|
||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(r.Context(), w, http.StatusOK, codersdk.Response{
|
||||
|
@ -304,7 +304,7 @@ func New(options *Options) *API {
|
|||
apiKeyMiddleware,
|
||||
// This number is arbitrary, but reading/writing
|
||||
// file content is expensive so it should be small.
|
||||
httpmw.RateLimitPerMinute(12),
|
||||
httpmw.RateLimit(12, time.Minute),
|
||||
)
|
||||
r.Get("/{fileID}", api.fileByID)
|
||||
r.Post("/", api.postFile)
|
||||
|
@ -391,7 +391,15 @@ func New(options *Options) *API {
|
|||
r.Route("/users", func(r chi.Router) {
|
||||
r.Get("/first", api.firstUser)
|
||||
r.Post("/first", api.postFirstUser)
|
||||
r.Post("/login", api.postLogin)
|
||||
r.Group(func(r chi.Router) {
|
||||
// We use a tight limit for password login to protect
|
||||
// against audit-log write DoS, pbkdf2 DoS, and simple
|
||||
// brute-force attacks.
|
||||
//
|
||||
// Making this too small can break tests.
|
||||
r.Use(httpmw.RateLimit(60, time.Minute))
|
||||
r.Post("/login", api.postLogin)
|
||||
})
|
||||
r.Get("/authmethods", api.userAuthMethods)
|
||||
r.Route("/oauth2", func(r chi.Router) {
|
||||
r.Route("/github", func(r chi.Router) {
|
||||
|
|
|
@ -631,8 +631,8 @@ func TestAPIKey(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func createUser(ctx context.Context, t *testing.T, db database.Store) database.User {
|
||||
user, err := db.InsertUser(ctx, database.InsertUserParams{
|
||||
func createUser(ctx context.Context, t *testing.T, db database.Store, opts ...func(u *database.InsertUserParams)) database.User {
|
||||
insert := database.InsertUserParams{
|
||||
ID: uuid.New(),
|
||||
Email: "email@coder.com",
|
||||
Username: "username",
|
||||
|
@ -640,7 +640,11 @@ func createUser(ctx context.Context, t *testing.T, db database.Store) database.U
|
|||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
RBACRoles: []string{},
|
||||
})
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(&insert)
|
||||
}
|
||||
user, err := db.InsertUser(ctx, insert)
|
||||
require.NoError(t, err, "create user")
|
||||
return user
|
||||
}
|
||||
|
|
|
@ -1,39 +1,71 @@
|
|||
package httpmw
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/httprate"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
)
|
||||
|
||||
// RateLimitPerMinute returns a handler that limits requests per-minute based
|
||||
// RateLimit returns a handler that limits requests per-minute based
|
||||
// on IP, endpoint, and user ID (if available).
|
||||
func RateLimitPerMinute(count int) func(http.Handler) http.Handler {
|
||||
func RateLimit(count int, window time.Duration) func(http.Handler) http.Handler {
|
||||
// -1 is no rate limit
|
||||
if count <= 0 {
|
||||
return func(handler http.Handler) http.Handler {
|
||||
return handler
|
||||
}
|
||||
}
|
||||
|
||||
return httprate.Limit(
|
||||
count,
|
||||
1*time.Minute,
|
||||
window,
|
||||
httprate.WithKeyFuncs(func(r *http.Request) (string, error) {
|
||||
// Prioritize by user, but fallback to IP.
|
||||
apiKey, ok := r.Context().Value(apiKeyContextKey{}).(database.APIKey)
|
||||
if ok {
|
||||
if !ok {
|
||||
return httprate.KeyByIP(r)
|
||||
}
|
||||
|
||||
if ok, _ := strconv.ParseBool(r.Header.Get(codersdk.BypassRatelimitHeader)); !ok {
|
||||
// No bypass attempt, just ratelimit.
|
||||
return apiKey.UserID.String(), nil
|
||||
}
|
||||
return httprate.KeyByIP(r)
|
||||
|
||||
// Allow Owner to bypass rate limiting for load tests
|
||||
// and automation.
|
||||
auth := UserAuthorization(r)
|
||||
|
||||
// We avoid using rbac.Authorizer since rego is CPU-intensive
|
||||
// and undermines the DoS-prevention goal of the rate limiter.
|
||||
for _, role := range auth.Roles {
|
||||
if role == rbac.RoleOwner() {
|
||||
// HACK: use a random key each time to
|
||||
// de facto disable rate limiting. The
|
||||
// `httprate` package has no
|
||||
// support for selectively changing the limit
|
||||
// for particular keys.
|
||||
return cryptorand.String(16)
|
||||
}
|
||||
}
|
||||
|
||||
return apiKey.UserID.String(), xerrors.Errorf(
|
||||
"%q provided but user is not %v",
|
||||
codersdk.BypassRatelimitHeader, rbac.RoleOwner(),
|
||||
)
|
||||
}, httprate.KeyByEndpoint),
|
||||
httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(r.Context(), w, http.StatusTooManyRequests, codersdk.Response{
|
||||
Message: "You've been rate limited for sending too many requests!",
|
||||
Message: fmt.Sprintf("You've been rate limited for sending more than %v requests in %v.", count, window),
|
||||
})
|
||||
}),
|
||||
)
|
||||
|
|
|
@ -1,23 +1,60 @@
|
|||
package httpmw_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/databasefake"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func insertAPIKey(ctx context.Context, t *testing.T, db database.Store, userID uuid.UUID) string {
|
||||
id, secret := randomAPIKeyParts()
|
||||
hashed := sha256.Sum256([]byte(secret))
|
||||
|
||||
_, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{
|
||||
ID: id,
|
||||
HashedSecret: hashed[:],
|
||||
LastUsed: database.Now().AddDate(0, 0, -1),
|
||||
ExpiresAt: database.Now().AddDate(0, 0, 1),
|
||||
UserID: userID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
Scope: database.APIKeyScopeAll,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return fmt.Sprintf("%s-%s", id, secret)
|
||||
}
|
||||
|
||||
func randRemoteAddr() string {
|
||||
var b [4]byte
|
||||
// nolint:gosec
|
||||
rand.Read(b[:])
|
||||
// nolint:gosec
|
||||
return fmt.Sprintf("%s:%v", net.IP(b[:]).String(), rand.Int31()%(1<<16))
|
||||
}
|
||||
|
||||
func TestRateLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("NoUser", func(t *testing.T) {
|
||||
t.Run("NoUserSucceeds", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(httpmw.RateLimitPerMinute(5))
|
||||
rtr.Use(httpmw.RateLimit(5, time.Second))
|
||||
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
@ -31,4 +68,107 @@ func TestRateLimit(t *testing.T) {
|
|||
return resp.StatusCode == http.StatusTooManyRequests
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
})
|
||||
|
||||
t.Run("RandomIPs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(httpmw.RateLimit(5, time.Second))
|
||||
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
require.Never(t, func() bool {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
req.RemoteAddr = randRemoteAddr()
|
||||
rtr.ServeHTTP(rec, req)
|
||||
resp := rec.Result()
|
||||
defer resp.Body.Close()
|
||||
return resp.StatusCode == http.StatusTooManyRequests
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
})
|
||||
|
||||
t.Run("RegularUser", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
db := databasefake.New()
|
||||
|
||||
u := createUser(ctx, t, db)
|
||||
key := insertAPIKey(ctx, t, db, u.ID)
|
||||
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
Optional: false,
|
||||
}))
|
||||
|
||||
rtr.Use(httpmw.RateLimit(5, time.Second))
|
||||
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Bypass must fail
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set(codersdk.SessionCustomHeader, key)
|
||||
req.Header.Set(codersdk.BypassRatelimitHeader, "true")
|
||||
rec := httptest.NewRecorder()
|
||||
// Assert we're not using IP address.
|
||||
req.RemoteAddr = randRemoteAddr()
|
||||
rtr.ServeHTTP(rec, req)
|
||||
resp := rec.Result()
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusPreconditionRequired, resp.StatusCode)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set(codersdk.SessionCustomHeader, key)
|
||||
rec := httptest.NewRecorder()
|
||||
// Assert we're not using IP address.
|
||||
req.RemoteAddr = randRemoteAddr()
|
||||
rtr.ServeHTTP(rec, req)
|
||||
resp := rec.Result()
|
||||
defer resp.Body.Close()
|
||||
return resp.StatusCode == http.StatusTooManyRequests
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
})
|
||||
|
||||
t.Run("OwnerBypass", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
db := databasefake.New()
|
||||
|
||||
u := createUser(ctx, t, db, func(u *database.InsertUserParams) {
|
||||
u.RBACRoles = []string{rbac.RoleOwner()}
|
||||
})
|
||||
|
||||
key := insertAPIKey(ctx, t, db, u.ID)
|
||||
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: db,
|
||||
Optional: false,
|
||||
}))
|
||||
|
||||
rtr.Use(httpmw.RateLimit(5, time.Second))
|
||||
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
require.Never(t, func() bool {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set(codersdk.SessionCustomHeader, key)
|
||||
req.Header.Set(codersdk.BypassRatelimitHeader, "true")
|
||||
rec := httptest.NewRecorder()
|
||||
// Assert we're not using IP address.
|
||||
req.RemoteAddr = randRemoteAddr()
|
||||
rtr.ServeHTTP(rec, req)
|
||||
resp := rec.Result()
|
||||
defer resp.Body.Close()
|
||||
return resp.StatusCode == http.StatusTooManyRequests
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -24,6 +24,9 @@ const (
|
|||
SessionCustomHeader = "Coder-Session-Token"
|
||||
OAuth2StateKey = "oauth_state"
|
||||
OAuth2RedirectKey = "oauth_redirect"
|
||||
|
||||
// nolint: gosec
|
||||
BypassRatelimitHeader = "X-Coder-Bypass-Ratelimit"
|
||||
)
|
||||
|
||||
// New creates a Coder client for the provided URL.
|
||||
|
|
Loading…
Reference in New Issue