mirror of https://gitlab.com/gitlab-org/cli.git
feat: handle response from VertexAI instead of OpenAI
This commit is contained in:
parent
35385c5c0d
commit
c8057a29f0
|
@ -1,7 +1,6 @@
|
|||
package git
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
|
@ -21,14 +20,15 @@ import (
|
|||
|
||||
type request struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
type response struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Predictions []struct {
|
||||
Candidates []struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
} `json:"candidates"`
|
||||
} `json:"predictions"`
|
||||
}
|
||||
|
||||
type result struct {
|
||||
|
@ -42,7 +42,11 @@ type opts struct {
|
|||
HttpClient func() (*gitlab.Client, error)
|
||||
}
|
||||
|
||||
var cmdRegexp = regexp.MustCompile("`([^`]*)`")
|
||||
var (
|
||||
cmdHighlightRegexp = regexp.MustCompile("`+\n?([^`]*)\n?`+\n?")
|
||||
cmdExecRegexp = regexp.MustCompile("```([^`]*)```")
|
||||
vertexAI = "vertexai"
|
||||
)
|
||||
|
||||
const (
|
||||
runCmdsQuestion = "Would you like to run these Git commands"
|
||||
|
@ -110,7 +114,7 @@ func (opts *opts) Result() (*result, error) {
|
|||
return nil, cmdutils.WrapError(err, "failed to get http client")
|
||||
}
|
||||
|
||||
body := request{Prompt: opts.Prompt}
|
||||
body := request{Prompt: opts.Prompt, Model: vertexAI}
|
||||
request, err := client.NewRequest(http.MethodPost, gitCmdAPIPath, body, nil)
|
||||
if err != nil {
|
||||
return nil, cmdutils.WrapError(err, "failed to create a request")
|
||||
|
@ -122,16 +126,21 @@ func (opts *opts) Result() (*result, error) {
|
|||
return nil, cmdutils.WrapError(err, apiUnreachableErr)
|
||||
}
|
||||
|
||||
if len(r.Choices) == 0 {
|
||||
if len(r.Predictions) == 0 || len(r.Predictions[0].Candidates) == 0 {
|
||||
return nil, fmt.Errorf(aiResponseErr)
|
||||
}
|
||||
|
||||
var result result
|
||||
if err := json.Unmarshal([]byte(r.Choices[0].Message.Content), &result); err != nil {
|
||||
return nil, fmt.Errorf(aiResponseErr)
|
||||
content := r.Predictions[0].Candidates[0].Content
|
||||
|
||||
var cmds []string
|
||||
for _, cmd := range cmdExecRegexp.FindAllString(content, -1) {
|
||||
cmds = append(cmds, strings.Trim(cmd, "\n`"))
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
return &result{
|
||||
Commands: cmds,
|
||||
Explanation: content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (opts *opts) displayResult(result *result) {
|
||||
|
@ -147,7 +156,7 @@ func (opts *opts) displayResult(result *result) {
|
|||
}
|
||||
|
||||
opts.IO.LogInfo(color.Bold("\nExplanation:\n"))
|
||||
explanation := cmdRegexp.ReplaceAllString(result.Explanation, color.Green("$1"))
|
||||
explanation := cmdHighlightRegexp.ReplaceAllString(result.Explanation, color.Green("$1"))
|
||||
opts.IO.LogInfo(explanation + "\n")
|
||||
}
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ func runCommand(rt http.RoundTripper, isTTY bool, args string) (*test.CmdOut, er
|
|||
}
|
||||
|
||||
func TestGitCmd(t *testing.T) {
|
||||
initialAiResponse := "The appropriate ```git log --pretty=format:'%h'``` Git command ```non-git cmd``` for listing ```git show``` commit SHAs."
|
||||
outputWithoutExecution := "Experiment:\n" + experimentMsg + `
|
||||
Commands:
|
||||
|
||||
|
@ -35,9 +36,10 @@ git show
|
|||
|
||||
Explanation:
|
||||
|
||||
The appropriate Git command for listing commit SHAs.
|
||||
The appropriate git log --pretty=format:'%h' Git command non-git cmd for listing git show commit SHAs.
|
||||
|
||||
`
|
||||
|
||||
tests := []struct {
|
||||
desc string
|
||||
content string
|
||||
|
@ -47,21 +49,21 @@ The appropriate Git command for listing commit SHAs.
|
|||
}{
|
||||
{
|
||||
desc: "agree to run commands",
|
||||
content: `{\"commands\": [\"git log --pretty=format:'%h'\", \"non-git cmd\", \"git show\"], \"explanation\":\"The appropriate Git command for listing commit SHAs.\"}`,
|
||||
content: initialAiResponse,
|
||||
withPrompt: true,
|
||||
withExecution: true,
|
||||
expectedResult: outputWithoutExecution + "git log executed\ngit show executed\n",
|
||||
},
|
||||
{
|
||||
desc: "disagree to run commands",
|
||||
content: `{\"commands\": [\"git log --pretty=format:'%h'\", \"non-git cmd\", \"git show\"], \"explanation\":\"The appropriate Git command for listing commit SHAs.\"}`,
|
||||
content: initialAiResponse,
|
||||
withPrompt: true,
|
||||
withExecution: false,
|
||||
expectedResult: outputWithoutExecution,
|
||||
},
|
||||
{
|
||||
desc: "no commands",
|
||||
content: `{\"commands\": [], \"explanation\":\"There are no Git commands related to the text.\"}`,
|
||||
content: "There are no Git commands related to the text.",
|
||||
withPrompt: false,
|
||||
expectedResult: "Experiment:\n" + experimentMsg + "\nCommands:\n\n\nExplanation:\n\nThere are no Git commands related to the text.\n\n",
|
||||
},
|
||||
|
@ -76,7 +78,7 @@ The appropriate Git command for listing commit SHAs.
|
|||
}
|
||||
defer fakeHTTP.Verify(t)
|
||||
|
||||
body := `{"choices": [{"message": {"content": "` + tc.content + `"}}]}`
|
||||
body := `{"predictions": [{ "candidates": [ {"content": "` + tc.content + `"} ]}]}`
|
||||
|
||||
response := httpmock.NewStringResponse(http.StatusOK, body)
|
||||
fakeHTTP.RegisterResponder(http.MethodPost, "/api/v4/ai/llm/git_command", response)
|
||||
|
@ -94,7 +96,7 @@ The appropriate Git command for listing commit SHAs.
|
|||
output, err := runCommand(fakeHTTP, false, "git list 10 commits")
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Equal(t, tc.expectedResult, output.String())
|
||||
require.Equal(t, output.String(), tc.expectedResult)
|
||||
require.Empty(t, output.Stderr())
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue