Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

index.tsx 6.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import { useMemo } from 'react'
  2. import { useTranslation } from 'react-i18next'
  3. import { useDebounce } from 'ahooks'
  4. import {
  5. RiAlertFill,
  6. RiBrainLine,
  7. } from '@remixicon/react'
  8. import SystemModelSelector from './system-model-selector'
  9. import ProviderAddedCard from './provider-added-card'
  10. import type {
  11. ModelProvider,
  12. } from './declarations'
  13. import {
  14. CustomConfigurationStatusEnum,
  15. ModelTypeEnum,
  16. } from './declarations'
  17. import {
  18. useDefaultModel,
  19. } from './hooks'
  20. import InstallFromMarketplace from './install-from-marketplace'
  21. import { useProviderContext } from '@/context/provider-context'
  22. import cn from '@/utils/classnames'
  23. import { useGlobalPublicStore } from '@/context/global-public-context'
  24. type Props = {
  25. searchText: string
  26. }
  27. const FixedModelProvider = ['langgenius/openai/openai', 'langgenius/anthropic/anthropic']
  28. const ModelProviderPage = ({ searchText }: Props) => {
  29. const debouncedSearchText = useDebounce(searchText, { wait: 500 })
  30. const { t } = useTranslation()
  31. const { data: textGenerationDefaultModel } = useDefaultModel(ModelTypeEnum.textGeneration)
  32. const { data: embeddingsDefaultModel } = useDefaultModel(ModelTypeEnum.textEmbedding)
  33. const { data: rerankDefaultModel } = useDefaultModel(ModelTypeEnum.rerank)
  34. const { data: speech2textDefaultModel } = useDefaultModel(ModelTypeEnum.speech2text)
  35. const { data: ttsDefaultModel } = useDefaultModel(ModelTypeEnum.tts)
  36. const { modelProviders: providers } = useProviderContext()
  37. const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures)
  38. const defaultModelNotConfigured = !textGenerationDefaultModel && !embeddingsDefaultModel && !speech2textDefaultModel && !rerankDefaultModel && !ttsDefaultModel
  39. const [configuredProviders, notConfiguredProviders] = useMemo(() => {
  40. const configuredProviders: ModelProvider[] = []
  41. const notConfiguredProviders: ModelProvider[] = []
  42. providers.forEach((provider) => {
  43. if (
  44. provider.custom_configuration.status === CustomConfigurationStatusEnum.active
  45. || (
  46. provider.system_configuration.enabled === true
  47. && provider.system_configuration.quota_configurations.find(item => item.quota_type === provider.system_configuration.current_quota_type)
  48. )
  49. )
  50. configuredProviders.push(provider)
  51. else
  52. notConfiguredProviders.push(provider)
  53. })
  54. configuredProviders.sort((a, b) => {
  55. if (FixedModelProvider.includes(a.provider) && FixedModelProvider.includes(b.provider))
  56. return FixedModelProvider.indexOf(a.provider) - FixedModelProvider.indexOf(b.provider) > 0 ? 1 : -1
  57. else if (FixedModelProvider.includes(a.provider))
  58. return -1
  59. else if (FixedModelProvider.includes(b.provider))
  60. return 1
  61. return 0
  62. })
  63. return [configuredProviders, notConfiguredProviders]
  64. }, [providers])
  65. const [filteredConfiguredProviders, filteredNotConfiguredProviders] = useMemo(() => {
  66. const filteredConfiguredProviders = configuredProviders.filter(
  67. provider => provider.provider.toLowerCase().includes(debouncedSearchText.toLowerCase())
  68. || Object.values(provider.label).some(text => text.toLowerCase().includes(debouncedSearchText.toLowerCase())),
  69. )
  70. const filteredNotConfiguredProviders = notConfiguredProviders.filter(
  71. provider => provider.provider.toLowerCase().includes(debouncedSearchText.toLowerCase())
  72. || Object.values(provider.label).some(text => text.toLowerCase().includes(debouncedSearchText.toLowerCase())),
  73. )
  74. return [filteredConfiguredProviders, filteredNotConfiguredProviders]
  75. }, [configuredProviders, debouncedSearchText, notConfiguredProviders])
  76. return (
  77. <div className='relative -mt-2 pt-1'>
  78. <div className={cn('mb-2 flex items-center')}>
  79. <div className='system-md-semibold grow text-text-primary'>{t('common.modelProvider.models')}</div>
  80. <div className={cn(
  81. 'relative flex shrink-0 items-center justify-end gap-2 rounded-lg border border-transparent p-px',
  82. defaultModelNotConfigured && 'border-components-panel-border bg-components-panel-bg-blur pl-2 shadow-xs',
  83. )}>
  84. {defaultModelNotConfigured && <div className='absolute bottom-0 left-0 right-0 top-0 opacity-40' style={{ background: 'linear-gradient(92deg, rgba(247, 144, 9, 0.25) 0%, rgba(255, 255, 255, 0.00) 100%)' }} />}
  85. {defaultModelNotConfigured && (
  86. <div className='system-xs-medium flex items-center gap-1 text-text-primary'>
  87. <RiAlertFill className='h-4 w-4 text-text-warning-secondary' />
  88. {t('common.modelProvider.notConfigured')}
  89. </div>
  90. )}
  91. <SystemModelSelector
  92. notConfigured={defaultModelNotConfigured}
  93. textGenerationDefaultModel={textGenerationDefaultModel}
  94. embeddingsDefaultModel={embeddingsDefaultModel}
  95. rerankDefaultModel={rerankDefaultModel}
  96. speech2textDefaultModel={speech2textDefaultModel}
  97. ttsDefaultModel={ttsDefaultModel}
  98. />
  99. </div>
  100. </div>
  101. {!filteredConfiguredProviders?.length && (
  102. <div className='mb-2 rounded-[10px] bg-workflow-process-bg p-4'>
  103. <div className='flex h-10 w-10 items-center justify-center rounded-[10px] border-[0.5px] border-components-card-border bg-components-card-bg shadow-lg backdrop-blur'>
  104. <RiBrainLine className='h-5 w-5 text-text-primary' />
  105. </div>
  106. <div className='system-sm-medium mt-2 text-text-secondary'>{t('common.modelProvider.emptyProviderTitle')}</div>
  107. <div className='system-xs-regular mt-1 text-text-tertiary'>{t('common.modelProvider.emptyProviderTip')}</div>
  108. </div>
  109. )}
  110. {!!filteredConfiguredProviders?.length && (
  111. <div className='relative'>
  112. {filteredConfiguredProviders?.map(provider => (
  113. <ProviderAddedCard
  114. key={provider.provider}
  115. provider={provider}
  116. />
  117. ))}
  118. </div>
  119. )}
  120. {!!filteredNotConfiguredProviders?.length && (
  121. <>
  122. <div className='system-md-semibold mb-2 flex items-center pt-2 text-text-primary'>{t('common.modelProvider.toBeConfigured')}</div>
  123. <div className='relative'>
  124. {filteredNotConfiguredProviders?.map(provider => (
  125. <ProviderAddedCard
  126. notConfigured
  127. key={provider.provider}
  128. provider={provider}
  129. />
  130. ))}
  131. </div>
  132. </>
  133. )}
  134. {
  135. enable_marketplace && (
  136. <InstallFromMarketplace
  137. providers={providers}
  138. searchText={searchText}
  139. />
  140. )
  141. }
  142. </div>
  143. )
  144. }
  145. export default ModelProviderPage