You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. #!/usr/bin/env python3
  2. """
  3. Test suite for the new Service API functionality in the Python SDK.
  4. This test validates the implementation of the missing Service API endpoints
  5. that were added to the Python SDK to achieve complete coverage.
  6. """
  7. import unittest
  8. from unittest.mock import Mock, patch, MagicMock
  9. import json
  10. from dify_client import (
  11. DifyClient,
  12. ChatClient,
  13. WorkflowClient,
  14. KnowledgeBaseClient,
  15. WorkspaceClient,
  16. )
  17. class TestNewServiceAPIs(unittest.TestCase):
  18. """Test cases for new Service API implementations."""
  19. def setUp(self):
  20. """Set up test fixtures."""
  21. self.api_key = "test-api-key"
  22. self.base_url = "https://api.dify.ai/v1"
  23. @patch("dify_client.client.requests.request")
  24. def test_app_info_apis(self, mock_request):
  25. """Test application info APIs."""
  26. mock_response = Mock()
  27. mock_response.json.return_value = {
  28. "name": "Test App",
  29. "description": "Test Description",
  30. "tags": ["test", "api"],
  31. "mode": "chat",
  32. "author_name": "Test Author",
  33. }
  34. mock_request.return_value = mock_response
  35. client = DifyClient(self.api_key, self.base_url)
  36. # Test get_app_info
  37. result = client.get_app_info()
  38. mock_request.assert_called_with(
  39. "GET",
  40. f"{self.base_url}/info",
  41. json=None,
  42. params=None,
  43. headers={
  44. "Authorization": f"Bearer {self.api_key}",
  45. "Content-Type": "application/json",
  46. },
  47. stream=False,
  48. )
  49. # Test get_app_site_info
  50. client.get_app_site_info()
  51. mock_request.assert_called_with(
  52. "GET",
  53. f"{self.base_url}/site",
  54. json=None,
  55. params=None,
  56. headers={
  57. "Authorization": f"Bearer {self.api_key}",
  58. "Content-Type": "application/json",
  59. },
  60. stream=False,
  61. )
  62. # Test get_file_preview
  63. file_id = "test-file-id"
  64. client.get_file_preview(file_id)
  65. mock_request.assert_called_with(
  66. "GET",
  67. f"{self.base_url}/files/{file_id}/preview",
  68. json=None,
  69. params=None,
  70. headers={
  71. "Authorization": f"Bearer {self.api_key}",
  72. "Content-Type": "application/json",
  73. },
  74. stream=False,
  75. )
  76. @patch("dify_client.client.requests.request")
  77. def test_annotation_apis(self, mock_request):
  78. """Test annotation APIs."""
  79. mock_response = Mock()
  80. mock_response.json.return_value = {"result": "success"}
  81. mock_request.return_value = mock_response
  82. client = ChatClient(self.api_key, self.base_url)
  83. # Test annotation_reply_action - enable
  84. client.annotation_reply_action(
  85. action="enable",
  86. score_threshold=0.8,
  87. embedding_provider_name="openai",
  88. embedding_model_name="text-embedding-ada-002",
  89. )
  90. mock_request.assert_called_with(
  91. "POST",
  92. f"{self.base_url}/apps/annotation-reply/enable",
  93. json={
  94. "score_threshold": 0.8,
  95. "embedding_provider_name": "openai",
  96. "embedding_model_name": "text-embedding-ada-002",
  97. },
  98. params=None,
  99. headers={
  100. "Authorization": f"Bearer {self.api_key}",
  101. "Content-Type": "application/json",
  102. },
  103. stream=False,
  104. )
  105. # Test annotation_reply_action - disable (now requires same fields as enable)
  106. client.annotation_reply_action(
  107. action="disable",
  108. score_threshold=0.5,
  109. embedding_provider_name="openai",
  110. embedding_model_name="text-embedding-ada-002",
  111. )
  112. # Test annotation_reply_action with score_threshold=0 (edge case)
  113. client.annotation_reply_action(
  114. action="enable",
  115. score_threshold=0.0, # This should work and not raise ValueError
  116. embedding_provider_name="openai",
  117. embedding_model_name="text-embedding-ada-002",
  118. )
  119. # Test get_annotation_reply_status
  120. client.get_annotation_reply_status("enable", "job-123")
  121. # Test list_annotations
  122. client.list_annotations(page=1, limit=20, keyword="test")
  123. # Test create_annotation
  124. client.create_annotation("Test question?", "Test answer.")
  125. # Test update_annotation
  126. client.update_annotation("annotation-123", "Updated question?", "Updated answer.")
  127. # Test delete_annotation
  128. client.delete_annotation("annotation-123")
  129. # Verify all calls were made (8 calls: enable + disable + enable with 0.0 + 5 other operations)
  130. self.assertEqual(mock_request.call_count, 8)
  131. @patch("dify_client.client.requests.request")
  132. def test_knowledge_base_advanced_apis(self, mock_request):
  133. """Test advanced knowledge base APIs."""
  134. mock_response = Mock()
  135. mock_response.json.return_value = {"result": "success"}
  136. mock_request.return_value = mock_response
  137. dataset_id = "test-dataset-id"
  138. client = KnowledgeBaseClient(self.api_key, self.base_url, dataset_id)
  139. # Test hit_testing
  140. client.hit_testing("test query", {"type": "vector"})
  141. mock_request.assert_called_with(
  142. "POST",
  143. f"{self.base_url}/datasets/{dataset_id}/hit-testing",
  144. json={"query": "test query", "retrieval_model": {"type": "vector"}},
  145. params=None,
  146. headers={
  147. "Authorization": f"Bearer {self.api_key}",
  148. "Content-Type": "application/json",
  149. },
  150. stream=False,
  151. )
  152. # Test metadata operations
  153. client.get_dataset_metadata()
  154. client.create_dataset_metadata({"key": "value"})
  155. client.update_dataset_metadata("meta-123", {"key": "new_value"})
  156. client.get_built_in_metadata()
  157. client.manage_built_in_metadata("enable", {"type": "built_in"})
  158. client.update_documents_metadata([{"document_id": "doc1", "metadata": {"key": "value"}}])
  159. # Test tag operations
  160. client.list_dataset_tags()
  161. client.bind_dataset_tags(["tag1", "tag2"])
  162. client.unbind_dataset_tag("tag1")
  163. client.get_dataset_tags()
  164. # Verify multiple calls were made
  165. self.assertGreater(mock_request.call_count, 5)
  166. @patch("dify_client.client.requests.request")
  167. def test_rag_pipeline_apis(self, mock_request):
  168. """Test RAG pipeline APIs."""
  169. mock_response = Mock()
  170. mock_response.json.return_value = {"result": "success"}
  171. mock_request.return_value = mock_response
  172. dataset_id = "test-dataset-id"
  173. client = KnowledgeBaseClient(self.api_key, self.base_url, dataset_id)
  174. # Test get_datasource_plugins
  175. client.get_datasource_plugins(is_published=True)
  176. mock_request.assert_called_with(
  177. "GET",
  178. f"{self.base_url}/datasets/{dataset_id}/pipeline/datasource-plugins",
  179. json=None,
  180. params={"is_published": True},
  181. headers={
  182. "Authorization": f"Bearer {self.api_key}",
  183. "Content-Type": "application/json",
  184. },
  185. stream=False,
  186. )
  187. # Test run_datasource_node
  188. client.run_datasource_node(
  189. node_id="node-123",
  190. inputs={"param": "value"},
  191. datasource_type="online_document",
  192. is_published=True,
  193. credential_id="cred-123",
  194. )
  195. # Test run_rag_pipeline with blocking mode
  196. client.run_rag_pipeline(
  197. inputs={"query": "test"},
  198. datasource_type="online_document",
  199. datasource_info_list=[{"id": "ds1"}],
  200. start_node_id="start-node",
  201. is_published=True,
  202. response_mode="blocking",
  203. )
  204. # Test run_rag_pipeline with streaming mode
  205. client.run_rag_pipeline(
  206. inputs={"query": "test"},
  207. datasource_type="online_document",
  208. datasource_info_list=[{"id": "ds1"}],
  209. start_node_id="start-node",
  210. is_published=True,
  211. response_mode="streaming",
  212. )
  213. self.assertEqual(mock_request.call_count, 4)
  214. @patch("dify_client.client.requests.request")
  215. def test_workspace_apis(self, mock_request):
  216. """Test workspace APIs."""
  217. mock_response = Mock()
  218. mock_response.json.return_value = {
  219. "data": [{"name": "gpt-3.5-turbo", "type": "llm"}, {"name": "gpt-4", "type": "llm"}]
  220. }
  221. mock_request.return_value = mock_response
  222. client = WorkspaceClient(self.api_key, self.base_url)
  223. # Test get_available_models
  224. result = client.get_available_models("llm")
  225. mock_request.assert_called_with(
  226. "GET",
  227. f"{self.base_url}/workspaces/current/models/model-types/llm",
  228. json=None,
  229. params=None,
  230. headers={
  231. "Authorization": f"Bearer {self.api_key}",
  232. "Content-Type": "application/json",
  233. },
  234. stream=False,
  235. )
  236. @patch("dify_client.client.requests.request")
  237. def test_workflow_advanced_apis(self, mock_request):
  238. """Test advanced workflow APIs."""
  239. mock_response = Mock()
  240. mock_response.json.return_value = {"result": "success"}
  241. mock_request.return_value = mock_response
  242. client = WorkflowClient(self.api_key, self.base_url)
  243. # Test get_workflow_logs
  244. client.get_workflow_logs(keyword="test", status="succeeded", page=1, limit=20)
  245. mock_request.assert_called_with(
  246. "GET",
  247. f"{self.base_url}/workflows/logs",
  248. json=None,
  249. params={"page": 1, "limit": 20, "keyword": "test", "status": "succeeded"},
  250. headers={
  251. "Authorization": f"Bearer {self.api_key}",
  252. "Content-Type": "application/json",
  253. },
  254. stream=False,
  255. )
  256. # Test get_workflow_logs with additional filters
  257. client.get_workflow_logs(
  258. keyword="test",
  259. status="succeeded",
  260. page=1,
  261. limit=20,
  262. created_at__before="2024-01-01",
  263. created_at__after="2023-01-01",
  264. created_by_account="user123",
  265. )
  266. # Test run_specific_workflow
  267. client.run_specific_workflow(
  268. workflow_id="workflow-123", inputs={"param": "value"}, response_mode="streaming", user="user-123"
  269. )
  270. self.assertEqual(mock_request.call_count, 3)
  271. def test_error_handling(self):
  272. """Test error handling for required parameters."""
  273. client = ChatClient(self.api_key, self.base_url)
  274. # Test annotation_reply_action with missing required parameters would be a TypeError now
  275. # since parameters are required in method signature
  276. with self.assertRaises(TypeError):
  277. client.annotation_reply_action("enable")
  278. # Test annotation_reply_action with explicit None values should raise ValueError
  279. with self.assertRaises(ValueError) as context:
  280. client.annotation_reply_action("enable", None, "provider", "model")
  281. self.assertIn("cannot be None", str(context.exception))
  282. # Test KnowledgeBaseClient without dataset_id
  283. kb_client = KnowledgeBaseClient(self.api_key, self.base_url)
  284. with self.assertRaises(ValueError) as context:
  285. kb_client.hit_testing("test query")
  286. self.assertIn("dataset_id is not set", str(context.exception))
  287. @patch("dify_client.client.open")
  288. @patch("dify_client.client.requests.request")
  289. def test_file_upload_apis(self, mock_request, mock_open):
  290. """Test file upload APIs."""
  291. mock_response = Mock()
  292. mock_response.json.return_value = {"result": "success"}
  293. mock_request.return_value = mock_response
  294. mock_file = MagicMock()
  295. mock_open.return_value.__enter__.return_value = mock_file
  296. dataset_id = "test-dataset-id"
  297. client = KnowledgeBaseClient(self.api_key, self.base_url, dataset_id)
  298. # Test upload_pipeline_file
  299. client.upload_pipeline_file("/path/to/test.pdf")
  300. mock_open.assert_called_with("/path/to/test.pdf", "rb")
  301. mock_request.assert_called_once()
  302. def test_comprehensive_coverage(self):
  303. """Test that all previously missing APIs are now implemented."""
  304. # Test DifyClient methods
  305. dify_methods = ["get_app_info", "get_app_site_info", "get_file_preview"]
  306. client = DifyClient(self.api_key)
  307. for method in dify_methods:
  308. self.assertTrue(hasattr(client, method), f"DifyClient missing method: {method}")
  309. # Test ChatClient annotation methods
  310. chat_methods = [
  311. "annotation_reply_action",
  312. "get_annotation_reply_status",
  313. "list_annotations",
  314. "create_annotation",
  315. "update_annotation",
  316. "delete_annotation",
  317. ]
  318. chat_client = ChatClient(self.api_key)
  319. for method in chat_methods:
  320. self.assertTrue(hasattr(chat_client, method), f"ChatClient missing method: {method}")
  321. # Test WorkflowClient advanced methods
  322. workflow_methods = ["get_workflow_logs", "run_specific_workflow"]
  323. workflow_client = WorkflowClient(self.api_key)
  324. for method in workflow_methods:
  325. self.assertTrue(hasattr(workflow_client, method), f"WorkflowClient missing method: {method}")
  326. # Test KnowledgeBaseClient advanced methods
  327. kb_methods = [
  328. "hit_testing",
  329. "get_dataset_metadata",
  330. "create_dataset_metadata",
  331. "update_dataset_metadata",
  332. "get_built_in_metadata",
  333. "manage_built_in_metadata",
  334. "update_documents_metadata",
  335. "list_dataset_tags",
  336. "bind_dataset_tags",
  337. "unbind_dataset_tag",
  338. "get_dataset_tags",
  339. "get_datasource_plugins",
  340. "run_datasource_node",
  341. "run_rag_pipeline",
  342. "upload_pipeline_file",
  343. ]
  344. kb_client = KnowledgeBaseClient(self.api_key)
  345. for method in kb_methods:
  346. self.assertTrue(hasattr(kb_client, method), f"KnowledgeBaseClient missing method: {method}")
  347. # Test WorkspaceClient methods
  348. workspace_methods = ["get_available_models"]
  349. workspace_client = WorkspaceClient(self.api_key)
  350. for method in workspace_methods:
  351. self.assertTrue(hasattr(workspace_client, method), f"WorkspaceClient missing method: {method}")
  352. if __name__ == "__main__":
  353. unittest.main()