Co-authored-by: -LAN- <laipz8200@outlook.com>tags/0.8.1
| @@ -65,7 +65,7 @@ class BasedGenerateTaskPipeline: | |||
| if isinstance(e, InvokeAuthorizationError): | |||
| err = InvokeAuthorizationError("Incorrect API key provided") | |||
| elif isinstance(e, InvokeError) or isinstance(e, ValueError): | |||
| elif isinstance(e, InvokeError | ValueError): | |||
| err = e | |||
| else: | |||
| err = Exception(e.description if getattr(e, "description", None) is not None else str(e)) | |||
| @@ -45,7 +45,7 @@ class BaichuanModel: | |||
| parameters: dict[str, Any], | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| ) -> dict[str, Any]: | |||
| if model in self._model_mapping.keys(): | |||
| if model in self._model_mapping: | |||
| # the LargeLanguageModel._code_block_mode_wrapper() method will remove the response_format of parameters. | |||
| # we need to rename it to res_format to get its value | |||
| if parameters.get("res_format") == "json_object": | |||
| @@ -94,7 +94,7 @@ class BaichuanModel: | |||
| timeout: int, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| ) -> Union[Iterator, dict]: | |||
| if model in self._model_mapping.keys(): | |||
| if model in self._model_mapping: | |||
| api_base = "https://api.baichuan-ai.com/v1/chat/completions" | |||
| else: | |||
| raise BadRequestError(f"Unknown model: {model}") | |||
| @@ -337,9 +337,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): | |||
| message_text = f"{human_prompt} {content}" | |||
| elif isinstance(message, AssistantPromptMessage): | |||
| message_text = f"{ai_prompt} {content}" | |||
| elif isinstance(message, SystemPromptMessage): | |||
| message_text = f"{human_prompt} {content}" | |||
| elif isinstance(message, ToolPromptMessage): | |||
| elif isinstance(message, SystemPromptMessage | ToolPromptMessage): | |||
| message_text = f"{human_prompt} {content}" | |||
| else: | |||
| raise ValueError(f"Got unknown type {message}") | |||
| @@ -442,9 +442,7 @@ class OCILargeLanguageModel(LargeLanguageModel): | |||
| message_text = f"{human_prompt} {content}" | |||
| elif isinstance(message, AssistantPromptMessage): | |||
| message_text = f"{ai_prompt} {content}" | |||
| elif isinstance(message, SystemPromptMessage): | |||
| message_text = f"{human_prompt} {content}" | |||
| elif isinstance(message, ToolPromptMessage): | |||
| elif isinstance(message, SystemPromptMessage | ToolPromptMessage): | |||
| message_text = f"{human_prompt} {content}" | |||
| else: | |||
| raise ValueError(f"Got unknown type {message}") | |||
| @@ -350,9 +350,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): | |||
| break | |||
| elif isinstance(message, AssistantPromptMessage): | |||
| message_text = f"{ai_prompt} {content}" | |||
| elif isinstance(message, SystemPromptMessage): | |||
| message_text = content | |||
| elif isinstance(message, ToolPromptMessage): | |||
| elif isinstance(message, SystemPromptMessage | ToolPromptMessage): | |||
| message_text = content | |||
| else: | |||
| raise ValueError(f"Got unknown type {message}") | |||
| @@ -633,9 +633,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): | |||
| message_text = f"{human_prompt} {content}" | |||
| elif isinstance(message, AssistantPromptMessage): | |||
| message_text = f"{ai_prompt} {content}" | |||
| elif isinstance(message, SystemPromptMessage): | |||
| message_text = f"{human_prompt} {content}" | |||
| elif isinstance(message, ToolPromptMessage): | |||
| elif isinstance(message, SystemPromptMessage | ToolPromptMessage): | |||
| message_text = f"{human_prompt} {content}" | |||
| else: | |||
| raise ValueError(f"Got unknown type {message}") | |||
| @@ -272,11 +272,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): | |||
| """ | |||
| text = "" | |||
| for item in message: | |||
| if isinstance(item, UserPromptMessage): | |||
| text += item.content | |||
| elif isinstance(item, SystemPromptMessage): | |||
| text += item.content | |||
| elif isinstance(item, AssistantPromptMessage): | |||
| if isinstance(item, UserPromptMessage | SystemPromptMessage | AssistantPromptMessage): | |||
| text += item.content | |||
| else: | |||
| raise NotImplementedError(f"PromptMessage type {type(item)} is not supported") | |||
| @@ -209,9 +209,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| ): | |||
| new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content | |||
| else: | |||
| if copy_prompt_message.role == PromptMessageRole.USER: | |||
| new_prompt_messages.append(copy_prompt_message) | |||
| elif copy_prompt_message.role == PromptMessageRole.TOOL: | |||
| if ( | |||
| copy_prompt_message.role == PromptMessageRole.USER | |||
| or copy_prompt_message.role == PromptMessageRole.TOOL | |||
| ): | |||
| new_prompt_messages.append(copy_prompt_message) | |||
| elif copy_prompt_message.role == PromptMessageRole.SYSTEM: | |||
| new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content) | |||
| @@ -461,9 +462,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| message_text = f"{human_prompt} {content}" | |||
| elif isinstance(message, AssistantPromptMessage): | |||
| message_text = f"{ai_prompt} {content}" | |||
| elif isinstance(message, SystemPromptMessage): | |||
| message_text = content | |||
| elif isinstance(message, ToolPromptMessage): | |||
| elif isinstance(message, SystemPromptMessage | ToolPromptMessage): | |||
| message_text = content | |||
| else: | |||
| raise ValueError(f"Got unknown type {message}") | |||
| @@ -56,14 +56,7 @@ class KeywordsModeration(Moderation): | |||
| ) | |||
| def _is_violated(self, inputs: dict, keywords_list: list) -> bool: | |||
| for value in inputs.values(): | |||
| if self._check_keywords_in_value(keywords_list, value): | |||
| return True | |||
| return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values()) | |||
| return False | |||
| def _check_keywords_in_value(self, keywords_list, value): | |||
| for keyword in keywords_list: | |||
| if keyword.lower() in value.lower(): | |||
| return True | |||
| return False | |||
| def _check_keywords_in_value(self, keywords_list, value) -> bool: | |||
| return any(keyword.lower() in value.lower() for keyword in keywords_list) | |||
| @@ -223,7 +223,7 @@ class OpsTraceManager: | |||
| :return: | |||
| """ | |||
| # auth check | |||
| if tracing_provider not in provider_config_map.keys() and tracing_provider is not None: | |||
| if tracing_provider not in provider_config_map and tracing_provider is not None: | |||
| raise ValueError(f"Invalid tracing provider: {tracing_provider}") | |||
| app_config: App = db.session.query(App).filter(App.id == app_id).first() | |||
| @@ -127,27 +127,26 @@ class RelytVector(BaseVector): | |||
| ) | |||
| chunks_table_data = [] | |||
| with self.client.connect() as conn: | |||
| with conn.begin(): | |||
| for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings): | |||
| chunks_table_data.append( | |||
| { | |||
| "id": chunk_id, | |||
| "embedding": embedding, | |||
| "document": document, | |||
| "metadata": metadata, | |||
| } | |||
| ) | |||
| # Execute the batch insert when the batch size is reached | |||
| if len(chunks_table_data) == 500: | |||
| conn.execute(insert(chunks_table).values(chunks_table_data)) | |||
| # Clear the chunks_table_data list for the next batch | |||
| chunks_table_data.clear() | |||
| # Insert any remaining records that didn't make up a full batch | |||
| if chunks_table_data: | |||
| with self.client.connect() as conn, conn.begin(): | |||
| for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings): | |||
| chunks_table_data.append( | |||
| { | |||
| "id": chunk_id, | |||
| "embedding": embedding, | |||
| "document": document, | |||
| "metadata": metadata, | |||
| } | |||
| ) | |||
| # Execute the batch insert when the batch size is reached | |||
| if len(chunks_table_data) == 500: | |||
| conn.execute(insert(chunks_table).values(chunks_table_data)) | |||
| # Clear the chunks_table_data list for the next batch | |||
| chunks_table_data.clear() | |||
| # Insert any remaining records that didn't make up a full batch | |||
| if chunks_table_data: | |||
| conn.execute(insert(chunks_table).values(chunks_table_data)) | |||
| return ids | |||
| @@ -186,11 +185,10 @@ class RelytVector(BaseVector): | |||
| ) | |||
| try: | |||
| with self.client.connect() as conn: | |||
| with conn.begin(): | |||
| delete_condition = chunks_table.c.id.in_(ids) | |||
| conn.execute(chunks_table.delete().where(delete_condition)) | |||
| return True | |||
| with self.client.connect() as conn, conn.begin(): | |||
| delete_condition = chunks_table.c.id.in_(ids) | |||
| conn.execute(chunks_table.delete().where(delete_condition)) | |||
| return True | |||
| except Exception as e: | |||
| print("Delete operation failed:", str(e)) | |||
| return False | |||
| @@ -63,10 +63,7 @@ class TencentVector(BaseVector): | |||
| def _has_collection(self) -> bool: | |||
| collections = self._db.list_collections() | |||
| for collection in collections: | |||
| if collection.collection_name == self._collection_name: | |||
| return True | |||
| return False | |||
| return any(collection.collection_name == self._collection_name for collection in collections) | |||
| def _create_collection(self, dimension: int) -> None: | |||
| lock_name = "vector_indexing_lock_{}".format(self._collection_name) | |||
| @@ -124,20 +124,19 @@ class TiDBVector(BaseVector): | |||
| texts = [d.page_content for d in documents] | |||
| chunks_table_data = [] | |||
| with self._engine.connect() as conn: | |||
| with conn.begin(): | |||
| for id, text, meta, embedding in zip(ids, texts, metas, embeddings): | |||
| chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta}) | |||
| # Execute the batch insert when the batch size is reached | |||
| if len(chunks_table_data) == 500: | |||
| conn.execute(insert(table).values(chunks_table_data)) | |||
| # Clear the chunks_table_data list for the next batch | |||
| chunks_table_data.clear() | |||
| # Insert any remaining records that didn't make up a full batch | |||
| if chunks_table_data: | |||
| with self._engine.connect() as conn, conn.begin(): | |||
| for id, text, meta, embedding in zip(ids, texts, metas, embeddings): | |||
| chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta}) | |||
| # Execute the batch insert when the batch size is reached | |||
| if len(chunks_table_data) == 500: | |||
| conn.execute(insert(table).values(chunks_table_data)) | |||
| # Clear the chunks_table_data list for the next batch | |||
| chunks_table_data.clear() | |||
| # Insert any remaining records that didn't make up a full batch | |||
| if chunks_table_data: | |||
| conn.execute(insert(table).values(chunks_table_data)) | |||
| return ids | |||
| def text_exists(self, id: str) -> bool: | |||
| @@ -160,11 +159,10 @@ class TiDBVector(BaseVector): | |||
| raise ValueError("No ids provided to delete.") | |||
| table = self._table(self._dimension) | |||
| try: | |||
| with self._engine.connect() as conn: | |||
| with conn.begin(): | |||
| delete_condition = table.c.id.in_(ids) | |||
| conn.execute(table.delete().where(delete_condition)) | |||
| return True | |||
| with self._engine.connect() as conn, conn.begin(): | |||
| delete_condition = table.c.id.in_(ids) | |||
| conn.execute(table.delete().where(delete_condition)) | |||
| return True | |||
| except Exception as e: | |||
| print("Delete operation failed:", str(e)) | |||
| return False | |||
| @@ -48,7 +48,8 @@ class WordExtractor(BaseExtractor): | |||
| raise ValueError(f"Check the url of your file; returned status code {r.status_code}") | |||
| self.web_path = self.file_path | |||
| self.temp_file = tempfile.NamedTemporaryFile() | |||
| # TODO: use a better way to handle the file | |||
| self.temp_file = tempfile.NamedTemporaryFile() # noqa: SIM115 | |||
| self.temp_file.write(r.content) | |||
| self.file_path = self.temp_file.name | |||
| elif not os.path.isfile(self.file_path): | |||
| @@ -120,8 +120,8 @@ class WeightRerankRunner: | |||
| intersection = set(vec1.keys()) & set(vec2.keys()) | |||
| numerator = sum(vec1[x] * vec2[x] for x in intersection) | |||
| sum1 = sum(vec1[x] ** 2 for x in vec1.keys()) | |||
| sum2 = sum(vec2[x] ** 2 for x in vec2.keys()) | |||
| sum1 = sum(vec1[x] ** 2 for x in vec1) | |||
| sum2 = sum(vec2[x] ** 2 for x in vec2) | |||
| denominator = math.sqrt(sum1) * math.sqrt(sum2) | |||
| if not denominator: | |||
| @@ -581,8 +581,8 @@ class DatasetRetrieval: | |||
| intersection = set(vec1.keys()) & set(vec2.keys()) | |||
| numerator = sum(vec1[x] * vec2[x] for x in intersection) | |||
| sum1 = sum(vec1[x] ** 2 for x in vec1.keys()) | |||
| sum2 = sum(vec2[x] ** 2 for x in vec2.keys()) | |||
| sum1 = sum(vec1[x] ** 2 for x in vec1) | |||
| sum2 = sum(vec2[x] ** 2 for x in vec2) | |||
| denominator = math.sqrt(sum1) * math.sqrt(sum2) | |||
| if not denominator: | |||
| @@ -201,9 +201,7 @@ class ListWorksheetRecordsTool(BuiltinTool): | |||
| elif value.startswith('[{"organizeId"'): | |||
| value = json.loads(value) | |||
| value = "、".join([item["organizeName"] for item in value]) | |||
| elif value.startswith('[{"file_id"'): | |||
| value = "" | |||
| elif value == "[]": | |||
| elif value.startswith('[{"file_id"') or value == "[]": | |||
| value = "" | |||
| elif hasattr(value, "accountId"): | |||
| value = value["fullname"] | |||
| @@ -35,7 +35,7 @@ class NovitaAiModelQueryTool(BuiltinTool): | |||
| models_data=[], | |||
| headers=headers, | |||
| params=params, | |||
| recursive=False if result_type == "first sd_name" or result_type == "first name sd_name pair" else True, | |||
| recursive=not (result_type == "first sd_name" or result_type == "first name sd_name pair"), | |||
| ) | |||
| result_str = "" | |||
| @@ -39,7 +39,7 @@ class QRCodeGeneratorTool(BuiltinTool): | |||
| # get error_correction | |||
| error_correction = tool_parameters.get("error_correction", "") | |||
| if error_correction not in self.error_correction_levels.keys(): | |||
| if error_correction not in self.error_correction_levels: | |||
| return self.create_text_message("Invalid parameter error_correction") | |||
| try: | |||
| @@ -44,36 +44,36 @@ class SearchAPI: | |||
| @staticmethod | |||
| def _process_response(res: dict, type: str) -> str: | |||
| """Process response from SearchAPI.""" | |||
| if "error" in res.keys(): | |||
| if "error" in res: | |||
| raise ValueError(f"Got error from SearchApi: {res['error']}") | |||
| toret = "" | |||
| if type == "text": | |||
| if "answer_box" in res.keys() and "answer" in res["answer_box"].keys(): | |||
| if "answer_box" in res and "answer" in res["answer_box"]: | |||
| toret += res["answer_box"]["answer"] + "\n" | |||
| if "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): | |||
| if "answer_box" in res and "snippet" in res["answer_box"]: | |||
| toret += res["answer_box"]["snippet"] + "\n" | |||
| if "knowledge_graph" in res.keys() and "description" in res["knowledge_graph"].keys(): | |||
| if "knowledge_graph" in res and "description" in res["knowledge_graph"]: | |||
| toret += res["knowledge_graph"]["description"] + "\n" | |||
| if "organic_results" in res.keys() and "snippet" in res["organic_results"][0].keys(): | |||
| if "organic_results" in res and "snippet" in res["organic_results"][0]: | |||
| for item in res["organic_results"]: | |||
| toret += "content: " + item["snippet"] + "\n" + "link: " + item["link"] + "\n" | |||
| if toret == "": | |||
| toret = "No good search result found" | |||
| elif type == "link": | |||
| if "answer_box" in res.keys() and "organic_result" in res["answer_box"].keys(): | |||
| if "title" in res["answer_box"]["organic_result"].keys(): | |||
| if "answer_box" in res and "organic_result" in res["answer_box"]: | |||
| if "title" in res["answer_box"]["organic_result"]: | |||
| toret = f"[{res['answer_box']['organic_result']['title']}]({res['answer_box']['organic_result']['link']})\n" | |||
| elif "organic_results" in res.keys() and "link" in res["organic_results"][0].keys(): | |||
| elif "organic_results" in res and "link" in res["organic_results"][0]: | |||
| toret = "" | |||
| for item in res["organic_results"]: | |||
| toret += f"[{item['title']}]({item['link']})\n" | |||
| elif "related_questions" in res.keys() and "link" in res["related_questions"][0].keys(): | |||
| elif "related_questions" in res and "link" in res["related_questions"][0]: | |||
| toret = "" | |||
| for item in res["related_questions"]: | |||
| toret += f"[{item['title']}]({item['link']})\n" | |||
| elif "related_searches" in res.keys() and "link" in res["related_searches"][0].keys(): | |||
| elif "related_searches" in res and "link" in res["related_searches"][0]: | |||
| toret = "" | |||
| for item in res["related_searches"]: | |||
| toret += f"[{item['title']}]({item['link']})\n" | |||
| @@ -44,12 +44,12 @@ class SearchAPI: | |||
| @staticmethod | |||
| def _process_response(res: dict, type: str) -> str: | |||
| """Process response from SearchAPI.""" | |||
| if "error" in res.keys(): | |||
| if "error" in res: | |||
| raise ValueError(f"Got error from SearchApi: {res['error']}") | |||
| toret = "" | |||
| if type == "text": | |||
| if "jobs" in res.keys() and "title" in res["jobs"][0].keys(): | |||
| if "jobs" in res and "title" in res["jobs"][0]: | |||
| for item in res["jobs"]: | |||
| toret += ( | |||
| "title: " | |||
| @@ -65,7 +65,7 @@ class SearchAPI: | |||
| toret = "No good search result found" | |||
| elif type == "link": | |||
| if "jobs" in res.keys() and "apply_link" in res["jobs"][0].keys(): | |||
| if "jobs" in res and "apply_link" in res["jobs"][0]: | |||
| for item in res["jobs"]: | |||
| toret += f"[{item['title']} - {item['company_name']}]({item['apply_link']})\n" | |||
| else: | |||
| @@ -44,25 +44,25 @@ class SearchAPI: | |||
| @staticmethod | |||
| def _process_response(res: dict, type: str) -> str: | |||
| """Process response from SearchAPI.""" | |||
| if "error" in res.keys(): | |||
| if "error" in res: | |||
| raise ValueError(f"Got error from SearchApi: {res['error']}") | |||
| toret = "" | |||
| if type == "text": | |||
| if "organic_results" in res.keys() and "snippet" in res["organic_results"][0].keys(): | |||
| if "organic_results" in res and "snippet" in res["organic_results"][0]: | |||
| for item in res["organic_results"]: | |||
| toret += "content: " + item["snippet"] + "\n" + "link: " + item["link"] + "\n" | |||
| if "top_stories" in res.keys() and "title" in res["top_stories"][0].keys(): | |||
| if "top_stories" in res and "title" in res["top_stories"][0]: | |||
| for item in res["top_stories"]: | |||
| toret += "title: " + item["title"] + "\n" + "link: " + item["link"] + "\n" | |||
| if toret == "": | |||
| toret = "No good search result found" | |||
| elif type == "link": | |||
| if "organic_results" in res.keys() and "title" in res["organic_results"][0].keys(): | |||
| if "organic_results" in res and "title" in res["organic_results"][0]: | |||
| for item in res["organic_results"]: | |||
| toret += f"[{item['title']}]({item['link']})\n" | |||
| elif "top_stories" in res.keys() and "title" in res["top_stories"][0].keys(): | |||
| elif "top_stories" in res and "title" in res["top_stories"][0]: | |||
| for item in res["top_stories"]: | |||
| toret += f"[{item['title']}]({item['link']})\n" | |||
| else: | |||
| @@ -44,11 +44,11 @@ class SearchAPI: | |||
| @staticmethod | |||
| def _process_response(res: dict) -> str: | |||
| """Process response from SearchAPI.""" | |||
| if "error" in res.keys(): | |||
| if "error" in res: | |||
| raise ValueError(f"Got error from SearchApi: {res['error']}") | |||
| toret = "" | |||
| if "transcripts" in res.keys() and "text" in res["transcripts"][0].keys(): | |||
| if "transcripts" in res and "text" in res["transcripts"][0]: | |||
| for item in res["transcripts"]: | |||
| toret += item["text"] + " " | |||
| if toret == "": | |||
| @@ -35,7 +35,7 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization): | |||
| if model in ["sd3", "sd3-turbo"]: | |||
| payload["model"] = tool_parameters.get("model") | |||
| if not model == "sd3-turbo": | |||
| if model != "sd3-turbo": | |||
| payload["negative_prompt"] = tool_parameters.get("negative_prompt", "") | |||
| response = post( | |||
| @@ -206,10 +206,9 @@ class StableDiffusionTool(BuiltinTool): | |||
| # Convert image to RGB and save as PNG | |||
| try: | |||
| with Image.open(io.BytesIO(image_binary)) as image: | |||
| with io.BytesIO() as buffer: | |||
| image.convert("RGB").save(buffer, format="PNG") | |||
| image_binary = buffer.getvalue() | |||
| with Image.open(io.BytesIO(image_binary)) as image, io.BytesIO() as buffer: | |||
| image.convert("RGB").save(buffer, format="PNG") | |||
| image_binary = buffer.getvalue() | |||
| except Exception as e: | |||
| return self.create_text_message(f"Failed to process the image: {str(e)}") | |||
| @@ -27,7 +27,7 @@ class WikipediaAPIWrapper: | |||
| self.doc_content_chars_max = doc_content_chars_max | |||
| def run(self, query: str, lang: str = "") -> str: | |||
| if lang in wikipedia.languages().keys(): | |||
| if lang in wikipedia.languages(): | |||
| self.lang = lang | |||
| wikipedia.set_lang(self.lang) | |||
| @@ -19,9 +19,7 @@ class ToolFileMessageTransformer: | |||
| result = [] | |||
| for message in messages: | |||
| if message.type == ToolInvokeMessage.MessageType.TEXT: | |||
| result.append(message) | |||
| elif message.type == ToolInvokeMessage.MessageType.LINK: | |||
| if message.type == ToolInvokeMessage.MessageType.TEXT or message.type == ToolInvokeMessage.MessageType.LINK: | |||
| result.append(message) | |||
| elif message.type == ToolInvokeMessage.MessageType.IMAGE: | |||
| # try to download image | |||
| @@ -224,9 +224,7 @@ class Graph(BaseModel): | |||
| """ | |||
| leaf_node_ids = [] | |||
| for node_id in self.node_ids: | |||
| if node_id not in self.edge_mapping: | |||
| leaf_node_ids.append(node_id) | |||
| elif ( | |||
| if node_id not in self.edge_mapping or ( | |||
| len(self.edge_mapping[node_id]) == 1 | |||
| and self.edge_mapping[node_id][0].target_node_id == self.root_node_id | |||
| ): | |||
| @@ -24,7 +24,7 @@ class AnswerStreamGeneratorRouter: | |||
| # parse stream output node value selectors of answer nodes | |||
| answer_generate_route: dict[str, list[GenerateRouteChunk]] = {} | |||
| for answer_node_id, node_config in node_id_config_mapping.items(): | |||
| if not node_config.get("data", {}).get("type") == NodeType.ANSWER.value: | |||
| if node_config.get("data", {}).get("type") != NodeType.ANSWER.value: | |||
| continue | |||
| # get generate route for stream output | |||
| @@ -17,7 +17,7 @@ class EndStreamGeneratorRouter: | |||
| # parse stream output node value selector of end nodes | |||
| end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {} | |||
| for end_node_id, node_config in node_id_config_mapping.items(): | |||
| if not node_config.get("data", {}).get("type") == NodeType.END.value: | |||
| if node_config.get("data", {}).get("type") != NodeType.END.value: | |||
| continue | |||
| # skip end node in parallel | |||
| @@ -20,7 +20,7 @@ class ToolEntity(BaseModel): | |||
| if not isinstance(value, dict): | |||
| raise ValueError("tool_configurations must be a dictionary") | |||
| for key in values.data.get("tool_configurations", {}).keys(): | |||
| for key in values.data.get("tool_configurations", {}): | |||
| value = values.data.get("tool_configurations", {}).get(key) | |||
| if not isinstance(value, str | int | float | bool): | |||
| raise ValueError(f"{key} must be a string") | |||
| @@ -17,14 +17,12 @@ select = [ | |||
| "F", # pyflakes rules | |||
| "I", # isort rules | |||
| "N", # pep8-naming | |||
| "UP", # pyupgrade rules | |||
| "RUF019", # unnecessary-key-check | |||
| "RUF100", # unused-noqa | |||
| "RUF101", # redirected-noqa | |||
| "S506", # unsafe-yaml-load | |||
| "SIM116", # if-else-block-instead-of-dict-lookup | |||
| "SIM401", # if-else-block-instead-of-dict-get | |||
| "SIM910", # dict-get-with-none-default | |||
| "SIM", # flake8-simplify rules | |||
| "UP", # pyupgrade rules | |||
| "W191", # tab-indentation | |||
| "W605", # invalid-escape-sequence | |||
| ] | |||
| @@ -50,6 +48,15 @@ ignore = [ | |||
| "B905", # zip-without-explicit-strict | |||
| "N806", # non-lowercase-variable-in-function | |||
| "N815", # mixed-case-variable-in-class-scope | |||
| "SIM102", # collapsible-if | |||
| "SIM103", # needless-bool | |||
| "SIM105", # suppressible-exception | |||
| "SIM107", # return-in-try-except-finally | |||
| "SIM108", # if-else-block-instead-of-if-exp | |||
| "SIM113", # eumerate-for-loop | |||
| "SIM117", # multiple-with-statements | |||
| "SIM210", # if-expr-with-true-false | |||
| "SIM300", # yoda-conditions | |||
| ] | |||
| [tool.ruff.lint.per-file-ignores] | |||
| @@ -56,9 +56,7 @@ class FileService: | |||
| if etl_type == "Unstructured" | |||
| else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS | |||
| ) | |||
| if extension.lower() not in allowed_extensions: | |||
| raise UnsupportedFileTypeError() | |||
| elif only_image and extension.lower() not in IMAGE_EXTENSIONS: | |||
| if extension.lower() not in allowed_extensions or only_image and extension.lower() not in IMAGE_EXTENSIONS: | |||
| raise UnsupportedFileTypeError() | |||
| # read file content | |||
| @@ -54,7 +54,7 @@ class OpsService: | |||
| :param tracing_config: tracing config | |||
| :return: | |||
| """ | |||
| if tracing_provider not in provider_config_map.keys() and tracing_provider: | |||
| if tracing_provider not in provider_config_map and tracing_provider: | |||
| return {"error": f"Invalid tracing provider: {tracing_provider}"} | |||
| config_class, other_keys = ( | |||
| @@ -113,7 +113,7 @@ class OpsService: | |||
| :param tracing_config: tracing config | |||
| :return: | |||
| """ | |||
| if tracing_provider not in provider_config_map.keys(): | |||
| if tracing_provider not in provider_config_map: | |||
| raise ValueError(f"Invalid tracing provider: {tracing_provider}") | |||
| # check if trace config already exists | |||