diff --git a/api/openai.go b/api/openai.go deleted file mode 100644 index ca5c5ec5..00000000 --- a/api/openai.go +++ /dev/null @@ -1,5 +0,0 @@ -package api - -type OpenAICompletionRequest struct { - Prompt string `json:"prompt"` -} diff --git a/plugin/openai/chat_completion.go b/plugin/openai/chat_completion.go index 221aa7a8..9e72d681 100644 --- a/plugin/openai/chat_completion.go +++ b/plugin/openai/chat_completion.go @@ -24,7 +24,7 @@ type ChatCompletionResponse struct { Choices []ChatCompletionChoice `json:"choices"` } -func PostChatCompletion(prompt string, apiKey string, apiHost string) (string, error) { +func PostChatCompletion(messages []ChatCompletionMessage, apiKey string, apiHost string) (string, error) { if apiHost == "" { apiHost = "https://api.openai.com" } @@ -34,8 +34,12 @@ func PostChatCompletion(prompt string, apiKey string, apiHost string) (string, e } values := map[string]interface{}{ - "model": "gpt-3.5-turbo", - "messages": []map[string]string{{"role": "user", "content": prompt}}, + "model": "gpt-3.5-turbo", + "messages": messages, + "max_tokens": 2000, + "temperature": 0, + "frequency_penalty": 0.0, + "presence_penalty": 0.0, } jsonValue, err := json.Marshal(values) if err != nil { diff --git a/server/openai.go b/server/openai.go index 17d52f61..522d05e6 100644 --- a/server/openai.go +++ b/server/openai.go @@ -31,15 +31,15 @@ func (s *Server) registerOpenAIRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusBadRequest, "OpenAI API key not set") } - completionRequest := api.OpenAICompletionRequest{} - if err := json.NewDecoder(c.Request().Body).Decode(&completionRequest); err != nil { + messages := []openai.ChatCompletionMessage{} + if err := json.NewDecoder(c.Request().Body).Decode(&messages); err != nil { return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post chat completion request").SetInternal(err) } - if completionRequest.Prompt == "" { - return echo.NewHTTPError(http.StatusBadRequest, "Prompt is required") + if len(messages) == 0 { + return echo.NewHTTPError(http.StatusBadRequest, "No messages provided") } - result, err := openai.PostChatCompletion(completionRequest.Prompt, openAIConfig.Key, openAIConfig.Host) + result, err := openai.PostChatCompletion(messages, openAIConfig.Key, openAIConfig.Host) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to post chat completion").SetInternal(err) } @@ -47,42 +47,6 @@ func (s *Server) registerOpenAIRoutes(g *echo.Group) { return c.JSON(http.StatusOK, composeResponse(result)) }) - g.POST("/openai/text-completion", func(c echo.Context) error { - ctx := c.Request().Context() - openAIConfigSetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{ - Name: api.SystemSettingOpenAIConfigName, - }) - if err != nil && common.ErrorCode(err) != common.NotFound { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find openai key").SetInternal(err) - } - - openAIConfig := api.OpenAIConfig{} - if openAIConfigSetting != nil { - err = json.Unmarshal([]byte(openAIConfigSetting.Value), &openAIConfig) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal openai system setting value").SetInternal(err) - } - } - if openAIConfig.Key == "" { - return echo.NewHTTPError(http.StatusBadRequest, "OpenAI API key not set") - } - - textCompletion := api.OpenAICompletionRequest{} - if err := json.NewDecoder(c.Request().Body).Decode(&textCompletion); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post text completion request").SetInternal(err) - } - if textCompletion.Prompt == "" { - return echo.NewHTTPError(http.StatusBadRequest, "Prompt is required") - } - - result, err := openai.PostTextCompletion(textCompletion.Prompt, openAIConfig.Key, openAIConfig.Host) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to post text completion").SetInternal(err) - } - - return c.JSON(http.StatusOK, composeResponse(result)) - }) - g.GET("/openai/enabled", func(c echo.Context) error { ctx := c.Request().Context() openAIConfigSetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{ diff --git a/web/package.json b/web/package.json index 6830f5b6..1c0029d9 100644 --- a/web/package.json +++ b/web/package.json @@ -29,7 +29,8 @@ "react-router-dom": "^6.8.2", "react-use": "^17.4.0", "semver": "^7.3.8", - "tailwindcss": "^3.2.4" + "tailwindcss": "^3.2.4", + "zustand": "^4.3.6" }, "devDependencies": { "@types/lodash-es": "^4.17.5", diff --git a/web/src/components/AskAIDialog.tsx b/web/src/components/AskAIDialog.tsx index a3adb304..4da1f07c 100644 --- a/web/src/components/AskAIDialog.tsx +++ b/web/src/components/AskAIDialog.tsx @@ -4,24 +4,21 @@ import { toast } from "react-hot-toast"; import * as api from "../helpers/api"; import useLoading from "../hooks/useLoading"; import { marked } from "../labs/marked"; +import { useMessageStore } from "../store/zustand/message"; import Icon from "./Icon"; import { generateDialog } from "./Dialog"; import showSettingDialog from "./SettingDialog"; type Props = DialogProps; -interface History { - question: string; - answer: string; -} - const AskAIDialog: React.FC = (props: Props) => { const { destroy, hide } = props; const fetchingState = useLoading(false); - const [historyList, setHistoryList] = useState([]); + const messageStore = useMessageStore(); const [isEnabled, setIsEnabled] = useState(true); const [isInIME, setIsInIME] = useState(false); const [question, setQuestion] = useState(""); + const messageList = messageStore.messageList; useEffect(() => { api.checkOpenAIEnabled().then(({ data }) => { @@ -47,10 +44,18 @@ const AskAIDialog: React.FC = (props: Props) => { }; const handleSendQuestionButtonClick = async () => { + if (!question) { + return; + } + fetchingState.setLoading(); setQuestion(""); + messageStore.addMessage({ + role: "user", + content: question, + }); try { - await askQuestion(question); + await fetchChatCompletion(); } catch (error: any) { console.error(error); toast.error(error.response.data.error); @@ -58,21 +63,15 @@ const AskAIDialog: React.FC = (props: Props) => { fetchingState.setFinish(); }; - const askQuestion = async (question: string) => { - if (question === "") { - return; - } - + const fetchChatCompletion = async () => { + const messageList = messageStore.getState().messageList; const { data: { data: answer }, - } = await api.postChatCompletion(question); - setHistoryList([ - { - question, - answer: answer.replace(/^\n\n/, ""), - }, - ...historyList, - ]); + } = await api.postChatCompletion(messageList); + messageStore.addMessage({ + role: "assistant", + content: answer.replace(/^\n\n/, ""), + }); }; return ( @@ -87,7 +86,36 @@ const AskAIDialog: React.FC = (props: Props) => {
-
+ {messageList.map((message, index) => ( +
+ {message.role === "user" ? ( +
+ + {message.content} + +
+ ) : ( +
+ +
+
{marked(message.content)}
+
+
+ )} +
+ ))} + {fetchingState.isLoading && ( +

+ +

+ )} + {!isEnabled && ( +
+

You have not set up your OpenAI API key.

+ +
+ )} +