feat: handle response from VertexAI instead of OpenAI

This commit is contained in:
Igor Drozdov 2023-08-18 10:59:13 +00:00 committed by Shekhar Patnaik
parent 35385c5c0d
commit c8057a29f0
2 changed files with 30 additions and 19 deletions

View File

@ -1,7 +1,6 @@
package git package git
import ( import (
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"os/exec" "os/exec"
@ -21,14 +20,15 @@ import (
type request struct { type request struct {
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Model string `json:"model"`
} }
type response struct { type response struct {
Choices []struct { Predictions []struct {
Message struct { Candidates []struct {
Content string `json:"content"` Content string `json:"content"`
} `json:"message"` } `json:"candidates"`
} `json:"choices"` } `json:"predictions"`
} }
type result struct { type result struct {
@ -42,7 +42,11 @@ type opts struct {
HttpClient func() (*gitlab.Client, error) HttpClient func() (*gitlab.Client, error)
} }
var cmdRegexp = regexp.MustCompile("`([^`]*)`") var (
cmdHighlightRegexp = regexp.MustCompile("`+\n?([^`]*)\n?`+\n?")
cmdExecRegexp = regexp.MustCompile("```([^`]*)```")
vertexAI = "vertexai"
)
const ( const (
runCmdsQuestion = "Would you like to run these Git commands" runCmdsQuestion = "Would you like to run these Git commands"
@ -110,7 +114,7 @@ func (opts *opts) Result() (*result, error) {
return nil, cmdutils.WrapError(err, "failed to get http client") return nil, cmdutils.WrapError(err, "failed to get http client")
} }
body := request{Prompt: opts.Prompt} body := request{Prompt: opts.Prompt, Model: vertexAI}
request, err := client.NewRequest(http.MethodPost, gitCmdAPIPath, body, nil) request, err := client.NewRequest(http.MethodPost, gitCmdAPIPath, body, nil)
if err != nil { if err != nil {
return nil, cmdutils.WrapError(err, "failed to create a request") return nil, cmdutils.WrapError(err, "failed to create a request")
@ -122,16 +126,21 @@ func (opts *opts) Result() (*result, error) {
return nil, cmdutils.WrapError(err, apiUnreachableErr) return nil, cmdutils.WrapError(err, apiUnreachableErr)
} }
if len(r.Choices) == 0 { if len(r.Predictions) == 0 || len(r.Predictions[0].Candidates) == 0 {
return nil, fmt.Errorf(aiResponseErr) return nil, fmt.Errorf(aiResponseErr)
} }
var result result content := r.Predictions[0].Candidates[0].Content
if err := json.Unmarshal([]byte(r.Choices[0].Message.Content), &result); err != nil {
return nil, fmt.Errorf(aiResponseErr) var cmds []string
for _, cmd := range cmdExecRegexp.FindAllString(content, -1) {
cmds = append(cmds, strings.Trim(cmd, "\n`"))
} }
return &result, nil return &result{
Commands: cmds,
Explanation: content,
}, nil
} }
func (opts *opts) displayResult(result *result) { func (opts *opts) displayResult(result *result) {
@ -147,7 +156,7 @@ func (opts *opts) displayResult(result *result) {
} }
opts.IO.LogInfo(color.Bold("\nExplanation:\n")) opts.IO.LogInfo(color.Bold("\nExplanation:\n"))
explanation := cmdRegexp.ReplaceAllString(result.Explanation, color.Green("$1")) explanation := cmdHighlightRegexp.ReplaceAllString(result.Explanation, color.Green("$1"))
opts.IO.LogInfo(explanation + "\n") opts.IO.LogInfo(explanation + "\n")
} }

View File

@ -26,6 +26,7 @@ func runCommand(rt http.RoundTripper, isTTY bool, args string) (*test.CmdOut, er
} }
func TestGitCmd(t *testing.T) { func TestGitCmd(t *testing.T) {
initialAiResponse := "The appropriate ```git log --pretty=format:'%h'``` Git command ```non-git cmd``` for listing ```git show``` commit SHAs."
outputWithoutExecution := "Experiment:\n" + experimentMsg + ` outputWithoutExecution := "Experiment:\n" + experimentMsg + `
Commands: Commands:
@ -35,9 +36,10 @@ git show
Explanation: Explanation:
The appropriate Git command for listing commit SHAs. The appropriate git log --pretty=format:'%h' Git command non-git cmd for listing git show commit SHAs.
` `
tests := []struct { tests := []struct {
desc string desc string
content string content string
@ -47,21 +49,21 @@ The appropriate Git command for listing commit SHAs.
}{ }{
{ {
desc: "agree to run commands", desc: "agree to run commands",
content: `{\"commands\": [\"git log --pretty=format:'%h'\", \"non-git cmd\", \"git show\"], \"explanation\":\"The appropriate Git command for listing commit SHAs.\"}`, content: initialAiResponse,
withPrompt: true, withPrompt: true,
withExecution: true, withExecution: true,
expectedResult: outputWithoutExecution + "git log executed\ngit show executed\n", expectedResult: outputWithoutExecution + "git log executed\ngit show executed\n",
}, },
{ {
desc: "disagree to run commands", desc: "disagree to run commands",
content: `{\"commands\": [\"git log --pretty=format:'%h'\", \"non-git cmd\", \"git show\"], \"explanation\":\"The appropriate Git command for listing commit SHAs.\"}`, content: initialAiResponse,
withPrompt: true, withPrompt: true,
withExecution: false, withExecution: false,
expectedResult: outputWithoutExecution, expectedResult: outputWithoutExecution,
}, },
{ {
desc: "no commands", desc: "no commands",
content: `{\"commands\": [], \"explanation\":\"There are no Git commands related to the text.\"}`, content: "There are no Git commands related to the text.",
withPrompt: false, withPrompt: false,
expectedResult: "Experiment:\n" + experimentMsg + "\nCommands:\n\n\nExplanation:\n\nThere are no Git commands related to the text.\n\n", expectedResult: "Experiment:\n" + experimentMsg + "\nCommands:\n\n\nExplanation:\n\nThere are no Git commands related to the text.\n\n",
}, },
@ -76,7 +78,7 @@ The appropriate Git command for listing commit SHAs.
} }
defer fakeHTTP.Verify(t) defer fakeHTTP.Verify(t)
body := `{"choices": [{"message": {"content": "` + tc.content + `"}}]}` body := `{"predictions": [{ "candidates": [ {"content": "` + tc.content + `"} ]}]}`
response := httpmock.NewStringResponse(http.StatusOK, body) response := httpmock.NewStringResponse(http.StatusOK, body)
fakeHTTP.RegisterResponder(http.MethodPost, "/api/v4/ai/llm/git_command", response) fakeHTTP.RegisterResponder(http.MethodPost, "/api/v4/ai/llm/git_command", response)
@ -94,7 +96,7 @@ The appropriate Git command for listing commit SHAs.
output, err := runCommand(fakeHTTP, false, "git list 10 commits") output, err := runCommand(fakeHTTP, false, "git list 10 commits")
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, tc.expectedResult, output.String()) require.Equal(t, output.String(), tc.expectedResult)
require.Empty(t, output.Stderr()) require.Empty(t, output.Stderr())
}) })
} }