feat: add conversions from tailnet to proto (#10441)

Adds conversions from existing tailnet types to protobuf
This commit is contained in:
Spike Curtis 2023-11-01 10:54:00 +04:00 committed by GitHub
parent f4026edd71
commit 6882e8e524
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 293 additions and 2 deletions

View File

@ -223,12 +223,12 @@ func TestTimezoneOffsets(t *testing.T) {
// Name: "Central",
// Loc: must(time.LoadLocation("America/Chicago")),
// ExpectedOffset: -5,
//},
// },
//{
// Name: "Ireland",
// Loc: must(time.LoadLocation("Europe/Dublin")),
// ExpectedOffset: 1,
//},
// },
{
Name: "HalfHourTz",
// This timezone is +6:30, but the function rounds to the nearest hour.

138
tailnet/convert.go Normal file
View File

@ -0,0 +1,138 @@
package tailnet
import (
"net/netip"
"github.com/google/uuid"
"golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/timestamppb"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"github.com/coder/coder/v2/tailnet/proto"
)
func UUIDToByteSlice(u uuid.UUID) []byte {
b := [16]byte(u)
o := make([]byte, 16)
copy(o, b[:]) // copy so that we can't mutate the original
return o
}
func NodeToProto(n *Node) (*proto.Node, error) {
k, err := n.Key.MarshalBinary()
if err != nil {
return nil, err
}
disco, err := n.DiscoKey.MarshalText()
if err != nil {
return nil, err
}
derpForcedWebsocket := make(map[int32]string)
for i, s := range n.DERPForcedWebsocket {
derpForcedWebsocket[int32(i)] = s
}
addresses := make([]string, len(n.Addresses))
for i, prefix := range n.Addresses {
s, err := prefix.MarshalText()
if err != nil {
return nil, err
}
addresses[i] = string(s)
}
allowedIPs := make([]string, len(n.AllowedIPs))
for i, prefix := range n.AllowedIPs {
s, err := prefix.MarshalText()
if err != nil {
return nil, err
}
allowedIPs[i] = string(s)
}
return &proto.Node{
Id: int64(n.ID),
AsOf: timestamppb.New(n.AsOf),
Key: k,
Disco: string(disco),
PreferredDerp: int32(n.PreferredDERP),
DerpLatency: n.DERPLatency,
DerpForcedWebsocket: derpForcedWebsocket,
Addresses: addresses,
AllowedIps: allowedIPs,
Endpoints: n.Endpoints,
}, nil
}
func ProtoToNode(p *proto.Node) (*Node, error) {
k := key.NodePublic{}
err := k.UnmarshalBinary(p.GetKey())
if err != nil {
return nil, err
}
disco := key.DiscoPublic{}
err = disco.UnmarshalText([]byte(p.GetDisco()))
if err != nil {
return nil, err
}
derpForcedWebsocket := make(map[int]string)
for i, s := range p.GetDerpForcedWebsocket() {
derpForcedWebsocket[int(i)] = s
}
addresses := make([]netip.Prefix, len(p.GetAddresses()))
for i, prefix := range p.GetAddresses() {
err = addresses[i].UnmarshalText([]byte(prefix))
if err != nil {
return nil, err
}
}
allowedIPs := make([]netip.Prefix, len(p.GetAllowedIps()))
for i, prefix := range p.GetAllowedIps() {
err = allowedIPs[i].UnmarshalText([]byte(prefix))
if err != nil {
return nil, err
}
}
return &Node{
ID: tailcfg.NodeID(p.GetId()),
AsOf: p.GetAsOf().AsTime(),
Key: k,
DiscoKey: disco,
PreferredDERP: int(p.GetPreferredDerp()),
DERPLatency: p.GetDerpLatency(),
DERPForcedWebsocket: derpForcedWebsocket,
Addresses: addresses,
AllowedIPs: allowedIPs,
Endpoints: p.Endpoints,
}, nil
}
func OnlyNodeUpdates(resp *proto.CoordinateResponse) ([]*Node, error) {
nodes := make([]*Node, 0, len(resp.GetPeerUpdates()))
for _, pu := range resp.GetPeerUpdates() {
if pu.Kind != proto.CoordinateResponse_PeerUpdate_NODE {
continue
}
n, err := ProtoToNode(pu.Node)
if err != nil {
return nil, xerrors.Errorf("failed conversion from protobuf: %w", err)
}
nodes = append(nodes, n)
}
return nodes, nil
}
func SingleNodeUpdate(id uuid.UUID, node *Node, reason string) (*proto.CoordinateResponse, error) {
p, err := NodeToProto(node)
if err != nil {
return nil, xerrors.Errorf("node failed conversion to protobuf: %w", err)
}
return &proto.CoordinateResponse{
PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{
{
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
Uuid: UUIDToByteSlice(id),
Node: p,
Reason: reason,
},
},
}, nil
}

153
tailnet/convert_test.go Normal file
View File

@ -0,0 +1,153 @@
package tailnet_test
import (
"net/netip"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
)
func TestNode(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
node tailnet.Node
}{
{
name: "Zero",
},
{
name: "AllFields",
node: tailnet.Node{
ID: 33,
AsOf: time.Now(),
Key: key.NewNode().Public(),
DiscoKey: key.NewDisco().Public(),
PreferredDERP: 12,
DERPLatency: map[string]float64{
"1": 0.2,
"12": 0.3,
},
DERPForcedWebsocket: map[int]string{
1: "forced",
},
Addresses: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("ff80::aa:1/128"),
},
AllowedIPs: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("ff80::aa:1/128"),
},
Endpoints: []string{
"192.168.0.1:3305",
"[ff80::aa:1]:2049",
},
},
},
{
name: "dbtime",
node: tailnet.Node{
AsOf: dbtime.Now(),
},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
p, err := tailnet.NodeToProto(&tc.node)
require.NoError(t, err)
inv, err := tailnet.ProtoToNode(p)
require.NoError(t, err)
require.Equal(t, tc.node.ID, inv.ID)
require.True(t, tc.node.AsOf.Equal(inv.AsOf))
require.Equal(t, tc.node.Key, inv.Key)
require.Equal(t, tc.node.DiscoKey, inv.DiscoKey)
require.Equal(t, tc.node.PreferredDERP, inv.PreferredDERP)
require.Equal(t, tc.node.DERPLatency, inv.DERPLatency)
require.Equal(t, len(tc.node.DERPForcedWebsocket), len(inv.DERPForcedWebsocket))
for k, v := range inv.DERPForcedWebsocket {
nv, ok := tc.node.DERPForcedWebsocket[k]
require.True(t, ok)
require.Equal(t, nv, v)
}
require.ElementsMatch(t, tc.node.Addresses, inv.Addresses)
require.ElementsMatch(t, tc.node.AllowedIPs, inv.AllowedIPs)
require.ElementsMatch(t, tc.node.Endpoints, inv.Endpoints)
})
}
}
func TestUUIDToByteSlice(t *testing.T) {
t.Parallel()
u := uuid.New()
b := tailnet.UUIDToByteSlice(u)
u2, err := uuid.FromBytes(b)
require.NoError(t, err)
require.Equal(t, u, u2)
b = tailnet.UUIDToByteSlice(uuid.Nil)
u2, err = uuid.FromBytes(b)
require.NoError(t, err)
require.Equal(t, uuid.Nil, u2)
}
func TestOnlyNodeUpdates(t *testing.T) {
t.Parallel()
node := &tailnet.Node{ID: tailcfg.NodeID(1)}
p, err := tailnet.NodeToProto(node)
require.NoError(t, err)
resp := &proto.CoordinateResponse{
PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{
{
Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1},
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
Node: p,
},
{
Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2},
Kind: proto.CoordinateResponse_PeerUpdate_DISCONNECTED,
Reason: "disconnected",
},
{
Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3},
Kind: proto.CoordinateResponse_PeerUpdate_LOST,
Reason: "disconnected",
},
{
Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4},
},
},
}
nodes, err := tailnet.OnlyNodeUpdates(resp)
require.NoError(t, err)
require.Len(t, nodes, 1)
require.Equal(t, tailcfg.NodeID(1), nodes[0].ID)
}
func TestSingleNodeUpdate(t *testing.T) {
t.Parallel()
node := &tailnet.Node{ID: tailcfg.NodeID(1)}
u := uuid.New()
resp, err := tailnet.SingleNodeUpdate(u, node, "unit test")
require.NoError(t, err)
require.Len(t, resp.PeerUpdates, 1)
up := resp.PeerUpdates[0]
require.Equal(t, proto.CoordinateResponse_PeerUpdate_NODE, up.Kind)
u2, err := uuid.FromBytes(up.Uuid)
require.NoError(t, err)
require.Equal(t, u, u2)
require.Equal(t, "unit test", up.Reason)
require.EqualValues(t, 1, up.Node.Id)
}