2023-07-19 16:11:11 +00:00
|
|
|
package tailnet
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
|
|
|
"context"
|
|
|
|
"encoding/json"
|
2023-08-22 02:55:39 +00:00
|
|
|
"errors"
|
2023-07-19 16:11:11 +00:00
|
|
|
"net"
|
2024-01-18 06:10:36 +00:00
|
|
|
"sync/atomic"
|
2023-07-19 16:11:11 +00:00
|
|
|
"time"
|
|
|
|
|
2024-01-18 06:10:36 +00:00
|
|
|
"github.com/google/uuid"
|
2023-07-19 16:11:11 +00:00
|
|
|
"golang.org/x/xerrors"
|
2024-01-18 06:10:36 +00:00
|
|
|
"tailscale.com/tailcfg"
|
2023-07-19 16:11:11 +00:00
|
|
|
|
2024-01-18 06:10:36 +00:00
|
|
|
"cdr.dev/slog"
|
|
|
|
"github.com/coder/coder/v2/coderd/util/apiversion"
|
2023-08-18 18:55:43 +00:00
|
|
|
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
|
|
|
|
agpl "github.com/coder/coder/v2/tailnet"
|
2023-07-19 16:11:11 +00:00
|
|
|
)
|
|
|
|
|
2024-01-18 06:10:36 +00:00
|
|
|
type ClientService struct {
|
|
|
|
*agpl.ClientService
|
|
|
|
}
|
|
|
|
|
|
|
|
// NewClientService returns a ClientService based on the given Coordinator pointer. The pointer is
|
|
|
|
// loaded on each processed connection.
|
|
|
|
func NewClientService(
|
|
|
|
logger slog.Logger,
|
|
|
|
coordPtr *atomic.Pointer[agpl.Coordinator],
|
|
|
|
derpMapUpdateFrequency time.Duration,
|
|
|
|
derpMapFn func() *tailcfg.DERPMap,
|
|
|
|
) (
|
|
|
|
*ClientService, error,
|
|
|
|
) {
|
|
|
|
s, err := agpl.NewClientService(logger, coordPtr, derpMapUpdateFrequency, derpMapFn)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
return &ClientService{ClientService: s}, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *ClientService) ServeMultiAgentClient(ctx context.Context, version string, conn net.Conn, id uuid.UUID) error {
|
|
|
|
major, _, err := apiversion.Parse(version)
|
|
|
|
if err != nil {
|
|
|
|
s.Logger.Warn(ctx, "serve client called with unparsable version", slog.Error(err))
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
switch major {
|
|
|
|
case 1:
|
|
|
|
coord := *(s.CoordPtr.Load())
|
|
|
|
sub := coord.ServeMultiAgent(id)
|
|
|
|
return ServeWorkspaceProxy(ctx, conn, sub)
|
|
|
|
case 2:
|
|
|
|
auth := agpl.SingleTailnetTunnelAuth{}
|
|
|
|
streamID := agpl.StreamID{
|
|
|
|
Name: id.String(),
|
|
|
|
ID: id,
|
|
|
|
Auth: auth,
|
|
|
|
}
|
|
|
|
return s.ServeConnV2(ctx, conn, streamID)
|
|
|
|
default:
|
|
|
|
s.Logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version))
|
|
|
|
return xerrors.New("unsupported version")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-07-19 16:11:11 +00:00
|
|
|
func ServeWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentConn) error {
|
|
|
|
go func() {
|
|
|
|
err := forwardNodesToWorkspaceProxy(ctx, conn, ma)
|
|
|
|
if err != nil {
|
|
|
|
_ = conn.Close()
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
|
|
|
|
decoder := json.NewDecoder(conn)
|
|
|
|
for {
|
|
|
|
var msg wsproxysdk.CoordinateMessage
|
|
|
|
err := decoder.Decode(&msg)
|
|
|
|
if err != nil {
|
2023-08-22 02:55:39 +00:00
|
|
|
if errors.Is(err, net.ErrClosed) {
|
|
|
|
return nil
|
|
|
|
}
|
2023-07-19 16:11:11 +00:00
|
|
|
return xerrors.Errorf("read json: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
switch msg.Type {
|
|
|
|
case wsproxysdk.CoordinateMessageTypeSubscribe:
|
|
|
|
err := ma.SubscribeAgent(msg.AgentID)
|
|
|
|
if err != nil {
|
|
|
|
return xerrors.Errorf("subscribe agent: %w", err)
|
|
|
|
}
|
|
|
|
case wsproxysdk.CoordinateMessageTypeUnsubscribe:
|
|
|
|
err := ma.UnsubscribeAgent(msg.AgentID)
|
|
|
|
if err != nil {
|
|
|
|
return xerrors.Errorf("unsubscribe agent: %w", err)
|
|
|
|
}
|
|
|
|
case wsproxysdk.CoordinateMessageTypeNodeUpdate:
|
2024-01-22 07:07:50 +00:00
|
|
|
pn, err := agpl.NodeToProto(msg.Node)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
err = ma.UpdateSelf(pn)
|
2023-07-19 16:11:11 +00:00
|
|
|
if err != nil {
|
|
|
|
return xerrors.Errorf("update self: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
default:
|
|
|
|
return xerrors.Errorf("unknown message type %q", msg.Type)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func forwardNodesToWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentConn) error {
|
|
|
|
var lastData []byte
|
|
|
|
for {
|
2024-01-22 07:07:50 +00:00
|
|
|
resp, ok := ma.NextUpdate(ctx)
|
2023-07-19 16:11:11 +00:00
|
|
|
if !ok {
|
|
|
|
return xerrors.New("multiagent is closed")
|
|
|
|
}
|
2024-01-22 07:07:50 +00:00
|
|
|
nodes, err := agpl.OnlyNodeUpdates(resp)
|
|
|
|
if err != nil {
|
|
|
|
return xerrors.Errorf("failed to convert response: %w", err)
|
|
|
|
}
|
2023-07-19 16:11:11 +00:00
|
|
|
data, err := json.Marshal(wsproxysdk.CoordinateNodes{Nodes: nodes})
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
if bytes.Equal(lastData, data) {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
// Set a deadline so that hung connections don't put back pressure on the system.
|
|
|
|
// Node updates are tiny, so even the dinkiest connection can handle them if it's not hung.
|
|
|
|
err = conn.SetWriteDeadline(time.Now().Add(agpl.WriteTimeout))
|
|
|
|
if err != nil {
|
|
|
|
// often, this is just because the connection is closed/broken, so only log at debug.
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
_, err = conn.Write(data)
|
|
|
|
if err != nil {
|
|
|
|
// often, this is just because the connection is closed/broken, so only log at debug.
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are
|
|
|
|
// *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write()
|
|
|
|
// fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context.
|
|
|
|
// If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after
|
|
|
|
// our successful write, it is important that we reset the deadline before it fires.
|
|
|
|
err = conn.SetWriteDeadline(time.Time{})
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
lastData = data
|
|
|
|
}
|
|
|
|
}
|