Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: JzoNg <jzongcode@gmail.com>tags/0.3.12
| @@ -292,4 +292,3 @@ api.add_resource(DatasetDocumentSegmentAddApi, | |||
| '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment') | |||
| api.add_resource(DatasetDocumentSegmentUpdateApi, | |||
| '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>') | |||
| @@ -1,123 +0,0 @@ | |||
| import numpy as np | |||
| import sklearn.decomposition | |||
| import pickle | |||
| import time | |||
| # Apply 'Algorithm 1' to the ada-002 embeddings to make them isotropic, taken from the paper: | |||
| # ALL-BUT-THE-TOP: SIMPLE AND EFFECTIVE POST- PROCESSING FOR WORD REPRESENTATIONS | |||
| # Jiaqi Mu, Pramod Viswanath | |||
| # This uses Principal Component Analysis (PCA) to 'evenly distribute' the embedding vectors (make them isotropic) | |||
| # For more information on PCA, see https://jamesmccaffrey.wordpress.com/2021/07/16/computing-pca-using-numpy-without-scikit/ | |||
| # get the file pointer of the pickle containing the embeddings | |||
| fp = open('/path/to/your/data/Embedding-Latest.pkl', 'rb') | |||
| # the embedding data here is a dict consisting of key / value pairs | |||
| # the key is the hash of the message (SHA3-256), the value is the embedding from ada-002 (array of dimension 1536) | |||
| # the hash can be used to lookup the orignal text in a database | |||
| E = pickle.load(fp) # load the data into memory | |||
| # seperate the keys (hashes) and values (embeddings) into seperate vectors | |||
| K = list(E.keys()) # vector of all the hash values | |||
| X = np.array(list(E.values())) # vector of all the embeddings, converted to numpy arrays | |||
| # list the total number of embeddings | |||
| # this can be truncated if there are too many embeddings to do PCA on | |||
| print(f"Total number of embeddings: {len(X)}") | |||
| # get dimension of embeddings, used later | |||
| Dim = len(X[0]) | |||
| # flash out the first few embeddings | |||
| print("First two embeddings are: ") | |||
| print(X[0]) | |||
| print(f"First embedding length: {len(X[0])}") | |||
| print(X[1]) | |||
| print(f"Second embedding length: {len(X[1])}") | |||
| # compute the mean of all the embeddings, and flash the result | |||
| mu = np.mean(X, axis=0) # same as mu in paper | |||
| print(f"Mean embedding vector: {mu}") | |||
| print(f"Mean embedding vector length: {len(mu)}") | |||
| # subtract the mean vector from each embedding vector ... vectorized in numpy | |||
| X_tilde = X - mu # same as v_tilde(w) in paper | |||
| # do the heavy lifting of extracting the principal components | |||
| # note that this is a function of the embeddings you currently have here, and this set may grow over time | |||
| # therefore the PCA basis vectors may change over time, and your final isotropic embeddings may drift over time | |||
| # but the drift should stabilize after you have extracted enough embedding data to characterize the nature of the embedding engine | |||
| print(f"Performing PCA on the normalized embeddings ...") | |||
| pca = sklearn.decomposition.PCA() # new object | |||
| TICK = time.time() # start timer | |||
| pca.fit(X_tilde) # do the heavy lifting! | |||
| TOCK = time.time() # end timer | |||
| DELTA = TOCK - TICK | |||
| print(f"PCA finished in {DELTA} seconds ...") | |||
| # dimensional reduction stage (the only hyperparameter) | |||
| # pick max dimension of PCA components to express embddings | |||
| # in general this is some integer less than or equal to the dimension of your embeddings | |||
| # it could be set as a high percentile, say 95th percentile of pca.explained_variance_ratio_ | |||
| # but just hardcoding a constant here | |||
| D = 15 # hyperparameter on dimension (out of 1536 for ada-002), paper recommeds D = Dim/100 | |||
| # form the set of v_prime(w), which is the final embedding | |||
| # this could be vectorized in numpy to speed it up, but coding it directly here in a double for-loop to avoid errors and to be transparent | |||
| E_prime = dict() # output dict of the new embeddings | |||
| N = len(X_tilde) | |||
| N10 = round(N/10) | |||
| U = pca.components_ # set of PCA basis vectors, sorted by most significant to least significant | |||
| print(f"Shape of full set of PCA componenents {U.shape}") | |||
| U = U[0:D,:] # take the top D dimensions (or take them all if D is the size of the embedding vector) | |||
| print(f"Shape of downselected PCA componenents {U.shape}") | |||
| for ii in range(N): | |||
| v_tilde = X_tilde[ii] | |||
| v = X[ii] | |||
| v_projection = np.zeros(Dim) # start to build the projection | |||
| # project the original embedding onto the PCA basis vectors, use only first D dimensions | |||
| for jj in range(D): | |||
| u_jj = U[jj,:] # vector | |||
| v_jj = np.dot(u_jj,v) # scaler | |||
| v_projection += v_jj*u_jj # vector | |||
| v_prime = v_tilde - v_projection # final embedding vector | |||
| v_prime = v_prime/np.linalg.norm(v_prime) # create unit vector | |||
| E_prime[K[ii]] = v_prime | |||
| if (ii%N10 == 0) or (ii == N-1): | |||
| print(f"Finished with {ii+1} embeddings out of {N} ({round(100*ii/N)}% done)") | |||
| # save as new pickle | |||
| print("Saving new pickle ...") | |||
| embeddingName = '/path/to/your/data/Embedding-Latest-Isotropic.pkl' | |||
| with open(embeddingName, 'wb') as f: # Python 3: open(..., 'wb') | |||
| pickle.dump([E_prime,mu,U], f) | |||
| print(embeddingName) | |||
| print("Done!") | |||
| # When working with live data with a new embedding from ada-002, be sure to tranform it first with this function before comparing it | |||
| # | |||
| def projectEmbedding(v,mu,U): | |||
| v = np.array(v) | |||
| v_tilde = v - mu | |||
| v_projection = np.zeros(len(v)) # start to build the projection | |||
| # project the original embedding onto the PCA basis vectors, use only first D dimensions | |||
| for u in U: | |||
| v_jj = np.dot(u,v) # scaler | |||
| v_projection += v_jj*u # vector | |||
| v_prime = v_tilde - v_projection # final embedding vector | |||
| v_prime = v_prime/np.linalg.norm(v_prime) # create unit vector | |||
| return v_prime | |||
| @@ -7,6 +7,7 @@ import re | |||
| import threading | |||
| import time | |||
| import uuid | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| from multiprocessing import Process | |||
| from typing import Optional, List, cast | |||
| @@ -14,7 +15,6 @@ import openai | |||
| from billiard.pool import Pool | |||
| from flask import current_app, Flask | |||
| from flask_login import current_user | |||
| from gevent.threadpool import ThreadPoolExecutor | |||
| from langchain.embeddings import OpenAIEmbeddings | |||
| from langchain.schema import Document | |||
| from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter | |||
| @@ -516,43 +516,51 @@ class IndexingRunner: | |||
| model_name='gpt-3.5-turbo', | |||
| max_tokens=2000 | |||
| ) | |||
| self.format_document(llm, documents, split_documents, document_form) | |||
| threads = [] | |||
| for doc in documents: | |||
| document_format_thread = threading.Thread(target=self.format_document, kwargs={ | |||
| 'llm': llm, 'document_node': doc, 'split_documents': split_documents, 'document_form': document_form}) | |||
| threads.append(document_format_thread) | |||
| document_format_thread.start() | |||
| for thread in threads: | |||
| thread.join() | |||
| all_documents.extend(split_documents) | |||
| return all_documents | |||
| def format_document(self, llm: StreamableOpenAI, documents: List[Document], split_documents: List, document_form: str): | |||
| for document_node in documents: | |||
| format_documents = [] | |||
| if document_node.page_content is None or not document_node.page_content.strip(): | |||
| return format_documents | |||
| if document_form == 'text_model': | |||
| # text model document | |||
| doc_id = str(uuid.uuid4()) | |||
| hash = helper.generate_text_hash(document_node.page_content) | |||
| document_node.metadata['doc_id'] = doc_id | |||
| document_node.metadata['doc_hash'] = hash | |||
| format_documents.append(document_node) | |||
| elif document_form == 'qa_model': | |||
| try: | |||
| # qa model document | |||
| response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content) | |||
| document_qa_list = self.format_split_text(response) | |||
| qa_documents = [] | |||
| for result in document_qa_list: | |||
| qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy()) | |||
| doc_id = str(uuid.uuid4()) | |||
| hash = helper.generate_text_hash(result['question']) | |||
| qa_document.metadata['answer'] = result['answer'] | |||
| qa_document.metadata['doc_id'] = doc_id | |||
| qa_document.metadata['doc_hash'] = hash | |||
| qa_documents.append(qa_document) | |||
| format_documents.extend(qa_documents) | |||
| except Exception: | |||
| continue | |||
| split_documents.extend(format_documents) | |||
| def format_document(self, llm: StreamableOpenAI, document_node, split_documents: List, document_form: str): | |||
| print(document_node.page_content) | |||
| format_documents = [] | |||
| if document_node.page_content is None or not document_node.page_content.strip(): | |||
| return format_documents | |||
| if document_form == 'text_model': | |||
| # text model document | |||
| doc_id = str(uuid.uuid4()) | |||
| hash = helper.generate_text_hash(document_node.page_content) | |||
| document_node.metadata['doc_id'] = doc_id | |||
| document_node.metadata['doc_hash'] = hash | |||
| format_documents.append(document_node) | |||
| elif document_form == 'qa_model': | |||
| try: | |||
| # qa model document | |||
| response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content) | |||
| document_qa_list = self.format_split_text(response) | |||
| qa_documents = [] | |||
| for result in document_qa_list: | |||
| qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy()) | |||
| doc_id = str(uuid.uuid4()) | |||
| hash = helper.generate_text_hash(result['question']) | |||
| qa_document.metadata['answer'] = result['answer'] | |||
| qa_document.metadata['doc_id'] = doc_id | |||
| qa_document.metadata['doc_hash'] = hash | |||
| qa_documents.append(qa_document) | |||
| format_documents.extend(qa_documents) | |||
| except Exception: | |||
| logging.error("sss") | |||
| split_documents.extend(format_documents) | |||
| def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter, | |||
| processing_rule: DatasetProcessRule) -> List[Document]: | |||
| @@ -7,3 +7,4 @@ from .clean_when_dataset_deleted import handle | |||
| from .update_app_dataset_join_when_app_model_config_updated import handle | |||
| from .generate_conversation_name_when_first_message_created import handle | |||
| from .generate_conversation_summary_when_few_message_created import handle | |||
| from .create_document_index import handle | |||
| @@ -0,0 +1,48 @@ | |||
| from events.dataset_event import dataset_was_deleted | |||
| from events.event_handlers.document_index_event import document_index_created | |||
| from tasks.clean_dataset_task import clean_dataset_task | |||
| import datetime | |||
| import logging | |||
| import time | |||
| import click | |||
| from celery import shared_task | |||
| from werkzeug.exceptions import NotFound | |||
| from core.indexing_runner import IndexingRunner, DocumentIsPausedException | |||
| from extensions.ext_database import db | |||
| from models.dataset import Document | |||
| @document_index_created.connect | |||
| def handle(sender, **kwargs): | |||
| dataset_id = sender | |||
| document_ids = kwargs.get('document_ids', None) | |||
| documents = [] | |||
| start_at = time.perf_counter() | |||
| for document_id in document_ids: | |||
| logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) | |||
| document = db.session.query(Document).filter( | |||
| Document.id == document_id, | |||
| Document.dataset_id == dataset_id | |||
| ).first() | |||
| if not document: | |||
| raise NotFound('Document not found') | |||
| document.indexing_status = 'parsing' | |||
| document.processing_started_at = datetime.datetime.utcnow() | |||
| documents.append(document) | |||
| db.session.add(document) | |||
| db.session.commit() | |||
| try: | |||
| indexing_runner = IndexingRunner() | |||
| indexing_runner.run(documents) | |||
| end_at = time.perf_counter() | |||
| logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) | |||
| except DocumentIsPausedException as ex: | |||
| logging.info(click.style(str(ex), fg='yellow')) | |||
| except Exception: | |||
| pass | |||
| @@ -0,0 +1,4 @@ | |||
| from blinker import signal | |||
| # sender: document | |||
| document_index_created = signal('document-index-created') | |||
| @@ -10,6 +10,7 @@ from flask import current_app | |||
| from sqlalchemy import func | |||
| from core.llm.token_calculator import TokenCalculator | |||
| from events.event_handlers.document_index_event import document_index_created | |||
| from extensions.ext_redis import redis_client | |||
| from flask_login import current_user | |||
| @@ -520,6 +521,7 @@ class DocumentService: | |||
| db.session.commit() | |||
| # trigger async task | |||
| #document_index_created.send(dataset.id, document_ids=document_ids) | |||
| document_indexing_task.delay(dataset.id, document_ids) | |||
| return documents, batch | |||
| @@ -1,24 +0,0 @@ | |||
| import logging | |||
| import time | |||
| import click | |||
| import requests | |||
| from celery import shared_task | |||
| from core.generator.llm_generator import LLMGenerator | |||
| @shared_task | |||
| def generate_test_task(): | |||
| logging.info(click.style('Start generate test', fg='green')) | |||
| start_at = time.perf_counter() | |||
| try: | |||
| #res = requests.post('https://api.openai.com/v1/chat/completions') | |||
| answer = LLMGenerator.generate_conversation_name('84b2202c-c359-46b7-a810-bce50feaa4d1', 'avb', 'ccc') | |||
| print(f'answer: {answer}') | |||
| end_at = time.perf_counter() | |||
| logging.info(click.style('Conversation test, latency: {}'.format(end_at - start_at), fg='green')) | |||
| except Exception: | |||
| logging.exception("generate test failed") | |||