feat: add etag to slim binaries endpoint (#5750)

This commit is contained in:
Dean Sheather 2023-01-17 12:38:08 -06:00 committed by GitHub
parent c377cd0fa9
commit b19d644162
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 227 additions and 63 deletions

View File

@ -197,7 +197,7 @@ func New(options *Options) *API {
if siteCacheDir != "" { if siteCacheDir != "" {
siteCacheDir = filepath.Join(siteCacheDir, "site") siteCacheDir = filepath.Join(siteCacheDir, "site")
} }
binFS, err := site.ExtractOrReadBinFS(siteCacheDir, site.FS()) binFS, binHashes, err := site.ExtractOrReadBinFS(siteCacheDir, site.FS())
if err != nil { if err != nil {
panic(xerrors.Errorf("read site bin failed: %w", err)) panic(xerrors.Errorf("read site bin failed: %w", err))
} }
@ -213,7 +213,7 @@ func New(options *Options) *API {
ID: uuid.New(), ID: uuid.New(),
Options: options, Options: options,
RootHandler: r, RootHandler: r,
siteHandler: site.Handler(site.FS(), binFS), siteHandler: site.Handler(site.FS(), binFS, binHashes),
HTTPAuth: &HTTPAuthorizer{ HTTPAuth: &HTTPAuthorizer{
Authorizer: options.Authorizer, Authorizer: options.Authorizer,
Logger: options.Logger, Logger: options.Logger,

View File

@ -16,6 +16,7 @@ import (
"path" "path"
"path/filepath" "path/filepath"
"strings" "strings"
"sync"
"text/template" // html/template escapes some nonces "text/template" // html/template escapes some nonces
"time" "time"
@ -24,6 +25,7 @@ import (
"github.com/unrolled/secure" "github.com/unrolled/secure"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"golang.org/x/sync/singleflight"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpapi"
@ -48,7 +50,7 @@ func init() {
} }
// Handler returns an HTTP handler for serving the static site. // Handler returns an HTTP handler for serving the static site.
func Handler(siteFS fs.FS, binFS http.FileSystem) http.Handler { func Handler(siteFS fs.FS, binFS http.FileSystem, binHashes map[string]string) http.Handler {
// html files are handled by a text/template. Non-html files // html files are handled by a text/template. Non-html files
// are served by the default file server. // are served by the default file server.
// //
@ -59,6 +61,8 @@ func Handler(siteFS fs.FS, binFS http.FileSystem) http.Handler {
panic(xerrors.Errorf("Failed to return handler for static files. Html files failed to load: %w", err)) panic(xerrors.Errorf("Failed to return handler for static files. Html files failed to load: %w", err))
} }
binHashCache := newBinHashCache(binFS, binHashes)
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle("/bin/", http.StripPrefix("/bin", http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { mux.Handle("/bin/", http.StripPrefix("/bin", http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
// Convert underscores in the filename to hyphens. We eventually want to // Convert underscores in the filename to hyphens. We eventually want to
@ -66,6 +70,34 @@ func Handler(siteFS fs.FS, binFS http.FileSystem) http.Handler {
// support both for now. // support both for now.
r.URL.Path = strings.ReplaceAll(r.URL.Path, "_", "-") r.URL.Path = strings.ReplaceAll(r.URL.Path, "_", "-")
// Set ETag header to the SHA1 hash of the file contents.
name := filePath(r.URL.Path)
if name == "" || name == "/" {
// Serve the directory listing.
http.FileServer(binFS).ServeHTTP(rw, r)
return
}
if strings.Contains(name, "/") {
// We only serve files from the root of this directory, so avoid any
// shenanigans by blocking slashes in the URL path.
http.NotFound(rw, r)
return
}
hash, err := binHashCache.getHash(name)
if xerrors.Is(err, os.ErrNotExist) {
http.NotFound(rw, r)
return
}
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
// ETag header needs to be quoted.
rw.Header().Set("ETag", fmt.Sprintf(`%q`, hash))
// http.FileServer will see the ETag header and automatically handle
// If-Match and If-None-Match headers on the request properly.
http.FileServer(binFS).ServeHTTP(rw, r) http.FileServer(binFS).ServeHTTP(rw, r)
}))) })))
mux.Handle("/", http.FileServer(http.FS(siteFS))) // All other non-html static files. mux.Handle("/", http.FileServer(http.FS(siteFS))) // All other non-html static files.
@ -409,20 +441,23 @@ func htmlFiles(files fs.FS) (*htmlTemplates, error) {
}, nil }, nil
} }
// ExtractOrReadBinFS checks the provided fs for compressed coder // ExtractOrReadBinFS checks the provided fs for compressed coder binaries and
// binaries and extracts them into dest/bin if found. As a fallback, // extracts them into dest/bin if found. As a fallback, the provided FS is
// the provided FS is checked for a /bin directory, if it is non-empty // checked for a /bin directory, if it is non-empty it is returned. Finally
// it is returned. Finally dest/bin is returned as a fallback allowing // dest/bin is returned as a fallback allowing binaries to be manually placed in
// binaries to be manually placed in dest (usually // dest (usually ${CODER_CACHE_DIRECTORY}/site/bin).
// ${CODER_CACHE_DIRECTORY}/site/bin). //
func ExtractOrReadBinFS(dest string, siteFS fs.FS) (http.FileSystem, error) { // Returns a http.FileSystem that serves unpacked binaries, and a map of binary
// name to SHA1 hash. The returned hash map may be incomplete or contain hashes
// for missing files.
func ExtractOrReadBinFS(dest string, siteFS fs.FS) (http.FileSystem, map[string]string, error) {
if dest == "" { if dest == "" {
// No destination on fs, embedded fs is the only option. // No destination on fs, embedded fs is the only option.
binFS, err := fs.Sub(siteFS, "bin") binFS, err := fs.Sub(siteFS, "bin")
if err != nil { if err != nil {
return nil, xerrors.Errorf("cache path is empty and embedded fs does not have /bin: %w", err) return nil, nil, xerrors.Errorf("cache path is empty and embedded fs does not have /bin: %w", err)
} }
return http.FS(binFS), nil return http.FS(binFS), nil, nil
} }
dest = filepath.Join(dest, "bin") dest = filepath.Join(dest, "bin")
@ -440,51 +475,63 @@ func ExtractOrReadBinFS(dest string, siteFS fs.FS) (http.FileSystem, error) {
files, err := fs.ReadDir(siteFS, "bin") files, err := fs.ReadDir(siteFS, "bin")
if err != nil { if err != nil {
if xerrors.Is(err, fs.ErrNotExist) { if xerrors.Is(err, fs.ErrNotExist) {
// Given fs does not have a bin directory, // Given fs does not have a bin directory, serve from cache
// serve from cache directory. // directory without extracting anything.
return mkdest() binFS, err := mkdest()
if err != nil {
return nil, nil, xerrors.Errorf("mkdest failed: %w", err)
}
return binFS, map[string]string{}, nil
} }
return nil, xerrors.Errorf("site fs read dir failed: %w", err) return nil, nil, xerrors.Errorf("site fs read dir failed: %w", err)
} }
if len(filterFiles(files, "GITKEEP")) > 0 { if len(filterFiles(files, "GITKEEP")) > 0 {
// If there are other files than bin/GITKEEP, // If there are other files than bin/GITKEEP, serve the files.
// serve the files.
binFS, err := fs.Sub(siteFS, "bin") binFS, err := fs.Sub(siteFS, "bin")
if err != nil { if err != nil {
return nil, xerrors.Errorf("site fs sub dir failed: %w", err) return nil, nil, xerrors.Errorf("site fs sub dir failed: %w", err)
} }
return http.FS(binFS), nil return http.FS(binFS), nil, nil
} }
// Nothing we can do, serve the cache directory, // Nothing we can do, serve the cache directory, thus allowing
// thus allowing binaries to be places there. // binaries to be placed there.
return mkdest() binFS, err := mkdest()
if err != nil {
return nil, nil, xerrors.Errorf("mkdest failed: %w", err)
}
return binFS, map[string]string{}, nil
} }
return nil, xerrors.Errorf("open coder binary archive failed: %w", err) return nil, nil, xerrors.Errorf("open coder binary archive failed: %w", err)
} }
defer archive.Close() defer archive.Close()
dir, err := mkdest() binFS, err := mkdest()
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
ok, err := verifyBinSha1IsCurrent(dest, siteFS) shaFiles, err := parseSHA1(siteFS)
if err != nil { if err != nil {
return nil, xerrors.Errorf("verify coder binaries sha1 failed: %w", err) return nil, nil, xerrors.Errorf("parse sha1 file failed: %w", err)
}
ok, err := verifyBinSha1IsCurrent(dest, siteFS, shaFiles)
if err != nil {
return nil, nil, xerrors.Errorf("verify coder binaries sha1 failed: %w", err)
} }
if !ok { if !ok {
n, err := extractBin(dest, archive) n, err := extractBin(dest, archive)
if err != nil { if err != nil {
return nil, xerrors.Errorf("extract coder binaries failed: %w", err) return nil, nil, xerrors.Errorf("extract coder binaries failed: %w", err)
} }
if n == 0 { if n == 0 {
return nil, xerrors.New("no files were extracted from coder binaries archive") return nil, nil, xerrors.New("no files were extracted from coder binaries archive")
} }
} }
return dir, nil return binFS, shaFiles, nil
} }
func filterFiles(files []fs.DirEntry, names ...string) []fs.DirEntry { func filterFiles(files []fs.DirEntry, names ...string) []fs.DirEntry {
@ -501,24 +548,32 @@ func filterFiles(files []fs.DirEntry, names ...string) []fs.DirEntry {
// errHashMismatch is a sentinel error used in verifyBinSha1IsCurrent. // errHashMismatch is a sentinel error used in verifyBinSha1IsCurrent.
var errHashMismatch = xerrors.New("hash mismatch") var errHashMismatch = xerrors.New("hash mismatch")
func verifyBinSha1IsCurrent(dest string, siteFS fs.FS) (ok bool, err error) { func parseSHA1(siteFS fs.FS) (map[string]string, error) {
b, err := fs.ReadFile(siteFS, "bin/coder.sha1")
if err != nil {
return nil, xerrors.Errorf("read coder sha1 from embedded fs failed: %w", err)
}
shaFiles := make(map[string]string)
for _, line := range bytes.Split(bytes.TrimSpace(b), []byte{'\n'}) {
parts := bytes.Split(line, []byte{' ', '*'})
if len(parts) != 2 {
return nil, xerrors.Errorf("malformed sha1 file: %w", err)
}
shaFiles[string(parts[1])] = strings.ToLower(string(parts[0]))
}
if len(shaFiles) == 0 {
return nil, xerrors.Errorf("empty sha1 file: %w", err)
}
return shaFiles, nil
}
func verifyBinSha1IsCurrent(dest string, siteFS fs.FS, shaFiles map[string]string) (ok bool, err error) {
b1, err := fs.ReadFile(siteFS, "bin/coder.sha1") b1, err := fs.ReadFile(siteFS, "bin/coder.sha1")
if err != nil { if err != nil {
return false, xerrors.Errorf("read coder sha1 from embedded fs failed: %w", err) return false, xerrors.Errorf("read coder sha1 from embedded fs failed: %w", err)
} }
// Parse sha1 file.
shaFiles := make(map[string][]byte)
for _, line := range bytes.Split(bytes.TrimSpace(b1), []byte{'\n'}) {
parts := bytes.Split(line, []byte{' ', '*'})
if len(parts) != 2 {
return false, xerrors.Errorf("malformed sha1 file: %w", err)
}
shaFiles[string(parts[1])] = parts[0]
}
if len(shaFiles) == 0 {
return false, xerrors.Errorf("empty sha1 file: %w", err)
}
b2, err := os.ReadFile(filepath.Join(dest, "coder.sha1")) b2, err := os.ReadFile(filepath.Join(dest, "coder.sha1"))
if err != nil { if err != nil {
if xerrors.Is(err, fs.ErrNotExist) { if xerrors.Is(err, fs.ErrNotExist) {
@ -551,7 +606,7 @@ func verifyBinSha1IsCurrent(dest string, siteFS fs.FS) (ok bool, err error) {
} }
return xerrors.Errorf("hash file failed: %w", err) return xerrors.Errorf("hash file failed: %w", err)
} }
if !bytes.Equal(hash1, hash2) { if !strings.EqualFold(hash1, hash2) {
return errHashMismatch return errHashMismatch
} }
return nil return nil
@ -570,24 +625,24 @@ func verifyBinSha1IsCurrent(dest string, siteFS fs.FS) (ok bool, err error) {
// sha1HashFile computes a SHA1 hash of the file, returning the hex // sha1HashFile computes a SHA1 hash of the file, returning the hex
// representation. // representation.
func sha1HashFile(name string) ([]byte, error) { func sha1HashFile(name string) (string, error) {
//#nosec // Not used for cryptography. //#nosec // Not used for cryptography.
hash := sha1.New() hash := sha1.New()
f, err := os.Open(name) f, err := os.Open(name)
if err != nil { if err != nil {
return nil, err return "", err
} }
defer f.Close() defer f.Close()
_, err = io.Copy(hash, f) _, err = io.Copy(hash, f)
if err != nil { if err != nil {
return nil, err return "", err
} }
b := make([]byte, hash.Size()) b := make([]byte, hash.Size())
hash.Sum(b[:0]) hash.Sum(b[:0])
return []byte(hex.EncodeToString(b)), nil return hex.EncodeToString(b), nil
} }
func extractBin(dest string, r io.Reader) (numExtracted int, err error) { func extractBin(dest string, r io.Reader) (numExtracted int, err error) {
@ -672,3 +727,67 @@ func RenderStaticErrorPage(rw http.ResponseWriter, r *http.Request, data ErrorPa
return return
} }
} }
type binHashCache struct {
binFS http.FileSystem
hashes map[string]string
mut sync.RWMutex
sf singleflight.Group
sem chan struct{}
}
func newBinHashCache(binFS http.FileSystem, binHashes map[string]string) *binHashCache {
b := &binHashCache{
binFS: binFS,
hashes: make(map[string]string, len(binHashes)),
mut: sync.RWMutex{},
sf: singleflight.Group{},
sem: make(chan struct{}, 4),
}
// Make a copy since we're gonna be mutating it.
for k, v := range binHashes {
b.hashes[k] = v
}
return b
}
func (b *binHashCache) getHash(name string) (string, error) {
b.mut.RLock()
hash, ok := b.hashes[name]
b.mut.RUnlock()
if ok {
return hash, nil
}
// Avoid DOS by using a pool, and only doing work once per file.
v, err, _ := b.sf.Do(name, func() (interface{}, error) {
b.sem <- struct{}{}
defer func() { <-b.sem }()
f, err := b.binFS.Open(name)
if err != nil {
return "", err
}
defer f.Close()
h := sha1.New() //#nosec // Not used for cryptography.
_, err = io.Copy(h, f)
if err != nil {
return "", err
}
hash := hex.EncodeToString(h.Sum(nil))
b.mut.Lock()
b.hashes[name] = hash
b.mut.Unlock()
return hash, nil
})
if err != nil {
return "", err
}
//nolint:forcetypeassert
return strings.ToLower(v.(string)), nil
}

View File

@ -45,7 +45,7 @@ func TestCaching(t *testing.T) {
} }
binFS := http.FS(fstest.MapFS{}) binFS := http.FS(fstest.MapFS{})
srv := httptest.NewServer(site.Handler(rootFS, binFS)) srv := httptest.NewServer(site.Handler(rootFS, binFS, nil))
defer srv.Close() defer srv.Close()
// Create a context // Create a context
@ -105,7 +105,7 @@ func TestServingFiles(t *testing.T) {
} }
binFS := http.FS(fstest.MapFS{}) binFS := http.FS(fstest.MapFS{})
srv := httptest.NewServer(site.Handler(rootFS, binFS)) srv := httptest.NewServer(site.Handler(rootFS, binFS, nil))
defer srv.Close() defer srv.Close()
// Create a context // Create a context
@ -185,10 +185,18 @@ const (
binCoderTarZstd = "bin/coder.tar.zst" binCoderTarZstd = "bin/coder.tar.zst"
) )
var sampleBinSHAs = map[string]string{
"coder-linux-amd64": "55641d5d56bbb8ccf5850fe923bd971b86364604",
}
func sampleBinFS() fstest.MapFS { func sampleBinFS() fstest.MapFS {
sha1File := bytes.NewBuffer(nil)
for name, sha := range sampleBinSHAs {
_, _ = fmt.Fprintf(sha1File, "%s *%s\n", sha, name)
}
return fstest.MapFS{ return fstest.MapFS{
binCoderSha1: &fstest.MapFile{ binCoderSha1: &fstest.MapFile{
Data: []byte("55641d5d56bbb8ccf5850fe923bd971b86364604 *coder-linux-amd64\n"), Data: sha1File.Bytes(),
}, },
binCoderTarZstd: &fstest.MapFile{ binCoderTarZstd: &fstest.MapFile{
// echo -n compressed >coder-linux-amd64 // echo -n compressed >coder-linux-amd64
@ -241,9 +249,11 @@ func TestServingBin(t *testing.T) {
delete(sampleBinFSMissingSha256, binCoderSha1) delete(sampleBinFSMissingSha256, binCoderSha1)
type req struct { type req struct {
url string url string
wantStatus int ifNoneMatch string
wantBody []byte wantStatus int
wantBody []byte
wantEtag string
} }
tests := []struct { tests := []struct {
name string name string
@ -255,7 +265,19 @@ func TestServingBin(t *testing.T) {
name: "Extract and serve bin", name: "Extract and serve bin",
fs: sampleBinFS(), fs: sampleBinFS(),
reqs: []req{ reqs: []req{
{url: "/bin/coder-linux-amd64", wantStatus: http.StatusOK, wantBody: []byte("compressed")}, {
url: "/bin/coder-linux-amd64",
wantStatus: http.StatusOK,
wantBody: []byte("compressed"),
wantEtag: fmt.Sprintf("%q", sampleBinSHAs["coder-linux-amd64"]),
},
// Test ETag support.
{
url: "/bin/coder-linux-amd64",
ifNoneMatch: fmt.Sprintf("%q", sampleBinSHAs["coder-linux-amd64"]),
wantStatus: http.StatusNotModified,
wantEtag: fmt.Sprintf("%q", sampleBinSHAs["coder-linux-amd64"]),
},
{url: "/bin/GITKEEP", wantStatus: http.StatusNotFound}, {url: "/bin/GITKEEP", wantStatus: http.StatusNotFound},
}, },
}, },
@ -329,14 +351,14 @@ func TestServingBin(t *testing.T) {
t.Parallel() t.Parallel()
dest := t.TempDir() dest := t.TempDir()
binFS, err := site.ExtractOrReadBinFS(dest, tt.fs) binFS, binHashes, err := site.ExtractOrReadBinFS(dest, tt.fs)
if !tt.wantErr && err != nil { if !tt.wantErr && err != nil {
require.NoError(t, err, "extract or read failed") require.NoError(t, err, "extract or read failed")
} else if tt.wantErr { } else if tt.wantErr {
require.Error(t, err, "extraction or read did not fail") require.Error(t, err, "extraction or read did not fail")
} }
srv := httptest.NewServer(site.Handler(rootFS, binFS)) srv := httptest.NewServer(site.Handler(rootFS, binFS, binHashes))
defer srv.Close() defer srv.Close()
// Create a context // Create a context
@ -348,6 +370,10 @@ func TestServingBin(t *testing.T) {
req, err := http.NewRequestWithContext(ctx, "GET", srv.URL+tr.url, nil) req, err := http.NewRequestWithContext(ctx, "GET", srv.URL+tr.url, nil)
require.NoError(t, err, "http request failed") require.NoError(t, err, "http request failed")
if tr.ifNoneMatch != "" {
req.Header.Set("If-None-Match", tr.ifNoneMatch)
}
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
require.NoError(t, err, "http do failed") require.NoError(t, err, "http do failed")
defer resp.Body.Close() defer resp.Body.Close()
@ -361,6 +387,14 @@ func TestServingBin(t *testing.T) {
if tr.wantBody != nil { if tr.wantBody != nil {
assert.Equal(t, string(tr.wantBody), string(gotBody), "body did not match") assert.Equal(t, string(tr.wantBody), string(gotBody), "body did not match")
} }
if tr.wantStatus == http.StatusNoContent || tr.wantStatus == http.StatusNotModified {
assert.Empty(t, gotBody, "body is not empty")
}
if tr.wantEtag != "" {
assert.NotEmpty(t, resp.Header.Get("ETag"), "etag header is empty")
assert.Equal(t, tr.wantEtag, resp.Header.Get("ETag"), "etag did not match")
}
}) })
} }
}) })
@ -374,7 +408,7 @@ func TestExtractOrReadBinFS(t *testing.T) {
siteFS := sampleBinFS() siteFS := sampleBinFS()
dest := t.TempDir() dest := t.TempDir()
_, err := site.ExtractOrReadBinFS(dest, siteFS) _, _, err := site.ExtractOrReadBinFS(dest, siteFS)
require.NoError(t, err) require.NoError(t, err)
checkModtime := func() map[string]time.Time { checkModtime := func() map[string]time.Time {
@ -402,7 +436,7 @@ func TestExtractOrReadBinFS(t *testing.T) {
firstModtimes := checkModtime() firstModtimes := checkModtime()
_, err = site.ExtractOrReadBinFS(dest, siteFS) _, _, err = site.ExtractOrReadBinFS(dest, siteFS)
require.NoError(t, err) require.NoError(t, err)
secondModtimes := checkModtime() secondModtimes := checkModtime()
@ -414,7 +448,7 @@ func TestExtractOrReadBinFS(t *testing.T) {
siteFS := sampleBinFS() siteFS := sampleBinFS()
dest := t.TempDir() dest := t.TempDir()
_, err := site.ExtractOrReadBinFS(dest, siteFS) _, _, err := site.ExtractOrReadBinFS(dest, siteFS)
require.NoError(t, err) require.NoError(t, err)
bin := filepath.Join(dest, "bin", "coder-linux-amd64") bin := filepath.Join(dest, "bin", "coder-linux-amd64")
@ -428,7 +462,7 @@ func TestExtractOrReadBinFS(t *testing.T) {
err = f.Close() err = f.Close()
require.NoError(t, err) require.NoError(t, err)
_, err = site.ExtractOrReadBinFS(dest, siteFS) _, _, err = site.ExtractOrReadBinFS(dest, siteFS)
require.NoError(t, err) require.NoError(t, err)
f, err = os.Open(bin) f, err = os.Open(bin)
@ -441,6 +475,17 @@ func TestExtractOrReadBinFS(t *testing.T) {
assert.NotEqual(t, dontWant, got, "file should be overwritten on hash mismatch") assert.NotEqual(t, dontWant, got, "file should be overwritten on hash mismatch")
}) })
t.Run("ParsesHashes", func(t *testing.T) {
t.Parallel()
siteFS := sampleBinFS()
dest := t.TempDir()
_, hashes, err := site.ExtractOrReadBinFS(dest, siteFS)
require.NoError(t, err)
require.Equal(t, sampleBinSHAs, hashes, "hashes did not match")
})
} }
func TestRenderStaticErrorPage(t *testing.T) { func TestRenderStaticErrorPage(t *testing.T) {