feat: peer wireguard (#2445)

This commit is contained in:
Colin Adler 2022-06-24 10:25:01 -05:00 committed by GitHub
parent d21ab2115d
commit 05b67ab1cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 1935 additions and 236 deletions

View File

@ -60,7 +60,7 @@ coderd/database/dump.sql: $(wildcard coderd/database/migrations/*.sql)
go run coderd/database/dump/main.go
# Generates Go code for querying the database.
coderd/database/querier.go: coderd/database/dump.sql $(wildcard coderd/database/queries/*.sql)
coderd/database/querier.go: coderd/database/sqlc.yaml coderd/database/dump.sql $(wildcard coderd/database/queries/*.sql)
coderd/database/generate.sh
# This target is deprecated, as GNU make has issues passing signals to subprocesses.

View File

@ -27,10 +27,13 @@ import (
"go.uber.org/atomic"
gossh "golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
"inet.af/netaddr"
"tailscale.com/types/key"
"cdr.dev/slog"
"github.com/coder/coder/agent/usershell"
"github.com/coder/coder/peer"
"github.com/coder/coder/peer/peerwg"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/pty"
"github.com/coder/retry"
@ -43,20 +46,31 @@ const (
)
type Options struct {
EnableWireguard bool
UploadWireguardKeys UploadWireguardKeys
ListenWireguardPeers ListenWireguardPeers
ReconnectingPTYTimeout time.Duration
EnvironmentVariables map[string]string
Logger slog.Logger
}
type Metadata struct {
OwnerEmail string `json:"owner_email"`
OwnerUsername string `json:"owner_username"`
EnvironmentVariables map[string]string `json:"environment_variables"`
StartupScript string `json:"startup_script"`
Directory string `json:"directory"`
WireguardAddresses []netaddr.IPPrefix `json:"addresses"`
OwnerEmail string `json:"owner_email"`
OwnerUsername string `json:"owner_username"`
EnvironmentVariables map[string]string `json:"environment_variables"`
StartupScript string `json:"startup_script"`
Directory string `json:"directory"`
}
type WireguardPublicKeys struct {
Public key.NodePublic `json:"public"`
Disco key.DiscoPublic `json:"disco"`
}
type Dialer func(ctx context.Context, logger slog.Logger) (Metadata, *peerbroker.Listener, error)
type UploadWireguardKeys func(ctx context.Context, keys WireguardPublicKeys) error
type ListenWireguardPeers func(ctx context.Context, logger slog.Logger) (<-chan peerwg.Handshake, func(), error)
func New(dialer Dialer, options *Options) io.Closer {
if options == nil {
@ -73,6 +87,9 @@ func New(dialer Dialer, options *Options) io.Closer {
closeCancel: cancelFunc,
closed: make(chan struct{}),
envVars: options.EnvironmentVariables,
enableWireguard: options.EnableWireguard,
postKeys: options.UploadWireguardKeys,
listenWireguardPeers: options.ListenWireguardPeers,
}
server.init(ctx)
return server
@ -95,6 +112,11 @@ type agent struct {
metadata atomic.Value
startupScript atomic.Bool
sshServer *ssh.Server
enableWireguard bool
network *peerwg.Network
postKeys UploadWireguardKeys
listenWireguardPeers ListenWireguardPeers
}
func (a *agent) run(ctx context.Context) {
@ -138,6 +160,13 @@ func (a *agent) run(ctx context.Context) {
}()
}
if a.enableWireguard {
err = a.startWireguard(ctx, metadata.WireguardAddresses)
if err != nil {
a.logger.Error(ctx, "start wireguard", slog.Error(err))
}
}
for {
conn, err := peerListener.Accept()
if err != nil {
@ -366,17 +395,17 @@ func (a *agent) createCommand(ctx context.Context, rawCommand string, env []stri
// Load environment variables passed via the agent.
// These should override all variables we manually specify.
for key, value := range metadata.EnvironmentVariables {
for envKey, value := range metadata.EnvironmentVariables {
// Expanding environment variables allows for customization
// of the $PATH, among other variables. Customers can prepand
// or append to the $PATH, so allowing expand is required!
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", key, os.ExpandEnv(value)))
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, os.ExpandEnv(value)))
}
// Agent-level environment variables should take over all!
// This is used for setting agent-specific variables like "CODER_AGENT_TOKEN".
for key, value := range a.envVars {
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", key, value))
for envKey, value := range a.envVars {
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, value))
}
return cmd, nil

63
agent/wireguard.go Normal file
View File

@ -0,0 +1,63 @@
package agent
import (
"context"
"golang.org/x/xerrors"
"inet.af/netaddr"
"cdr.dev/slog"
"github.com/coder/coder/peer/peerwg"
)
func (a *agent) startWireguard(ctx context.Context, addrs []netaddr.IPPrefix) error {
if a.network != nil {
_ = a.network.Close()
a.network = nil
}
// We can't create a wireguard network without these.
if len(addrs) == 0 || a.listenWireguardPeers == nil || a.postKeys == nil {
return xerrors.New("wireguard is enabled, but no addresses were provided or necessary functions were not provided")
}
wg, err := peerwg.New(a.logger.Named("wireguard"), addrs)
if err != nil {
return xerrors.Errorf("create wireguard network: %w", err)
}
// A new keypair is generated on each agent start.
// This keypair must be sent to Coder to allow for incoming connections.
err = a.postKeys(ctx, WireguardPublicKeys{
Public: wg.NodePrivateKey.Public(),
Disco: wg.DiscoPublicKey,
})
if err != nil {
a.logger.Warn(ctx, "post keys", slog.Error(err))
}
go func() {
for {
ch, listenClose, err := a.listenWireguardPeers(ctx, a.logger)
if err != nil {
a.logger.Warn(ctx, "listen wireguard peers", slog.Error(err))
return
}
for {
peer, ok := <-ch
if !ok {
break
}
err := wg.AddPeer(peer)
a.logger.Info(ctx, "added wireguard peer", slog.F("peer", peer.NodePublicKey.ShortString()), slog.Error(err))
}
listenClose()
}
}()
a.network = wg
return nil
}

View File

@ -14,17 +14,15 @@ import (
"cloud.google.com/go/compute/metadata"
"github.com/spf13/cobra"
"golang.org/x/xerrors"
"gopkg.in/natefinch/lumberjack.v2"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/agent"
"github.com/coder/coder/agent/reaper"
"github.com/coder/coder/cli/cliflag"
"github.com/coder/coder/codersdk"
"github.com/coder/retry"
"gopkg.in/natefinch/lumberjack.v2"
)
func workspaceAgent() *cobra.Command {
@ -33,6 +31,7 @@ func workspaceAgent() *cobra.Command {
pprofEnabled bool
pprofAddress string
noReap bool
wireguard bool
)
cmd := &cobra.Command{
Use: "agent",
@ -178,6 +177,9 @@ func workspaceAgent() *cobra.Command {
// shells so "gitssh" works!
"CODER_AGENT_TOKEN": client.SessionToken,
},
EnableWireguard: wireguard,
UploadWireguardKeys: client.UploadWorkspaceAgentKeys,
ListenWireguardPeers: client.WireguardPeerListener,
})
<-cmd.Context().Done()
return closer.Close()
@ -188,5 +190,6 @@ func workspaceAgent() *cobra.Command {
cliflag.BoolVarP(cmd.Flags(), &pprofEnabled, "pprof-enable", "", "CODER_AGENT_PPROF_ENABLE", false, "Enable serving pprof metrics on the address defined by --pprof-address.")
cliflag.BoolVarP(cmd.Flags(), &noReap, "no-reap", "", "", false, "Do not start a process reaper.")
cliflag.StringVarP(cmd.Flags(), &pprofAddress, "pprof-address", "", "CODER_AGENT_PPROF_ADDRESS", "127.0.0.1:6060", "The address to serve pprof.")
cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_AGENT_WIREGUARD", true, "Whether to start the Wireguard interface.")
return cmd
}

View File

@ -46,7 +46,7 @@ func TestWorkspaceAgent(t *testing.T) {
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
cmd, _ := clitest.New(t, "agent", "--auth", "azure-instance-identity", "--agent-url", client.URL.String())
cmd, _ := clitest.New(t, "agent", "--auth", "azure-instance-identity", "--agent-url", client.URL.String(), "--wireguard=false")
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
errC := make(chan error)
@ -101,7 +101,7 @@ func TestWorkspaceAgent(t *testing.T) {
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
cmd, _ := clitest.New(t, "agent", "--auth", "aws-instance-identity", "--agent-url", client.URL.String())
cmd, _ := clitest.New(t, "agent", "--auth", "aws-instance-identity", "--agent-url", client.URL.String(), "--wireguard=false")
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
errC := make(chan error)
@ -156,7 +156,7 @@ func TestWorkspaceAgent(t *testing.T) {
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
cmd, _ := clitest.New(t, "agent", "--auth", "google-instance-identity", "--agent-url", client.URL.String())
cmd, _ := clitest.New(t, "agent", "--auth", "google-instance-identity", "--agent-url", client.URL.String(), "--wireguard=false")
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
errC := make(chan error)

View File

@ -73,22 +73,23 @@ func Root() *cobra.Command {
list(),
login(),
logout(),
parameters(),
portForward(),
publickey(),
resetPassword(),
schedules(),
server(),
show(),
ssh(),
start(),
state(),
stop(),
ssh(),
templates(),
update(),
users(),
portForward(),
workspaceAgent(),
versionCmd(),
parameters(),
wireguardPortForward(),
workspaceAgent(),
)
cmd.SetUsageTemplate(usageTemplate())

258
cli/wireguardtunnel.go Normal file
View File

@ -0,0 +1,258 @@
package cli
import (
"context"
"fmt"
"net"
"os"
"os/signal"
"strconv"
"sync"
"syscall"
"github.com/google/uuid"
"github.com/pion/udp"
"github.com/spf13/cobra"
"golang.org/x/xerrors"
"inet.af/netaddr"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
coderagent "github.com/coder/coder/agent"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/peer/peerwg"
)
func wireguardPortForward() *cobra.Command {
var (
tcpForwards []string // <port>:<port>
udpForwards []string // <port>:<port>
// TODO: unix support
// unixForwards []string // <path>:<path> OR <port>:<path>
)
cmd := &cobra.Command{
Use: "wireguard-port-forward <workspace>",
Aliases: []string{"wireguard-tunnel"},
Args: cobra.ExactArgs(1),
// Hide all wireguard commands for now while we test!
Hidden: true,
Example: `
- Port forward a single TCP port from 1234 in the workspace to port 5678 on
your local machine
` + cliui.Styles.Code.Render("$ coder port-forward <workspace> --tcp 5678:1234") + `
- Port forward a single UDP port from port 9000 to port 9000 on your local
machine
` + cliui.Styles.Code.Render("$ coder port-forward <workspace> --udp 9000") + `
- Port forward multiple TCP ports and a UDP port
` + cliui.Styles.Code.Render("$ coder port-forward <workspace> --tcp 8080:8080 --tcp 9000:3000 --udp 5353:53") + `
`,
RunE: func(cmd *cobra.Command, args []string) error {
specs, err := parsePortForwards(tcpForwards, nil, nil)
if err != nil {
return xerrors.Errorf("parse port-forward specs: %w", err)
}
if len(specs) == 0 {
err = cmd.Help()
if err != nil {
return xerrors.Errorf("generate help output: %w", err)
}
return xerrors.New("no port-forwards requested")
}
client, err := createClient(cmd)
if err != nil {
return err
}
workspace, workspaceAgent, err := getWorkspaceAndAgent(cmd, client, codersdk.Me, args[0], false)
if err != nil {
return err
}
if workspace.LatestBuild.Transition != codersdk.WorkspaceTransitionStart {
return xerrors.New("workspace must be in start transition to port-forward")
}
if workspace.LatestBuild.Job.CompletedAt == nil {
err = cliui.WorkspaceBuild(cmd.Context(), cmd.ErrOrStderr(), client, workspace.LatestBuild.ID, workspace.CreatedAt)
if err != nil {
return err
}
}
err = cliui.Agent(cmd.Context(), cmd.ErrOrStderr(), cliui.AgentOptions{
WorkspaceName: workspace.Name,
Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) {
return client.WorkspaceAgent(ctx, workspaceAgent.ID)
},
})
if err != nil {
return xerrors.Errorf("await agent: %w", err)
}
ipv6 := peerwg.UUIDToNetaddr(uuid.New())
wgn, err := peerwg.New(
slog.Make(sloghuman.Sink(os.Stderr)),
[]netaddr.IPPrefix{netaddr.IPPrefixFrom(ipv6, 128)},
)
if err != nil {
return xerrors.Errorf("create wireguard network: %w", err)
}
err = client.PostWireguardPeer(cmd.Context(), workspace.ID, peerwg.Handshake{
Recipient: workspaceAgent.ID,
NodePublicKey: wgn.NodePrivateKey.Public(),
DiscoPublicKey: wgn.DiscoPublicKey,
IPv6: ipv6,
})
if err != nil {
return xerrors.Errorf("post wireguard peer: %w", err)
}
err = wgn.AddPeer(peerwg.Handshake{
Recipient: workspaceAgent.ID,
DiscoPublicKey: workspaceAgent.DiscoPublicKey,
NodePublicKey: workspaceAgent.WireguardPublicKey,
IPv6: workspaceAgent.IPv6.IP(),
})
if err != nil {
return xerrors.Errorf("add workspace agent as peer: %w", err)
}
// Start all listeners.
var (
ctx, cancel = context.WithCancel(cmd.Context())
wg = new(sync.WaitGroup)
listeners = make([]net.Listener, len(specs))
closeAllListeners = func() {
for _, l := range listeners {
if l == nil {
continue
}
_ = l.Close()
}
}
)
defer cancel()
for i, spec := range specs {
l, err := listenAndPortForwardWireguard(ctx, cmd, wgn, wg, spec, workspaceAgent.IPv6.IP())
if err != nil {
closeAllListeners()
return err
}
listeners[i] = l
}
// Wait for the context to be canceled or for a signal and close
// all listeners.
var closeErr error
go func() {
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
select {
case <-ctx.Done():
closeErr = ctx.Err()
case <-sigs:
_, _ = fmt.Fprintln(cmd.OutOrStderr(), "Received signal, closing all listeners and active connections")
closeErr = xerrors.New("signal received")
}
cancel()
closeAllListeners()
}()
_, _ = fmt.Fprintln(cmd.OutOrStderr(), "Ready!")
wg.Wait()
return closeErr
},
}
cmd.Flags().StringArrayVarP(&tcpForwards, "tcp", "p", []string{}, "Forward a TCP port from the workspace to the local machine")
cmd.Flags().StringArrayVar(&udpForwards, "udp", []string{}, "Forward a UDP port from the workspace to the local machine. The UDP connection has TCP-like semantics to support stateful UDP protocols")
// cmd.Flags().StringArrayVar(&unixForwards, "unix", []string{}, "Forward a Unix socket in the workspace to a local Unix socket or TCP port")
return cmd
}
func listenAndPortForwardWireguard(ctx context.Context, cmd *cobra.Command,
wgn *peerwg.Network,
wg *sync.WaitGroup,
spec portForwardSpec,
agentIP netaddr.IP,
) (net.Listener, error) {
_, _ = fmt.Fprintf(cmd.OutOrStderr(), "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress)
var (
l net.Listener
err error
)
switch spec.listenNetwork {
case "tcp":
l, err = net.Listen(spec.listenNetwork, spec.listenAddress)
case "udp":
var host, port string
host, port, err = net.SplitHostPort(spec.listenAddress)
if err != nil {
return nil, xerrors.Errorf("split %q: %w", spec.listenAddress, err)
}
var portInt int
portInt, err = strconv.Atoi(port)
if err != nil {
return nil, xerrors.Errorf("parse port %v from %q as int: %w", port, spec.listenAddress, err)
}
l, err = udp.Listen(spec.listenNetwork, &net.UDPAddr{
IP: net.ParseIP(host),
Port: portInt,
})
// case "unix":
// l, err = net.Listen(spec.listenNetwork, spec.listenAddress)
default:
return nil, xerrors.Errorf("unknown listen network %q", spec.listenNetwork)
}
if err != nil {
return nil, xerrors.Errorf("listen '%v://%v': %w", spec.listenNetwork, spec.listenAddress, err)
}
wg.Add(1)
go func(spec portForwardSpec) {
defer wg.Done()
for {
netConn, err := l.Accept()
if err != nil {
_, _ = fmt.Fprintf(cmd.OutOrStderr(), "Error accepting connection from '%v://%v': %+v\n", spec.listenNetwork, spec.listenAddress, err)
_, _ = fmt.Fprintln(cmd.OutOrStderr(), "Killing listener")
return
}
go func(netConn net.Conn) {
defer netConn.Close()
ipPort := netaddr.MustParseIPPort(spec.dialAddress).WithIP(agentIP)
var remoteConn net.Conn
switch spec.dialNetwork {
case "tcp":
remoteConn, err = wgn.Netstack.DialContextTCP(ctx, ipPort)
case "udp":
remoteConn, err = wgn.Netstack.DialContextUDP(ctx, ipPort)
}
if err != nil {
_, _ = fmt.Fprintf(cmd.OutOrStderr(), "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err)
return
}
defer remoteConn.Close()
coderagent.Bicopy(ctx, netConn, remoteConn)
}(netConn)
}
}(spec)
return l, nil
}

View File

@ -6,7 +6,6 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/rbac"
)

View File

@ -307,6 +307,9 @@ func New(options *Options) *API {
r.Get("/gitsshkey", api.agentGitSSHKey)
r.Get("/turn", api.workspaceAgentTurn)
r.Get("/iceservers", api.workspaceAgentICEServers)
r.Get("/wireguardlisten", api.workspaceAgentWireguardListener)
r.Post("/keys", api.postWorkspaceAgentKeys)
r.Get("/derp", api.derpMap)
})
r.Route("/{workspaceagent}", func(r chi.Router) {
r.Use(
@ -315,10 +318,12 @@ func New(options *Options) *API {
httpmw.ExtractWorkspaceParam(options.Database),
)
r.Get("/", api.workspaceAgent)
r.Post("/peer", api.postWorkspaceAgentWireguardPeer)
r.Get("/dial", api.workspaceAgentDial)
r.Get("/turn", api.workspaceAgentTurn)
r.Get("/pty", api.workspaceAgentPTY)
r.Get("/iceservers", api.workspaceAgentICEServers)
r.Get("/derp", api.derpMap)
})
})
r.Route("/workspaceresources/{workspaceresource}", func(r chi.Router) {

View File

@ -154,8 +154,12 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
"GET:/api/v2/workspaceagents/me/listen": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/metadata": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/turn": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/derp": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/wireguardlisten": {NoAuthorize: true},
"POST:/api/v2/workspaceagents/me/keys": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/{workspaceagent}/iceservers": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/{workspaceagent}/turn": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/{workspaceagent}/derp": {NoAuthorize: true},
// These endpoints have more assertions. This is good, add more endpoints to assert if you can!
"GET:/api/v2/organizations/{organization}": {AssertObject: rbac.ResourceOrganization.InOrg(admin.OrganizationID)},

View File

@ -25,17 +25,12 @@ import (
"testing"
"time"
"github.com/spf13/afero"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/telemetry"
"github.com/coder/coder/coderd/util/ptr"
"cloud.google.com/go/compute/metadata"
"github.com/fullsailor/pkcs7"
"github.com/golang-jwt/jwt"
"github.com/google/uuid"
"github.com/moby/moby/pkg/namesgenerator"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/api/idtoken"
@ -50,7 +45,10 @@ import (
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/coderd/database/postgres"
"github.com/coder/coder/coderd/gitsshkey"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/telemetry"
"github.com/coder/coder/coderd/turnconn"
"github.com/coder/coder/coderd/util/ptr"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/cryptorand"
"github.com/coder/coder/provisioner/echo"

View File

@ -1168,7 +1168,7 @@ func (q *fakeQuerier) GetWorkspaceAgentByAuthToken(_ context.Context, authToken
// The schema sorts this by created at, so we iterate the array backwards.
for i := len(q.provisionerJobAgents) - 1; i >= 0; i-- {
agent := q.provisionerJobAgents[i]
if agent.AuthToken.String() == authToken.String() {
if agent.AuthToken == authToken {
return agent, nil
}
}
@ -1182,7 +1182,7 @@ func (q *fakeQuerier) GetWorkspaceAgentByID(_ context.Context, id uuid.UUID) (da
// The schema sorts this by created at, so we iterate the array backwards.
for i := len(q.provisionerJobAgents) - 1; i >= 0; i-- {
agent := q.provisionerJobAgents[i]
if agent.ID.String() == id.String() {
if agent.ID == id {
return agent, nil
}
}
@ -1210,7 +1210,7 @@ func (q *fakeQuerier) GetWorkspaceAgentsByResourceIDs(_ context.Context, resourc
workspaceAgents := make([]database.WorkspaceAgent, 0)
for _, agent := range q.provisionerJobAgents {
for _, resourceID := range resourceIDs {
if agent.ResourceID.String() != resourceID.String() {
if agent.ResourceID != resourceID {
continue
}
workspaceAgents = append(workspaceAgents, agent)
@ -1269,7 +1269,7 @@ func (q *fakeQuerier) GetProvisionerJobByID(_ context.Context, id uuid.UUID) (da
defer q.mutex.RUnlock()
for _, provisionerJob := range q.provisionerJobs {
if provisionerJob.ID.String() != id.String() {
if provisionerJob.ID != id {
continue
}
return provisionerJob, nil
@ -1604,23 +1604,26 @@ func (q *fakeQuerier) InsertWorkspaceAgent(_ context.Context, arg database.Inser
q.mutex.Lock()
defer q.mutex.Unlock()
//nolint:gosimple
agent := database.WorkspaceAgent{
ID: arg.ID,
CreatedAt: arg.CreatedAt,
UpdatedAt: arg.UpdatedAt,
ResourceID: arg.ResourceID,
AuthToken: arg.AuthToken,
AuthInstanceID: arg.AuthInstanceID,
EnvironmentVariables: arg.EnvironmentVariables,
Name: arg.Name,
Architecture: arg.Architecture,
OperatingSystem: arg.OperatingSystem,
Directory: arg.Directory,
StartupScript: arg.StartupScript,
InstanceMetadata: arg.InstanceMetadata,
ResourceMetadata: arg.ResourceMetadata,
ID: arg.ID,
CreatedAt: arg.CreatedAt,
UpdatedAt: arg.UpdatedAt,
ResourceID: arg.ResourceID,
AuthToken: arg.AuthToken,
AuthInstanceID: arg.AuthInstanceID,
EnvironmentVariables: arg.EnvironmentVariables,
Name: arg.Name,
Architecture: arg.Architecture,
OperatingSystem: arg.OperatingSystem,
Directory: arg.Directory,
StartupScript: arg.StartupScript,
InstanceMetadata: arg.InstanceMetadata,
ResourceMetadata: arg.ResourceMetadata,
WireguardNodeIPv6: arg.WireguardNodeIPv6,
WireguardNodePublicKey: arg.WireguardNodePublicKey,
WireguardDiscoPublicKey: arg.WireguardDiscoPublicKey,
}
q.provisionerJobAgents = append(q.provisionerJobAgents, agent)
return agent, nil
}
@ -1874,7 +1877,7 @@ func (q *fakeQuerier) UpdateTemplateVersionDescriptionByJobID(_ context.Context,
continue
}
templateVersion.Readme = arg.Readme
templateVersion.UpdatedAt = time.Now()
templateVersion.UpdatedAt = database.Now()
q.templateVersions[index] = templateVersion
return nil
}
@ -1914,6 +1917,24 @@ func (q *fakeQuerier) UpdateWorkspaceAgentConnectionByID(_ context.Context, arg
return sql.ErrNoRows
}
func (q *fakeQuerier) UpdateWorkspaceAgentKeysByID(_ context.Context, arg database.UpdateWorkspaceAgentKeysByIDParams) error {
q.mutex.Lock()
defer q.mutex.Unlock()
for index, agent := range q.provisionerJobAgents {
if agent.ID != arg.ID {
continue
}
agent.WireguardNodePublicKey = arg.WireguardNodePublicKey
agent.WireguardDiscoPublicKey = arg.WireguardDiscoPublicKey
agent.UpdatedAt = database.Now()
q.provisionerJobAgents[index] = agent
return nil
}
return sql.ErrNoRows
}
func (q *fakeQuerier) UpdateProvisionerJobByID(_ context.Context, arg database.UpdateProvisionerJobByIDParams) error {
q.mutex.Lock()
defer q.mutex.Unlock()

View File

@ -292,7 +292,10 @@ CREATE TABLE workspace_agents (
startup_script character varying(65534),
instance_metadata jsonb,
resource_metadata jsonb,
directory character varying(4096) DEFAULT ''::character varying NOT NULL
directory character varying(4096) DEFAULT ''::character varying NOT NULL,
wireguard_node_ipv6 inet DEFAULT '::'::inet NOT NULL,
wireguard_node_public_key character varying(128) DEFAULT 'mkey:0000000000000000000000000000000000000000000000000000000000000000'::character varying NOT NULL,
wireguard_disco_public_key character varying(128) DEFAULT 'discokey:0000000000000000000000000000000000000000000000000000000000000000'::character varying NOT NULL
);
CREATE TABLE workspace_apps (

View File

@ -0,0 +1,4 @@
ALTER TABLE workspace_agents
DROP COLUMN wireguard_node_ipv6,
DROP COLUMN wireguard_node_public_key,
DROP COLUMN wireguard_disco_public_key;

View File

@ -0,0 +1,4 @@
ALTER TABLE workspace_agents
ADD COLUMN wireguard_node_ipv6 inet NOT NULL DEFAULT '::/128',
ADD COLUMN wireguard_node_public_key varchar(128) NOT NULL DEFAULT 'mkey:0000000000000000000000000000000000000000000000000000000000000000',
ADD COLUMN wireguard_disco_public_key varchar(128) NOT NULL DEFAULT 'discokey:0000000000000000000000000000000000000000000000000000000000000000';

View File

@ -503,23 +503,26 @@ type Workspace struct {
}
type WorkspaceAgent struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Name string `db:"name" json:"name"`
FirstConnectedAt sql.NullTime `db:"first_connected_at" json:"first_connected_at"`
LastConnectedAt sql.NullTime `db:"last_connected_at" json:"last_connected_at"`
DisconnectedAt sql.NullTime `db:"disconnected_at" json:"disconnected_at"`
ResourceID uuid.UUID `db:"resource_id" json:"resource_id"`
AuthToken uuid.UUID `db:"auth_token" json:"auth_token"`
AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"`
Architecture string `db:"architecture" json:"architecture"`
EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"`
OperatingSystem string `db:"operating_system" json:"operating_system"`
StartupScript sql.NullString `db:"startup_script" json:"startup_script"`
InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"`
ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"`
Directory string `db:"directory" json:"directory"`
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Name string `db:"name" json:"name"`
FirstConnectedAt sql.NullTime `db:"first_connected_at" json:"first_connected_at"`
LastConnectedAt sql.NullTime `db:"last_connected_at" json:"last_connected_at"`
DisconnectedAt sql.NullTime `db:"disconnected_at" json:"disconnected_at"`
ResourceID uuid.UUID `db:"resource_id" json:"resource_id"`
AuthToken uuid.UUID `db:"auth_token" json:"auth_token"`
AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"`
Architecture string `db:"architecture" json:"architecture"`
EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"`
OperatingSystem string `db:"operating_system" json:"operating_system"`
StartupScript sql.NullString `db:"startup_script" json:"startup_script"`
InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"`
ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"`
Directory string `db:"directory" json:"directory"`
WireguardNodeIPv6 pqtype.Inet `db:"wireguard_node_ipv6" json:"wireguard_node_ipv6"`
WireguardNodePublicKey string `db:"wireguard_node_public_key" json:"wireguard_node_public_key"`
WireguardDiscoPublicKey string `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"`
}
type WorkspaceApp struct {

View File

@ -28,7 +28,7 @@ type pgPubsub struct {
pgListener *pq.Listener
db *sql.DB
mut sync.Mutex
listeners map[string]map[string]Listener
listeners map[string]map[uuid.UUID]Listener
}
// Subscribe calls the listener when an event matching the name is received.
@ -45,20 +45,22 @@ func (p *pgPubsub) Subscribe(event string, listener Listener) (cancel func(), er
return nil, xerrors.Errorf("listen: %w", err)
}
var listeners map[string]Listener
var eventListeners map[uuid.UUID]Listener
var ok bool
if listeners, ok = p.listeners[event]; !ok {
listeners = map[string]Listener{}
p.listeners[event] = listeners
if eventListeners, ok = p.listeners[event]; !ok {
eventListeners = map[uuid.UUID]Listener{}
p.listeners[event] = eventListeners
}
var id string
var id uuid.UUID
for {
id = uuid.New().String()
if _, ok = listeners[id]; !ok {
id = uuid.New()
if _, ok = eventListeners[id]; !ok {
break
}
}
listeners[id] = listener
eventListeners[id] = listener
return func() {
p.mut.Lock()
defer p.mut.Unlock()
@ -77,7 +79,7 @@ func (p *pgPubsub) Publish(event string, message []byte) error {
//nolint:gosec
_, err := p.db.ExecContext(context.Background(), `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message)
if err != nil {
return xerrors.Errorf("exec: %w", err)
return xerrors.Errorf("exec pg_notify: %w", err)
}
return nil
}
@ -128,7 +130,7 @@ func (p *pgPubsub) listenReceive(ctx context.Context, notif *pq.Notification) {
func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub, error) {
// Creates a new listener using pq.
errCh := make(chan error)
listener := pq.NewListener(connectURL, time.Second*10, time.Minute, func(event pq.ListenerEventType, err error) {
listener := pq.NewListener(connectURL, time.Second, time.Minute, func(event pq.ListenerEventType, err error) {
// This callback gets events whenever the connection state changes.
// Don't send if the errChannel has already been closed.
select {
@ -150,7 +152,7 @@ func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub
pgPubsub := &pgPubsub{
db: database,
pgListener: listener,
listeners: make(map[string]map[string]Listener),
listeners: make(map[string]map[uuid.UUID]Listener),
}
go pgPubsub.listen(ctx)

View File

@ -127,6 +127,7 @@ type querier interface {
UpdateUserRoles(ctx context.Context, arg UpdateUserRolesParams) (User, error)
UpdateUserStatus(ctx context.Context, arg UpdateUserStatusParams) (User, error)
UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg UpdateWorkspaceAgentConnectionByIDParams) error
UpdateWorkspaceAgentKeysByID(ctx context.Context, arg UpdateWorkspaceAgentKeysByIDParams) error
UpdateWorkspaceAutostart(ctx context.Context, arg UpdateWorkspaceAutostartParams) error
UpdateWorkspaceBuildByID(ctx context.Context, arg UpdateWorkspaceBuildByIDParams) error
UpdateWorkspaceDeletedByID(ctx context.Context, arg UpdateWorkspaceDeletedByIDParams) error

View File

@ -2850,7 +2850,7 @@ func (q *sqlQuerier) UpdateUserStatus(ctx context.Context, arg UpdateUserStatusP
const getWorkspaceAgentByAuthToken = `-- name: GetWorkspaceAgentByAuthToken :one
SELECT
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key
FROM
workspace_agents
WHERE
@ -2880,13 +2880,16 @@ func (q *sqlQuerier) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken
&i.InstanceMetadata,
&i.ResourceMetadata,
&i.Directory,
&i.WireguardNodeIPv6,
&i.WireguardNodePublicKey,
&i.WireguardDiscoPublicKey,
)
return i, err
}
const getWorkspaceAgentByID = `-- name: GetWorkspaceAgentByID :one
SELECT
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key
FROM
workspace_agents
WHERE
@ -2914,13 +2917,16 @@ func (q *sqlQuerier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (W
&i.InstanceMetadata,
&i.ResourceMetadata,
&i.Directory,
&i.WireguardNodeIPv6,
&i.WireguardNodePublicKey,
&i.WireguardDiscoPublicKey,
)
return i, err
}
const getWorkspaceAgentByInstanceID = `-- name: GetWorkspaceAgentByInstanceID :one
SELECT
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key
FROM
workspace_agents
WHERE
@ -2950,13 +2956,16 @@ func (q *sqlQuerier) GetWorkspaceAgentByInstanceID(ctx context.Context, authInst
&i.InstanceMetadata,
&i.ResourceMetadata,
&i.Directory,
&i.WireguardNodeIPv6,
&i.WireguardNodePublicKey,
&i.WireguardDiscoPublicKey,
)
return i, err
}
const getWorkspaceAgentsByResourceIDs = `-- name: GetWorkspaceAgentsByResourceIDs :many
SELECT
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory
id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key
FROM
workspace_agents
WHERE
@ -2990,6 +2999,9 @@ func (q *sqlQuerier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids []
&i.InstanceMetadata,
&i.ResourceMetadata,
&i.Directory,
&i.WireguardNodeIPv6,
&i.WireguardNodePublicKey,
&i.WireguardDiscoPublicKey,
); err != nil {
return nil, err
}
@ -3005,7 +3017,7 @@ func (q *sqlQuerier) GetWorkspaceAgentsByResourceIDs(ctx context.Context, ids []
}
const getWorkspaceAgentsCreatedAfter = `-- name: GetWorkspaceAgentsCreatedAfter :many
SELECT id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory FROM workspace_agents WHERE created_at > $1
SELECT id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key FROM workspace_agents WHERE created_at > $1
`
func (q *sqlQuerier) GetWorkspaceAgentsCreatedAfter(ctx context.Context, createdAt time.Time) ([]WorkspaceAgent, error) {
@ -3035,6 +3047,9 @@ func (q *sqlQuerier) GetWorkspaceAgentsCreatedAfter(ctx context.Context, created
&i.InstanceMetadata,
&i.ResourceMetadata,
&i.Directory,
&i.WireguardNodeIPv6,
&i.WireguardNodePublicKey,
&i.WireguardDiscoPublicKey,
); err != nil {
return nil, err
}
@ -3065,27 +3080,33 @@ INSERT INTO
startup_script,
directory,
instance_metadata,
resource_metadata
resource_metadata,
wireguard_node_ipv6,
wireguard_node_public_key,
wireguard_disco_public_key
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) RETURNING id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) RETURNING id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, wireguard_node_ipv6, wireguard_node_public_key, wireguard_disco_public_key
`
type InsertWorkspaceAgentParams struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Name string `db:"name" json:"name"`
ResourceID uuid.UUID `db:"resource_id" json:"resource_id"`
AuthToken uuid.UUID `db:"auth_token" json:"auth_token"`
AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"`
Architecture string `db:"architecture" json:"architecture"`
EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"`
OperatingSystem string `db:"operating_system" json:"operating_system"`
StartupScript sql.NullString `db:"startup_script" json:"startup_script"`
Directory string `db:"directory" json:"directory"`
InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"`
ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"`
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Name string `db:"name" json:"name"`
ResourceID uuid.UUID `db:"resource_id" json:"resource_id"`
AuthToken uuid.UUID `db:"auth_token" json:"auth_token"`
AuthInstanceID sql.NullString `db:"auth_instance_id" json:"auth_instance_id"`
Architecture string `db:"architecture" json:"architecture"`
EnvironmentVariables pqtype.NullRawMessage `db:"environment_variables" json:"environment_variables"`
OperatingSystem string `db:"operating_system" json:"operating_system"`
StartupScript sql.NullString `db:"startup_script" json:"startup_script"`
Directory string `db:"directory" json:"directory"`
InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"`
ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"`
WireguardNodeIPv6 pqtype.Inet `db:"wireguard_node_ipv6" json:"wireguard_node_ipv6"`
WireguardNodePublicKey string `db:"wireguard_node_public_key" json:"wireguard_node_public_key"`
WireguardDiscoPublicKey string `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"`
}
func (q *sqlQuerier) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspaceAgentParams) (WorkspaceAgent, error) {
@ -3104,6 +3125,9 @@ func (q *sqlQuerier) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspa
arg.Directory,
arg.InstanceMetadata,
arg.ResourceMetadata,
arg.WireguardNodeIPv6,
arg.WireguardNodePublicKey,
arg.WireguardDiscoPublicKey,
)
var i WorkspaceAgent
err := row.Scan(
@ -3124,6 +3148,9 @@ func (q *sqlQuerier) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspa
&i.InstanceMetadata,
&i.ResourceMetadata,
&i.Directory,
&i.WireguardNodeIPv6,
&i.WireguardNodePublicKey,
&i.WireguardDiscoPublicKey,
)
return i, err
}
@ -3132,6 +3159,7 @@ const updateWorkspaceAgentConnectionByID = `-- name: UpdateWorkspaceAgentConnect
UPDATE
workspace_agents
SET
updated_at = now(),
first_connected_at = $2,
last_connected_at = $3,
disconnected_at = $4
@ -3156,6 +3184,28 @@ func (q *sqlQuerier) UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg
return err
}
const updateWorkspaceAgentKeysByID = `-- name: UpdateWorkspaceAgentKeysByID :exec
UPDATE
workspace_agents
SET
updated_at = now(),
wireguard_node_public_key = $2,
wireguard_disco_public_key = $3
WHERE
id = $1
`
type UpdateWorkspaceAgentKeysByIDParams struct {
ID uuid.UUID `db:"id" json:"id"`
WireguardNodePublicKey string `db:"wireguard_node_public_key" json:"wireguard_node_public_key"`
WireguardDiscoPublicKey string `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"`
}
func (q *sqlQuerier) UpdateWorkspaceAgentKeysByID(ctx context.Context, arg UpdateWorkspaceAgentKeysByIDParams) error {
_, err := q.db.ExecContext(ctx, updateWorkspaceAgentKeysByID, arg.ID, arg.WireguardNodePublicKey, arg.WireguardDiscoPublicKey)
return err
}
const getWorkspaceAppByAgentIDAndName = `-- name: GetWorkspaceAppByAgentIDAndName :one
SELECT id, created_at, agent_id, name, icon, command, url, relative_path FROM workspace_apps WHERE agent_id = $1 AND name = $2
`

View File

@ -53,17 +53,31 @@ INSERT INTO
startup_script,
directory,
instance_metadata,
resource_metadata
resource_metadata,
wireguard_node_ipv6,
wireguard_node_public_key,
wireguard_disco_public_key
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) RETURNING *;
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) RETURNING *;
-- name: UpdateWorkspaceAgentConnectionByID :exec
UPDATE
workspace_agents
SET
updated_at = now(),
first_connected_at = $2,
last_connected_at = $3,
disconnected_at = $4
WHERE
id = $1;
-- name: UpdateWorkspaceAgentKeysByID :exec
UPDATE
workspace_agents
SET
updated_at = now(),
wireguard_node_public_key = $2,
wireguard_disco_public_key = $3
WHERE
id = $1;

View File

@ -15,9 +15,13 @@ packages:
# to add support for transactions. This file is
# deleted after generation.
output_db_file_name: db_tmp.go
overrides:
- db_type: citext
go_type: string
- column: workspaces.wireguard_public_key
go_type: tailscale.com/types/key.MachinePublic
- column: workspaces.disco_public_key
go_type: tailscale.com/types/key.DiscoPublic
rename:
api_key: APIKey
login_type_oidc: LoginTypeOIDC
@ -30,3 +34,4 @@ rename:
gitsshkey: GitSSHKey
rbac_roles: RBACRoles
ip_address: IPAddress
wireguard_node_ipv6: WireguardNodeIPv6

View File

@ -37,21 +37,20 @@ func ExtractWorkspaceAgent(db database.Store) func(http.Handler) http.Handler {
}
token, err := uuid.Parse(cookie.Value)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: fmt.Sprintf("Parse token %q: %s.", cookie.Value, err),
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: "Agent token is invalid.",
})
return
}
agent, err := db.GetWorkspaceAgentByAuthToken(r.Context(), token)
if errors.Is(err, sql.ErrNoRows) {
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
Message: "Agent token is invalid.",
})
return
}
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error fetching workspace agent.",
Detail: err.Error(),

View File

@ -31,6 +31,7 @@ func ExtractWorkspaceAgentParam(db database.Store) func(http.Handler) http.Handl
if !parsed {
return
}
agent, err := db.GetWorkspaceAgentByID(r.Context(), agentUUID)
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
@ -45,6 +46,7 @@ func ExtractWorkspaceAgentParam(db database.Store) func(http.Handler) http.Handl
})
return
}
resource, err := db.GetWorkspaceResourceByID(r.Context(), agent.ResourceID)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{

View File

@ -19,6 +19,7 @@ import (
protobuf "google.golang.org/protobuf/proto"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"tailscale.com/types/key"
"cdr.dev/slog"
@ -27,6 +28,7 @@ import (
"github.com/coder/coder/coderd/parameter"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/telemetry"
"github.com/coder/coder/peer/peerwg"
"github.com/coder/coder/provisionerd/proto"
"github.com/coder/coder/provisionersdk"
sdkproto "github.com/coder/coder/provisionersdk/proto"
@ -714,17 +716,17 @@ func insertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.
}
snapshot.WorkspaceResources = append(snapshot.WorkspaceResources, telemetry.ConvertWorkspaceResource(resource))
for _, agent := range protoResource.Agents {
for _, prAgent := range protoResource.Agents {
var instanceID sql.NullString
if agent.GetInstanceId() != "" {
if prAgent.GetInstanceId() != "" {
instanceID = sql.NullString{
String: agent.GetInstanceId(),
String: prAgent.GetInstanceId(),
Valid: true,
}
}
var env pqtype.NullRawMessage
if agent.Env != nil {
data, err := json.Marshal(agent.Env)
if prAgent.Env != nil {
data, err := json.Marshal(prAgent.Env)
if err != nil {
return xerrors.Errorf("marshal env: %w", err)
}
@ -734,36 +736,40 @@ func insertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.
}
}
authToken := uuid.New()
if agent.GetToken() != "" {
authToken, err = uuid.Parse(agent.GetToken())
if prAgent.GetToken() != "" {
authToken, err = uuid.Parse(prAgent.GetToken())
if err != nil {
return xerrors.Errorf("invalid auth token format; must be uuid: %w", err)
}
}
agentID := uuid.New()
dbAgent, err := db.InsertWorkspaceAgent(ctx, database.InsertWorkspaceAgentParams{
ID: uuid.New(),
ID: agentID,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
ResourceID: resource.ID,
Name: agent.Name,
Name: prAgent.Name,
AuthToken: authToken,
AuthInstanceID: instanceID,
Architecture: agent.Architecture,
Architecture: prAgent.Architecture,
EnvironmentVariables: env,
Directory: agent.Directory,
OperatingSystem: agent.OperatingSystem,
Directory: prAgent.Directory,
OperatingSystem: prAgent.OperatingSystem,
StartupScript: sql.NullString{
String: agent.StartupScript,
Valid: agent.StartupScript != "",
String: prAgent.StartupScript,
Valid: prAgent.StartupScript != "",
},
WireguardNodeIPv6: peerwg.UUIDToInet(agentID),
WireguardNodePublicKey: key.NodePublic{}.String(),
WireguardDiscoPublicKey: key.DiscoPublic{}.String(),
})
if err != nil {
return xerrors.Errorf("insert agent: %w", err)
}
snapshot.WorkspaceAgents = append(snapshot.WorkspaceAgents, telemetry.ConvertWorkspaceAgent(dbAgent))
for _, app := range agent.Apps {
for _, app := range prAgent.Apps {
dbApp, err := db.InsertWorkspaceApp(ctx, database.InsertWorkspaceAppParams{
ID: uuid.New(),
CreatedAt: database.Now(),

View File

@ -13,8 +13,11 @@ import (
"github.com/google/uuid"
"github.com/hashicorp/yamux"
"github.com/tabbed/pqtype"
"golang.org/x/xerrors"
"inet.af/netaddr"
"nhooyr.io/websocket"
"tailscale.com/types/key"
"cdr.dev/slog"
"github.com/coder/coder/agent"
@ -25,6 +28,7 @@ import (
"github.com/coder/coder/coderd/turnconn"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/peer"
"github.com/coder/coder/peer/peerwg"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/peerbroker/proto"
"github.com/coder/coder/provisionersdk"
@ -156,7 +160,18 @@ func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request)
})
return
}
ipp, ok := netaddr.FromStdIPNet(&workspaceAgent.WireguardNodeIPv6.IPNet)
if !ok {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Workspace agent has an invalid ipv6 address.",
Detail: workspaceAgent.WireguardNodeIPv6.IPNet.String(),
})
return
}
httpapi.Write(rw, http.StatusOK, agent.Metadata{
WireguardAddresses: []netaddr.IPPrefix{ipp},
OwnerEmail: owner.Email,
OwnerUsername: owner.Username,
EnvironmentVariables: apiAgent.EnvironmentVariables,
@ -452,6 +467,133 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
_, _ = io.Copy(ptNetConn, wsNetConn)
}
func (*API) derpMap(rw http.ResponseWriter, _ *http.Request) {
httpapi.Write(rw, http.StatusOK, peerwg.DerpMap)
}
type WorkspaceKeysRequest struct {
Public key.NodePublic `json:"public"`
Disco key.DiscoPublic `json:"disco"`
}
func (api *API) postWorkspaceAgentKeys(rw http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
workspaceAgent = httpmw.WorkspaceAgent(r)
keys WorkspaceKeysRequest
)
if !httpapi.Read(rw, r, &keys) {
return
}
err := api.Database.UpdateWorkspaceAgentKeysByID(ctx, database.UpdateWorkspaceAgentKeysByIDParams{
ID: workspaceAgent.ID,
WireguardNodePublicKey: keys.Public.String(),
WireguardDiscoPublicKey: keys.Disco.String(),
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error setting agent keys.",
Detail: err.Error(),
})
return
}
rw.WriteHeader(http.StatusNoContent)
}
func (api *API) postWorkspaceAgentWireguardPeer(rw http.ResponseWriter, r *http.Request) {
var (
req peerwg.Handshake
workspaceAgent = httpmw.WorkspaceAgentParam(r)
workspace = httpmw.WorkspaceParam(r)
)
if !api.Authorize(r, rbac.ActionUpdate, workspace) {
httpapi.ResourceNotFound(rw)
return
}
if !httpapi.Read(rw, r, &req) {
return
}
if req.Recipient != workspaceAgent.ID {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: "Invalid recipient.",
})
return
}
raw, err := req.MarshalText()
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error marshaling wireguard peer message.",
Detail: err.Error(),
})
return
}
err = api.Pubsub.Publish("wireguard_peers", raw)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error publishing wireguard peer message.",
Detail: err.Error(),
})
return
}
rw.WriteHeader(http.StatusNoContent)
}
func (api *API) workspaceAgentWireguardListener(rw http.ResponseWriter, r *http.Request) {
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
ctx := r.Context()
workspaceAgent := httpmw.WorkspaceAgent(r)
conn, err := websocket.Accept(rw, r, nil)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: "Failed to accept websocket.",
Detail: err.Error(),
})
return
}
defer conn.Close(websocket.StatusNormalClosure, "")
agentIDBytes, _ := workspaceAgent.ID.MarshalText()
subCancel, err := api.Pubsub.Subscribe("wireguard_peers", func(ctx context.Context, message []byte) {
// Since we subscribe to all peer broadcasts, we do a light check to
// make sure we're the intended recipient without fully decoding the
// message.
hint, err := peerwg.HandshakeRecipientHint(agentIDBytes, message)
if err != nil {
api.Logger.Error(ctx, "invalid wireguard peer message", slog.Error(err))
return
}
// We aren't the intended recipient.
if !hint {
return
}
_ = conn.Write(ctx, websocket.MessageBinary, message)
})
if err != nil {
api.Logger.Error(ctx, "pubsub listen", slog.Error(err))
return
}
defer subCancel()
// Wait for the connection to close or the client to send a message.
//nolint:dogsled
_, _, _ = conn.Reader(ctx)
}
// dialWorkspaceAgent connects to a workspace agent by ID. Only rely on
// r.Context() for cancellation if it's use is safe or r.Hijack() has
// not been performed.
@ -533,6 +675,19 @@ func convertApps(dbApps []database.WorkspaceApp) []codersdk.WorkspaceApp {
return apps
}
func inetToNetaddr(inet pqtype.Inet) netaddr.IPPrefix {
if !inet.Valid {
return netaddr.IPPrefixFrom(netaddr.IPv6Unspecified(), 128)
}
ipp, ok := netaddr.FromStdIPNet(&inet.IPNet)
if !ok {
return netaddr.IPPrefixFrom(netaddr.IPv6Unspecified(), 128)
}
return ipp
}
func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, apps []codersdk.WorkspaceApp, agentUpdateFrequency time.Duration) (codersdk.WorkspaceAgent, error) {
var envs map[string]string
if dbAgent.EnvironmentVariables.Valid {
@ -541,6 +696,7 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, apps []codersdk.Work
return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal: %w", err)
}
}
workspaceAgent := codersdk.WorkspaceAgent{
ID: dbAgent.ID,
CreatedAt: dbAgent.CreatedAt,
@ -554,7 +710,18 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, apps []codersdk.Work
EnvironmentVariables: envs,
Directory: dbAgent.Directory,
Apps: apps,
IPv6: inetToNetaddr(dbAgent.WireguardNodeIPv6),
}
err := workspaceAgent.WireguardPublicKey.UnmarshalText([]byte(dbAgent.WireguardNodePublicKey))
if err != nil {
return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal wireguard node public key %q: %w", dbAgent.WireguardNodePublicKey, err)
}
err = workspaceAgent.DiscoPublicKey.UnmarshalText([]byte(dbAgent.WireguardDiscoPublicKey))
if err != nil {
return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal disco public key %q: %w", dbAgent.WireguardDiscoPublicKey, err)
}
if dbAgent.FirstConnectedAt.Valid {
workspaceAgent.FirstConnectedAt = &dbAgent.FirstConnectedAt.Time
}

View File

@ -23,6 +23,7 @@ import (
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/turnconn"
"github.com/coder/coder/peer"
"github.com/coder/coder/peer/peerwg"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/peerbroker/proto"
"github.com/coder/coder/provisionersdk"
@ -252,6 +253,97 @@ func (c *Client) ListenWorkspaceAgent(ctx context.Context, logger slog.Logger) (
return agentMetadata, listener, json.NewDecoder(res.Body).Decode(&agentMetadata)
}
// PostWireguardPeer announces your public keys and IPv6 address to the
// specified recipient.
func (c *Client) PostWireguardPeer(ctx context.Context, workspaceID uuid.UUID, peerMsg peerwg.Handshake) error {
res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/v2/workspaceagents/%s/peer?workspace=%s",
peerMsg.Recipient,
workspaceID.String(),
), peerMsg)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != http.StatusNoContent {
return readBodyAsError(res)
}
_, _ = io.Copy(io.Discard, res.Body)
return nil
}
// WireguardPeerListener listens for wireguard peer messages. Peer messages are
// sent when a new client wants to connect. Once receiving a peer message, the
// peer should be added to the NetworkMap of the wireguard interface.
func (c *Client) WireguardPeerListener(ctx context.Context, logger slog.Logger) (<-chan peerwg.Handshake, func(), error) {
serverURL, err := c.URL.Parse("/api/v2/workspaceagents/me/wireguardlisten")
if err != nil {
return nil, nil, xerrors.Errorf("parse url: %w", err)
}
jar, err := cookiejar.New(nil)
if err != nil {
return nil, nil, xerrors.Errorf("create cookie jar: %w", err)
}
jar.SetCookies(serverURL, []*http.Cookie{{
Name: httpmw.SessionTokenKey,
Value: c.SessionToken,
}})
httpClient := &http.Client{
Jar: jar,
}
conn, res, err := websocket.Dial(ctx, serverURL.String(), &websocket.DialOptions{
HTTPClient: httpClient,
// Need to disable compression to avoid a data-race.
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
if res == nil {
return nil, nil, xerrors.Errorf("websocket dial: %w", err)
}
return nil, nil, readBodyAsError(res)
}
ch := make(chan peerwg.Handshake, 1)
go func() {
defer conn.Close(websocket.StatusGoingAway, "")
defer close(ch)
for {
_, message, err := conn.Read(ctx)
if err != nil {
break
}
var msg peerwg.Handshake
err = msg.UnmarshalText(message)
if err != nil {
logger.Error(ctx, "unmarshal wireguard peer message", slog.Error(err))
continue
}
ch <- msg
}
}()
return ch, func() { _ = conn.Close(websocket.StatusGoingAway, "") }, nil
}
// UploadWorkspaceAgentKeys uploads the public keys of the workspace agent that
// were generated on startup. These keys are used by clients to communicate with
// the workspace agent over the wireguard interface.
func (c *Client) UploadWorkspaceAgentKeys(ctx context.Context, keys agent.WireguardPublicKeys) error {
res, err := c.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/me/keys", keys)
if err != nil {
return xerrors.Errorf("do request: %w", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusNoContent {
return readBodyAsError(res)
}
return nil
}
// DialWorkspaceAgent creates a connection to the specified resource.
func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *peer.ConnOptions) (*agent.Conn, error) {
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/dial", agentID.String()))

View File

@ -8,6 +8,8 @@ import (
"time"
"github.com/google/uuid"
"inet.af/netaddr"
"tailscale.com/types/key"
)
type WorkspaceAgentStatus string
@ -45,6 +47,9 @@ type WorkspaceAgent struct {
StartupScript string `json:"startup_script,omitempty"`
Directory string `json:"directory,omitempty"`
Apps []WorkspaceApp `json:"apps"`
WireguardPublicKey key.NodePublic `json:"wireguard_public_key"`
DiscoPublicKey key.DiscoPublic `json:"disco_public_key"`
IPv6 netaddr.IPPrefix `json:"ipv6"`
}
type WorkspaceAgentResourceMetadata struct {

56
go.mod
View File

@ -39,6 +39,8 @@ replace github.com/golang/glog => github.com/coder/glog v1.0.1-0.20220322161911-
// https://github.com/coder/kcp-go/commit/83c0904cec69dcf21ec10c54ea666bda18ada831
replace github.com/fatedier/kcp-go => github.com/coder/kcp-go v2.0.4-0.20220409183554-83c0904cec69+incompatible
replace golang.zx2c4.com/wireguard/tun/netstack => github.com/coder/wireguard-go/tun/netstack v0.0.0-20220614153727-d82b4ba8619f
require (
cdr.dev/slog v1.4.2-0.20220525200111-18dce5c2cd5f
cloud.google.com/go/compute v1.6.1
@ -120,25 +122,18 @@ require (
golang.org/x/text v0.3.7
golang.org/x/tools v0.1.11
golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df
golang.zx2c4.com/wireguard v0.0.0-20220407013110-ef5c587f782d
golang.zx2c4.com/wireguard/tun/netstack v0.0.0-20220407013110-ef5c587f782d
golang.zx2c4.com/wireguard v0.0.0-20220601130007-6a08d81f6bc4
golang.zx2c4.com/wireguard/tun/netstack v0.0.0-00010101000000-000000000000
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20220504211119-3d4a969bb56b
google.golang.org/api v0.82.0
google.golang.org/protobuf v1.28.0
gopkg.in/natefinch/lumberjack.v2 v2.0.0
gopkg.in/yaml.v3 v3.0.1
inet.af/netaddr v0.0.0-20211027220019-c74959edd3b6
k8s.io/utils v0.0.0-20220210201930-3a6ce19ff2f9
nhooyr.io/websocket v1.8.7
storj.io/drpc v0.0.30
)
require (
github.com/agnivade/levenshtein v1.0.1 // indirect
github.com/elastic/go-windows v1.0.0 // indirect
github.com/joeshaw/multierror v0.0.0-20140124173710-69b34d4ec901 // indirect
github.com/vektah/gqlparser/v2 v2.4.4 // indirect
github.com/yuin/goldmark v1.4.12 // indirect
howett.net/plist v0.0.0-20181124034731-591f970eefbb // indirect
tailscale.com v1.26.0
)
require (
@ -147,7 +142,10 @@ require (
github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect
github.com/OneOfOne/xxhash v1.2.8 // indirect
github.com/agext/levenshtein v1.2.3 // indirect
github.com/agnivade/levenshtein v1.0.1 // indirect
github.com/akutz/memconn v0.1.0 // indirect
github.com/alecthomas/chroma v0.10.0 // indirect
github.com/alexbrainman/sspi v0.0.0-20210105120005-909beea2cc74 // indirect
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
@ -160,17 +158,20 @@ require (
github.com/clbanning/mxj/v2 v2.5.6 // indirect
github.com/containerd/console v1.0.3 // indirect
github.com/containerd/continuity v0.3.0 // indirect
github.com/coreos/go-iptables v0.6.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dlclark/regexp2 v1.4.0 // indirect
github.com/docker/cli v20.10.14+incompatible // indirect
github.com/docker/docker v20.10.13+incompatible // indirect
github.com/docker/cli v20.10.16+incompatible // indirect
github.com/docker/docker v20.10.16+incompatible // indirect
github.com/docker/go-connections v0.4.0 // indirect
github.com/docker/go-units v0.4.0 // indirect
github.com/elastic/go-windows v1.0.0 // indirect
github.com/ghodss/yaml v1.0.0 // indirect
github.com/gin-gonic/gin v1.7.0 // indirect
github.com/go-chi/chi v1.5.4
github.com/go-logr/logr v1.2.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-playground/locales v0.14.0 // indirect
github.com/go-playground/universal-translator v0.18.0 // indirect
github.com/go-toast/toast v0.0.0-20190211030409-01e6764cf0a4 // indirect
@ -191,16 +192,26 @@ require (
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/imdario/mergo v0.3.12 // indirect
github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/insomniacslk/dhcp v0.0.0-20211209223715-7d93572ebe8e // indirect
github.com/joeshaw/multierror v0.0.0-20140124173710-69b34d4ec901 // indirect
github.com/josharian/native v1.0.0 // indirect
github.com/jsimonetti/rtnetlink v1.1.2-0.20220408201609-d380b505068b // indirect
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
github.com/klauspost/compress v1.15.6
github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a // indirect
github.com/kr/fs v0.1.0 // indirect
github.com/leodido/go-urn v1.2.1 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-colorable v0.1.12 // indirect
github.com/mattn/go-runewidth v0.0.13 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect
github.com/mdlayher/genetlink v1.2.0 // indirect
github.com/mdlayher/netlink v1.6.0 // indirect
github.com/mdlayher/sdnotify v1.0.0 // indirect
github.com/mdlayher/socket v0.2.3 // indirect
github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b // indirect
github.com/miekg/dns v1.1.45 // indirect
github.com/mitchellh/go-ps v1.0.0 // indirect
github.com/moby/term v0.0.0-20210619224110-3f7ff695adc6 // indirect
github.com/muesli/ansi v0.0.0-20211031195517-c9f0611b6c70 // indirect
github.com/muesli/reflow v0.3.0 // indirect
@ -209,7 +220,7 @@ require (
github.com/niklasfasching/go-org v1.6.5 // indirect
github.com/nu7hatch/gouuid v0.0.0-20131221200532-179d4d0c4d8d // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.0.3-0.20211202183452-c5a74bcca799 // indirect
github.com/opencontainers/image-spec v1.0.3-0.20220114050600-8b9d41f48198 // indirect
github.com/opencontainers/runc v1.1.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.2 // indirect
github.com/pion/dtls/v2 v2.1.5 // indirect
@ -235,12 +246,22 @@ require (
github.com/spf13/cast v1.5.0 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/tadvi/systray v0.0.0-20190226123456-11a2b8fa57af // indirect
github.com/tailscale/certstore v0.1.1-0.20220316223106-78d6e1c49d8d // indirect
github.com/tailscale/golang-x-crypto v0.0.0-20220428210705-0b941c09a5e1 // indirect
github.com/tailscale/goupnp v1.0.1-0.20210804011211-c64d0f06ea05 // indirect
github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85 // indirect
github.com/tcnksm/go-httpstat v0.2.0 // indirect
github.com/tdewolff/parse/v2 v2.6.0 // indirect
github.com/u-root/uio v0.0.0-20210528151154-e40b768296a7 // indirect
github.com/vektah/gqlparser/v2 v2.4.4 // indirect
github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54 // indirect
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
github.com/xeipuuv/gojsonschema v1.2.0 // indirect
github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 // indirect
github.com/yashtewari/glob-intersection v0.1.0 // indirect
github.com/yuin/goldmark v1.4.12 // indirect
github.com/zclconf/go-cty v1.10.0 // indirect
github.com/zeebo/errs v1.2.2 // indirect
go.opencensus.io v0.23.0 // indirect
@ -251,11 +272,16 @@ require (
go.opentelemetry.io/otel/sdk v1.7.0
go.opentelemetry.io/otel/trace v1.7.0
go.opentelemetry.io/proto/otlp v0.16.0 // indirect
go4.org/intern v0.0.0-20211027215823-ae77deb06f29 // indirect
go4.org/mem v0.0.0-20210711025021-927187094b94 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20211027215541-db492cf91b37 // indirect
golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 // indirect
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 // indirect
golang.zx2c4.com/wireguard/windows v0.4.10 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20220527130721-00d5c0f3be58 // indirect
google.golang.org/grpc v1.47.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gvisor.dev/gvisor v0.0.0-20211020211948-f76a604701b6 // indirect
gvisor.dev/gvisor v0.0.0-20220407223209-21871174d445 // indirect
howett.net/plist v1.0.0 // indirect
)

507
go.sum

File diff suppressed because it is too large Load Diff

67
peer/peerwg/derp.go Normal file
View File

@ -0,0 +1,67 @@
package peerwg
import (
"net"
"tailscale.com/tailcfg"
"tailscale.com/wgengine/magicsock"
)
// This is currently set to use Tailscale's DERP server in DFW while we build in
// our own support for DERP servers.
var DerpMap = &tailcfg.DERPMap{
Regions: map[int]*tailcfg.DERPRegion{
9: {
RegionID: 9,
RegionCode: "dfw",
RegionName: "Dallas",
Avoid: false,
Nodes: []*tailcfg.DERPNode{
{
Name: "9a",
RegionID: 9,
HostName: "derp9.tailscale.com",
CertName: "",
IPv4: "207.148.3.137",
IPv6: "2001:19f0:6401:1d9c:5400:2ff:feef:bb82",
STUNPort: 0,
STUNOnly: false,
DERPPort: 0,
InsecureForTests: false,
STUNTestIP: "",
},
{
Name: "9c",
RegionID: 9,
HostName: "derp9c.tailscale.com",
CertName: "",
IPv4: "155.138.243.219",
IPv6: "2001:19f0:6401:fe7:5400:3ff:fe8d:6d9c",
STUNPort: 0,
STUNOnly: false,
DERPPort: 0,
InsecureForTests: false,
STUNTestIP: "",
},
{
Name: "9b",
RegionID: 9,
HostName: "derp9b.tailscale.com",
CertName: "",
IPv4: "144.202.67.195",
IPv6: "2001:19f0:6401:eb5:5400:3ff:fe8d:6d9b",
STUNPort: 0,
STUNOnly: false,
DERPPort: 0,
InsecureForTests: false,
STUNTestIP: "",
},
},
},
},
OmitDefaultRegions: true,
}
// DefaultDerpHome is the ipv4 representation of a DERP server. The port is the
// DERP id. We only support using DERP 9 for now.
var DefaultDerpHome = net.JoinHostPort(magicsock.DerpMagicIP, "9")

94
peer/peerwg/handshake.go Normal file
View File

@ -0,0 +1,94 @@
package peerwg
import (
"bytes"
"strconv"
"github.com/google/uuid"
"golang.org/x/xerrors"
"inet.af/netaddr"
"tailscale.com/types/key"
)
const handshakeSeparator byte = '|'
// Handshake is a message received from a wireguard peer, indicating
// it would like to connect.
type Handshake struct {
// Recipient is the uuid of the agent that the message was intended for.
Recipient uuid.UUID `json:"recipient"`
// DiscoPublicKey is the disco public key of the peer.
DiscoPublicKey key.DiscoPublic `json:"disco"`
// NodePublicKey is the public key of the peer.
NodePublicKey key.NodePublic `json:"public"`
// IPv6 is the IPv6 address of the peer.
IPv6 netaddr.IP `json:"ipv6"`
}
// HandshakeRecipientHint parses the first part of a serialized
// Handshake to quickly determine if the message is meant for the
// provided recipient.
func HandshakeRecipientHint(agentID []byte, msg []byte) (bool, error) {
idx := bytes.Index(msg, []byte{handshakeSeparator})
if idx == -1 {
return false, xerrors.Errorf("invalid peer message, no separator")
}
return bytes.Equal(agentID, msg[:idx]), nil
}
func (h *Handshake) UnmarshalText(text []byte) error {
sp := bytes.Split(text, []byte{handshakeSeparator})
if len(sp) != 4 {
return xerrors.Errorf("expected 4 parts, got %d", len(sp))
}
err := h.Recipient.UnmarshalText(sp[0])
if err != nil {
return xerrors.Errorf("parse recipient: %w", err)
}
err = h.DiscoPublicKey.UnmarshalText(sp[1])
if err != nil {
return xerrors.Errorf("parse disco: %w", err)
}
err = h.NodePublicKey.UnmarshalText(sp[2])
if err != nil {
return xerrors.Errorf("parse public: %w", err)
}
h.IPv6, err = netaddr.ParseIP(string(sp[3]))
if err != nil {
return xerrors.Errorf("parse ipv6: %w", err)
}
return nil
}
func (h Handshake) MarshalText() ([]byte, error) {
const expectedLen = 223
var buf bytes.Buffer
buf.Grow(expectedLen)
recp, _ := h.Recipient.MarshalText()
_, _ = buf.Write(recp)
_ = buf.WriteByte(handshakeSeparator)
disco, _ := h.DiscoPublicKey.MarshalText()
_, _ = buf.Write(disco)
_ = buf.WriteByte(handshakeSeparator)
pub, _ := h.NodePublicKey.MarshalText()
_, _ = buf.Write(pub)
_ = buf.WriteByte(handshakeSeparator)
ipv6 := h.IPv6.StringExpanded()
_, _ = buf.WriteString(ipv6)
// Ensure we're always allocating exactly enough.
if buf.Len() != expectedLen {
panic("buffer length mismatch: want 223, got " + strconv.Itoa(buf.Len()))
}
return buf.Bytes(), nil
}

435
peer/peerwg/wireguard.go Normal file
View File

@ -0,0 +1,435 @@
package peerwg
import (
"context"
"fmt"
"hash/fnv"
"io"
"log"
"net"
"strconv"
"sync"
"time"
"github.com/google/uuid"
"github.com/tabbed/pqtype"
"golang.org/x/xerrors"
"inet.af/netaddr"
"tailscale.com/ipn/ipnstate"
"tailscale.com/net/dns"
"tailscale.com/net/netns"
"tailscale.com/net/tsdial"
"tailscale.com/tailcfg"
"tailscale.com/types/ipproto"
"tailscale.com/types/key"
tslogger "tailscale.com/types/logger"
"tailscale.com/types/netmap"
"tailscale.com/wgengine"
"tailscale.com/wgengine/filter"
"tailscale.com/wgengine/magicsock"
"tailscale.com/wgengine/monitor"
"tailscale.com/wgengine/netstack"
"tailscale.com/wgengine/router"
"tailscale.com/wgengine/wgcfg/nmcfg"
"cdr.dev/slog"
)
var logf tslogger.Logf = log.Printf
func init() {
// Globally disable network namespacing.
// All networking happens in userspace.
netns.SetEnabled(false)
}
func UUIDToInet(uid uuid.UUID) pqtype.Inet {
uid = privateUUID(uid)
return pqtype.Inet{
Valid: true,
IPNet: net.IPNet{
IP: uid[:],
Mask: net.CIDRMask(128, 128),
},
}
}
func UUIDToNetaddr(uid uuid.UUID) netaddr.IP {
return netaddr.IPFrom16(privateUUID(uid))
}
// privateUUID sets the uid to have the tailscale private ipv6 prefix.
func privateUUID(uid uuid.UUID) uuid.UUID {
// fd7a:115c:a1e0
uid[0] = 0xfd
uid[1] = 0x7a
uid[2] = 0x11
uid[3] = 0x5c
uid[4] = 0xa1
uid[5] = 0xe0
return uid
}
type Network struct {
mu sync.Mutex
logger slog.Logger
Netstack *netstack.Impl
magicSock *magicsock.Conn
netMap *netmap.NetworkMap
router *router.Config
wgEngine wgengine.Engine
// listeners is a map of listening sockets that will be forwarded traffic
// from the wireguard interface.
listeners map[listenKey]*listener
DiscoPublicKey key.DiscoPublic
NodePrivateKey key.NodePrivate
}
// New constructs a Wireguard network that filters traffic
// to destinations matching the addresses provided.
func New(logger slog.Logger, addresses []netaddr.IPPrefix) (*Network, error) {
nodePrivateKey := key.NewNode()
nodePublicKey := nodePrivateKey.Public()
id, stableID := nodeIDs(nodePublicKey)
netMap := &netmap.NetworkMap{
NodeKey: nodePublicKey,
PrivateKey: nodePrivateKey,
Addresses: addresses,
PacketFilter: []filter.Match{{
// Allow any protocol!
IPProto: []ipproto.Proto{ipproto.TCP, ipproto.UDP, ipproto.ICMPv4, ipproto.ICMPv6, ipproto.SCTP},
// Allow traffic sourced from anywhere.
Srcs: []netaddr.IPPrefix{
netaddr.IPPrefixFrom(netaddr.IPv4(0, 0, 0, 0), 0),
netaddr.IPPrefixFrom(netaddr.IPv6Unspecified(), 0),
},
// Allow traffic to route anywhere.
Dsts: []filter.NetPortRange{
{
Net: netaddr.IPPrefixFrom(netaddr.IPv4(0, 0, 0, 0), 0),
Ports: filter.PortRange{
First: 0,
Last: 65535,
},
},
{
Net: netaddr.IPPrefixFrom(netaddr.IPv6Unspecified(), 0),
Ports: filter.PortRange{
First: 0,
Last: 65535,
},
},
},
Caps: []filter.CapMatch{},
}},
}
// Identify itself as a node on the network with the addresses provided.
netMap.SelfNode = &tailcfg.Node{
ID: id,
StableID: stableID,
Key: nodePublicKey,
Addresses: netMap.Addresses,
AllowedIPs: append(netMap.Addresses, netaddr.MustParseIPPrefix("::/0")),
Endpoints: []string{},
DERP: DefaultDerpHome,
}
wgMonitor, err := monitor.New(logf)
if err != nil {
return nil, xerrors.Errorf("create link monitor: %w", err)
}
dialer := new(tsdial.Dialer)
dialer.Logf = logf
// Create a wireguard engine in userspace.
engine, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{
LinkMonitor: wgMonitor,
Dialer: dialer,
})
if err != nil {
return nil, xerrors.Errorf("create wgengine: %w", err)
}
// This is taken from Tailscale:
// https://github.com/tailscale/tailscale/blob/0f05b2c13ff0c305aa7a1655fa9c17ed969d65be/tsnet/tsnet.go#L247-L255
// nolint
tunDev, magicConn, dnsManager, ok := engine.(wgengine.InternalsGetter).GetInternals()
if !ok {
return nil, xerrors.New("could not get wgengine internals")
}
// Update the keys for the magic connection!
err = magicConn.SetPrivateKey(nodePrivateKey)
if err != nil {
return nil, xerrors.Errorf("set node private key: %w", err)
}
netMap.SelfNode.DiscoKey = magicConn.DiscoPublicKey()
// Create the networking stack.
// This is called to route connections.
netStack, err := netstack.Create(logf, tunDev, engine, magicConn, dialer, dnsManager)
if err != nil {
return nil, xerrors.Errorf("create netstack: %w", err)
}
netStack.ProcessLocalIPs = true
netStack.ProcessSubnets = true
dialer.UseNetstackForIP = func(ip netaddr.IP) bool {
_, ok := engine.PeerForIP(ip)
return ok
}
dialer.NetstackDialTCP = func(ctx context.Context, dst netaddr.IPPort) (net.Conn, error) {
return netStack.DialContextTCP(ctx, dst)
}
err = netStack.Start()
if err != nil {
return nil, xerrors.Errorf("start netstack: %w", err)
}
engine = wgengine.NewWatchdog(engine)
// Update the wireguard configuration to allow traffic to flow.
cfg, err := nmcfg.WGCfg(netMap, logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, netMap.SelfNode.StableID)
if err != nil {
return nil, xerrors.Errorf("create wgcfg: %w", err)
}
rtr := &router.Config{
LocalAddrs: cfg.Addresses,
}
err = engine.Reconfig(cfg, rtr, &dns.Config{}, &tailcfg.Debug{})
if err != nil {
return nil, xerrors.Errorf("reconfig: %w", err)
}
engine.SetDERPMap(DerpMap)
engine.SetNetworkMap(copyNetMap(netMap))
ipb := netaddr.IPSetBuilder{}
for _, addr := range netMap.Addresses {
ipb.AddPrefix(addr)
}
ips, _ := ipb.IPSet()
iplb := netaddr.IPSetBuilder{}
ipl, _ := iplb.IPSet()
engine.SetFilter(filter.New(netMap.PacketFilter, ips, ipl, nil, logf))
wn := &Network{
logger: logger,
NodePrivateKey: nodePrivateKey,
DiscoPublicKey: magicConn.DiscoPublicKey(),
wgEngine: engine,
Netstack: netStack,
magicSock: magicConn,
netMap: netMap,
router: rtr,
listeners: map[listenKey]*listener{},
}
netStack.ForwardTCPIn = wn.forwardTCP
return wn, nil
}
// forwardTCP handles incoming connections from Wireguard in userspace.
func (n *Network) forwardTCP(conn net.Conn, port uint16) {
n.mu.Lock()
listener, ok := n.listeners[listenKey{"tcp", "", fmt.Sprint(port)}]
n.mu.Unlock()
if !ok {
// No in-memory listener exists, forward to host.
n.forwardTCPToLocalHandler(conn, port)
return
}
timer := time.NewTimer(time.Second)
defer timer.Stop()
select {
case listener.conn <- conn:
case <-timer.C:
_ = conn.Close()
}
}
// forwardTCPToLocalHandler forwards the provided net.Conn to the
// matching port bound to localhost.
func (n *Network) forwardTCPToLocalHandler(c net.Conn, port uint16) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
defer c.Close()
dialAddrStr := net.JoinHostPort("127.0.0.1", strconv.Itoa(int(port)))
var stdDialer net.Dialer
server, err := stdDialer.DialContext(ctx, "tcp", dialAddrStr)
if err != nil {
n.logger.Debug(ctx, "dial local port", slog.F("port", port), slog.Error(err))
return
}
defer server.Close()
connClosed := make(chan error, 2)
go func() {
_, err := io.Copy(server, c)
connClosed <- err
}()
go func() {
_, err := io.Copy(c, server)
connClosed <- err
}()
err = <-connClosed
if err != nil {
n.logger.Debug(ctx, "proxy connection closed with error", slog.Error(err))
}
n.logger.Debug(ctx, "forwarded connection closed", slog.F("local_addr", dialAddrStr))
}
// AddPeer allows connections from another Wireguard instance with the
// handshake credentials.
func (n *Network) AddPeer(handshake Handshake) error {
n.mu.Lock()
defer n.mu.Unlock()
// If the peer already exists in the network map, do nothing.
for _, p := range n.netMap.Peers {
if p.Key == handshake.NodePublicKey {
n.logger.Debug(context.Background(), "peer already in netmap", slog.F("peer", handshake.NodePublicKey.ShortString()))
return nil
}
}
// The Tailscale engine owns this slice, so we need to copy to make
// modifications.
peers := append(([]*tailcfg.Node)(nil), n.netMap.Peers...)
id, stableID := nodeIDs(handshake.NodePublicKey)
peers = append(peers, &tailcfg.Node{
ID: id,
StableID: stableID,
Name: handshake.NodePublicKey.String() + ".com",
Key: handshake.NodePublicKey,
DiscoKey: handshake.DiscoPublicKey,
Addresses: []netaddr.IPPrefix{netaddr.IPPrefixFrom(handshake.IPv6, 128)},
AllowedIPs: []netaddr.IPPrefix{netaddr.IPPrefixFrom(handshake.IPv6, 128)},
DERP: DefaultDerpHome,
Endpoints: []string{DefaultDerpHome},
})
n.netMap.Peers = peers
cfg, err := nmcfg.WGCfg(n.netMap, logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, tailcfg.StableNodeID("nBBoJZ5CNTRL"))
if err != nil {
return xerrors.Errorf("create wgcfg: %w", err)
}
err = n.wgEngine.Reconfig(cfg, n.router, &dns.Config{}, &tailcfg.Debug{})
if err != nil {
return xerrors.Errorf("reconfig: %w", err)
}
// Always give the Tailscale engine a copy of our network map.
n.wgEngine.SetNetworkMap(copyNetMap(n.netMap))
return nil
}
// Ping sends a discovery ping to the provided peer.
// The peer address must be connected before a successful ping will work.
func (n *Network) Ping(ip netaddr.IP) *ipnstate.PingResult {
ch := make(chan *ipnstate.PingResult)
n.wgEngine.Ping(ip, tailcfg.PingDisco, func(pr *ipnstate.PingResult) {
ch <- pr
})
return <-ch
}
// Listener returns a net.Listener in userspace that can be used to accept
// connections from the Wireguard network to the specified address. If a
// listener exists for a given address, all connections will be forwarded to the
// listener instead of being routed to the host.
func (n *Network) Listen(network, addr string) (net.Listener, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, xerrors.Errorf("split addr host port: %w", err)
}
lkey := listenKey{network, host, port}
ln := &listener{
wn: n,
key: lkey,
addr: addr,
conn: make(chan net.Conn, 1),
}
n.mu.Lock()
defer n.mu.Unlock()
if _, ok := n.listeners[lkey]; ok {
return nil, xerrors.Errorf("listener already open for %s, %s", network, addr)
}
n.listeners[lkey] = ln
return ln, nil
}
func (n *Network) Close() error {
_ = n.Netstack.Close()
n.wgEngine.Close()
return nil
}
type listenKey struct {
network string
host string
port string
}
type listener struct {
wn *Network
key listenKey
addr string
conn chan net.Conn
}
func (ln *listener) Accept() (net.Conn, error) {
c, ok := <-ln.conn
if !ok {
return nil, xerrors.Errorf("tsnet: %w", net.ErrClosed)
}
return c, nil
}
func (ln *listener) Addr() net.Addr { return addr{ln} }
func (ln *listener) Close() error {
ln.wn.mu.Lock()
defer ln.wn.mu.Unlock()
if v, ok := ln.wn.listeners[ln.key]; ok && v == ln {
delete(ln.wn.listeners, ln.key)
close(ln.conn)
}
return nil
}
type addr struct{ ln *listener }
func (a addr) Network() string { return a.ln.key.network }
func (a addr) String() string { return a.ln.addr }
// nodeIDs generates Tailscale node IDs for the provided public key.
func nodeIDs(public key.NodePublic) (tailcfg.NodeID, tailcfg.StableNodeID) {
idhash := fnv.New64()
pub, _ := public.MarshalText()
_, _ = idhash.Write(pub)
return tailcfg.NodeID(idhash.Sum64()), tailcfg.StableNodeID(pub)
}
func copyNetMap(nm *netmap.NetworkMap) *netmap.NetworkMap {
nmCopy := *nm
return &nmCopy
}

View File

@ -1,6 +1,6 @@
// Code generated by 'make coder/scripts/apitypings/main.go'. DO NOT EDIT.
// From codersdk/workspaceagents.go:35:6
// From codersdk/workspaceagents.go:36:6
export interface AWSInstanceIdentityToken {
readonly signature: string
readonly document: string
@ -18,7 +18,7 @@ export interface AuthMethods {
readonly github: boolean
}
// From codersdk/workspaceagents.go:40:6
// From codersdk/workspaceagents.go:41:6
export interface AzureInstanceIdentityToken {
readonly signature: string
readonly encoding: string
@ -128,7 +128,7 @@ export interface GitSSHKey {
readonly public_key: string
}
// From codersdk/workspaceagents.go:31:6
// From codersdk/workspaceagents.go:32:6
export interface GoogleInstanceIdentityToken {
readonly json_web_token: string
}
@ -380,7 +380,7 @@ export interface Workspace {
readonly ttl_ms?: number
}
// From codersdk/workspaceresources.go:31:6
// From codersdk/workspaceresources.go:33:6
export interface WorkspaceAgent {
readonly id: string
readonly created_at: string
@ -398,14 +398,23 @@ export interface WorkspaceAgent {
readonly startup_script?: string
readonly directory?: string
readonly apps: WorkspaceApp[]
// Named type "tailscale.com/types/key.NodePublic" unknown, using "any"
// eslint-disable-next-line @typescript-eslint/no-explicit-any
readonly wireguard_public_key: any
// Named type "tailscale.com/types/key.DiscoPublic" unknown, using "any"
// eslint-disable-next-line @typescript-eslint/no-explicit-any
readonly disco_public_key: any
// Named type "inet.af/netaddr.IPPrefix" unknown, using "any"
// eslint-disable-next-line @typescript-eslint/no-explicit-any
readonly ipv6: any
}
// From codersdk/workspaceagents.go:47:6
// From codersdk/workspaceagents.go:48:6
export interface WorkspaceAgentAuthenticateResponse {
readonly session_token: string
}
// From codersdk/workspaceresources.go:58:6
// From codersdk/workspaceresources.go:63:6
export interface WorkspaceAgentInstanceMetadata {
readonly jail_orchestrator: string
readonly operating_system: string
@ -418,7 +427,7 @@ export interface WorkspaceAgentInstanceMetadata {
readonly vnc: boolean
}
// From codersdk/workspaceresources.go:50:6
// From codersdk/workspaceresources.go:55:6
export interface WorkspaceAgentResourceMetadata {
readonly memory_total: number
readonly disk_total: number
@ -470,7 +479,7 @@ export interface WorkspaceOptions {
readonly include_deleted?: boolean
}
// From codersdk/workspaceresources.go:21:6
// From codersdk/workspaceresources.go:23:6
export interface WorkspaceResource {
readonly id: string
readonly created_at: string
@ -514,7 +523,7 @@ export type ProvisionerType = "echo" | "terraform"
// From codersdk/users.go:18:6
export type UserStatus = "active" | "suspended"
// From codersdk/workspaceresources.go:13:6
// From codersdk/workspaceresources.go:15:6
export type WorkspaceAgentStatus = "connected" | "connecting" | "disconnected"
// From codersdk/workspacebuilds.go:14:6

View File

@ -247,6 +247,9 @@ export const MockWorkspaceAgent: TypesGen.WorkspaceAgent = {
resource_id: "",
status: "connected",
updated_at: "",
wireguard_public_key: "",
disco_public_key: "",
ipv6: "",
}
export const MockWorkspaceAgentDisconnected: TypesGen.WorkspaceAgent = {