Selaa lähdekoodia

Fix/remove tsne position test (#5858)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
tags/0.6.13
Jyong 1 vuosi sitten
vanhempi
commit
0944ca9d91
No account linked to committer's email address

+ 2
- 29
api/services/hit_testing_service.py Näytä tiedosto

import numpy as np import numpy as np
from sklearn.manifold import TSNE from sklearn.manifold import TSNE


from core.embedding.cached_embedding import CacheEmbedding
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document from core.rag.models.document import Document
from core.rag.retrieval.retrival_methods import RetrievalMethod from core.rag.retrieval.retrival_methods import RetrievalMethod
if not retrieval_model: if not retrieval_model:
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model


# get embedding model
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
provider=dataset.embedding_model_provider,
model=dataset.embedding_model
)

embeddings = CacheEmbedding(embedding_model)

all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
dataset_id=dataset.id, dataset_id=dataset.id,
query=query, query=query,
db.session.add(dataset_query) db.session.add(dataset_query)
db.session.commit() db.session.commit()


return cls.compact_retrieve_response(dataset, embeddings, query, all_documents)
return cls.compact_retrieve_response(dataset, query, all_documents)


@classmethod @classmethod
def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: list[Document]):
text_embeddings = [
embeddings.embed_query(query)
]

text_embeddings.extend(embeddings.embed_documents([document.page_content for document in documents]))

tsne_position_data = cls.get_tsne_positions_from_embeddings(text_embeddings)

query_position = tsne_position_data.pop(0)

def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]):
i = 0 i = 0
records = [] records = []
for document in documents: for document in documents:
record = { record = {
"segment": segment, "segment": segment,
"score": document.metadata.get('score', None), "score": document.metadata.get('score', None),
"tsne_position": tsne_position_data[i]
} }


records.append(record) records.append(record)
return { return {
"query": { "query": {
"content": query, "content": query,
"tsne_position": query_position,
}, },
"records": records "records": records
} }

+ 3
- 52
web/app/components/datasets/hit-testing/hit-detail.tsx Näytä tiedosto

import React from 'react' import React from 'react'
import cn from 'classnames' import cn from 'classnames'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import ReactECharts from 'echarts-for-react'
import { SegmentIndexTag } from '../documents/detail/completed' import { SegmentIndexTag } from '../documents/detail/completed'
import s from '../documents/detail/completed/style.module.css' import s from '../documents/detail/completed/style.module.css'
import type { SegmentDetailModel } from '@/models/datasets' import type { SegmentDetailModel } from '@/models/datasets'
import Divider from '@/app/components/base/divider' import Divider from '@/app/components/base/divider'


type IScatterChartProps = {
data: Array<number[]>
curr: Array<number[]>
}

const ScatterChart: FC<IScatterChartProps> = ({ data, curr }) => {
const option = {
xAxis: {},
yAxis: {},
tooltip: {
trigger: 'item',
axisPointer: {
type: 'cross',
},
},
series: [
{
type: 'effectScatter',
symbolSize: 5,
data: curr,
},
{
type: 'scatter',
symbolSize: 5,
data,
},
],
}
return (
<ReactECharts option={option} style={{ height: 380, width: 430 }} />
)
}

type IHitDetailProps = { type IHitDetailProps = {
segInfo?: Partial<SegmentDetailModel> & { id: string } segInfo?: Partial<SegmentDetailModel> & { id: string }
vectorInfo?: { curr: Array<number[]>; points: Array<number[]> }
} }


const HitDetail: FC<IHitDetailProps> = ({ segInfo, vectorInfo }) => {
const HitDetail: FC<IHitDetailProps> = ({ segInfo }) => {
const { t } = useTranslation() const { t } = useTranslation()


const renderContent = () => { const renderContent = () => {
} }


return ( return (
<div className='flex flex-row overflow-x-auto'>
<div className="flex-1 bg-gray-25 p-6 min-w-[300px]">
<div className='overflow-x-auto'>
<div className="bg-gray-25 p-6">
<div className="flex items-center"> <div className="flex items-center">
<SegmentIndexTag <SegmentIndexTag
positionId={segInfo?.position || ''} positionId={segInfo?.position || ''}
})} })}
</div> </div>
</div> </div>
<div className="flex-1 bg-white p-6">
<div className="flex items-center">
<div className={cn(s.commonIcon, s.bezierCurveIcon)} />
<span className={s.numberInfo}>
{t('datasetDocuments.segment.vectorHash')}
</span>
</div>
<div
className={cn(s.numberInfo, 'w-[400px] truncate text-gray-700 mt-1')}
>
{segInfo?.index_node_hash}
</div>
<ScatterChart data={vectorInfo?.points || []} curr={vectorInfo?.curr || []} />
</div>
</div> </div>
) )
} }

+ 2
- 8
web/app/components/datasets/hit-testing/index.tsx Näytä tiedosto

'use client' 'use client'
import type { FC } from 'react' import type { FC } from 'react'
import React, { useEffect, useMemo, useState } from 'react'
import React, { useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import useSWR from 'swr' import useSWR from 'swr'
import { omit } from 'lodash-es' import { omit } from 'lodash-es'


const total = recordsRes?.total || 0 const total = recordsRes?.total || 0


const points = useMemo(() => (hitResult?.records.map(v => [v.tsne_position.x, v.tsne_position.y]) || []), [hitResult?.records])

const onClickCard = (detail: HitTestingType) => { const onClickCard = (detail: HitTestingType) => {
setCurrParagraph({ paraInfo: detail, showModal: true }) setCurrParagraph({ paraInfo: detail, showModal: true })
} }
</div> </div>
</FloatRightContainer> </FloatRightContainer>
<Modal <Modal
className='!max-w-[960px] !p-0'
className='w-[520px] p-0'
closable closable
onClose={() => setCurrParagraph({ showModal: false })} onClose={() => setCurrParagraph({ showModal: false })}
isShow={currParagraph.showModal} isShow={currParagraph.showModal}
> >
{currParagraph.showModal && <HitDetail {currParagraph.showModal && <HitDetail
segInfo={currParagraph.paraInfo?.segment} segInfo={currParagraph.paraInfo?.segment}
vectorInfo={{
curr: [[currParagraph.paraInfo?.tsne_position?.x || 0, currParagraph.paraInfo?.tsne_position.y || 0]],
points,
}}
/>} />}
</Modal> </Modal>
<Drawer isOpen={isShowModifyRetrievalModal} onClose={() => setIsShowModifyRetrievalModal(false)} footer={null} mask={isMobile} panelClassname='mt-16 mx-2 sm:mr-2 mb-3 !p-0 !max-w-[640px] rounded-xl'> <Drawer isOpen={isShowModifyRetrievalModal} onClose={() => setIsShowModifyRetrievalModal(false)} footer={null} mask={isMobile} panelClassname='mt-16 mx-2 sm:mr-2 mb-3 !p-0 !max-w-[640px] rounded-xl'>

Loading…
Peruuta
Tallenna