You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

hooks.ts 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. import {
  2. useCallback,
  3. useEffect,
  4. useMemo,
  5. useState,
  6. } from 'react'
  7. import useSWR, { useSWRConfig } from 'swr'
  8. import { useContext } from 'use-context-selector'
  9. import type {
  10. Credential,
  11. CustomConfigurationModelFixedFields,
  12. CustomModel,
  13. DefaultModel,
  14. DefaultModelResponse,
  15. Model,
  16. ModelProvider,
  17. ModelTypeEnum,
  18. } from './declarations'
  19. import {
  20. ConfigurationMethodEnum,
  21. CustomConfigurationStatusEnum,
  22. ModelStatusEnum,
  23. } from './declarations'
  24. import I18n from '@/context/i18n'
  25. import {
  26. fetchDefaultModal,
  27. fetchModelList,
  28. fetchModelProviderCredentials,
  29. fetchModelProviders,
  30. getPayUrl,
  31. } from '@/service/common'
  32. import { useProviderContext } from '@/context/provider-context'
  33. import {
  34. useMarketplacePlugins,
  35. } from '@/app/components/plugins/marketplace/hooks'
  36. import type { Plugin } from '@/app/components/plugins/types'
  37. import { PluginType } from '@/app/components/plugins/types'
  38. import { getMarketplacePluginsByCollectionId } from '@/app/components/plugins/marketplace/utils'
  39. import { useModalContextSelector } from '@/context/modal-context'
  40. import { useEventEmitterContextContext } from '@/context/event-emitter'
  41. import { UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST } from './provider-added-card'
  42. type UseDefaultModelAndModelList = (
  43. defaultModel: DefaultModelResponse | undefined,
  44. modelList: Model[],
  45. ) => [DefaultModel | undefined, (model: DefaultModel) => void]
  46. export const useSystemDefaultModelAndModelList: UseDefaultModelAndModelList = (
  47. defaultModel,
  48. modelList,
  49. ) => {
  50. const currentDefaultModel = useMemo(() => {
  51. const currentProvider = modelList.find(provider => provider.provider === defaultModel?.provider.provider)
  52. const currentModel = currentProvider?.models.find(model => model.model === defaultModel?.model)
  53. const currentDefaultModel = currentProvider && currentModel && {
  54. model: currentModel.model,
  55. provider: currentProvider.provider,
  56. }
  57. return currentDefaultModel
  58. }, [defaultModel, modelList])
  59. const [defaultModelState, setDefaultModelState] = useState<DefaultModel | undefined>(currentDefaultModel)
  60. const handleDefaultModelChange = useCallback((model: DefaultModel) => {
  61. setDefaultModelState(model)
  62. }, [])
  63. useEffect(() => {
  64. setDefaultModelState(currentDefaultModel)
  65. }, [currentDefaultModel])
  66. return [defaultModelState, handleDefaultModelChange]
  67. }
  68. export const useLanguage = () => {
  69. const { locale } = useContext(I18n)
  70. return locale.replace('-', '_')
  71. }
  72. export const useProviderCredentialsAndLoadBalancing = (
  73. provider: string,
  74. configurationMethod: ConfigurationMethodEnum,
  75. configured?: boolean,
  76. currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields,
  77. credentialId?: string,
  78. ) => {
  79. const { data: predefinedFormSchemasValue, mutate: mutatePredefined, isLoading: isPredefinedLoading } = useSWR(
  80. (configurationMethod === ConfigurationMethodEnum.predefinedModel && configured && credentialId)
  81. ? `/workspaces/current/model-providers/${provider}/credentials${credentialId ? `?credential_id=${credentialId}` : ''}`
  82. : null,
  83. fetchModelProviderCredentials,
  84. )
  85. const { data: customFormSchemasValue, mutate: mutateCustomized, isLoading: isCustomizedLoading } = useSWR(
  86. (configurationMethod === ConfigurationMethodEnum.customizableModel && currentCustomConfigurationModelFixedFields && credentialId)
  87. ? `/workspaces/current/model-providers/${provider}/models/credentials?model=${currentCustomConfigurationModelFixedFields?.__model_name}&model_type=${currentCustomConfigurationModelFixedFields?.__model_type}${credentialId ? `&credential_id=${credentialId}` : ''}`
  88. : null,
  89. fetchModelProviderCredentials,
  90. )
  91. const credentials = useMemo(() => {
  92. return configurationMethod === ConfigurationMethodEnum.predefinedModel
  93. ? predefinedFormSchemasValue?.credentials
  94. : customFormSchemasValue?.credentials
  95. ? {
  96. ...customFormSchemasValue?.credentials,
  97. ...currentCustomConfigurationModelFixedFields,
  98. }
  99. : undefined
  100. }, [
  101. configurationMethod,
  102. credentialId,
  103. currentCustomConfigurationModelFixedFields,
  104. customFormSchemasValue?.credentials,
  105. predefinedFormSchemasValue?.credentials,
  106. ])
  107. const mutate = useMemo(() => () => {
  108. mutatePredefined()
  109. mutateCustomized()
  110. }, [mutateCustomized, mutatePredefined])
  111. return {
  112. credentials,
  113. loadBalancing: (configurationMethod === ConfigurationMethodEnum.predefinedModel
  114. ? predefinedFormSchemasValue
  115. : customFormSchemasValue
  116. )?.load_balancing,
  117. mutate,
  118. isLoading: isPredefinedLoading || isCustomizedLoading,
  119. }
  120. // as ([Record<string, string | boolean | undefined> | undefined, ModelLoadBalancingConfig | undefined])
  121. }
  122. export const useModelList = (type: ModelTypeEnum) => {
  123. const { data, mutate, isLoading } = useSWR(`/workspaces/current/models/model-types/${type}`, fetchModelList)
  124. return {
  125. data: data?.data || [],
  126. mutate,
  127. isLoading,
  128. }
  129. }
  130. export const useDefaultModel = (type: ModelTypeEnum) => {
  131. const { data, mutate, isLoading } = useSWR(`/workspaces/current/default-model?model_type=${type}`, fetchDefaultModal)
  132. return {
  133. data: data?.data,
  134. mutate,
  135. isLoading,
  136. }
  137. }
  138. export const useCurrentProviderAndModel = (modelList: Model[], defaultModel?: DefaultModel) => {
  139. const currentProvider = modelList.find(provider => provider.provider === defaultModel?.provider)
  140. const currentModel = currentProvider?.models.find(model => model.model === defaultModel?.model)
  141. return {
  142. currentProvider,
  143. currentModel,
  144. }
  145. }
  146. export const useTextGenerationCurrentProviderAndModelAndModelList = (defaultModel?: DefaultModel) => {
  147. const { textGenerationModelList } = useProviderContext()
  148. const activeTextGenerationModelList = textGenerationModelList.filter(model => model.status === ModelStatusEnum.active)
  149. const {
  150. currentProvider,
  151. currentModel,
  152. } = useCurrentProviderAndModel(textGenerationModelList, defaultModel)
  153. return {
  154. currentProvider,
  155. currentModel,
  156. textGenerationModelList,
  157. activeTextGenerationModelList,
  158. }
  159. }
  160. export const useModelListAndDefaultModel = (type: ModelTypeEnum) => {
  161. const { data: modelList } = useModelList(type)
  162. const { data: defaultModel } = useDefaultModel(type)
  163. return {
  164. modelList,
  165. defaultModel,
  166. }
  167. }
  168. export const useModelListAndDefaultModelAndCurrentProviderAndModel = (type: ModelTypeEnum) => {
  169. const { modelList, defaultModel } = useModelListAndDefaultModel(type)
  170. const { currentProvider, currentModel } = useCurrentProviderAndModel(
  171. modelList,
  172. { provider: defaultModel?.provider.provider || '', model: defaultModel?.model || '' },
  173. )
  174. return {
  175. modelList,
  176. defaultModel,
  177. currentProvider,
  178. currentModel,
  179. }
  180. }
  181. export const useUpdateModelList = () => {
  182. const { mutate } = useSWRConfig()
  183. const updateModelList = useCallback((type: ModelTypeEnum) => {
  184. mutate(`/workspaces/current/models/model-types/${type}`)
  185. }, [mutate])
  186. return updateModelList
  187. }
  188. export const useAnthropicBuyQuota = () => {
  189. const [loading, setLoading] = useState(false)
  190. const handleGetPayUrl = async () => {
  191. if (loading)
  192. return
  193. setLoading(true)
  194. try {
  195. const res = await getPayUrl('/workspaces/current/model-providers/anthropic/checkout-url')
  196. window.location.href = res.url
  197. }
  198. finally {
  199. setLoading(false)
  200. }
  201. }
  202. return handleGetPayUrl
  203. }
  204. export const useModelProviders = () => {
  205. const { data: providersData, mutate, isLoading } = useSWR('/workspaces/current/model-providers', fetchModelProviders)
  206. return {
  207. data: providersData?.data || [],
  208. mutate,
  209. isLoading,
  210. }
  211. }
  212. export const useUpdateModelProviders = () => {
  213. const { mutate } = useSWRConfig()
  214. const updateModelProviders = useCallback(() => {
  215. mutate('/workspaces/current/model-providers')
  216. }, [mutate])
  217. return updateModelProviders
  218. }
  219. export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText: string) => {
  220. const exclude = useMemo(() => {
  221. return providers.map(provider => provider.provider.replace(/(.+)\/([^/]+)$/, '$1'))
  222. }, [providers])
  223. const [collectionPlugins, setCollectionPlugins] = useState<Plugin[]>([])
  224. const {
  225. plugins,
  226. queryPlugins,
  227. queryPluginsWithDebounced,
  228. isLoading,
  229. } = useMarketplacePlugins()
  230. const getCollectionPlugins = useCallback(async () => {
  231. const collectionPlugins = await getMarketplacePluginsByCollectionId('__model-settings-pinned-models')
  232. setCollectionPlugins(collectionPlugins)
  233. }, [])
  234. useEffect(() => {
  235. getCollectionPlugins()
  236. }, [getCollectionPlugins])
  237. useEffect(() => {
  238. if (searchText) {
  239. queryPluginsWithDebounced({
  240. query: searchText,
  241. category: PluginType.model,
  242. exclude,
  243. type: 'plugin',
  244. sortBy: 'install_count',
  245. sortOrder: 'DESC',
  246. })
  247. }
  248. else {
  249. queryPlugins({
  250. query: '',
  251. category: PluginType.model,
  252. type: 'plugin',
  253. pageSize: 1000,
  254. exclude,
  255. sortBy: 'install_count',
  256. sortOrder: 'DESC',
  257. })
  258. }
  259. }, [queryPlugins, queryPluginsWithDebounced, searchText, exclude])
  260. const allPlugins = useMemo(() => {
  261. const allPlugins = [...collectionPlugins.filter(plugin => !exclude.includes(plugin.plugin_id))]
  262. if (plugins?.length) {
  263. for (let i = 0; i < plugins.length; i++) {
  264. const plugin = plugins[i]
  265. if (plugin.type !== 'bundle' && !allPlugins.find(p => p.plugin_id === plugin.plugin_id))
  266. allPlugins.push(plugin)
  267. }
  268. }
  269. return allPlugins
  270. }, [plugins, collectionPlugins, exclude])
  271. return {
  272. plugins: allPlugins,
  273. isLoading,
  274. }
  275. }
  276. export const useRefreshModel = () => {
  277. const { eventEmitter } = useEventEmitterContextContext()
  278. const updateModelProviders = useUpdateModelProviders()
  279. const updateModelList = useUpdateModelList()
  280. const handleRefreshModel = useCallback((provider: ModelProvider, configurationMethod: ConfigurationMethodEnum, CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => {
  281. updateModelProviders()
  282. provider.supported_model_types.forEach((type) => {
  283. updateModelList(type)
  284. })
  285. if (configurationMethod === ConfigurationMethodEnum.customizableModel
  286. && provider.custom_configuration.status === CustomConfigurationStatusEnum.active) {
  287. eventEmitter?.emit({
  288. type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST,
  289. payload: provider.provider,
  290. } as any)
  291. if (CustomConfigurationModelFixedFields?.__model_type)
  292. updateModelList(CustomConfigurationModelFixedFields.__model_type)
  293. }
  294. }, [eventEmitter, updateModelList, updateModelProviders])
  295. return {
  296. handleRefreshModel,
  297. }
  298. }
  299. export const useModelModalHandler = () => {
  300. const setShowModelModal = useModalContextSelector(state => state.setShowModelModal)
  301. const { handleRefreshModel } = useRefreshModel()
  302. return (
  303. provider: ModelProvider,
  304. configurationMethod: ConfigurationMethodEnum,
  305. CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields,
  306. isModelCredential?: boolean,
  307. credential?: Credential,
  308. model?: CustomModel,
  309. onUpdate?: () => void,
  310. ) => {
  311. setShowModelModal({
  312. payload: {
  313. currentProvider: provider,
  314. currentConfigurationMethod: configurationMethod,
  315. currentCustomConfigurationModelFixedFields: CustomConfigurationModelFixedFields,
  316. isModelCredential,
  317. credential,
  318. model,
  319. },
  320. onSaveCallback: () => {
  321. handleRefreshModel(provider, configurationMethod, CustomConfigurationModelFixedFields)
  322. onUpdate?.()
  323. },
  324. })
  325. }
  326. }