You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

llm-hooks.ts 7.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. import { LlmModelType } from '@/constants/knowledge';
  2. import { ResponseGetType } from '@/interfaces/database/base';
  3. import {
  4. IFactory,
  5. IMyLlmValue,
  6. IThirdOAIModelCollection as IThirdAiModelCollection,
  7. IThirdOAIModelCollection,
  8. } from '@/interfaces/database/llm';
  9. import {
  10. IAddLlmRequestBody,
  11. IDeleteLlmRequestBody,
  12. } from '@/interfaces/request/llm';
  13. import userService from '@/services/user-service';
  14. import { sortLLmFactoryListBySpecifiedOrder } from '@/utils/common-util';
  15. import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
  16. import { message } from 'antd';
  17. import { useMemo } from 'react';
  18. import { useTranslation } from 'react-i18next';
  19. export const useFetchLlmList = (
  20. modelType?: LlmModelType,
  21. ): IThirdAiModelCollection => {
  22. const { data } = useQuery({
  23. queryKey: ['llmList'],
  24. initialData: {},
  25. queryFn: async () => {
  26. const { data } = await userService.llm_list({ model_type: modelType });
  27. return data?.data ?? {};
  28. },
  29. });
  30. return data;
  31. };
  32. export const useSelectLlmOptions = () => {
  33. const llmInfo: IThirdOAIModelCollection = useFetchLlmList();
  34. const embeddingModelOptions = useMemo(() => {
  35. return Object.entries(llmInfo).map(([key, value]) => {
  36. return {
  37. label: key,
  38. options: value.map((x) => ({
  39. label: x.llm_name,
  40. value: x.llm_name,
  41. disabled: !x.available,
  42. })),
  43. };
  44. });
  45. }, [llmInfo]);
  46. return embeddingModelOptions;
  47. };
  48. export const useSelectLlmOptionsByModelType = () => {
  49. const llmInfo: IThirdOAIModelCollection = useFetchLlmList();
  50. const groupOptionsByModelType = (modelType: LlmModelType) => {
  51. return Object.entries(llmInfo)
  52. .filter(([, value]) =>
  53. modelType ? value.some((x) => x.model_type.includes(modelType)) : true,
  54. )
  55. .map(([key, value]) => {
  56. return {
  57. label: key,
  58. options: value
  59. .filter(
  60. (x) =>
  61. (modelType ? x.model_type.includes(modelType) : true) &&
  62. x.available,
  63. )
  64. .map((x) => ({
  65. label: x.llm_name,
  66. value: `${x.llm_name}@${x.fid}`,
  67. disabled: !x.available,
  68. })),
  69. };
  70. })
  71. .filter((x) => x.options.length > 0);
  72. };
  73. return {
  74. [LlmModelType.Chat]: groupOptionsByModelType(LlmModelType.Chat),
  75. [LlmModelType.Embedding]: groupOptionsByModelType(LlmModelType.Embedding),
  76. [LlmModelType.Image2text]: groupOptionsByModelType(LlmModelType.Image2text),
  77. [LlmModelType.Speech2text]: groupOptionsByModelType(
  78. LlmModelType.Speech2text,
  79. ),
  80. [LlmModelType.Rerank]: groupOptionsByModelType(LlmModelType.Rerank),
  81. [LlmModelType.TTS]: groupOptionsByModelType(LlmModelType.TTS),
  82. };
  83. };
  84. export const useFetchLlmFactoryList = (): ResponseGetType<IFactory[]> => {
  85. const { data, isFetching: loading } = useQuery({
  86. queryKey: ['factoryList'],
  87. initialData: [],
  88. gcTime: 0,
  89. queryFn: async () => {
  90. const { data } = await userService.factories_list();
  91. return data?.data ?? [];
  92. },
  93. });
  94. return { data, loading };
  95. };
  96. export type LlmItem = { name: string; logo: string } & IMyLlmValue;
  97. export const useFetchMyLlmList = (): ResponseGetType<
  98. Record<string, IMyLlmValue>
  99. > => {
  100. const { data, isFetching: loading } = useQuery({
  101. queryKey: ['myLlmList'],
  102. initialData: {},
  103. gcTime: 0,
  104. queryFn: async () => {
  105. const { data } = await userService.my_llm();
  106. return data?.data ?? {};
  107. },
  108. });
  109. return { data, loading };
  110. };
  111. export const useSelectLlmList = () => {
  112. const { data: myLlmList, loading: myLlmListLoading } = useFetchMyLlmList();
  113. const { data: factoryList, loading: factoryListLoading } =
  114. useFetchLlmFactoryList();
  115. const nextMyLlmList: Array<LlmItem> = useMemo(() => {
  116. return Object.entries(myLlmList).map(([key, value]) => ({
  117. name: key,
  118. logo: factoryList.find((x) => x.name === key)?.logo ?? '',
  119. ...value,
  120. }));
  121. }, [myLlmList, factoryList]);
  122. const nextFactoryList = useMemo(() => {
  123. const currentList = factoryList.filter((x) =>
  124. Object.keys(myLlmList).every((y) => y !== x.name),
  125. );
  126. return sortLLmFactoryListBySpecifiedOrder(currentList);
  127. }, [factoryList, myLlmList]);
  128. return {
  129. myLlmList: nextMyLlmList,
  130. factoryList: nextFactoryList,
  131. loading: myLlmListLoading || factoryListLoading,
  132. };
  133. };
  134. export interface IApiKeySavingParams {
  135. llm_factory: string;
  136. api_key: string;
  137. llm_name?: string;
  138. model_type?: string;
  139. base_url?: string;
  140. }
  141. export const useSaveApiKey = () => {
  142. const queryClient = useQueryClient();
  143. const { t } = useTranslation();
  144. const {
  145. data,
  146. isPending: loading,
  147. mutateAsync,
  148. } = useMutation({
  149. mutationKey: ['saveApiKey'],
  150. mutationFn: async (params: IApiKeySavingParams) => {
  151. const { data } = await userService.set_api_key(params);
  152. if (data.retcode === 0) {
  153. message.success(t('message.modified'));
  154. queryClient.invalidateQueries({ queryKey: ['myLlmList'] });
  155. queryClient.invalidateQueries({ queryKey: ['factoryList'] });
  156. }
  157. return data.retcode;
  158. },
  159. });
  160. return { data, loading, saveApiKey: mutateAsync };
  161. };
  162. export interface ISystemModelSettingSavingParams {
  163. tenant_id: string;
  164. name?: string;
  165. asr_id: string;
  166. embd_id: string;
  167. img2txt_id: string;
  168. llm_id: string;
  169. }
  170. export const useSaveTenantInfo = () => {
  171. const { t } = useTranslation();
  172. const {
  173. data,
  174. isPending: loading,
  175. mutateAsync,
  176. } = useMutation({
  177. mutationKey: ['saveTenantInfo'],
  178. mutationFn: async (params: ISystemModelSettingSavingParams) => {
  179. const { data } = await userService.set_tenant_info(params);
  180. if (data.retcode === 0) {
  181. message.success(t('message.modified'));
  182. }
  183. return data.retcode;
  184. },
  185. });
  186. return { data, loading, saveTenantInfo: mutateAsync };
  187. };
  188. export const useAddLlm = () => {
  189. const queryClient = useQueryClient();
  190. const { t } = useTranslation();
  191. const {
  192. data,
  193. isPending: loading,
  194. mutateAsync,
  195. } = useMutation({
  196. mutationKey: ['addLlm'],
  197. mutationFn: async (params: IAddLlmRequestBody) => {
  198. const { data } = await userService.add_llm(params);
  199. if (data.retcode === 0) {
  200. queryClient.invalidateQueries({ queryKey: ['myLlmList'] });
  201. queryClient.invalidateQueries({ queryKey: ['factoryList'] });
  202. message.success(t('message.modified'));
  203. }
  204. return data.retcode;
  205. },
  206. });
  207. return { data, loading, addLlm: mutateAsync };
  208. };
  209. export const useDeleteLlm = () => {
  210. const queryClient = useQueryClient();
  211. const { t } = useTranslation();
  212. const {
  213. data,
  214. isPending: loading,
  215. mutateAsync,
  216. } = useMutation({
  217. mutationKey: ['deleteLlm'],
  218. mutationFn: async (params: IDeleteLlmRequestBody) => {
  219. const { data } = await userService.delete_llm(params);
  220. if (data.retcode === 0) {
  221. queryClient.invalidateQueries({ queryKey: ['myLlmList'] });
  222. queryClient.invalidateQueries({ queryKey: ['factoryList'] });
  223. message.success(t('message.deleted'));
  224. }
  225. return data.retcode;
  226. },
  227. });
  228. return { data, loading, deleteLlm: mutateAsync };
  229. };