mirror of https://github.com/coder/coder.git
355 lines
8.6 KiB
Go
355 lines
8.6 KiB
Go
package codersdk
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/go-logr/logr"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.opentelemetry.io/otel"
|
|
"go.opentelemetry.io/otel/propagation"
|
|
"go.opentelemetry.io/otel/sdk/resource"
|
|
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
|
semconv "go.opentelemetry.io/otel/semconv/v1.14.0"
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog"
|
|
"cdr.dev/slog/sloggers/sloghuman"
|
|
"github.com/coder/coder/v2/testutil"
|
|
)
|
|
|
|
const jsonCT = "application/json"
|
|
|
|
func TestIsConnectionErr(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
type tc = struct {
|
|
name string
|
|
err error
|
|
expectedResult bool
|
|
}
|
|
|
|
cases := []tc{
|
|
{
|
|
// E.g. "no such host"
|
|
name: "DNSError",
|
|
err: &net.DNSError{
|
|
Err: "no such host",
|
|
Name: "foofoo",
|
|
Server: "1.1.1.1:53",
|
|
IsTimeout: false,
|
|
IsTemporary: false,
|
|
IsNotFound: true,
|
|
},
|
|
expectedResult: true,
|
|
},
|
|
{
|
|
// E.g. "connection refused"
|
|
name: "OpErr",
|
|
err: &net.OpError{
|
|
Op: "dial",
|
|
Net: "tcp",
|
|
Source: nil,
|
|
Addr: nil,
|
|
Err: &os.SyscallError{},
|
|
},
|
|
expectedResult: true,
|
|
},
|
|
{
|
|
name: "OpaqueError",
|
|
err: xerrors.Errorf("I'm opaque!"),
|
|
expectedResult: false,
|
|
},
|
|
}
|
|
|
|
for _, c := range cases {
|
|
c := c
|
|
|
|
t.Run(c.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
require.Equal(t, c.expectedResult, IsConnectionError(c.err))
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_Client(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
const method = http.MethodPost
|
|
const path = "/ok"
|
|
const token = "token"
|
|
const reqBody = `{"msg": "request body"}`
|
|
const resBody = `{"status": "ok"}`
|
|
|
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
assert.Equal(t, method, r.Method)
|
|
assert.Equal(t, path, r.URL.Path)
|
|
assert.Equal(t, token, r.Header.Get(SessionTokenHeader))
|
|
assert.NotEmpty(t, r.Header.Get("Traceparent"))
|
|
for k, v := range r.Header {
|
|
t.Logf("header %q: %q", k, strings.Join(v, ", "))
|
|
}
|
|
|
|
w.Header().Set("Content-Type", jsonCT)
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = io.WriteString(w, resBody)
|
|
}))
|
|
|
|
u, err := url.Parse(s.URL)
|
|
require.NoError(t, err)
|
|
client := New(u)
|
|
client.SetSessionToken(token)
|
|
|
|
logBuf := bytes.NewBuffer(nil)
|
|
client.SetLogger(slog.Make(sloghuman.Sink(logBuf)).Leveled(slog.LevelDebug))
|
|
client.SetLogBodies(true)
|
|
|
|
// Setup tracing.
|
|
res := resource.NewWithAttributes(
|
|
semconv.SchemaURL,
|
|
semconv.ServiceNameKey.String("codersdk_test"),
|
|
)
|
|
tracerOpts := []sdktrace.TracerProviderOption{
|
|
sdktrace.WithResource(res),
|
|
}
|
|
tracerProvider := sdktrace.NewTracerProvider(tracerOpts...)
|
|
otel.SetTracerProvider(tracerProvider)
|
|
otel.SetErrorHandler(otel.ErrorHandlerFunc(func(err error) {}))
|
|
otel.SetTextMapPropagator(
|
|
propagation.NewCompositeTextMapPropagator(
|
|
propagation.TraceContext{},
|
|
propagation.Baggage{},
|
|
),
|
|
)
|
|
otel.SetLogger(logr.Discard())
|
|
client.Trace = true
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
|
defer cancel()
|
|
ctx, span := tracerProvider.Tracer("codersdk_test").Start(ctx, "codersdk client test 1")
|
|
defer span.End()
|
|
|
|
resp, err := client.Request(ctx, method, path, []byte(reqBody))
|
|
require.NoError(t, err)
|
|
defer resp.Body.Close()
|
|
|
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
|
require.Equal(t, jsonCT, resp.Header.Get("Content-Type"))
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
require.NoError(t, err)
|
|
require.Equal(t, resBody, string(body))
|
|
|
|
logStr := logBuf.String()
|
|
require.Contains(t, logStr, "sdk request")
|
|
require.Contains(t, logStr, method)
|
|
require.Contains(t, logStr, path)
|
|
require.Contains(t, logStr, strings.ReplaceAll(reqBody, `"`, `\"`))
|
|
require.Contains(t, logStr, "sdk response")
|
|
require.Contains(t, logStr, "200")
|
|
require.Contains(t, logStr, strings.ReplaceAll(resBody, `"`, `\"`))
|
|
}
|
|
|
|
func Test_readBodyAsError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
exampleURL := "http://example.com"
|
|
simpleResponse := Response{
|
|
Message: "test",
|
|
Detail: "hi",
|
|
}
|
|
|
|
longResponse := ""
|
|
for i := 0; i < 4000; i++ {
|
|
longResponse += "a"
|
|
}
|
|
|
|
unexpectedJSON := marshal(map[string]any{
|
|
"hello": "world",
|
|
"foo": "bar",
|
|
})
|
|
|
|
//nolint:bodyclose
|
|
tests := []struct {
|
|
name string
|
|
req *http.Request
|
|
res *http.Response
|
|
assert func(t *testing.T, err error)
|
|
}{
|
|
{
|
|
name: "JSONWithRequest",
|
|
req: httptest.NewRequest(http.MethodGet, exampleURL, nil),
|
|
res: newResponse(http.StatusNotFound, jsonCT, marshal(simpleResponse)),
|
|
assert: func(t *testing.T, err error) {
|
|
sdkErr := assertSDKError(t, err)
|
|
|
|
assert.Equal(t, simpleResponse, sdkErr.Response)
|
|
assert.ErrorContains(t, err, sdkErr.Response.Message)
|
|
assert.ErrorContains(t, err, sdkErr.Response.Detail)
|
|
|
|
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
|
assert.ErrorContains(t, err, strconv.Itoa(sdkErr.StatusCode()))
|
|
|
|
assert.Equal(t, http.MethodGet, sdkErr.method)
|
|
assert.ErrorContains(t, err, sdkErr.method)
|
|
|
|
assert.Equal(t, exampleURL, sdkErr.url)
|
|
assert.ErrorContains(t, err, sdkErr.url)
|
|
|
|
assert.Empty(t, sdkErr.Helper)
|
|
},
|
|
},
|
|
{
|
|
name: "JSONWithoutRequest",
|
|
req: nil,
|
|
res: newResponse(http.StatusNotFound, jsonCT, marshal(simpleResponse)),
|
|
assert: func(t *testing.T, err error) {
|
|
sdkErr := assertSDKError(t, err)
|
|
|
|
assert.Equal(t, simpleResponse, sdkErr.Response)
|
|
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
|
assert.Empty(t, sdkErr.method)
|
|
assert.Empty(t, sdkErr.url)
|
|
assert.Empty(t, sdkErr.Helper)
|
|
},
|
|
},
|
|
{
|
|
name: "UnauthorizedHelper",
|
|
req: nil,
|
|
res: newResponse(http.StatusUnauthorized, jsonCT, marshal(simpleResponse)),
|
|
assert: func(t *testing.T, err error) {
|
|
sdkErr := assertSDKError(t, err)
|
|
|
|
assert.Contains(t, sdkErr.Helper, "Try logging in")
|
|
assert.ErrorContains(t, err, sdkErr.Helper)
|
|
},
|
|
},
|
|
{
|
|
name: "NonJSON",
|
|
req: nil,
|
|
res: newResponse(http.StatusNotFound, "text/plain; charset=utf-8", "hello world"),
|
|
assert: func(t *testing.T, err error) {
|
|
sdkErr := assertSDKError(t, err)
|
|
|
|
assert.Contains(t, sdkErr.Response.Message, "unexpected non-JSON response")
|
|
assert.Equal(t, "hello world", sdkErr.Response.Detail)
|
|
},
|
|
},
|
|
{
|
|
name: "NonJSONLong",
|
|
req: nil,
|
|
res: newResponse(http.StatusNotFound, "text/plain; charset=utf-8", longResponse),
|
|
assert: func(t *testing.T, err error) {
|
|
sdkErr := assertSDKError(t, err)
|
|
|
|
assert.Contains(t, sdkErr.Response.Message, "unexpected non-JSON response")
|
|
|
|
expected := longResponse[0:2048] + "..."
|
|
assert.Equal(t, expected, sdkErr.Response.Detail)
|
|
},
|
|
},
|
|
{
|
|
name: "JSONNoBody",
|
|
req: nil,
|
|
res: newResponse(http.StatusNotFound, jsonCT, ""),
|
|
assert: func(t *testing.T, err error) {
|
|
sdkErr := assertSDKError(t, err)
|
|
|
|
assert.Contains(t, sdkErr.Response.Message, "empty response body")
|
|
},
|
|
},
|
|
{
|
|
name: "JSONNoMessage",
|
|
req: nil,
|
|
res: newResponse(http.StatusNotFound, jsonCT, unexpectedJSON),
|
|
assert: func(t *testing.T, err error) {
|
|
sdkErr := assertSDKError(t, err)
|
|
|
|
assert.Contains(t, sdkErr.Response.Message, "unexpected status code")
|
|
assert.Contains(t, sdkErr.Response.Message, "has no message")
|
|
assert.Equal(t, unexpectedJSON, sdkErr.Response.Detail)
|
|
},
|
|
},
|
|
{
|
|
// Even status code 200 should be considered an error if this function
|
|
// is called. There are parts of the code that require this function
|
|
// to always return an error.
|
|
name: "OKResp",
|
|
req: nil,
|
|
res: newResponse(http.StatusOK, jsonCT, marshal(map[string]any{})),
|
|
assert: func(t *testing.T, err error) {
|
|
require.Error(t, err)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, c := range tests {
|
|
c := c
|
|
t.Run(c.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
c.res.Request = c.req
|
|
|
|
err := ReadBodyAsError(c.res)
|
|
c.assert(t, err)
|
|
})
|
|
}
|
|
}
|
|
|
|
func assertSDKError(t *testing.T, err error) *Error {
|
|
t.Helper()
|
|
|
|
var sdkErr *Error
|
|
require.Error(t, err)
|
|
require.True(t, xerrors.As(err, &sdkErr))
|
|
|
|
return sdkErr
|
|
}
|
|
|
|
func newResponse(status int, contentType string, body interface{}) *http.Response {
|
|
var r io.ReadCloser
|
|
switch v := body.(type) {
|
|
case string:
|
|
r = io.NopCloser(strings.NewReader(v))
|
|
case []byte:
|
|
r = io.NopCloser(bytes.NewReader(v))
|
|
case io.ReadCloser:
|
|
r = v
|
|
case io.Reader:
|
|
r = io.NopCloser(v)
|
|
default:
|
|
panic(fmt.Sprintf("unknown body type: %T", body))
|
|
}
|
|
|
|
return &http.Response{
|
|
Status: http.StatusText(status),
|
|
StatusCode: status,
|
|
Header: http.Header{
|
|
"Content-Type": []string{contentType},
|
|
},
|
|
Body: r,
|
|
}
|
|
}
|
|
|
|
func marshal(res any) string {
|
|
b, err := json.Marshal(res)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
return string(b)
|
|
}
|