WIP: feat(agent): add support for control socket (ctrlsock)

This commit is contained in:
Mathias Fredriksson 2024-02-05 15:45:50 +02:00
parent e5ba586e30
commit 4eea3e9db4
5 changed files with 279 additions and 0 deletions

View File

@ -65,6 +65,7 @@ const EnvProcPrioMgmt = "CODER_PROC_PRIO_MGMT"
type Options struct {
Filesystem afero.Fs
LogDir string
RunDir string
TempDir string
ExchangeToken func(ctx context.Context) (string, error)
Client Client
@ -116,6 +117,12 @@ func New(options Options) Agent {
}
options.LogDir = options.TempDir
}
if options.RunDir == "" {
if options.TempDir != os.TempDir() {
options.Logger.Debug(context.Background(), "run dir not set, using temp dir", slog.F("temp_dir", options.TempDir))
}
options.RunDir = options.TempDir
}
if options.ExchangeToken == nil {
options.ExchangeToken = func(ctx context.Context) (string, error) {
return "", nil

34
agent/ctrlsock/client.go Normal file
View File

@ -0,0 +1,34 @@
package ctrlsock
import (
"net"
)
type Client struct {
conn net.Conn
}
func NewClient(sock, authKey string) (*Client, error) {
conn, err := net.Dial("unix", sock)
if err != nil {
return nil, err
}
c := &Client{conn: conn}
// Send auth key.
if err := writeString(c.conn, authKey); err != nil {
_ = conn.Close()
return nil, err
}
return c, nil
}
func (c *Client) Close() error {
return c.conn.Close()
}
func (c *Client) SetEnv(key, value string) error {
return writeSetEnv(c.conn, key, value)
}

77
agent/ctrlsock/command.go Normal file
View File

@ -0,0 +1,77 @@
package ctrlsock
import (
"encoding/binary"
"fmt"
"io"
)
type Command byte
func (c Command) String() string {
switch c {
case SetEnv:
return "SetEnv"
default:
return fmt.Sprintf("Command(%d)", c)
}
}
const (
SetEnv Command = iota + 1
)
func readByte(r io.Reader) (byte, error) {
var b byte
if err := binary.Read(r, binary.BigEndian, &b); err != nil {
return 0, err
}
return b, nil
}
func readString(r io.Reader) (string, error) {
var length uint32
if err := binary.Read(r, binary.BigEndian, &length); err != nil {
return "", err
}
data := make([]byte, length)
if _, err := io.ReadFull(r, data); err != nil {
return "", err
}
return string(data), nil
}
func writeCommand(w io.Writer, c Command) error {
return binary.Write(w, binary.BigEndian, c)
}
func writeString(w io.Writer, s string) error {
if err := binary.Write(w, binary.BigEndian, uint32(len(s))); err != nil {
return err
}
_, err := w.Write([]byte(s))
return err
}
func readSetEnv(r io.Reader) (key string, value string, err error) {
key, err = readString(r)
if err != nil {
return key, value, err
}
value, err = readString(r)
if err != nil {
return key, value, err
}
return key, value, nil
}
func writeSetEnv(w io.Writer, key string, value string) error {
if err := writeCommand(w, SetEnv); err != nil {
return err
}
if err := writeString(w, key); err != nil {
return err
}
return writeString(w, value)
}

152
agent/ctrlsock/server.go Normal file
View File

@ -0,0 +1,152 @@
package ctrlsock
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"io"
"io/fs"
"net"
"os"
"path/filepath"
"sync"
"golang.org/x/xerrors"
"cdr.dev/slog"
)
type Server struct {
logger slog.Logger
handlers Handlers
ln net.Listener
wg sync.WaitGroup
done chan struct{}
authKey string
}
func (s *Server) Addr() net.Addr {
return s.ln.Addr()
}
func (s *Server) AuthKey() string {
return s.authKey
}
type Handlers struct {
SetEnv func(key, value string)
}
func New(logger slog.Logger, runDir string, handlers Handlers) (*Server, error) {
addr := filepath.Join(runDir, "agent.sock")
err := os.Remove(addr)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return nil, xerrors.Errorf("remove existing socket failed: %w", err)
}
ln, err := net.Listen("unix", addr)
if err != nil {
return nil, err
}
authKey, err := generateAuthKey()
if err != nil {
return nil, err
}
s := &Server{
logger: logger.Named("ctrlsock"),
handlers: handlers,
ln: ln,
done: make(chan struct{}),
authKey: authKey,
}
go s.acceptLoop()
return s, nil
}
func (s *Server) acceptLoop() {
for {
conn, err := s.ln.Accept()
if err != nil {
select {
case <-s.done:
// The listener was closed, so we're done.
return
default:
s.logger.Error(context.Background(), "accept connection failed", "err", err)
}
} else {
s.wg.Add(1)
go s.handleConnection(conn)
}
}
}
func (s *Server) handleConnection(conn net.Conn) {
defer s.wg.Done()
defer conn.Close()
// Check the authentication key.
if err := s.handleAuth(conn); err != nil {
s.logger.Error(context.Background(), "authentication failed", "err", err)
return
}
// Handle commands.
for {
cmdByte, err := readByte(conn)
if err != nil {
if errors.Is(err, io.EOF) {
break
}
s.logger.Error(context.Background(), "read command type failed", "err", err)
return
}
cmdType := Command(cmdByte)
logger := s.logger.With(slog.F("command", cmdType.String()))
switch cmdType {
case SetEnv:
key, value, err := readSetEnv(conn)
if err != nil {
logger.Error(context.Background(), "handle command failed", "err", err)
}
logger.Info(context.Background(), "calling command with input", "key", key, "value_length", len(value))
s.handlers.SetEnv(key, value)
default:
s.logger.Error(context.Background(), "unknown command, closing connection")
return
}
}
}
func generateAuthKey() (string, error) {
key := make([]byte, 16)
_, err := rand.Read(key)
if err != nil {
return "", err
}
return hex.EncodeToString(key), nil
}
func (s *Server) handleAuth(conn net.Conn) error {
key, err := readString(conn)
if err != nil {
return err
}
if key != s.authKey {
return xerrors.Errorf("invalid auth key: %s", key)
}
return nil
}
func (s *Server) Close() error {
close(s.done)
err := s.ln.Close()
s.wg.Wait()
return err
}

View File

@ -40,6 +40,7 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd {
var (
auth string
logDir string
runDir string
pprofAddress string
noReap bool
sshMaxTimeout time.Duration
@ -284,6 +285,7 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd {
Client: client,
Logger: logger,
LogDir: logDir,
RunDir: runDir,
TailnetListenPort: uint16(tailnetListenPort),
ExchangeToken: func(ctx context.Context) (string, error) {
if exchangeToken == nil {
@ -337,6 +339,13 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd {
Env: "CODER_AGENT_LOG_DIR",
Value: clibase.StringOf(&logDir),
},
{
Flag: "run-dir",
Default: os.TempDir(),
Description: "Specify the location for the agent run files.",
Env: "CODER_AGENT_RUN_DIR",
Value: clibase.StringOf(&runDir),
},
{
Flag: "pprof-address",
Default: "127.0.0.1:6060",