fix: ensure listen websocket isn't opened for non-latest agents (#2002)

Exponential backoff is only enabled if the websocket fails to open. If
the websocket is opened but immediately killed, the agent will try to
immediately reconnect. This is desireable in cases where coderd is being
replaced or network conditions cause the connection to die, but not for
permanent errors.
This commit is contained in:
Colin Adler 2022-06-02 15:03:01 -05:00 committed by GitHub
parent 0e1f868f5f
commit 89dde21837
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 162 additions and 70 deletions

View File

@ -143,16 +143,49 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
defer api.websocketWaitGroup.Done()
workspaceAgent := httpmw.WorkspaceAgent(r)
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
CompressionMode: websocket.CompressionDisabled,
})
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: fmt.Sprintf("accept websocket: %s", err),
Message: fmt.Sprintf("get workspace resource: %s", err),
})
return
}
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID)
build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: fmt.Sprintf("get workspace build job: %s", err),
})
return
}
// Ensure the resource is still valid!
// We only accept agents for resources on the latest build.
ensureLatestBuild := func() error {
latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(r.Context(), build.WorkspaceID)
if err != nil {
return err
}
if build.ID != latestBuild.ID {
return xerrors.New("build is outdated")
}
return nil
}
err = ensureLatestBuild()
if err != nil {
api.Logger.Debug(r.Context(), "agent tried to connect from non-latest built",
slog.F("resource", resource),
slog.F("agent", workspaceAgent),
)
httpapi.Write(rw, http.StatusForbidden, httpapi.Response{
Message: fmt.Sprintf("ensure latest build: %s", err),
})
return
}
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: fmt.Sprintf("accept websocket: %s", err),
@ -163,6 +196,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
defer func() {
_ = conn.Close(websocket.StatusNormalClosure, "")
}()
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
@ -170,6 +204,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
closer, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(session)), peerbroker.ProxyOptions{
ChannelID: workspaceAgent.ID.String(),
Pubsub: api.Pubsub,
@ -180,6 +215,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
return
}
defer closer.Close()
firstConnectedAt := workspaceAgent.FirstConnectedAt
if !firstConnectedAt.Valid {
firstConnectedAt = sql.NullTime{
@ -204,23 +240,6 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
}
return nil
}
build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID)
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
// Ensure the resource is still valid!
// We only accept agents for resources on the latest build.
ensureLatestBuild := func() error {
latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(r.Context(), build.WorkspaceID)
if err != nil {
return err
}
if build.ID != latestBuild.ID {
return xerrors.New("build is outdated")
}
return nil
}
defer func() {
disconnectedAt = sql.NullTime{
@ -230,11 +249,6 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
_ = updateConnectionTimes()
}()
err = ensureLatestBuild()
if err != nil {
_ = conn.Close(websocket.StatusGoingAway, "")
return
}
err = updateConnectionTimes()
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())

View File

@ -68,52 +68,130 @@ func TestWorkspaceAgent(t *testing.T) {
func TestWorkspaceAgentListen(t *testing.T) {
t.Parallel()
client, coderAPI := coderdtest.NewWithAPI(t, nil)
user := coderdtest.CreateFirstUser(t, client)
daemonCloser := coderdtest.NewProvisionerDaemon(t, coderAPI)
authToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionDryRun: echo.ProvisionComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "example",
Type: "aws_instance",
Agents: []*proto.Agent{{
Id: uuid.NewString(),
Auth: &proto.Agent_Token{
Token: authToken,
},
}},
}},
},
},
}},
})
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
daemonCloser.Close()
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
t.Run("Connect", func(t *testing.T) {
t.Parallel()
client, coderAPI := coderdtest.NewWithAPI(t, nil)
user := coderdtest.CreateFirstUser(t, client)
daemonCloser := coderdtest.NewProvisionerDaemon(t, coderAPI)
authToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionDryRun: echo.ProvisionComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "example",
Type: "aws_instance",
Agents: []*proto.Agent{{
Id: uuid.NewString(),
Auth: &proto.Agent_Token{
Token: authToken,
},
}},
}},
},
},
}},
})
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
daemonCloser.Close()
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
})
t.Cleanup(func() {
_ = agentCloser.Close()
})
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
conn, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil)
require.NoError(t, err)
t.Cleanup(func() {
_ = conn.Close()
})
_, err = conn.Ping()
require.NoError(t, err)
})
t.Cleanup(func() {
_ = agentCloser.Close()
t.Run("FailNonLatestBuild", func(t *testing.T) {
t.Parallel()
ctx := context.Background()
client, coderAPI := coderdtest.NewWithAPI(t, nil)
user := coderdtest.CreateFirstUser(t, client)
daemonCloser := coderdtest.NewProvisionerDaemon(t, coderAPI)
defer daemonCloser.Close()
authToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionDryRun: echo.ProvisionComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "example",
Type: "aws_instance",
Agents: []*proto.Agent{{
Id: uuid.NewString(),
Auth: &proto.Agent_Token{
Token: authToken,
},
}},
}},
},
},
}},
})
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
version = coderdtest.UpdateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionDryRun: echo.ProvisionComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "example",
Type: "aws_instance",
Agents: []*proto.Agent{{
Id: uuid.NewString(),
Auth: &proto.Agent_Token{
Token: uuid.NewString(),
},
}},
}},
},
},
}},
}, template.ID)
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
stopBuild, err := client.CreateWorkspaceBuild(context.Background(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{
TemplateVersionID: version.ID,
Transition: codersdk.WorkspaceTransitionStop,
})
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJob(t, client, stopBuild.ID)
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
_, _, err = agentClient.ListenWorkspaceAgent(ctx, slogtest.Make(t, nil))
require.Error(t, err)
require.ErrorContains(t, err, "build is outdated")
})
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
conn, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil)
require.NoError(t, err)
t.Cleanup(func() {
_ = conn.Close()
})
_, err = conn.Ping()
require.NoError(t, err)
}
func TestWorkspaceAgentTURN(t *testing.T) {