coder/coderd/azureidentity/azureidentity.go

79 lines
2.3 KiB
Go

package azureidentity
import (
"context"
"crypto/x509"
"encoding/base64"
"encoding/json"
"io"
"net/http"
"regexp"
"go.mozilla.org/pkcs7"
"golang.org/x/xerrors"
)
// allowedSigners matches valid common names listed here:
// https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service?tabs=linux#tabgroup_14
var allowedSigners = regexp.MustCompile(`^(.*\.)?metadata\.(azure\.(com|us|cn)|microsoftazure\.de)$`)
type metadata struct {
VMID string `json:"vmId"`
}
// Validate ensures the signature was signed by an Azure certificate.
// It returns the associated VM ID if successful.
func Validate(ctx context.Context, signature string, options x509.VerifyOptions) (string, error) {
data, err := base64.StdEncoding.DecodeString(signature)
if err != nil {
return "", xerrors.Errorf("decode base64: %w", err)
}
pkcs7Data, err := pkcs7.Parse(data)
if err != nil {
return "", xerrors.Errorf("parse pkcs7: %w", err)
}
signer := pkcs7Data.GetOnlySigner()
if signer == nil {
return "", xerrors.New("no signers for signature")
}
if !allowedSigners.MatchString(signer.Subject.CommonName) {
return "", xerrors.Errorf("unmatched common name of signer: %q", signer.Subject.CommonName)
}
if options.Intermediates == nil {
options.Intermediates = x509.NewCertPool()
for _, certURL := range signer.IssuingCertificateURL {
req, err := http.NewRequestWithContext(ctx, "GET", certURL, nil)
if err != nil {
return "", xerrors.Errorf("new request %q: %w", certURL, err)
}
res, err := http.DefaultClient.Do(req)
if err != nil {
return "", xerrors.Errorf("perform request %q: %w", certURL, err)
}
data, err := io.ReadAll(res.Body)
if err != nil {
_ = res.Body.Close()
return "", xerrors.Errorf("read body %q: %w", certURL, err)
}
_ = res.Body.Close()
cert, err := x509.ParseCertificate(data)
if err != nil {
return "", xerrors.Errorf("parse certificate %q: %w", certURL, err)
}
options.Intermediates.AddCert(cert)
}
}
_, err = signer.Verify(options)
if err != nil {
return "", xerrors.Errorf("verify certificates: %w", err)
}
var metadata metadata
err = json.Unmarshal(pkcs7Data.Content, &metadata)
if err != nil {
return "", xerrors.Errorf("unmarshal metadata: %w", err)
}
return metadata.VMID, nil
}