mirror of https://github.com/coder/coder.git
302 lines
8.9 KiB
Go
302 lines
8.9 KiB
Go
package httpmw_test
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/dbgen"
|
|
"github.com/coder/coder/v2/coderd/database/dbmem"
|
|
"github.com/coder/coder/v2/coderd/httpapi"
|
|
"github.com/coder/coder/v2/coderd/httpmw"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/cryptorand"
|
|
)
|
|
|
|
func TestExtractWorkspaceProxy(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
successHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
// Only called if the API key passes through the handler.
|
|
httpapi.Write(context.Background(), rw, http.StatusOK, codersdk.Response{
|
|
Message: "It worked!",
|
|
})
|
|
})
|
|
|
|
t.Run("NoHeader", func(t *testing.T) {
|
|
t.Parallel()
|
|
var (
|
|
db = dbmem.New()
|
|
r = httptest.NewRequest("GET", "/", nil)
|
|
rw = httptest.NewRecorder()
|
|
)
|
|
httpmw.ExtractWorkspaceProxy(httpmw.ExtractWorkspaceProxyConfig{
|
|
DB: db,
|
|
})(successHandler).ServeHTTP(rw, r)
|
|
res := rw.Result()
|
|
defer res.Body.Close()
|
|
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
|
})
|
|
|
|
t.Run("InvalidFormat", func(t *testing.T) {
|
|
t.Parallel()
|
|
var (
|
|
db = dbmem.New()
|
|
r = httptest.NewRequest("GET", "/", nil)
|
|
rw = httptest.NewRecorder()
|
|
)
|
|
r.Header.Set(httpmw.WorkspaceProxyAuthTokenHeader, "test:wow-hello")
|
|
|
|
httpmw.ExtractWorkspaceProxy(httpmw.ExtractWorkspaceProxyConfig{
|
|
DB: db,
|
|
})(successHandler).ServeHTTP(rw, r)
|
|
res := rw.Result()
|
|
defer res.Body.Close()
|
|
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
|
})
|
|
|
|
t.Run("InvalidID", func(t *testing.T) {
|
|
t.Parallel()
|
|
var (
|
|
db = dbmem.New()
|
|
r = httptest.NewRequest("GET", "/", nil)
|
|
rw = httptest.NewRecorder()
|
|
)
|
|
r.Header.Set(httpmw.WorkspaceProxyAuthTokenHeader, "test:wow")
|
|
|
|
httpmw.ExtractWorkspaceProxy(httpmw.ExtractWorkspaceProxyConfig{
|
|
DB: db,
|
|
})(successHandler).ServeHTTP(rw, r)
|
|
res := rw.Result()
|
|
defer res.Body.Close()
|
|
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
|
})
|
|
|
|
t.Run("InvalidSecretLength", func(t *testing.T) {
|
|
t.Parallel()
|
|
var (
|
|
db = dbmem.New()
|
|
r = httptest.NewRequest("GET", "/", nil)
|
|
rw = httptest.NewRecorder()
|
|
)
|
|
r.Header.Set(httpmw.WorkspaceProxyAuthTokenHeader, fmt.Sprintf("%s:%s", uuid.NewString(), "wow"))
|
|
|
|
httpmw.ExtractWorkspaceProxy(httpmw.ExtractWorkspaceProxyConfig{
|
|
DB: db,
|
|
})(successHandler).ServeHTTP(rw, r)
|
|
res := rw.Result()
|
|
defer res.Body.Close()
|
|
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
|
})
|
|
|
|
t.Run("NotFound", func(t *testing.T) {
|
|
t.Parallel()
|
|
var (
|
|
db = dbmem.New()
|
|
r = httptest.NewRequest("GET", "/", nil)
|
|
rw = httptest.NewRecorder()
|
|
)
|
|
|
|
secret, err := cryptorand.HexString(64)
|
|
require.NoError(t, err)
|
|
r.Header.Set(httpmw.WorkspaceProxyAuthTokenHeader, fmt.Sprintf("%s:%s", uuid.NewString(), secret))
|
|
|
|
httpmw.ExtractWorkspaceProxy(httpmw.ExtractWorkspaceProxyConfig{
|
|
DB: db,
|
|
})(successHandler).ServeHTTP(rw, r)
|
|
res := rw.Result()
|
|
defer res.Body.Close()
|
|
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
|
})
|
|
|
|
t.Run("InvalidSecret", func(t *testing.T) {
|
|
t.Parallel()
|
|
var (
|
|
db = dbmem.New()
|
|
r = httptest.NewRequest("GET", "/", nil)
|
|
rw = httptest.NewRecorder()
|
|
|
|
proxy, _ = dbgen.WorkspaceProxy(t, db, database.WorkspaceProxy{})
|
|
)
|
|
|
|
// Use a different secret so they don't match!
|
|
secret, err := cryptorand.HexString(64)
|
|
require.NoError(t, err)
|
|
r.Header.Set(httpmw.WorkspaceProxyAuthTokenHeader, fmt.Sprintf("%s:%s", proxy.ID.String(), secret))
|
|
|
|
httpmw.ExtractWorkspaceProxy(httpmw.ExtractWorkspaceProxyConfig{
|
|
DB: db,
|
|
})(successHandler).ServeHTTP(rw, r)
|
|
res := rw.Result()
|
|
defer res.Body.Close()
|
|
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
|
})
|
|
|
|
t.Run("Valid", func(t *testing.T) {
|
|
t.Parallel()
|
|
var (
|
|
db = dbmem.New()
|
|
r = httptest.NewRequest("GET", "/", nil)
|
|
rw = httptest.NewRecorder()
|
|
|
|
proxy, secret = dbgen.WorkspaceProxy(t, db, database.WorkspaceProxy{})
|
|
)
|
|
r.Header.Set(httpmw.WorkspaceProxyAuthTokenHeader, fmt.Sprintf("%s:%s", proxy.ID.String(), secret))
|
|
|
|
httpmw.ExtractWorkspaceProxy(httpmw.ExtractWorkspaceProxyConfig{
|
|
DB: db,
|
|
})(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
// Checks that it exists on the context!
|
|
_ = httpmw.WorkspaceProxy(r)
|
|
successHandler.ServeHTTP(rw, r)
|
|
})).ServeHTTP(rw, r)
|
|
res := rw.Result()
|
|
defer res.Body.Close()
|
|
require.Equal(t, http.StatusOK, res.StatusCode)
|
|
})
|
|
|
|
t.Run("Deleted", func(t *testing.T) {
|
|
t.Parallel()
|
|
var (
|
|
db = dbmem.New()
|
|
r = httptest.NewRequest("GET", "/", nil)
|
|
rw = httptest.NewRecorder()
|
|
|
|
proxy, secret = dbgen.WorkspaceProxy(t, db, database.WorkspaceProxy{})
|
|
)
|
|
err := db.UpdateWorkspaceProxyDeleted(context.Background(), database.UpdateWorkspaceProxyDeletedParams{
|
|
ID: proxy.ID,
|
|
Deleted: true,
|
|
})
|
|
require.NoError(t, err, "failed to delete workspace proxy")
|
|
|
|
r.Header.Set(httpmw.WorkspaceProxyAuthTokenHeader, fmt.Sprintf("%s:%s", proxy.ID.String(), secret))
|
|
|
|
httpmw.ExtractWorkspaceProxy(httpmw.ExtractWorkspaceProxyConfig{
|
|
DB: db,
|
|
})(successHandler).ServeHTTP(rw, r)
|
|
res := rw.Result()
|
|
defer res.Body.Close()
|
|
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
|
})
|
|
}
|
|
|
|
func TestExtractWorkspaceProxyParam(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
successHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
// Only called if the API key passes through the handler.
|
|
httpapi.Write(context.Background(), rw, http.StatusOK, codersdk.Response{
|
|
Message: "It worked!",
|
|
})
|
|
})
|
|
|
|
t.Run("OKName", func(t *testing.T) {
|
|
t.Parallel()
|
|
var (
|
|
db = dbmem.New()
|
|
r = httptest.NewRequest("GET", "/", nil)
|
|
rw = httptest.NewRecorder()
|
|
|
|
proxy, _ = dbgen.WorkspaceProxy(t, db, database.WorkspaceProxy{})
|
|
)
|
|
|
|
routeContext := chi.NewRouteContext()
|
|
routeContext.URLParams.Add("workspaceproxy", proxy.Name)
|
|
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeContext))
|
|
|
|
httpmw.ExtractWorkspaceProxyParam(db, uuid.NewString(), nil)(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
|
// Checks that it exists on the context!
|
|
_ = httpmw.WorkspaceProxyParam(request)
|
|
successHandler.ServeHTTP(writer, request)
|
|
})).ServeHTTP(rw, r)
|
|
res := rw.Result()
|
|
defer res.Body.Close()
|
|
require.Equal(t, http.StatusOK, res.StatusCode)
|
|
})
|
|
|
|
t.Run("OKID", func(t *testing.T) {
|
|
t.Parallel()
|
|
var (
|
|
db = dbmem.New()
|
|
r = httptest.NewRequest("GET", "/", nil)
|
|
rw = httptest.NewRecorder()
|
|
|
|
proxy, _ = dbgen.WorkspaceProxy(t, db, database.WorkspaceProxy{})
|
|
)
|
|
|
|
routeContext := chi.NewRouteContext()
|
|
routeContext.URLParams.Add("workspaceproxy", proxy.ID.String())
|
|
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeContext))
|
|
|
|
httpmw.ExtractWorkspaceProxyParam(db, uuid.NewString(), nil)(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
|
// Checks that it exists on the context!
|
|
_ = httpmw.WorkspaceProxyParam(request)
|
|
successHandler.ServeHTTP(writer, request)
|
|
})).ServeHTTP(rw, r)
|
|
res := rw.Result()
|
|
defer res.Body.Close()
|
|
require.Equal(t, http.StatusOK, res.StatusCode)
|
|
})
|
|
|
|
t.Run("NotFound", func(t *testing.T) {
|
|
t.Parallel()
|
|
var (
|
|
db = dbmem.New()
|
|
r = httptest.NewRequest("GET", "/", nil)
|
|
rw = httptest.NewRecorder()
|
|
)
|
|
|
|
routeContext := chi.NewRouteContext()
|
|
routeContext.URLParams.Add("workspaceproxy", uuid.NewString())
|
|
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeContext))
|
|
|
|
httpmw.ExtractWorkspaceProxyParam(db, uuid.NewString(), nil)(successHandler).ServeHTTP(rw, r)
|
|
res := rw.Result()
|
|
defer res.Body.Close()
|
|
require.Equal(t, http.StatusNotFound, res.StatusCode)
|
|
})
|
|
|
|
t.Run("FetchPrimary", func(t *testing.T) {
|
|
t.Parallel()
|
|
var (
|
|
db = dbmem.New()
|
|
r = httptest.NewRequest("GET", "/", nil)
|
|
rw = httptest.NewRecorder()
|
|
deploymentID = uuid.New()
|
|
primaryProxy = database.WorkspaceProxy{
|
|
ID: deploymentID,
|
|
Name: "primary",
|
|
DisplayName: "Default",
|
|
Icon: "Icon",
|
|
Url: "Url",
|
|
WildcardHostname: "Wildcard",
|
|
}
|
|
fetchPrimary = func(ctx context.Context) (database.WorkspaceProxy, error) {
|
|
return primaryProxy, nil
|
|
}
|
|
)
|
|
|
|
routeContext := chi.NewRouteContext()
|
|
routeContext.URLParams.Add("workspaceproxy", deploymentID.String())
|
|
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeContext))
|
|
|
|
httpmw.ExtractWorkspaceProxyParam(db, deploymentID.String(), fetchPrimary)(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
|
// Checks that it exists on the context!
|
|
found := httpmw.WorkspaceProxyParam(request)
|
|
require.Equal(t, primaryProxy, found)
|
|
successHandler.ServeHTTP(writer, request)
|
|
})).ServeHTTP(rw, r)
|
|
res := rw.Result()
|
|
defer res.Body.Close()
|
|
require.Equal(t, http.StatusOK, res.StatusCode)
|
|
})
|
|
}
|