2023-12-15 07:49:30 +00:00
|
|
|
package tailnet
|
|
|
|
|
|
|
|
import (
|
2023-12-15 08:48:39 +00:00
|
|
|
"context"
|
|
|
|
"io"
|
|
|
|
"net"
|
|
|
|
"sync/atomic"
|
2024-01-02 04:07:57 +00:00
|
|
|
"time"
|
2023-12-15 08:48:39 +00:00
|
|
|
|
|
|
|
"github.com/google/uuid"
|
|
|
|
"github.com/hashicorp/yamux"
|
|
|
|
"storj.io/drpc/drpcmux"
|
|
|
|
"storj.io/drpc/drpcserver"
|
2024-01-02 04:07:57 +00:00
|
|
|
"tailscale.com/tailcfg"
|
2023-12-15 08:48:39 +00:00
|
|
|
|
|
|
|
"cdr.dev/slog"
|
2024-02-16 18:43:07 +00:00
|
|
|
"github.com/coder/coder/v2/apiversion"
|
2023-12-15 08:48:39 +00:00
|
|
|
"github.com/coder/coder/v2/tailnet/proto"
|
2023-12-15 07:49:30 +00:00
|
|
|
|
|
|
|
"golang.org/x/xerrors"
|
|
|
|
)
|
|
|
|
|
2023-12-15 08:48:39 +00:00
|
|
|
type streamIDContextKey struct{}
|
|
|
|
|
|
|
|
// StreamID identifies the caller of the CoordinateTailnet RPC. We store this
|
|
|
|
// on the context, since the information is extracted at the HTTP layer for
|
|
|
|
// remote clients of the API, or set outside tailnet for local clients (e.g.
|
|
|
|
// Coderd's single_tailnet)
|
|
|
|
type StreamID struct {
|
|
|
|
Name string
|
|
|
|
ID uuid.UUID
|
|
|
|
Auth TunnelAuth
|
|
|
|
}
|
|
|
|
|
|
|
|
func WithStreamID(ctx context.Context, streamID StreamID) context.Context {
|
|
|
|
return context.WithValue(ctx, streamIDContextKey{}, streamID)
|
|
|
|
}
|
|
|
|
|
|
|
|
// ClientService is a tailnet coordination service that accepts a connection and version from a
|
|
|
|
// tailnet client, and support versions 1.0 and 2.x of the Tailnet API protocol.
|
|
|
|
type ClientService struct {
|
2024-01-18 06:10:36 +00:00
|
|
|
Logger slog.Logger
|
|
|
|
CoordPtr *atomic.Pointer[Coordinator]
|
2023-12-15 08:48:39 +00:00
|
|
|
drpc *drpcserver.Server
|
|
|
|
}
|
|
|
|
|
|
|
|
// NewClientService returns a ClientService based on the given Coordinator pointer. The pointer is
|
|
|
|
// loaded on each processed connection.
|
2024-01-02 04:07:57 +00:00
|
|
|
func NewClientService(
|
|
|
|
logger slog.Logger,
|
|
|
|
coordPtr *atomic.Pointer[Coordinator],
|
|
|
|
derpMapUpdateFrequency time.Duration,
|
|
|
|
derpMapFn func() *tailcfg.DERPMap,
|
|
|
|
) (
|
|
|
|
*ClientService, error,
|
|
|
|
) {
|
2024-01-18 06:10:36 +00:00
|
|
|
s := &ClientService{Logger: logger, CoordPtr: coordPtr}
|
2023-12-15 08:48:39 +00:00
|
|
|
mux := drpcmux.New()
|
2024-01-02 04:07:57 +00:00
|
|
|
drpcService := &DRPCService{
|
|
|
|
CoordPtr: coordPtr,
|
|
|
|
Logger: logger,
|
|
|
|
DerpMapUpdateFrequency: derpMapUpdateFrequency,
|
|
|
|
DerpMapFn: derpMapFn,
|
|
|
|
}
|
2024-01-02 06:02:45 +00:00
|
|
|
err := proto.DRPCRegisterTailnet(mux, drpcService)
|
2023-12-15 08:48:39 +00:00
|
|
|
if err != nil {
|
|
|
|
return nil, xerrors.Errorf("register DRPC service: %w", err)
|
|
|
|
}
|
|
|
|
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
|
|
|
|
Log: func(err error) {
|
2024-01-22 07:07:50 +00:00
|
|
|
if xerrors.Is(err, io.EOF) ||
|
|
|
|
xerrors.Is(err, context.Canceled) ||
|
|
|
|
xerrors.Is(err, context.DeadlineExceeded) {
|
2023-12-15 08:48:39 +00:00
|
|
|
return
|
|
|
|
}
|
|
|
|
logger.Debug(context.Background(), "drpc server error", slog.Error(err))
|
|
|
|
},
|
|
|
|
})
|
|
|
|
s.drpc = server
|
|
|
|
return s, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *ClientService) ServeClient(ctx context.Context, version string, conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
|
2024-01-05 10:22:07 +00:00
|
|
|
major, _, err := apiversion.Parse(version)
|
2023-12-15 08:48:39 +00:00
|
|
|
if err != nil {
|
2024-01-18 06:10:36 +00:00
|
|
|
s.Logger.Warn(ctx, "serve client called with unparsable version", slog.Error(err))
|
2023-12-15 08:48:39 +00:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
switch major {
|
|
|
|
case 1:
|
2024-01-18 06:10:36 +00:00
|
|
|
coord := *(s.CoordPtr.Load())
|
2023-12-15 08:48:39 +00:00
|
|
|
return coord.ServeClient(conn, id, agent)
|
|
|
|
case 2:
|
|
|
|
auth := ClientTunnelAuth{AgentID: agent}
|
|
|
|
streamID := StreamID{
|
|
|
|
Name: "client",
|
|
|
|
ID: id,
|
|
|
|
Auth: auth,
|
|
|
|
}
|
2024-01-18 06:10:36 +00:00
|
|
|
return s.ServeConnV2(ctx, conn, streamID)
|
2023-12-15 08:48:39 +00:00
|
|
|
default:
|
2024-01-18 06:10:36 +00:00
|
|
|
s.Logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version))
|
2023-12-15 08:48:39 +00:00
|
|
|
return xerrors.New("unsupported version")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-01-18 06:10:36 +00:00
|
|
|
func (s ClientService) ServeConnV2(ctx context.Context, conn net.Conn, streamID StreamID) error {
|
|
|
|
config := yamux.DefaultConfig()
|
|
|
|
config.LogOutput = io.Discard
|
|
|
|
session, err := yamux.Server(conn, config)
|
|
|
|
if err != nil {
|
|
|
|
return xerrors.Errorf("yamux init failed: %w", err)
|
|
|
|
}
|
|
|
|
ctx = WithStreamID(ctx, streamID)
|
|
|
|
return s.drpc.Serve(ctx, session)
|
|
|
|
}
|
|
|
|
|
2023-12-15 08:48:39 +00:00
|
|
|
// DRPCService is the dRPC-based, version 2.x of the tailnet API and implements proto.DRPCClientServer
|
|
|
|
type DRPCService struct {
|
2024-01-02 04:07:57 +00:00
|
|
|
CoordPtr *atomic.Pointer[Coordinator]
|
|
|
|
Logger slog.Logger
|
|
|
|
DerpMapUpdateFrequency time.Duration
|
|
|
|
DerpMapFn func() *tailcfg.DERPMap
|
2023-12-15 08:48:39 +00:00
|
|
|
}
|
|
|
|
|
2024-01-02 06:02:45 +00:00
|
|
|
func (s *DRPCService) StreamDERPMaps(_ *proto.StreamDERPMapsRequest, stream proto.DRPCTailnet_StreamDERPMapsStream) error {
|
2024-01-02 04:07:57 +00:00
|
|
|
defer stream.Close()
|
|
|
|
|
|
|
|
ticker := time.NewTicker(s.DerpMapUpdateFrequency)
|
|
|
|
defer ticker.Stop()
|
2023-12-15 08:48:39 +00:00
|
|
|
|
2024-01-02 04:07:57 +00:00
|
|
|
var lastDERPMap *tailcfg.DERPMap
|
|
|
|
for {
|
|
|
|
derpMap := s.DerpMapFn()
|
2024-01-23 10:42:07 +00:00
|
|
|
if derpMap == nil {
|
|
|
|
// in testing, we send nil to close the stream.
|
|
|
|
return io.EOF
|
|
|
|
}
|
2024-01-02 04:07:57 +00:00
|
|
|
if lastDERPMap == nil || !CompareDERPMaps(lastDERPMap, derpMap) {
|
|
|
|
protoDERPMap := DERPMapToProto(derpMap)
|
|
|
|
err := stream.Send(protoDERPMap)
|
|
|
|
if err != nil {
|
|
|
|
return xerrors.Errorf("send derp map: %w", err)
|
|
|
|
}
|
|
|
|
lastDERPMap = derpMap
|
|
|
|
}
|
|
|
|
|
|
|
|
ticker.Reset(s.DerpMapUpdateFrequency)
|
|
|
|
select {
|
|
|
|
case <-stream.Context().Done():
|
|
|
|
return nil
|
|
|
|
case <-ticker.C:
|
|
|
|
}
|
|
|
|
}
|
2023-12-15 08:48:39 +00:00
|
|
|
}
|
|
|
|
|
2024-01-02 06:02:45 +00:00
|
|
|
func (s *DRPCService) Coordinate(stream proto.DRPCTailnet_CoordinateStream) error {
|
2023-12-15 08:48:39 +00:00
|
|
|
ctx := stream.Context()
|
|
|
|
streamID, ok := ctx.Value(streamIDContextKey{}).(StreamID)
|
|
|
|
if !ok {
|
|
|
|
_ = stream.Close()
|
|
|
|
return xerrors.New("no Stream ID")
|
|
|
|
}
|
2024-01-02 04:07:57 +00:00
|
|
|
logger := s.Logger.With(slog.F("peer_id", streamID), slog.F("name", streamID.Name))
|
2023-12-15 08:48:39 +00:00
|
|
|
logger.Debug(ctx, "starting tailnet Coordinate")
|
2024-01-02 04:07:57 +00:00
|
|
|
coord := *(s.CoordPtr.Load())
|
2023-12-15 08:48:39 +00:00
|
|
|
reqs, resps := coord.Coordinate(ctx, streamID.ID, streamID.Name, streamID.Auth)
|
|
|
|
c := communicator{
|
|
|
|
logger: logger,
|
|
|
|
stream: stream,
|
|
|
|
reqs: reqs,
|
|
|
|
resps: resps,
|
|
|
|
}
|
|
|
|
c.communicate()
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
type communicator struct {
|
|
|
|
logger slog.Logger
|
2024-01-02 06:02:45 +00:00
|
|
|
stream proto.DRPCTailnet_CoordinateStream
|
2023-12-15 08:48:39 +00:00
|
|
|
reqs chan<- *proto.CoordinateRequest
|
|
|
|
resps <-chan *proto.CoordinateResponse
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c communicator) communicate() {
|
|
|
|
go c.loopReq()
|
|
|
|
c.loopResp()
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c communicator) loopReq() {
|
|
|
|
ctx := c.stream.Context()
|
|
|
|
defer close(c.reqs)
|
|
|
|
for {
|
|
|
|
req, err := c.stream.Recv()
|
|
|
|
if err != nil {
|
|
|
|
c.logger.Debug(ctx, "error receiving requests from DRPC stream", slog.Error(err))
|
|
|
|
return
|
|
|
|
}
|
|
|
|
err = SendCtx(ctx, c.reqs, req)
|
|
|
|
if err != nil {
|
|
|
|
c.logger.Debug(ctx, "context done while sending coordinate request", slog.Error(ctx.Err()))
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c communicator) loopResp() {
|
|
|
|
ctx := c.stream.Context()
|
|
|
|
defer func() {
|
|
|
|
err := c.stream.Close()
|
|
|
|
if err != nil {
|
|
|
|
c.logger.Debug(ctx, "loopResp hit error closing stream", slog.Error(err))
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
for {
|
|
|
|
resp, err := RecvCtx(ctx, c.resps)
|
|
|
|
if err != nil {
|
|
|
|
c.logger.Debug(ctx, "loopResp failed to get response", slog.Error(err))
|
|
|
|
return
|
|
|
|
}
|
|
|
|
err = c.stream.Send(resp)
|
|
|
|
if err != nil {
|
|
|
|
c.logger.Debug(ctx, "loopResp failed to send response to DRPC stream", slog.Error(err))
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|