feat: Select derived messages from backend #2088 (#2176)

### What problem does this PR solve?

feat: Select derived messages from backend #2088

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
balibabu
2024-08-30 17:53:30 +08:00
committed by GitHub
parent 2c771fb0b4
commit 5400467da1
13 changed files with 556 additions and 220 deletions

View File

@ -1,14 +1,15 @@
import { Authorization } from '@/constants/authorization';
import { MessageType } from '@/constants/chat';
import { LanguageTranslationMap } from '@/constants/common';
import { Pagination } from '@/interfaces/common';
import { ResponseType } from '@/interfaces/database/base';
import { IAnswer, Message } from '@/interfaces/database/chat';
import { IKnowledgeFile } from '@/interfaces/database/knowledge';
import { IChangeParserConfigRequestBody } from '@/interfaces/request/document';
import { IClientConversation } from '@/pages/chat/interface';
import { IClientConversation, IMessage } from '@/pages/chat/interface';
import api from '@/utils/api';
import { getAuthorization } from '@/utils/authorization-util';
import { getMessagePureId } from '@/utils/chat';
import { buildMessageUuid, getMessagePureId } from '@/utils/chat';
import { PaginationProps } from 'antd';
import { FormInstance } from 'antd/lib';
import axios from 'axios';
@ -309,6 +310,108 @@ export const useHandleMessageInputChange = () => {
};
};
export const useSelectDerivedMessages = () => {
const [derivedMessages, setDerivedMessages] = useState<IMessage[]>([]);
const ref = useScrollToBottom(derivedMessages);
const addNewestQuestion = useCallback(
(message: Message, answer: string = '') => {
setDerivedMessages((pre) => {
return [
...pre,
{
...message,
id: buildMessageUuid(message),
},
{
role: MessageType.Assistant,
content: answer,
id: buildMessageUuid({ ...message, role: MessageType.Assistant }),
},
];
});
},
[],
);
// Add the streaming message to the last item in the message list
const addNewestAnswer = useCallback((answer: IAnswer) => {
setDerivedMessages((pre) => {
return [
...(pre?.slice(0, -1) ?? []),
{
role: MessageType.Assistant,
content: answer.answer,
reference: answer.reference,
id: buildMessageUuid({
id: answer.id,
role: MessageType.Assistant,
}),
prompt: answer.prompt,
},
];
});
}, []);
const removeLatestMessage = useCallback(() => {
setDerivedMessages((pre) => {
const nextMessages = pre?.slice(0, -2) ?? [];
return nextMessages;
});
}, []);
const removeMessageById = useCallback(
(messageId: string) => {
setDerivedMessages((pre) => {
const nextMessages =
pre?.filter(
(x) => getMessagePureId(x.id) !== getMessagePureId(messageId),
) ?? [];
return nextMessages;
});
},
[setDerivedMessages],
);
const removeMessagesAfterCurrentMessage = useCallback(
(messageId: string) => {
setDerivedMessages((pre) => {
const index = pre.findIndex((x) => x.id === messageId);
if (index !== -1) {
let nextMessages = pre.slice(0, index + 2) ?? [];
const latestMessage = nextMessages.at(-1);
nextMessages = latestMessage
? [
...nextMessages.slice(0, -1),
{
...latestMessage,
content: '',
reference: undefined,
prompt: undefined,
},
]
: nextMessages;
return nextMessages;
}
return pre;
});
},
[setDerivedMessages],
);
return {
ref,
derivedMessages,
setDerivedMessages,
addNewestQuestion,
addNewestAnswer,
removeLatestMessage,
removeMessageById,
removeMessagesAfterCurrentMessage,
};
};
export interface IRemoveMessageById {
removeMessageById(messageId: string): void;
}
@ -375,7 +478,7 @@ export const useRemoveMessagesAfterCurrentMessage = (
};
export interface IRegenerateMessage {
regenerateMessage(message: Message): void;
regenerateMessage?: (message: Message) => void;
}
export const useRegenerateMessage = ({
@ -384,7 +487,12 @@ export const useRegenerateMessage = ({
messages,
}: {
removeMessagesAfterCurrentMessage(messageId: string): void;
sendMessage({ message }: { message: Message; messages?: Message[] }): void;
sendMessage({
message,
}: {
message: Message;
messages?: Message[];
}): void | Promise<any>;
messages: Message[];
}) => {
const regenerateMessage = useCallback(