You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. import {
  2. useCallback,
  3. useEffect,
  4. useMemo,
  5. useState,
  6. } from 'react'
  7. import useSWR, { useSWRConfig } from 'swr'
  8. import { useContext } from 'use-context-selector'
  9. import type {
  10. Credential,
  11. CustomConfigurationModelFixedFields,
  12. CustomModel,
  13. DefaultModel,
  14. DefaultModelResponse,
  15. Model,
  16. ModelModalModeEnum,
  17. ModelProvider,
  18. ModelTypeEnum,
  19. } from './declarations'
  20. import {
  21. ConfigurationMethodEnum,
  22. CustomConfigurationStatusEnum,
  23. ModelStatusEnum,
  24. } from './declarations'
  25. import I18n from '@/context/i18n'
  26. import {
  27. fetchDefaultModal,
  28. fetchModelList,
  29. fetchModelProviderCredentials,
  30. fetchModelProviders,
  31. getPayUrl,
  32. } from '@/service/common'
  33. import { useProviderContext } from '@/context/provider-context'
  34. import {
  35. useMarketplacePlugins,
  36. } from '@/app/components/plugins/marketplace/hooks'
  37. import type { Plugin } from '@/app/components/plugins/types'
  38. import { PluginType } from '@/app/components/plugins/types'
  39. import { getMarketplacePluginsByCollectionId } from '@/app/components/plugins/marketplace/utils'
  40. import { useModalContextSelector } from '@/context/modal-context'
  41. import { useEventEmitterContextContext } from '@/context/event-emitter'
  42. import { UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST } from './provider-added-card'
  43. type UseDefaultModelAndModelList = (
  44. defaultModel: DefaultModelResponse | undefined,
  45. modelList: Model[],
  46. ) => [DefaultModel | undefined, (model: DefaultModel) => void]
  47. export const useSystemDefaultModelAndModelList: UseDefaultModelAndModelList = (
  48. defaultModel,
  49. modelList,
  50. ) => {
  51. const currentDefaultModel = useMemo(() => {
  52. const currentProvider = modelList.find(provider => provider.provider === defaultModel?.provider.provider)
  53. const currentModel = currentProvider?.models.find(model => model.model === defaultModel?.model)
  54. const currentDefaultModel = currentProvider && currentModel && {
  55. model: currentModel.model,
  56. provider: currentProvider.provider,
  57. }
  58. return currentDefaultModel
  59. }, [defaultModel, modelList])
  60. const [defaultModelState, setDefaultModelState] = useState<DefaultModel | undefined>(currentDefaultModel)
  61. const handleDefaultModelChange = useCallback((model: DefaultModel) => {
  62. setDefaultModelState(model)
  63. }, [])
  64. useEffect(() => {
  65. setDefaultModelState(currentDefaultModel)
  66. }, [currentDefaultModel])
  67. return [defaultModelState, handleDefaultModelChange]
  68. }
  69. export const useLanguage = () => {
  70. const { locale } = useContext(I18n)
  71. return locale.replace('-', '_')
  72. }
  73. export const useProviderCredentialsAndLoadBalancing = (
  74. provider: string,
  75. configurationMethod: ConfigurationMethodEnum,
  76. configured?: boolean,
  77. currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields,
  78. credentialId?: string,
  79. ) => {
  80. const { data: predefinedFormSchemasValue, mutate: mutatePredefined, isLoading: isPredefinedLoading } = useSWR(
  81. (configurationMethod === ConfigurationMethodEnum.predefinedModel && configured && credentialId)
  82. ? `/workspaces/current/model-providers/${provider}/credentials${credentialId ? `?credential_id=${credentialId}` : ''}`
  83. : null,
  84. fetchModelProviderCredentials,
  85. )
  86. const { data: customFormSchemasValue, mutate: mutateCustomized, isLoading: isCustomizedLoading } = useSWR(
  87. (configurationMethod === ConfigurationMethodEnum.customizableModel && currentCustomConfigurationModelFixedFields && credentialId)
  88. ? `/workspaces/current/model-providers/${provider}/models/credentials?model=${currentCustomConfigurationModelFixedFields?.__model_name}&model_type=${currentCustomConfigurationModelFixedFields?.__model_type}${credentialId ? `&credential_id=${credentialId}` : ''}`
  89. : null,
  90. fetchModelProviderCredentials,
  91. )
  92. const credentials = useMemo(() => {
  93. return configurationMethod === ConfigurationMethodEnum.predefinedModel
  94. ? predefinedFormSchemasValue?.credentials
  95. : customFormSchemasValue?.credentials
  96. ? {
  97. ...customFormSchemasValue?.credentials,
  98. ...currentCustomConfigurationModelFixedFields,
  99. }
  100. : undefined
  101. }, [
  102. configurationMethod,
  103. credentialId,
  104. currentCustomConfigurationModelFixedFields,
  105. customFormSchemasValue?.credentials,
  106. predefinedFormSchemasValue?.credentials,
  107. ])
  108. const mutate = useMemo(() => () => {
  109. mutatePredefined()
  110. mutateCustomized()
  111. }, [mutateCustomized, mutatePredefined])
  112. return {
  113. credentials,
  114. loadBalancing: (configurationMethod === ConfigurationMethodEnum.predefinedModel
  115. ? predefinedFormSchemasValue
  116. : customFormSchemasValue
  117. )?.load_balancing,
  118. mutate,
  119. isLoading: isPredefinedLoading || isCustomizedLoading,
  120. }
  121. // as ([Record<string, string | boolean | undefined> | undefined, ModelLoadBalancingConfig | undefined])
  122. }
  123. export const useModelList = (type: ModelTypeEnum) => {
  124. const { data, mutate, isLoading } = useSWR(`/workspaces/current/models/model-types/${type}`, fetchModelList)
  125. return {
  126. data: data?.data || [],
  127. mutate,
  128. isLoading,
  129. }
  130. }
  131. export const useDefaultModel = (type: ModelTypeEnum) => {
  132. const { data, mutate, isLoading } = useSWR(`/workspaces/current/default-model?model_type=${type}`, fetchDefaultModal)
  133. return {
  134. data: data?.data,
  135. mutate,
  136. isLoading,
  137. }
  138. }
  139. export const useCurrentProviderAndModel = (modelList: Model[], defaultModel?: DefaultModel) => {
  140. const currentProvider = modelList.find(provider => provider.provider === defaultModel?.provider)
  141. const currentModel = currentProvider?.models.find(model => model.model === defaultModel?.model)
  142. return {
  143. currentProvider,
  144. currentModel,
  145. }
  146. }
  147. export const useTextGenerationCurrentProviderAndModelAndModelList = (defaultModel?: DefaultModel) => {
  148. const { textGenerationModelList } = useProviderContext()
  149. const activeTextGenerationModelList = textGenerationModelList.filter(model => model.status === ModelStatusEnum.active)
  150. const {
  151. currentProvider,
  152. currentModel,
  153. } = useCurrentProviderAndModel(textGenerationModelList, defaultModel)
  154. return {
  155. currentProvider,
  156. currentModel,
  157. textGenerationModelList,
  158. activeTextGenerationModelList,
  159. }
  160. }
  161. export const useModelListAndDefaultModel = (type: ModelTypeEnum) => {
  162. const { data: modelList } = useModelList(type)
  163. const { data: defaultModel } = useDefaultModel(type)
  164. return {
  165. modelList,
  166. defaultModel,
  167. }
  168. }
  169. export const useModelListAndDefaultModelAndCurrentProviderAndModel = (type: ModelTypeEnum) => {
  170. const { modelList, defaultModel } = useModelListAndDefaultModel(type)
  171. const { currentProvider, currentModel } = useCurrentProviderAndModel(
  172. modelList,
  173. { provider: defaultModel?.provider.provider || '', model: defaultModel?.model || '' },
  174. )
  175. return {
  176. modelList,
  177. defaultModel,
  178. currentProvider,
  179. currentModel,
  180. }
  181. }
  182. export const useUpdateModelList = () => {
  183. const { mutate } = useSWRConfig()
  184. const updateModelList = useCallback((type: ModelTypeEnum) => {
  185. mutate(`/workspaces/current/models/model-types/${type}`)
  186. }, [mutate])
  187. return updateModelList
  188. }
  189. export const useAnthropicBuyQuota = () => {
  190. const [loading, setLoading] = useState(false)
  191. const handleGetPayUrl = async () => {
  192. if (loading)
  193. return
  194. setLoading(true)
  195. try {
  196. const res = await getPayUrl('/workspaces/current/model-providers/anthropic/checkout-url')
  197. window.location.href = res.url
  198. }
  199. finally {
  200. setLoading(false)
  201. }
  202. }
  203. return handleGetPayUrl
  204. }
  205. export const useModelProviders = () => {
  206. const { data: providersData, mutate, isLoading } = useSWR('/workspaces/current/model-providers', fetchModelProviders)
  207. return {
  208. data: providersData?.data || [],
  209. mutate,
  210. isLoading,
  211. }
  212. }
  213. export const useUpdateModelProviders = () => {
  214. const { mutate } = useSWRConfig()
  215. const updateModelProviders = useCallback(() => {
  216. mutate('/workspaces/current/model-providers')
  217. }, [mutate])
  218. return updateModelProviders
  219. }
  220. export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText: string) => {
  221. const exclude = useMemo(() => {
  222. return providers.map(provider => provider.provider.replace(/(.+)\/([^/]+)$/, '$1'))
  223. }, [providers])
  224. const [collectionPlugins, setCollectionPlugins] = useState<Plugin[]>([])
  225. const {
  226. plugins,
  227. queryPlugins,
  228. queryPluginsWithDebounced,
  229. isLoading,
  230. } = useMarketplacePlugins()
  231. const getCollectionPlugins = useCallback(async () => {
  232. const collectionPlugins = await getMarketplacePluginsByCollectionId('__model-settings-pinned-models')
  233. setCollectionPlugins(collectionPlugins)
  234. }, [])
  235. useEffect(() => {
  236. getCollectionPlugins()
  237. }, [getCollectionPlugins])
  238. useEffect(() => {
  239. if (searchText) {
  240. queryPluginsWithDebounced({
  241. query: searchText,
  242. category: PluginType.model,
  243. exclude,
  244. type: 'plugin',
  245. sortBy: 'install_count',
  246. sortOrder: 'DESC',
  247. })
  248. }
  249. else {
  250. queryPlugins({
  251. query: '',
  252. category: PluginType.model,
  253. type: 'plugin',
  254. pageSize: 1000,
  255. exclude,
  256. sortBy: 'install_count',
  257. sortOrder: 'DESC',
  258. })
  259. }
  260. }, [queryPlugins, queryPluginsWithDebounced, searchText, exclude])
  261. const allPlugins = useMemo(() => {
  262. const allPlugins = collectionPlugins.filter(plugin => !exclude.includes(plugin.plugin_id))
  263. if (plugins?.length) {
  264. for (let i = 0; i < plugins.length; i++) {
  265. const plugin = plugins[i]
  266. if (plugin.type !== 'bundle' && !allPlugins.find(p => p.plugin_id === plugin.plugin_id))
  267. allPlugins.push(plugin)
  268. }
  269. }
  270. return allPlugins
  271. }, [plugins, collectionPlugins, exclude])
  272. return {
  273. plugins: allPlugins,
  274. isLoading,
  275. }
  276. }
  277. export const useRefreshModel = () => {
  278. const { eventEmitter } = useEventEmitterContextContext()
  279. const updateModelProviders = useUpdateModelProviders()
  280. const updateModelList = useUpdateModelList()
  281. const handleRefreshModel = useCallback((provider: ModelProvider, configurationMethod: ConfigurationMethodEnum, CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => {
  282. updateModelProviders()
  283. provider.supported_model_types.forEach((type) => {
  284. updateModelList(type)
  285. })
  286. if (configurationMethod === ConfigurationMethodEnum.customizableModel
  287. && provider.custom_configuration.status === CustomConfigurationStatusEnum.active) {
  288. eventEmitter?.emit({
  289. type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST,
  290. payload: provider.provider,
  291. } as any)
  292. if (CustomConfigurationModelFixedFields?.__model_type)
  293. updateModelList(CustomConfigurationModelFixedFields.__model_type)
  294. }
  295. }, [eventEmitter, updateModelList, updateModelProviders])
  296. return {
  297. handleRefreshModel,
  298. }
  299. }
  300. export const useModelModalHandler = () => {
  301. const setShowModelModal = useModalContextSelector(state => state.setShowModelModal)
  302. return (
  303. provider: ModelProvider,
  304. configurationMethod: ConfigurationMethodEnum,
  305. CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields,
  306. extra: {
  307. isModelCredential?: boolean,
  308. credential?: Credential,
  309. model?: CustomModel,
  310. onUpdate?: (newPayload: any, formValues?: Record<string, any>) => void,
  311. mode?: ModelModalModeEnum,
  312. } = {},
  313. ) => {
  314. setShowModelModal({
  315. payload: {
  316. currentProvider: provider,
  317. currentConfigurationMethod: configurationMethod,
  318. currentCustomConfigurationModelFixedFields: CustomConfigurationModelFixedFields,
  319. isModelCredential: extra.isModelCredential,
  320. credential: extra.credential,
  321. model: extra.model,
  322. mode: extra.mode,
  323. },
  324. onSaveCallback: (newPayload, formValues) => {
  325. extra.onUpdate?.(newPayload, formValues)
  326. },
  327. })
  328. }
  329. }