You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

index.tsx 9.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. import type {
  2. FC,
  3. ReactNode,
  4. } from 'react'
  5. import { useMemo, useState } from 'react'
  6. import { useTranslation } from 'react-i18next'
  7. import type {
  8. DefaultModel,
  9. FormValue,
  10. } from '@/app/components/header/account-setting/model-provider-page/declarations'
  11. import { ModelStatusEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
  12. import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
  13. import {
  14. useModelList,
  15. } from '@/app/components/header/account-setting/model-provider-page/hooks'
  16. import AgentModelTrigger from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/agent-model-trigger'
  17. import Trigger from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/trigger'
  18. import type { TriggerProps } from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/trigger'
  19. import {
  20. PortalToFollowElem,
  21. PortalToFollowElemContent,
  22. PortalToFollowElemTrigger,
  23. } from '@/app/components/base/portal-to-follow-elem'
  24. import LLMParamsPanel from './llm-params-panel'
  25. import TTSParamsPanel from './tts-params-panel'
  26. import { useProviderContext } from '@/context/provider-context'
  27. import cn from '@/utils/classnames'
  28. import Toast from '@/app/components/base/toast'
  29. import { fetchAndMergeValidCompletionParams } from '@/utils/completion-params'
  30. export type ModelParameterModalProps = {
  31. popupClassName?: string
  32. portalToFollowElemContentClassName?: string
  33. isAdvancedMode: boolean
  34. value: any
  35. setModel: (model: any) => void
  36. renderTrigger?: (v: TriggerProps) => ReactNode
  37. readonly?: boolean
  38. isInWorkflow?: boolean
  39. isAgentStrategy?: boolean
  40. scope?: string
  41. }
  42. const ModelParameterModal: FC<ModelParameterModalProps> = ({
  43. popupClassName,
  44. portalToFollowElemContentClassName,
  45. isAdvancedMode,
  46. value,
  47. setModel,
  48. renderTrigger,
  49. readonly,
  50. isInWorkflow,
  51. isAgentStrategy,
  52. scope = ModelTypeEnum.textGeneration,
  53. }) => {
  54. const { t } = useTranslation()
  55. const { isAPIKeySet } = useProviderContext()
  56. const [open, setOpen] = useState(false)
  57. const scopeArray = scope.split('&')
  58. const scopeFeatures = useMemo(() => {
  59. if (scopeArray.includes('all'))
  60. return []
  61. return scopeArray.filter(item => ![
  62. ModelTypeEnum.textGeneration,
  63. ModelTypeEnum.textEmbedding,
  64. ModelTypeEnum.rerank,
  65. ModelTypeEnum.moderation,
  66. ModelTypeEnum.speech2text,
  67. ModelTypeEnum.tts,
  68. ].includes(item as ModelTypeEnum))
  69. }, [scopeArray])
  70. const { data: textGenerationList } = useModelList(ModelTypeEnum.textGeneration)
  71. const { data: textEmbeddingList } = useModelList(ModelTypeEnum.textEmbedding)
  72. const { data: rerankList } = useModelList(ModelTypeEnum.rerank)
  73. const { data: moderationList } = useModelList(ModelTypeEnum.moderation)
  74. const { data: sttList } = useModelList(ModelTypeEnum.speech2text)
  75. const { data: ttsList } = useModelList(ModelTypeEnum.tts)
  76. const scopedModelList = useMemo(() => {
  77. const resultList: any[] = []
  78. if (scopeArray.includes('all')) {
  79. return [
  80. ...textGenerationList,
  81. ...textEmbeddingList,
  82. ...rerankList,
  83. ...sttList,
  84. ...ttsList,
  85. ...moderationList,
  86. ]
  87. }
  88. if (scopeArray.includes(ModelTypeEnum.textGeneration))
  89. return textGenerationList
  90. if (scopeArray.includes(ModelTypeEnum.textEmbedding))
  91. return textEmbeddingList
  92. if (scopeArray.includes(ModelTypeEnum.rerank))
  93. return rerankList
  94. if (scopeArray.includes(ModelTypeEnum.moderation))
  95. return moderationList
  96. if (scopeArray.includes(ModelTypeEnum.speech2text))
  97. return sttList
  98. if (scopeArray.includes(ModelTypeEnum.tts))
  99. return ttsList
  100. return resultList
  101. }, [scopeArray, textGenerationList, textEmbeddingList, rerankList, sttList, ttsList, moderationList])
  102. const { currentProvider, currentModel } = useMemo(() => {
  103. const currentProvider = scopedModelList.find(item => item.provider === value?.provider)
  104. const currentModel = currentProvider?.models.find((model: { model: string }) => model.model === value?.model)
  105. return {
  106. currentProvider,
  107. currentModel,
  108. }
  109. }, [scopedModelList, value?.provider, value?.model])
  110. const hasDeprecated = useMemo(() => {
  111. return !currentProvider || !currentModel
  112. }, [currentModel, currentProvider])
  113. const modelDisabled = useMemo(() => {
  114. return currentModel?.status !== ModelStatusEnum.active
  115. }, [currentModel?.status])
  116. const disabled = useMemo(() => {
  117. return !isAPIKeySet || hasDeprecated || modelDisabled
  118. }, [hasDeprecated, isAPIKeySet, modelDisabled])
  119. const handleChangeModel = async ({ provider, model }: DefaultModel) => {
  120. const targetProvider = scopedModelList.find(modelItem => modelItem.provider === provider)
  121. const targetModelItem = targetProvider?.models.find((modelItem: { model: string }) => modelItem.model === model)
  122. const model_type = targetModelItem?.model_type as string
  123. let nextCompletionParams: FormValue = {}
  124. if (model_type === ModelTypeEnum.textGeneration) {
  125. try {
  126. const { params: filtered, removedDetails } = await fetchAndMergeValidCompletionParams(
  127. provider,
  128. model,
  129. value?.completion_params,
  130. isAdvancedMode,
  131. )
  132. nextCompletionParams = filtered
  133. const keys = Object.keys(removedDetails || {})
  134. if (keys.length) {
  135. Toast.notify({
  136. type: 'warning',
  137. message: `${t('common.modelProvider.parametersInvalidRemoved')}: ${keys.map(k => `${k} (${removedDetails[k]})`).join(', ')}`,
  138. })
  139. }
  140. }
  141. catch (e) {
  142. Toast.notify({ type: 'error', message: t('common.error') })
  143. }
  144. }
  145. setModel({
  146. provider,
  147. model,
  148. model_type,
  149. ...(model_type === ModelTypeEnum.textGeneration ? {
  150. mode: targetModelItem?.model_properties.mode as string,
  151. completion_params: nextCompletionParams,
  152. } : {}),
  153. })
  154. }
  155. const handleLLMParamsChange = (newParams: FormValue) => {
  156. const newValue = {
  157. ...value?.completionParams,
  158. completion_params: newParams,
  159. }
  160. setModel({
  161. ...value,
  162. ...newValue,
  163. })
  164. }
  165. const handleTTSParamsChange = (language: string, voice: string) => {
  166. setModel({
  167. ...value,
  168. language,
  169. voice,
  170. })
  171. }
  172. return (
  173. <PortalToFollowElem
  174. open={open}
  175. onOpenChange={setOpen}
  176. placement={isInWorkflow ? 'left' : 'bottom-end'}
  177. offset={4}
  178. >
  179. <div className='relative'>
  180. <PortalToFollowElemTrigger
  181. onClick={() => {
  182. if (readonly)
  183. return
  184. setOpen(v => !v)
  185. }}
  186. className='block'
  187. >
  188. {
  189. renderTrigger
  190. ? renderTrigger({
  191. open,
  192. disabled,
  193. modelDisabled,
  194. hasDeprecated,
  195. currentProvider,
  196. currentModel,
  197. providerName: value?.provider,
  198. modelId: value?.model,
  199. })
  200. : (isAgentStrategy
  201. ? <AgentModelTrigger
  202. disabled={disabled}
  203. hasDeprecated={hasDeprecated}
  204. currentProvider={currentProvider}
  205. currentModel={currentModel}
  206. providerName={value?.provider}
  207. modelId={value?.model}
  208. scope={scope}
  209. />
  210. : <Trigger
  211. disabled={disabled}
  212. isInWorkflow={isInWorkflow}
  213. modelDisabled={modelDisabled}
  214. hasDeprecated={hasDeprecated}
  215. currentProvider={currentProvider}
  216. currentModel={currentModel}
  217. providerName={value?.provider}
  218. modelId={value?.model}
  219. />
  220. )
  221. }
  222. </PortalToFollowElemTrigger>
  223. <PortalToFollowElemContent className={cn('z-50', portalToFollowElemContentClassName)}>
  224. <div className={cn(popupClassName, 'w-[389px] rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg')}>
  225. <div className={cn('max-h-[420px] overflow-y-auto p-4 pt-3')}>
  226. <div className='relative'>
  227. <div className={cn('system-sm-semibold mb-1 flex h-6 items-center text-text-secondary')}>
  228. {t('common.modelProvider.model').toLocaleUpperCase()}
  229. </div>
  230. <ModelSelector
  231. defaultModel={(value?.provider || value?.model) ? { provider: value?.provider, model: value?.model } : undefined}
  232. modelList={scopedModelList}
  233. scopeFeatures={scopeFeatures}
  234. onSelect={handleChangeModel}
  235. />
  236. </div>
  237. {(currentModel?.model_type === ModelTypeEnum.textGeneration || currentModel?.model_type === ModelTypeEnum.tts) && (
  238. <div className='my-3 h-px bg-divider-subtle' />
  239. )}
  240. {currentModel?.model_type === ModelTypeEnum.textGeneration && (
  241. <LLMParamsPanel
  242. provider={value?.provider}
  243. modelId={value?.model}
  244. completionParams={value?.completion_params || {}}
  245. onCompletionParamsChange={handleLLMParamsChange}
  246. isAdvancedMode={isAdvancedMode}
  247. />
  248. )}
  249. {currentModel?.model_type === ModelTypeEnum.tts && (
  250. <TTSParamsPanel
  251. currentModel={currentModel}
  252. language={value?.language}
  253. voice={value?.voice}
  254. onChange={handleTTSParamsChange}
  255. />
  256. )}
  257. </div>
  258. </div>
  259. </PortalToFollowElemContent>
  260. </div>
  261. </PortalToFollowElem>
  262. )
  263. }
  264. export default ModelParameterModal