| @@ -32,6 +32,15 @@ provider_credential_schema: | |||
| zh_Hans: 在此输入您的 API Key | |||
| en_US: Enter your API Key | |||
| show_on: [ ] | |||
| - variable: base_url | |||
| label: | |||
| zh_Hans: API Base | |||
| en_US: API Base | |||
| type: text-input | |||
| required: false | |||
| placeholder: | |||
| zh_Hans: 在此输入您的 API Base,如 https://api.cohere.ai/v1 | |||
| en_US: Enter your API Base, e.g. https://api.cohere.ai/v1 | |||
| model_credential_schema: | |||
| model: | |||
| label: | |||
| @@ -70,3 +79,12 @@ model_credential_schema: | |||
| placeholder: | |||
| zh_Hans: 在此输入您的 API Key | |||
| en_US: Enter your API Key | |||
| - variable: base_url | |||
| label: | |||
| zh_Hans: API Base | |||
| en_US: API Base | |||
| type: text-input | |||
| required: false | |||
| placeholder: | |||
| zh_Hans: 在此输入您的 API Base,如 https://api.cohere.ai/v1 | |||
| en_US: Enter your API Base, e.g. https://api.cohere.ai/v1 | |||
| @@ -173,7 +173,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): | |||
| :return: full response or stream response chunk generator result | |||
| """ | |||
| # initialize client | |||
| client = cohere.Client(credentials.get('api_key')) | |||
| client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) | |||
| if stop: | |||
| model_parameters['end_sequences'] = stop | |||
| @@ -233,7 +233,8 @@ class CohereLargeLanguageModel(LargeLanguageModel): | |||
| return response | |||
| def _handle_generate_stream_response(self, model: str, credentials: dict, response: Iterator[GenerateStreamedResponse], | |||
| def _handle_generate_stream_response(self, model: str, credentials: dict, | |||
| response: Iterator[GenerateStreamedResponse], | |||
| prompt_messages: list[PromptMessage]) -> Generator: | |||
| """ | |||
| Handle llm stream response | |||
| @@ -317,7 +318,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): | |||
| :return: full response or stream response chunk generator result | |||
| """ | |||
| # initialize client | |||
| client = cohere.Client(credentials.get('api_key')) | |||
| client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) | |||
| if stop: | |||
| model_parameters['stop_sequences'] = stop | |||
| @@ -636,7 +637,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): | |||
| :return: number of tokens | |||
| """ | |||
| # initialize client | |||
| client = cohere.Client(credentials.get('api_key')) | |||
| client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) | |||
| response = client.tokenize( | |||
| text=text, | |||
| @@ -44,7 +44,7 @@ class CohereRerankModel(RerankModel): | |||
| ) | |||
| # initialize client | |||
| client = cohere.Client(credentials.get('api_key')) | |||
| client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) | |||
| response = client.rerank( | |||
| query=query, | |||
| documents=docs, | |||
| @@ -141,7 +141,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): | |||
| return [] | |||
| # initialize client | |||
| client = cohere.Client(credentials.get('api_key')) | |||
| client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) | |||
| response = client.tokenize( | |||
| text=text, | |||
| @@ -180,7 +180,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): | |||
| :return: embeddings and used tokens | |||
| """ | |||
| # initialize client | |||
| client = cohere.Client(credentials.get('api_key')) | |||
| client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) | |||
| # call embedding model | |||
| response = client.embed( | |||