mirror of https://github.com/coder/coder.git
feat: add conversions from tailnet to proto (#10441)
Adds conversions from existing tailnet types to protobuf
This commit is contained in:
parent
f4026edd71
commit
6882e8e524
|
@ -223,12 +223,12 @@ func TestTimezoneOffsets(t *testing.T) {
|
|||
// Name: "Central",
|
||||
// Loc: must(time.LoadLocation("America/Chicago")),
|
||||
// ExpectedOffset: -5,
|
||||
//},
|
||||
// },
|
||||
//{
|
||||
// Name: "Ireland",
|
||||
// Loc: must(time.LoadLocation("Europe/Dublin")),
|
||||
// ExpectedOffset: 1,
|
||||
//},
|
||||
// },
|
||||
{
|
||||
Name: "HalfHourTz",
|
||||
// This timezone is +6:30, but the function rounds to the nearest hour.
|
||||
|
|
|
@ -0,0 +1,138 @@
|
|||
package tailnet
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
)
|
||||
|
||||
func UUIDToByteSlice(u uuid.UUID) []byte {
|
||||
b := [16]byte(u)
|
||||
o := make([]byte, 16)
|
||||
copy(o, b[:]) // copy so that we can't mutate the original
|
||||
return o
|
||||
}
|
||||
|
||||
func NodeToProto(n *Node) (*proto.Node, error) {
|
||||
k, err := n.Key.MarshalBinary()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
disco, err := n.DiscoKey.MarshalText()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
derpForcedWebsocket := make(map[int32]string)
|
||||
for i, s := range n.DERPForcedWebsocket {
|
||||
derpForcedWebsocket[int32(i)] = s
|
||||
}
|
||||
addresses := make([]string, len(n.Addresses))
|
||||
for i, prefix := range n.Addresses {
|
||||
s, err := prefix.MarshalText()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addresses[i] = string(s)
|
||||
}
|
||||
allowedIPs := make([]string, len(n.AllowedIPs))
|
||||
for i, prefix := range n.AllowedIPs {
|
||||
s, err := prefix.MarshalText()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allowedIPs[i] = string(s)
|
||||
}
|
||||
return &proto.Node{
|
||||
Id: int64(n.ID),
|
||||
AsOf: timestamppb.New(n.AsOf),
|
||||
Key: k,
|
||||
Disco: string(disco),
|
||||
PreferredDerp: int32(n.PreferredDERP),
|
||||
DerpLatency: n.DERPLatency,
|
||||
DerpForcedWebsocket: derpForcedWebsocket,
|
||||
Addresses: addresses,
|
||||
AllowedIps: allowedIPs,
|
||||
Endpoints: n.Endpoints,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ProtoToNode(p *proto.Node) (*Node, error) {
|
||||
k := key.NodePublic{}
|
||||
err := k.UnmarshalBinary(p.GetKey())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
disco := key.DiscoPublic{}
|
||||
err = disco.UnmarshalText([]byte(p.GetDisco()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
derpForcedWebsocket := make(map[int]string)
|
||||
for i, s := range p.GetDerpForcedWebsocket() {
|
||||
derpForcedWebsocket[int(i)] = s
|
||||
}
|
||||
addresses := make([]netip.Prefix, len(p.GetAddresses()))
|
||||
for i, prefix := range p.GetAddresses() {
|
||||
err = addresses[i].UnmarshalText([]byte(prefix))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
allowedIPs := make([]netip.Prefix, len(p.GetAllowedIps()))
|
||||
for i, prefix := range p.GetAllowedIps() {
|
||||
err = allowedIPs[i].UnmarshalText([]byte(prefix))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &Node{
|
||||
ID: tailcfg.NodeID(p.GetId()),
|
||||
AsOf: p.GetAsOf().AsTime(),
|
||||
Key: k,
|
||||
DiscoKey: disco,
|
||||
PreferredDERP: int(p.GetPreferredDerp()),
|
||||
DERPLatency: p.GetDerpLatency(),
|
||||
DERPForcedWebsocket: derpForcedWebsocket,
|
||||
Addresses: addresses,
|
||||
AllowedIPs: allowedIPs,
|
||||
Endpoints: p.Endpoints,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func OnlyNodeUpdates(resp *proto.CoordinateResponse) ([]*Node, error) {
|
||||
nodes := make([]*Node, 0, len(resp.GetPeerUpdates()))
|
||||
for _, pu := range resp.GetPeerUpdates() {
|
||||
if pu.Kind != proto.CoordinateResponse_PeerUpdate_NODE {
|
||||
continue
|
||||
}
|
||||
n, err := ProtoToNode(pu.Node)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed conversion from protobuf: %w", err)
|
||||
}
|
||||
nodes = append(nodes, n)
|
||||
}
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func SingleNodeUpdate(id uuid.UUID, node *Node, reason string) (*proto.CoordinateResponse, error) {
|
||||
p, err := NodeToProto(node)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("node failed conversion to protobuf: %w", err)
|
||||
}
|
||||
return &proto.CoordinateResponse{
|
||||
PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{
|
||||
{
|
||||
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
|
||||
Uuid: UUIDToByteSlice(id),
|
||||
Node: p,
|
||||
Reason: reason,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
|
@ -0,0 +1,153 @@
|
|||
package tailnet_test
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
)
|
||||
|
||||
func TestNode(t *testing.T) {
|
||||
t.Parallel()
|
||||
testCases := []struct {
|
||||
name string
|
||||
node tailnet.Node
|
||||
}{
|
||||
{
|
||||
name: "Zero",
|
||||
},
|
||||
{
|
||||
name: "AllFields",
|
||||
node: tailnet.Node{
|
||||
ID: 33,
|
||||
AsOf: time.Now(),
|
||||
Key: key.NewNode().Public(),
|
||||
DiscoKey: key.NewDisco().Public(),
|
||||
PreferredDERP: 12,
|
||||
DERPLatency: map[string]float64{
|
||||
"1": 0.2,
|
||||
"12": 0.3,
|
||||
},
|
||||
DERPForcedWebsocket: map[int]string{
|
||||
1: "forced",
|
||||
},
|
||||
Addresses: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("ff80::aa:1/128"),
|
||||
},
|
||||
AllowedIPs: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("ff80::aa:1/128"),
|
||||
},
|
||||
Endpoints: []string{
|
||||
"192.168.0.1:3305",
|
||||
"[ff80::aa:1]:2049",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "dbtime",
|
||||
node: tailnet.Node{
|
||||
AsOf: dbtime.Now(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
p, err := tailnet.NodeToProto(&tc.node)
|
||||
require.NoError(t, err)
|
||||
|
||||
inv, err := tailnet.ProtoToNode(p)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.node.ID, inv.ID)
|
||||
require.True(t, tc.node.AsOf.Equal(inv.AsOf))
|
||||
require.Equal(t, tc.node.Key, inv.Key)
|
||||
require.Equal(t, tc.node.DiscoKey, inv.DiscoKey)
|
||||
require.Equal(t, tc.node.PreferredDERP, inv.PreferredDERP)
|
||||
require.Equal(t, tc.node.DERPLatency, inv.DERPLatency)
|
||||
require.Equal(t, len(tc.node.DERPForcedWebsocket), len(inv.DERPForcedWebsocket))
|
||||
for k, v := range inv.DERPForcedWebsocket {
|
||||
nv, ok := tc.node.DERPForcedWebsocket[k]
|
||||
require.True(t, ok)
|
||||
require.Equal(t, nv, v)
|
||||
}
|
||||
require.ElementsMatch(t, tc.node.Addresses, inv.Addresses)
|
||||
require.ElementsMatch(t, tc.node.AllowedIPs, inv.AllowedIPs)
|
||||
require.ElementsMatch(t, tc.node.Endpoints, inv.Endpoints)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUUIDToByteSlice(t *testing.T) {
|
||||
t.Parallel()
|
||||
u := uuid.New()
|
||||
b := tailnet.UUIDToByteSlice(u)
|
||||
u2, err := uuid.FromBytes(b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, u, u2)
|
||||
|
||||
b = tailnet.UUIDToByteSlice(uuid.Nil)
|
||||
u2, err = uuid.FromBytes(b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uuid.Nil, u2)
|
||||
}
|
||||
|
||||
func TestOnlyNodeUpdates(t *testing.T) {
|
||||
t.Parallel()
|
||||
node := &tailnet.Node{ID: tailcfg.NodeID(1)}
|
||||
p, err := tailnet.NodeToProto(node)
|
||||
require.NoError(t, err)
|
||||
resp := &proto.CoordinateResponse{
|
||||
PeerUpdates: []*proto.CoordinateResponse_PeerUpdate{
|
||||
{
|
||||
Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1},
|
||||
Kind: proto.CoordinateResponse_PeerUpdate_NODE,
|
||||
Node: p,
|
||||
},
|
||||
{
|
||||
Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2},
|
||||
Kind: proto.CoordinateResponse_PeerUpdate_DISCONNECTED,
|
||||
Reason: "disconnected",
|
||||
},
|
||||
{
|
||||
Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3},
|
||||
Kind: proto.CoordinateResponse_PeerUpdate_LOST,
|
||||
Reason: "disconnected",
|
||||
},
|
||||
{
|
||||
Uuid: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4},
|
||||
},
|
||||
},
|
||||
}
|
||||
nodes, err := tailnet.OnlyNodeUpdates(resp)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, nodes, 1)
|
||||
require.Equal(t, tailcfg.NodeID(1), nodes[0].ID)
|
||||
}
|
||||
|
||||
func TestSingleNodeUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
node := &tailnet.Node{ID: tailcfg.NodeID(1)}
|
||||
u := uuid.New()
|
||||
resp, err := tailnet.SingleNodeUpdate(u, node, "unit test")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.PeerUpdates, 1)
|
||||
up := resp.PeerUpdates[0]
|
||||
require.Equal(t, proto.CoordinateResponse_PeerUpdate_NODE, up.Kind)
|
||||
u2, err := uuid.FromBytes(up.Uuid)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, u, u2)
|
||||
require.Equal(t, "unit test", up.Reason)
|
||||
require.EqualValues(t, 1, up.Node.Id)
|
||||
}
|
Loading…
Reference in New Issue