mirror of
https://github.com/usememos/memos.git
synced 2025-06-05 22:09:59 +02:00
chore: initial memo service definition (#2077)
* chore: initial memo service definition * chore: update * chore: update * chore: update
This commit is contained in:
130
api/v1/openai.go
130
api/v1/openai.go
@@ -1,130 +0,0 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
echosse "github.com/CorrectRoadH/echo-sse"
|
||||
"github.com/PullRequestInc/go-gpt3"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/usememos/memos/plugin/openai"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *APIV1Service) registerOpenAIRoutes(g *echo.Group) {
|
||||
g.POST("/openai/chat-completion", func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
openAIConfigSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{
|
||||
Name: SystemSettingOpenAIConfigName.String(),
|
||||
})
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find openai key").SetInternal(err)
|
||||
}
|
||||
|
||||
openAIConfig := OpenAIConfig{}
|
||||
if openAIConfigSetting != nil {
|
||||
err = json.Unmarshal([]byte(openAIConfigSetting.Value), &openAIConfig)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal openai system setting value").SetInternal(err)
|
||||
}
|
||||
}
|
||||
if openAIConfig.Key == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "OpenAI API key not set")
|
||||
}
|
||||
|
||||
messages := []openai.ChatCompletionMessage{}
|
||||
if err := json.NewDecoder(c.Request().Body).Decode(&messages); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post chat completion request").SetInternal(err)
|
||||
}
|
||||
if len(messages) == 0 {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "No messages provided")
|
||||
}
|
||||
|
||||
result, err := openai.PostChatCompletion(messages, openAIConfig.Key, openAIConfig.Host)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to post chat completion").SetInternal(err)
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, result)
|
||||
})
|
||||
|
||||
g.POST("/openai/chat-streaming", func(c echo.Context) error {
|
||||
messages := []gpt3.ChatCompletionRequestMessage{}
|
||||
if err := json.NewDecoder(c.Request().Body).Decode(&messages); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post chat completion request").SetInternal(err)
|
||||
}
|
||||
if len(messages) == 0 {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "No messages provided")
|
||||
}
|
||||
|
||||
ctx := c.Request().Context()
|
||||
openAIConfigSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{
|
||||
Name: SystemSettingOpenAIConfigName.String(),
|
||||
})
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find openai key").SetInternal(err)
|
||||
}
|
||||
|
||||
openAIConfig := OpenAIConfig{}
|
||||
if openAIConfigSetting != nil {
|
||||
err = json.Unmarshal([]byte(openAIConfigSetting.Value), &openAIConfig)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal openai system setting value").SetInternal(err)
|
||||
}
|
||||
}
|
||||
if openAIConfig.Key == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "OpenAI API key not set")
|
||||
}
|
||||
|
||||
sse := echosse.NewSSEClint(c)
|
||||
|
||||
// to do these things in server may not elegant.
|
||||
// But move it to openai plugin will break the simple. Because it is a streaming. We must use a channel to do it.
|
||||
// And we can think it is a forward proxy. So it in here is not a bad idea.
|
||||
client := gpt3.NewClient(openAIConfig.Key)
|
||||
err = client.ChatCompletionStream(ctx, gpt3.ChatCompletionRequest{
|
||||
Model: gpt3.GPT3Dot5Turbo,
|
||||
Messages: messages,
|
||||
Stream: true,
|
||||
},
|
||||
func(resp *gpt3.ChatCompletionStreamResponse) {
|
||||
// _ is for to pass the golangci-lint check
|
||||
_ = sse.SendEvent(resp.Choices[0].Delta.Content)
|
||||
|
||||
// to delay 0.5 s
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
// the delay is a very good way to make the chatbot more comfortable
|
||||
// otherwise the chatbot will reply too fast. Believe me it is not good.🤔
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to chat with OpenAI").SetInternal(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
g.GET("/openai/enabled", func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
openAIConfigSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{
|
||||
Name: SystemSettingOpenAIConfigName.String(),
|
||||
})
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find openai key").SetInternal(err)
|
||||
}
|
||||
|
||||
openAIConfig := OpenAIConfig{}
|
||||
if openAIConfigSetting != nil {
|
||||
err = json.Unmarshal([]byte(openAIConfigSetting.Value), &openAIConfig)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal openai system setting value").SetInternal(err)
|
||||
}
|
||||
}
|
||||
if openAIConfig.Key == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "OpenAI API key not set")
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, openAIConfig.Key != "")
|
||||
})
|
||||
}
|
@@ -42,7 +42,6 @@ func (s *APIV1Service) Register(rootGroup *echo.Group) {
|
||||
s.registerMemoOrganizerRoutes(apiV1Group)
|
||||
s.registerMemoResourceRoutes(apiV1Group)
|
||||
s.registerMemoRelationRoutes(apiV1Group)
|
||||
s.registerOpenAIRoutes(apiV1Group)
|
||||
|
||||
// Register public routes.
|
||||
publicGroup := rootGroup.Group("/o")
|
||||
|
@@ -18,8 +18,18 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// ContextKey is the key type of context value.
|
||||
type ContextKey int
|
||||
|
||||
const (
|
||||
// The key name used to store user id in the context
|
||||
// user id is extracted from the jwt token subject field.
|
||||
UserIDContextKey ContextKey = iota
|
||||
)
|
||||
|
||||
var authenticationAllowlistMethods = map[string]bool{
|
||||
"/memos.api.v2.UserService/GetUser": true,
|
||||
"/memos.api.v2.UserService/GetUser": true,
|
||||
"/memos.api.v2.MemoService/ListMemos": true,
|
||||
}
|
||||
|
||||
// IsAuthenticationAllowed returns whether the method is exempted from authentication.
|
||||
@@ -30,15 +40,6 @@ func IsAuthenticationAllowed(fullMethodName string) bool {
|
||||
return authenticationAllowlistMethods[fullMethodName]
|
||||
}
|
||||
|
||||
// ContextKey is the key type of context value.
|
||||
type ContextKey int
|
||||
|
||||
const (
|
||||
// The key name used to store user id in the context
|
||||
// user id is extracted from the jwt token subject field.
|
||||
UserIDContextKey ContextKey = iota
|
||||
)
|
||||
|
||||
// GRPCAuthInterceptor is the auth interceptor for gRPC server.
|
||||
type GRPCAuthInterceptor struct {
|
||||
store *store.Store
|
||||
|
119
api/v2/memo_service.go
Normal file
119
api/v2/memo_service.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/pkg/errors"
|
||||
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
|
||||
"github.com/usememos/memos/store"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type MemoService struct {
|
||||
apiv2pb.UnimplementedMemoServiceServer
|
||||
|
||||
Store *store.Store
|
||||
}
|
||||
|
||||
// NewMemoService creates a new MemoService.
|
||||
func NewMemoService(store *store.Store) *MemoService {
|
||||
return &MemoService{
|
||||
Store: store,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MemoService) ListMemos(ctx context.Context, request *apiv2pb.ListMemosRequest) (*apiv2pb.ListMemosResponse, error) {
|
||||
memoFind := &store.FindMemo{}
|
||||
if request.PageSize != 0 {
|
||||
offset := int(request.Page * request.PageSize)
|
||||
limit := int(request.PageSize)
|
||||
memoFind.Offset = &offset
|
||||
memoFind.Limit = &limit
|
||||
}
|
||||
if request.Filter != "" {
|
||||
visibilityString, err := getVisibilityFilter(request.Filter)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
|
||||
}
|
||||
memoFind.VisibilityList = []store.Visibility{store.Visibility(visibilityString)}
|
||||
}
|
||||
memos, err := s.Store.ListMemos(ctx, memoFind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
memoMessages := make([]*apiv2pb.Memo, len(memos))
|
||||
for i, memo := range memos {
|
||||
memoMessages[i] = convertMemoFromStore(memo)
|
||||
}
|
||||
|
||||
response := &apiv2pb.ListMemosResponse{
|
||||
Memos: memoMessages,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
const visibilityFilterExample = `visibility == "PRIVATE"`
|
||||
|
||||
// getVisibilityFilter will parse the simple filter such as `visibility = "PRIVATE"` to "PRIVATE" .
|
||||
func getVisibilityFilter(filter string) (string, error) {
|
||||
formatInvalidErr := errors.Errorf("invalid filter %q, example %q", filter, visibilityFilterExample)
|
||||
e, err := cel.NewEnv(cel.Variable("visibility", cel.StringType))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
ast, issues := e.Compile(filter)
|
||||
if issues != nil {
|
||||
return "", status.Errorf(codes.InvalidArgument, issues.String())
|
||||
}
|
||||
expr := ast.Expr()
|
||||
if expr == nil {
|
||||
return "", formatInvalidErr
|
||||
}
|
||||
callExpr := expr.GetCallExpr()
|
||||
if callExpr == nil {
|
||||
return "", formatInvalidErr
|
||||
}
|
||||
if callExpr.Function != "_==_" {
|
||||
return "", formatInvalidErr
|
||||
}
|
||||
if len(callExpr.Args) != 2 {
|
||||
return "", formatInvalidErr
|
||||
}
|
||||
if callExpr.Args[0].GetIdentExpr() == nil || callExpr.Args[0].GetIdentExpr().Name != "visibility" {
|
||||
return "", formatInvalidErr
|
||||
}
|
||||
constExpr := callExpr.Args[1].GetConstExpr()
|
||||
if constExpr == nil {
|
||||
return "", formatInvalidErr
|
||||
}
|
||||
return constExpr.GetStringValue(), nil
|
||||
}
|
||||
|
||||
func convertMemoFromStore(memo *store.Memo) *apiv2pb.Memo {
|
||||
return &apiv2pb.Memo{
|
||||
Id: int32(memo.ID),
|
||||
RowStatus: convertRowStatusFromStore(memo.RowStatus),
|
||||
CreatedTs: memo.CreatedTs,
|
||||
UpdatedTs: memo.UpdatedTs,
|
||||
CreatorId: int32(memo.CreatorID),
|
||||
Content: memo.Content,
|
||||
Visibility: convertVisibilityFromStore(memo.Visibility),
|
||||
Pinned: memo.Pinned,
|
||||
}
|
||||
}
|
||||
|
||||
func convertVisibilityFromStore(visibility store.Visibility) apiv2pb.Visibility {
|
||||
switch visibility {
|
||||
case store.Private:
|
||||
return apiv2pb.Visibility_PRIVATE
|
||||
case store.Protected:
|
||||
return apiv2pb.Visibility_PROTECTED
|
||||
case store.Public:
|
||||
return apiv2pb.Visibility_PUBLIC
|
||||
default:
|
||||
return apiv2pb.Visibility_VISIBILITY_UNSPECIFIED
|
||||
}
|
||||
}
|
@@ -37,20 +37,6 @@ func (s *UserService) GetUser(ctx context.Context, request *apiv2pb.GetUserReque
|
||||
// Data desensitization.
|
||||
userMessage.OpenId = ""
|
||||
|
||||
userSettings, err := s.Store.ListUserSettings(ctx, &store.FindUserSetting{
|
||||
UserID: &userMessage.Id,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to list user settings: %v", err)
|
||||
}
|
||||
|
||||
userID, ok := ctx.Value(UserIDContextKey).(int)
|
||||
if ok && userID == int(userMessage.Id) {
|
||||
for _, userSetting := range userSettings {
|
||||
userMessage.Settings = append(userMessage.Settings, convertUserSettingFromStore(userSetting))
|
||||
}
|
||||
}
|
||||
|
||||
response := &apiv2pb.GetUserResponse{
|
||||
User: userMessage,
|
||||
}
|
||||
@@ -69,7 +55,6 @@ func convertUserFromStore(user *store.User) *apiv2pb.User {
|
||||
Nickname: user.Nickname,
|
||||
OpenId: user.OpenID,
|
||||
AvatarUrl: user.AvatarURL,
|
||||
Settings: []*apiv2pb.UserSetting{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,7 +71,8 @@ func convertUserRoleFromStore(role store.Role) apiv2pb.Role {
|
||||
}
|
||||
}
|
||||
|
||||
func convertUserSettingFromStore(userSetting *store.UserSetting) *apiv2pb.UserSetting {
|
||||
// ConvertUserSettingFromStore converts a user setting from store to protobuf.
|
||||
func ConvertUserSettingFromStore(userSetting *store.UserSetting) *apiv2pb.UserSetting {
|
||||
userSettingKey := apiv2pb.UserSetting_KEY_UNSPECIFIED
|
||||
userSettingValue := &apiv2pb.UserSettingValue{}
|
||||
switch userSetting.Key {
|
||||
@@ -103,7 +89,7 @@ func convertUserSettingFromStore(userSetting *store.UserSetting) *apiv2pb.UserSe
|
||||
case "memo-visibility":
|
||||
userSettingKey = apiv2pb.UserSetting_MEMO_VISIBILITY
|
||||
userSettingValue.Value = &apiv2pb.UserSettingValue_VisibilityValue{
|
||||
VisibilityValue: convertVisibilityFromString(userSetting.Value),
|
||||
VisibilityValue: convertVisibilityFromStore(store.Visibility(userSetting.Value)),
|
||||
}
|
||||
case "telegram-user-id":
|
||||
userSettingKey = apiv2pb.UserSetting_TELEGRAM_USER_ID
|
||||
@@ -117,14 +103,3 @@ func convertUserSettingFromStore(userSetting *store.UserSetting) *apiv2pb.UserSe
|
||||
Value: userSettingValue,
|
||||
}
|
||||
}
|
||||
|
||||
func convertVisibilityFromString(visibility string) apiv2pb.Visibility {
|
||||
switch visibility {
|
||||
case "public":
|
||||
return apiv2pb.Visibility_PUBLIC
|
||||
case "private":
|
||||
return apiv2pb.Visibility_PRIVATE
|
||||
default:
|
||||
return apiv2pb.Visibility_VISIBILITY_UNSPECIFIED
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user