mirror of https://github.com/coder/coder.git
WIP: feat(agent): add support for control socket (ctrlsock)
This commit is contained in:
parent
e5ba586e30
commit
4eea3e9db4
|
@ -65,6 +65,7 @@ const EnvProcPrioMgmt = "CODER_PROC_PRIO_MGMT"
|
|||
type Options struct {
|
||||
Filesystem afero.Fs
|
||||
LogDir string
|
||||
RunDir string
|
||||
TempDir string
|
||||
ExchangeToken func(ctx context.Context) (string, error)
|
||||
Client Client
|
||||
|
@ -116,6 +117,12 @@ func New(options Options) Agent {
|
|||
}
|
||||
options.LogDir = options.TempDir
|
||||
}
|
||||
if options.RunDir == "" {
|
||||
if options.TempDir != os.TempDir() {
|
||||
options.Logger.Debug(context.Background(), "run dir not set, using temp dir", slog.F("temp_dir", options.TempDir))
|
||||
}
|
||||
options.RunDir = options.TempDir
|
||||
}
|
||||
if options.ExchangeToken == nil {
|
||||
options.ExchangeToken = func(ctx context.Context) (string, error) {
|
||||
return "", nil
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
package ctrlsock
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func NewClient(sock, authKey string) (*Client, error) {
|
||||
conn, err := net.Dial("unix", sock)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := &Client{conn: conn}
|
||||
|
||||
// Send auth key.
|
||||
if err := writeString(c.conn, authKey); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *Client) SetEnv(key, value string) error {
|
||||
return writeSetEnv(c.conn, key, value)
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
package ctrlsock
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
type Command byte
|
||||
|
||||
func (c Command) String() string {
|
||||
switch c {
|
||||
case SetEnv:
|
||||
return "SetEnv"
|
||||
default:
|
||||
return fmt.Sprintf("Command(%d)", c)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
SetEnv Command = iota + 1
|
||||
)
|
||||
|
||||
func readByte(r io.Reader) (byte, error) {
|
||||
var b byte
|
||||
if err := binary.Read(r, binary.BigEndian, &b); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func readString(r io.Reader) (string, error) {
|
||||
var length uint32
|
||||
if err := binary.Read(r, binary.BigEndian, &length); err != nil {
|
||||
return "", err
|
||||
}
|
||||
data := make([]byte, length)
|
||||
if _, err := io.ReadFull(r, data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
func writeCommand(w io.Writer, c Command) error {
|
||||
return binary.Write(w, binary.BigEndian, c)
|
||||
}
|
||||
|
||||
func writeString(w io.Writer, s string) error {
|
||||
if err := binary.Write(w, binary.BigEndian, uint32(len(s))); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := w.Write([]byte(s))
|
||||
return err
|
||||
}
|
||||
|
||||
func readSetEnv(r io.Reader) (key string, value string, err error) {
|
||||
key, err = readString(r)
|
||||
if err != nil {
|
||||
return key, value, err
|
||||
}
|
||||
value, err = readString(r)
|
||||
if err != nil {
|
||||
return key, value, err
|
||||
}
|
||||
|
||||
return key, value, nil
|
||||
}
|
||||
|
||||
func writeSetEnv(w io.Writer, key string, value string) error {
|
||||
if err := writeCommand(w, SetEnv); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeString(w, key); err != nil {
|
||||
return err
|
||||
}
|
||||
return writeString(w, value)
|
||||
}
|
|
@ -0,0 +1,152 @@
|
|||
package ctrlsock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
logger slog.Logger
|
||||
handlers Handlers
|
||||
ln net.Listener
|
||||
wg sync.WaitGroup
|
||||
done chan struct{}
|
||||
authKey string
|
||||
}
|
||||
|
||||
func (s *Server) Addr() net.Addr {
|
||||
return s.ln.Addr()
|
||||
}
|
||||
|
||||
func (s *Server) AuthKey() string {
|
||||
return s.authKey
|
||||
}
|
||||
|
||||
type Handlers struct {
|
||||
SetEnv func(key, value string)
|
||||
}
|
||||
|
||||
func New(logger slog.Logger, runDir string, handlers Handlers) (*Server, error) {
|
||||
addr := filepath.Join(runDir, "agent.sock")
|
||||
|
||||
err := os.Remove(addr)
|
||||
if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
return nil, xerrors.Errorf("remove existing socket failed: %w", err)
|
||||
}
|
||||
|
||||
ln, err := net.Listen("unix", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authKey, err := generateAuthKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
logger: logger.Named("ctrlsock"),
|
||||
handlers: handlers,
|
||||
ln: ln,
|
||||
done: make(chan struct{}),
|
||||
authKey: authKey,
|
||||
}
|
||||
go s.acceptLoop()
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Server) acceptLoop() {
|
||||
for {
|
||||
conn, err := s.ln.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-s.done:
|
||||
// The listener was closed, so we're done.
|
||||
return
|
||||
default:
|
||||
s.logger.Error(context.Background(), "accept connection failed", "err", err)
|
||||
}
|
||||
} else {
|
||||
s.wg.Add(1)
|
||||
go s.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleConnection(conn net.Conn) {
|
||||
defer s.wg.Done()
|
||||
defer conn.Close()
|
||||
|
||||
// Check the authentication key.
|
||||
if err := s.handleAuth(conn); err != nil {
|
||||
s.logger.Error(context.Background(), "authentication failed", "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle commands.
|
||||
for {
|
||||
cmdByte, err := readByte(conn)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
s.logger.Error(context.Background(), "read command type failed", "err", err)
|
||||
return
|
||||
}
|
||||
cmdType := Command(cmdByte)
|
||||
logger := s.logger.With(slog.F("command", cmdType.String()))
|
||||
|
||||
switch cmdType {
|
||||
case SetEnv:
|
||||
key, value, err := readSetEnv(conn)
|
||||
if err != nil {
|
||||
logger.Error(context.Background(), "handle command failed", "err", err)
|
||||
}
|
||||
logger.Info(context.Background(), "calling command with input", "key", key, "value_length", len(value))
|
||||
s.handlers.SetEnv(key, value)
|
||||
default:
|
||||
s.logger.Error(context.Background(), "unknown command, closing connection")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func generateAuthKey() (string, error) {
|
||||
key := make([]byte, 16)
|
||||
_, err := rand.Read(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(key), nil
|
||||
}
|
||||
|
||||
func (s *Server) handleAuth(conn net.Conn) error {
|
||||
key, err := readString(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if key != s.authKey {
|
||||
return xerrors.Errorf("invalid auth key: %s", key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) Close() error {
|
||||
close(s.done)
|
||||
err := s.ln.Close()
|
||||
s.wg.Wait()
|
||||
return err
|
||||
}
|
|
@ -40,6 +40,7 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd {
|
|||
var (
|
||||
auth string
|
||||
logDir string
|
||||
runDir string
|
||||
pprofAddress string
|
||||
noReap bool
|
||||
sshMaxTimeout time.Duration
|
||||
|
@ -284,6 +285,7 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd {
|
|||
Client: client,
|
||||
Logger: logger,
|
||||
LogDir: logDir,
|
||||
RunDir: runDir,
|
||||
TailnetListenPort: uint16(tailnetListenPort),
|
||||
ExchangeToken: func(ctx context.Context) (string, error) {
|
||||
if exchangeToken == nil {
|
||||
|
@ -337,6 +339,13 @@ func (r *RootCmd) workspaceAgent() *clibase.Cmd {
|
|||
Env: "CODER_AGENT_LOG_DIR",
|
||||
Value: clibase.StringOf(&logDir),
|
||||
},
|
||||
{
|
||||
Flag: "run-dir",
|
||||
Default: os.TempDir(),
|
||||
Description: "Specify the location for the agent run files.",
|
||||
Env: "CODER_AGENT_RUN_DIR",
|
||||
Value: clibase.StringOf(&runDir),
|
||||
},
|
||||
{
|
||||
Flag: "pprof-address",
|
||||
Default: "127.0.0.1:6060",
|
||||
|
|
Loading…
Reference in New Issue