Co-authored-by: StyleZhang <jasonapring2015@outlook.com>tags/0.6.13
| import numpy as np | import numpy as np | ||||
| from sklearn.manifold import TSNE | from sklearn.manifold import TSNE | ||||
| from core.embedding.cached_embedding import CacheEmbedding | |||||
| from core.model_manager import ModelManager | |||||
| from core.model_runtime.entities.model_entities import ModelType | |||||
| from core.rag.datasource.entity.embedding import Embeddings | |||||
| from core.rag.datasource.retrieval_service import RetrievalService | from core.rag.datasource.retrieval_service import RetrievalService | ||||
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| from core.rag.retrieval.retrival_methods import RetrievalMethod | from core.rag.retrieval.retrival_methods import RetrievalMethod | ||||
| if not retrieval_model: | if not retrieval_model: | ||||
| retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | ||||
| # get embedding model | |||||
| model_manager = ModelManager() | |||||
| embedding_model = model_manager.get_model_instance( | |||||
| tenant_id=dataset.tenant_id, | |||||
| model_type=ModelType.TEXT_EMBEDDING, | |||||
| provider=dataset.embedding_model_provider, | |||||
| model=dataset.embedding_model | |||||
| ) | |||||
| embeddings = CacheEmbedding(embedding_model) | |||||
| all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], | all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], | ||||
| dataset_id=dataset.id, | dataset_id=dataset.id, | ||||
| query=query, | query=query, | ||||
| db.session.add(dataset_query) | db.session.add(dataset_query) | ||||
| db.session.commit() | db.session.commit() | ||||
| return cls.compact_retrieve_response(dataset, embeddings, query, all_documents) | |||||
| return cls.compact_retrieve_response(dataset, query, all_documents) | |||||
| @classmethod | @classmethod | ||||
| def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: list[Document]): | |||||
| text_embeddings = [ | |||||
| embeddings.embed_query(query) | |||||
| ] | |||||
| text_embeddings.extend(embeddings.embed_documents([document.page_content for document in documents])) | |||||
| tsne_position_data = cls.get_tsne_positions_from_embeddings(text_embeddings) | |||||
| query_position = tsne_position_data.pop(0) | |||||
| def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]): | |||||
| i = 0 | i = 0 | ||||
| records = [] | records = [] | ||||
| for document in documents: | for document in documents: | ||||
| record = { | record = { | ||||
| "segment": segment, | "segment": segment, | ||||
| "score": document.metadata.get('score', None), | "score": document.metadata.get('score', None), | ||||
| "tsne_position": tsne_position_data[i] | |||||
| } | } | ||||
| records.append(record) | records.append(record) | ||||
| return { | return { | ||||
| "query": { | "query": { | ||||
| "content": query, | "content": query, | ||||
| "tsne_position": query_position, | |||||
| }, | }, | ||||
| "records": records | "records": records | ||||
| } | } | 
| import React from 'react' | import React from 'react' | ||||
| import cn from 'classnames' | import cn from 'classnames' | ||||
| import { useTranslation } from 'react-i18next' | import { useTranslation } from 'react-i18next' | ||||
| import ReactECharts from 'echarts-for-react' | |||||
| import { SegmentIndexTag } from '../documents/detail/completed' | import { SegmentIndexTag } from '../documents/detail/completed' | ||||
| import s from '../documents/detail/completed/style.module.css' | import s from '../documents/detail/completed/style.module.css' | ||||
| import type { SegmentDetailModel } from '@/models/datasets' | import type { SegmentDetailModel } from '@/models/datasets' | ||||
| import Divider from '@/app/components/base/divider' | import Divider from '@/app/components/base/divider' | ||||
| type IScatterChartProps = { | |||||
| data: Array<number[]> | |||||
| curr: Array<number[]> | |||||
| } | |||||
| const ScatterChart: FC<IScatterChartProps> = ({ data, curr }) => { | |||||
| const option = { | |||||
| xAxis: {}, | |||||
| yAxis: {}, | |||||
| tooltip: { | |||||
| trigger: 'item', | |||||
| axisPointer: { | |||||
| type: 'cross', | |||||
| }, | |||||
| }, | |||||
| series: [ | |||||
| { | |||||
| type: 'effectScatter', | |||||
| symbolSize: 5, | |||||
| data: curr, | |||||
| }, | |||||
| { | |||||
| type: 'scatter', | |||||
| symbolSize: 5, | |||||
| data, | |||||
| }, | |||||
| ], | |||||
| } | |||||
| return ( | |||||
| <ReactECharts option={option} style={{ height: 380, width: 430 }} /> | |||||
| ) | |||||
| } | |||||
| type IHitDetailProps = { | type IHitDetailProps = { | ||||
| segInfo?: Partial<SegmentDetailModel> & { id: string } | segInfo?: Partial<SegmentDetailModel> & { id: string } | ||||
| vectorInfo?: { curr: Array<number[]>; points: Array<number[]> } | |||||
| } | } | ||||
| const HitDetail: FC<IHitDetailProps> = ({ segInfo, vectorInfo }) => { | |||||
| const HitDetail: FC<IHitDetailProps> = ({ segInfo }) => { | |||||
| const { t } = useTranslation() | const { t } = useTranslation() | ||||
| const renderContent = () => { | const renderContent = () => { | ||||
| } | } | ||||
| return ( | return ( | ||||
| <div className='flex flex-row overflow-x-auto'> | |||||
| <div className="flex-1 bg-gray-25 p-6 min-w-[300px]"> | |||||
| <div className='overflow-x-auto'> | |||||
| <div className="bg-gray-25 p-6"> | |||||
| <div className="flex items-center"> | <div className="flex items-center"> | ||||
| <SegmentIndexTag | <SegmentIndexTag | ||||
| positionId={segInfo?.position || ''} | positionId={segInfo?.position || ''} | ||||
| })} | })} | ||||
| </div> | </div> | ||||
| </div> | </div> | ||||
| <div className="flex-1 bg-white p-6"> | |||||
| <div className="flex items-center"> | |||||
| <div className={cn(s.commonIcon, s.bezierCurveIcon)} /> | |||||
| <span className={s.numberInfo}> | |||||
| {t('datasetDocuments.segment.vectorHash')} | |||||
| </span> | |||||
| </div> | |||||
| <div | |||||
| className={cn(s.numberInfo, 'w-[400px] truncate text-gray-700 mt-1')} | |||||
| > | |||||
| {segInfo?.index_node_hash} | |||||
| </div> | |||||
| <ScatterChart data={vectorInfo?.points || []} curr={vectorInfo?.curr || []} /> | |||||
| </div> | |||||
| </div> | </div> | ||||
| ) | ) | ||||
| } | } | 
| 'use client' | 'use client' | ||||
| import type { FC } from 'react' | import type { FC } from 'react' | ||||
| import React, { useEffect, useMemo, useState } from 'react' | |||||
| import React, { useEffect, useState } from 'react' | |||||
| import { useTranslation } from 'react-i18next' | import { useTranslation } from 'react-i18next' | ||||
| import useSWR from 'swr' | import useSWR from 'swr' | ||||
| import { omit } from 'lodash-es' | import { omit } from 'lodash-es' | ||||
| const total = recordsRes?.total || 0 | const total = recordsRes?.total || 0 | ||||
| const points = useMemo(() => (hitResult?.records.map(v => [v.tsne_position.x, v.tsne_position.y]) || []), [hitResult?.records]) | |||||
| const onClickCard = (detail: HitTestingType) => { | const onClickCard = (detail: HitTestingType) => { | ||||
| setCurrParagraph({ paraInfo: detail, showModal: true }) | setCurrParagraph({ paraInfo: detail, showModal: true }) | ||||
| } | } | ||||
| </div> | </div> | ||||
| </FloatRightContainer> | </FloatRightContainer> | ||||
| <Modal | <Modal | ||||
| className='!max-w-[960px] !p-0' | |||||
| className='w-[520px] p-0' | |||||
| closable | closable | ||||
| onClose={() => setCurrParagraph({ showModal: false })} | onClose={() => setCurrParagraph({ showModal: false })} | ||||
| isShow={currParagraph.showModal} | isShow={currParagraph.showModal} | ||||
| > | > | ||||
| {currParagraph.showModal && <HitDetail | {currParagraph.showModal && <HitDetail | ||||
| segInfo={currParagraph.paraInfo?.segment} | segInfo={currParagraph.paraInfo?.segment} | ||||
| vectorInfo={{ | |||||
| curr: [[currParagraph.paraInfo?.tsne_position?.x || 0, currParagraph.paraInfo?.tsne_position.y || 0]], | |||||
| points, | |||||
| }} | |||||
| />} | />} | ||||
| </Modal> | </Modal> | ||||
| <Drawer isOpen={isShowModifyRetrievalModal} onClose={() => setIsShowModifyRetrievalModal(false)} footer={null} mask={isMobile} panelClassname='mt-16 mx-2 sm:mr-2 mb-3 !p-0 !max-w-[640px] rounded-xl'> | <Drawer isOpen={isShowModifyRetrievalModal} onClose={() => setIsShowModifyRetrievalModal(false)} footer={null} mask={isMobile} panelClassname='mt-16 mx-2 sm:mr-2 mb-3 !p-0 !max-w-[640px] rounded-xl'> |