您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. import {
  2. uniq,
  3. xorBy,
  4. } from 'lodash-es'
  5. import type { MultipleRetrievalConfig } from './types'
  6. import type {
  7. DataSet,
  8. SelectedDatasetsMode,
  9. } from '@/models/datasets'
  10. import {
  11. DEFAULT_WEIGHTED_SCORE,
  12. RerankingModeEnum,
  13. } from '@/models/datasets'
  14. import { RETRIEVE_METHOD } from '@/types/app'
  15. import { DATASET_DEFAULT } from '@/config'
  16. export const checkNodeValid = () => {
  17. return true
  18. }
  19. export const getSelectedDatasetsMode = (datasets: DataSet[] = []) => {
  20. if (datasets === null)
  21. datasets = []
  22. let allHighQuality = true
  23. let allHighQualityVectorSearch = true
  24. let allHighQualityFullTextSearch = true
  25. let allEconomic = true
  26. let mixtureHighQualityAndEconomic = true
  27. let allExternal = true
  28. let allInternal = true
  29. let mixtureInternalAndExternal = true
  30. let inconsistentEmbeddingModel = false
  31. if (!datasets.length) {
  32. allHighQuality = false
  33. allHighQualityVectorSearch = false
  34. allHighQualityFullTextSearch = false
  35. allEconomic = false
  36. mixtureHighQualityAndEconomic = false
  37. allExternal = false
  38. allInternal = false
  39. mixtureInternalAndExternal = false
  40. }
  41. datasets.forEach((dataset) => {
  42. if (dataset.indexing_technique === 'economy') {
  43. allHighQuality = false
  44. allHighQualityVectorSearch = false
  45. allHighQualityFullTextSearch = false
  46. }
  47. if (dataset.indexing_technique === 'high_quality') {
  48. allEconomic = false
  49. if (dataset.retrieval_model_dict.search_method !== RETRIEVE_METHOD.semantic)
  50. allHighQualityVectorSearch = false
  51. if (dataset.retrieval_model_dict.search_method !== RETRIEVE_METHOD.fullText)
  52. allHighQualityFullTextSearch = false
  53. }
  54. if (dataset.provider !== 'external') {
  55. allExternal = false
  56. }
  57. else {
  58. allInternal = false
  59. allHighQuality = false
  60. allHighQualityVectorSearch = false
  61. allHighQualityFullTextSearch = false
  62. mixtureHighQualityAndEconomic = false
  63. }
  64. })
  65. if (allExternal || allInternal)
  66. mixtureInternalAndExternal = false
  67. if (allHighQuality || allEconomic)
  68. mixtureHighQualityAndEconomic = false
  69. if (allHighQuality)
  70. inconsistentEmbeddingModel = uniq(datasets.map(item => item.embedding_model)).length > 1
  71. return {
  72. allHighQuality,
  73. allHighQualityVectorSearch,
  74. allHighQualityFullTextSearch,
  75. allEconomic,
  76. mixtureHighQualityAndEconomic,
  77. allInternal,
  78. allExternal,
  79. mixtureInternalAndExternal,
  80. inconsistentEmbeddingModel,
  81. } as SelectedDatasetsMode
  82. }
  83. export const getMultipleRetrievalConfig = (
  84. multipleRetrievalConfig: MultipleRetrievalConfig,
  85. selectedDatasets: DataSet[],
  86. originalDatasets: DataSet[],
  87. validRerankModel?: { provider?: string; model?: string },
  88. ) => {
  89. const shouldSetWeightDefaultValue = xorBy(selectedDatasets, originalDatasets, 'id').length > 0
  90. const rerankModelIsValid = validRerankModel?.provider && validRerankModel?.model
  91. const {
  92. allHighQuality,
  93. allHighQualityVectorSearch,
  94. allHighQualityFullTextSearch,
  95. allEconomic,
  96. mixtureHighQualityAndEconomic,
  97. allInternal,
  98. allExternal,
  99. mixtureInternalAndExternal,
  100. inconsistentEmbeddingModel,
  101. } = getSelectedDatasetsMode(selectedDatasets)
  102. const {
  103. top_k = DATASET_DEFAULT.top_k,
  104. score_threshold,
  105. reranking_mode,
  106. reranking_model,
  107. weights,
  108. reranking_enable,
  109. } = multipleRetrievalConfig || { top_k: DATASET_DEFAULT.top_k }
  110. const result = {
  111. top_k,
  112. score_threshold,
  113. reranking_mode,
  114. reranking_model,
  115. weights,
  116. reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : shouldSetWeightDefaultValue,
  117. }
  118. const setDefaultWeights = () => {
  119. result.weights = {
  120. vector_setting: {
  121. vector_weight: allHighQualityVectorSearch
  122. ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.semantic
  123. : allHighQualityFullTextSearch
  124. ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.semantic
  125. : DEFAULT_WEIGHTED_SCORE.other.semantic,
  126. embedding_provider_name: selectedDatasets[0].embedding_model_provider,
  127. embedding_model_name: selectedDatasets[0].embedding_model,
  128. },
  129. keyword_setting: {
  130. keyword_weight: allHighQualityVectorSearch
  131. ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.keyword
  132. : allHighQualityFullTextSearch
  133. ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.keyword
  134. : DEFAULT_WEIGHTED_SCORE.other.keyword,
  135. },
  136. }
  137. }
  138. if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal) {
  139. result.reranking_mode = RerankingModeEnum.RerankingModel
  140. if (!result.reranking_model?.provider || !result.reranking_model?.model) {
  141. if (rerankModelIsValid) {
  142. result.reranking_enable = reranking_enable !== false
  143. result.reranking_model = {
  144. provider: validRerankModel?.provider || '',
  145. model: validRerankModel?.model || '',
  146. }
  147. }
  148. else {
  149. result.reranking_model = {
  150. provider: '',
  151. model: '',
  152. }
  153. }
  154. }
  155. else {
  156. result.reranking_enable = reranking_enable !== false
  157. }
  158. }
  159. if (allHighQuality && !inconsistentEmbeddingModel && allInternal) {
  160. if (!reranking_mode) {
  161. if (validRerankModel?.provider && validRerankModel?.model) {
  162. result.reranking_mode = RerankingModeEnum.RerankingModel
  163. result.reranking_enable = reranking_enable !== false
  164. result.reranking_model = {
  165. provider: validRerankModel.provider,
  166. model: validRerankModel.model,
  167. }
  168. }
  169. else {
  170. result.reranking_mode = RerankingModeEnum.WeightedScore
  171. setDefaultWeights()
  172. }
  173. }
  174. if (reranking_mode === RerankingModeEnum.WeightedScore && !weights)
  175. setDefaultWeights()
  176. if (reranking_mode === RerankingModeEnum.WeightedScore && weights && shouldSetWeightDefaultValue) {
  177. if (rerankModelIsValid) {
  178. result.reranking_mode = RerankingModeEnum.RerankingModel
  179. result.reranking_enable = reranking_enable !== false
  180. result.reranking_model = {
  181. provider: validRerankModel.provider || '',
  182. model: validRerankModel.model || '',
  183. }
  184. }
  185. else {
  186. setDefaultWeights()
  187. }
  188. }
  189. if (reranking_mode === RerankingModeEnum.RerankingModel && !rerankModelIsValid && shouldSetWeightDefaultValue) {
  190. result.reranking_mode = RerankingModeEnum.WeightedScore
  191. setDefaultWeights()
  192. }
  193. }
  194. return result
  195. }
  196. export const checkoutRerankModelConfigedInRetrievalSettings = (
  197. datasets: DataSet[],
  198. multipleRetrievalConfig?: MultipleRetrievalConfig,
  199. ) => {
  200. if (!multipleRetrievalConfig)
  201. return true
  202. const {
  203. allEconomic,
  204. allExternal,
  205. } = getSelectedDatasetsMode(datasets)
  206. const {
  207. reranking_enable,
  208. reranking_mode,
  209. reranking_model,
  210. } = multipleRetrievalConfig
  211. if (reranking_mode === RerankingModeEnum.RerankingModel && (!reranking_model?.provider || !reranking_model?.model)) {
  212. if ((allEconomic || allExternal) && !reranking_enable)
  213. return true
  214. return false
  215. }
  216. return true
  217. }