### What problem does this PR solve? Format the code ### Type of change - [x] Refactoring Signed-off-by: Jin Hai <haijin.chn@gmail.com>tags/v0.16.0
| @@ -1,56 +1,56 @@ | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from abc import ABC | |||
| import pandas as pd | |||
| import akshare as ak | |||
| from agent.component.base import ComponentBase, ComponentParamBase | |||
| class AkShareParam(ComponentParamBase): | |||
| """ | |||
| Define the AkShare component parameters. | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.top_n = 10 | |||
| def check(self): | |||
| self.check_positive_integer(self.top_n, "Top N") | |||
| class AkShare(ComponentBase, ABC): | |||
| component_name = "AkShare" | |||
| def _run(self, history, **kwargs): | |||
| ans = self.get_input() | |||
| ans = ",".join(ans["content"]) if "content" in ans else "" | |||
| if not ans: | |||
| return AkShare.be_output("") | |||
| try: | |||
| ak_res = [] | |||
| stock_news_em_df = ak.stock_news_em(symbol=ans) | |||
| stock_news_em_df = stock_news_em_df.head(self._param.top_n) | |||
| ak_res = [{"content": '<a href="' + i["新闻链接"] + '">' + i["新闻标题"] + '</a>\n 新闻内容: ' + i[ | |||
| "新闻内容"] + " \n发布时间:" + i["发布时间"] + " \n文章来源: " + i["文章来源"]} for index, i in stock_news_em_df.iterrows()] | |||
| except Exception as e: | |||
| return AkShare.be_output("**ERROR**: " + str(e)) | |||
| if not ak_res: | |||
| return AkShare.be_output("") | |||
| return pd.DataFrame(ak_res) | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from abc import ABC | |||
| import pandas as pd | |||
| import akshare as ak | |||
| from agent.component.base import ComponentBase, ComponentParamBase | |||
| class AkShareParam(ComponentParamBase): | |||
| """ | |||
| Define the AkShare component parameters. | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.top_n = 10 | |||
| def check(self): | |||
| self.check_positive_integer(self.top_n, "Top N") | |||
| class AkShare(ComponentBase, ABC): | |||
| component_name = "AkShare" | |||
| def _run(self, history, **kwargs): | |||
| ans = self.get_input() | |||
| ans = ",".join(ans["content"]) if "content" in ans else "" | |||
| if not ans: | |||
| return AkShare.be_output("") | |||
| try: | |||
| ak_res = [] | |||
| stock_news_em_df = ak.stock_news_em(symbol=ans) | |||
| stock_news_em_df = stock_news_em_df.head(self._param.top_n) | |||
| ak_res = [{"content": '<a href="' + i["新闻链接"] + '">' + i["新闻标题"] + '</a>\n 新闻内容: ' + i[ | |||
| "新闻内容"] + " \n发布时间:" + i["发布时间"] + " \n文章来源: " + i["文章来源"]} for index, i in stock_news_em_df.iterrows()] | |||
| except Exception as e: | |||
| return AkShare.be_output("**ERROR**: " + str(e)) | |||
| if not ak_res: | |||
| return AkShare.be_output("") | |||
| return pd.DataFrame(ak_res) | |||
| @@ -1,36 +1,36 @@ | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from abc import ABC | |||
| from agent.component.base import ComponentBase, ComponentParamBase | |||
| class ConcentratorParam(ComponentParamBase): | |||
| """ | |||
| Define the Concentrator component parameters. | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| def check(self): | |||
| return True | |||
| class Concentrator(ComponentBase, ABC): | |||
| component_name = "Concentrator" | |||
| def _run(self, history, **kwargs): | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from abc import ABC | |||
| from agent.component.base import ComponentBase, ComponentParamBase | |||
| class ConcentratorParam(ComponentParamBase): | |||
| """ | |||
| Define the Concentrator component parameters. | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| def check(self): | |||
| return True | |||
| class Concentrator(ComponentBase, ABC): | |||
| component_name = "Concentrator" | |||
| def _run(self, history, **kwargs): | |||
| return Concentrator.be_output("") | |||
| @@ -1,130 +1,130 @@ | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import json | |||
| from abc import ABC | |||
| import pandas as pd | |||
| import requests | |||
| from agent.component.base import ComponentBase, ComponentParamBase | |||
| class Jin10Param(ComponentParamBase): | |||
| """ | |||
| Define the Jin10 component parameters. | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.type = "flash" | |||
| self.secret_key = "xxx" | |||
| self.flash_type = '1' | |||
| self.calendar_type = 'cj' | |||
| self.calendar_datatype = 'data' | |||
| self.symbols_type = 'GOODS' | |||
| self.symbols_datatype = 'symbols' | |||
| self.contain = "" | |||
| self.filter = "" | |||
| def check(self): | |||
| self.check_valid_value(self.type, "Type", ['flash', 'calendar', 'symbols', 'news']) | |||
| self.check_valid_value(self.flash_type, "Flash Type", ['1', '2', '3', '4', '5']) | |||
| self.check_valid_value(self.calendar_type, "Calendar Type", ['cj', 'qh', 'hk', 'us']) | |||
| self.check_valid_value(self.calendar_datatype, "Calendar DataType", ['data', 'event', 'holiday']) | |||
| self.check_valid_value(self.symbols_type, "Symbols Type", ['GOODS', 'FOREX', 'FUTURE', 'CRYPTO']) | |||
| self.check_valid_value(self.symbols_datatype, 'Symbols DataType', ['symbols', 'quotes']) | |||
| class Jin10(ComponentBase, ABC): | |||
| component_name = "Jin10" | |||
| def _run(self, history, **kwargs): | |||
| ans = self.get_input() | |||
| ans = " - ".join(ans["content"]) if "content" in ans else "" | |||
| if not ans: | |||
| return Jin10.be_output("") | |||
| jin10_res = [] | |||
| headers = {'secret-key': self._param.secret_key} | |||
| try: | |||
| if self._param.type == "flash": | |||
| params = { | |||
| 'category': self._param.flash_type, | |||
| 'contain': self._param.contain, | |||
| 'filter': self._param.filter | |||
| } | |||
| response = requests.get( | |||
| url='https://open-data-api.jin10.com/data-api/flash?category=' + self._param.flash_type, | |||
| headers=headers, data=json.dumps(params)) | |||
| response = response.json() | |||
| for i in response['data']: | |||
| jin10_res.append({"content": i['data']['content']}) | |||
| if self._param.type == "calendar": | |||
| params = { | |||
| 'category': self._param.calendar_type | |||
| } | |||
| response = requests.get( | |||
| url='https://open-data-api.jin10.com/data-api/calendar/' + self._param.calendar_datatype + '?category=' + self._param.calendar_type, | |||
| headers=headers, data=json.dumps(params)) | |||
| response = response.json() | |||
| jin10_res.append({"content": pd.DataFrame(response['data']).to_markdown()}) | |||
| if self._param.type == "symbols": | |||
| params = { | |||
| 'type': self._param.symbols_type | |||
| } | |||
| if self._param.symbols_datatype == "quotes": | |||
| params['codes'] = 'BTCUSD' | |||
| response = requests.get( | |||
| url='https://open-data-api.jin10.com/data-api/' + self._param.symbols_datatype + '?type=' + self._param.symbols_type, | |||
| headers=headers, data=json.dumps(params)) | |||
| response = response.json() | |||
| if self._param.symbols_datatype == "symbols": | |||
| for i in response['data']: | |||
| i['Commodity Code'] = i['c'] | |||
| i['Stock Exchange'] = i['e'] | |||
| i['Commodity Name'] = i['n'] | |||
| i['Commodity Type'] = i['t'] | |||
| del i['c'], i['e'], i['n'], i['t'] | |||
| if self._param.symbols_datatype == "quotes": | |||
| for i in response['data']: | |||
| i['Selling Price'] = i['a'] | |||
| i['Buying Price'] = i['b'] | |||
| i['Commodity Code'] = i['c'] | |||
| i['Stock Exchange'] = i['e'] | |||
| i['Highest Price'] = i['h'] | |||
| i['Yesterday’s Closing Price'] = i['hc'] | |||
| i['Lowest Price'] = i['l'] | |||
| i['Opening Price'] = i['o'] | |||
| i['Latest Price'] = i['p'] | |||
| i['Market Quote Time'] = i['t'] | |||
| del i['a'], i['b'], i['c'], i['e'], i['h'], i['hc'], i['l'], i['o'], i['p'], i['t'] | |||
| jin10_res.append({"content": pd.DataFrame(response['data']).to_markdown()}) | |||
| if self._param.type == "news": | |||
| params = { | |||
| 'contain': self._param.contain, | |||
| 'filter': self._param.filter | |||
| } | |||
| response = requests.get( | |||
| url='https://open-data-api.jin10.com/data-api/news', | |||
| headers=headers, data=json.dumps(params)) | |||
| response = response.json() | |||
| jin10_res.append({"content": pd.DataFrame(response['data']).to_markdown()}) | |||
| except Exception as e: | |||
| return Jin10.be_output("**ERROR**: " + str(e)) | |||
| if not jin10_res: | |||
| return Jin10.be_output("") | |||
| return pd.DataFrame(jin10_res) | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import json | |||
| from abc import ABC | |||
| import pandas as pd | |||
| import requests | |||
| from agent.component.base import ComponentBase, ComponentParamBase | |||
| class Jin10Param(ComponentParamBase): | |||
| """ | |||
| Define the Jin10 component parameters. | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.type = "flash" | |||
| self.secret_key = "xxx" | |||
| self.flash_type = '1' | |||
| self.calendar_type = 'cj' | |||
| self.calendar_datatype = 'data' | |||
| self.symbols_type = 'GOODS' | |||
| self.symbols_datatype = 'symbols' | |||
| self.contain = "" | |||
| self.filter = "" | |||
| def check(self): | |||
| self.check_valid_value(self.type, "Type", ['flash', 'calendar', 'symbols', 'news']) | |||
| self.check_valid_value(self.flash_type, "Flash Type", ['1', '2', '3', '4', '5']) | |||
| self.check_valid_value(self.calendar_type, "Calendar Type", ['cj', 'qh', 'hk', 'us']) | |||
| self.check_valid_value(self.calendar_datatype, "Calendar DataType", ['data', 'event', 'holiday']) | |||
| self.check_valid_value(self.symbols_type, "Symbols Type", ['GOODS', 'FOREX', 'FUTURE', 'CRYPTO']) | |||
| self.check_valid_value(self.symbols_datatype, 'Symbols DataType', ['symbols', 'quotes']) | |||
| class Jin10(ComponentBase, ABC): | |||
| component_name = "Jin10" | |||
| def _run(self, history, **kwargs): | |||
| ans = self.get_input() | |||
| ans = " - ".join(ans["content"]) if "content" in ans else "" | |||
| if not ans: | |||
| return Jin10.be_output("") | |||
| jin10_res = [] | |||
| headers = {'secret-key': self._param.secret_key} | |||
| try: | |||
| if self._param.type == "flash": | |||
| params = { | |||
| 'category': self._param.flash_type, | |||
| 'contain': self._param.contain, | |||
| 'filter': self._param.filter | |||
| } | |||
| response = requests.get( | |||
| url='https://open-data-api.jin10.com/data-api/flash?category=' + self._param.flash_type, | |||
| headers=headers, data=json.dumps(params)) | |||
| response = response.json() | |||
| for i in response['data']: | |||
| jin10_res.append({"content": i['data']['content']}) | |||
| if self._param.type == "calendar": | |||
| params = { | |||
| 'category': self._param.calendar_type | |||
| } | |||
| response = requests.get( | |||
| url='https://open-data-api.jin10.com/data-api/calendar/' + self._param.calendar_datatype + '?category=' + self._param.calendar_type, | |||
| headers=headers, data=json.dumps(params)) | |||
| response = response.json() | |||
| jin10_res.append({"content": pd.DataFrame(response['data']).to_markdown()}) | |||
| if self._param.type == "symbols": | |||
| params = { | |||
| 'type': self._param.symbols_type | |||
| } | |||
| if self._param.symbols_datatype == "quotes": | |||
| params['codes'] = 'BTCUSD' | |||
| response = requests.get( | |||
| url='https://open-data-api.jin10.com/data-api/' + self._param.symbols_datatype + '?type=' + self._param.symbols_type, | |||
| headers=headers, data=json.dumps(params)) | |||
| response = response.json() | |||
| if self._param.symbols_datatype == "symbols": | |||
| for i in response['data']: | |||
| i['Commodity Code'] = i['c'] | |||
| i['Stock Exchange'] = i['e'] | |||
| i['Commodity Name'] = i['n'] | |||
| i['Commodity Type'] = i['t'] | |||
| del i['c'], i['e'], i['n'], i['t'] | |||
| if self._param.symbols_datatype == "quotes": | |||
| for i in response['data']: | |||
| i['Selling Price'] = i['a'] | |||
| i['Buying Price'] = i['b'] | |||
| i['Commodity Code'] = i['c'] | |||
| i['Stock Exchange'] = i['e'] | |||
| i['Highest Price'] = i['h'] | |||
| i['Yesterday’s Closing Price'] = i['hc'] | |||
| i['Lowest Price'] = i['l'] | |||
| i['Opening Price'] = i['o'] | |||
| i['Latest Price'] = i['p'] | |||
| i['Market Quote Time'] = i['t'] | |||
| del i['a'], i['b'], i['c'], i['e'], i['h'], i['hc'], i['l'], i['o'], i['p'], i['t'] | |||
| jin10_res.append({"content": pd.DataFrame(response['data']).to_markdown()}) | |||
| if self._param.type == "news": | |||
| params = { | |||
| 'contain': self._param.contain, | |||
| 'filter': self._param.filter | |||
| } | |||
| response = requests.get( | |||
| url='https://open-data-api.jin10.com/data-api/news', | |||
| headers=headers, data=json.dumps(params)) | |||
| response = response.json() | |||
| jin10_res.append({"content": pd.DataFrame(response['data']).to_markdown()}) | |||
| except Exception as e: | |||
| return Jin10.be_output("**ERROR**: " + str(e)) | |||
| if not jin10_res: | |||
| return Jin10.be_output("") | |||
| return pd.DataFrame(jin10_res) | |||
| @@ -1,72 +1,72 @@ | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import json | |||
| from abc import ABC | |||
| import pandas as pd | |||
| import time | |||
| import requests | |||
| from agent.component.base import ComponentBase, ComponentParamBase | |||
| class TuShareParam(ComponentParamBase): | |||
| """ | |||
| Define the TuShare component parameters. | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.token = "xxx" | |||
| self.src = "eastmoney" | |||
| self.start_date = "2024-01-01 09:00:00" | |||
| self.end_date = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) | |||
| self.keyword = "" | |||
| def check(self): | |||
| self.check_valid_value(self.src, "Quick News Source", | |||
| ["sina", "wallstreetcn", "10jqka", "eastmoney", "yuncaijing", "fenghuang", "jinrongjie"]) | |||
| class TuShare(ComponentBase, ABC): | |||
| component_name = "TuShare" | |||
| def _run(self, history, **kwargs): | |||
| ans = self.get_input() | |||
| ans = ",".join(ans["content"]) if "content" in ans else "" | |||
| if not ans: | |||
| return TuShare.be_output("") | |||
| try: | |||
| tus_res = [] | |||
| params = { | |||
| "api_name": "news", | |||
| "token": self._param.token, | |||
| "params": {"src": self._param.src, "start_date": self._param.start_date, | |||
| "end_date": self._param.end_date} | |||
| } | |||
| response = requests.post(url="http://api.tushare.pro", data=json.dumps(params).encode('utf-8')) | |||
| response = response.json() | |||
| if response['code'] != 0: | |||
| return TuShare.be_output(response['msg']) | |||
| df = pd.DataFrame(response['data']['items']) | |||
| df.columns = response['data']['fields'] | |||
| tus_res.append({"content": (df[df['content'].str.contains(self._param.keyword, case=False)]).to_markdown()}) | |||
| except Exception as e: | |||
| return TuShare.be_output("**ERROR**: " + str(e)) | |||
| if not tus_res: | |||
| return TuShare.be_output("") | |||
| return pd.DataFrame(tus_res) | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import json | |||
| from abc import ABC | |||
| import pandas as pd | |||
| import time | |||
| import requests | |||
| from agent.component.base import ComponentBase, ComponentParamBase | |||
| class TuShareParam(ComponentParamBase): | |||
| """ | |||
| Define the TuShare component parameters. | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.token = "xxx" | |||
| self.src = "eastmoney" | |||
| self.start_date = "2024-01-01 09:00:00" | |||
| self.end_date = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) | |||
| self.keyword = "" | |||
| def check(self): | |||
| self.check_valid_value(self.src, "Quick News Source", | |||
| ["sina", "wallstreetcn", "10jqka", "eastmoney", "yuncaijing", "fenghuang", "jinrongjie"]) | |||
| class TuShare(ComponentBase, ABC): | |||
| component_name = "TuShare" | |||
| def _run(self, history, **kwargs): | |||
| ans = self.get_input() | |||
| ans = ",".join(ans["content"]) if "content" in ans else "" | |||
| if not ans: | |||
| return TuShare.be_output("") | |||
| try: | |||
| tus_res = [] | |||
| params = { | |||
| "api_name": "news", | |||
| "token": self._param.token, | |||
| "params": {"src": self._param.src, "start_date": self._param.start_date, | |||
| "end_date": self._param.end_date} | |||
| } | |||
| response = requests.post(url="http://api.tushare.pro", data=json.dumps(params).encode('utf-8')) | |||
| response = response.json() | |||
| if response['code'] != 0: | |||
| return TuShare.be_output(response['msg']) | |||
| df = pd.DataFrame(response['data']['items']) | |||
| df.columns = response['data']['fields'] | |||
| tus_res.append({"content": (df[df['content'].str.contains(self._param.keyword, case=False)]).to_markdown()}) | |||
| except Exception as e: | |||
| return TuShare.be_output("**ERROR**: " + str(e)) | |||
| if not tus_res: | |||
| return TuShare.be_output("") | |||
| return pd.DataFrame(tus_res) | |||
| @@ -1,80 +1,80 @@ | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from abc import ABC | |||
| import pandas as pd | |||
| import pywencai | |||
| from agent.component.base import ComponentBase, ComponentParamBase | |||
| class WenCaiParam(ComponentParamBase): | |||
| """ | |||
| Define the WenCai component parameters. | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.top_n = 10 | |||
| self.query_type = "stock" | |||
| def check(self): | |||
| self.check_positive_integer(self.top_n, "Top N") | |||
| self.check_valid_value(self.query_type, "Query type", | |||
| ['stock', 'zhishu', 'fund', 'hkstock', 'usstock', 'threeboard', 'conbond', 'insurance', | |||
| 'futures', 'lccp', | |||
| 'foreign_exchange']) | |||
| class WenCai(ComponentBase, ABC): | |||
| component_name = "WenCai" | |||
| def _run(self, history, **kwargs): | |||
| ans = self.get_input() | |||
| ans = ",".join(ans["content"]) if "content" in ans else "" | |||
| if not ans: | |||
| return WenCai.be_output("") | |||
| try: | |||
| wencai_res = [] | |||
| res = pywencai.get(query=ans, query_type=self._param.query_type, perpage=self._param.top_n) | |||
| if isinstance(res, pd.DataFrame): | |||
| wencai_res.append({"content": res.to_markdown()}) | |||
| if isinstance(res, dict): | |||
| for item in res.items(): | |||
| if isinstance(item[1], list): | |||
| wencai_res.append({"content": item[0] + "\n" + pd.DataFrame(item[1]).to_markdown()}) | |||
| continue | |||
| if isinstance(item[1], str): | |||
| wencai_res.append({"content": item[0] + "\n" + item[1]}) | |||
| continue | |||
| if isinstance(item[1], dict): | |||
| if "meta" in item[1].keys(): | |||
| continue | |||
| wencai_res.append({"content": pd.DataFrame.from_dict(item[1], orient='index').to_markdown()}) | |||
| continue | |||
| if isinstance(item[1], pd.DataFrame): | |||
| if "image_url" in item[1].columns: | |||
| continue | |||
| wencai_res.append({"content": item[1].to_markdown()}) | |||
| continue | |||
| wencai_res.append({"content": item[0] + "\n" + str(item[1])}) | |||
| except Exception as e: | |||
| return WenCai.be_output("**ERROR**: " + str(e)) | |||
| if not wencai_res: | |||
| return WenCai.be_output("") | |||
| return pd.DataFrame(wencai_res) | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from abc import ABC | |||
| import pandas as pd | |||
| import pywencai | |||
| from agent.component.base import ComponentBase, ComponentParamBase | |||
| class WenCaiParam(ComponentParamBase): | |||
| """ | |||
| Define the WenCai component parameters. | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.top_n = 10 | |||
| self.query_type = "stock" | |||
| def check(self): | |||
| self.check_positive_integer(self.top_n, "Top N") | |||
| self.check_valid_value(self.query_type, "Query type", | |||
| ['stock', 'zhishu', 'fund', 'hkstock', 'usstock', 'threeboard', 'conbond', 'insurance', | |||
| 'futures', 'lccp', | |||
| 'foreign_exchange']) | |||
| class WenCai(ComponentBase, ABC): | |||
| component_name = "WenCai" | |||
| def _run(self, history, **kwargs): | |||
| ans = self.get_input() | |||
| ans = ",".join(ans["content"]) if "content" in ans else "" | |||
| if not ans: | |||
| return WenCai.be_output("") | |||
| try: | |||
| wencai_res = [] | |||
| res = pywencai.get(query=ans, query_type=self._param.query_type, perpage=self._param.top_n) | |||
| if isinstance(res, pd.DataFrame): | |||
| wencai_res.append({"content": res.to_markdown()}) | |||
| if isinstance(res, dict): | |||
| for item in res.items(): | |||
| if isinstance(item[1], list): | |||
| wencai_res.append({"content": item[0] + "\n" + pd.DataFrame(item[1]).to_markdown()}) | |||
| continue | |||
| if isinstance(item[1], str): | |||
| wencai_res.append({"content": item[0] + "\n" + item[1]}) | |||
| continue | |||
| if isinstance(item[1], dict): | |||
| if "meta" in item[1].keys(): | |||
| continue | |||
| wencai_res.append({"content": pd.DataFrame.from_dict(item[1], orient='index').to_markdown()}) | |||
| continue | |||
| if isinstance(item[1], pd.DataFrame): | |||
| if "image_url" in item[1].columns: | |||
| continue | |||
| wencai_res.append({"content": item[1].to_markdown()}) | |||
| continue | |||
| wencai_res.append({"content": item[0] + "\n" + str(item[1])}) | |||
| except Exception as e: | |||
| return WenCai.be_output("**ERROR**: " + str(e)) | |||
| if not wencai_res: | |||
| return WenCai.be_output("") | |||
| return pd.DataFrame(wencai_res) | |||
| @@ -1,84 +1,84 @@ | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import logging | |||
| from abc import ABC | |||
| import pandas as pd | |||
| from agent.component.base import ComponentBase, ComponentParamBase | |||
| import yfinance as yf | |||
| class YahooFinanceParam(ComponentParamBase): | |||
| """ | |||
| Define the YahooFinance component parameters. | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.info = True | |||
| self.history = False | |||
| self.count = False | |||
| self.financials = False | |||
| self.income_stmt = False | |||
| self.balance_sheet = False | |||
| self.cash_flow_statement = False | |||
| self.news = True | |||
| def check(self): | |||
| self.check_boolean(self.info, "get all stock info") | |||
| self.check_boolean(self.history, "get historical market data") | |||
| self.check_boolean(self.count, "show share count") | |||
| self.check_boolean(self.financials, "show financials") | |||
| self.check_boolean(self.income_stmt, "income statement") | |||
| self.check_boolean(self.balance_sheet, "balance sheet") | |||
| self.check_boolean(self.cash_flow_statement, "cash flow statement") | |||
| self.check_boolean(self.news, "show news") | |||
| class YahooFinance(ComponentBase, ABC): | |||
| component_name = "YahooFinance" | |||
| def _run(self, history, **kwargs): | |||
| ans = self.get_input() | |||
| ans = "".join(ans["content"]) if "content" in ans else "" | |||
| if not ans: | |||
| return YahooFinance.be_output("") | |||
| yohoo_res = [] | |||
| try: | |||
| msft = yf.Ticker(ans) | |||
| if self._param.info: | |||
| yohoo_res.append({"content": "info:\n" + pd.Series(msft.info).to_markdown() + "\n"}) | |||
| if self._param.history: | |||
| yohoo_res.append({"content": "history:\n" + msft.history().to_markdown() + "\n"}) | |||
| if self._param.financials: | |||
| yohoo_res.append({"content": "calendar:\n" + pd.DataFrame(msft.calendar).to_markdown() + "\n"}) | |||
| if self._param.balance_sheet: | |||
| yohoo_res.append({"content": "balance sheet:\n" + msft.balance_sheet.to_markdown() + "\n"}) | |||
| yohoo_res.append( | |||
| {"content": "quarterly balance sheet:\n" + msft.quarterly_balance_sheet.to_markdown() + "\n"}) | |||
| if self._param.cash_flow_statement: | |||
| yohoo_res.append({"content": "cash flow statement:\n" + msft.cashflow.to_markdown() + "\n"}) | |||
| yohoo_res.append( | |||
| {"content": "quarterly cash flow statement:\n" + msft.quarterly_cashflow.to_markdown() + "\n"}) | |||
| if self._param.news: | |||
| yohoo_res.append({"content": "news:\n" + pd.DataFrame(msft.news).to_markdown() + "\n"}) | |||
| except Exception: | |||
| logging.exception("YahooFinance got exception") | |||
| if not yohoo_res: | |||
| return YahooFinance.be_output("") | |||
| return pd.DataFrame(yohoo_res) | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import logging | |||
| from abc import ABC | |||
| import pandas as pd | |||
| from agent.component.base import ComponentBase, ComponentParamBase | |||
| import yfinance as yf | |||
| class YahooFinanceParam(ComponentParamBase): | |||
| """ | |||
| Define the YahooFinance component parameters. | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.info = True | |||
| self.history = False | |||
| self.count = False | |||
| self.financials = False | |||
| self.income_stmt = False | |||
| self.balance_sheet = False | |||
| self.cash_flow_statement = False | |||
| self.news = True | |||
| def check(self): | |||
| self.check_boolean(self.info, "get all stock info") | |||
| self.check_boolean(self.history, "get historical market data") | |||
| self.check_boolean(self.count, "show share count") | |||
| self.check_boolean(self.financials, "show financials") | |||
| self.check_boolean(self.income_stmt, "income statement") | |||
| self.check_boolean(self.balance_sheet, "balance sheet") | |||
| self.check_boolean(self.cash_flow_statement, "cash flow statement") | |||
| self.check_boolean(self.news, "show news") | |||
| class YahooFinance(ComponentBase, ABC): | |||
| component_name = "YahooFinance" | |||
| def _run(self, history, **kwargs): | |||
| ans = self.get_input() | |||
| ans = "".join(ans["content"]) if "content" in ans else "" | |||
| if not ans: | |||
| return YahooFinance.be_output("") | |||
| yohoo_res = [] | |||
| try: | |||
| msft = yf.Ticker(ans) | |||
| if self._param.info: | |||
| yohoo_res.append({"content": "info:\n" + pd.Series(msft.info).to_markdown() + "\n"}) | |||
| if self._param.history: | |||
| yohoo_res.append({"content": "history:\n" + msft.history().to_markdown() + "\n"}) | |||
| if self._param.financials: | |||
| yohoo_res.append({"content": "calendar:\n" + pd.DataFrame(msft.calendar).to_markdown() + "\n"}) | |||
| if self._param.balance_sheet: | |||
| yohoo_res.append({"content": "balance sheet:\n" + msft.balance_sheet.to_markdown() + "\n"}) | |||
| yohoo_res.append( | |||
| {"content": "quarterly balance sheet:\n" + msft.quarterly_balance_sheet.to_markdown() + "\n"}) | |||
| if self._param.cash_flow_statement: | |||
| yohoo_res.append({"content": "cash flow statement:\n" + msft.cashflow.to_markdown() + "\n"}) | |||
| yohoo_res.append( | |||
| {"content": "quarterly cash flow statement:\n" + msft.quarterly_cashflow.to_markdown() + "\n"}) | |||
| if self._param.news: | |||
| yohoo_res.append({"content": "news:\n" + pd.DataFrame(msft.news).to_markdown() + "\n"}) | |||
| except Exception: | |||
| logging.exception("YahooFinance got exception") | |||
| if not yohoo_res: | |||
| return YahooFinance.be_output("") | |||
| return pd.DataFrame(yohoo_res) | |||
| @@ -1,113 +1,113 @@ | |||
| { | |||
| "components": { | |||
| "begin": { | |||
| "obj":{ | |||
| "component_name": "Begin", | |||
| "params": { | |||
| "prologue": "Hi there!" | |||
| } | |||
| }, | |||
| "downstream": ["answer:0"], | |||
| "upstream": [] | |||
| }, | |||
| "answer:0": { | |||
| "obj": { | |||
| "component_name": "Answer", | |||
| "params": {} | |||
| }, | |||
| "downstream": ["categorize:0"], | |||
| "upstream": ["begin"] | |||
| }, | |||
| "categorize:0": { | |||
| "obj": { | |||
| "component_name": "Categorize", | |||
| "params": { | |||
| "llm_id": "deepseek-chat", | |||
| "category_description": { | |||
| "product_related": { | |||
| "description": "The question is about the product usage, appearance and how it works.", | |||
| "examples": "Why it always beaming?\nHow to install it onto the wall?\nIt leaks, what to do?", | |||
| "to": "concentrator:0" | |||
| }, | |||
| "others": { | |||
| "description": "The question is not about the product usage, appearance and how it works.", | |||
| "examples": "How are you doing?\nWhat is your name?\nAre you a robot?\nWhat's the weather?\nWill it rain?", | |||
| "to": "concentrator:1" | |||
| } | |||
| } | |||
| } | |||
| }, | |||
| "downstream": ["concentrator:0","concentrator:1"], | |||
| "upstream": ["answer:0"] | |||
| }, | |||
| "concentrator:0": { | |||
| "obj": { | |||
| "component_name": "Concentrator", | |||
| "params": {} | |||
| }, | |||
| "downstream": ["message:0"], | |||
| "upstream": ["categorize:0"] | |||
| }, | |||
| "concentrator:1": { | |||
| "obj": { | |||
| "component_name": "Concentrator", | |||
| "params": {} | |||
| }, | |||
| "downstream": ["message:1_0","message:1_1","message:1_2"], | |||
| "upstream": ["categorize:0"] | |||
| }, | |||
| "message:0": { | |||
| "obj": { | |||
| "component_name": "Message", | |||
| "params": { | |||
| "messages": [ | |||
| "Message 0_0!!!!!!!" | |||
| ] | |||
| } | |||
| }, | |||
| "downstream": ["answer:0"], | |||
| "upstream": ["concentrator:0"] | |||
| }, | |||
| "message:1_0": { | |||
| "obj": { | |||
| "component_name": "Message", | |||
| "params": { | |||
| "messages": [ | |||
| "Message 1_0!!!!!!!" | |||
| ] | |||
| } | |||
| }, | |||
| "downstream": ["answer:0"], | |||
| "upstream": ["concentrator:1"] | |||
| }, | |||
| "message:1_1": { | |||
| "obj": { | |||
| "component_name": "Message", | |||
| "params": { | |||
| "messages": [ | |||
| "Message 1_1!!!!!!!" | |||
| ] | |||
| } | |||
| }, | |||
| "downstream": ["answer:0"], | |||
| "upstream": ["concentrator:1"] | |||
| }, | |||
| "message:1_2": { | |||
| "obj": { | |||
| "component_name": "Message", | |||
| "params": { | |||
| "messages": [ | |||
| "Message 1_2!!!!!!!" | |||
| ] | |||
| } | |||
| }, | |||
| "downstream": ["answer:0"], | |||
| "upstream": ["concentrator:1"] | |||
| } | |||
| }, | |||
| "history": [], | |||
| "messages": [], | |||
| "path": [], | |||
| "reference": [], | |||
| "answer": [] | |||
| { | |||
| "components": { | |||
| "begin": { | |||
| "obj":{ | |||
| "component_name": "Begin", | |||
| "params": { | |||
| "prologue": "Hi there!" | |||
| } | |||
| }, | |||
| "downstream": ["answer:0"], | |||
| "upstream": [] | |||
| }, | |||
| "answer:0": { | |||
| "obj": { | |||
| "component_name": "Answer", | |||
| "params": {} | |||
| }, | |||
| "downstream": ["categorize:0"], | |||
| "upstream": ["begin"] | |||
| }, | |||
| "categorize:0": { | |||
| "obj": { | |||
| "component_name": "Categorize", | |||
| "params": { | |||
| "llm_id": "deepseek-chat", | |||
| "category_description": { | |||
| "product_related": { | |||
| "description": "The question is about the product usage, appearance and how it works.", | |||
| "examples": "Why it always beaming?\nHow to install it onto the wall?\nIt leaks, what to do?", | |||
| "to": "concentrator:0" | |||
| }, | |||
| "others": { | |||
| "description": "The question is not about the product usage, appearance and how it works.", | |||
| "examples": "How are you doing?\nWhat is your name?\nAre you a robot?\nWhat's the weather?\nWill it rain?", | |||
| "to": "concentrator:1" | |||
| } | |||
| } | |||
| } | |||
| }, | |||
| "downstream": ["concentrator:0","concentrator:1"], | |||
| "upstream": ["answer:0"] | |||
| }, | |||
| "concentrator:0": { | |||
| "obj": { | |||
| "component_name": "Concentrator", | |||
| "params": {} | |||
| }, | |||
| "downstream": ["message:0"], | |||
| "upstream": ["categorize:0"] | |||
| }, | |||
| "concentrator:1": { | |||
| "obj": { | |||
| "component_name": "Concentrator", | |||
| "params": {} | |||
| }, | |||
| "downstream": ["message:1_0","message:1_1","message:1_2"], | |||
| "upstream": ["categorize:0"] | |||
| }, | |||
| "message:0": { | |||
| "obj": { | |||
| "component_name": "Message", | |||
| "params": { | |||
| "messages": [ | |||
| "Message 0_0!!!!!!!" | |||
| ] | |||
| } | |||
| }, | |||
| "downstream": ["answer:0"], | |||
| "upstream": ["concentrator:0"] | |||
| }, | |||
| "message:1_0": { | |||
| "obj": { | |||
| "component_name": "Message", | |||
| "params": { | |||
| "messages": [ | |||
| "Message 1_0!!!!!!!" | |||
| ] | |||
| } | |||
| }, | |||
| "downstream": ["answer:0"], | |||
| "upstream": ["concentrator:1"] | |||
| }, | |||
| "message:1_1": { | |||
| "obj": { | |||
| "component_name": "Message", | |||
| "params": { | |||
| "messages": [ | |||
| "Message 1_1!!!!!!!" | |||
| ] | |||
| } | |||
| }, | |||
| "downstream": ["answer:0"], | |||
| "upstream": ["concentrator:1"] | |||
| }, | |||
| "message:1_2": { | |||
| "obj": { | |||
| "component_name": "Message", | |||
| "params": { | |||
| "messages": [ | |||
| "Message 1_2!!!!!!!" | |||
| ] | |||
| } | |||
| }, | |||
| "downstream": ["answer:0"], | |||
| "upstream": ["concentrator:1"] | |||
| } | |||
| }, | |||
| "history": [], | |||
| "messages": [], | |||
| "path": [], | |||
| "reference": [], | |||
| "answer": [] | |||
| } | |||
| @@ -1,60 +1,60 @@ | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import logging | |||
| import time | |||
| import traceback | |||
| from api.db.db_models import close_connection | |||
| from api.db.services.task_service import TaskService | |||
| from rag.utils.storage_factory import STORAGE_IMPL | |||
| from rag.utils.redis_conn import REDIS_CONN | |||
| def collect(): | |||
| doc_locations = TaskService.get_ongoing_doc_name() | |||
| logging.debug(doc_locations) | |||
| if len(doc_locations) == 0: | |||
| time.sleep(1) | |||
| return | |||
| return doc_locations | |||
| def main(): | |||
| locations = collect() | |||
| if not locations: | |||
| return | |||
| logging.info(f"TASKS: {len(locations)}") | |||
| for kb_id, loc in locations: | |||
| try: | |||
| if REDIS_CONN.is_alive(): | |||
| try: | |||
| key = "{}/{}".format(kb_id, loc) | |||
| if REDIS_CONN.exist(key): | |||
| continue | |||
| file_bin = STORAGE_IMPL.get(kb_id, loc) | |||
| REDIS_CONN.transaction(key, file_bin, 12 * 60) | |||
| logging.info("CACHE: {}".format(loc)) | |||
| except Exception as e: | |||
| traceback.print_stack(e) | |||
| except Exception as e: | |||
| traceback.print_stack(e) | |||
| if __name__ == "__main__": | |||
| while True: | |||
| main() | |||
| close_connection() | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import logging | |||
| import time | |||
| import traceback | |||
| from api.db.db_models import close_connection | |||
| from api.db.services.task_service import TaskService | |||
| from rag.utils.storage_factory import STORAGE_IMPL | |||
| from rag.utils.redis_conn import REDIS_CONN | |||
| def collect(): | |||
| doc_locations = TaskService.get_ongoing_doc_name() | |||
| logging.debug(doc_locations) | |||
| if len(doc_locations) == 0: | |||
| time.sleep(1) | |||
| return | |||
| return doc_locations | |||
| def main(): | |||
| locations = collect() | |||
| if not locations: | |||
| return | |||
| logging.info(f"TASKS: {len(locations)}") | |||
| for kb_id, loc in locations: | |||
| try: | |||
| if REDIS_CONN.is_alive(): | |||
| try: | |||
| key = "{}/{}".format(kb_id, loc) | |||
| if REDIS_CONN.exist(key): | |||
| continue | |||
| file_bin = STORAGE_IMPL.get(kb_id, loc) | |||
| REDIS_CONN.transaction(key, file_bin, 12 * 60) | |||
| logging.info("CACHE: {}".format(loc)) | |||
| except Exception as e: | |||
| traceback.print_stack(e) | |||
| except Exception as e: | |||
| traceback.print_stack(e) | |||
| if __name__ == "__main__": | |||
| while True: | |||
| main() | |||
| close_connection() | |||
| time.sleep(1) | |||
| @@ -1,38 +1,38 @@ | |||
| class Base(object): | |||
| def __init__(self, rag, res_dict): | |||
| self.rag = rag | |||
| for k, v in res_dict.items(): | |||
| if isinstance(v, dict): | |||
| self.__dict__[k] = Base(rag, v) | |||
| else: | |||
| self.__dict__[k] = v | |||
| def to_json(self): | |||
| pr = {} | |||
| for name in dir(self): | |||
| value = getattr(self, name) | |||
| if not name.startswith('__') and not callable(value) and name != "rag": | |||
| if isinstance(value, Base): | |||
| pr[name] = value.to_json() | |||
| else: | |||
| pr[name] = value | |||
| return pr | |||
| def post(self, path, json=None, stream=False, files=None): | |||
| res = self.rag.post(path, json, stream=stream,files=files) | |||
| return res | |||
| def get(self, path, params=None): | |||
| res = self.rag.get(path, params) | |||
| return res | |||
| def rm(self, path, json): | |||
| res = self.rag.delete(path, json) | |||
| return res | |||
| def put(self,path, json): | |||
| res = self.rag.put(path,json) | |||
| return res | |||
| def __str__(self): | |||
| return str(self.to_json()) | |||
| class Base(object): | |||
| def __init__(self, rag, res_dict): | |||
| self.rag = rag | |||
| for k, v in res_dict.items(): | |||
| if isinstance(v, dict): | |||
| self.__dict__[k] = Base(rag, v) | |||
| else: | |||
| self.__dict__[k] = v | |||
| def to_json(self): | |||
| pr = {} | |||
| for name in dir(self): | |||
| value = getattr(self, name) | |||
| if not name.startswith('__') and not callable(value) and name != "rag": | |||
| if isinstance(value, Base): | |||
| pr[name] = value.to_json() | |||
| else: | |||
| pr[name] = value | |||
| return pr | |||
| def post(self, path, json=None, stream=False, files=None): | |||
| res = self.rag.post(path, json, stream=stream,files=files) | |||
| return res | |||
| def get(self, path, params=None): | |||
| res = self.rag.get(path, params) | |||
| return res | |||
| def rm(self, path, json): | |||
| res = self.rag.delete(path, json) | |||
| return res | |||
| def put(self,path, json): | |||
| res = self.rag.put(path,json) | |||
| return res | |||
| def __str__(self): | |||
| return str(self.to_json()) | |||
| @@ -1,73 +1,73 @@ | |||
| import json | |||
| from .base import Base | |||
| class Session(Base): | |||
| def __init__(self, rag, res_dict): | |||
| self.id = None | |||
| self.name = "New session" | |||
| self.messages = [{"role": "assistant", "content": "Hi! I am your assistant,can I help you?"}] | |||
| for key,value in res_dict.items(): | |||
| if key =="chat_id" and value is not None: | |||
| self.chat_id = None | |||
| self.__session_type = "chat" | |||
| if key == "agent_id" and value is not None: | |||
| self.agent_id = None | |||
| self.__session_type = "agent" | |||
| super().__init__(rag, res_dict) | |||
| def ask(self, question="",stream=True,**kwargs): | |||
| if self.__session_type == "agent": | |||
| res=self._ask_agent(question,stream) | |||
| elif self.__session_type == "chat": | |||
| res=self._ask_chat(question,stream,**kwargs) | |||
| for line in res.iter_lines(): | |||
| line = line.decode("utf-8") | |||
| if line.startswith("{"): | |||
| json_data = json.loads(line) | |||
| raise Exception(json_data["message"]) | |||
| if not line.startswith("data:"): | |||
| continue | |||
| json_data = json.loads(line[5:]) | |||
| if json_data["data"] is True or json_data["data"].get("running_status"): | |||
| continue | |||
| answer = json_data["data"]["answer"] | |||
| reference = json_data["data"].get("reference", {}) | |||
| temp_dict = { | |||
| "content": answer, | |||
| "role": "assistant" | |||
| } | |||
| if reference and "chunks" in reference: | |||
| chunks = reference["chunks"] | |||
| temp_dict["reference"] = chunks | |||
| message = Message(self.rag, temp_dict) | |||
| yield message | |||
| def _ask_chat(self, question: str, stream: bool,**kwargs): | |||
| json_data={"question": question, "stream": True,"session_id":self.id} | |||
| json_data.update(kwargs) | |||
| res = self.post(f"/chats/{self.chat_id}/completions", | |||
| json_data, stream=stream) | |||
| return res | |||
| def _ask_agent(self,question:str,stream:bool): | |||
| res = self.post(f"/agents/{self.agent_id}/completions", | |||
| {"question": question, "stream": True,"session_id":self.id}, stream=stream) | |||
| return res | |||
| def update(self,update_message): | |||
| res = self.put(f"/chats/{self.chat_id}/sessions/{self.id}", | |||
| update_message) | |||
| res = res.json() | |||
| if res.get("code") != 0: | |||
| raise Exception(res.get("message")) | |||
| class Message(Base): | |||
| def __init__(self, rag, res_dict): | |||
| self.content = "Hi! I am your assistant,can I help you?" | |||
| self.reference = None | |||
| self.role = "assistant" | |||
| self.prompt = None | |||
| self.id = None | |||
| super().__init__(rag, res_dict) | |||
| import json | |||
| from .base import Base | |||
| class Session(Base): | |||
| def __init__(self, rag, res_dict): | |||
| self.id = None | |||
| self.name = "New session" | |||
| self.messages = [{"role": "assistant", "content": "Hi! I am your assistant,can I help you?"}] | |||
| for key,value in res_dict.items(): | |||
| if key =="chat_id" and value is not None: | |||
| self.chat_id = None | |||
| self.__session_type = "chat" | |||
| if key == "agent_id" and value is not None: | |||
| self.agent_id = None | |||
| self.__session_type = "agent" | |||
| super().__init__(rag, res_dict) | |||
| def ask(self, question="",stream=True,**kwargs): | |||
| if self.__session_type == "agent": | |||
| res=self._ask_agent(question,stream) | |||
| elif self.__session_type == "chat": | |||
| res=self._ask_chat(question,stream,**kwargs) | |||
| for line in res.iter_lines(): | |||
| line = line.decode("utf-8") | |||
| if line.startswith("{"): | |||
| json_data = json.loads(line) | |||
| raise Exception(json_data["message"]) | |||
| if not line.startswith("data:"): | |||
| continue | |||
| json_data = json.loads(line[5:]) | |||
| if json_data["data"] is True or json_data["data"].get("running_status"): | |||
| continue | |||
| answer = json_data["data"]["answer"] | |||
| reference = json_data["data"].get("reference", {}) | |||
| temp_dict = { | |||
| "content": answer, | |||
| "role": "assistant" | |||
| } | |||
| if reference and "chunks" in reference: | |||
| chunks = reference["chunks"] | |||
| temp_dict["reference"] = chunks | |||
| message = Message(self.rag, temp_dict) | |||
| yield message | |||
| def _ask_chat(self, question: str, stream: bool,**kwargs): | |||
| json_data={"question": question, "stream": True,"session_id":self.id} | |||
| json_data.update(kwargs) | |||
| res = self.post(f"/chats/{self.chat_id}/completions", | |||
| json_data, stream=stream) | |||
| return res | |||
| def _ask_agent(self,question:str,stream:bool): | |||
| res = self.post(f"/agents/{self.agent_id}/completions", | |||
| {"question": question, "stream": True,"session_id":self.id}, stream=stream) | |||
| return res | |||
| def update(self,update_message): | |||
| res = self.put(f"/chats/{self.chat_id}/sessions/{self.id}", | |||
| update_message) | |||
| res = res.json() | |||
| if res.get("code") != 0: | |||
| raise Exception(res.get("message")) | |||
| class Message(Base): | |||
| def __init__(self, rag, res_dict): | |||
| self.content = "Hi! I am your assistant,can I help you?" | |||
| self.reference = None | |||
| self.role = "assistant" | |||
| self.prompt = None | |||
| self.id = None | |||
| super().__init__(rag, res_dict) | |||