diff --git a/coderd/coderd.go b/coderd/coderd.go index a06b1aedce..c374b6ac60 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -197,7 +197,7 @@ func New(options *Options) *API { if siteCacheDir != "" { siteCacheDir = filepath.Join(siteCacheDir, "site") } - binFS, err := site.ExtractOrReadBinFS(siteCacheDir, site.FS()) + binFS, binHashes, err := site.ExtractOrReadBinFS(siteCacheDir, site.FS()) if err != nil { panic(xerrors.Errorf("read site bin failed: %w", err)) } @@ -213,7 +213,7 @@ func New(options *Options) *API { ID: uuid.New(), Options: options, RootHandler: r, - siteHandler: site.Handler(site.FS(), binFS), + siteHandler: site.Handler(site.FS(), binFS, binHashes), HTTPAuth: &HTTPAuthorizer{ Authorizer: options.Authorizer, Logger: options.Logger, diff --git a/site/site.go b/site/site.go index a06500b1c1..4a0e739563 100644 --- a/site/site.go +++ b/site/site.go @@ -16,6 +16,7 @@ import ( "path" "path/filepath" "strings" + "sync" "text/template" // html/template escapes some nonces "time" @@ -24,6 +25,7 @@ import ( "github.com/unrolled/secure" "golang.org/x/exp/slices" "golang.org/x/sync/errgroup" + "golang.org/x/sync/singleflight" "golang.org/x/xerrors" "github.com/coder/coder/coderd/httpapi" @@ -48,7 +50,7 @@ func init() { } // Handler returns an HTTP handler for serving the static site. -func Handler(siteFS fs.FS, binFS http.FileSystem) http.Handler { +func Handler(siteFS fs.FS, binFS http.FileSystem, binHashes map[string]string) http.Handler { // html files are handled by a text/template. Non-html files // are served by the default file server. // @@ -59,6 +61,8 @@ func Handler(siteFS fs.FS, binFS http.FileSystem) http.Handler { panic(xerrors.Errorf("Failed to return handler for static files. Html files failed to load: %w", err)) } + binHashCache := newBinHashCache(binFS, binHashes) + mux := http.NewServeMux() mux.Handle("/bin/", http.StripPrefix("/bin", http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { // Convert underscores in the filename to hyphens. We eventually want to @@ -66,6 +70,34 @@ func Handler(siteFS fs.FS, binFS http.FileSystem) http.Handler { // support both for now. r.URL.Path = strings.ReplaceAll(r.URL.Path, "_", "-") + // Set ETag header to the SHA1 hash of the file contents. + name := filePath(r.URL.Path) + if name == "" || name == "/" { + // Serve the directory listing. + http.FileServer(binFS).ServeHTTP(rw, r) + return + } + if strings.Contains(name, "/") { + // We only serve files from the root of this directory, so avoid any + // shenanigans by blocking slashes in the URL path. + http.NotFound(rw, r) + return + } + hash, err := binHashCache.getHash(name) + if xerrors.Is(err, os.ErrNotExist) { + http.NotFound(rw, r) + return + } + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + + // ETag header needs to be quoted. + rw.Header().Set("ETag", fmt.Sprintf(`%q`, hash)) + + // http.FileServer will see the ETag header and automatically handle + // If-Match and If-None-Match headers on the request properly. http.FileServer(binFS).ServeHTTP(rw, r) }))) mux.Handle("/", http.FileServer(http.FS(siteFS))) // All other non-html static files. @@ -409,20 +441,23 @@ func htmlFiles(files fs.FS) (*htmlTemplates, error) { }, nil } -// ExtractOrReadBinFS checks the provided fs for compressed coder -// binaries and extracts them into dest/bin if found. As a fallback, -// the provided FS is checked for a /bin directory, if it is non-empty -// it is returned. Finally dest/bin is returned as a fallback allowing -// binaries to be manually placed in dest (usually -// ${CODER_CACHE_DIRECTORY}/site/bin). -func ExtractOrReadBinFS(dest string, siteFS fs.FS) (http.FileSystem, error) { +// ExtractOrReadBinFS checks the provided fs for compressed coder binaries and +// extracts them into dest/bin if found. As a fallback, the provided FS is +// checked for a /bin directory, if it is non-empty it is returned. Finally +// dest/bin is returned as a fallback allowing binaries to be manually placed in +// dest (usually ${CODER_CACHE_DIRECTORY}/site/bin). +// +// Returns a http.FileSystem that serves unpacked binaries, and a map of binary +// name to SHA1 hash. The returned hash map may be incomplete or contain hashes +// for missing files. +func ExtractOrReadBinFS(dest string, siteFS fs.FS) (http.FileSystem, map[string]string, error) { if dest == "" { // No destination on fs, embedded fs is the only option. binFS, err := fs.Sub(siteFS, "bin") if err != nil { - return nil, xerrors.Errorf("cache path is empty and embedded fs does not have /bin: %w", err) + return nil, nil, xerrors.Errorf("cache path is empty and embedded fs does not have /bin: %w", err) } - return http.FS(binFS), nil + return http.FS(binFS), nil, nil } dest = filepath.Join(dest, "bin") @@ -440,51 +475,63 @@ func ExtractOrReadBinFS(dest string, siteFS fs.FS) (http.FileSystem, error) { files, err := fs.ReadDir(siteFS, "bin") if err != nil { if xerrors.Is(err, fs.ErrNotExist) { - // Given fs does not have a bin directory, - // serve from cache directory. - return mkdest() + // Given fs does not have a bin directory, serve from cache + // directory without extracting anything. + binFS, err := mkdest() + if err != nil { + return nil, nil, xerrors.Errorf("mkdest failed: %w", err) + } + return binFS, map[string]string{}, nil } - return nil, xerrors.Errorf("site fs read dir failed: %w", err) + return nil, nil, xerrors.Errorf("site fs read dir failed: %w", err) } if len(filterFiles(files, "GITKEEP")) > 0 { - // If there are other files than bin/GITKEEP, - // serve the files. + // If there are other files than bin/GITKEEP, serve the files. binFS, err := fs.Sub(siteFS, "bin") if err != nil { - return nil, xerrors.Errorf("site fs sub dir failed: %w", err) + return nil, nil, xerrors.Errorf("site fs sub dir failed: %w", err) } - return http.FS(binFS), nil + return http.FS(binFS), nil, nil } - // Nothing we can do, serve the cache directory, - // thus allowing binaries to be places there. - return mkdest() + // Nothing we can do, serve the cache directory, thus allowing + // binaries to be placed there. + binFS, err := mkdest() + if err != nil { + return nil, nil, xerrors.Errorf("mkdest failed: %w", err) + } + return binFS, map[string]string{}, nil } - return nil, xerrors.Errorf("open coder binary archive failed: %w", err) + return nil, nil, xerrors.Errorf("open coder binary archive failed: %w", err) } defer archive.Close() - dir, err := mkdest() + binFS, err := mkdest() if err != nil { - return nil, err + return nil, nil, err } - ok, err := verifyBinSha1IsCurrent(dest, siteFS) + shaFiles, err := parseSHA1(siteFS) if err != nil { - return nil, xerrors.Errorf("verify coder binaries sha1 failed: %w", err) + return nil, nil, xerrors.Errorf("parse sha1 file failed: %w", err) + } + + ok, err := verifyBinSha1IsCurrent(dest, siteFS, shaFiles) + if err != nil { + return nil, nil, xerrors.Errorf("verify coder binaries sha1 failed: %w", err) } if !ok { n, err := extractBin(dest, archive) if err != nil { - return nil, xerrors.Errorf("extract coder binaries failed: %w", err) + return nil, nil, xerrors.Errorf("extract coder binaries failed: %w", err) } if n == 0 { - return nil, xerrors.New("no files were extracted from coder binaries archive") + return nil, nil, xerrors.New("no files were extracted from coder binaries archive") } } - return dir, nil + return binFS, shaFiles, nil } func filterFiles(files []fs.DirEntry, names ...string) []fs.DirEntry { @@ -501,24 +548,32 @@ func filterFiles(files []fs.DirEntry, names ...string) []fs.DirEntry { // errHashMismatch is a sentinel error used in verifyBinSha1IsCurrent. var errHashMismatch = xerrors.New("hash mismatch") -func verifyBinSha1IsCurrent(dest string, siteFS fs.FS) (ok bool, err error) { +func parseSHA1(siteFS fs.FS) (map[string]string, error) { + b, err := fs.ReadFile(siteFS, "bin/coder.sha1") + if err != nil { + return nil, xerrors.Errorf("read coder sha1 from embedded fs failed: %w", err) + } + + shaFiles := make(map[string]string) + for _, line := range bytes.Split(bytes.TrimSpace(b), []byte{'\n'}) { + parts := bytes.Split(line, []byte{' ', '*'}) + if len(parts) != 2 { + return nil, xerrors.Errorf("malformed sha1 file: %w", err) + } + shaFiles[string(parts[1])] = strings.ToLower(string(parts[0])) + } + if len(shaFiles) == 0 { + return nil, xerrors.Errorf("empty sha1 file: %w", err) + } + + return shaFiles, nil +} + +func verifyBinSha1IsCurrent(dest string, siteFS fs.FS, shaFiles map[string]string) (ok bool, err error) { b1, err := fs.ReadFile(siteFS, "bin/coder.sha1") if err != nil { return false, xerrors.Errorf("read coder sha1 from embedded fs failed: %w", err) } - // Parse sha1 file. - shaFiles := make(map[string][]byte) - for _, line := range bytes.Split(bytes.TrimSpace(b1), []byte{'\n'}) { - parts := bytes.Split(line, []byte{' ', '*'}) - if len(parts) != 2 { - return false, xerrors.Errorf("malformed sha1 file: %w", err) - } - shaFiles[string(parts[1])] = parts[0] - } - if len(shaFiles) == 0 { - return false, xerrors.Errorf("empty sha1 file: %w", err) - } - b2, err := os.ReadFile(filepath.Join(dest, "coder.sha1")) if err != nil { if xerrors.Is(err, fs.ErrNotExist) { @@ -551,7 +606,7 @@ func verifyBinSha1IsCurrent(dest string, siteFS fs.FS) (ok bool, err error) { } return xerrors.Errorf("hash file failed: %w", err) } - if !bytes.Equal(hash1, hash2) { + if !strings.EqualFold(hash1, hash2) { return errHashMismatch } return nil @@ -570,24 +625,24 @@ func verifyBinSha1IsCurrent(dest string, siteFS fs.FS) (ok bool, err error) { // sha1HashFile computes a SHA1 hash of the file, returning the hex // representation. -func sha1HashFile(name string) ([]byte, error) { +func sha1HashFile(name string) (string, error) { //#nosec // Not used for cryptography. hash := sha1.New() f, err := os.Open(name) if err != nil { - return nil, err + return "", err } defer f.Close() _, err = io.Copy(hash, f) if err != nil { - return nil, err + return "", err } b := make([]byte, hash.Size()) hash.Sum(b[:0]) - return []byte(hex.EncodeToString(b)), nil + return hex.EncodeToString(b), nil } func extractBin(dest string, r io.Reader) (numExtracted int, err error) { @@ -672,3 +727,67 @@ func RenderStaticErrorPage(rw http.ResponseWriter, r *http.Request, data ErrorPa return } } + +type binHashCache struct { + binFS http.FileSystem + + hashes map[string]string + mut sync.RWMutex + sf singleflight.Group + sem chan struct{} +} + +func newBinHashCache(binFS http.FileSystem, binHashes map[string]string) *binHashCache { + b := &binHashCache{ + binFS: binFS, + hashes: make(map[string]string, len(binHashes)), + mut: sync.RWMutex{}, + sf: singleflight.Group{}, + sem: make(chan struct{}, 4), + } + // Make a copy since we're gonna be mutating it. + for k, v := range binHashes { + b.hashes[k] = v + } + + return b +} + +func (b *binHashCache) getHash(name string) (string, error) { + b.mut.RLock() + hash, ok := b.hashes[name] + b.mut.RUnlock() + if ok { + return hash, nil + } + + // Avoid DOS by using a pool, and only doing work once per file. + v, err, _ := b.sf.Do(name, func() (interface{}, error) { + b.sem <- struct{}{} + defer func() { <-b.sem }() + + f, err := b.binFS.Open(name) + if err != nil { + return "", err + } + defer f.Close() + + h := sha1.New() //#nosec // Not used for cryptography. + _, err = io.Copy(h, f) + if err != nil { + return "", err + } + + hash := hex.EncodeToString(h.Sum(nil)) + b.mut.Lock() + b.hashes[name] = hash + b.mut.Unlock() + return hash, nil + }) + if err != nil { + return "", err + } + + //nolint:forcetypeassert + return strings.ToLower(v.(string)), nil +} diff --git a/site/site_test.go b/site/site_test.go index a4bf5eccaf..c7b301982f 100644 --- a/site/site_test.go +++ b/site/site_test.go @@ -45,7 +45,7 @@ func TestCaching(t *testing.T) { } binFS := http.FS(fstest.MapFS{}) - srv := httptest.NewServer(site.Handler(rootFS, binFS)) + srv := httptest.NewServer(site.Handler(rootFS, binFS, nil)) defer srv.Close() // Create a context @@ -105,7 +105,7 @@ func TestServingFiles(t *testing.T) { } binFS := http.FS(fstest.MapFS{}) - srv := httptest.NewServer(site.Handler(rootFS, binFS)) + srv := httptest.NewServer(site.Handler(rootFS, binFS, nil)) defer srv.Close() // Create a context @@ -185,10 +185,18 @@ const ( binCoderTarZstd = "bin/coder.tar.zst" ) +var sampleBinSHAs = map[string]string{ + "coder-linux-amd64": "55641d5d56bbb8ccf5850fe923bd971b86364604", +} + func sampleBinFS() fstest.MapFS { + sha1File := bytes.NewBuffer(nil) + for name, sha := range sampleBinSHAs { + _, _ = fmt.Fprintf(sha1File, "%s *%s\n", sha, name) + } return fstest.MapFS{ binCoderSha1: &fstest.MapFile{ - Data: []byte("55641d5d56bbb8ccf5850fe923bd971b86364604 *coder-linux-amd64\n"), + Data: sha1File.Bytes(), }, binCoderTarZstd: &fstest.MapFile{ // echo -n compressed >coder-linux-amd64 @@ -241,9 +249,11 @@ func TestServingBin(t *testing.T) { delete(sampleBinFSMissingSha256, binCoderSha1) type req struct { - url string - wantStatus int - wantBody []byte + url string + ifNoneMatch string + wantStatus int + wantBody []byte + wantEtag string } tests := []struct { name string @@ -255,7 +265,19 @@ func TestServingBin(t *testing.T) { name: "Extract and serve bin", fs: sampleBinFS(), reqs: []req{ - {url: "/bin/coder-linux-amd64", wantStatus: http.StatusOK, wantBody: []byte("compressed")}, + { + url: "/bin/coder-linux-amd64", + wantStatus: http.StatusOK, + wantBody: []byte("compressed"), + wantEtag: fmt.Sprintf("%q", sampleBinSHAs["coder-linux-amd64"]), + }, + // Test ETag support. + { + url: "/bin/coder-linux-amd64", + ifNoneMatch: fmt.Sprintf("%q", sampleBinSHAs["coder-linux-amd64"]), + wantStatus: http.StatusNotModified, + wantEtag: fmt.Sprintf("%q", sampleBinSHAs["coder-linux-amd64"]), + }, {url: "/bin/GITKEEP", wantStatus: http.StatusNotFound}, }, }, @@ -329,14 +351,14 @@ func TestServingBin(t *testing.T) { t.Parallel() dest := t.TempDir() - binFS, err := site.ExtractOrReadBinFS(dest, tt.fs) + binFS, binHashes, err := site.ExtractOrReadBinFS(dest, tt.fs) if !tt.wantErr && err != nil { require.NoError(t, err, "extract or read failed") } else if tt.wantErr { require.Error(t, err, "extraction or read did not fail") } - srv := httptest.NewServer(site.Handler(rootFS, binFS)) + srv := httptest.NewServer(site.Handler(rootFS, binFS, binHashes)) defer srv.Close() // Create a context @@ -348,6 +370,10 @@ func TestServingBin(t *testing.T) { req, err := http.NewRequestWithContext(ctx, "GET", srv.URL+tr.url, nil) require.NoError(t, err, "http request failed") + if tr.ifNoneMatch != "" { + req.Header.Set("If-None-Match", tr.ifNoneMatch) + } + resp, err := http.DefaultClient.Do(req) require.NoError(t, err, "http do failed") defer resp.Body.Close() @@ -361,6 +387,14 @@ func TestServingBin(t *testing.T) { if tr.wantBody != nil { assert.Equal(t, string(tr.wantBody), string(gotBody), "body did not match") } + if tr.wantStatus == http.StatusNoContent || tr.wantStatus == http.StatusNotModified { + assert.Empty(t, gotBody, "body is not empty") + } + + if tr.wantEtag != "" { + assert.NotEmpty(t, resp.Header.Get("ETag"), "etag header is empty") + assert.Equal(t, tr.wantEtag, resp.Header.Get("ETag"), "etag did not match") + } }) } }) @@ -374,7 +408,7 @@ func TestExtractOrReadBinFS(t *testing.T) { siteFS := sampleBinFS() dest := t.TempDir() - _, err := site.ExtractOrReadBinFS(dest, siteFS) + _, _, err := site.ExtractOrReadBinFS(dest, siteFS) require.NoError(t, err) checkModtime := func() map[string]time.Time { @@ -402,7 +436,7 @@ func TestExtractOrReadBinFS(t *testing.T) { firstModtimes := checkModtime() - _, err = site.ExtractOrReadBinFS(dest, siteFS) + _, _, err = site.ExtractOrReadBinFS(dest, siteFS) require.NoError(t, err) secondModtimes := checkModtime() @@ -414,7 +448,7 @@ func TestExtractOrReadBinFS(t *testing.T) { siteFS := sampleBinFS() dest := t.TempDir() - _, err := site.ExtractOrReadBinFS(dest, siteFS) + _, _, err := site.ExtractOrReadBinFS(dest, siteFS) require.NoError(t, err) bin := filepath.Join(dest, "bin", "coder-linux-amd64") @@ -428,7 +462,7 @@ func TestExtractOrReadBinFS(t *testing.T) { err = f.Close() require.NoError(t, err) - _, err = site.ExtractOrReadBinFS(dest, siteFS) + _, _, err = site.ExtractOrReadBinFS(dest, siteFS) require.NoError(t, err) f, err = os.Open(bin) @@ -441,6 +475,17 @@ func TestExtractOrReadBinFS(t *testing.T) { assert.NotEqual(t, dontWant, got, "file should be overwritten on hash mismatch") }) + + t.Run("ParsesHashes", func(t *testing.T) { + t.Parallel() + + siteFS := sampleBinFS() + dest := t.TempDir() + _, hashes, err := site.ExtractOrReadBinFS(dest, siteFS) + require.NoError(t, err) + + require.Equal(t, sampleBinSHAs, hashes, "hashes did not match") + }) } func TestRenderStaticErrorPage(t *testing.T) {