feat: handle response from VertexAI instead of OpenAI

This commit is contained in:
Igor Drozdov 2023-08-18 10:59:13 +00:00 committed by Shekhar Patnaik
parent 35385c5c0d
commit c8057a29f0
2 changed files with 30 additions and 19 deletions

View File

@ -1,7 +1,6 @@
package git
import (
"encoding/json"
"fmt"
"net/http"
"os/exec"
@ -21,14 +20,15 @@ import (
type request struct {
Prompt string `json:"prompt"`
Model string `json:"model"`
}
type response struct {
Choices []struct {
Message struct {
Predictions []struct {
Candidates []struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
} `json:"candidates"`
} `json:"predictions"`
}
type result struct {
@ -42,7 +42,11 @@ type opts struct {
HttpClient func() (*gitlab.Client, error)
}
var cmdRegexp = regexp.MustCompile("`([^`]*)`")
var (
cmdHighlightRegexp = regexp.MustCompile("`+\n?([^`]*)\n?`+\n?")
cmdExecRegexp = regexp.MustCompile("```([^`]*)```")
vertexAI = "vertexai"
)
const (
runCmdsQuestion = "Would you like to run these Git commands"
@ -110,7 +114,7 @@ func (opts *opts) Result() (*result, error) {
return nil, cmdutils.WrapError(err, "failed to get http client")
}
body := request{Prompt: opts.Prompt}
body := request{Prompt: opts.Prompt, Model: vertexAI}
request, err := client.NewRequest(http.MethodPost, gitCmdAPIPath, body, nil)
if err != nil {
return nil, cmdutils.WrapError(err, "failed to create a request")
@ -122,16 +126,21 @@ func (opts *opts) Result() (*result, error) {
return nil, cmdutils.WrapError(err, apiUnreachableErr)
}
if len(r.Choices) == 0 {
if len(r.Predictions) == 0 || len(r.Predictions[0].Candidates) == 0 {
return nil, fmt.Errorf(aiResponseErr)
}
var result result
if err := json.Unmarshal([]byte(r.Choices[0].Message.Content), &result); err != nil {
return nil, fmt.Errorf(aiResponseErr)
content := r.Predictions[0].Candidates[0].Content
var cmds []string
for _, cmd := range cmdExecRegexp.FindAllString(content, -1) {
cmds = append(cmds, strings.Trim(cmd, "\n`"))
}
return &result, nil
return &result{
Commands: cmds,
Explanation: content,
}, nil
}
func (opts *opts) displayResult(result *result) {
@ -147,7 +156,7 @@ func (opts *opts) displayResult(result *result) {
}
opts.IO.LogInfo(color.Bold("\nExplanation:\n"))
explanation := cmdRegexp.ReplaceAllString(result.Explanation, color.Green("$1"))
explanation := cmdHighlightRegexp.ReplaceAllString(result.Explanation, color.Green("$1"))
opts.IO.LogInfo(explanation + "\n")
}

View File

@ -26,6 +26,7 @@ func runCommand(rt http.RoundTripper, isTTY bool, args string) (*test.CmdOut, er
}
func TestGitCmd(t *testing.T) {
initialAiResponse := "The appropriate ```git log --pretty=format:'%h'``` Git command ```non-git cmd``` for listing ```git show``` commit SHAs."
outputWithoutExecution := "Experiment:\n" + experimentMsg + `
Commands:
@ -35,9 +36,10 @@ git show
Explanation:
The appropriate Git command for listing commit SHAs.
The appropriate git log --pretty=format:'%h' Git command non-git cmd for listing git show commit SHAs.
`
tests := []struct {
desc string
content string
@ -47,21 +49,21 @@ The appropriate Git command for listing commit SHAs.
}{
{
desc: "agree to run commands",
content: `{\"commands\": [\"git log --pretty=format:'%h'\", \"non-git cmd\", \"git show\"], \"explanation\":\"The appropriate Git command for listing commit SHAs.\"}`,
content: initialAiResponse,
withPrompt: true,
withExecution: true,
expectedResult: outputWithoutExecution + "git log executed\ngit show executed\n",
},
{
desc: "disagree to run commands",
content: `{\"commands\": [\"git log --pretty=format:'%h'\", \"non-git cmd\", \"git show\"], \"explanation\":\"The appropriate Git command for listing commit SHAs.\"}`,
content: initialAiResponse,
withPrompt: true,
withExecution: false,
expectedResult: outputWithoutExecution,
},
{
desc: "no commands",
content: `{\"commands\": [], \"explanation\":\"There are no Git commands related to the text.\"}`,
content: "There are no Git commands related to the text.",
withPrompt: false,
expectedResult: "Experiment:\n" + experimentMsg + "\nCommands:\n\n\nExplanation:\n\nThere are no Git commands related to the text.\n\n",
},
@ -76,7 +78,7 @@ The appropriate Git command for listing commit SHAs.
}
defer fakeHTTP.Verify(t)
body := `{"choices": [{"message": {"content": "` + tc.content + `"}}]}`
body := `{"predictions": [{ "candidates": [ {"content": "` + tc.content + `"} ]}]}`
response := httpmock.NewStringResponse(http.StatusOK, body)
fakeHTTP.RegisterResponder(http.MethodPost, "/api/v4/ai/llm/git_command", response)
@ -94,7 +96,7 @@ The appropriate Git command for listing commit SHAs.
output, err := runCommand(fakeHTTP, false, "git list 10 commits")
require.Nil(t, err)
require.Equal(t, tc.expectedResult, output.String())
require.Equal(t, output.String(), tc.expectedResult)
require.Empty(t, output.Stderr())
})
}