feat: add tailnet v2 Service and Client (#11225)

Part of #10532

Adds a tailnet ClientService that accepts a net.Conn and serves v1 or v2 of the tailnet API.

Also adds a DRPCService that implements the DRPC interface for the v2 API.  This component is within the ClientService, but needs to be reusable and exported so that we can also embed it in the Agent API.

Finally, includes a NewDRPCClient function that takes a net.Conn and runs dRPC in yamux over it on the client side.
This commit is contained in:
Spike Curtis 2023-12-15 12:48:39 +04:00 committed by GitHub
parent 9a4e1100fa
commit a58e4febb9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 411 additions and 10 deletions

22
tailnet/client.go Normal file
View File

@ -0,0 +1,22 @@
package tailnet
import (
"io"
"net"
"github.com/hashicorp/yamux"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk/drpc"
"github.com/coder/coder/v2/tailnet/proto"
)
func NewDRPCClient(conn net.Conn) (proto.DRPCClientClient, error) {
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
session, err := yamux.Client(conn, config)
if err != nil {
return nil, xerrors.Errorf("multiplex client: %w", err)
}
return proto.NewDRPCClientClient(drpc.MultiplexedConn(session)), nil
}

View File

@ -1,8 +1,20 @@
package tailnet
import (
"context"
"io"
"net"
"strconv"
"strings"
"sync/atomic"
"github.com/google/uuid"
"github.com/hashicorp/yamux"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"cdr.dev/slog"
"github.com/coder/coder/v2/tailnet/proto"
"golang.org/x/xerrors"
)
@ -15,17 +27,9 @@ const (
var SupportedMajors = []int{2, 1}
func ValidateVersion(version string) error {
parts := strings.Split(version, ".")
if len(parts) != 2 {
return xerrors.Errorf("invalid version string: %s", version)
}
major, err := strconv.Atoi(parts[0])
major, minor, err := parseVersion(version)
if err != nil {
return xerrors.Errorf("invalid major version: %s", version)
}
minor, err := strconv.Atoi(parts[1])
if err != nil {
return xerrors.Errorf("invalid minor version: %s", version)
return err
}
if major > CurrentMajor {
return xerrors.Errorf("server is at version %d.%d, behind requested version %s",
@ -45,3 +49,186 @@ func ValidateVersion(version string) error {
}
return xerrors.Errorf("version %s is no longer supported", version)
}
func parseVersion(version string) (major int, minor int, err error) {
parts := strings.Split(version, ".")
if len(parts) != 2 {
return 0, 0, xerrors.Errorf("invalid version string: %s", version)
}
major, err = strconv.Atoi(parts[0])
if err != nil {
return 0, 0, xerrors.Errorf("invalid major version: %s", version)
}
minor, err = strconv.Atoi(parts[1])
if err != nil {
return 0, 0, xerrors.Errorf("invalid minor version: %s", version)
}
return major, minor, nil
}
type streamIDContextKey struct{}
// StreamID identifies the caller of the CoordinateTailnet RPC. We store this
// on the context, since the information is extracted at the HTTP layer for
// remote clients of the API, or set outside tailnet for local clients (e.g.
// Coderd's single_tailnet)
type StreamID struct {
Name string
ID uuid.UUID
Auth TunnelAuth
}
func WithStreamID(ctx context.Context, streamID StreamID) context.Context {
return context.WithValue(ctx, streamIDContextKey{}, streamID)
}
// ClientService is a tailnet coordination service that accepts a connection and version from a
// tailnet client, and support versions 1.0 and 2.x of the Tailnet API protocol.
type ClientService struct {
logger slog.Logger
coordPtr *atomic.Pointer[Coordinator]
drpc *drpcserver.Server
}
// NewClientService returns a ClientService based on the given Coordinator pointer. The pointer is
// loaded on each processed connection.
func NewClientService(logger slog.Logger, coordPtr *atomic.Pointer[Coordinator]) (*ClientService, error) {
s := &ClientService{logger: logger, coordPtr: coordPtr}
mux := drpcmux.New()
drpcService := NewDRPCService(logger, coordPtr)
err := proto.DRPCRegisterClient(mux, drpcService)
if err != nil {
return nil, xerrors.Errorf("register DRPC service: %w", err)
}
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Log: func(err error) {
if xerrors.Is(err, io.EOF) {
return
}
logger.Debug(context.Background(), "drpc server error", slog.Error(err))
},
})
s.drpc = server
return s, nil
}
func (s *ClientService) ServeClient(ctx context.Context, version string, conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
major, _, err := parseVersion(version)
if err != nil {
s.logger.Warn(ctx, "serve client called with unparsable version", slog.Error(err))
return err
}
switch major {
case 1:
coord := *(s.coordPtr.Load())
return coord.ServeClient(conn, id, agent)
case 2:
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
session, err := yamux.Server(conn, config)
if err != nil {
return xerrors.Errorf("yamux init failed: %w", err)
}
auth := ClientTunnelAuth{AgentID: agent}
streamID := StreamID{
Name: "client",
ID: id,
Auth: auth,
}
ctx = WithStreamID(ctx, streamID)
return s.drpc.Serve(ctx, session)
default:
s.logger.Warn(ctx, "serve client called with unsupported version", slog.F("version", version))
return xerrors.New("unsupported version")
}
}
// DRPCService is the dRPC-based, version 2.x of the tailnet API and implements proto.DRPCClientServer
type DRPCService struct {
coordPtr *atomic.Pointer[Coordinator]
logger slog.Logger
}
func NewDRPCService(logger slog.Logger, coordPtr *atomic.Pointer[Coordinator]) *DRPCService {
return &DRPCService{
coordPtr: coordPtr,
logger: logger,
}
}
func (*DRPCService) StreamDERPMaps(*proto.StreamDERPMapsRequest, proto.DRPCClient_StreamDERPMapsStream) error {
// TODO integrate with Dean's PR implementation
return xerrors.New("unimplemented")
}
func (s *DRPCService) CoordinateTailnet(stream proto.DRPCClient_CoordinateTailnetStream) error {
ctx := stream.Context()
streamID, ok := ctx.Value(streamIDContextKey{}).(StreamID)
if !ok {
_ = stream.Close()
return xerrors.New("no Stream ID")
}
logger := s.logger.With(slog.F("peer_id", streamID), slog.F("name", streamID.Name))
logger.Debug(ctx, "starting tailnet Coordinate")
coord := *(s.coordPtr.Load())
reqs, resps := coord.Coordinate(ctx, streamID.ID, streamID.Name, streamID.Auth)
c := communicator{
logger: logger,
stream: stream,
reqs: reqs,
resps: resps,
}
c.communicate()
return nil
}
type communicator struct {
logger slog.Logger
stream proto.DRPCClient_CoordinateTailnetStream
reqs chan<- *proto.CoordinateRequest
resps <-chan *proto.CoordinateResponse
}
func (c communicator) communicate() {
go c.loopReq()
c.loopResp()
}
func (c communicator) loopReq() {
ctx := c.stream.Context()
defer close(c.reqs)
for {
req, err := c.stream.Recv()
if err != nil {
c.logger.Debug(ctx, "error receiving requests from DRPC stream", slog.Error(err))
return
}
err = SendCtx(ctx, c.reqs, req)
if err != nil {
c.logger.Debug(ctx, "context done while sending coordinate request", slog.Error(ctx.Err()))
return
}
}
}
func (c communicator) loopResp() {
ctx := c.stream.Context()
defer func() {
err := c.stream.Close()
if err != nil {
c.logger.Debug(ctx, "loopResp hit error closing stream", slog.Error(err))
}
}()
for {
resp, err := RecvCtx(ctx, c.resps)
if err != nil {
c.logger.Debug(ctx, "loopResp failed to get response", slog.Error(err))
return
}
err = c.stream.Send(resp)
if err != nil {
c.logger.Debug(ctx, "loopResp failed to send response to DRPC stream", slog.Error(err))
return
}
}
}

View File

@ -1,9 +1,23 @@
package tailnet_test
import (
"context"
"fmt"
"io"
"net"
"net/http"
"sync/atomic"
"testing"
"golang.org/x/xerrors"
"github.com/google/uuid"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/testutil"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/tailnet"
@ -72,3 +86,181 @@ func TestValidateVersion(t *testing.T) {
})
}
}
func TestClientService_ServeClient_V2(t *testing.T) {
t.Parallel()
fCoord := newFakeCoordinator()
var coord tailnet.Coordinator = fCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
uut, err := tailnet.NewClientService(logger, &coordPtr)
require.NoError(t, err)
ctx := testutil.Context(t, testutil.WaitShort)
c, s := net.Pipe()
defer c.Close()
defer s.Close()
clientID := uuid.MustParse("10000001-0000-0000-0000-000000000000")
agentID := uuid.MustParse("20000001-0000-0000-0000-000000000000")
errCh := make(chan error, 1)
go func() {
err := uut.ServeClient(ctx, "2.0", s, clientID, agentID)
t.Logf("ServeClient returned; err=%v", err)
errCh <- err
}()
client, err := tailnet.NewDRPCClient(c)
require.NoError(t, err)
stream, err := client.CoordinateTailnet(ctx)
require.NoError(t, err)
defer stream.Close()
err = stream.Send(&proto.CoordinateRequest{
UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: &proto.Node{PreferredDerp: 11}},
})
require.NoError(t, err)
call := testutil.RequireRecvCtx(ctx, t, fCoord.coordinateCalls)
require.NotNil(t, call)
require.Equal(t, call.id, clientID)
require.Equal(t, call.name, "client")
require.True(t, call.auth.Authorize(agentID))
req := testutil.RequireRecvCtx(ctx, t, call.reqs)
require.Equal(t, int32(11), req.GetUpdateSelf().GetNode().GetPreferredDerp())
call.resps <- &proto.CoordinateResponse{PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{
{
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
Node: &proto.Node{PreferredDerp: 22},
Uuid: agentID[:],
},
}}
resp, err := stream.Recv()
require.NoError(t, err)
u := resp.GetPeerUpdates()
require.Len(t, u, 1)
require.Equal(t, int32(22), u[0].GetNode().GetPreferredDerp())
err = stream.Close()
require.NoError(t, err)
// stream ^^ is just one RPC; we need to close the Conn to end the session.
err = c.Close()
require.NoError(t, err)
err = testutil.RequireRecvCtx(ctx, t, errCh)
require.ErrorIs(t, err, io.EOF)
}
func TestClientService_ServeClient_V1(t *testing.T) {
t.Parallel()
fCoord := newFakeCoordinator()
var coord tailnet.Coordinator = fCoord
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
coordPtr.Store(&coord)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
uut, err := tailnet.NewClientService(logger, &coordPtr)
require.NoError(t, err)
ctx := testutil.Context(t, testutil.WaitShort)
c, s := net.Pipe()
defer c.Close()
defer s.Close()
clientID := uuid.MustParse("10000001-0000-0000-0000-000000000000")
agentID := uuid.MustParse("20000001-0000-0000-0000-000000000000")
errCh := make(chan error, 1)
go func() {
err := uut.ServeClient(ctx, "1.0", s, clientID, agentID)
t.Logf("ServeClient returned; err=%v", err)
errCh <- err
}()
call := testutil.RequireRecvCtx(ctx, t, fCoord.serveClientCalls)
require.NotNil(t, call)
require.Equal(t, call.id, clientID)
require.Equal(t, call.agent, agentID)
require.Equal(t, s, call.conn)
expectedError := xerrors.New("test error")
select {
case call.errCh <- expectedError:
// ok!
case <-ctx.Done():
t.Fatalf("timeout sending error")
}
err = testutil.RequireRecvCtx(ctx, t, errCh)
require.ErrorIs(t, err, expectedError)
}
type fakeCoordinator struct {
coordinateCalls chan *fakeCoordinate
serveClientCalls chan *fakeServeClient
}
func (*fakeCoordinator) ServeHTTPDebug(http.ResponseWriter, *http.Request) {
panic("unimplemented")
}
func (*fakeCoordinator) Node(uuid.UUID) *tailnet.Node {
panic("unimplemented")
}
func (f *fakeCoordinator) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error {
errCh := make(chan error)
f.serveClientCalls <- &fakeServeClient{
conn: conn,
id: id,
agent: agent,
errCh: errCh,
}
return <-errCh
}
func (*fakeCoordinator) ServeAgent(net.Conn, uuid.UUID, string) error {
panic("unimplemented")
}
func (*fakeCoordinator) Close() error {
panic("unimplemented")
}
func (*fakeCoordinator) ServeMultiAgent(uuid.UUID) tailnet.MultiAgentConn {
panic("unimplemented")
}
func (f *fakeCoordinator) Coordinate(ctx context.Context, id uuid.UUID, name string, a tailnet.TunnelAuth) (chan<- *proto.CoordinateRequest, <-chan *proto.CoordinateResponse) {
reqs := make(chan *proto.CoordinateRequest, 100)
resps := make(chan *proto.CoordinateResponse, 100)
f.coordinateCalls <- &fakeCoordinate{
ctx: ctx,
id: id,
name: name,
auth: a,
reqs: reqs,
resps: resps,
}
return reqs, resps
}
func newFakeCoordinator() *fakeCoordinator {
return &fakeCoordinator{
coordinateCalls: make(chan *fakeCoordinate, 100),
serveClientCalls: make(chan *fakeServeClient, 100),
}
}
type fakeCoordinate struct {
ctx context.Context
id uuid.UUID
name string
auth tailnet.TunnelAuth
reqs chan *proto.CoordinateRequest
resps chan *proto.CoordinateResponse
}
type fakeServeClient struct {
conn net.Conn
id uuid.UUID
agent uuid.UUID
errCh chan error
}